nexaroa 0.0.111__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. neuroshard/__init__.py +93 -0
  2. neuroshard/__main__.py +4 -0
  3. neuroshard/cli.py +466 -0
  4. neuroshard/core/__init__.py +92 -0
  5. neuroshard/core/consensus/verifier.py +252 -0
  6. neuroshard/core/crypto/__init__.py +20 -0
  7. neuroshard/core/crypto/ecdsa.py +392 -0
  8. neuroshard/core/economics/__init__.py +52 -0
  9. neuroshard/core/economics/constants.py +387 -0
  10. neuroshard/core/economics/ledger.py +2111 -0
  11. neuroshard/core/economics/market.py +975 -0
  12. neuroshard/core/economics/wallet.py +168 -0
  13. neuroshard/core/governance/__init__.py +74 -0
  14. neuroshard/core/governance/proposal.py +561 -0
  15. neuroshard/core/governance/registry.py +545 -0
  16. neuroshard/core/governance/versioning.py +332 -0
  17. neuroshard/core/governance/voting.py +453 -0
  18. neuroshard/core/model/__init__.py +30 -0
  19. neuroshard/core/model/dynamic.py +4186 -0
  20. neuroshard/core/model/llm.py +905 -0
  21. neuroshard/core/model/registry.py +164 -0
  22. neuroshard/core/model/scaler.py +387 -0
  23. neuroshard/core/model/tokenizer.py +568 -0
  24. neuroshard/core/network/__init__.py +56 -0
  25. neuroshard/core/network/connection_pool.py +72 -0
  26. neuroshard/core/network/dht.py +130 -0
  27. neuroshard/core/network/dht_plan.py +55 -0
  28. neuroshard/core/network/dht_proof_store.py +516 -0
  29. neuroshard/core/network/dht_protocol.py +261 -0
  30. neuroshard/core/network/dht_service.py +506 -0
  31. neuroshard/core/network/encrypted_channel.py +141 -0
  32. neuroshard/core/network/nat.py +201 -0
  33. neuroshard/core/network/nat_traversal.py +695 -0
  34. neuroshard/core/network/p2p.py +929 -0
  35. neuroshard/core/network/p2p_data.py +150 -0
  36. neuroshard/core/swarm/__init__.py +106 -0
  37. neuroshard/core/swarm/aggregation.py +729 -0
  38. neuroshard/core/swarm/buffers.py +643 -0
  39. neuroshard/core/swarm/checkpoint.py +709 -0
  40. neuroshard/core/swarm/compute.py +624 -0
  41. neuroshard/core/swarm/diloco.py +844 -0
  42. neuroshard/core/swarm/factory.py +1288 -0
  43. neuroshard/core/swarm/heartbeat.py +669 -0
  44. neuroshard/core/swarm/logger.py +487 -0
  45. neuroshard/core/swarm/router.py +658 -0
  46. neuroshard/core/swarm/service.py +640 -0
  47. neuroshard/core/training/__init__.py +29 -0
  48. neuroshard/core/training/checkpoint.py +600 -0
  49. neuroshard/core/training/distributed.py +1602 -0
  50. neuroshard/core/training/global_tracker.py +617 -0
  51. neuroshard/core/training/production.py +276 -0
  52. neuroshard/governance_cli.py +729 -0
  53. neuroshard/grpc_server.py +895 -0
  54. neuroshard/runner.py +3223 -0
  55. neuroshard/sdk/__init__.py +92 -0
  56. neuroshard/sdk/client.py +990 -0
  57. neuroshard/sdk/errors.py +101 -0
  58. neuroshard/sdk/types.py +282 -0
  59. neuroshard/tracker/__init__.py +0 -0
  60. neuroshard/tracker/server.py +864 -0
  61. neuroshard/ui/__init__.py +0 -0
  62. neuroshard/ui/app.py +102 -0
  63. neuroshard/ui/templates/index.html +1052 -0
  64. neuroshard/utils/__init__.py +0 -0
  65. neuroshard/utils/autostart.py +81 -0
  66. neuroshard/utils/hardware.py +121 -0
  67. neuroshard/utils/serialization.py +90 -0
  68. neuroshard/version.py +1 -0
  69. nexaroa-0.0.111.dist-info/METADATA +283 -0
  70. nexaroa-0.0.111.dist-info/RECORD +78 -0
  71. nexaroa-0.0.111.dist-info/WHEEL +5 -0
  72. nexaroa-0.0.111.dist-info/entry_points.txt +4 -0
  73. nexaroa-0.0.111.dist-info/licenses/LICENSE +190 -0
  74. nexaroa-0.0.111.dist-info/top_level.txt +2 -0
  75. protos/__init__.py +0 -0
  76. protos/neuroshard.proto +651 -0
  77. protos/neuroshard_pb2.py +160 -0
  78. protos/neuroshard_pb2_grpc.py +1298 -0
