nexaroa 0.0.111__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. neuroshard/__init__.py +93 -0
  2. neuroshard/__main__.py +4 -0
  3. neuroshard/cli.py +466 -0
  4. neuroshard/core/__init__.py +92 -0
  5. neuroshard/core/consensus/verifier.py +252 -0
  6. neuroshard/core/crypto/__init__.py +20 -0
  7. neuroshard/core/crypto/ecdsa.py +392 -0
  8. neuroshard/core/economics/__init__.py +52 -0
  9. neuroshard/core/economics/constants.py +387 -0
  10. neuroshard/core/economics/ledger.py +2111 -0
  11. neuroshard/core/economics/market.py +975 -0
  12. neuroshard/core/economics/wallet.py +168 -0
  13. neuroshard/core/governance/__init__.py +74 -0
  14. neuroshard/core/governance/proposal.py +561 -0
  15. neuroshard/core/governance/registry.py +545 -0
  16. neuroshard/core/governance/versioning.py +332 -0
  17. neuroshard/core/governance/voting.py +453 -0
  18. neuroshard/core/model/__init__.py +30 -0
  19. neuroshard/core/model/dynamic.py +4186 -0
  20. neuroshard/core/model/llm.py +905 -0
  21. neuroshard/core/model/registry.py +164 -0
  22. neuroshard/core/model/scaler.py +387 -0
  23. neuroshard/core/model/tokenizer.py +568 -0
  24. neuroshard/core/network/__init__.py +56 -0
  25. neuroshard/core/network/connection_pool.py +72 -0
  26. neuroshard/core/network/dht.py +130 -0
  27. neuroshard/core/network/dht_plan.py +55 -0
  28. neuroshard/core/network/dht_proof_store.py +516 -0
  29. neuroshard/core/network/dht_protocol.py +261 -0
  30. neuroshard/core/network/dht_service.py +506 -0
  31. neuroshard/core/network/encrypted_channel.py +141 -0
  32. neuroshard/core/network/nat.py +201 -0
  33. neuroshard/core/network/nat_traversal.py +695 -0
  34. neuroshard/core/network/p2p.py +929 -0
  35. neuroshard/core/network/p2p_data.py +150 -0
  36. neuroshard/core/swarm/__init__.py +106 -0
  37. neuroshard/core/swarm/aggregation.py +729 -0
  38. neuroshard/core/swarm/buffers.py +643 -0
  39. neuroshard/core/swarm/checkpoint.py +709 -0
  40. neuroshard/core/swarm/compute.py +624 -0
  41. neuroshard/core/swarm/diloco.py +844 -0
  42. neuroshard/core/swarm/factory.py +1288 -0
  43. neuroshard/core/swarm/heartbeat.py +669 -0
  44. neuroshard/core/swarm/logger.py +487 -0
  45. neuroshard/core/swarm/router.py +658 -0
  46. neuroshard/core/swarm/service.py +640 -0
  47. neuroshard/core/training/__init__.py +29 -0
  48. neuroshard/core/training/checkpoint.py +600 -0
  49. neuroshard/core/training/distributed.py +1602 -0
  50. neuroshard/core/training/global_tracker.py +617 -0
  51. neuroshard/core/training/production.py +276 -0
  52. neuroshard/governance_cli.py +729 -0
  53. neuroshard/grpc_server.py +895 -0
  54. neuroshard/runner.py +3223 -0
  55. neuroshard/sdk/__init__.py +92 -0
  56. neuroshard/sdk/client.py +990 -0
  57. neuroshard/sdk/errors.py +101 -0
  58. neuroshard/sdk/types.py +282 -0
  59. neuroshard/tracker/__init__.py +0 -0
  60. neuroshard/tracker/server.py +864 -0
  61. neuroshard/ui/__init__.py +0 -0
  62. neuroshard/ui/app.py +102 -0
  63. neuroshard/ui/templates/index.html +1052 -0
  64. neuroshard/utils/__init__.py +0 -0
  65. neuroshard/utils/autostart.py +81 -0
  66. neuroshard/utils/hardware.py +121 -0
  67. neuroshard/utils/serialization.py +90 -0
  68. neuroshard/version.py +1 -0
  69. nexaroa-0.0.111.dist-info/METADATA +283 -0
  70. nexaroa-0.0.111.dist-info/RECORD +78 -0
  71. nexaroa-0.0.111.dist-info/WHEEL +5 -0
  72. nexaroa-0.0.111.dist-info/entry_points.txt +4 -0
  73. nexaroa-0.0.111.dist-info/licenses/LICENSE +190 -0
  74. nexaroa-0.0.111.dist-info/top_level.txt +2 -0
  75. protos/__init__.py +0 -0
  76. protos/neuroshard.proto +651 -0
  77. protos/neuroshard_pb2.py +160 -0
  78. protos/neuroshard_pb2_grpc.py +1298 -0
