nexaroa 0.0.111__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- neuroshard/__init__.py +93 -0
- neuroshard/__main__.py +4 -0
- neuroshard/cli.py +466 -0
- neuroshard/core/__init__.py +92 -0
- neuroshard/core/consensus/verifier.py +252 -0
- neuroshard/core/crypto/__init__.py +20 -0
- neuroshard/core/crypto/ecdsa.py +392 -0
- neuroshard/core/economics/__init__.py +52 -0
- neuroshard/core/economics/constants.py +387 -0
- neuroshard/core/economics/ledger.py +2111 -0
- neuroshard/core/economics/market.py +975 -0
- neuroshard/core/economics/wallet.py +168 -0
- neuroshard/core/governance/__init__.py +74 -0
- neuroshard/core/governance/proposal.py +561 -0
- neuroshard/core/governance/registry.py +545 -0
- neuroshard/core/governance/versioning.py +332 -0
- neuroshard/core/governance/voting.py +453 -0
- neuroshard/core/model/__init__.py +30 -0
- neuroshard/core/model/dynamic.py +4186 -0
- neuroshard/core/model/llm.py +905 -0
- neuroshard/core/model/registry.py +164 -0
- neuroshard/core/model/scaler.py +387 -0
- neuroshard/core/model/tokenizer.py +568 -0
- neuroshard/core/network/__init__.py +56 -0
- neuroshard/core/network/connection_pool.py +72 -0
- neuroshard/core/network/dht.py +130 -0
- neuroshard/core/network/dht_plan.py +55 -0
- neuroshard/core/network/dht_proof_store.py +516 -0
- neuroshard/core/network/dht_protocol.py +261 -0
- neuroshard/core/network/dht_service.py +506 -0
- neuroshard/core/network/encrypted_channel.py +141 -0
- neuroshard/core/network/nat.py +201 -0
- neuroshard/core/network/nat_traversal.py +695 -0
- neuroshard/core/network/p2p.py +929 -0
- neuroshard/core/network/p2p_data.py +150 -0
- neuroshard/core/swarm/__init__.py +106 -0
- neuroshard/core/swarm/aggregation.py +729 -0
- neuroshard/core/swarm/buffers.py +643 -0
- neuroshard/core/swarm/checkpoint.py +709 -0
- neuroshard/core/swarm/compute.py +624 -0
- neuroshard/core/swarm/diloco.py +844 -0
- neuroshard/core/swarm/factory.py +1288 -0
- neuroshard/core/swarm/heartbeat.py +669 -0
- neuroshard/core/swarm/logger.py +487 -0
- neuroshard/core/swarm/router.py +658 -0
- neuroshard/core/swarm/service.py +640 -0
- neuroshard/core/training/__init__.py +29 -0
- neuroshard/core/training/checkpoint.py +600 -0
- neuroshard/core/training/distributed.py +1602 -0
- neuroshard/core/training/global_tracker.py +617 -0
- neuroshard/core/training/production.py +276 -0
- neuroshard/governance_cli.py +729 -0
- neuroshard/grpc_server.py +895 -0
- neuroshard/runner.py +3223 -0
- neuroshard/sdk/__init__.py +92 -0
- neuroshard/sdk/client.py +990 -0
- neuroshard/sdk/errors.py +101 -0
- neuroshard/sdk/types.py +282 -0
- neuroshard/tracker/__init__.py +0 -0
- neuroshard/tracker/server.py +864 -0
- neuroshard/ui/__init__.py +0 -0
- neuroshard/ui/app.py +102 -0
- neuroshard/ui/templates/index.html +1052 -0
- neuroshard/utils/__init__.py +0 -0
- neuroshard/utils/autostart.py +81 -0
- neuroshard/utils/hardware.py +121 -0
- neuroshard/utils/serialization.py +90 -0
- neuroshard/version.py +1 -0
- nexaroa-0.0.111.dist-info/METADATA +283 -0
- nexaroa-0.0.111.dist-info/RECORD +78 -0
- nexaroa-0.0.111.dist-info/WHEEL +5 -0
- nexaroa-0.0.111.dist-info/entry_points.txt +4 -0
- nexaroa-0.0.111.dist-info/licenses/LICENSE +190 -0
- nexaroa-0.0.111.dist-info/top_level.txt +2 -0
- protos/__init__.py +0 -0
- protos/neuroshard.proto +651 -0
- protos/neuroshard_pb2.py +160 -0
- protos/neuroshard_pb2_grpc.py +1298 -0
|
@@ -0,0 +1,729 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Robust Gradient Aggregation - Byzantine-Tolerant Gradient Verification
|
|
3
|
+
|
|
4
|
+
Implements robust aggregation for DiLoCo pseudo-gradients:
|
|
5
|
+
- Statistical validation of gradients
|
|
6
|
+
- Byzantine-tolerant aggregation (Krum, Median, Trimmed Mean)
|
|
7
|
+
- Cosine similarity verification
|
|
8
|
+
- Gradient magnitude checking
|
|
9
|
+
|
|
10
|
+
Key Insight: "Since DiLoCo syncs less often, bad gradients are more damaging.
|
|
11
|
+
Enhanced verification required."
|
|
12
|
+
|
|
13
|
+
Supports multiple aggregation strategies:
|
|
14
|
+
1. Simple Mean: Fast but vulnerable to Byzantine nodes
|
|
15
|
+
2. Coordinate-wise Median: Robust to outliers
|
|
16
|
+
3. Trimmed Mean: Removes top/bottom percentiles
|
|
17
|
+
4. Krum: Selects gradients closest to majority
|
|
18
|
+
5. Multi-Krum: Weighted combination of top-k
|
|
19
|
+
|
|
20
|
+
Usage:
|
|
21
|
+
aggregator = RobustAggregator(strategy="trimmed_mean", trim_fraction=0.1)
|
|
22
|
+
|
|
23
|
+
# Validate incoming gradients
|
|
24
|
+
is_valid, reason = aggregator.validate_gradient(peer_grad, local_grad)
|
|
25
|
+
|
|
26
|
+
if is_valid:
|
|
27
|
+
aggregator.add_contribution(peer_id, peer_grad)
|
|
28
|
+
|
|
29
|
+
# Get aggregated result
|
|
30
|
+
aggregated = aggregator.aggregate()
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
import logging
|
|
34
|
+
import math
|
|
35
|
+
import time
|
|
36
|
+
from dataclasses import dataclass, field
|
|
37
|
+
from typing import Dict, List, Optional, Any, Tuple
|
|
38
|
+
from enum import Enum
|
|
39
|
+
from collections import defaultdict
|
|
40
|
+
|
|
41
|
+
import torch
|
|
42
|
+
import torch.nn.functional as F
|
|
43
|
+
|
|
44
|
+
logger = logging.getLogger(__name__)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class AggregationStrategy(Enum):
|
|
48
|
+
"""Strategy for aggregating gradients."""
|
|
49
|
+
MEAN = "mean" # Simple average
|
|
50
|
+
MEDIAN = "median" # Coordinate-wise median
|
|
51
|
+
TRIMMED_MEAN = "trimmed_mean" # Remove top/bottom percentiles
|
|
52
|
+
KRUM = "krum" # Select closest to majority
|
|
53
|
+
MULTI_KRUM = "multi_krum" # Weighted top-k
|
|
54
|
+
GEOMETRIC_MEDIAN = "geometric_median" # L2 geometric median
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass
|
|
58
|
+
class ValidationConfig:
|
|
59
|
+
"""Configuration for gradient validation."""
|
|
60
|
+
# Cosine similarity
|
|
61
|
+
min_cosine_similarity: float = 0.3 # Minimum alignment with local gradient
|
|
62
|
+
|
|
63
|
+
# Magnitude
|
|
64
|
+
max_magnitude_ratio: float = 10.0 # Max ratio to local gradient norm
|
|
65
|
+
min_magnitude_ratio: float = 0.1 # Min ratio to local gradient norm
|
|
66
|
+
|
|
67
|
+
# Variance
|
|
68
|
+
max_variance_ratio: float = 100.0 # Max variance ratio
|
|
69
|
+
|
|
70
|
+
# Statistical
|
|
71
|
+
zscore_threshold: float = 3.0 # Max z-score for outlier detection
|
|
72
|
+
|
|
73
|
+
# Trust
|
|
74
|
+
require_signature: bool = False # Require cryptographic signature
|
|
75
|
+
min_trust_score: float = 0.0 # Minimum trust score for peer
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@dataclass
|
|
79
|
+
class AggregationConfig:
|
|
80
|
+
"""Configuration for aggregation."""
|
|
81
|
+
strategy: AggregationStrategy = AggregationStrategy.TRIMMED_MEAN
|
|
82
|
+
|
|
83
|
+
# Trimmed mean settings
|
|
84
|
+
trim_fraction: float = 0.1 # Fraction to trim from each end
|
|
85
|
+
|
|
86
|
+
# Krum settings
|
|
87
|
+
num_byzantine: int = 0 # Expected number of Byzantine nodes
|
|
88
|
+
multi_krum_k: int = 0 # Number of gradients to select (0 = auto)
|
|
89
|
+
|
|
90
|
+
# Geometric median settings
|
|
91
|
+
max_iterations: int = 100 # Max iterations for convergence
|
|
92
|
+
tolerance: float = 1e-6 # Convergence tolerance
|
|
93
|
+
|
|
94
|
+
# Weighting
|
|
95
|
+
use_trust_weights: bool = False # Weight by peer trust scores
|
|
96
|
+
use_freshness_weights: bool = False # Weight by gradient freshness
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@dataclass
|
|
100
|
+
class GradientContribution:
|
|
101
|
+
"""A gradient contribution from a peer."""
|
|
102
|
+
peer_id: str
|
|
103
|
+
gradients: Dict[str, torch.Tensor]
|
|
104
|
+
timestamp: float = field(default_factory=time.time)
|
|
105
|
+
trust_score: float = 1.0
|
|
106
|
+
signature: Optional[str] = None
|
|
107
|
+
|
|
108
|
+
# Validation results
|
|
109
|
+
is_validated: bool = False
|
|
110
|
+
validation_reason: str = ""
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def age_seconds(self) -> float:
|
|
114
|
+
return time.time() - self.timestamp
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class GradientValidator:
|
|
118
|
+
"""
|
|
119
|
+
Validates incoming gradients against local reference.
|
|
120
|
+
|
|
121
|
+
Performs multiple checks:
|
|
122
|
+
1. Cosine similarity (direction alignment)
|
|
123
|
+
2. Magnitude ratio (scale)
|
|
124
|
+
3. Variance ratio (distribution)
|
|
125
|
+
4. Z-score outlier detection
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
def __init__(self, config: Optional[ValidationConfig] = None):
|
|
129
|
+
self.config = config or ValidationConfig()
|
|
130
|
+
|
|
131
|
+
# Stats
|
|
132
|
+
self.validations_performed = 0
|
|
133
|
+
self.validations_passed = 0
|
|
134
|
+
self.validations_failed = 0
|
|
135
|
+
self.failure_reasons: Dict[str, int] = defaultdict(int)
|
|
136
|
+
|
|
137
|
+
def validate(
|
|
138
|
+
self,
|
|
139
|
+
submitted_grads: Dict[str, torch.Tensor],
|
|
140
|
+
reference_grads: Dict[str, torch.Tensor],
|
|
141
|
+
peer_trust: float = 1.0,
|
|
142
|
+
) -> Tuple[bool, str]:
|
|
143
|
+
"""
|
|
144
|
+
Validate submitted gradients against reference.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
submitted_grads: Gradients from peer
|
|
148
|
+
reference_grads: Local reference gradients
|
|
149
|
+
peer_trust: Trust score of submitting peer
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
(is_valid, reason) tuple
|
|
153
|
+
"""
|
|
154
|
+
self.validations_performed += 1
|
|
155
|
+
|
|
156
|
+
# Check trust score
|
|
157
|
+
if peer_trust < self.config.min_trust_score:
|
|
158
|
+
self._record_failure("low_trust")
|
|
159
|
+
return False, f"Trust score {peer_trust} below minimum {self.config.min_trust_score}"
|
|
160
|
+
|
|
161
|
+
# Validate each parameter
|
|
162
|
+
for name, submitted in submitted_grads.items():
|
|
163
|
+
if name not in reference_grads:
|
|
164
|
+
continue
|
|
165
|
+
|
|
166
|
+
reference = reference_grads[name]
|
|
167
|
+
|
|
168
|
+
# Check 1: Cosine similarity (direction)
|
|
169
|
+
is_valid, reason = self._check_cosine_similarity(submitted, reference, name)
|
|
170
|
+
if not is_valid:
|
|
171
|
+
self._record_failure("cosine_similarity")
|
|
172
|
+
return False, reason
|
|
173
|
+
|
|
174
|
+
# Check 2: Magnitude ratio (scale)
|
|
175
|
+
is_valid, reason = self._check_magnitude(submitted, reference, name)
|
|
176
|
+
if not is_valid:
|
|
177
|
+
self._record_failure("magnitude")
|
|
178
|
+
return False, reason
|
|
179
|
+
|
|
180
|
+
# Check 3: Variance ratio (distribution)
|
|
181
|
+
is_valid, reason = self._check_variance(submitted, reference, name)
|
|
182
|
+
if not is_valid:
|
|
183
|
+
self._record_failure("variance")
|
|
184
|
+
return False, reason
|
|
185
|
+
|
|
186
|
+
self.validations_passed += 1
|
|
187
|
+
return True, "Validation passed"
|
|
188
|
+
|
|
189
|
+
def _check_cosine_similarity(
|
|
190
|
+
self,
|
|
191
|
+
submitted: torch.Tensor,
|
|
192
|
+
reference: torch.Tensor,
|
|
193
|
+
param_name: str,
|
|
194
|
+
) -> Tuple[bool, str]:
|
|
195
|
+
"""Check cosine similarity between gradients."""
|
|
196
|
+
# Flatten for comparison
|
|
197
|
+
submitted_flat = submitted.flatten()
|
|
198
|
+
reference_flat = reference.flatten()
|
|
199
|
+
|
|
200
|
+
# Handle zero vectors
|
|
201
|
+
submitted_norm = submitted_flat.norm()
|
|
202
|
+
reference_norm = reference_flat.norm()
|
|
203
|
+
|
|
204
|
+
if submitted_norm == 0 or reference_norm == 0:
|
|
205
|
+
# Zero gradient - suspicious but might be valid
|
|
206
|
+
if submitted_norm == 0 and reference_norm == 0:
|
|
207
|
+
return True, "Both zero"
|
|
208
|
+
return True, "One zero vector - allowing"
|
|
209
|
+
|
|
210
|
+
# Compute cosine similarity
|
|
211
|
+
cosine_sim = F.cosine_similarity(
|
|
212
|
+
submitted_flat.unsqueeze(0),
|
|
213
|
+
reference_flat.unsqueeze(0)
|
|
214
|
+
).item()
|
|
215
|
+
|
|
216
|
+
if cosine_sim < self.config.min_cosine_similarity:
|
|
217
|
+
return False, (
|
|
218
|
+
f"Cosine similarity {cosine_sim:.3f} below threshold "
|
|
219
|
+
f"{self.config.min_cosine_similarity} for {param_name}"
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
return True, f"cosine={cosine_sim:.3f}"
|
|
223
|
+
|
|
224
|
+
def _check_magnitude(
|
|
225
|
+
self,
|
|
226
|
+
submitted: torch.Tensor,
|
|
227
|
+
reference: torch.Tensor,
|
|
228
|
+
param_name: str,
|
|
229
|
+
) -> Tuple[bool, str]:
|
|
230
|
+
"""Check magnitude ratio of gradients."""
|
|
231
|
+
submitted_norm = submitted.norm().item()
|
|
232
|
+
reference_norm = reference.norm().item()
|
|
233
|
+
|
|
234
|
+
if reference_norm == 0:
|
|
235
|
+
return True, "Reference norm is zero"
|
|
236
|
+
|
|
237
|
+
ratio = submitted_norm / reference_norm
|
|
238
|
+
|
|
239
|
+
if ratio > self.config.max_magnitude_ratio:
|
|
240
|
+
return False, (
|
|
241
|
+
f"Magnitude ratio {ratio:.2f} exceeds max "
|
|
242
|
+
f"{self.config.max_magnitude_ratio} for {param_name}"
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
if ratio < self.config.min_magnitude_ratio:
|
|
246
|
+
return False, (
|
|
247
|
+
f"Magnitude ratio {ratio:.2f} below min "
|
|
248
|
+
f"{self.config.min_magnitude_ratio} for {param_name}"
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
return True, f"magnitude_ratio={ratio:.2f}"
|
|
252
|
+
|
|
253
|
+
def _check_variance(
|
|
254
|
+
self,
|
|
255
|
+
submitted: torch.Tensor,
|
|
256
|
+
reference: torch.Tensor,
|
|
257
|
+
param_name: str,
|
|
258
|
+
) -> Tuple[bool, str]:
|
|
259
|
+
"""Check variance ratio of gradients."""
|
|
260
|
+
submitted_var = submitted.var().item()
|
|
261
|
+
reference_var = reference.var().item()
|
|
262
|
+
|
|
263
|
+
if reference_var == 0:
|
|
264
|
+
return True, "Reference variance is zero"
|
|
265
|
+
|
|
266
|
+
ratio = submitted_var / reference_var
|
|
267
|
+
|
|
268
|
+
if ratio > self.config.max_variance_ratio:
|
|
269
|
+
return False, (
|
|
270
|
+
f"Variance ratio {ratio:.2f} exceeds max "
|
|
271
|
+
f"{self.config.max_variance_ratio} for {param_name}"
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
return True, f"variance_ratio={ratio:.2f}"
|
|
275
|
+
|
|
276
|
+
def _record_failure(self, reason: str):
|
|
277
|
+
"""Record a validation failure."""
|
|
278
|
+
self.validations_failed += 1
|
|
279
|
+
self.failure_reasons[reason] += 1
|
|
280
|
+
|
|
281
|
+
def get_stats(self) -> Dict[str, Any]:
|
|
282
|
+
"""Get validation statistics."""
|
|
283
|
+
success_rate = (
|
|
284
|
+
self.validations_passed / self.validations_performed
|
|
285
|
+
if self.validations_performed > 0 else 0.0
|
|
286
|
+
)
|
|
287
|
+
return {
|
|
288
|
+
'validations_performed': self.validations_performed,
|
|
289
|
+
'validations_passed': self.validations_passed,
|
|
290
|
+
'validations_failed': self.validations_failed,
|
|
291
|
+
'success_rate': success_rate,
|
|
292
|
+
'failure_reasons': dict(self.failure_reasons),
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
class RobustAggregator:
|
|
297
|
+
"""
|
|
298
|
+
Byzantine-tolerant gradient aggregator.
|
|
299
|
+
|
|
300
|
+
Supports multiple aggregation strategies for robustness
|
|
301
|
+
against malicious or faulty nodes.
|
|
302
|
+
"""
|
|
303
|
+
|
|
304
|
+
def __init__(
|
|
305
|
+
self,
|
|
306
|
+
aggregation_config: Optional[AggregationConfig] = None,
|
|
307
|
+
validation_config: Optional[ValidationConfig] = None,
|
|
308
|
+
):
|
|
309
|
+
self.agg_config = aggregation_config or AggregationConfig()
|
|
310
|
+
self.validator = GradientValidator(validation_config)
|
|
311
|
+
|
|
312
|
+
# Contributions
|
|
313
|
+
self.contributions: List[GradientContribution] = []
|
|
314
|
+
self._lock = None # For thread safety if needed
|
|
315
|
+
|
|
316
|
+
# Stats
|
|
317
|
+
self.aggregations_performed = 0
|
|
318
|
+
self.total_contributions_received = 0
|
|
319
|
+
self.contributions_rejected = 0
|
|
320
|
+
|
|
321
|
+
def clear(self):
|
|
322
|
+
"""Clear all contributions."""
|
|
323
|
+
self.contributions.clear()
|
|
324
|
+
|
|
325
|
+
def add_contribution(
|
|
326
|
+
self,
|
|
327
|
+
peer_id: str,
|
|
328
|
+
gradients: Dict[str, torch.Tensor],
|
|
329
|
+
reference_grads: Optional[Dict[str, torch.Tensor]] = None,
|
|
330
|
+
trust_score: float = 1.0,
|
|
331
|
+
validate: bool = True,
|
|
332
|
+
) -> Tuple[bool, str]:
|
|
333
|
+
"""
|
|
334
|
+
Add a gradient contribution from a peer.
|
|
335
|
+
|
|
336
|
+
Args:
|
|
337
|
+
peer_id: ID of contributing peer
|
|
338
|
+
gradients: Gradient tensors from peer
|
|
339
|
+
reference_grads: Local reference for validation
|
|
340
|
+
trust_score: Trust score of peer
|
|
341
|
+
validate: Whether to validate before adding
|
|
342
|
+
|
|
343
|
+
Returns:
|
|
344
|
+
(accepted, reason) tuple
|
|
345
|
+
"""
|
|
346
|
+
self.total_contributions_received += 1
|
|
347
|
+
|
|
348
|
+
# Create contribution
|
|
349
|
+
contribution = GradientContribution(
|
|
350
|
+
peer_id=peer_id,
|
|
351
|
+
gradients=gradients,
|
|
352
|
+
trust_score=trust_score,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
# Validate if reference provided
|
|
356
|
+
if validate and reference_grads is not None:
|
|
357
|
+
is_valid, reason = self.validator.validate(
|
|
358
|
+
gradients,
|
|
359
|
+
reference_grads,
|
|
360
|
+
trust_score
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
contribution.is_validated = True
|
|
364
|
+
contribution.validation_reason = reason
|
|
365
|
+
|
|
366
|
+
if not is_valid:
|
|
367
|
+
self.contributions_rejected += 1
|
|
368
|
+
logger.warning(f"Rejected gradient from {peer_id}: {reason}")
|
|
369
|
+
return False, reason
|
|
370
|
+
|
|
371
|
+
self.contributions.append(contribution)
|
|
372
|
+
return True, "Accepted"
|
|
373
|
+
|
|
374
|
+
def aggregate(
|
|
375
|
+
self,
|
|
376
|
+
local_grads: Optional[Dict[str, torch.Tensor]] = None,
|
|
377
|
+
) -> Dict[str, torch.Tensor]:
|
|
378
|
+
"""
|
|
379
|
+
Aggregate all contributions using configured strategy.
|
|
380
|
+
|
|
381
|
+
Args:
|
|
382
|
+
local_grads: Optional local gradients to include
|
|
383
|
+
|
|
384
|
+
Returns:
|
|
385
|
+
Aggregated gradients
|
|
386
|
+
"""
|
|
387
|
+
if not self.contributions and local_grads is None:
|
|
388
|
+
return {}
|
|
389
|
+
|
|
390
|
+
# Add local as contribution if provided
|
|
391
|
+
all_contributions = list(self.contributions)
|
|
392
|
+
if local_grads is not None:
|
|
393
|
+
all_contributions.append(GradientContribution(
|
|
394
|
+
peer_id="local",
|
|
395
|
+
gradients=local_grads,
|
|
396
|
+
trust_score=1.0,
|
|
397
|
+
is_validated=True,
|
|
398
|
+
))
|
|
399
|
+
|
|
400
|
+
# Select aggregation method
|
|
401
|
+
strategy = self.agg_config.strategy
|
|
402
|
+
|
|
403
|
+
if strategy == AggregationStrategy.MEAN:
|
|
404
|
+
result = self._aggregate_mean(all_contributions)
|
|
405
|
+
elif strategy == AggregationStrategy.MEDIAN:
|
|
406
|
+
result = self._aggregate_median(all_contributions)
|
|
407
|
+
elif strategy == AggregationStrategy.TRIMMED_MEAN:
|
|
408
|
+
result = self._aggregate_trimmed_mean(all_contributions)
|
|
409
|
+
elif strategy == AggregationStrategy.KRUM:
|
|
410
|
+
result = self._aggregate_krum(all_contributions)
|
|
411
|
+
elif strategy == AggregationStrategy.MULTI_KRUM:
|
|
412
|
+
result = self._aggregate_multi_krum(all_contributions)
|
|
413
|
+
elif strategy == AggregationStrategy.GEOMETRIC_MEDIAN:
|
|
414
|
+
result = self._aggregate_geometric_median(all_contributions)
|
|
415
|
+
else:
|
|
416
|
+
result = self._aggregate_mean(all_contributions)
|
|
417
|
+
|
|
418
|
+
self.aggregations_performed += 1
|
|
419
|
+
self.clear() # Clear contributions after aggregation
|
|
420
|
+
|
|
421
|
+
return result
|
|
422
|
+
|
|
423
|
+
def _aggregate_mean(
|
|
424
|
+
self,
|
|
425
|
+
contributions: List[GradientContribution],
|
|
426
|
+
) -> Dict[str, torch.Tensor]:
|
|
427
|
+
"""Simple mean aggregation."""
|
|
428
|
+
if not contributions:
|
|
429
|
+
return {}
|
|
430
|
+
|
|
431
|
+
# Get all parameter names
|
|
432
|
+
param_names = set()
|
|
433
|
+
for c in contributions:
|
|
434
|
+
param_names.update(c.gradients.keys())
|
|
435
|
+
|
|
436
|
+
# Average each parameter
|
|
437
|
+
result = {}
|
|
438
|
+
for name in param_names:
|
|
439
|
+
tensors = [
|
|
440
|
+
c.gradients[name] for c in contributions
|
|
441
|
+
if name in c.gradients
|
|
442
|
+
]
|
|
443
|
+
if tensors:
|
|
444
|
+
# Apply trust weights if configured
|
|
445
|
+
if self.agg_config.use_trust_weights:
|
|
446
|
+
weights = torch.tensor([
|
|
447
|
+
c.trust_score for c in contributions
|
|
448
|
+
if name in c.gradients
|
|
449
|
+
])
|
|
450
|
+
weights = weights / weights.sum()
|
|
451
|
+
result[name] = sum(
|
|
452
|
+
w * t for w, t in zip(weights, tensors)
|
|
453
|
+
)
|
|
454
|
+
else:
|
|
455
|
+
result[name] = torch.stack(tensors).mean(dim=0)
|
|
456
|
+
|
|
457
|
+
return result
|
|
458
|
+
|
|
459
|
+
def _aggregate_median(
|
|
460
|
+
self,
|
|
461
|
+
contributions: List[GradientContribution],
|
|
462
|
+
) -> Dict[str, torch.Tensor]:
|
|
463
|
+
"""Coordinate-wise median aggregation."""
|
|
464
|
+
if not contributions:
|
|
465
|
+
return {}
|
|
466
|
+
|
|
467
|
+
param_names = set()
|
|
468
|
+
for c in contributions:
|
|
469
|
+
param_names.update(c.gradients.keys())
|
|
470
|
+
|
|
471
|
+
result = {}
|
|
472
|
+
for name in param_names:
|
|
473
|
+
tensors = [
|
|
474
|
+
c.gradients[name] for c in contributions
|
|
475
|
+
if name in c.gradients
|
|
476
|
+
]
|
|
477
|
+
if tensors:
|
|
478
|
+
stacked = torch.stack(tensors)
|
|
479
|
+
result[name] = stacked.median(dim=0)[0]
|
|
480
|
+
|
|
481
|
+
return result
|
|
482
|
+
|
|
483
|
+
def _aggregate_trimmed_mean(
|
|
484
|
+
self,
|
|
485
|
+
contributions: List[GradientContribution],
|
|
486
|
+
) -> Dict[str, torch.Tensor]:
|
|
487
|
+
"""Trimmed mean aggregation (removes top/bottom percentiles)."""
|
|
488
|
+
if not contributions:
|
|
489
|
+
return {}
|
|
490
|
+
|
|
491
|
+
trim_fraction = self.agg_config.trim_fraction
|
|
492
|
+
n = len(contributions)
|
|
493
|
+
|
|
494
|
+
# Number to trim from each end
|
|
495
|
+
trim_count = int(n * trim_fraction)
|
|
496
|
+
if 2 * trim_count >= n:
|
|
497
|
+
trim_count = max(0, n // 2 - 1)
|
|
498
|
+
|
|
499
|
+
param_names = set()
|
|
500
|
+
for c in contributions:
|
|
501
|
+
param_names.update(c.gradients.keys())
|
|
502
|
+
|
|
503
|
+
result = {}
|
|
504
|
+
for name in param_names:
|
|
505
|
+
tensors = [
|
|
506
|
+
c.gradients[name] for c in contributions
|
|
507
|
+
if name in c.gradients
|
|
508
|
+
]
|
|
509
|
+
if tensors:
|
|
510
|
+
stacked = torch.stack(tensors) # [n, ...]
|
|
511
|
+
|
|
512
|
+
if trim_count > 0 and len(tensors) > 2 * trim_count:
|
|
513
|
+
# Sort along first dimension and trim
|
|
514
|
+
sorted_tensors = stacked.sort(dim=0)[0]
|
|
515
|
+
trimmed = sorted_tensors[trim_count:-trim_count]
|
|
516
|
+
result[name] = trimmed.mean(dim=0)
|
|
517
|
+
else:
|
|
518
|
+
result[name] = stacked.mean(dim=0)
|
|
519
|
+
|
|
520
|
+
return result
|
|
521
|
+
|
|
522
|
+
def _aggregate_krum(
|
|
523
|
+
self,
|
|
524
|
+
contributions: List[GradientContribution],
|
|
525
|
+
) -> Dict[str, torch.Tensor]:
|
|
526
|
+
"""
|
|
527
|
+
Krum aggregation - select gradient closest to majority.
|
|
528
|
+
|
|
529
|
+
Assumes at most f Byzantine nodes out of n.
|
|
530
|
+
Selects the gradient with smallest sum of distances to
|
|
531
|
+
its n - f - 2 closest neighbors.
|
|
532
|
+
"""
|
|
533
|
+
if not contributions:
|
|
534
|
+
return {}
|
|
535
|
+
|
|
536
|
+
n = len(contributions)
|
|
537
|
+
f = min(self.agg_config.num_byzantine, n - 2)
|
|
538
|
+
|
|
539
|
+
if n <= 2:
|
|
540
|
+
return self._aggregate_mean(contributions)
|
|
541
|
+
|
|
542
|
+
# Compute pairwise distances
|
|
543
|
+
distances = self._compute_pairwise_distances(contributions)
|
|
544
|
+
|
|
545
|
+
# For each gradient, sum distances to n - f - 2 closest
|
|
546
|
+
scores = []
|
|
547
|
+
keep = n - f - 2
|
|
548
|
+
|
|
549
|
+
for i in range(n):
|
|
550
|
+
# Get distances from i to all others
|
|
551
|
+
dists = distances[i]
|
|
552
|
+
# Sort and sum closest (excluding self which is 0)
|
|
553
|
+
sorted_dists = sorted(dists)
|
|
554
|
+
score = sum(sorted_dists[1:keep+1]) # Skip self (index 0)
|
|
555
|
+
scores.append(score)
|
|
556
|
+
|
|
557
|
+
# Select gradient with minimum score
|
|
558
|
+
best_idx = min(range(n), key=lambda i: scores[i])
|
|
559
|
+
|
|
560
|
+
return contributions[best_idx].gradients
|
|
561
|
+
|
|
562
|
+
def _aggregate_multi_krum(
|
|
563
|
+
self,
|
|
564
|
+
contributions: List[GradientContribution],
|
|
565
|
+
) -> Dict[str, torch.Tensor]:
|
|
566
|
+
"""
|
|
567
|
+
Multi-Krum aggregation - average of top-k Krum selections.
|
|
568
|
+
"""
|
|
569
|
+
if not contributions:
|
|
570
|
+
return {}
|
|
571
|
+
|
|
572
|
+
n = len(contributions)
|
|
573
|
+
k = self.agg_config.multi_krum_k
|
|
574
|
+
if k <= 0 or k >= n:
|
|
575
|
+
k = max(1, n - self.agg_config.num_byzantine)
|
|
576
|
+
|
|
577
|
+
# Compute Krum scores
|
|
578
|
+
f = min(self.agg_config.num_byzantine, n - 2)
|
|
579
|
+
distances = self._compute_pairwise_distances(contributions)
|
|
580
|
+
|
|
581
|
+
scores = []
|
|
582
|
+
keep = max(1, n - f - 2)
|
|
583
|
+
|
|
584
|
+
for i in range(n):
|
|
585
|
+
dists = distances[i]
|
|
586
|
+
sorted_dists = sorted(dists)
|
|
587
|
+
score = sum(sorted_dists[1:keep+1])
|
|
588
|
+
scores.append((score, i))
|
|
589
|
+
|
|
590
|
+
# Select top-k by score (lower is better)
|
|
591
|
+
scores.sort(key=lambda x: x[0])
|
|
592
|
+
selected_indices = [idx for _, idx in scores[:k]]
|
|
593
|
+
|
|
594
|
+
# Average selected gradients
|
|
595
|
+
selected = [contributions[i] for i in selected_indices]
|
|
596
|
+
return self._aggregate_mean(selected)
|
|
597
|
+
|
|
598
|
+
def _aggregate_geometric_median(
|
|
599
|
+
self,
|
|
600
|
+
contributions: List[GradientContribution],
|
|
601
|
+
) -> Dict[str, torch.Tensor]:
|
|
602
|
+
"""
|
|
603
|
+
Geometric median aggregation via Weiszfeld algorithm.
|
|
604
|
+
|
|
605
|
+
More robust than coordinate-wise median.
|
|
606
|
+
"""
|
|
607
|
+
if not contributions:
|
|
608
|
+
return {}
|
|
609
|
+
|
|
610
|
+
# Start with mean as initial estimate
|
|
611
|
+
result = self._aggregate_mean(contributions)
|
|
612
|
+
|
|
613
|
+
# Iterative refinement
|
|
614
|
+
for iteration in range(self.agg_config.max_iterations):
|
|
615
|
+
prev_result = {k: v.clone() for k, v in result.items()}
|
|
616
|
+
|
|
617
|
+
for name in result.keys():
|
|
618
|
+
tensors = [
|
|
619
|
+
c.gradients[name] for c in contributions
|
|
620
|
+
if name in c.gradients
|
|
621
|
+
]
|
|
622
|
+
if not tensors:
|
|
623
|
+
continue
|
|
624
|
+
|
|
625
|
+
# Compute weighted update
|
|
626
|
+
current = result[name]
|
|
627
|
+
weights = []
|
|
628
|
+
weighted_sum = torch.zeros_like(current)
|
|
629
|
+
|
|
630
|
+
for t in tensors:
|
|
631
|
+
dist = (t - current).norm().item()
|
|
632
|
+
if dist > 1e-10:
|
|
633
|
+
w = 1.0 / dist
|
|
634
|
+
weights.append(w)
|
|
635
|
+
weighted_sum += w * t
|
|
636
|
+
else:
|
|
637
|
+
# Point is at current estimate
|
|
638
|
+
weights.append(1e10)
|
|
639
|
+
weighted_sum += 1e10 * t
|
|
640
|
+
|
|
641
|
+
total_weight = sum(weights)
|
|
642
|
+
if total_weight > 0:
|
|
643
|
+
result[name] = weighted_sum / total_weight
|
|
644
|
+
|
|
645
|
+
# Check convergence
|
|
646
|
+
total_change = sum(
|
|
647
|
+
(result[k] - prev_result[k]).norm().item()
|
|
648
|
+
for k in result.keys()
|
|
649
|
+
)
|
|
650
|
+
if total_change < self.agg_config.tolerance:
|
|
651
|
+
break
|
|
652
|
+
|
|
653
|
+
return result
|
|
654
|
+
|
|
655
|
+
def _compute_pairwise_distances(
|
|
656
|
+
self,
|
|
657
|
+
contributions: List[GradientContribution],
|
|
658
|
+
) -> List[List[float]]:
|
|
659
|
+
"""Compute pairwise L2 distances between all gradients."""
|
|
660
|
+
n = len(contributions)
|
|
661
|
+
distances = [[0.0] * n for _ in range(n)]
|
|
662
|
+
|
|
663
|
+
for i in range(n):
|
|
664
|
+
for j in range(i + 1, n):
|
|
665
|
+
# Sum squared distances across all parameters
|
|
666
|
+
total_dist = 0.0
|
|
667
|
+
for name in contributions[i].gradients.keys():
|
|
668
|
+
if name in contributions[j].gradients:
|
|
669
|
+
diff = contributions[i].gradients[name] - contributions[j].gradients[name]
|
|
670
|
+
total_dist += diff.norm().item() ** 2
|
|
671
|
+
|
|
672
|
+
dist = math.sqrt(total_dist)
|
|
673
|
+
distances[i][j] = dist
|
|
674
|
+
distances[j][i] = dist
|
|
675
|
+
|
|
676
|
+
return distances
|
|
677
|
+
|
|
678
|
+
def get_stats(self) -> Dict[str, Any]:
|
|
679
|
+
"""Get aggregator statistics."""
|
|
680
|
+
return {
|
|
681
|
+
'aggregations_performed': self.aggregations_performed,
|
|
682
|
+
'total_contributions_received': self.total_contributions_received,
|
|
683
|
+
'contributions_rejected': self.contributions_rejected,
|
|
684
|
+
'current_contributions': len(self.contributions),
|
|
685
|
+
'validation_stats': self.validator.get_stats(),
|
|
686
|
+
}
|
|
687
|
+
|
|
688
|
+
|
|
689
|
+
# ==================== FACTORY FUNCTIONS ====================
|
|
690
|
+
|
|
691
|
+
def create_robust_aggregator(
|
|
692
|
+
strategy: str = "trimmed_mean",
|
|
693
|
+
trim_fraction: float = 0.1,
|
|
694
|
+
num_byzantine: int = 0,
|
|
695
|
+
min_cosine_similarity: float = 0.3,
|
|
696
|
+
**kwargs,
|
|
697
|
+
) -> RobustAggregator:
|
|
698
|
+
"""
|
|
699
|
+
Factory function to create a robust aggregator.
|
|
700
|
+
|
|
701
|
+
Args:
|
|
702
|
+
strategy: Aggregation strategy name
|
|
703
|
+
trim_fraction: Fraction to trim (for trimmed_mean)
|
|
704
|
+
num_byzantine: Expected Byzantine nodes (for Krum)
|
|
705
|
+
min_cosine_similarity: Minimum gradient alignment
|
|
706
|
+
**kwargs: Additional config options
|
|
707
|
+
|
|
708
|
+
Returns:
|
|
709
|
+
Configured RobustAggregator
|
|
710
|
+
"""
|
|
711
|
+
try:
|
|
712
|
+
agg_strategy = AggregationStrategy(strategy)
|
|
713
|
+
except ValueError:
|
|
714
|
+
agg_strategy = AggregationStrategy.TRIMMED_MEAN
|
|
715
|
+
|
|
716
|
+
agg_config = AggregationConfig(
|
|
717
|
+
strategy=agg_strategy,
|
|
718
|
+
trim_fraction=trim_fraction,
|
|
719
|
+
num_byzantine=num_byzantine,
|
|
720
|
+
)
|
|
721
|
+
|
|
722
|
+
val_config = ValidationConfig(
|
|
723
|
+
min_cosine_similarity=min_cosine_similarity,
|
|
724
|
+
)
|
|
725
|
+
|
|
726
|
+
return RobustAggregator(
|
|
727
|
+
aggregation_config=agg_config,
|
|
728
|
+
validation_config=val_config,
|
|
729
|
+
)
|