55 lines
2.3 KiB
Python
55 lines
2.3 KiB
Python
import os
|
|
import sys
|
|
# Ensure repository root is on sys.path for package import during tests
|
|
ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|
if ROOT not in sys.path:
|
|
sys.path.insert(0, ROOT)
|
|
from interplanetary_edge_orchestrator_privacy import Client, Server
|
|
import random
|
|
def generate_dataset(n_samples: int, n_features: int, seed: int = 0):
|
|
rng = random.Random(seed)
|
|
X = [[rng.gauss(0.0, 1.0) for _ in range(n_features)] for _ in range(n_samples)]
|
|
true_w = [rng.gauss(0.0, 1.0) for _ in range(n_features)]
|
|
y = [sum(X[i][k] * true_w[k] for k in range(n_features)) + rng.gauss(0.0, 0.1) for i in range(n_samples)]
|
|
return X, y, true_w
|
|
|
|
|
|
def test_basic_federated_aggregation_improves_model():
|
|
random.seed(0)
|
|
n_features = 2
|
|
n_clients = 3
|
|
clients = []
|
|
for i in range(n_clients):
|
|
X, y, _ = generate_dataset(30, n_features, seed=i+1)
|
|
c = Client(client_id=i, data_X=X, data_y=y, connected=True)
|
|
clients.append(c)
|
|
server = Server(n_features)
|
|
|
|
# Initial global weights (zeros)
|
|
w_init = server.w.copy()
|
|
# Each client trains locally and returns its update (delta)
|
|
updates = []
|
|
for c in clients:
|
|
c.initialize(n_features)
|
|
update = c.train(server.w, lr=0.01, epochs=20)
|
|
updates.append(update)
|
|
|
|
# Aggregate updates on the server (no DP noise for determinism in test)
|
|
server.aggregate(updates, noise_scale=0.0, seed=123)
|
|
|
|
# Sanity: server weights should have moved away from initial zeros
|
|
assert isinstance(server.w, list) and len(server.w) == n_features
|
|
assert any(abs(server.w[i] - w_init[i]) > 1e-9 for i in range(n_features))
|
|
|
|
# Optional sanity: compute average loss reduction on clients after aggregation
|
|
total_initial_loss = 0.0
|
|
total_final_loss = 0.0
|
|
for c in clients:
|
|
# initial loss with zeros (predictions are zeros for all samples)
|
|
pred0 = [0.0 for _ in range(len(c.y))]
|
|
total_initial_loss += sum((pred0[i] - c.y[i]) ** 2 for i in range(len(c.y)))
|
|
# final loss with new global weights
|
|
pred1 = [sum(c.X[i][k] * server.w[k] for k in range(n_features)) for i in range(len(c.y))]
|
|
total_final_loss += sum((pred1[i] - c.y[i]) ** 2 for i in range(len(c.y)))
|
|
assert total_final_loss <= total_initial_loss + 1e-6
|