From 895f41d19a39b463f6c573968d72730385eba146 Mon Sep 17 00:00:00 2001 From: agent-a6e6ec231c5f7801 Date: Sun, 19 Apr 2026 22:20:41 +0200 Subject: [PATCH] build(agent): new-agents#a6e6ec iteration --- README.md | 3 + catopt_query/__init__.py | 4 ++ catopt_query/adapters.py | 55 ++++++++++++++++ catopt_query/adapters/__init__.py | 54 +++++++++++++++- catopt_query/dsl.py | 100 ++++++++++++++++++++++++++++++ catopt_query/protocol.py | 82 ++++++++++++++++++++++++ tests/test_adapters_ext.py | 34 ++++++++++ tests/test_dsl.py | 38 ++++++++++++ tests/test_protocol_ext.py | 32 ++++++++++ 9 files changed, 399 insertions(+), 3 deletions(-) create mode 100644 catopt_query/dsl.py create mode 100644 tests/test_adapters_ext.py create mode 100644 tests/test_dsl.py create mode 100644 tests/test_protocol_ext.py diff --git a/README.md b/README.md index 1fd9d20..328e96d 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,9 @@ Project structure - catopt_query/solver.py: tiny ADMM-lite style cross-shard planner - tests/: pytest-based tests for protocol, adapters, and solver +- DSL Primitives (new): +- PrivacyBudget, AuditLog, PolicyBlock, GraphOfContracts and extended governance hooks you can use to describe cross-shard contracts and privacy constraints. + How to run - Run tests: pytest -q - Build: python -m build diff --git a/catopt_query/__init__.py b/catopt_query/__init__.py index 7fa3b04..46f36b3 100644 --- a/catopt_query/__init__.py +++ b/catopt_query/__init__.py @@ -4,6 +4,7 @@ Exports a small, testable surface for the MVP. """ from .protocol import LocalProblem, SharedVariables, PlanDelta, CanonicalPlan +from .dsl import LocalProblemDSL, SharedVariablesDSL, PlanDeltaDSL from .solver import optimize_across_shards __all__ = [ @@ -12,4 +13,7 @@ __all__ = [ "PlanDelta", "CanonicalPlan", "optimize_across_shards", + "LocalProblemDSL", + "SharedVariablesDSL", + "PlanDeltaDSL", ] diff --git a/catopt_query/adapters.py b/catopt_query/adapters.py index a111e0e..88d5e6d 100644 --- a/catopt_query/adapters.py +++ b/catopt_query/adapters.py @@ -25,3 +25,58 @@ class Adapter: predicates=vendor_plan.predicates, estimated_cost=float(vendor_plan.price), ) + + +@dataclass +class PostgresVendorPlan: + shard_id: str + table: str + projection: list + predicates: list + price: float + + +class PostgresAdapter(Adapter): + """Concrete adapter mapping a PostgresVendorPlan into a CanonicalPlan.""" + + def to_canonical(self, vendor_plan: VendorPlan) -> CanonicalPlan: + # Accept either a PostgresVendorPlan or a generic VendorPlan-like object + if isinstance(vendor_plan, PostgresVendorPlan): + vp = vendor_plan + return CanonicalPlan( + projection=vp.projection, + predicates=vp.predicates, + estimated_cost=float(vp.price), + ) + proj = getattr(vendor_plan, "projection", []) + preds = getattr(vendor_plan, "predicates", []) + price = float(getattr(vendor_plan, "price", 0.0)) + return CanonicalPlan(projection=proj, predicates=preds, estimated_cost=price) + + +@dataclass +class MongoVendorPlan: + shard_id: str + collection: str + projection: list + predicates: list + price: float + + +class MongoAdapter(Adapter): + """Concrete adapter mapping a MongoVendorPlan into a CanonicalPlan.""" + + def to_canonical(self, vendor_plan: VendorPlan) -> CanonicalPlan: + # Accept either a MongoVendorPlan or a generic VendorPlan-like object + if isinstance(vendor_plan, MongoVendorPlan): + vp = vendor_plan + return CanonicalPlan( + projection=vp.projection, + predicates=vp.predicates, + estimated_cost=float(vp.price), + ) + # Fallback: try to treat as a generic VendorPlan + proj = getattr(vendor_plan, "projection", []) + preds = getattr(vendor_plan, "predicates", []) + price = float(getattr(vendor_plan, "price", 0.0)) + return CanonicalPlan(projection=proj, predicates=preds, estimated_cost=price) diff --git a/catopt_query/adapters/__init__.py b/catopt_query/adapters/__init__.py index 60ae175..52dcc0f 100644 --- a/catopt_query/adapters/__init__.py +++ b/catopt_query/adapters/__init__.py @@ -2,8 +2,10 @@ This module intentionally avoids importing heavy vendor adapters at import time to keep unit tests fast and isolated. It provides minimal, test-aligned -interfaces: VendorPlan (a vendor-supplied plan) and Adapter (translator to -the canonical protocol).""" +interfaces. It now includes vendor-specific adapters (Postgres and Mongo) and +their corresponding vendor plan data structures, in addition to the generic +Adapter/VendorPlan compatibility for older tests. +""" from typing import List from dataclasses import dataclass @@ -11,6 +13,45 @@ from dataclasses import dataclass from catopt_query.protocol import CanonicalPlan +@dataclass(frozen=True) +class PostgresVendorPlan: + shard_id: str + table: str + projection: List[str] + predicates: List[str] + price: float + + +class PostgresAdapter: + def to_canonical(self, vp: PostgresVendorPlan) -> CanonicalPlan: + # Map a PostgreSQL-style vendor plan into the canonical plan structure + return CanonicalPlan( + projection=vp.projection, + predicates=vp.predicates, + estimated_cost=vp.price, + ) + + +@dataclass(frozen=True) +class MongoVendorPlan: + shard_id: str + collection: str + projection: List[str] + predicates: List[str] + price: float + + +class MongoAdapter: + def to_canonical(self, vp: MongoVendorPlan) -> CanonicalPlan: + # Map a MongoDB-like vendor plan into the canonical plan structure + return CanonicalPlan( + projection=vp.projection, + predicates=vp.predicates, + estimated_cost=vp.price, + ) + + +# Backwards-compatibility generic types (left intact for older tests) @dataclass(frozen=True) class VendorPlan: shard_id: str @@ -29,4 +70,11 @@ class Adapter: ) -__all__ = ["Adapter", "VendorPlan"] +__all__ = [ + "PostgresAdapter", + "PostgresVendorPlan", + "MongoAdapter", + "MongoVendorPlan", + "Adapter", + "VendorPlan", +] diff --git a/catopt_query/dsl.py b/catopt_query/dsl.py new file mode 100644 index 0000000..46d9284 --- /dev/null +++ b/catopt_query/dsl.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from dataclasses import dataclass, asdict +from typing import Dict, List, Any + +from .protocol import LocalProblem, SharedVariables, PlanDelta + + +@dataclass +class LocalProblemDSL: + """Minimal DSL wrapper for a per-shard LocalProblem. + + This DSL is designed to be mapped into the canonical LocalProblem protocol + using adapters. It captures the same conceptual fields as LocalProblem, but + presents a stable, serializable interface for cross-adapter contracts. + """ + shard_id: str + projection: List[str] + predicates: List[str] + costs: float + constraints: Dict[str, Any] + + def to_protocol(self) -> LocalProblem: + # Use the canonical constructor; LocalProblem accepts both 'projection' and + # 'projected_attrs' aliases for compatibility. + return LocalProblem( + shard_id=self.shard_id, + projection=self.projection, + predicates=self.predicates, + costs=self.costs, + constraints=self.constraints, + ) + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + @staticmethod + def from_dict(d: Dict[str, Any]) -> "LocalProblemDSL": + if d is None: + d = {} + return LocalProblemDSL( + shard_id=d.get("shard_id", ""), + projection=d.get("projection", d.get("projected_attrs", [])), + predicates=d.get("predicates", []), + costs=float(d.get("costs", 0.0)), + constraints=d.get("constraints", {}), + ) + + +@dataclass(frozen=True) +class SharedVariablesDSL: + version: int + signals: Dict[str, float] + priors: Dict[str, float] + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + @staticmethod + def from_dict(d: Dict[str, Any]) -> "SharedVariablesDSL": + return SharedVariablesDSL( + version=int(d.get("version", 0)), + signals=dict(d.get("signals", {})), + priors=dict(d.get("priors", {})), + ) + + def to_protocol(self) -> SharedVariables: + # Convert to the canonical SharedVariables dataclass defined in protocol.py + return SharedVariables( + version=self.version, + signals=self.signals, + priors=self.priors, + ) + + +@dataclass(frozen=True) +class PlanDeltaDSL: + delta_id: str + timestamp: float + changes: Dict[str, Any] + contract_id: str = "" + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + def to_protocol(self) -> PlanDelta: + return PlanDelta( + delta_id=self.delta_id, + timestamp=self.timestamp, + changes=dict(self.changes), + contract_id=self.contract_id, + ) + + @staticmethod + def from_dict(d: Dict[str, Any]) -> "PlanDeltaDSL": + return PlanDeltaDSL( + delta_id=d.get("delta_id", ""), + timestamp=float(d.get("timestamp", 0.0)), + changes=dict(d.get("changes", {})), + contract_id=d.get("contract_id", ""), + ) diff --git a/catopt_query/protocol.py b/catopt_query/protocol.py index 7a417e6..9b004e9 100644 --- a/catopt_query/protocol.py +++ b/catopt_query/protocol.py @@ -107,3 +107,85 @@ class DualVariables: def to_dict(self) -> Dict[str, Any]: return asdict(self) + + +@dataclass(frozen=True) +class PrivacyBudget: + """Simple privacy budget token for a signal. + + This lightweight abstraction can carry a budget and an expiry, allowing + adapters to honor privacy constraints when sharing signals across shards. + """ + signal: float + budget: float + expiry: float # timestamp (epoch seconds) + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + @staticmethod + def from_dict(d: Dict[str, Any]) -> "PrivacyBudget": + return PrivacyBudget( + signal=float(d.get("signal", 0.0)), + budget=float(d.get("budget", 0.0)), + expiry=float(d.get("expiry", 0.0)), + ) + + +@dataclass(frozen=True) +class AuditLog: + """Tamper-evident-like audit entry for governance and replay tracing.""" + entry: str + signer: str + timestamp: float + contract_id: str + version: int + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + @staticmethod + def from_dict(d: Dict[str, Any]) -> "AuditLog": + return AuditLog( + entry=str(d.get("entry", "")), + signer=str(d.get("signer", "")), + timestamp=float(d.get("timestamp", 0.0)), + contract_id=str(d.get("contract_id", "")), + version=int(d.get("version", 0)), + ) + + +@dataclass(frozen=True) +class PolicyBlock: + """Contain safety and exposure controls for a data-sharing contract.""" + safety: str + exposure_controls: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + @staticmethod + def from_dict(d: Dict[str, Any]) -> "PolicyBlock": + return PolicyBlock( + safety=str(d.get("safety", "")), + exposure_controls=dict(d.get("exposure_controls", {})), + ) + + +@dataclass(frozen=True) +class GraphOfContracts: + """Lightweight registry entry mapping adapters to domains and versions.""" + adapter_id: str + supported_domains: List[str] + contract_version: str + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + @staticmethod + def from_dict(d: Dict[str, Any]) -> "GraphOfContracts": + return GraphOfContracts( + adapter_id=str(d.get("adapter_id", "")), + supported_domains=list(d.get("supported_domains", [])), + contract_version=str(d.get("contract_version", "")), + ) diff --git a/tests/test_adapters_ext.py b/tests/test_adapters_ext.py new file mode 100644 index 0000000..d4a7c07 --- /dev/null +++ b/tests/test_adapters_ext.py @@ -0,0 +1,34 @@ +from catopt_query.adapters import PostgresAdapter, PostgresVendorPlan, MongoAdapter, MongoVendorPlan +from catopt_query.protocol import CanonicalPlan + + +def test_postgres_adapter_roundtrip(): + vendor = PostgresVendorPlan( + shard_id="s1", + table="t1", + projection=["a", "b"], + predicates=["a>0"], + price=10.0, + ) + adapter = PostgresAdapter() + canon = adapter.to_canonical(vendor) + assert isinstance(canon, CanonicalPlan) + assert canon.projection == ["a", "b"] + assert canon.predicates == ["a>0"] + assert canon.estimated_cost == 10.0 + + +def test_mongo_adapter_roundtrip(): + vendor = MongoVendorPlan( + shard_id="s2", + collection="coll", + projection=["x"], + predicates=["x<5"], + price=5.5, + ) + adapter = MongoAdapter() + canon = adapter.to_canonical(vendor) + assert isinstance(canon, CanonicalPlan) + assert canon.projection == ["x"] + assert canon.predicates == ["x<5"] + assert canon.estimated_cost == 5.5 diff --git a/tests/test_dsl.py b/tests/test_dsl.py new file mode 100644 index 0000000..ece2af0 --- /dev/null +++ b/tests/test_dsl.py @@ -0,0 +1,38 @@ +from catopt_query.dsl import LocalProblemDSL, SharedVariablesDSL, PlanDeltaDSL +from catopt_query.protocol import LocalProblem, SharedVariables, PlanDelta + + +def test_local_problem_dsl_to_protocol(): + dsl = LocalProblemDSL( + shard_id="shard-1", + projection=["a", "b"], + predicates=["a>0"], + costs=1.5, + constraints={"limit": 100}, + ) + lp = dsl.to_protocol() + assert isinstance(lp, LocalProblem) + assert lp.shard_id == dsl.shard_id + assert lp.projection == dsl.projection + assert lp.predicates == dsl.predicates + assert lp.costs == dsl.costs + assert lp.constraints == dsl.constraints + + +def test_sharedvariables_dsl_to_protocol(): + dsl = SharedVariablesDSL(version=1, signals={"x": 0.5}, priors={"x": 0.2}) + sv = dsl.to_protocol() + assert isinstance(sv, SharedVariables) + assert sv.version == dsl.version + assert sv.signals == dsl.signals + assert sv.priors == dsl.priors + + +def test_plan_delta_dsl_to_protocol(): + dsl = PlanDeltaDSL(delta_id="d1", timestamp=123.0, changes={"a": 1}, contract_id="c0") + pd = dsl.to_protocol() + assert isinstance(pd, PlanDelta) + assert pd.delta_id == dsl.delta_id + assert pd.timestamp == dsl.timestamp + assert pd.changes == dsl.changes + assert pd.contract_id == dsl.contract_id diff --git a/tests/test_protocol_ext.py b/tests/test_protocol_ext.py new file mode 100644 index 0000000..f5e32f0 --- /dev/null +++ b/tests/test_protocol_ext.py @@ -0,0 +1,32 @@ +import math +from catopt_query.protocol import PrivacyBudget, AuditLog, PolicyBlock, GraphOfContracts, SharedVariables, PlanDelta + + +def test_privacy_budget_roundtrip(): + d = { + "signal": 0.5, + "budget": 1.5, + "expiry": 1700000000.0, + } + pb = PrivacyBudget.from_dict(d) + assert pb.signal == 0.5 + assert pb.budget == 1.5 + assert pb.expiry == 1700000000.0 + assert PrivacyBudget.from_dict(pb.to_dict()) == pb + + +def test_audit_log_roundtrip(): + al = AuditLog(entry="test-entry", signer="alice", timestamp=1234.5, contract_id="c1", version=2) + as_dict = al.to_dict() + al2 = AuditLog.from_dict(as_dict) + assert al2 == al + + +def test_policy_block_roundtrip(): + pb = PolicyBlock(safety="strict", exposure_controls={"read": True, "write": False}) + assert PolicyBlock.from_dict(pb.to_dict()) == pb + + +def test_graph_of_contracts_roundtrip(): + go = GraphOfContracts(adapter_id="mongo", supported_domains=["finance", "sensor"], contract_version="v1") + assert GraphOfContracts.from_dict(go.to_dict()) == go