build(agent): new-agents-3#dd492b iteration

This commit is contained in:
agent-dd492b85242a98c5 2026-04-19 20:56:08 +02:00
parent b7ba8241f2
commit e42e29f9d3
2 changed files with 53 additions and 0 deletions

BIN
cache/client_1_update.pkl vendored Normal file

Binary file not shown.

View File

@ -0,0 +1,53 @@
import os
import random
import sys
# Ensure repository root is on sys.path for package import during tests
ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if ROOT not in sys.path:
sys.path.insert(0, ROOT)
from interplanetary_edge_orchestrator_privacy import Client, Server
def generate_dataset(n_samples: int, n_features: int, seed: int = 0):
rng = random.Random(seed)
X = [[rng.gauss(0.0, 1.0) for _ in range(n_features)] for _ in range(n_samples)]
true_w = [rng.gauss(0.0, 1.0) for _ in range(n_features)]
y = [sum(X[i][k] * true_w[k] for k in range(n_features)) + rng.gauss(0.0, 0.1) for i in range(n_samples)]
return X, y, true_w
def test_offline_client_can_cache_and_load_update(tmp_path=None):
random.seed(0)
n_features = 3
n_clients = 2
# Client 0 will operate online; Client 1 will operate offline to exercise cache.
X0, y0, _ = generate_dataset(40, n_features, seed=1)
X1, y1, _ = generate_dataset(40, n_features, seed=2)
online = Client(client_id=0, data_X=X0, data_y=y0, connected=True, cache_dir=str(tmp_path or "cache"))
offline = Client(client_id=1, data_X=X1, data_y=y1, connected=False, cache_dir=str(tmp_path or "cache"))
server = Server(n_features)
# Initialize weights and collect updates from online client
online.initialize(n_features)
upd_online = online.train(server.w, lr=0.01, epochs=5)
# Train offline client; its update should be cached to disk
offline.initialize(n_features)
upd_offline = offline.train(server.w, lr=0.01, epochs=5)
# Load the cached offline update and verify it matches the returned value
cached_offline = offline.load_update()
assert cached_offline == upd_offline
# Now perform a server aggregation including the offline update read from cache
updates = [upd_online, cached_offline]
server.aggregate(updates, noise_scale=0.0, seed=123)
# Basic sanity: server weights moved from initial zeros
w_init = [0.0 for _ in range(n_features)]
assert isinstance(server.w, list) and len(server.w) == n_features
assert any(abs(server.w[i] - w_init[i]) > 1e-9 for i in range(n_features))