citygrid-policy-driven-fede.../citygrid/solver/admm_lite.py

21 lines
750 B
Python

from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, Any
@dataclass
class AdmmState:
local_primal: Dict[str, float]
global_dual: float
rho: float = 1.0
def admm_step(state: AdmmState, local_update: Dict[str, float], global_update: float) -> AdmmState:
# Very small, toy ADMM step for MVP purposes
# Update local primal with local_update (simple averaging)
new_local = {k: (local_update.get(k, 0.0) + state.local_primal.get(k, 0.0)) / 2.0 for k in set(local_update) | set(state.local_primal)}
# Update global dual with a simple delta
new_global = (state.global_dual + global_update) * 0.5
return AdmmState(local_primal=new_local, global_dual=new_global, rho=state.rho)