import os import sys # Ensure repository root is on sys.path for package import during tests ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) if ROOT not in sys.path: sys.path.insert(0, ROOT) from interplanetary_edge_orchestrator_privacy import Client, Server import random def generate_dataset(n_samples: int, n_features: int, seed: int = 0): rng = random.Random(seed) X = [[rng.gauss(0.0, 1.0) for _ in range(n_features)] for _ in range(n_samples)] true_w = [rng.gauss(0.0, 1.0) for _ in range(n_features)] y = [sum(X[i][k] * true_w[k] for k in range(n_features)) + rng.gauss(0.0, 0.1) for i in range(n_samples)] return X, y, true_w def test_basic_federated_aggregation_improves_model(): random.seed(0) n_features = 2 n_clients = 3 clients = [] for i in range(n_clients): X, y, _ = generate_dataset(30, n_features, seed=i+1) c = Client(client_id=i, data_X=X, data_y=y, connected=True) clients.append(c) server = Server(n_features) # Initial global weights (zeros) w_init = server.w.copy() # Each client trains locally and returns its update (delta) updates = [] for c in clients: c.initialize(n_features) update = c.train(server.w, lr=0.01, epochs=20) updates.append(update) # Aggregate updates on the server (no DP noise for determinism in test) server.aggregate(updates, noise_scale=0.0, seed=123) # Sanity: server weights should have moved away from initial zeros assert isinstance(server.w, list) and len(server.w) == n_features assert any(abs(server.w[i] - w_init[i]) > 1e-9 for i in range(n_features)) # Optional sanity: compute average loss reduction on clients after aggregation total_initial_loss = 0.0 total_final_loss = 0.0 for c in clients: # initial loss with zeros (predictions are zeros for all samples) pred0 = [0.0 for _ in range(len(c.y))] total_initial_loss += sum((pred0[i] - c.y[i]) ** 2 for i in range(len(c.y))) # final loss with new global weights pred1 = [sum(c.X[i][k] * server.w[k] for k in range(n_features)) for i in range(len(c.y))] total_final_loss += sum((pred1[i] - c.y[i]) ** 2 for i in range(len(c.y))) assert total_final_loss <= total_initial_loss + 1e-6