catopt-grid-category-theore.../tests/test_catopt_grid.py

34 lines
1.0 KiB
Python

import sys
import os
repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if repo_root not in sys.path:
sys.path.insert(0, repo_root)
from catopt_grid.core import LocalProblem
from catopt_grid.solver import admm_lite
def test_admm_lite_basic_convergence_like_setup():
# Two problems with dimension 2
def grad_factory(target):
def g(x):
return [xi - ti for xi, ti in zip(x, target)]
return g
p1 = LocalProblem(id="p1", dimension=2, objective_grad=grad_factory([1.0, 0.0]))
p2 = LocalProblem(id="p2", dimension=2, objective_grad=grad_factory([0.0, 1.0]))
res = admm_lite([p1, p2], rho=1.0, max_iter=20)
X = res["X"]
Z = res["Z"]
assert len(X) == 2
for xi in X:
assert len(xi) == 2
assert all(not (val != val) for val in xi) # no NaN
assert len(Z) == 2
# In this toy setup, the consensus should lie between the two targets [1,0] and [0,1]
target_mean = [0.5, 0.5]
assert all(abs(a - b) <= 0.5 for a, b in zip(Z, target_mean))