catopt-query-category-theor.../catopt_query/protocol.py

110 lines
3.4 KiB
Python

from __future__ import annotations
from dataclasses import dataclass, asdict, field
from typing import Dict, Any, List
class LocalProblem:
"""Backward-compatible LocalProblem with alias support.
Accepts either 'projection' or 'projected_attrs' in constructor and
dictionary representations.
"""
def __init__(self, shard_id: str, projection=None, projected_attrs=None,
predicates=None, costs=0.0, constraints=None):
self.shard_id = shard_id
# Support both naming styles for compatibility with older tests
if projection is None:
projection = projected_attrs if projected_attrs is not None else []
self.projection = list(projection)
self.predicates = list(predicates) if predicates is not None else []
self.costs = float(costs)
self.constraints = dict(constraints) if constraints is not None else {}
def to_dict(self) -> Dict[str, Any]:
return {
"shard_id": self.shard_id,
"projection": self.projection,
"projected_attrs": self.projection, # alias for backwards-compat
"predicates": self.predicates,
"costs": self.costs,
"constraints": self.constraints,
}
@staticmethod
def from_dict(d: Dict[str, Any]) -> "LocalProblem":
if d is None:
d = {}
return LocalProblem(
shard_id=d["shard_id"],
projection=d.get("projection", d.get("projected_attrs", [])),
predicates=d.get("predicates", []),
costs=d.get("costs", 0.0),
constraints=d.get("constraints", {}),
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, LocalProblem):
return False
return self.to_dict() == other.to_dict()
@dataclass(frozen=True)
class SharedVariables:
version: int
signals: Dict[str, float]
priors: Dict[str, float]
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
@staticmethod
def from_dict(d: Dict[str, Any]) -> "SharedVariables":
return SharedVariables(
version=int(d.get("version", 0)),
signals=dict(d.get("signals", {})),
priors=dict(d.get("priors", {})),
)
@dataclass(frozen=True)
class PlanDelta:
delta_id: str
timestamp: float
changes: Dict[str, Any]
contract_id: str = ""
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
@staticmethod
def from_dict(d: Dict[str, Any]) -> "PlanDelta":
return PlanDelta(
delta_id=d.get("delta_id", ""),
timestamp=float(d.get("timestamp", 0.0)),
changes=dict(d.get("changes", {})),
contract_id=d.get("contract_id", ""),
)
@dataclass(frozen=True)
class CanonicalPlan:
projection: List[str]
predicates: List[str]
estimated_cost: float
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
@staticmethod
def from_dict(d: Dict[str, Any]) -> "CanonicalPlan":
return CanonicalPlan(
projection=list(d.get("projection", [])),
predicates=list(d.get("predicates", [])),
estimated_cost=float(d.get("estimated_cost", 0.0)),
)
@dataclass(frozen=True)
class DualVariables:
multipliers: Dict[str, float] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
return asdict(self)