nexaroa 0.0.111__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. neuroshard/__init__.py +93 -0
  2. neuroshard/__main__.py +4 -0
  3. neuroshard/cli.py +466 -0
  4. neuroshard/core/__init__.py +92 -0
  5. neuroshard/core/consensus/verifier.py +252 -0
  6. neuroshard/core/crypto/__init__.py +20 -0
  7. neuroshard/core/crypto/ecdsa.py +392 -0
  8. neuroshard/core/economics/__init__.py +52 -0
  9. neuroshard/core/economics/constants.py +387 -0
  10. neuroshard/core/economics/ledger.py +2111 -0
  11. neuroshard/core/economics/market.py +975 -0
  12. neuroshard/core/economics/wallet.py +168 -0
  13. neuroshard/core/governance/__init__.py +74 -0
  14. neuroshard/core/governance/proposal.py +561 -0
  15. neuroshard/core/governance/registry.py +545 -0
  16. neuroshard/core/governance/versioning.py +332 -0
  17. neuroshard/core/governance/voting.py +453 -0
  18. neuroshard/core/model/__init__.py +30 -0
  19. neuroshard/core/model/dynamic.py +4186 -0
  20. neuroshard/core/model/llm.py +905 -0
  21. neuroshard/core/model/registry.py +164 -0
  22. neuroshard/core/model/scaler.py +387 -0
  23. neuroshard/core/model/tokenizer.py +568 -0
  24. neuroshard/core/network/__init__.py +56 -0
  25. neuroshard/core/network/connection_pool.py +72 -0
  26. neuroshard/core/network/dht.py +130 -0
  27. neuroshard/core/network/dht_plan.py +55 -0
  28. neuroshard/core/network/dht_proof_store.py +516 -0
  29. neuroshard/core/network/dht_protocol.py +261 -0
  30. neuroshard/core/network/dht_service.py +506 -0
  31. neuroshard/core/network/encrypted_channel.py +141 -0
  32. neuroshard/core/network/nat.py +201 -0
  33. neuroshard/core/network/nat_traversal.py +695 -0
  34. neuroshard/core/network/p2p.py +929 -0
  35. neuroshard/core/network/p2p_data.py +150 -0
  36. neuroshard/core/swarm/__init__.py +106 -0
  37. neuroshard/core/swarm/aggregation.py +729 -0
  38. neuroshard/core/swarm/buffers.py +643 -0
  39. neuroshard/core/swarm/checkpoint.py +709 -0
  40. neuroshard/core/swarm/compute.py +624 -0
  41. neuroshard/core/swarm/diloco.py +844 -0
  42. neuroshard/core/swarm/factory.py +1288 -0
  43. neuroshard/core/swarm/heartbeat.py +669 -0
  44. neuroshard/core/swarm/logger.py +487 -0
  45. neuroshard/core/swarm/router.py +658 -0
  46. neuroshard/core/swarm/service.py +640 -0
  47. neuroshard/core/training/__init__.py +29 -0
  48. neuroshard/core/training/checkpoint.py +600 -0
  49. neuroshard/core/training/distributed.py +1602 -0
  50. neuroshard/core/training/global_tracker.py +617 -0
  51. neuroshard/core/training/production.py +276 -0
  52. neuroshard/governance_cli.py +729 -0
  53. neuroshard/grpc_server.py +895 -0
  54. neuroshard/runner.py +3223 -0
  55. neuroshard/sdk/__init__.py +92 -0
  56. neuroshard/sdk/client.py +990 -0
  57. neuroshard/sdk/errors.py +101 -0
  58. neuroshard/sdk/types.py +282 -0
  59. neuroshard/tracker/__init__.py +0 -0
  60. neuroshard/tracker/server.py +864 -0
  61. neuroshard/ui/__init__.py +0 -0
  62. neuroshard/ui/app.py +102 -0
  63. neuroshard/ui/templates/index.html +1052 -0
  64. neuroshard/utils/__init__.py +0 -0
  65. neuroshard/utils/autostart.py +81 -0
  66. neuroshard/utils/hardware.py +121 -0
  67. neuroshard/utils/serialization.py +90 -0
  68. neuroshard/version.py +1 -0
  69. nexaroa-0.0.111.dist-info/METADATA +283 -0
  70. nexaroa-0.0.111.dist-info/RECORD +78 -0
  71. nexaroa-0.0.111.dist-info/WHEEL +5 -0
  72. nexaroa-0.0.111.dist-info/entry_points.txt +4 -0
  73. nexaroa-0.0.111.dist-info/licenses/LICENSE +190 -0
  74. nexaroa-0.0.111.dist-info/top_level.txt +2 -0
  75. protos/__init__.py +0 -0
  76. protos/neuroshard.proto +651 -0
  77. protos/neuroshard_pb2.py +160 -0
  78. protos/neuroshard_pb2_grpc.py +1298 -0
