39 lines
1.3 KiB
Python
39 lines
1.3 KiB
Python
from catopt_query.dsl import LocalProblemDSL, SharedVariablesDSL, PlanDeltaDSL
|
|
from catopt_query.protocol import LocalProblem, SharedVariables, PlanDelta
|
|
|
|
|
|
def test_local_problem_dsl_to_protocol():
|
|
dsl = LocalProblemDSL(
|
|
shard_id="shard-1",
|
|
projection=["a", "b"],
|
|
predicates=["a>0"],
|
|
costs=1.5,
|
|
constraints={"limit": 100},
|
|
)
|
|
lp = dsl.to_protocol()
|
|
assert isinstance(lp, LocalProblem)
|
|
assert lp.shard_id == dsl.shard_id
|
|
assert lp.projection == dsl.projection
|
|
assert lp.predicates == dsl.predicates
|
|
assert lp.costs == dsl.costs
|
|
assert lp.constraints == dsl.constraints
|
|
|
|
|
|
def test_sharedvariables_dsl_to_protocol():
|
|
dsl = SharedVariablesDSL(version=1, signals={"x": 0.5}, priors={"x": 0.2})
|
|
sv = dsl.to_protocol()
|
|
assert isinstance(sv, SharedVariables)
|
|
assert sv.version == dsl.version
|
|
assert sv.signals == dsl.signals
|
|
assert sv.priors == dsl.priors
|
|
|
|
|
|
def test_plan_delta_dsl_to_protocol():
|
|
dsl = PlanDeltaDSL(delta_id="d1", timestamp=123.0, changes={"a": 1}, contract_id="c0")
|
|
pd = dsl.to_protocol()
|
|
assert isinstance(pd, PlanDelta)
|
|
assert pd.delta_id == dsl.delta_id
|
|
assert pd.timestamp == dsl.timestamp
|
|
assert pd.changes == dsl.changes
|
|
assert pd.contract_id == dsl.contract_id
|