gridverse-open-low-code-pla.../gridverse/registry.py

52 lines
2.2 KiB
Python

import time
import uuid
from typing import Dict, Any
from .contracts import LocalProblem, SharedVariables, PlanDelta, ConstraintSet, DeviceInfo
class GraphContractRegistry:
def __init__(self) -> None:
self._contracts: Dict[str, Dict[str, Any]] = {}
# Backward-compatible API: support both old and new signatures used by tests
def register_contract(self, contract_type: str, version: str, payload: Dict[str, Any]) -> None:
# New API: store by (type, version)
self._contracts[(contract_type, version)] = payload
def get_contract(self, contract_type: str, version: str) -> Dict[str, Any]:
return self._contracts.get((contract_type, version), {})
def conformance_check(self, contract: Dict[str, Any]) -> bool:
# Minimal conformance: ensure required keys exist
required = {"type", "payload"}
if not isinstance(contract, dict):
return False
return required.issubset(set(contract.keys()))
def register_adapter(self, adapter_type: str, version: str, payload: Dict[str, Any]) -> None:
# minimal adapter registry (store in a separate namespace)
if not hasattr(self, "_adapters"):
self._adapters = {}
self._adapters[(adapter_type, version)] = payload
def conformance_test(self, adapter_iface: Dict[str, Any], contract_schema: Dict[str, Any]) -> bool:
key = (adapter_iface.get("name"), adapter_iface.get("version"))
contract_key = (contract_schema.get("name"), contract_schema.get("version"))
adapters_ok = getattr(self, "_adapters", {}).get(key) is not None
contracts_ok = self._contracts.get(contract_key) is not None
return adapters_ok and contracts_ok
class ContractRegistry:
"""Backward-compatible, simplified registry interface used by tests.
Maps (type_name, version) -> contract payload dict.
"""
def __init__(self) -> None:
self._store: Dict[tuple, Any] = {}
def register_contract(self, contract_type: str, version: str, payload: Any) -> None:
self._store[(contract_type, version)] = payload
def get_contract(self, contract_type: str, version: str) -> Any:
return self._store.get((contract_type, version))