nexaroa 0.0.111__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. neuroshard/__init__.py +93 -0
  2. neuroshard/__main__.py +4 -0
  3. neuroshard/cli.py +466 -0
  4. neuroshard/core/__init__.py +92 -0
  5. neuroshard/core/consensus/verifier.py +252 -0
  6. neuroshard/core/crypto/__init__.py +20 -0
  7. neuroshard/core/crypto/ecdsa.py +392 -0
  8. neuroshard/core/economics/__init__.py +52 -0
  9. neuroshard/core/economics/constants.py +387 -0
  10. neuroshard/core/economics/ledger.py +2111 -0
  11. neuroshard/core/economics/market.py +975 -0
  12. neuroshard/core/economics/wallet.py +168 -0
  13. neuroshard/core/governance/__init__.py +74 -0
  14. neuroshard/core/governance/proposal.py +561 -0
  15. neuroshard/core/governance/registry.py +545 -0
  16. neuroshard/core/governance/versioning.py +332 -0
  17. neuroshard/core/governance/voting.py +453 -0
  18. neuroshard/core/model/__init__.py +30 -0
  19. neuroshard/core/model/dynamic.py +4186 -0
  20. neuroshard/core/model/llm.py +905 -0
  21. neuroshard/core/model/registry.py +164 -0
  22. neuroshard/core/model/scaler.py +387 -0
  23. neuroshard/core/model/tokenizer.py +568 -0
  24. neuroshard/core/network/__init__.py +56 -0
  25. neuroshard/core/network/connection_pool.py +72 -0
  26. neuroshard/core/network/dht.py +130 -0
  27. neuroshard/core/network/dht_plan.py +55 -0
  28. neuroshard/core/network/dht_proof_store.py +516 -0
  29. neuroshard/core/network/dht_protocol.py +261 -0
  30. neuroshard/core/network/dht_service.py +506 -0
  31. neuroshard/core/network/encrypted_channel.py +141 -0
  32. neuroshard/core/network/nat.py +201 -0
  33. neuroshard/core/network/nat_traversal.py +695 -0
  34. neuroshard/core/network/p2p.py +929 -0
  35. neuroshard/core/network/p2p_data.py +150 -0
  36. neuroshard/core/swarm/__init__.py +106 -0
  37. neuroshard/core/swarm/aggregation.py +729 -0
  38. neuroshard/core/swarm/buffers.py +643 -0
  39. neuroshard/core/swarm/checkpoint.py +709 -0
  40. neuroshard/core/swarm/compute.py +624 -0
  41. neuroshard/core/swarm/diloco.py +844 -0
  42. neuroshard/core/swarm/factory.py +1288 -0
  43. neuroshard/core/swarm/heartbeat.py +669 -0
  44. neuroshard/core/swarm/logger.py +487 -0
  45. neuroshard/core/swarm/router.py +658 -0
  46. neuroshard/core/swarm/service.py +640 -0
  47. neuroshard/core/training/__init__.py +29 -0
  48. neuroshard/core/training/checkpoint.py +600 -0
  49. neuroshard/core/training/distributed.py +1602 -0
  50. neuroshard/core/training/global_tracker.py +617 -0
  51. neuroshard/core/training/production.py +276 -0
  52. neuroshard/governance_cli.py +729 -0
  53. neuroshard/grpc_server.py +895 -0
  54. neuroshard/runner.py +3223 -0
  55. neuroshard/sdk/__init__.py +92 -0
  56. neuroshard/sdk/client.py +990 -0
  57. neuroshard/sdk/errors.py +101 -0
  58. neuroshard/sdk/types.py +282 -0
  59. neuroshard/tracker/__init__.py +0 -0
  60. neuroshard/tracker/server.py +864 -0
  61. neuroshard/ui/__init__.py +0 -0
  62. neuroshard/ui/app.py +102 -0
  63. neuroshard/ui/templates/index.html +1052 -0
  64. neuroshard/utils/__init__.py +0 -0
  65. neuroshard/utils/autostart.py +81 -0
  66. neuroshard/utils/hardware.py +121 -0
  67. neuroshard/utils/serialization.py +90 -0
  68. neuroshard/version.py +1 -0
  69. nexaroa-0.0.111.dist-info/METADATA +283 -0
  70. nexaroa-0.0.111.dist-info/RECORD +78 -0
  71. nexaroa-0.0.111.dist-info/WHEEL +5 -0
  72. nexaroa-0.0.111.dist-info/entry_points.txt +4 -0
  73. nexaroa-0.0.111.dist-info/licenses/LICENSE +190 -0
  74. nexaroa-0.0.111.dist-info/top_level.txt +2 -0
  75. protos/__init__.py +0 -0
  76. protos/neuroshard.proto +651 -0
  77. protos/neuroshard_pb2.py +160 -0
  78. protos/neuroshard_pb2_grpc.py +1298 -0
