build(agent): new-agents-3#dd492b iteration
This commit is contained in:
parent
b7ba8241f2
commit
e42e29f9d3
Binary file not shown.
|
|
@ -0,0 +1,53 @@
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Ensure repository root is on sys.path for package import during tests
|
||||||
|
ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
|
if ROOT not in sys.path:
|
||||||
|
sys.path.insert(0, ROOT)
|
||||||
|
|
||||||
|
from interplanetary_edge_orchestrator_privacy import Client, Server
|
||||||
|
|
||||||
|
|
||||||
|
def generate_dataset(n_samples: int, n_features: int, seed: int = 0):
|
||||||
|
rng = random.Random(seed)
|
||||||
|
X = [[rng.gauss(0.0, 1.0) for _ in range(n_features)] for _ in range(n_samples)]
|
||||||
|
true_w = [rng.gauss(0.0, 1.0) for _ in range(n_features)]
|
||||||
|
y = [sum(X[i][k] * true_w[k] for k in range(n_features)) + rng.gauss(0.0, 0.1) for i in range(n_samples)]
|
||||||
|
return X, y, true_w
|
||||||
|
|
||||||
|
|
||||||
|
def test_offline_client_can_cache_and_load_update(tmp_path=None):
|
||||||
|
random.seed(0)
|
||||||
|
n_features = 3
|
||||||
|
n_clients = 2
|
||||||
|
|
||||||
|
# Client 0 will operate online; Client 1 will operate offline to exercise cache.
|
||||||
|
X0, y0, _ = generate_dataset(40, n_features, seed=1)
|
||||||
|
X1, y1, _ = generate_dataset(40, n_features, seed=2)
|
||||||
|
|
||||||
|
online = Client(client_id=0, data_X=X0, data_y=y0, connected=True, cache_dir=str(tmp_path or "cache"))
|
||||||
|
offline = Client(client_id=1, data_X=X1, data_y=y1, connected=False, cache_dir=str(tmp_path or "cache"))
|
||||||
|
server = Server(n_features)
|
||||||
|
|
||||||
|
# Initialize weights and collect updates from online client
|
||||||
|
online.initialize(n_features)
|
||||||
|
upd_online = online.train(server.w, lr=0.01, epochs=5)
|
||||||
|
|
||||||
|
# Train offline client; its update should be cached to disk
|
||||||
|
offline.initialize(n_features)
|
||||||
|
upd_offline = offline.train(server.w, lr=0.01, epochs=5)
|
||||||
|
|
||||||
|
# Load the cached offline update and verify it matches the returned value
|
||||||
|
cached_offline = offline.load_update()
|
||||||
|
assert cached_offline == upd_offline
|
||||||
|
|
||||||
|
# Now perform a server aggregation including the offline update read from cache
|
||||||
|
updates = [upd_online, cached_offline]
|
||||||
|
server.aggregate(updates, noise_scale=0.0, seed=123)
|
||||||
|
|
||||||
|
# Basic sanity: server weights moved from initial zeros
|
||||||
|
w_init = [0.0 for _ in range(n_features)]
|
||||||
|
assert isinstance(server.w, list) and len(server.w) == n_features
|
||||||
|
assert any(abs(server.w[i] - w_init[i]) > 1e-9 for i in range(n_features))
|
||||||
Loading…
Reference in New Issue