@@ -0,0 +1,624 @@
1
+ """
2
+ Compute Engine - Decoupled GPU Worker with Soft Overflow
3
+
4
+ Implements async compute loop with:
5
+ - Priority-based activation processing
6
+ - Interleaved 1F1B schedule (forward/backward interleaving)
7
+ - Soft overflow handling ("Don't Stop" logic)
8
+ - DiLoCo-style local gradient accumulation during congestion
9
+
10
+ Key Directive: "If outbound.full(): Do not await. Instead:
11
+ accumulate_gradient_locally() and discard the activation.
12
+ Treat it as a DiLoCo-style local-only training step."
13
+
14
+ CRITICAL: GPU must NEVER wait for network. Compute loop must always make progress.
15
+ """
16
+
17
+ import asyncio
18
+ import logging
19
+ import time
20
+ import torch
21
+ from dataclasses import dataclass
22
+ from enum import Enum
23
+ from typing import Dict, Optional, Any, List, TYPE_CHECKING
24
+
25
+ from neuroshard.core.swarm.buffers import (
26
+ ActivationBuffer,
27
+ OutboundBuffer,
28
+ ActivationPacket,
29
+ ActivationPriority,
30
+ )
31
+
32
+ if TYPE_CHECKING:
33
+ from neuroshard.core.swarm.diloco import DiLoCoTrainer
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ class StepOutcome(Enum):
39
+ """Outcome of a compute step."""
40
+ SENT = "sent" # Normal: activation sent to next peer
41
+ LOCAL_ONLY = "local_only" # Soft overflow: accumulated locally, activation discarded
42
+ DROPPED = "dropped" # Critical overflow: couldn't even accumulate
43
+
44
+
45
+ @dataclass
46
+ class ComputeStats:
47
+ """Statistics for compute engine."""
48
+ total_steps: int = 0
49
+ forward_count: int = 0
50
+ backward_count: int = 0
51
+ local_only_steps: int = 0
52
+ dropped_steps: int = 0
53
+ total_compute_time_ms: float = 0.0
54
+ total_queue_time_ms: float = 0.0
55
+
56
+ @property
57
+ def local_only_rate(self) -> float:
58
+ """Fraction of steps that were local-only due to overflow."""
59
+ if self.total_steps == 0:
60
+ return 0.0
61
+ return self.local_only_steps / self.total_steps
62
+
63
+ @property
64
+ def drop_rate(self) -> float:
65
+ """Fraction of steps that were completely dropped."""
66
+ if self.total_steps == 0:
67
+ return 0.0
68
+ return self.dropped_steps / self.total_steps
69
+
70
+ @property
71
+ def avg_compute_time_ms(self) -> float:
72
+ """Average compute time per step."""
73
+ if self.total_steps == 0:
74
+ return 0.0
75
+ return self.total_compute_time_ms / self.total_steps
76
+
77
+ def to_dict(self) -> Dict[str, Any]:
78
+ """Convert to dictionary."""
79
+ return {
80
+ "total_steps": self.total_steps,
81
+ "forward_count": self.forward_count,
82
+ "backward_count": self.backward_count,
83
+ "local_only_steps": self.local_only_steps,
84
+ "dropped_steps": self.dropped_steps,
85
+ "local_only_rate": self.local_only_rate,
86
+ "drop_rate": self.drop_rate,
87
+ "avg_compute_time_ms": self.avg_compute_time_ms,
88
+ }
89
+
90
+
91
+ class ComputeEngine:
92
+ """
93
+ Decoupled GPU compute worker with Soft Overflow handling.
94
+
95
+ Architecture:
96
+
97
+ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐
98
+ │ InboundQueue│ ──→ │ GPU Compute │ ──→ │OutboundQueue│
99
+ │ (Priority) │ │ (Never Wait)│ │ (Async Send)│
100
+ └─────────────┘ └─────────────┘ └─────────────┘
101
+
102
+ Key Behaviors:
103
+
104
+ 1. PULLS from inbound buffer (priority queue)
105
+ 2. COMPUTES forward/backward pass on GPU
106
+ 3. PUSHES to outbound buffer (if not congested)
107
+
108
+ CRITICAL: Never waits for network - GPU must never stall!
109
+
110
+ Soft Overflow Logic ("Don't Stop" Mechanism):
111
+ =============================================
112
+ When outbound buffer is full (network congestion):
113
+
114
+ 1. DO NOT await outbound.put() - this would stall GPU
115
+ 2. Instead, accumulate gradients locally (DiLoCo style)
116
+ 3. Discard activation (don't try to send)
117
+ 4. Continue processing next packet
118
+
119
+ Rationale: Better to treat a step as "local-only training" than to
120
+ halt the GPU waiting for network. DiLoCo outer optimizer syncs later.
121
+
122
+ Interleaved 1F1B Schedule:
123
+ ==========================
124
+ For 4 micro-batches: F0 F1 F2 F3 B0 F4 B1 F5 B2 F6 B3 ...
125
+
126
+ Start backward passes BEFORE all forwards complete, overlapping
127
+ backward compute with forward network latency.
128
+ """
129
+
130
+ # Soft overflow thresholds
131
+ OUTBOUND_SOFT_LIMIT = 0.9 # Start soft overflow at 90% full
132
+ OUTBOUND_HARD_LIMIT = 0.99 # Hard limit - must discard
133
+
134
+ # Scheduling parameters
135
+ DEFAULT_NUM_MICRO_BATCHES = 4
136
+ DEFAULT_WARMUP_STEPS = 4 # Forward steps before interleaving
137
+
138
+ def __init__(
139
+ self,
140
+ model: Any, # DynamicNeuroLLM
141
+ inbound: ActivationBuffer,
142
+ outbound: OutboundBuffer,
143
+ diloco_trainer: Optional['DiLoCoTrainer'] = None,
144
+ num_micro_batches: int = DEFAULT_NUM_MICRO_BATCHES,
145
+ node_id: str = "",
146
+ ):
147
+ """
148
+ Initialize compute engine.
149
+
150
+ Args:
151
+ model: Neural network model (DynamicNeuroLLM)
152
+ inbound: Buffer for incoming activations
153
+ outbound: Buffer for outgoing activations
154
+ diloco_trainer: Optional DiLoCo trainer for local accumulation
155
+ num_micro_batches: Number of micro-batches for 1F1B schedule
156
+ node_id: This node's identifier
157
+ """
158
+ self.model = model
159
+ self.inbound = inbound
160
+ self.outbound = outbound
161
+ self.diloco = diloco_trainer
162
+ self.num_micro_batches = num_micro_batches
163
+ self.node_id = node_id
164
+
165
+ # Device handling
166
+ self.device = getattr(model, 'device', 'cpu')
167
+ if hasattr(model, 'device'):
168
+ self.device = model.device
169
+ elif torch.cuda.is_available():
170
+ self.device = torch.device('cuda')
171
+ else:
172
+ self.device = torch.device('cpu')
173
+
174
+ # Interleaved 1F1B state
175
+ self.pending_backwards: Dict[int, ActivationPacket] = {}
176
+ self.saved_activations: Dict[int, torch.Tensor] = {} # For backward pass
177
+
178
+ # Soft overflow state
179
+ self.local_gradient_buffer: Dict[str, torch.Tensor] = {}
180
+
181
+ # Statistics
182
+ self.stats = ComputeStats()
183
+
184
+ # State
185
+ self.running = False
186
+ self._task: Optional[asyncio.Task] = None
187
+
188
+ # Callbacks
189
+ self._on_forward_complete: Optional[callable] = None
190
+ self._on_backward_complete: Optional[callable] = None
191
+
192
+ def _check_outbound_pressure(self) -> str:
193
+ """
194
+ Check outbound buffer pressure level.
195
+
196
+ Returns:
197
+ "ok" - can send normally
198
+ "soft_overflow" - buffer almost full, use local accumulation
199
+ "hard_overflow" - buffer completely full, must discard
200
+ """
201
+ fill_rate = self.outbound.fill_rate
202
+
203
+ if fill_rate >= self.OUTBOUND_HARD_LIMIT:
204
+ return "hard_overflow"
205
+ elif fill_rate >= self.OUTBOUND_SOFT_LIMIT:
206
+ return "soft_overflow"
207
+ else:
208
+ return "ok"
209
+
210
+ async def start(self):
211
+ """Start the compute loop."""
212
+ if self.running:
213
+ return
214
+
215
+ self.running = True
216
+ self._task = asyncio.create_task(self.run())
217
+ logger.info(f"ComputeEngine started on {self.device}")
218
+
219
+ async def stop(self):
220
+ """Stop the compute loop gracefully."""
221
+ self.running = False
222
+
223
+ if self._task:
224
+ self._task.cancel()
225
+ try:
226
+ await self._task
227
+ except asyncio.CancelledError:
228
+ pass
229
+ self._task = None
230
+
231
+ logger.info("ComputeEngine stopped")
232
+
233
+ async def run(self):
234
+ """
235
+ Main compute loop with Interleaved 1F1B schedule and Soft Overflow.
236
+
237
+ This is the heart of the async engine. It:
238
+ 1. Pulls packets from inbound queue
239
+ 2. Processes forward/backward with interleaving
240
+ 3. Handles network congestion gracefully
241
+ """
242
+ logger.info("ComputeEngine run loop started")
243
+
244
+ while self.running:
245
+ try:
246
+ # Non-blocking get from inbound buffer
247
+ packet = self.inbound.get(timeout=0.01)
248
+
249
+ if packet is None:
250
+ # Buffer empty - GPU potentially starved
251
+ # Small sleep to avoid busy-wait
252
+ await asyncio.sleep(0.001)
253
+ continue
254
+
255
+ # Process the packet
256
+ await self._process_packet(packet)
257
+
258
+ # Interleaved 1F1B: After warmup, interleave backwards
259
+ if self.stats.forward_count >= self.num_micro_batches:
260
+ await self._try_interleaved_backward()
261
+
262
+ except asyncio.CancelledError:
263
+ break
264
+ except Exception as e:
265
+ logger.error(f"ComputeEngine error: {e}", exc_info=True)
266
+ await asyncio.sleep(0.1) # Prevent tight error loop
267
+
268
+ # Cleanup
269
+ await self._flush_pending()
270
+ logger.info("ComputeEngine run loop ended")
271
+
272
+ async def _process_packet(self, packet: ActivationPacket):
273
+ """Process a single activation packet."""
274
+ self.stats.total_steps += 1
275
+ start_time = time.time()
276
+
277
+ # Track queue wait time
278
+ queue_time = (start_time - packet.timestamp) * 1000
279
+ self.stats.total_queue_time_ms += queue_time
280
+
281
+ if packet.is_backward:
282
+ outcome = await self._process_backward(packet)
283
+ else:
284
+ outcome = await self._process_forward_with_overflow(packet)
285
+
286
+ # Track compute time
287
+ compute_time = (time.time() - start_time) * 1000
288
+ self.stats.total_compute_time_ms += compute_time
289
+
290
+ # Update stats based on outcome
291
+ if outcome == StepOutcome.LOCAL_ONLY:
292
+ self.stats.local_only_steps += 1
293
+ elif outcome == StepOutcome.DROPPED:
294
+ self.stats.dropped_steps += 1
295
+
296
+ # Periodic logging
297
+ if self.stats.total_steps % 100 == 0:
298
+ self._log_stats()
299
+
300
+ def _log_stats(self):
301
+ """Log current statistics."""
302
+ logger.info(
303
+ f"ComputeEngine: steps={self.stats.total_steps}, "
304
+ f"forward={self.stats.forward_count}, backward={self.stats.backward_count}, "
305
+ f"local_only={self.stats.local_only_rate:.1%}, "
306
+ f"dropped={self.stats.drop_rate:.1%}, "
307
+ f"avg_compute={self.stats.avg_compute_time_ms:.1f}ms"
308
+ )
309
+
310
+ async def _process_forward_with_overflow(self, packet: ActivationPacket) -> StepOutcome:
311
+ """
312
+ Process forward pass with soft overflow handling.
313
+
314
+ The "Don't Stop" Logic:
315
+ 1. Always compute forward pass (GPU never waits)
316
+ 2. Check outbound pressure AFTER compute
317
+ 3. If congested: accumulate locally, skip sending
318
+ 4. If ok: queue for outbound
319
+ """
320
+ # ALWAYS compute forward - GPU must never stall
321
+ try:
322
+ with torch.no_grad() if not packet.requires_grad else torch.enable_grad():
323
+ input_tensor = packet.tensor_data.to(self.device)
324
+
325
+ # Forward through model layers
326
+ if hasattr(self.model, 'forward_my_layers'):
327
+ output = self.model.forward_my_layers(input_tensor)
328
+ else:
329
+ output = self.model(input_tensor)
330
+
331
+ # Save activation for potential backward pass
332
+ if packet.requires_grad:
333
+ self.saved_activations[packet.micro_batch_id] = output.detach().clone()
334
+
335
+ except Exception as e:
336
+ logger.error(f"Forward pass error: {e}")
337
+ return StepOutcome.DROPPED
338
+
339
+ self.stats.forward_count += 1
340
+
341
+ # Check backpressure AFTER compute
342
+ pressure = self._check_outbound_pressure()
343
+
344
+ if pressure == "ok":
345
+ # Normal path: queue activation for sending
346
+ return await self._queue_forward_output(packet, output)
347
+ elif pressure == "soft_overflow":
348
+ # SOFT OVERFLOW: Network congested
349
+ return self._handle_soft_overflow(packet, output)
350
+ else:
351
+ # HARD OVERFLOW: Critical congestion
352
+ return self._handle_hard_overflow(packet, output)
353
+
354
+ async def _queue_forward_output(
355
+ self,
356
+ packet: ActivationPacket,
357
+ output: torch.Tensor
358
+ ) -> StepOutcome:
359
+ """Queue forward output for sending."""
360
+ # Determine next layer
361
+ if hasattr(self.model, 'my_layer_ids'):
362
+ next_layer = max(self.model.my_layer_ids) + 1
363
+ else:
364
+ next_layer = packet.target_layer + 1
365
+
366
+ outbound_packet = ActivationPacket(
367
+ priority=packet.priority,
368
+ session_id=packet.session_id,
369
+ micro_batch_id=packet.micro_batch_id,
370
+ tensor_data=output.cpu(),
371
+ source_node=self.node_id,
372
+ target_layer=next_layer,
373
+ requires_grad=packet.requires_grad,
374
+ )
375
+
376
+ # Non-blocking put with short timeout
377
+ try:
378
+ await asyncio.wait_for(
379
+ self.outbound.put(outbound_packet),
380
+ timeout=0.01 # 10ms max wait
381
+ )
382
+ return StepOutcome.SENT
383
+ except asyncio.TimeoutError:
384
+ # Couldn't send in time - fall through to soft overflow
385
+ return self._handle_soft_overflow(packet, output)
386
+ except asyncio.QueueFull:
387
+ return self._handle_soft_overflow(packet, output)
388
+
389
+ def _handle_soft_overflow(
390
+ self,
391
+ packet: ActivationPacket,
392
+ output: torch.Tensor
393
+ ) -> StepOutcome:
394
+ """
395
+ Handle soft overflow: accumulate locally, discard activation.
396
+
397
+ This implements DiLoCo "local training" behavior during congestion.
398
+ """
399
+ logger.debug(
400
+ f"Soft overflow at step {self.stats.total_steps}: "
401
+ f"accumulating locally (outbound: {self.outbound.fill_rate:.1%})"
402
+ )
403
+
404
+ self.outbound.soft_overflow_count += 1
405
+
406
+ # Accumulate gradient locally if training
407
+ if packet.requires_grad and self.model.training:
408
+ self._accumulate_local_gradient(output, packet)
409
+
410
+ # Discard activation - don't try to send
411
+ del output
412
+
413
+ return StepOutcome.LOCAL_ONLY
414
+
415
+ def _handle_hard_overflow(
416
+ self,
417
+ packet: ActivationPacket,
418
+ output: torch.Tensor
419
+ ) -> StepOutcome:
420
+ """
421
+ Handle hard overflow: must drop step entirely.
422
+ """
423
+ logger.warning(
424
+ f"Hard overflow at step {self.stats.total_steps}: "
425
+ f"dropping step (outbound: {self.outbound.fill_rate:.1%})"
426
+ )
427
+
428
+ self.outbound.hard_overflow_count += 1
429
+
430
+ del output
431
+ return StepOutcome.DROPPED
432
+
433
+ def _accumulate_local_gradient(
434
+ self,
435
+ output: torch.Tensor,
436
+ packet: ActivationPacket
437
+ ):
438
+ """
439
+ Accumulate gradient locally during soft overflow.
440
+
441
+ This is the DiLoCo "local training" behavior - gradients are
442
+ accumulated locally and will be synced during next outer step.
443
+ """
444
+ if not self.model.training:
445
+ return
446
+
447
+ # If we have upstream gradient, compute local gradient
448
+ if packet.grad_output is not None:
449
+ try:
450
+ output.backward(packet.grad_output.to(self.device))
451
+
452
+ # Track in DiLoCo trainer if available
453
+ if self.diloco:
454
+ self.diloco.inner_step_count += 1
455
+ except Exception as e:
456
+ logger.debug(f"Local gradient accumulation error: {e}")
457
+
458
+ async def _process_backward(self, packet: ActivationPacket) -> StepOutcome:
459
+ """
460
+ Process backward pass.
461
+
462
+ Backward passes must always be processed - they contain gradients
463
+ that need to be applied. However, gradient sending respects overflow.
464
+ """
465
+ try:
466
+ # Get saved activation for this micro-batch
467
+ saved_act = self.saved_activations.pop(packet.micro_batch_id, None)
468
+
469
+ if saved_act is not None and packet.grad_output is not None:
470
+ grad_output = packet.grad_output.to(self.device)
471
+
472
+ # Recompute forward and backward
473
+ saved_act.requires_grad_(True)
474
+
475
+ # Backward through model
476
+ if hasattr(self.model, 'backward_my_layers'):
477
+ grad_input = self.model.backward_my_layers(saved_act, grad_output)
478
+ else:
479
+ # Standard backward
480
+ saved_act.backward(grad_output)
481
+ grad_input = saved_act.grad
482
+
483
+ except Exception as e:
484
+ logger.error(f"Backward pass error: {e}")
485
+ return StepOutcome.DROPPED
486
+
487
+ self.stats.backward_count += 1
488
+
489
+ # Backward passes also respect soft overflow for gradient sending
490
+ pressure = self._check_outbound_pressure()
491
+
492
+ if pressure != "ok":
493
+ logger.debug(f"Backward gradient local accumulation")
494
+ return StepOutcome.LOCAL_ONLY
495
+
496
+ return StepOutcome.SENT
497
+
498
+ async def _try_interleaved_backward(self):
499
+ """
500
+ Interleaved 1F1B: Process oldest pending backward if available.
501
+
502
+ This interleaves backward passes with forward passes to hide
503
+ network latency and maintain GPU utilization.
504
+ """
505
+ if not self.pending_backwards:
506
+ return
507
+
508
+ # Get oldest micro-batch that needs backward
509
+ oldest_mb = min(self.pending_backwards.keys())
510
+ packet = self.pending_backwards.pop(oldest_mb)
511
+
512
+ await self._process_backward(packet)
513
+
514
+ async def _flush_pending(self):
515
+ """Process any remaining pending backward passes."""
516
+ while self.pending_backwards:
517
+ oldest_mb = min(self.pending_backwards.keys())
518
+ packet = self.pending_backwards.pop(oldest_mb)
519
+ await self._process_backward(packet)
520
+
521
+ def get_stats(self) -> Dict[str, Any]:
522
+ """Get comprehensive engine statistics."""
523
+ return {
524
+ **self.stats.to_dict(),
525
+ "outbound_fill_rate": self.outbound.fill_rate,
526
+ "inbound_fill_rate": self.inbound.fill_rate,
527
+ "pending_backwards": len(self.pending_backwards),
528
+ "saved_activations": len(self.saved_activations),
529
+ "device": str(self.device),
530
+ }
531
+
532
+
533
+ class InferenceEngine:
534
+ """
535
+ Simplified compute engine for inference-only workloads.
536
+
537
+ No gradient handling, no backward passes, no soft overflow.
538
+ Just fast forward pass processing.
539
+ """
540
+
541
+ def __init__(
542
+ self,
543
+ model: Any,
544
+ inbound: ActivationBuffer,
545
+ outbound: OutboundBuffer,
546
+ node_id: str = "",
547
+ ):
548
+ self.model = model
549
+ self.inbound = inbound
550
+ self.outbound = outbound
551
+ self.node_id = node_id
552
+
553
+ # Device
554
+ self.device = getattr(model, 'device', torch.device('cpu'))
555
+
556
+ # Stats
557
+ self.requests_processed = 0
558
+ self.total_latency_ms = 0.0
559
+
560
+ self.running = False
561
+ self._task: Optional[asyncio.Task] = None
562
+
563
+ async def start(self):
564
+ """Start inference loop."""
565
+ self.running = True
566
+ self._task = asyncio.create_task(self._run())
567
+
568
+ async def stop(self):
569
+ """Stop inference loop."""
570
+ self.running = False
571
+ if self._task:
572
+ self._task.cancel()
573
+ try:
574
+ await self._task
575
+ except asyncio.CancelledError:
576
+ pass
577
+
578
+ async def _run(self):
579
+ """Main inference loop."""
580
+ while self.running:
581
+ try:
582
+ packet = self.inbound.get(timeout=0.01)
583
+ if packet is None:
584
+ await asyncio.sleep(0.001)
585
+ continue
586
+
587
+ start = time.time()
588
+
589
+ with torch.no_grad():
590
+ input_tensor = packet.tensor_data.to(self.device)
591
+ output = self.model(input_tensor)
592
+
593
+ # Queue output
594
+ outbound_packet = ActivationPacket(
595
+ priority=packet.priority,
596
+ session_id=packet.session_id,
597
+ micro_batch_id=packet.micro_batch_id,
598
+ tensor_data=output.cpu(),
599
+ source_node=self.node_id,
600
+ target_layer=packet.target_layer + 1,
601
+ )
602
+
603
+ await self.outbound.put(outbound_packet)
604
+
605
+ self.requests_processed += 1
606
+ self.total_latency_ms += (time.time() - start) * 1000
607
+
608
+ except asyncio.CancelledError:
609
+ break
610
+ except Exception as e:
611
+ logger.error(f"Inference error: {e}")
612
+
613
+ def get_stats(self) -> Dict[str, Any]:
614
+ """Get inference statistics."""
615
+ avg_latency = (
616
+ self.total_latency_ms / max(1, self.requests_processed)
617
+ )
618
+ return {
619
+ "requests_processed": self.requests_processed,
620
+ "avg_latency_ms": avg_latency,
621
+ "inbound_fill": self.inbound.fill_rate,
622
+ "outbound_fill": self.outbound.fill_rate,
623
+ }
624
+