@@ -0,0 +1,669 @@
1
+ """
2
+ Swarm Heartbeat Service - Capacity Advertisement Protocol
3
+
4
+ Implements UDP-based heartbeat broadcasting for swarm awareness:
5
+ - Every node broadcasts capacity bitmask every 5 seconds
6
+ - Peers learn about available compute before routing
7
+ - Enables load-aware routing decisions
8
+
9
+ Key Directive: "Nodes must broadcast a lightweight Capacity Bitmask every 5 seconds
10
+ so peers know who is ready to receive work."
11
+
12
+ Capacity Bitmask contains:
13
+ - Available memory (MB)
14
+ - Queue depth (current work queue size)
15
+ - Layer range (which layers this node handles)
16
+ - GPU utilization (0-100%)
17
+ - Status flags (training, accepting inference, accepting activations)
18
+ """
19
+
20
+ import asyncio
21
+ import json
22
+ import logging
23
+ import socket
24
+ import struct
25
+ import threading
26
+ import time
27
+ from dataclasses import dataclass, field
28
+ from typing import Dict, List, Optional, Callable, Tuple, Any
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ @dataclass
34
+ class CapacityBitmask:
35
+ """
36
+ Lightweight capacity advertisement broadcast every 5 seconds.
37
+
38
+ Designed to fit in a single UDP packet (~64 bytes serialized).
39
+ Contains all info needed for routing decisions.
40
+ """
41
+ node_id: str
42
+ timestamp: float = field(default_factory=time.time)
43
+
44
+ # Network address (for gRPC connections)
45
+ grpc_addr: str = ""
46
+
47
+ # Capacity info
48
+ available_memory_mb: int = 0 # Free GPU/system memory
49
+ queue_depth: int = 0 # Current work queue size (0-65535)
50
+ layer_range: Tuple[int, int] = (0, 0) # (start_layer, end_layer)
51
+
52
+ # Utilization metrics
53
+ gpu_utilization: float = 0.0 # 0-100%
54
+ network_saturation: float = 0.0 # 0-100%
55
+
56
+ # Status flags
57
+ is_training: bool = False
58
+ is_accepting_inference: bool = True
59
+ is_accepting_activations: bool = True
60
+
61
+ # Optional: current training step for sync coordination
62
+ training_step: int = 0
63
+
64
+ def to_bytes(self) -> bytes:
65
+ """
66
+ Serialize to compact binary format (~64 bytes).
67
+
68
+ Format:
69
+ - 4 bytes: magic number (0x4E455552 = "NEUR")
70
+ - 1 byte: version
71
+ - 32 bytes: node_id (truncated/padded)
72
+ - 8 bytes: timestamp (double)
73
+ - 4 bytes: available_memory_mb (uint32)
74
+ - 2 bytes: queue_depth (uint16)
75
+ - 2 bytes: layer_start (uint16)
76
+ - 2 bytes: layer_end (uint16)
77
+ - 1 byte: gpu_utilization (uint8, 0-100)
78
+ - 1 byte: network_saturation (uint8, 0-100)
79
+ - 1 byte: flags (bit field)
80
+ - 4 bytes: training_step (uint32)
81
+ - remaining: grpc_addr (variable, null-terminated)
82
+ """
83
+ flags = 0
84
+ if self.is_training:
85
+ flags |= 0x01
86
+ if self.is_accepting_inference:
87
+ flags |= 0x02
88
+ if self.is_accepting_activations:
89
+ flags |= 0x04
90
+
91
+ # Truncate/pad node_id to 32 bytes
92
+ node_id_bytes = self.node_id.encode('utf-8')[:32].ljust(32, b'\x00')
93
+
94
+ # Pack fixed fields
95
+ header = struct.pack(
96
+ '>I B 32s d I H H H B B B I',
97
+ 0x4E455552, # Magic "NEUR"
98
+ 1, # Version
99
+ node_id_bytes, # Node ID (32 bytes)
100
+ self.timestamp, # Timestamp
101
+ self.available_memory_mb, # Memory
102
+ min(65535, self.queue_depth), # Queue depth
103
+ self.layer_range[0], # Layer start
104
+ self.layer_range[1], # Layer end
105
+ int(min(100, max(0, self.gpu_utilization))), # GPU util
106
+ int(min(100, max(0, self.network_saturation))), # Network sat
107
+ flags, # Status flags
108
+ self.training_step, # Training step
109
+ )
110
+
111
+ # Append grpc_addr (null-terminated)
112
+ addr_bytes = (self.grpc_addr + '\x00').encode('utf-8')[:128]
113
+
114
+ return header + addr_bytes
115
+
116
+ @classmethod
117
+ def from_bytes(cls, data: bytes) -> Optional['CapacityBitmask']:
118
+ """Deserialize from binary format."""
119
+ if len(data) < 62: # Minimum size without grpc_addr
120
+ return None
121
+
122
+ try:
123
+ # Unpack fixed header
124
+ magic, version, node_id_bytes, timestamp, memory, queue, \
125
+ layer_start, layer_end, gpu_util, net_sat, flags, step = \
126
+ struct.unpack('>I B 32s d I H H H B B B I', data[:62])
127
+
128
+ # Validate magic number
129
+ if magic != 0x4E455552:
130
+ logger.debug(f"Invalid magic number: {magic:#x}")
131
+ return None
132
+
133
+ # Version check
134
+ if version > 1:
135
+ logger.debug(f"Unknown version: {version}")
136
+ # Continue anyway for forward compatibility
137
+
138
+ # Decode node_id (strip null padding)
139
+ node_id = node_id_bytes.rstrip(b'\x00').decode('utf-8', errors='replace')
140
+
141
+ # Extract grpc_addr from remaining bytes
142
+ grpc_addr = ""
143
+ if len(data) > 62:
144
+ addr_data = data[62:]
145
+ null_idx = addr_data.find(b'\x00')
146
+ if null_idx >= 0:
147
+ grpc_addr = addr_data[:null_idx].decode('utf-8', errors='replace')
148
+ else:
149
+ grpc_addr = addr_data.decode('utf-8', errors='replace')
150
+
151
+ return cls(
152
+ node_id=node_id,
153
+ timestamp=timestamp,
154
+ grpc_addr=grpc_addr,
155
+ available_memory_mb=memory,
156
+ queue_depth=queue,
157
+ layer_range=(layer_start, layer_end),
158
+ gpu_utilization=float(gpu_util),
159
+ network_saturation=float(net_sat),
160
+ is_training=bool(flags & 0x01),
161
+ is_accepting_inference=bool(flags & 0x02),
162
+ is_accepting_activations=bool(flags & 0x04),
163
+ training_step=step,
164
+ )
165
+
166
+ except Exception as e:
167
+ logger.error(f"Failed to deserialize heartbeat: {e}")
168
+ return None
169
+
170
+ def to_json(self) -> str:
171
+ """Serialize to JSON (for debugging/logging)."""
172
+ return json.dumps({
173
+ "node_id": self.node_id,
174
+ "timestamp": self.timestamp,
175
+ "grpc_addr": self.grpc_addr,
176
+ "available_memory_mb": self.available_memory_mb,
177
+ "queue_depth": self.queue_depth,
178
+ "layer_range": list(self.layer_range),
179
+ "gpu_utilization": self.gpu_utilization,
180
+ "network_saturation": self.network_saturation,
181
+ "is_training": self.is_training,
182
+ "is_accepting_inference": self.is_accepting_inference,
183
+ "is_accepting_activations": self.is_accepting_activations,
184
+ "training_step": self.training_step,
185
+ })
186
+
187
+ @classmethod
188
+ def from_json(cls, data: str) -> Optional['CapacityBitmask']:
189
+ """Deserialize from JSON."""
190
+ try:
191
+ d = json.loads(data)
192
+ return cls(
193
+ node_id=d.get("node_id", ""),
194
+ timestamp=d.get("timestamp", time.time()),
195
+ grpc_addr=d.get("grpc_addr", ""),
196
+ available_memory_mb=d.get("available_memory_mb", 0),
197
+ queue_depth=d.get("queue_depth", 0),
198
+ layer_range=tuple(d.get("layer_range", [0, 0])),
199
+ gpu_utilization=d.get("gpu_utilization", 0.0),
200
+ network_saturation=d.get("network_saturation", 0.0),
201
+ is_training=d.get("is_training", False),
202
+ is_accepting_inference=d.get("is_accepting_inference", True),
203
+ is_accepting_activations=d.get("is_accepting_activations", True),
204
+ training_step=d.get("training_step", 0),
205
+ )
206
+ except Exception as e:
207
+ logger.error(f"Failed to parse heartbeat JSON: {e}")
208
+ return None
209
+
210
+
211
+ class SwarmHeartbeatService:
212
+ """
213
+ Broadcasts and receives capacity heartbeats over UDP.
214
+
215
+ Features:
216
+ - Periodic broadcast (default: every 5 seconds)
217
+ - Multicast or direct peer-to-peer UDP
218
+ - Stale peer detection (dead after 15s)
219
+ - Integration with SwarmRouter for routing updates
220
+
221
+ Usage:
222
+ service = SwarmHeartbeatService(node_id="abc123", udp_port=9999)
223
+ service.set_capacity_callback(get_my_capacity)
224
+ service.set_peer_update_callback(router.update_peer_from_heartbeat)
225
+ service.start()
226
+ """
227
+
228
+ HEARTBEAT_INTERVAL = 5.0 # Broadcast every 5 seconds
229
+ STALE_THRESHOLD = 15.0 # Consider peer dead after 15s
230
+ DEAD_THRESHOLD = 60.0 # Remove peer after 60s
231
+
232
+ # Multicast group for local network discovery
233
+ MULTICAST_GROUP = "239.255.42.42"
234
+ MULTICAST_PORT = 9999
235
+
236
+ def __init__(
237
+ self,
238
+ node_id: str,
239
+ udp_port: int = MULTICAST_PORT,
240
+ grpc_addr: str = "",
241
+ use_multicast: bool = True,
242
+ known_peers: Optional[List[str]] = None,
243
+ ):
244
+ """
245
+ Initialize heartbeat service.
246
+
247
+ Args:
248
+ node_id: This node's unique identifier
249
+ udp_port: UDP port for heartbeat traffic
250
+ grpc_addr: This node's gRPC address (for inclusion in heartbeat)
251
+ use_multicast: Whether to use multicast (True) or unicast (False)
252
+ known_peers: List of peer UDP addresses for unicast mode
253
+ """
254
+ self.node_id = node_id
255
+ self.udp_port = udp_port
256
+ self.grpc_addr = grpc_addr
257
+ self.use_multicast = use_multicast
258
+ self.known_peer_addrs = known_peers or []
259
+
260
+ # Peer state
261
+ self.peer_capacities: Dict[str, CapacityBitmask] = {}
262
+ self._peer_lock = threading.Lock()
263
+
264
+ # Local capacity (updated via update_local_capacity())
265
+ self._local_capacity: Optional[CapacityBitmask] = None
266
+
267
+ # Callbacks
268
+ self._capacity_callback: Optional[Callable[[], CapacityBitmask]] = None
269
+ self._peer_update_callback: Optional[Callable[[CapacityBitmask], None]] = None
270
+
271
+ # State
272
+ self.running = False
273
+ self._broadcast_thread: Optional[threading.Thread] = None
274
+ self._listen_thread: Optional[threading.Thread] = None
275
+ self._socket: Optional[socket.socket] = None
276
+
277
+ # Metrics
278
+ self.heartbeats_sent = 0
279
+ self.heartbeats_received = 0
280
+ self.peers_discovered = 0
281
+ self.peers_lost = 0
282
+
283
+ @property
284
+ def broadcast_count(self) -> int:
285
+ """Alias for heartbeats_sent."""
286
+ return self.heartbeats_sent
287
+
288
+ def set_capacity_callback(self, callback: Callable[[], CapacityBitmask]):
289
+ """
290
+ Set callback to get current node capacity.
291
+
292
+ The callback should return a CapacityBitmask with current state.
293
+ Called every heartbeat interval.
294
+ """
295
+ self._capacity_callback = callback
296
+
297
+ def set_peer_update_callback(self, callback: Callable[[CapacityBitmask], None]):
298
+ """
299
+ Set callback for peer capacity updates.
300
+
301
+ Called when a heartbeat is received from a peer.
302
+ Typically used to update SwarmRouter's peer stats.
303
+ """
304
+ self._peer_update_callback = callback
305
+
306
+ def add_known_peer(self, peer_addr: str):
307
+ """Add a peer address for unicast heartbeats."""
308
+ if peer_addr not in self.known_peer_addrs:
309
+ self.known_peer_addrs.append(peer_addr)
310
+
311
+ def update_local_capacity(
312
+ self,
313
+ available_memory_mb: int,
314
+ queue_depth: int,
315
+ layer_range: Tuple[int, int],
316
+ gpu_utilization: float = 0.0,
317
+ is_training: bool = False,
318
+ is_accepting_activations: bool = True,
319
+ ):
320
+ """
321
+ Update local capacity info for next heartbeat broadcast.
322
+
323
+ Called by SwarmServiceMixin to update capacity state.
324
+ The actual broadcast happens in the background thread.
325
+ """
326
+ self._local_capacity = CapacityBitmask(
327
+ node_id=self.node_id,
328
+ grpc_addr=self.grpc_addr,
329
+ available_memory_mb=available_memory_mb,
330
+ queue_depth=queue_depth,
331
+ layer_range=layer_range,
332
+ gpu_utilization=gpu_utilization,
333
+ is_training=is_training,
334
+ is_accepting_activations=is_accepting_activations,
335
+ )
336
+
337
+ def start(self):
338
+ """Start heartbeat broadcast and listener."""
339
+ if self.running:
340
+ return
341
+
342
+ self.running = True
343
+
344
+ # Setup UDP socket
345
+ self._setup_socket()
346
+
347
+ # Start broadcast thread
348
+ self._broadcast_thread = threading.Thread(
349
+ target=self._broadcast_loop,
350
+ daemon=True,
351
+ name="SwarmHeartbeat-Broadcast"
352
+ )
353
+ self._broadcast_thread.start()
354
+
355
+ # Start listener thread
356
+ self._listen_thread = threading.Thread(
357
+ target=self._listen_loop,
358
+ daemon=True,
359
+ name="SwarmHeartbeat-Listen"
360
+ )
361
+ self._listen_thread.start()
362
+
363
+ logger.info(f"SwarmHeartbeatService started on port {self.udp_port}")
364
+
365
+ def stop(self):
366
+ """Stop heartbeat service."""
367
+ self.running = False
368
+
369
+ if self._socket:
370
+ try:
371
+ self._socket.close()
372
+ except:
373
+ pass
374
+
375
+ if self._broadcast_thread:
376
+ self._broadcast_thread.join(timeout=2.0)
377
+ if self._listen_thread:
378
+ self._listen_thread.join(timeout=2.0)
379
+
380
+ logger.info("SwarmHeartbeatService stopped")
381
+
382
+ def _setup_socket(self):
383
+ """Setup UDP socket for heartbeat traffic."""
384
+ self._socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
385
+ self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
386
+
387
+ # Allow multiple processes on same host (for testing)
388
+ try:
389
+ self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
390
+ except AttributeError:
391
+ pass # Not available on all platforms
392
+
393
+ if self.use_multicast:
394
+ # Join multicast group
395
+ self._socket.bind(('', self.udp_port))
396
+ group = socket.inet_aton(self.MULTICAST_GROUP)
397
+ mreq = struct.pack('4sL', group, socket.INADDR_ANY)
398
+ self._socket.setsockopt(
399
+ socket.IPPROTO_IP,
400
+ socket.IP_ADD_MEMBERSHIP,
401
+ mreq
402
+ )
403
+ # Set TTL for multicast
404
+ self._socket.setsockopt(
405
+ socket.IPPROTO_IP,
406
+ socket.IP_MULTICAST_TTL,
407
+ 2
408
+ )
409
+ else:
410
+ self._socket.bind(('0.0.0.0', self.udp_port))
411
+
412
+ # Non-blocking with timeout
413
+ self._socket.settimeout(1.0)
414
+
415
+ def _broadcast_loop(self):
416
+ """Periodically broadcast capacity heartbeat."""
417
+ while self.running:
418
+ try:
419
+ self._broadcast_heartbeat()
420
+ time.sleep(self.HEARTBEAT_INTERVAL)
421
+ except Exception as e:
422
+ if self.running:
423
+ logger.error(f"Broadcast error: {e}")
424
+ time.sleep(1.0)
425
+
426
+ def _broadcast_heartbeat(self):
427
+ """Send a single heartbeat broadcast."""
428
+ if not self._capacity_callback:
429
+ # Create default capacity
430
+ capacity = CapacityBitmask(
431
+ node_id=self.node_id,
432
+ grpc_addr=self.grpc_addr,
433
+ )
434
+ else:
435
+ capacity = self._capacity_callback()
436
+ # Ensure our info is set
437
+ capacity.node_id = self.node_id
438
+ if not capacity.grpc_addr:
439
+ capacity.grpc_addr = self.grpc_addr
440
+
441
+ data = capacity.to_bytes()
442
+
443
+ if self.use_multicast:
444
+ # Send to multicast group
445
+ self._socket.sendto(data, (self.MULTICAST_GROUP, self.udp_port))
446
+ else:
447
+ # Send to known peers
448
+ for peer_addr in self.known_peer_addrs:
449
+ try:
450
+ host, port = self._parse_addr(peer_addr)
451
+ self._socket.sendto(data, (host, port))
452
+ except Exception as e:
453
+ logger.debug(f"Failed to send to {peer_addr}: {e}")
454
+
455
+ self.heartbeats_sent += 1
456
+
457
+ def _parse_addr(self, addr: str) -> Tuple[str, int]:
458
+ """Parse address string to (host, port)."""
459
+ if ':' in addr:
460
+ host, port = addr.rsplit(':', 1)
461
+ return host, int(port)
462
+ return addr, self.udp_port
463
+
464
+ def _listen_loop(self):
465
+ """Listen for heartbeats from peers."""
466
+ while self.running:
467
+ try:
468
+ data, addr = self._socket.recvfrom(256)
469
+ self._handle_heartbeat(data, addr)
470
+ except socket.timeout:
471
+ # Check for stale peers periodically
472
+ self._cleanup_stale_peers()
473
+ except Exception as e:
474
+ if self.running:
475
+ logger.debug(f"Listen error: {e}")
476
+
477
+ def _handle_heartbeat(self, data: bytes, addr: Tuple[str, int]):
478
+ """Process received heartbeat packet."""
479
+ capacity = CapacityBitmask.from_bytes(data)
480
+ if not capacity:
481
+ return
482
+
483
+ # Ignore our own heartbeats
484
+ if capacity.node_id == self.node_id:
485
+ return
486
+
487
+ self.heartbeats_received += 1
488
+
489
+ with self._peer_lock:
490
+ is_new = capacity.node_id not in self.peer_capacities
491
+ self.peer_capacities[capacity.node_id] = capacity
492
+
493
+ if is_new:
494
+ self.peers_discovered += 1
495
+ logger.info(
496
+ f"Discovered peer {capacity.node_id[:8]}... "
497
+ f"layers={capacity.layer_range}, "
498
+ f"queue={capacity.queue_depth}"
499
+ )
500
+
501
+ # Add to known peers for unicast
502
+ if capacity.grpc_addr and not self.use_multicast:
503
+ self.add_known_peer(f"{addr[0]}:{self.udp_port}")
504
+
505
+ # Notify callback
506
+ if self._peer_update_callback:
507
+ try:
508
+ self._peer_update_callback(capacity)
509
+ except Exception as e:
510
+ logger.error(f"Peer update callback error: {e}")
511
+
512
+ def _cleanup_stale_peers(self):
513
+ """Remove peers that haven't sent heartbeat recently."""
514
+ now = time.time()
515
+
516
+ with self._peer_lock:
517
+ dead_peers = []
518
+
519
+ for node_id, capacity in self.peer_capacities.items():
520
+ staleness = now - capacity.timestamp
521
+
522
+ if staleness > self.DEAD_THRESHOLD:
523
+ dead_peers.append(node_id)
524
+ elif staleness > self.STALE_THRESHOLD:
525
+ # Mark as not accepting (but keep for history)
526
+ capacity.is_accepting_activations = False
527
+ capacity.is_accepting_inference = False
528
+
529
+ for node_id in dead_peers:
530
+ del self.peer_capacities[node_id]
531
+ self.peers_lost += 1
532
+ logger.info(f"Lost peer {node_id[:8]}... (no heartbeat for {self.DEAD_THRESHOLD}s)")
533
+
534
+ def get_available_peers(
535
+ self,
536
+ min_memory: int = 0,
537
+ target_layer: Optional[int] = None,
538
+ ) -> List[CapacityBitmask]:
539
+ """
540
+ Get peers with available capacity.
541
+
542
+ Args:
543
+ min_memory: Minimum required memory (MB)
544
+ target_layer: Optional layer to filter by
545
+
546
+ Returns:
547
+ List of CapacityBitmask for available peers
548
+ """
549
+ now = time.time()
550
+
551
+ with self._peer_lock:
552
+ result = []
553
+
554
+ for capacity in self.peer_capacities.values():
555
+ # Check freshness
556
+ if (now - capacity.timestamp) > self.STALE_THRESHOLD:
557
+ continue
558
+
559
+ # Check availability
560
+ if not capacity.is_accepting_activations:
561
+ continue
562
+
563
+ # Check memory
564
+ if capacity.available_memory_mb < min_memory:
565
+ continue
566
+
567
+ # Check layer coverage
568
+ if target_layer is not None:
569
+ if not (capacity.layer_range[0] <= target_layer < capacity.layer_range[1]):
570
+ continue
571
+
572
+ result.append(capacity)
573
+
574
+ return result
575
+
576
+ def get_stats(self) -> Dict[str, Any]:
577
+ """Get heartbeat service statistics."""
578
+ with self._peer_lock:
579
+ active_peers = sum(
580
+ 1 for c in self.peer_capacities.values()
581
+ if c.is_accepting_activations
582
+ )
583
+
584
+ return {
585
+ "running": self.running,
586
+ "heartbeats_sent": self.heartbeats_sent,
587
+ "heartbeats_received": self.heartbeats_received,
588
+ "peers_discovered": self.peers_discovered,
589
+ "peers_lost": self.peers_lost,
590
+ "active_peers": active_peers,
591
+ "total_known_peers": len(self.peer_capacities),
592
+ "use_multicast": self.use_multicast,
593
+ }
594
+
595
+
596
+ def create_capacity_callback(
597
+ layer_range: Tuple[int, int],
598
+ inbound_buffer: Any, # ActivationBuffer
599
+ outbound_buffer: Any, # OutboundBuffer
600
+ get_memory_fn: Optional[Callable[[], int]] = None,
601
+ get_gpu_util_fn: Optional[Callable[[], float]] = None,
602
+ is_training: bool = False,
603
+ ) -> Callable[[], CapacityBitmask]:
604
+ """
605
+ Factory to create a capacity callback function.
606
+
607
+ Args:
608
+ layer_range: (start, end) layer range
609
+ inbound_buffer: ActivationBuffer for queue depth
610
+ outbound_buffer: OutboundBuffer for network saturation
611
+ get_memory_fn: Optional function returning available memory (MB)
612
+ get_gpu_util_fn: Optional function returning GPU utilization (0-100)
613
+ is_training: Whether node is currently training
614
+
615
+ Returns:
616
+ Callback function returning CapacityBitmask
617
+ """
618
+ def _get_default_memory() -> int:
619
+ """Default: report 4GB available."""
620
+ try:
621
+ import torch
622
+ if torch.cuda.is_available():
623
+ free, total = torch.cuda.mem_get_info()
624
+ return int(free / 1024 / 1024)
625
+ except:
626
+ pass
627
+ return 4096
628
+
629
+ def _get_default_gpu_util() -> float:
630
+ """Default: report 0% if no GPU."""
631
+ try:
632
+ import torch
633
+ if torch.cuda.is_available():
634
+ # Rough estimate from memory usage
635
+ free, total = torch.cuda.mem_get_info()
636
+ return 100.0 * (1 - free / total)
637
+ except:
638
+ pass
639
+ return 0.0
640
+
641
+ memory_fn = get_memory_fn or _get_default_memory
642
+ gpu_fn = get_gpu_util_fn or _get_default_gpu_util
643
+
644
+ def callback() -> CapacityBitmask:
645
+ # Get queue depth from inbound buffer
646
+ queue_depth = 0
647
+ if hasattr(inbound_buffer, '_queue'):
648
+ queue_depth = len(inbound_buffer._queue)
649
+
650
+ # Get network saturation from outbound buffer
651
+ net_sat = 0.0
652
+ if hasattr(outbound_buffer, 'fill_rate'):
653
+ net_sat = outbound_buffer.fill_rate * 100
654
+
655
+ return CapacityBitmask(
656
+ node_id="", # Set by service
657
+ timestamp=time.time(),
658
+ available_memory_mb=memory_fn(),
659
+ queue_depth=queue_depth,
660
+ layer_range=layer_range,
661
+ gpu_utilization=gpu_fn(),
662
+ network_saturation=net_sat,
663
+ is_training=is_training,
664
+ is_accepting_inference=not is_training or queue_depth < 50,
665
+ is_accepting_activations=queue_depth < 80,
666
+ )
667
+
668
+ return callback
669
+