diff --git a/mltrail_verifiable_provenance_ledger_for/registry.py b/mltrail_verifiable_provenance_ledger_for/registry.py index a45c4dd..b43f70e 100644 --- a/mltrail_verifiable_provenance_ledger_for/registry.py +++ b/mltrail_verifiable_provenance_ledger_for/registry.py @@ -1,5 +1,19 @@ from typing import Dict, Any + +# Lightweight default schemas for core MLTrail contracts. This enables a +# canonical, language-agnostic registry that adapters can reference for +# conformance checks and schema validation in MVP deployments. +_DEFAULT_CONTRACT_SCHEMAS: Dict[str, Dict[str, Any]] = { + "Experiment": {"fields": ["id", "name", "version", "description", "metadata"]}, + "Run": {"fields": ["id", "experiment_id", "parameters", "metrics", "environment_hash"]}, + "Dataset": {"fields": ["id", "name", "version", "metadata"]}, + "Model": {"fields": ["id", "architecture", "fingerprint", "metadata"]}, + "Environment": {"fields": ["id", "language", "version", "dependencies", "container_hash"]}, + "EvaluationMetric": {"fields": ["name", "value", "unit"]}, + "Policy": {"fields": ["id", "rules", "metadata"]}, +} + class ContractRegistry: def __init__(self) -> None: self._contracts: Dict[str, Dict[str, Any]] = {} @@ -12,3 +26,13 @@ class ContractRegistry: def all_contracts(self) -> Dict[str, Any]: return self._contracts + + def register_default_contracts(self) -> None: + """Register a canonical set of core MLTrail contract schemas. + + This helps adapters and tooling validate payloads against expected + fields, enabling consistent interoperability across languages. + """ + for name, schema in _DEFAULT_CONTRACT_SCHEMAS.items(): + # Use a stable default version for defaults; allow user overrides if needed + self.register_contract(name, schema, version="1.0.0") diff --git a/tests/test_registry_defaults.py b/tests/test_registry_defaults.py new file mode 100644 index 0000000..4a0e74f --- /dev/null +++ b/tests/test_registry_defaults.py @@ -0,0 +1,14 @@ +from mltrail_verifiable_provenance_ledger_for.registry import ContractRegistry, _DEFAULT_CONTRACT_SCHEMAS + + +def test_default_contracts_registered(): + reg = ContractRegistry() + reg.register_default_contracts() + + # Ensure all canonical contracts are registered with their expected fields + for name, expected in _DEFAULT_CONTRACT_SCHEMAS.items(): + contract = reg.get_contract(name) + assert contract is not None, f"Contract {name} not registered" + assert contract.get("schema") is not None, f"Schema for {name} missing" + fields = contract["schema"].get("fields") + assert fields == expected["fields"], f"Fields for {name} do not match expected"