diff --git a/cache/client_1_update.pkl b/cache/client_1_update.pkl new file mode 100644 index 0000000..a3ac529 Binary files /dev/null and b/cache/client_1_update.pkl differ diff --git a/tests/test_offline_cache.py b/tests/test_offline_cache.py new file mode 100644 index 0000000..8abd2e3 --- /dev/null +++ b/tests/test_offline_cache.py @@ -0,0 +1,53 @@ +import os +import random +import sys + +# Ensure repository root is on sys.path for package import during tests +ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if ROOT not in sys.path: + sys.path.insert(0, ROOT) + +from interplanetary_edge_orchestrator_privacy import Client, Server + + +def generate_dataset(n_samples: int, n_features: int, seed: int = 0): + rng = random.Random(seed) + X = [[rng.gauss(0.0, 1.0) for _ in range(n_features)] for _ in range(n_samples)] + true_w = [rng.gauss(0.0, 1.0) for _ in range(n_features)] + y = [sum(X[i][k] * true_w[k] for k in range(n_features)) + rng.gauss(0.0, 0.1) for i in range(n_samples)] + return X, y, true_w + + +def test_offline_client_can_cache_and_load_update(tmp_path=None): + random.seed(0) + n_features = 3 + n_clients = 2 + + # Client 0 will operate online; Client 1 will operate offline to exercise cache. + X0, y0, _ = generate_dataset(40, n_features, seed=1) + X1, y1, _ = generate_dataset(40, n_features, seed=2) + + online = Client(client_id=0, data_X=X0, data_y=y0, connected=True, cache_dir=str(tmp_path or "cache")) + offline = Client(client_id=1, data_X=X1, data_y=y1, connected=False, cache_dir=str(tmp_path or "cache")) + server = Server(n_features) + + # Initialize weights and collect updates from online client + online.initialize(n_features) + upd_online = online.train(server.w, lr=0.01, epochs=5) + + # Train offline client; its update should be cached to disk + offline.initialize(n_features) + upd_offline = offline.train(server.w, lr=0.01, epochs=5) + + # Load the cached offline update and verify it matches the returned value + cached_offline = offline.load_update() + assert cached_offline == upd_offline + + # Now perform a server aggregation including the offline update read from cache + updates = [upd_online, cached_offline] + server.aggregate(updates, noise_scale=0.0, seed=123) + + # Basic sanity: server weights moved from initial zeros + w_init = [0.0 for _ in range(n_features)] + assert isinstance(server.w, list) and len(server.w) == n_features + assert any(abs(server.w[i] - w_init[i]) > 1e-9 for i in range(n_features))