21 lines
671 B
Python
21 lines
671 B
Python
import math
|
|
import random
|
|
from interplanetary_edge_orchestrator_privacy import Client, Server
|
|
|
|
|
|
def test_training_clipping_applies():
|
|
random.seed(0)
|
|
# Create a dataset with large feature values to encourage large updates
|
|
X = [[100.0, 100.0], [100.0, -100.0]]
|
|
y = [0.0, 0.0]
|
|
c = Client(client_id=99, data_X=X, data_y=y, connected=True)
|
|
c.initialize(n_features=2)
|
|
w = [0.0, 0.0]
|
|
|
|
# Clip updates to a small norm to enforce DP-like behavior
|
|
update = c.train(w, lr=0.01, epochs=1, clip_norm=1.0)
|
|
|
|
# verify that the resulting update has L2 norm <= clip_norm
|
|
norm = math.sqrt(sum(v * v for v in update))
|
|
assert norm <= 1.0 + 1e-9
|