catopt-query-category-theor.../tests/test_dsl.py

39 lines
1.3 KiB
Python

from catopt_query.dsl import LocalProblemDSL, SharedVariablesDSL, PlanDeltaDSL
from catopt_query.protocol import LocalProblem, SharedVariables, PlanDelta
def test_local_problem_dsl_to_protocol():
dsl = LocalProblemDSL(
shard_id="shard-1",
projection=["a", "b"],
predicates=["a>0"],
costs=1.5,
constraints={"limit": 100},
)
lp = dsl.to_protocol()
assert isinstance(lp, LocalProblem)
assert lp.shard_id == dsl.shard_id
assert lp.projection == dsl.projection
assert lp.predicates == dsl.predicates
assert lp.costs == dsl.costs
assert lp.constraints == dsl.constraints
def test_sharedvariables_dsl_to_protocol():
dsl = SharedVariablesDSL(version=1, signals={"x": 0.5}, priors={"x": 0.2})
sv = dsl.to_protocol()
assert isinstance(sv, SharedVariables)
assert sv.version == dsl.version
assert sv.signals == dsl.signals
assert sv.priors == dsl.priors
def test_plan_delta_dsl_to_protocol():
dsl = PlanDeltaDSL(delta_id="d1", timestamp=123.0, changes={"a": 1}, contract_id="c0")
pd = dsl.to_protocol()
assert isinstance(pd, PlanDelta)
assert pd.delta_id == dsl.delta_id
assert pd.timestamp == dsl.timestamp
assert pd.changes == dsl.changes
assert pd.contract_id == dsl.contract_id