21 lines
750 B
Python
21 lines
750 B
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Dict, Any
|
|
|
|
|
|
@dataclass
|
|
class AdmmState:
|
|
local_primal: Dict[str, float]
|
|
global_dual: float
|
|
rho: float = 1.0
|
|
|
|
|
|
def admm_step(state: AdmmState, local_update: Dict[str, float], global_update: float) -> AdmmState:
|
|
# Very small, toy ADMM step for MVP purposes
|
|
# Update local primal with local_update (simple averaging)
|
|
new_local = {k: (local_update.get(k, 0.0) + state.local_primal.get(k, 0.0)) / 2.0 for k in set(local_update) | set(state.local_primal)}
|
|
# Update global dual with a simple delta
|
|
new_global = (state.global_dual + global_update) * 0.5
|
|
return AdmmState(local_primal=new_local, global_dual=new_global, rho=state.rho)
|