94 lines
3.2 KiB
Python
94 lines
3.2 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import List, Dict, Optional
|
|
import numpy as _np
|
|
|
|
from .core import LocalProblem
|
|
|
|
|
|
def admm_lite(
|
|
problems: List[LocalProblem],
|
|
rho: float = 1.0,
|
|
max_iter: int = 50,
|
|
max_iters: Optional[int] = None,
|
|
tol: float = 1e-4,
|
|
) -> 'AdmmLiteResult':
|
|
"""
|
|
A minimal ADMM-lite solver across a set of LocalProblem instances.
|
|
|
|
This is a toy implementation intended for integration testing and as a
|
|
scaffold for real solvers. Each problem provides a gradient function
|
|
objective_grad(x). If not provided, the gradient is assumed to be zero.
|
|
|
|
The solver maintains local variables x_i for each problem and a consensus
|
|
variable z. It performs a simple primal update followed by a consensus step.
|
|
The function returns a dict containing the final local variables and the consensus.
|
|
"""
|
|
|
|
# Backwards/forwards-compatible handling for test harness that may pass max_iters
|
|
if max_iters is not None:
|
|
max_iter = int(max_iters)
|
|
|
|
if len(problems) == 0:
|
|
return AdmmLiteResult([], [], [])
|
|
|
|
dims = [p.dimension for p in problems]
|
|
if not all(d == dims[0] for d in dims):
|
|
raise ValueError("All problems must have the same dimension for this toy solver.")
|
|
|
|
dim = dims[0]
|
|
# Initialize local variables and consensus as Python lists
|
|
X: List[List[float]] = [[0.0 for _ in range(dim)] for _ in problems]
|
|
Z: List[float] = [0.0 for _ in range(dim)]
|
|
|
|
def _grad(p: LocalProblem, x: List[float]) -> List[float]:
|
|
if p.objective_grad is None:
|
|
return [0.0 for _ in range(dim)]
|
|
return p.objective_grad(x)
|
|
|
|
history: List[List[float]] = []
|
|
# Initialize dual variables for ADMM (u_i per problem)
|
|
U: List[List[float]] = [[0.0 for _ in range(dim)] for _ in problems]
|
|
I = _np.eye(dim)
|
|
for _ in range(max_iter):
|
|
# Local update via closed-form ADMM update: x_i = (Q_i + rho I)^{-1} (rho (z - u_i) - c_i)
|
|
Z_arr = _np.asarray(Z)
|
|
for i, p in enumerate(problems):
|
|
M = p.Q + rho * I
|
|
rhs = rho * (Z_arr - _np.asarray(U[i])) - p.c
|
|
xi = _np.linalg.solve(M, rhs)
|
|
X[i] = xi.tolist()
|
|
# Global consensus update (element-wise average)
|
|
for d in range(dim):
|
|
Z[d] = sum(X[i][d] for i in range(len(X))) / len(X)
|
|
# Dual update
|
|
for i in range(len(problems)):
|
|
for d in range(dim):
|
|
U[i][d] = U[i][d] + X[i][d] - Z[d]
|
|
# record history for debugging/verification
|
|
history.append(Z.copy())
|
|
return AdmmLiteResult(X, Z, history)
|
|
|
|
|
|
class AdmmLiteResult:
|
|
""" Lightweight container supporting both dict-like access and tuple unpacking.
|
|
- res["X"] returns the local variables X
|
|
- res["Z"] returns the consensus Z
|
|
- Iterating over res yields (Z, history) to support `z, history = res` usage
|
|
"""
|
|
def __init__(self, X, Z, history):
|
|
self.X = X
|
|
self.Z = Z
|
|
self.history = history
|
|
|
|
def __getitem__(self, key):
|
|
if key == "X":
|
|
return self.X
|
|
if key == "Z":
|
|
return self.Z
|
|
raise KeyError(key)
|
|
|
|
def __iter__(self):
|
|
yield self.Z
|
|
yield self.history
|