@@ -0,0 +1,658 @@
1
+ """
2
+ Swarm Router - Fault-Tolerant Multipath Routing for NeuroShard
3
+
4
+ Implements the "Swarm" network layer with:
5
+ - K-candidate routing (multiple peers per layer range)
6
+ - Automatic failover within 200ms
7
+ - Weighted scoring (latency + queue_depth)
8
+ - Probabilistic send with retry
9
+
10
+ Key Directive: "If Node A hangs, the packet must automatically flow to Node B
11
+ without crashing the run."
12
+
13
+ Reference: SWARM Parallelism papers for randomized pipeline construction.
14
+ """
15
+
16
+ import asyncio
17
+ import grpc
18
+ import logging
19
+ import random
20
+ import time
21
+ import torch
22
+ from dataclasses import dataclass, field
23
+ from typing import Dict, List, Optional, Tuple, Any
24
+ from collections import defaultdict
25
+
26
+ # Try importing proto definitions
27
+ try:
28
+ from protos import neuroshard_pb2
29
+ from protos import neuroshard_pb2_grpc
30
+ GRPC_AVAILABLE = True
31
+ except ImportError:
32
+ GRPC_AVAILABLE = False
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ @dataclass
38
+ class PeerCandidate:
39
+ """
40
+ A candidate peer for routing.
41
+
42
+ Candidates are scored by a weighted combination of latency and queue depth.
43
+ Lower score = better candidate.
44
+ """
45
+ node_id: str
46
+ grpc_addr: str
47
+ layer_range: Tuple[int, int] # (start_layer, end_layer) exclusive
48
+
49
+ # Performance metrics (rolling averages)
50
+ latency_ms: float = 100.0 # Default 100ms
51
+ queue_depth: int = 0 # Current work queue size
52
+
53
+ # Health tracking
54
+ last_heartbeat: float = field(default_factory=time.time)
55
+ consecutive_failures: int = 0
56
+ total_requests: int = 0
57
+ successful_requests: int = 0
58
+
59
+ # Capacity info from heartbeat
60
+ available_memory_mb: int = 0
61
+ gpu_utilization: float = 0.0
62
+ is_accepting_activations: bool = True
63
+
64
+ def score(self, weights: Optional[Dict[str, float]] = None) -> float:
65
+ """
66
+ Compute weighted score for peer selection.
67
+
68
+ Lower score = better candidate.
69
+
70
+ Default weights prioritize queue depth (60%) over latency (40%)
71
+ because queue depth is a better indicator of immediate availability.
72
+ """
73
+ w = weights or {"latency": 0.4, "queue": 0.6}
74
+
75
+ # Normalize: latency 0-500ms maps to 0-1, queue 0-100 maps to 0-1
76
+ latency_norm = min(1.0, self.latency_ms / 500.0)
77
+ queue_norm = min(1.0, self.queue_depth / 100.0)
78
+
79
+ base_score = w["latency"] * latency_norm + w["queue"] * queue_norm
80
+
81
+ # Penalty for recent failures (exponential backoff)
82
+ if self.consecutive_failures > 0:
83
+ failure_penalty = min(1.0, 0.1 * (2 ** self.consecutive_failures))
84
+ base_score += failure_penalty
85
+
86
+ # Penalty for stale heartbeat (>15s)
87
+ staleness = time.time() - self.last_heartbeat
88
+ if staleness > 15.0:
89
+ stale_penalty = min(1.0, staleness / 60.0)
90
+ base_score += stale_penalty
91
+
92
+ return base_score
93
+
94
+ @property
95
+ def success_rate(self) -> float:
96
+ """Success rate of requests to this peer."""
97
+ if self.total_requests == 0:
98
+ return 1.0 # Assume good until proven otherwise
99
+ return self.successful_requests / self.total_requests
100
+
101
+ def covers_layer(self, layer: int) -> bool:
102
+ """Check if this peer handles the given layer."""
103
+ return self.layer_range[0] <= layer < self.layer_range[1]
104
+
105
+ def update_latency(self, latency_ms: float, alpha: float = 0.3):
106
+ """Update latency with exponential moving average."""
107
+ self.latency_ms = alpha * latency_ms + (1 - alpha) * self.latency_ms
108
+
109
+ def record_success(self, latency_ms: float):
110
+ """Record a successful request."""
111
+ self.total_requests += 1
112
+ self.successful_requests += 1
113
+ self.consecutive_failures = 0
114
+ self.update_latency(latency_ms)
115
+
116
+ def record_failure(self):
117
+ """Record a failed request."""
118
+ self.total_requests += 1
119
+ self.consecutive_failures += 1
120
+
121
+
122
+ @dataclass
123
+ class RoutingResult:
124
+ """Result of a routing operation."""
125
+ success: bool
126
+ peer_used: Optional[str] = None
127
+ latency_ms: float = 0.0
128
+ attempts: int = 0
129
+ error: Optional[str] = None
130
+ response_data: Any = None
131
+
132
+
133
+ class SwarmRouter:
134
+ """
135
+ Fault-tolerant multipath router for distributed inference/training.
136
+
137
+ Key Features:
138
+ - Returns K candidates per layer range (not just one)
139
+ - Automatic failover: if primary doesn't ACK in 200ms, try secondary
140
+ - Weighted scoring: 0.4 * latency + 0.6 * queue_depth
141
+ - Probabilistic load balancing for equal-scored candidates
142
+
143
+ Integration:
144
+ - Replaces single-hop routing in p2p.py
145
+ - Works with SwarmHeartbeatService for capacity awareness
146
+ - Uses DHT for peer discovery fallback
147
+ """
148
+
149
+ ACK_TIMEOUT_MS = 200 # Failover if no ACK in 200ms
150
+ K_CANDIDATES = 3 # Return top-K candidates per layer range
151
+ CACHE_TTL_SECONDS = 30 # How long to cache peer info
152
+
153
+ def __init__(
154
+ self,
155
+ dht_protocol: Optional[Any] = None,
156
+ layer_pool: Optional[Any] = None,
157
+ tracker_url: Optional[str] = None,
158
+ ):
159
+ """
160
+ Initialize SwarmRouter.
161
+
162
+ Args:
163
+ dht_protocol: DHT protocol for peer discovery
164
+ layer_pool: DynamicLayerPool for local layer info
165
+ tracker_url: Fallback tracker URL for peer discovery
166
+ """
167
+ self.dht = dht_protocol
168
+ self.layer_pool = layer_pool
169
+ self.tracker_url = tracker_url
170
+
171
+ # Peer state cache
172
+ self.peer_stats: Dict[str, PeerCandidate] = {}
173
+ self._peer_lock = asyncio.Lock()
174
+
175
+ # gRPC connection pool
176
+ self._channels: Dict[str, grpc.aio.Channel] = {}
177
+ self._stubs: Dict[str, Any] = {}
178
+
179
+ # Routing metrics
180
+ self.total_routes = 0
181
+ self.successful_routes = 0
182
+ self.failover_count = 0
183
+
184
+ # Background tasks
185
+ self._cleanup_task: Optional[asyncio.Task] = None
186
+
187
+ # Property aliases for compatibility with SwarmServiceMixin
188
+ @property
189
+ def total_sends(self) -> int:
190
+ """Alias for total_routes."""
191
+ return self.total_routes
192
+
193
+ @property
194
+ def successful_sends(self) -> int:
195
+ """Alias for successful_routes."""
196
+ return self.successful_routes
197
+
198
+ async def start(self):
199
+ """Start background maintenance tasks."""
200
+ self._cleanup_task = asyncio.create_task(self._cleanup_loop())
201
+ logger.info("SwarmRouter started")
202
+
203
+ async def stop(self):
204
+ """Stop router and cleanup."""
205
+ if self._cleanup_task:
206
+ self._cleanup_task.cancel()
207
+ try:
208
+ await self._cleanup_task
209
+ except asyncio.CancelledError:
210
+ pass
211
+
212
+ # Close all gRPC channels
213
+ for channel in self._channels.values():
214
+ await channel.close()
215
+ self._channels.clear()
216
+ self._stubs.clear()
217
+
218
+ logger.info("SwarmRouter stopped")
219
+
220
+ async def _cleanup_loop(self):
221
+ """Periodically cleanup stale peers and connections."""
222
+ while True:
223
+ try:
224
+ await asyncio.sleep(60) # Every minute
225
+ await self._cleanup_stale_peers()
226
+ except asyncio.CancelledError:
227
+ break
228
+ except Exception as e:
229
+ logger.error(f"Cleanup error: {e}")
230
+
231
+ async def _cleanup_stale_peers(self):
232
+ """Remove peers that haven't sent heartbeat in >60s."""
233
+ now = time.time()
234
+ stale_threshold = 60.0
235
+
236
+ async with self._peer_lock:
237
+ stale_peers = [
238
+ node_id for node_id, peer in self.peer_stats.items()
239
+ if (now - peer.last_heartbeat) > stale_threshold
240
+ ]
241
+
242
+ for node_id in stale_peers:
243
+ del self.peer_stats[node_id]
244
+ if node_id in self._channels:
245
+ await self._channels[node_id].close()
246
+ del self._channels[node_id]
247
+ if node_id in self._stubs:
248
+ del self._stubs[node_id]
249
+
250
+ if stale_peers:
251
+ logger.info(f"Cleaned up {len(stale_peers)} stale peers")
252
+
253
+ def register_peer(self, peer: PeerCandidate):
254
+ """
255
+ Register or update a peer from heartbeat.
256
+
257
+ Called by SwarmHeartbeatService when receiving capacity broadcasts.
258
+ """
259
+ self.peer_stats[peer.node_id] = peer
260
+
261
+ def update_peer_from_heartbeat(
262
+ self,
263
+ node_id: str,
264
+ grpc_addr: str,
265
+ layer_range: Tuple[int, int],
266
+ queue_depth: int,
267
+ available_memory_mb: int,
268
+ gpu_utilization: float,
269
+ is_accepting: bool,
270
+ ):
271
+ """Update peer info from heartbeat message."""
272
+ if node_id in self.peer_stats:
273
+ peer = self.peer_stats[node_id]
274
+ peer.queue_depth = queue_depth
275
+ peer.available_memory_mb = available_memory_mb
276
+ peer.gpu_utilization = gpu_utilization
277
+ peer.is_accepting_activations = is_accepting
278
+ peer.last_heartbeat = time.time()
279
+ else:
280
+ # New peer
281
+ self.peer_stats[node_id] = PeerCandidate(
282
+ node_id=node_id,
283
+ grpc_addr=grpc_addr,
284
+ layer_range=layer_range,
285
+ queue_depth=queue_depth,
286
+ available_memory_mb=available_memory_mb,
287
+ gpu_utilization=gpu_utilization,
288
+ is_accepting_activations=is_accepting,
289
+ last_heartbeat=time.time(),
290
+ )
291
+
292
+ def get_candidates(self, target_layer: int) -> List[PeerCandidate]:
293
+ """
294
+ Get K candidates for a target layer, sorted by score.
295
+
296
+ Returns candidates from:
297
+ 1. Local cache (fastest) - peers we know from heartbeats
298
+ 2. DHT lookup (if cache miss)
299
+ 3. Tracker fallback (if DHT fails)
300
+
301
+ Args:
302
+ target_layer: The layer index we need to route to
303
+
304
+ Returns:
305
+ List of up to K candidates, sorted by score (best first)
306
+ """
307
+ candidates: List[PeerCandidate] = []
308
+
309
+ # Strategy 1: Local cache (from heartbeats)
310
+ for node_id, info in self.peer_stats.items():
311
+ if info.covers_layer(target_layer) and info.is_accepting_activations:
312
+ candidates.append(info)
313
+
314
+ # Strategy 2: DHT lookup (if not enough candidates)
315
+ if len(candidates) < self.K_CANDIDATES and self.dht:
316
+ dht_peers = self._dht_lookup_layer(target_layer)
317
+ for peer in dht_peers:
318
+ if peer.node_id not in {c.node_id for c in candidates}:
319
+ candidates.append(peer)
320
+
321
+ # Strategy 3: Tracker fallback (if still not enough)
322
+ if len(candidates) < self.K_CANDIDATES and self.tracker_url:
323
+ tracker_peers = self._tracker_lookup_layer(target_layer)
324
+ for peer in tracker_peers:
325
+ if peer.node_id not in {c.node_id for c in candidates}:
326
+ candidates.append(peer)
327
+
328
+ # Sort by score, return top K
329
+ candidates.sort(key=lambda c: c.score())
330
+
331
+ # Add small random shuffle among similarly-scored candidates
332
+ # to distribute load
333
+ if len(candidates) > 1:
334
+ candidates = self._probabilistic_shuffle(candidates)
335
+
336
+ return candidates[:self.K_CANDIDATES]
337
+
338
+ def _probabilistic_shuffle(
339
+ self,
340
+ candidates: List[PeerCandidate],
341
+ epsilon: float = 0.1
342
+ ) -> List[PeerCandidate]:
343
+ """
344
+ Add small random shuffle among similarly-scored candidates.
345
+
346
+ Candidates with scores within epsilon of each other may be swapped.
347
+ This provides load balancing without completely ignoring scores.
348
+ """
349
+ if len(candidates) < 2:
350
+ return candidates
351
+
352
+ result = candidates.copy()
353
+
354
+ for i in range(len(result) - 1):
355
+ # Check if next candidate has similar score
356
+ score_diff = abs(result[i].score() - result[i + 1].score())
357
+ if score_diff < epsilon:
358
+ # 30% chance to swap similar candidates
359
+ if random.random() < 0.3:
360
+ result[i], result[i + 1] = result[i + 1], result[i]
361
+
362
+ return result
363
+
364
+ def _dht_lookup_layer(self, target_layer: int) -> List[PeerCandidate]:
365
+ """Look up peers for a layer via DHT."""
366
+ if not self.dht:
367
+ return []
368
+
369
+ try:
370
+ # DHT key format: "layer:{layer_num}"
371
+ key = f"layer:{target_layer}"
372
+ values = self.dht.lookup_values(key, k=self.K_CANDIDATES * 2)
373
+
374
+ candidates = []
375
+ for value in values:
376
+ # Parse peer info from DHT value
377
+ try:
378
+ import json
379
+ info = json.loads(value)
380
+ peer = PeerCandidate(
381
+ node_id=info.get("node_id", ""),
382
+ grpc_addr=info.get("grpc_addr", ""),
383
+ layer_range=(info.get("start_layer", 0), info.get("end_layer", 0)),
384
+ )
385
+ candidates.append(peer)
386
+ except:
387
+ continue
388
+
389
+ return candidates
390
+ except Exception as e:
391
+ logger.debug(f"DHT lookup failed: {e}")
392
+ return []
393
+
394
+ def _tracker_lookup_layer(self, target_layer: int) -> List[PeerCandidate]:
395
+ """Look up peers for a layer via centralized tracker."""
396
+ if not self.tracker_url:
397
+ return []
398
+
399
+ try:
400
+ import requests
401
+ resp = requests.get(
402
+ f"{self.tracker_url}/peers",
403
+ params={"layer": target_layer},
404
+ timeout=2.0
405
+ )
406
+
407
+ if resp.status_code != 200:
408
+ return []
409
+
410
+ candidates = []
411
+ for peer_info in resp.json().get("peers", []):
412
+ try:
413
+ # Parse shard range (e.g., "0-12")
414
+ shard_str = peer_info.get("shard_range", "0-0")
415
+ parts = shard_str.split("-")
416
+ layer_range = (int(parts[0]), int(parts[1]) if len(parts) > 1 else int(parts[0]) + 1)
417
+
418
+ peer = PeerCandidate(
419
+ node_id=peer_info.get("node_id", peer_info.get("url", "")),
420
+ grpc_addr=peer_info.get("url", ""),
421
+ layer_range=layer_range,
422
+ )
423
+ candidates.append(peer)
424
+ except:
425
+ continue
426
+
427
+ return candidates
428
+ except Exception as e:
429
+ logger.debug(f"Tracker lookup failed: {e}")
430
+ return []
431
+
432
+ async def _get_channel(self, grpc_addr: str) -> grpc.aio.Channel:
433
+ """Get or create gRPC channel for address."""
434
+ if grpc_addr not in self._channels:
435
+ # Create new channel for P2P network
436
+ # Fast keepalive to detect dead nodes quickly in decentralized network
437
+ options = [
438
+ ('grpc.keepalive_time_ms', 30000), # Ping every 30 seconds
439
+ ('grpc.keepalive_timeout_ms', 10000), # 10 second timeout
440
+ ('grpc.keepalive_permit_without_calls', True), # Ping even when idle
441
+ ('grpc.http2.max_pings_without_data', 0), # Unlimited pings
442
+ ]
443
+
444
+ # Handle different address formats
445
+ if grpc_addr.startswith("http://"):
446
+ addr = grpc_addr.replace("http://", "")
447
+ elif grpc_addr.startswith("https://"):
448
+ addr = grpc_addr.replace("https://", "")
449
+ else:
450
+ addr = grpc_addr
451
+
452
+ self._channels[grpc_addr] = grpc.aio.insecure_channel(addr, options=options)
453
+
454
+ return self._channels[grpc_addr]
455
+
456
+ async def _get_stub(self, grpc_addr: str) -> Any:
457
+ """Get or create gRPC stub for address."""
458
+ if not GRPC_AVAILABLE:
459
+ raise RuntimeError("gRPC not available")
460
+
461
+ if grpc_addr not in self._stubs:
462
+ channel = await self._get_channel(grpc_addr)
463
+ self._stubs[grpc_addr] = neuroshard_pb2_grpc.NeuroShardServiceStub(channel)
464
+
465
+ return self._stubs[grpc_addr]
466
+
467
+ async def send_with_failover(
468
+ self,
469
+ tensor: torch.Tensor,
470
+ target_layer: int,
471
+ session_id: str,
472
+ metadata: Optional[Dict[str, Any]] = None,
473
+ ) -> RoutingResult:
474
+ """
475
+ Send tensor to target layer with automatic failover.
476
+
477
+ Algorithm:
478
+ 1. Get K candidates sorted by score
479
+ 2. Try primary candidate
480
+ 3. If no ACK in 200ms, try secondary
481
+ 4. Continue until success or all candidates exhausted
482
+
483
+ Args:
484
+ tensor: Activation tensor to send
485
+ target_layer: Target layer index
486
+ session_id: Session identifier for routing
487
+ metadata: Optional metadata to include
488
+
489
+ Returns:
490
+ RoutingResult with success status and response data
491
+ """
492
+ self.total_routes += 1
493
+ candidates = self.get_candidates(target_layer)
494
+
495
+ if not candidates:
496
+ logger.error(f"No candidates available for layer {target_layer}")
497
+ return RoutingResult(
498
+ success=False,
499
+ error=f"No candidates for layer {target_layer}",
500
+ attempts=0,
501
+ )
502
+
503
+ last_error: Optional[Exception] = None
504
+
505
+ for i, candidate in enumerate(candidates):
506
+ start_time = time.time()
507
+
508
+ try:
509
+ # Attempt send with timeout
510
+ result = await asyncio.wait_for(
511
+ self._send_to_peer(candidate, tensor, session_id, metadata),
512
+ timeout=self.ACK_TIMEOUT_MS / 1000.0
513
+ )
514
+
515
+ latency_ms = (time.time() - start_time) * 1000
516
+
517
+ # Record success
518
+ candidate.record_success(latency_ms)
519
+ self.successful_routes += 1
520
+
521
+ if i > 0:
522
+ self.failover_count += 1
523
+ logger.info(f"Failover to candidate {i}: {candidate.node_id[:8]}...")
524
+
525
+ return RoutingResult(
526
+ success=True,
527
+ peer_used=candidate.node_id,
528
+ latency_ms=latency_ms,
529
+ attempts=i + 1,
530
+ response_data=result,
531
+ )
532
+
533
+ except asyncio.TimeoutError:
534
+ candidate.record_failure()
535
+ logger.warning(
536
+ f"Peer {candidate.node_id[:8]}... timed out after {self.ACK_TIMEOUT_MS}ms, "
537
+ f"trying next ({i+1}/{len(candidates)})"
538
+ )
539
+ last_error = TimeoutError(f"Peer {candidate.node_id[:8]} timed out")
540
+ continue
541
+
542
+ except grpc.aio.AioRpcError as e:
543
+ candidate.record_failure()
544
+ logger.warning(
545
+ f"Peer {candidate.node_id[:8]}... gRPC error: {e.code()}, "
546
+ f"trying next ({i+1}/{len(candidates)})"
547
+ )
548
+ last_error = e
549
+ continue
550
+
551
+ except Exception as e:
552
+ candidate.record_failure()
553
+ logger.warning(
554
+ f"Peer {candidate.node_id[:8]}... failed: {e}, "
555
+ f"trying next ({i+1}/{len(candidates)})"
556
+ )
557
+ last_error = e
558
+ continue
559
+
560
+ # All candidates failed
561
+ error_msg = f"All {len(candidates)} candidates failed for layer {target_layer}"
562
+ if last_error:
563
+ error_msg += f": {last_error}"
564
+
565
+ logger.error(error_msg)
566
+
567
+ return RoutingResult(
568
+ success=False,
569
+ error=error_msg,
570
+ attempts=len(candidates),
571
+ )
572
+
573
+ async def _send_to_peer(
574
+ self,
575
+ candidate: PeerCandidate,
576
+ tensor: torch.Tensor,
577
+ session_id: str,
578
+ metadata: Optional[Dict[str, Any]] = None,
579
+ ) -> Any:
580
+ """
581
+ Send activation tensor to a specific peer.
582
+
583
+ Uses gRPC SwarmForward RPC (or falls back to existing Forward).
584
+ """
585
+ if not GRPC_AVAILABLE:
586
+ raise RuntimeError("gRPC not available")
587
+
588
+ stub = await self._get_stub(candidate.grpc_addr)
589
+
590
+ # Serialize tensor
591
+ tensor_bytes = tensor.cpu().numpy().tobytes()
592
+ shape = list(tensor.shape)
593
+ dtype_str = str(tensor.dtype).replace("torch.", "")
594
+
595
+ # Build request - use SwarmForward if available
596
+ try:
597
+ # Try SwarmForward (new swarm-aware RPC)
598
+ request = neuroshard_pb2.SwarmForwardRequest(
599
+ session_id=session_id,
600
+ request_id=f"swarm_{time.time()}_{session_id[:8]}",
601
+ hidden_states=tensor_bytes,
602
+ hidden_shape=shape,
603
+ # tensor_dtype inferred from shape/context
604
+ target_layer=candidate.layer_range[0],
605
+ sender_url=metadata.get("source_node", "") if metadata else "",
606
+ micro_batch_id=metadata.get("micro_batch_id", 0) if metadata else 0,
607
+ priority=metadata.get("priority", 10) if metadata else 10,
608
+ )
609
+
610
+ response = await stub.SwarmForward(request)
611
+ return response
612
+
613
+ except grpc.aio.AioRpcError as e:
614
+ if e.code() == grpc.StatusCode.UNIMPLEMENTED:
615
+ logger.warning(f"SwarmForward not implemented on {candidate.node_id}, falling back")
616
+ # Fall back to PipelineForward
617
+ request = neuroshard_pb2.PipelineForwardRequest(
618
+ session_id=session_id,
619
+ request_id=f"pipe_{time.time()}",
620
+ hidden_states=tensor_bytes,
621
+ hidden_shape=shape,
622
+ target_shard=candidate.layer_range[0],
623
+ sender_url=metadata.get("source_node", "") if metadata else "",
624
+ )
625
+ response = await stub.PipelineForward(request)
626
+ return response
627
+ raise
628
+
629
+ def get_stats(self) -> Dict[str, Any]:
630
+ """Get router statistics."""
631
+ success_rate = (
632
+ self.successful_routes / self.total_routes
633
+ if self.total_routes > 0 else 0.0
634
+ )
635
+ failover_rate = (
636
+ self.failover_count / self.total_routes
637
+ if self.total_routes > 0 else 0.0
638
+ )
639
+
640
+ return {
641
+ "total_routes": self.total_routes,
642
+ "successful_routes": self.successful_routes,
643
+ "success_rate": success_rate,
644
+ "failover_count": self.failover_count,
645
+ "failover_rate": failover_rate,
646
+ "known_peers": len(self.peer_stats),
647
+ "active_channels": len(self._channels),
648
+ }
649
+
650
+
651
+ # Convenience function for backwards compatibility
652
+ async def get_swarm_candidates(
653
+ router: SwarmRouter,
654
+ target_layer: int,
655
+ ) -> List[PeerCandidate]:
656
+ """Get swarm routing candidates for a layer."""
657
+ return router.get_candidates(target_layer)
658
+