catopt-graph-graph-calculus.../core/contract_registry.py

65 lines
2.8 KiB
Python

from __future__ import annotations
from typing import Any, Dict, Optional
class ContractRegistry:
"""Lightweight, in-memory, versioned contract registry for MVP tests.
- Contracts are registered by name and version.
- Each contract has a schema (dict) describing LocalProblem/SharedVariables/etc.
- Exposes simple get_contract(name, version), register_contract(...), and conformance checks.
- Exposes list_versions(name) to mirror test expectations.
"""
def __init__(self) -> None:
self._store: Dict[str, Dict[str, Dict[str, Any]]] = {}
def register_contract(self, name: str, version: str, schema: Dict[str, Any]) -> None:
"""Register or update a contract schema for a given name/version."""
# Store a shallow copy to avoid accidental external mutation
self._store.setdefault(name, {})[version] = dict(schema)
def get_contract(self, name: str, version: str) -> Optional[Dict[str, Any]]:
return self._store.get(name, {}).get(version)
def list_versions(self, name: str) -> list[str]:
return sorted(list(self._store.get(name, {}).keys()))
def conformance_check(self, name: str, version: str, adapter_data: Dict[str, Any]) -> bool:
"""Check if adapter_data satisfies the contract's required_fields.
Returns True if all required_fields (defined in schema) are present in adapter_data.
If no contract or required_fields are defined, returns False to be safe.
"""
contract = self.get_contract(name, version)
if contract is None:
return False
required = contract.get("required_fields") or []
if not isinstance(required, (list, tuple)):
return False
return all(field in adapter_data for field in required)
def migrate_contract(self, name: str, old_version: str, new_version: str, new_schema: Dict[str, Any]) -> None:
"""Migrate a contract from an old version to a new version.
- Mark the old_version as migrated_to new_version (if present).
- Register the new_version with the provided new_schema.
This enables simple, MVP-aligned contract evolution for adapters during the MVP lifecycle.
"""
# Ensure parent namespace exists
self._store.setdefault(name, {})
# Mark old version as migrated to new version if it exists
if old_version in self._store.get(name, {}):
self._store[name][old_version]["migrated_to"] = new_version
# Register the new version with its schema
self._store[name][new_version] = dict(new_schema)
def get_migration_target(self, name: str, version: str) -> Optional[str]:
"""Return the version that this version migrated to, if any."""
contract = self.get_contract(name, version)
if contract is None:
return None
return contract.get("migrated_to")