diff --git a/core/contract_registry.py b/core/contract_registry.py index 5d2dd41..060989e 100644 --- a/core/contract_registry.py +++ b/core/contract_registry.py @@ -1,41 +1,42 @@ -from typing import Dict, List, Any, Optional +from __future__ import annotations + +from typing import Any, Dict, Optional class ContractRegistry: - """ - Lightweight, versioned contract registry for CatOpt-Graph MVP. - Stores schema definitions per contract name and version. - Provides a simple conformance check against adapter-provided data. + """Lightweight, in-memory, versioned contract registry for MVP tests. + + - Contracts are registered by name and version. + - Each contract has a schema (dict) describing LocalProblem/SharedVariables/etc. + - Exposes simple get_contract(name, version), register_contract(...), and conformance checks. + - Exposes list_versions(name) to mirror test expectations. """ def __init__(self) -> None: - # Structure: { name: { version: schema_dict } } - self._registry: Dict[str, Dict[str, Dict[str, Any]]] = {} + self._store: Dict[str, Dict[str, Dict[str, Any]]] = {} def register_contract(self, name: str, version: str, schema: Dict[str, Any]) -> None: - if name not in self._registry: - self._registry[name] = {} - self._registry[name][version] = schema + """Register or update a contract schema for a given name/version.""" + # Store a shallow copy to avoid accidental external mutation + self._store.setdefault(name, {})[version] = dict(schema) def get_contract(self, name: str, version: str) -> Optional[Dict[str, Any]]: - return self._registry.get(name, {}).get(version) + return self._store.get(name, {}).get(version) - def list_versions(self, name: str) -> List[str]: - return list(self._registry.get(name, {}).keys()) + def list_versions(self, name: str) -> list[str]: + return sorted(list(self._store.get(name, {}).keys())) def conformance_check(self, name: str, version: str, adapter_data: Dict[str, Any]) -> bool: - """ - Very lightweight conformance check: - - The contract schema defines required_fields. - - adapter_data must contain all required fields at top level. - This is a stub to be extended by real conformance tests. + """Check if adapter_data satisfies the contract's required_fields. + + Returns True if all required_fields (defined in schema) are present in adapter_data. + If no contract or required_fields are defined, returns False to be safe. """ contract = self.get_contract(name, version) if contract is None: return False - required = contract.get("required_fields", []) - # If required_fields not provided, assume no conformance requirement - for field in required: - if field not in adapter_data: - return False - return True + + required = contract.get("required_fields") or [] + if not isinstance(required, (list, tuple)): + return False + return all(field in adapter_data for field in required)