33 lines
1014 B
Python
33 lines
1014 B
Python
import math
|
|
import statistics
|
|
|
|
|
|
class SecureAggregator:
|
|
@staticmethod
|
|
def aggregate(local_results: list) -> dict:
|
|
# local_results: list of dicts with numeric values
|
|
if not local_results:
|
|
return {}
|
|
# collect all metric keys
|
|
keys = set()
|
|
for d in local_results:
|
|
keys.update(d.keys())
|
|
|
|
aggregated = {}
|
|
for k in keys:
|
|
values = [d[k] for d in local_results if k in d and isinstance(d[k], (int, float))]
|
|
if not values:
|
|
continue
|
|
n = len(values)
|
|
mean = sum(values) / n
|
|
if n < 2:
|
|
ci_lower = ci_upper = mean
|
|
else:
|
|
std = statistics.pstdev(values)
|
|
se = std / math.sqrt(n)
|
|
margin = 1.96 * se
|
|
ci_lower = mean - margin
|
|
ci_upper = mean + margin
|
|
aggregated[k] = {"mean": mean, "ci_lower": ci_lower, "ci_upper": ci_upper}
|
|
return aggregated
|