from __future__ import annotations from typing import List, Dict, Optional import numpy as _np from .core import LocalProblem def admm_lite( problems: List[LocalProblem], rho: float = 1.0, max_iter: int = 50, max_iters: Optional[int] = None, tol: float = 1e-4, ) -> 'AdmmLiteResult': """ A minimal ADMM-lite solver across a set of LocalProblem instances. This is a toy implementation intended for integration testing and as a scaffold for real solvers. Each problem provides a gradient function objective_grad(x). If not provided, the gradient is assumed to be zero. The solver maintains local variables x_i for each problem and a consensus variable z. It performs a simple primal update followed by a consensus step. The function returns a dict containing the final local variables and the consensus. """ # Backwards/forwards-compatible handling for test harness that may pass max_iters if max_iters is not None: max_iter = int(max_iters) if len(problems) == 0: return AdmmLiteResult([], [], []) dims = [p.dimension for p in problems] if not all(d == dims[0] for d in dims): raise ValueError("All problems must have the same dimension for this toy solver.") dim = dims[0] # Initialize local variables and consensus as Python lists X: List[List[float]] = [[0.0 for _ in range(dim)] for _ in problems] Z: List[float] = [0.0 for _ in range(dim)] def _grad(p: LocalProblem, x: List[float]) -> List[float]: if p.objective_grad is None: return [0.0 for _ in range(dim)] return p.objective_grad(x) history: List[List[float]] = [] # Initialize dual variables for ADMM (u_i per problem) U: List[List[float]] = [[0.0 for _ in range(dim)] for _ in problems] I = _np.eye(dim) for _ in range(max_iter): # Local update via closed-form ADMM update: x_i = (Q_i + rho I)^{-1} (rho (z - u_i) - c_i) Z_arr = _np.asarray(Z) for i, p in enumerate(problems): M = p.Q + rho * I rhs = rho * (Z_arr - _np.asarray(U[i])) - p.c xi = _np.linalg.solve(M, rhs) X[i] = xi.tolist() # Global consensus update (element-wise average) for d in range(dim): Z[d] = sum(X[i][d] for i in range(len(X))) / len(X) # Dual update for i in range(len(problems)): for d in range(dim): U[i][d] = U[i][d] + X[i][d] - Z[d] # record history for debugging/verification history.append(Z.copy()) return AdmmLiteResult(X, Z, history) class AdmmLiteResult: """ Lightweight container supporting both dict-like access and tuple unpacking. - res["X"] returns the local variables X - res["Z"] returns the consensus Z - Iterating over res yields (Z, history) to support `z, history = res` usage """ def __init__(self, X, Z, history): self.X = X self.Z = Z self.history = history def __getitem__(self, key): if key == "X": return self.X if key == "Z": return self.Z raise KeyError(key) def __iter__(self): yield self.Z yield self.history