@@ -0,0 +1,729 @@
1
+ """
2
+ Robust Gradient Aggregation - Byzantine-Tolerant Gradient Verification
3
+
4
+ Implements robust aggregation for DiLoCo pseudo-gradients:
5
+ - Statistical validation of gradients
6
+ - Byzantine-tolerant aggregation (Krum, Median, Trimmed Mean)
7
+ - Cosine similarity verification
8
+ - Gradient magnitude checking
9
+
10
+ Key Insight: "Since DiLoCo syncs less often, bad gradients are more damaging.
11
+ Enhanced verification required."
12
+
13
+ Supports multiple aggregation strategies:
14
+ 1. Simple Mean: Fast but vulnerable to Byzantine nodes
15
+ 2. Coordinate-wise Median: Robust to outliers
16
+ 3. Trimmed Mean: Removes top/bottom percentiles
17
+ 4. Krum: Selects gradients closest to majority
18
+ 5. Multi-Krum: Weighted combination of top-k
19
+
20
+ Usage:
21
+ aggregator = RobustAggregator(strategy="trimmed_mean", trim_fraction=0.1)
22
+
23
+ # Validate incoming gradients
24
+ is_valid, reason = aggregator.validate_gradient(peer_grad, local_grad)
25
+
26
+ if is_valid:
27
+ aggregator.add_contribution(peer_id, peer_grad)
28
+
29
+ # Get aggregated result
30
+ aggregated = aggregator.aggregate()
31
+ """
32
+
33
+ import logging
34
+ import math
35
+ import time
36
+ from dataclasses import dataclass, field
37
+ from typing import Dict, List, Optional, Any, Tuple
38
+ from enum import Enum
39
+ from collections import defaultdict
40
+
41
+ import torch
42
+ import torch.nn.functional as F
43
+
44
+ logger = logging.getLogger(__name__)
45
+
46
+
47
+ class AggregationStrategy(Enum):
48
+ """Strategy for aggregating gradients."""
49
+ MEAN = "mean" # Simple average
50
+ MEDIAN = "median" # Coordinate-wise median
51
+ TRIMMED_MEAN = "trimmed_mean" # Remove top/bottom percentiles
52
+ KRUM = "krum" # Select closest to majority
53
+ MULTI_KRUM = "multi_krum" # Weighted top-k
54
+ GEOMETRIC_MEDIAN = "geometric_median" # L2 geometric median
55
+
56
+
57
+ @dataclass
58
+ class ValidationConfig:
59
+ """Configuration for gradient validation."""
60
+ # Cosine similarity
61
+ min_cosine_similarity: float = 0.3 # Minimum alignment with local gradient
62
+
63
+ # Magnitude
64
+ max_magnitude_ratio: float = 10.0 # Max ratio to local gradient norm
65
+ min_magnitude_ratio: float = 0.1 # Min ratio to local gradient norm
66
+
67
+ # Variance
68
+ max_variance_ratio: float = 100.0 # Max variance ratio
69
+
70
+ # Statistical
71
+ zscore_threshold: float = 3.0 # Max z-score for outlier detection
72
+
73
+ # Trust
74
+ require_signature: bool = False # Require cryptographic signature
75
+ min_trust_score: float = 0.0 # Minimum trust score for peer
76
+
77
+
78
+ @dataclass
79
+ class AggregationConfig:
80
+ """Configuration for aggregation."""
81
+ strategy: AggregationStrategy = AggregationStrategy.TRIMMED_MEAN
82
+
83
+ # Trimmed mean settings
84
+ trim_fraction: float = 0.1 # Fraction to trim from each end
85
+
86
+ # Krum settings
87
+ num_byzantine: int = 0 # Expected number of Byzantine nodes
88
+ multi_krum_k: int = 0 # Number of gradients to select (0 = auto)
89
+
90
+ # Geometric median settings
91
+ max_iterations: int = 100 # Max iterations for convergence
92
+ tolerance: float = 1e-6 # Convergence tolerance
93
+
94
+ # Weighting
95
+ use_trust_weights: bool = False # Weight by peer trust scores
96
+ use_freshness_weights: bool = False # Weight by gradient freshness
97
+
98
+
99
+ @dataclass
100
+ class GradientContribution:
101
+ """A gradient contribution from a peer."""
102
+ peer_id: str
103
+ gradients: Dict[str, torch.Tensor]
104
+ timestamp: float = field(default_factory=time.time)
105
+ trust_score: float = 1.0
106
+ signature: Optional[str] = None
107
+
108
+ # Validation results
109
+ is_validated: bool = False
110
+ validation_reason: str = ""
111
+
112
+ @property
113
+ def age_seconds(self) -> float:
114
+ return time.time() - self.timestamp
115
+
116
+
117
+ class GradientValidator:
118
+ """
119
+ Validates incoming gradients against local reference.
120
+
121
+ Performs multiple checks:
122
+ 1. Cosine similarity (direction alignment)
123
+ 2. Magnitude ratio (scale)
124
+ 3. Variance ratio (distribution)
125
+ 4. Z-score outlier detection
126
+ """
127
+
128
+ def __init__(self, config: Optional[ValidationConfig] = None):
129
+ self.config = config or ValidationConfig()
130
+
131
+ # Stats
132
+ self.validations_performed = 0
133
+ self.validations_passed = 0
134
+ self.validations_failed = 0
135
+ self.failure_reasons: Dict[str, int] = defaultdict(int)
136
+
137
+ def validate(
138
+ self,
139
+ submitted_grads: Dict[str, torch.Tensor],
140
+ reference_grads: Dict[str, torch.Tensor],
141
+ peer_trust: float = 1.0,
142
+ ) -> Tuple[bool, str]:
143
+ """
144
+ Validate submitted gradients against reference.
145
+
146
+ Args:
147
+ submitted_grads: Gradients from peer
148
+ reference_grads: Local reference gradients
149
+ peer_trust: Trust score of submitting peer
150
+
151
+ Returns:
152
+ (is_valid, reason) tuple
153
+ """
154
+ self.validations_performed += 1
155
+
156
+ # Check trust score
157
+ if peer_trust < self.config.min_trust_score:
158
+ self._record_failure("low_trust")
159
+ return False, f"Trust score {peer_trust} below minimum {self.config.min_trust_score}"
160
+
161
+ # Validate each parameter
162
+ for name, submitted in submitted_grads.items():
163
+ if name not in reference_grads:
164
+ continue
165
+
166
+ reference = reference_grads[name]
167
+
168
+ # Check 1: Cosine similarity (direction)
169
+ is_valid, reason = self._check_cosine_similarity(submitted, reference, name)
170
+ if not is_valid:
171
+ self._record_failure("cosine_similarity")
172
+ return False, reason
173
+
174
+ # Check 2: Magnitude ratio (scale)
175
+ is_valid, reason = self._check_magnitude(submitted, reference, name)
176
+ if not is_valid:
177
+ self._record_failure("magnitude")
178
+ return False, reason
179
+
180
+ # Check 3: Variance ratio (distribution)
181
+ is_valid, reason = self._check_variance(submitted, reference, name)
182
+ if not is_valid:
183
+ self._record_failure("variance")
184
+ return False, reason
185
+
186
+ self.validations_passed += 1
187
+ return True, "Validation passed"
188
+
189
+ def _check_cosine_similarity(
190
+ self,
191
+ submitted: torch.Tensor,
192
+ reference: torch.Tensor,
193
+ param_name: str,
194
+ ) -> Tuple[bool, str]:
195
+ """Check cosine similarity between gradients."""
196
+ # Flatten for comparison
197
+ submitted_flat = submitted.flatten()
198
+ reference_flat = reference.flatten()
199
+
200
+ # Handle zero vectors
201
+ submitted_norm = submitted_flat.norm()
202
+ reference_norm = reference_flat.norm()
203
+
204
+ if submitted_norm == 0 or reference_norm == 0:
205
+ # Zero gradient - suspicious but might be valid
206
+ if submitted_norm == 0 and reference_norm == 0:
207
+ return True, "Both zero"
208
+ return True, "One zero vector - allowing"
209
+
210
+ # Compute cosine similarity
211
+ cosine_sim = F.cosine_similarity(
212
+ submitted_flat.unsqueeze(0),
213
+ reference_flat.unsqueeze(0)
214
+ ).item()
215
+
216
+ if cosine_sim < self.config.min_cosine_similarity:
217
+ return False, (
218
+ f"Cosine similarity {cosine_sim:.3f} below threshold "
219
+ f"{self.config.min_cosine_similarity} for {param_name}"
220
+ )
221
+
222
+ return True, f"cosine={cosine_sim:.3f}"
223
+
224
+ def _check_magnitude(
225
+ self,
226
+ submitted: torch.Tensor,
227
+ reference: torch.Tensor,
228
+ param_name: str,
229
+ ) -> Tuple[bool, str]:
230
+ """Check magnitude ratio of gradients."""
231
+ submitted_norm = submitted.norm().item()
232
+ reference_norm = reference.norm().item()
233
+
234
+ if reference_norm == 0:
235
+ return True, "Reference norm is zero"
236
+
237
+ ratio = submitted_norm / reference_norm
238
+
239
+ if ratio > self.config.max_magnitude_ratio:
240
+ return False, (
241
+ f"Magnitude ratio {ratio:.2f} exceeds max "
242
+ f"{self.config.max_magnitude_ratio} for {param_name}"
243
+ )
244
+
245
+ if ratio < self.config.min_magnitude_ratio:
246
+ return False, (
247
+ f"Magnitude ratio {ratio:.2f} below min "
248
+ f"{self.config.min_magnitude_ratio} for {param_name}"
249
+ )
250
+
251
+ return True, f"magnitude_ratio={ratio:.2f}"
252
+
253
+ def _check_variance(
254
+ self,
255
+ submitted: torch.Tensor,
256
+ reference: torch.Tensor,
257
+ param_name: str,
258
+ ) -> Tuple[bool, str]:
259
+ """Check variance ratio of gradients."""
260
+ submitted_var = submitted.var().item()
261
+ reference_var = reference.var().item()
262
+
263
+ if reference_var == 0:
264
+ return True, "Reference variance is zero"
265
+
266
+ ratio = submitted_var / reference_var
267
+
268
+ if ratio > self.config.max_variance_ratio:
269
+ return False, (
270
+ f"Variance ratio {ratio:.2f} exceeds max "
271
+ f"{self.config.max_variance_ratio} for {param_name}"
272
+ )
273
+
274
+ return True, f"variance_ratio={ratio:.2f}"
275
+
276
+ def _record_failure(self, reason: str):
277
+ """Record a validation failure."""
278
+ self.validations_failed += 1
279
+ self.failure_reasons[reason] += 1
280
+
281
+ def get_stats(self) -> Dict[str, Any]:
282
+ """Get validation statistics."""
283
+ success_rate = (
284
+ self.validations_passed / self.validations_performed
285
+ if self.validations_performed > 0 else 0.0
286
+ )
287
+ return {
288
+ 'validations_performed': self.validations_performed,
289
+ 'validations_passed': self.validations_passed,
290
+ 'validations_failed': self.validations_failed,
291
+ 'success_rate': success_rate,
292
+ 'failure_reasons': dict(self.failure_reasons),
293
+ }
294
+
295
+
296
+ class RobustAggregator:
297
+ """
298
+ Byzantine-tolerant gradient aggregator.
299
+
300
+ Supports multiple aggregation strategies for robustness
301
+ against malicious or faulty nodes.
302
+ """
303
+
304
+ def __init__(
305
+ self,
306
+ aggregation_config: Optional[AggregationConfig] = None,
307
+ validation_config: Optional[ValidationConfig] = None,
308
+ ):
309
+ self.agg_config = aggregation_config or AggregationConfig()
310
+ self.validator = GradientValidator(validation_config)
311
+
312
+ # Contributions
313
+ self.contributions: List[GradientContribution] = []
314
+ self._lock = None # For thread safety if needed
315
+
316
+ # Stats
317
+ self.aggregations_performed = 0
318
+ self.total_contributions_received = 0
319
+ self.contributions_rejected = 0
320
+
321
+ def clear(self):
322
+ """Clear all contributions."""
323
+ self.contributions.clear()
324
+
325
+ def add_contribution(
326
+ self,
327
+ peer_id: str,
328
+ gradients: Dict[str, torch.Tensor],
329
+ reference_grads: Optional[Dict[str, torch.Tensor]] = None,
330
+ trust_score: float = 1.0,
331
+ validate: bool = True,
332
+ ) -> Tuple[bool, str]:
333
+ """
334
+ Add a gradient contribution from a peer.
335
+
336
+ Args:
337
+ peer_id: ID of contributing peer
338
+ gradients: Gradient tensors from peer
339
+ reference_grads: Local reference for validation
340
+ trust_score: Trust score of peer
341
+ validate: Whether to validate before adding
342
+
343
+ Returns:
344
+ (accepted, reason) tuple
345
+ """
346
+ self.total_contributions_received += 1
347
+
348
+ # Create contribution
349
+ contribution = GradientContribution(
350
+ peer_id=peer_id,
351
+ gradients=gradients,
352
+ trust_score=trust_score,
353
+ )
354
+
355
+ # Validate if reference provided
356
+ if validate and reference_grads is not None:
357
+ is_valid, reason = self.validator.validate(
358
+ gradients,
359
+ reference_grads,
360
+ trust_score
361
+ )
362
+
363
+ contribution.is_validated = True
364
+ contribution.validation_reason = reason
365
+
366
+ if not is_valid:
367
+ self.contributions_rejected += 1
368
+ logger.warning(f"Rejected gradient from {peer_id}: {reason}")
369
+ return False, reason
370
+
371
+ self.contributions.append(contribution)
372
+ return True, "Accepted"
373
+
374
+ def aggregate(
375
+ self,
376
+ local_grads: Optional[Dict[str, torch.Tensor]] = None,
377
+ ) -> Dict[str, torch.Tensor]:
378
+ """
379
+ Aggregate all contributions using configured strategy.
380
+
381
+ Args:
382
+ local_grads: Optional local gradients to include
383
+
384
+ Returns:
385
+ Aggregated gradients
386
+ """
387
+ if not self.contributions and local_grads is None:
388
+ return {}
389
+
390
+ # Add local as contribution if provided
391
+ all_contributions = list(self.contributions)
392
+ if local_grads is not None:
393
+ all_contributions.append(GradientContribution(
394
+ peer_id="local",
395
+ gradients=local_grads,
396
+ trust_score=1.0,
397
+ is_validated=True,
398
+ ))
399
+
400
+ # Select aggregation method
401
+ strategy = self.agg_config.strategy
402
+
403
+ if strategy == AggregationStrategy.MEAN:
404
+ result = self._aggregate_mean(all_contributions)
405
+ elif strategy == AggregationStrategy.MEDIAN:
406
+ result = self._aggregate_median(all_contributions)
407
+ elif strategy == AggregationStrategy.TRIMMED_MEAN:
408
+ result = self._aggregate_trimmed_mean(all_contributions)
409
+ elif strategy == AggregationStrategy.KRUM:
410
+ result = self._aggregate_krum(all_contributions)
411
+ elif strategy == AggregationStrategy.MULTI_KRUM:
412
+ result = self._aggregate_multi_krum(all_contributions)
413
+ elif strategy == AggregationStrategy.GEOMETRIC_MEDIAN:
414
+ result = self._aggregate_geometric_median(all_contributions)
415
+ else:
416
+ result = self._aggregate_mean(all_contributions)
417
+
418
+ self.aggregations_performed += 1
419
+ self.clear() # Clear contributions after aggregation
420
+
421
+ return result
422
+
423
+ def _aggregate_mean(
424
+ self,
425
+ contributions: List[GradientContribution],
426
+ ) -> Dict[str, torch.Tensor]:
427
+ """Simple mean aggregation."""
428
+ if not contributions:
429
+ return {}
430
+
431
+ # Get all parameter names
432
+ param_names = set()
433
+ for c in contributions:
434
+ param_names.update(c.gradients.keys())
435
+
436
+ # Average each parameter
437
+ result = {}
438
+ for name in param_names:
439
+ tensors = [
440
+ c.gradients[name] for c in contributions
441
+ if name in c.gradients
442
+ ]
443
+ if tensors:
444
+ # Apply trust weights if configured
445
+ if self.agg_config.use_trust_weights:
446
+ weights = torch.tensor([
447
+ c.trust_score for c in contributions
448
+ if name in c.gradients
449
+ ])
450
+ weights = weights / weights.sum()
451
+ result[name] = sum(
452
+ w * t for w, t in zip(weights, tensors)
453
+ )
454
+ else:
455
+ result[name] = torch.stack(tensors).mean(dim=0)
456
+
457
+ return result
458
+
459
+ def _aggregate_median(
460
+ self,
461
+ contributions: List[GradientContribution],
462
+ ) -> Dict[str, torch.Tensor]:
463
+ """Coordinate-wise median aggregation."""
464
+ if not contributions:
465
+ return {}
466
+
467
+ param_names = set()
468
+ for c in contributions:
469
+ param_names.update(c.gradients.keys())
470
+
471
+ result = {}
472
+ for name in param_names:
473
+ tensors = [
474
+ c.gradients[name] for c in contributions
475
+ if name in c.gradients
476
+ ]
477
+ if tensors:
478
+ stacked = torch.stack(tensors)
479
+ result[name] = stacked.median(dim=0)[0]
480
+
481
+ return result
482
+
483
+ def _aggregate_trimmed_mean(
484
+ self,
485
+ contributions: List[GradientContribution],
486
+ ) -> Dict[str, torch.Tensor]:
487
+ """Trimmed mean aggregation (removes top/bottom percentiles)."""
488
+ if not contributions:
489
+ return {}
490
+
491
+ trim_fraction = self.agg_config.trim_fraction
492
+ n = len(contributions)
493
+
494
+ # Number to trim from each end
495
+ trim_count = int(n * trim_fraction)
496
+ if 2 * trim_count >= n:
497
+ trim_count = max(0, n // 2 - 1)
498
+
499
+ param_names = set()
500
+ for c in contributions:
501
+ param_names.update(c.gradients.keys())
502
+
503
+ result = {}
504
+ for name in param_names:
505
+ tensors = [
506
+ c.gradients[name] for c in contributions
507
+ if name in c.gradients
508
+ ]
509
+ if tensors:
510
+ stacked = torch.stack(tensors) # [n, ...]
511
+
512
+ if trim_count > 0 and len(tensors) > 2 * trim_count:
513
+ # Sort along first dimension and trim
514
+ sorted_tensors = stacked.sort(dim=0)[0]
515
+ trimmed = sorted_tensors[trim_count:-trim_count]
516
+ result[name] = trimmed.mean(dim=0)
517
+ else:
518
+ result[name] = stacked.mean(dim=0)
519
+
520
+ return result
521
+
522
+ def _aggregate_krum(
523
+ self,
524
+ contributions: List[GradientContribution],
525
+ ) -> Dict[str, torch.Tensor]:
526
+ """
527
+ Krum aggregation - select gradient closest to majority.
528
+
529
+ Assumes at most f Byzantine nodes out of n.
530
+ Selects the gradient with smallest sum of distances to
531
+ its n - f - 2 closest neighbors.
532
+ """
533
+ if not contributions:
534
+ return {}
535
+
536
+ n = len(contributions)
537
+ f = min(self.agg_config.num_byzantine, n - 2)
538
+
539
+ if n <= 2:
540
+ return self._aggregate_mean(contributions)
541
+
542
+ # Compute pairwise distances
543
+ distances = self._compute_pairwise_distances(contributions)
544
+
545
+ # For each gradient, sum distances to n - f - 2 closest
546
+ scores = []
547
+ keep = n - f - 2
548
+
549
+ for i in range(n):
550
+ # Get distances from i to all others
551
+ dists = distances[i]
552
+ # Sort and sum closest (excluding self which is 0)
553
+ sorted_dists = sorted(dists)
554
+ score = sum(sorted_dists[1:keep+1]) # Skip self (index 0)
555
+ scores.append(score)
556
+
557
+ # Select gradient with minimum score
558
+ best_idx = min(range(n), key=lambda i: scores[i])
559
+
560
+ return contributions[best_idx].gradients
561
+
562
+ def _aggregate_multi_krum(
563
+ self,
564
+ contributions: List[GradientContribution],
565
+ ) -> Dict[str, torch.Tensor]:
566
+ """
567
+ Multi-Krum aggregation - average of top-k Krum selections.
568
+ """
569
+ if not contributions:
570
+ return {}
571
+
572
+ n = len(contributions)
573
+ k = self.agg_config.multi_krum_k
574
+ if k <= 0 or k >= n:
575
+ k = max(1, n - self.agg_config.num_byzantine)
576
+
577
+ # Compute Krum scores
578
+ f = min(self.agg_config.num_byzantine, n - 2)
579
+ distances = self._compute_pairwise_distances(contributions)
580
+
581
+ scores = []
582
+ keep = max(1, n - f - 2)
583
+
584
+ for i in range(n):
585
+ dists = distances[i]
586
+ sorted_dists = sorted(dists)
587
+ score = sum(sorted_dists[1:keep+1])
588
+ scores.append((score, i))
589
+
590
+ # Select top-k by score (lower is better)
591
+ scores.sort(key=lambda x: x[0])
592
+ selected_indices = [idx for _, idx in scores[:k]]
593
+
594
+ # Average selected gradients
595
+ selected = [contributions[i] for i in selected_indices]
596
+ return self._aggregate_mean(selected)
597
+
598
+ def _aggregate_geometric_median(
599
+ self,
600
+ contributions: List[GradientContribution],
601
+ ) -> Dict[str, torch.Tensor]:
602
+ """
603
+ Geometric median aggregation via Weiszfeld algorithm.
604
+
605
+ More robust than coordinate-wise median.
606
+ """
607
+ if not contributions:
608
+ return {}
609
+
610
+ # Start with mean as initial estimate
611
+ result = self._aggregate_mean(contributions)
612
+
613
+ # Iterative refinement
614
+ for iteration in range(self.agg_config.max_iterations):
615
+ prev_result = {k: v.clone() for k, v in result.items()}
616
+
617
+ for name in result.keys():
618
+ tensors = [
619
+ c.gradients[name] for c in contributions
620
+ if name in c.gradients
621
+ ]
622
+ if not tensors:
623
+ continue
624
+
625
+ # Compute weighted update
626
+ current = result[name]
627
+ weights = []
628
+ weighted_sum = torch.zeros_like(current)
629
+
630
+ for t in tensors:
631
+ dist = (t - current).norm().item()
632
+ if dist > 1e-10:
633
+ w = 1.0 / dist
634
+ weights.append(w)
635
+ weighted_sum += w * t
636
+ else:
637
+ # Point is at current estimate
638
+ weights.append(1e10)
639
+ weighted_sum += 1e10 * t
640
+
641
+ total_weight = sum(weights)
642
+ if total_weight > 0:
643
+ result[name] = weighted_sum / total_weight
644
+
645
+ # Check convergence
646
+ total_change = sum(
647
+ (result[k] - prev_result[k]).norm().item()
648
+ for k in result.keys()
649
+ )
650
+ if total_change < self.agg_config.tolerance:
651
+ break
652
+
653
+ return result
654
+
655
+ def _compute_pairwise_distances(
656
+ self,
657
+ contributions: List[GradientContribution],
658
+ ) -> List[List[float]]:
659
+ """Compute pairwise L2 distances between all gradients."""
660
+ n = len(contributions)
661
+ distances = [[0.0] * n for _ in range(n)]
662
+
663
+ for i in range(n):
664
+ for j in range(i + 1, n):
665
+ # Sum squared distances across all parameters
666
+ total_dist = 0.0
667
+ for name in contributions[i].gradients.keys():
668
+ if name in contributions[j].gradients:
669
+ diff = contributions[i].gradients[name] - contributions[j].gradients[name]
670
+ total_dist += diff.norm().item() ** 2
671
+
672
+ dist = math.sqrt(total_dist)
673
+ distances[i][j] = dist
674
+ distances[j][i] = dist
675
+
676
+ return distances
677
+
678
+ def get_stats(self) -> Dict[str, Any]:
679
+ """Get aggregator statistics."""
680
+ return {
681
+ 'aggregations_performed': self.aggregations_performed,
682
+ 'total_contributions_received': self.total_contributions_received,
683
+ 'contributions_rejected': self.contributions_rejected,
684
+ 'current_contributions': len(self.contributions),
685
+ 'validation_stats': self.validator.get_stats(),
686
+ }
687
+
688
+
689
+ # ==================== FACTORY FUNCTIONS ====================
690
+
691
+ def create_robust_aggregator(
692
+ strategy: str = "trimmed_mean",
693
+ trim_fraction: float = 0.1,
694
+ num_byzantine: int = 0,
695
+ min_cosine_similarity: float = 0.3,
696
+ **kwargs,
697
+ ) -> RobustAggregator:
698
+ """
699
+ Factory function to create a robust aggregator.
700
+
701
+ Args:
702
+ strategy: Aggregation strategy name
703
+ trim_fraction: Fraction to trim (for trimmed_mean)
704
+ num_byzantine: Expected Byzantine nodes (for Krum)
705
+ min_cosine_similarity: Minimum gradient alignment
706
+ **kwargs: Additional config options
707
+
708
+ Returns:
709
+ Configured RobustAggregator
710
+ """
711
+ try:
712
+ agg_strategy = AggregationStrategy(strategy)
713
+ except ValueError:
714
+ agg_strategy = AggregationStrategy.TRIMMED_MEAN
715
+
716
+ agg_config = AggregationConfig(
717
+ strategy=agg_strategy,
718
+ trim_fraction=trim_fraction,
719
+ num_byzantine=num_byzantine,
720
+ )
721
+
722
+ val_config = ValidationConfig(
723
+ min_cosine_similarity=min_cosine_similarity,
724
+ )
725
+
726
+ return RobustAggregator(
727
+ aggregation_config=agg_config,
728
+ validation_config=val_config,
729
+ )