nebulaforge-offline-resilie.../tests/test_federated.py

15 lines
413 B
Python

from nebulaforge.federated import SecureAggregator
def test_aggregate_basic():
agg = SecureAggregator()
updates = [[1.0, 2.0], [3.0, 4.0]]
out = agg.aggregate(updates)
assert out == [2.0, 3.0]
def test_dp_budget_clipping():
agg = SecureAggregator(dp_budget=1.0)
updates = [[2.0, -2.0]]
out = agg.aggregate(updates)
# clipped to +/-1
assert all(-1.0 <= v <= 1.0 for v in out)