catopt-grid-category-theore.../catopt_grid/solver.py

51 lines
1.9 KiB
Python

from __future__ import annotations
from typing import List, Dict
from .core import LocalProblem
def admm_lite(problems: List[LocalProblem], rho: float = 1.0, max_iter: int = 50) -> Dict:
"""
A minimal ADMM-lite solver across a set of LocalProblem instances.
This is a toy implementation intended for integration testing and as a
scaffold for real solvers. Each problem provides a gradient function
objective_grad(x). If not provided, the gradient is assumed to be zero.
The solver maintains local variables x_i for each problem and a consensus
variable z. It performs a simple primal update followed by a consensus step.
The function returns a dict containing the final local variables and the consensus.
"""
if len(problems) == 0:
return {"X": [], "Z": None, "iterations": 0}
dims = [p.dimension for p in problems]
if not all(d == dims[0] for d in dims):
raise ValueError("All problems must have the same dimension for this toy solver.")
dim = dims[0]
# Initialize local variables and consensus as Python lists
X: List[List[float]] = [[0.0 for _ in range(dim)] for _ in problems]
Z: List[float] = [0.0 for _ in range(dim)]
def _grad(p: LocalProblem, x: List[float]) -> List[float]:
if p.objective_grad is None:
return [0.0 for _ in range(dim)]
return p.objective_grad(x)
for _ in range(max_iter):
# Local update (proximal-like step towards consensus Z)
for i, p in enumerate(problems):
g = _grad(p, X[i])
# X[i] = X[i] - (1/rho) * g - (1/rho) * (X[i] - Z)
for d in range(dim):
X[i][d] = X[i][d] - (1.0 / max(1e-8, rho)) * g[d] - (1.0 / max(1e-8, rho)) * (X[i][d] - Z[d])
# Global consensus update (element-wise average)
for d in range(dim):
Z[d] = sum(X[i][d] for i in range(len(X))) / len(X)
return {"X": X, "Z": Z, "iterations": max_iter}