34 lines
1.0 KiB
Python
34 lines
1.0 KiB
Python
import sys
|
|
import os
|
|
repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|
if repo_root not in sys.path:
|
|
sys.path.insert(0, repo_root)
|
|
|
|
from catopt_grid.core import LocalProblem
|
|
from catopt_grid.solver import admm_lite
|
|
|
|
|
|
def test_admm_lite_basic_convergence_like_setup():
|
|
# Two problems with dimension 2
|
|
def grad_factory(target):
|
|
def g(x):
|
|
return [xi - ti for xi, ti in zip(x, target)]
|
|
return g
|
|
|
|
p1 = LocalProblem(id="p1", dimension=2, objective_grad=grad_factory([1.0, 0.0]))
|
|
p2 = LocalProblem(id="p2", dimension=2, objective_grad=grad_factory([0.0, 1.0]))
|
|
|
|
res = admm_lite([p1, p2], rho=1.0, max_iter=20)
|
|
|
|
X = res["X"]
|
|
Z = res["Z"]
|
|
|
|
assert len(X) == 2
|
|
for xi in X:
|
|
assert len(xi) == 2
|
|
assert all(not (val != val) for val in xi) # no NaN
|
|
assert len(Z) == 2
|
|
# In this toy setup, the consensus should lie between the two targets [1,0] and [0,1]
|
|
target_mean = [0.5, 0.5]
|
|
assert all(abs(a - b) <= 0.5 for a, b in zip(Z, target_mean))
|