15 lines
413 B
Python
15 lines
413 B
Python
from nebulaforge.federated import SecureAggregator
|
|
|
|
def test_aggregate_basic():
|
|
agg = SecureAggregator()
|
|
updates = [[1.0, 2.0], [3.0, 4.0]]
|
|
out = agg.aggregate(updates)
|
|
assert out == [2.0, 3.0]
|
|
|
|
def test_dp_budget_clipping():
|
|
agg = SecureAggregator(dp_budget=1.0)
|
|
updates = [[2.0, -2.0]]
|
|
out = agg.aggregate(updates)
|
|
# clipped to +/-1
|
|
assert all(-1.0 <= v <= 1.0 for v in out)
|