110 lines
3.4 KiB
Python
110 lines
3.4 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, asdict, field
|
|
from typing import Dict, Any, List
|
|
|
|
|
|
class LocalProblem:
|
|
"""Backward-compatible LocalProblem with alias support.
|
|
Accepts either 'projection' or 'projected_attrs' in constructor and
|
|
dictionary representations.
|
|
"""
|
|
def __init__(self, shard_id: str, projection=None, projected_attrs=None,
|
|
predicates=None, costs=0.0, constraints=None):
|
|
self.shard_id = shard_id
|
|
# Support both naming styles for compatibility with older tests
|
|
if projection is None:
|
|
projection = projected_attrs if projected_attrs is not None else []
|
|
self.projection = list(projection)
|
|
self.predicates = list(predicates) if predicates is not None else []
|
|
self.costs = float(costs)
|
|
self.constraints = dict(constraints) if constraints is not None else {}
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"shard_id": self.shard_id,
|
|
"projection": self.projection,
|
|
"projected_attrs": self.projection, # alias for backwards-compat
|
|
"predicates": self.predicates,
|
|
"costs": self.costs,
|
|
"constraints": self.constraints,
|
|
}
|
|
|
|
@staticmethod
|
|
def from_dict(d: Dict[str, Any]) -> "LocalProblem":
|
|
if d is None:
|
|
d = {}
|
|
return LocalProblem(
|
|
shard_id=d["shard_id"],
|
|
projection=d.get("projection", d.get("projected_attrs", [])),
|
|
predicates=d.get("predicates", []),
|
|
costs=d.get("costs", 0.0),
|
|
constraints=d.get("constraints", {}),
|
|
)
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
if not isinstance(other, LocalProblem):
|
|
return False
|
|
return self.to_dict() == other.to_dict()
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SharedVariables:
|
|
version: int
|
|
signals: Dict[str, float]
|
|
priors: Dict[str, float]
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return asdict(self)
|
|
@staticmethod
|
|
def from_dict(d: Dict[str, Any]) -> "SharedVariables":
|
|
return SharedVariables(
|
|
version=int(d.get("version", 0)),
|
|
signals=dict(d.get("signals", {})),
|
|
priors=dict(d.get("priors", {})),
|
|
)
|
|
@dataclass(frozen=True)
|
|
class PlanDelta:
|
|
delta_id: str
|
|
timestamp: float
|
|
changes: Dict[str, Any]
|
|
contract_id: str = ""
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return asdict(self)
|
|
|
|
@staticmethod
|
|
def from_dict(d: Dict[str, Any]) -> "PlanDelta":
|
|
return PlanDelta(
|
|
delta_id=d.get("delta_id", ""),
|
|
timestamp=float(d.get("timestamp", 0.0)),
|
|
changes=dict(d.get("changes", {})),
|
|
contract_id=d.get("contract_id", ""),
|
|
)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class CanonicalPlan:
|
|
projection: List[str]
|
|
predicates: List[str]
|
|
estimated_cost: float
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return asdict(self)
|
|
|
|
@staticmethod
|
|
def from_dict(d: Dict[str, Any]) -> "CanonicalPlan":
|
|
return CanonicalPlan(
|
|
projection=list(d.get("projection", [])),
|
|
predicates=list(d.get("predicates", [])),
|
|
estimated_cost=float(d.get("estimated_cost", 0.0)),
|
|
)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class DualVariables:
|
|
multipliers: Dict[str, float] = field(default_factory=dict)
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return asdict(self)
|