interplanetary-edge-orchest.../tests/test_federated.py

55 lines
2.3 KiB
Python

import os
import sys
# Ensure repository root is on sys.path for package import during tests
ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if ROOT not in sys.path:
sys.path.insert(0, ROOT)
from interplanetary_edge_orchestrator_privacy import Client, Server
import random
def generate_dataset(n_samples: int, n_features: int, seed: int = 0):
rng = random.Random(seed)
X = [[rng.gauss(0.0, 1.0) for _ in range(n_features)] for _ in range(n_samples)]
true_w = [rng.gauss(0.0, 1.0) for _ in range(n_features)]
y = [sum(X[i][k] * true_w[k] for k in range(n_features)) + rng.gauss(0.0, 0.1) for i in range(n_samples)]
return X, y, true_w
def test_basic_federated_aggregation_improves_model():
random.seed(0)
n_features = 2
n_clients = 3
clients = []
for i in range(n_clients):
X, y, _ = generate_dataset(30, n_features, seed=i+1)
c = Client(client_id=i, data_X=X, data_y=y, connected=True)
clients.append(c)
server = Server(n_features)
# Initial global weights (zeros)
w_init = server.w.copy()
# Each client trains locally and returns its update (delta)
updates = []
for c in clients:
c.initialize(n_features)
update = c.train(server.w, lr=0.01, epochs=20)
updates.append(update)
# Aggregate updates on the server (no DP noise for determinism in test)
server.aggregate(updates, noise_scale=0.0, seed=123)
# Sanity: server weights should have moved away from initial zeros
assert isinstance(server.w, list) and len(server.w) == n_features
assert any(abs(server.w[i] - w_init[i]) > 1e-9 for i in range(n_features))
# Optional sanity: compute average loss reduction on clients after aggregation
total_initial_loss = 0.0
total_final_loss = 0.0
for c in clients:
# initial loss with zeros (predictions are zeros for all samples)
pred0 = [0.0 for _ in range(len(c.y))]
total_initial_loss += sum((pred0[i] - c.y[i]) ** 2 for i in range(len(c.y)))
# final loss with new global weights
pred1 = [sum(c.X[i][k] * server.w[k] for k in range(n_features)) for i in range(len(c.y))]
total_final_loss += sum((pred1[i] - c.y[i]) ** 2 for i in range(len(c.y)))
assert total_final_loss <= total_initial_loss + 1e-6