opengrowth-privacy-preservi.../opengrowth_privacy_preservi.../secure_aggregation.py

33 lines
1014 B
Python

import math
import statistics
class SecureAggregator:
@staticmethod
def aggregate(local_results: list) -> dict:
# local_results: list of dicts with numeric values
if not local_results:
return {}
# collect all metric keys
keys = set()
for d in local_results:
keys.update(d.keys())
aggregated = {}
for k in keys:
values = [d[k] for d in local_results if k in d and isinstance(d[k], (int, float))]
if not values:
continue
n = len(values)
mean = sum(values) / n
if n < 2:
ci_lower = ci_upper = mean
else:
std = statistics.pstdev(values)
se = std / math.sqrt(n)
margin = 1.96 * se
ci_lower = mean - margin
ci_upper = mean + margin
aggregated[k] = {"mean": mean, "ci_lower": ci_lower, "ci_upper": ci_upper}
return aggregated