catopt-grid-category-theore.../catopt_grid/solver.py

94 lines
3.2 KiB
Python

from __future__ import annotations
from typing import List, Dict, Optional
import numpy as _np
from .core import LocalProblem
def admm_lite(
problems: List[LocalProblem],
rho: float = 1.0,
max_iter: int = 50,
max_iters: Optional[int] = None,
tol: float = 1e-4,
) -> 'AdmmLiteResult':
"""
A minimal ADMM-lite solver across a set of LocalProblem instances.
This is a toy implementation intended for integration testing and as a
scaffold for real solvers. Each problem provides a gradient function
objective_grad(x). If not provided, the gradient is assumed to be zero.
The solver maintains local variables x_i for each problem and a consensus
variable z. It performs a simple primal update followed by a consensus step.
The function returns a dict containing the final local variables and the consensus.
"""
# Backwards/forwards-compatible handling for test harness that may pass max_iters
if max_iters is not None:
max_iter = int(max_iters)
if len(problems) == 0:
return AdmmLiteResult([], [], [])
dims = [p.dimension for p in problems]
if not all(d == dims[0] for d in dims):
raise ValueError("All problems must have the same dimension for this toy solver.")
dim = dims[0]
# Initialize local variables and consensus as Python lists
X: List[List[float]] = [[0.0 for _ in range(dim)] for _ in problems]
Z: List[float] = [0.0 for _ in range(dim)]
def _grad(p: LocalProblem, x: List[float]) -> List[float]:
if p.objective_grad is None:
return [0.0 for _ in range(dim)]
return p.objective_grad(x)
history: List[List[float]] = []
# Initialize dual variables for ADMM (u_i per problem)
U: List[List[float]] = [[0.0 for _ in range(dim)] for _ in problems]
I = _np.eye(dim)
for _ in range(max_iter):
# Local update via closed-form ADMM update: x_i = (Q_i + rho I)^{-1} (rho (z - u_i) - c_i)
Z_arr = _np.asarray(Z)
for i, p in enumerate(problems):
M = p.Q + rho * I
rhs = rho * (Z_arr - _np.asarray(U[i])) - p.c
xi = _np.linalg.solve(M, rhs)
X[i] = xi.tolist()
# Global consensus update (element-wise average)
for d in range(dim):
Z[d] = sum(X[i][d] for i in range(len(X))) / len(X)
# Dual update
for i in range(len(problems)):
for d in range(dim):
U[i][d] = U[i][d] + X[i][d] - Z[d]
# record history for debugging/verification
history.append(Z.copy())
return AdmmLiteResult(X, Z, history)
class AdmmLiteResult:
""" Lightweight container supporting both dict-like access and tuple unpacking.
- res["X"] returns the local variables X
- res["Z"] returns the consensus Z
- Iterating over res yields (Z, history) to support `z, history = res` usage
"""
def __init__(self, X, Z, history):
self.X = X
self.Z = Z
self.history = history
def __getitem__(self, key):
if key == "X":
return self.X
if key == "Z":
return self.Z
raise KeyError(key)
def __iter__(self):
yield self.Z
yield self.history