build(agent): new-agents-3#dd492b iteration
This commit is contained in:
parent
b7ba8241f2
commit
e42e29f9d3
Binary file not shown.
|
|
@ -0,0 +1,53 @@
|
|||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
# Ensure repository root is on sys.path for package import during tests
|
||||
ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
if ROOT not in sys.path:
|
||||
sys.path.insert(0, ROOT)
|
||||
|
||||
from interplanetary_edge_orchestrator_privacy import Client, Server
|
||||
|
||||
|
||||
def generate_dataset(n_samples: int, n_features: int, seed: int = 0):
|
||||
rng = random.Random(seed)
|
||||
X = [[rng.gauss(0.0, 1.0) for _ in range(n_features)] for _ in range(n_samples)]
|
||||
true_w = [rng.gauss(0.0, 1.0) for _ in range(n_features)]
|
||||
y = [sum(X[i][k] * true_w[k] for k in range(n_features)) + rng.gauss(0.0, 0.1) for i in range(n_samples)]
|
||||
return X, y, true_w
|
||||
|
||||
|
||||
def test_offline_client_can_cache_and_load_update(tmp_path=None):
|
||||
random.seed(0)
|
||||
n_features = 3
|
||||
n_clients = 2
|
||||
|
||||
# Client 0 will operate online; Client 1 will operate offline to exercise cache.
|
||||
X0, y0, _ = generate_dataset(40, n_features, seed=1)
|
||||
X1, y1, _ = generate_dataset(40, n_features, seed=2)
|
||||
|
||||
online = Client(client_id=0, data_X=X0, data_y=y0, connected=True, cache_dir=str(tmp_path or "cache"))
|
||||
offline = Client(client_id=1, data_X=X1, data_y=y1, connected=False, cache_dir=str(tmp_path or "cache"))
|
||||
server = Server(n_features)
|
||||
|
||||
# Initialize weights and collect updates from online client
|
||||
online.initialize(n_features)
|
||||
upd_online = online.train(server.w, lr=0.01, epochs=5)
|
||||
|
||||
# Train offline client; its update should be cached to disk
|
||||
offline.initialize(n_features)
|
||||
upd_offline = offline.train(server.w, lr=0.01, epochs=5)
|
||||
|
||||
# Load the cached offline update and verify it matches the returned value
|
||||
cached_offline = offline.load_update()
|
||||
assert cached_offline == upd_offline
|
||||
|
||||
# Now perform a server aggregation including the offline update read from cache
|
||||
updates = [upd_online, cached_offline]
|
||||
server.aggregate(updates, noise_scale=0.0, seed=123)
|
||||
|
||||
# Basic sanity: server weights moved from initial zeros
|
||||
w_init = [0.0 for _ in range(n_features)]
|
||||
assert isinstance(server.w, list) and len(server.w) == n_features
|
||||
assert any(abs(server.w[i] - w_init[i]) > 1e-9 for i in range(n_features))
|
||||
Loading…
Reference in New Issue