nexaroa 0.0.111__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- neuroshard/__init__.py +93 -0
- neuroshard/__main__.py +4 -0
- neuroshard/cli.py +466 -0
- neuroshard/core/__init__.py +92 -0
- neuroshard/core/consensus/verifier.py +252 -0
- neuroshard/core/crypto/__init__.py +20 -0
- neuroshard/core/crypto/ecdsa.py +392 -0
- neuroshard/core/economics/__init__.py +52 -0
- neuroshard/core/economics/constants.py +387 -0
- neuroshard/core/economics/ledger.py +2111 -0
- neuroshard/core/economics/market.py +975 -0
- neuroshard/core/economics/wallet.py +168 -0
- neuroshard/core/governance/__init__.py +74 -0
- neuroshard/core/governance/proposal.py +561 -0
- neuroshard/core/governance/registry.py +545 -0
- neuroshard/core/governance/versioning.py +332 -0
- neuroshard/core/governance/voting.py +453 -0
- neuroshard/core/model/__init__.py +30 -0
- neuroshard/core/model/dynamic.py +4186 -0
- neuroshard/core/model/llm.py +905 -0
- neuroshard/core/model/registry.py +164 -0
- neuroshard/core/model/scaler.py +387 -0
- neuroshard/core/model/tokenizer.py +568 -0
- neuroshard/core/network/__init__.py +56 -0
- neuroshard/core/network/connection_pool.py +72 -0
- neuroshard/core/network/dht.py +130 -0
- neuroshard/core/network/dht_plan.py +55 -0
- neuroshard/core/network/dht_proof_store.py +516 -0
- neuroshard/core/network/dht_protocol.py +261 -0
- neuroshard/core/network/dht_service.py +506 -0
- neuroshard/core/network/encrypted_channel.py +141 -0
- neuroshard/core/network/nat.py +201 -0
- neuroshard/core/network/nat_traversal.py +695 -0
- neuroshard/core/network/p2p.py +929 -0
- neuroshard/core/network/p2p_data.py +150 -0
- neuroshard/core/swarm/__init__.py +106 -0
- neuroshard/core/swarm/aggregation.py +729 -0
- neuroshard/core/swarm/buffers.py +643 -0
- neuroshard/core/swarm/checkpoint.py +709 -0
- neuroshard/core/swarm/compute.py +624 -0
- neuroshard/core/swarm/diloco.py +844 -0
- neuroshard/core/swarm/factory.py +1288 -0
- neuroshard/core/swarm/heartbeat.py +669 -0
- neuroshard/core/swarm/logger.py +487 -0
- neuroshard/core/swarm/router.py +658 -0
- neuroshard/core/swarm/service.py +640 -0
- neuroshard/core/training/__init__.py +29 -0
- neuroshard/core/training/checkpoint.py +600 -0
- neuroshard/core/training/distributed.py +1602 -0
- neuroshard/core/training/global_tracker.py +617 -0
- neuroshard/core/training/production.py +276 -0
- neuroshard/governance_cli.py +729 -0
- neuroshard/grpc_server.py +895 -0
- neuroshard/runner.py +3223 -0
- neuroshard/sdk/__init__.py +92 -0
- neuroshard/sdk/client.py +990 -0
- neuroshard/sdk/errors.py +101 -0
- neuroshard/sdk/types.py +282 -0
- neuroshard/tracker/__init__.py +0 -0
- neuroshard/tracker/server.py +864 -0
- neuroshard/ui/__init__.py +0 -0
- neuroshard/ui/app.py +102 -0
- neuroshard/ui/templates/index.html +1052 -0
- neuroshard/utils/__init__.py +0 -0
- neuroshard/utils/autostart.py +81 -0
- neuroshard/utils/hardware.py +121 -0
- neuroshard/utils/serialization.py +90 -0
- neuroshard/version.py +1 -0
- nexaroa-0.0.111.dist-info/METADATA +283 -0
- nexaroa-0.0.111.dist-info/RECORD +78 -0
- nexaroa-0.0.111.dist-info/WHEEL +5 -0
- nexaroa-0.0.111.dist-info/entry_points.txt +4 -0
- nexaroa-0.0.111.dist-info/licenses/LICENSE +190 -0
- nexaroa-0.0.111.dist-info/top_level.txt +2 -0
- protos/__init__.py +0 -0
- protos/neuroshard.proto +651 -0
- protos/neuroshard_pb2.py +160 -0
- protos/neuroshard_pb2_grpc.py +1298 -0
|
@@ -0,0 +1,624 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Compute Engine - Decoupled GPU Worker with Soft Overflow
|
|
3
|
+
|
|
4
|
+
Implements async compute loop with:
|
|
5
|
+
- Priority-based activation processing
|
|
6
|
+
- Interleaved 1F1B schedule (forward/backward interleaving)
|
|
7
|
+
- Soft overflow handling ("Don't Stop" logic)
|
|
8
|
+
- DiLoCo-style local gradient accumulation during congestion
|
|
9
|
+
|
|
10
|
+
Key Directive: "If outbound.full(): Do not await. Instead:
|
|
11
|
+
accumulate_gradient_locally() and discard the activation.
|
|
12
|
+
Treat it as a DiLoCo-style local-only training step."
|
|
13
|
+
|
|
14
|
+
CRITICAL: GPU must NEVER wait for network. Compute loop must always make progress.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import asyncio
|
|
18
|
+
import logging
|
|
19
|
+
import time
|
|
20
|
+
import torch
|
|
21
|
+
from dataclasses import dataclass
|
|
22
|
+
from enum import Enum
|
|
23
|
+
from typing import Dict, Optional, Any, List, TYPE_CHECKING
|
|
24
|
+
|
|
25
|
+
from neuroshard.core.swarm.buffers import (
|
|
26
|
+
ActivationBuffer,
|
|
27
|
+
OutboundBuffer,
|
|
28
|
+
ActivationPacket,
|
|
29
|
+
ActivationPriority,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
if TYPE_CHECKING:
|
|
33
|
+
from neuroshard.core.swarm.diloco import DiLoCoTrainer
|
|
34
|
+
|
|
35
|
+
logger = logging.getLogger(__name__)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class StepOutcome(Enum):
|
|
39
|
+
"""Outcome of a compute step."""
|
|
40
|
+
SENT = "sent" # Normal: activation sent to next peer
|
|
41
|
+
LOCAL_ONLY = "local_only" # Soft overflow: accumulated locally, activation discarded
|
|
42
|
+
DROPPED = "dropped" # Critical overflow: couldn't even accumulate
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class ComputeStats:
|
|
47
|
+
"""Statistics for compute engine."""
|
|
48
|
+
total_steps: int = 0
|
|
49
|
+
forward_count: int = 0
|
|
50
|
+
backward_count: int = 0
|
|
51
|
+
local_only_steps: int = 0
|
|
52
|
+
dropped_steps: int = 0
|
|
53
|
+
total_compute_time_ms: float = 0.0
|
|
54
|
+
total_queue_time_ms: float = 0.0
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def local_only_rate(self) -> float:
|
|
58
|
+
"""Fraction of steps that were local-only due to overflow."""
|
|
59
|
+
if self.total_steps == 0:
|
|
60
|
+
return 0.0
|
|
61
|
+
return self.local_only_steps / self.total_steps
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def drop_rate(self) -> float:
|
|
65
|
+
"""Fraction of steps that were completely dropped."""
|
|
66
|
+
if self.total_steps == 0:
|
|
67
|
+
return 0.0
|
|
68
|
+
return self.dropped_steps / self.total_steps
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def avg_compute_time_ms(self) -> float:
|
|
72
|
+
"""Average compute time per step."""
|
|
73
|
+
if self.total_steps == 0:
|
|
74
|
+
return 0.0
|
|
75
|
+
return self.total_compute_time_ms / self.total_steps
|
|
76
|
+
|
|
77
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
78
|
+
"""Convert to dictionary."""
|
|
79
|
+
return {
|
|
80
|
+
"total_steps": self.total_steps,
|
|
81
|
+
"forward_count": self.forward_count,
|
|
82
|
+
"backward_count": self.backward_count,
|
|
83
|
+
"local_only_steps": self.local_only_steps,
|
|
84
|
+
"dropped_steps": self.dropped_steps,
|
|
85
|
+
"local_only_rate": self.local_only_rate,
|
|
86
|
+
"drop_rate": self.drop_rate,
|
|
87
|
+
"avg_compute_time_ms": self.avg_compute_time_ms,
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class ComputeEngine:
|
|
92
|
+
"""
|
|
93
|
+
Decoupled GPU compute worker with Soft Overflow handling.
|
|
94
|
+
|
|
95
|
+
Architecture:
|
|
96
|
+
|
|
97
|
+
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
|
|
98
|
+
│ InboundQueue│ ──→ │ GPU Compute │ ──→ │OutboundQueue│
|
|
99
|
+
│ (Priority) │ │ (Never Wait)│ │ (Async Send)│
|
|
100
|
+
└─────────────┘ └─────────────┘ └─────────────┘
|
|
101
|
+
|
|
102
|
+
Key Behaviors:
|
|
103
|
+
|
|
104
|
+
1. PULLS from inbound buffer (priority queue)
|
|
105
|
+
2. COMPUTES forward/backward pass on GPU
|
|
106
|
+
3. PUSHES to outbound buffer (if not congested)
|
|
107
|
+
|
|
108
|
+
CRITICAL: Never waits for network - GPU must never stall!
|
|
109
|
+
|
|
110
|
+
Soft Overflow Logic ("Don't Stop" Mechanism):
|
|
111
|
+
=============================================
|
|
112
|
+
When outbound buffer is full (network congestion):
|
|
113
|
+
|
|
114
|
+
1. DO NOT await outbound.put() - this would stall GPU
|
|
115
|
+
2. Instead, accumulate gradients locally (DiLoCo style)
|
|
116
|
+
3. Discard activation (don't try to send)
|
|
117
|
+
4. Continue processing next packet
|
|
118
|
+
|
|
119
|
+
Rationale: Better to treat a step as "local-only training" than to
|
|
120
|
+
halt the GPU waiting for network. DiLoCo outer optimizer syncs later.
|
|
121
|
+
|
|
122
|
+
Interleaved 1F1B Schedule:
|
|
123
|
+
==========================
|
|
124
|
+
For 4 micro-batches: F0 F1 F2 F3 B0 F4 B1 F5 B2 F6 B3 ...
|
|
125
|
+
|
|
126
|
+
Start backward passes BEFORE all forwards complete, overlapping
|
|
127
|
+
backward compute with forward network latency.
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
# Soft overflow thresholds
|
|
131
|
+
OUTBOUND_SOFT_LIMIT = 0.9 # Start soft overflow at 90% full
|
|
132
|
+
OUTBOUND_HARD_LIMIT = 0.99 # Hard limit - must discard
|
|
133
|
+
|
|
134
|
+
# Scheduling parameters
|
|
135
|
+
DEFAULT_NUM_MICRO_BATCHES = 4
|
|
136
|
+
DEFAULT_WARMUP_STEPS = 4 # Forward steps before interleaving
|
|
137
|
+
|
|
138
|
+
def __init__(
|
|
139
|
+
self,
|
|
140
|
+
model: Any, # DynamicNeuroLLM
|
|
141
|
+
inbound: ActivationBuffer,
|
|
142
|
+
outbound: OutboundBuffer,
|
|
143
|
+
diloco_trainer: Optional['DiLoCoTrainer'] = None,
|
|
144
|
+
num_micro_batches: int = DEFAULT_NUM_MICRO_BATCHES,
|
|
145
|
+
node_id: str = "",
|
|
146
|
+
):
|
|
147
|
+
"""
|
|
148
|
+
Initialize compute engine.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
model: Neural network model (DynamicNeuroLLM)
|
|
152
|
+
inbound: Buffer for incoming activations
|
|
153
|
+
outbound: Buffer for outgoing activations
|
|
154
|
+
diloco_trainer: Optional DiLoCo trainer for local accumulation
|
|
155
|
+
num_micro_batches: Number of micro-batches for 1F1B schedule
|
|
156
|
+
node_id: This node's identifier
|
|
157
|
+
"""
|
|
158
|
+
self.model = model
|
|
159
|
+
self.inbound = inbound
|
|
160
|
+
self.outbound = outbound
|
|
161
|
+
self.diloco = diloco_trainer
|
|
162
|
+
self.num_micro_batches = num_micro_batches
|
|
163
|
+
self.node_id = node_id
|
|
164
|
+
|
|
165
|
+
# Device handling
|
|
166
|
+
self.device = getattr(model, 'device', 'cpu')
|
|
167
|
+
if hasattr(model, 'device'):
|
|
168
|
+
self.device = model.device
|
|
169
|
+
elif torch.cuda.is_available():
|
|
170
|
+
self.device = torch.device('cuda')
|
|
171
|
+
else:
|
|
172
|
+
self.device = torch.device('cpu')
|
|
173
|
+
|
|
174
|
+
# Interleaved 1F1B state
|
|
175
|
+
self.pending_backwards: Dict[int, ActivationPacket] = {}
|
|
176
|
+
self.saved_activations: Dict[int, torch.Tensor] = {} # For backward pass
|
|
177
|
+
|
|
178
|
+
# Soft overflow state
|
|
179
|
+
self.local_gradient_buffer: Dict[str, torch.Tensor] = {}
|
|
180
|
+
|
|
181
|
+
# Statistics
|
|
182
|
+
self.stats = ComputeStats()
|
|
183
|
+
|
|
184
|
+
# State
|
|
185
|
+
self.running = False
|
|
186
|
+
self._task: Optional[asyncio.Task] = None
|
|
187
|
+
|
|
188
|
+
# Callbacks
|
|
189
|
+
self._on_forward_complete: Optional[callable] = None
|
|
190
|
+
self._on_backward_complete: Optional[callable] = None
|
|
191
|
+
|
|
192
|
+
def _check_outbound_pressure(self) -> str:
|
|
193
|
+
"""
|
|
194
|
+
Check outbound buffer pressure level.
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
"ok" - can send normally
|
|
198
|
+
"soft_overflow" - buffer almost full, use local accumulation
|
|
199
|
+
"hard_overflow" - buffer completely full, must discard
|
|
200
|
+
"""
|
|
201
|
+
fill_rate = self.outbound.fill_rate
|
|
202
|
+
|
|
203
|
+
if fill_rate >= self.OUTBOUND_HARD_LIMIT:
|
|
204
|
+
return "hard_overflow"
|
|
205
|
+
elif fill_rate >= self.OUTBOUND_SOFT_LIMIT:
|
|
206
|
+
return "soft_overflow"
|
|
207
|
+
else:
|
|
208
|
+
return "ok"
|
|
209
|
+
|
|
210
|
+
async def start(self):
|
|
211
|
+
"""Start the compute loop."""
|
|
212
|
+
if self.running:
|
|
213
|
+
return
|
|
214
|
+
|
|
215
|
+
self.running = True
|
|
216
|
+
self._task = asyncio.create_task(self.run())
|
|
217
|
+
logger.info(f"ComputeEngine started on {self.device}")
|
|
218
|
+
|
|
219
|
+
async def stop(self):
|
|
220
|
+
"""Stop the compute loop gracefully."""
|
|
221
|
+
self.running = False
|
|
222
|
+
|
|
223
|
+
if self._task:
|
|
224
|
+
self._task.cancel()
|
|
225
|
+
try:
|
|
226
|
+
await self._task
|
|
227
|
+
except asyncio.CancelledError:
|
|
228
|
+
pass
|
|
229
|
+
self._task = None
|
|
230
|
+
|
|
231
|
+
logger.info("ComputeEngine stopped")
|
|
232
|
+
|
|
233
|
+
async def run(self):
|
|
234
|
+
"""
|
|
235
|
+
Main compute loop with Interleaved 1F1B schedule and Soft Overflow.
|
|
236
|
+
|
|
237
|
+
This is the heart of the async engine. It:
|
|
238
|
+
1. Pulls packets from inbound queue
|
|
239
|
+
2. Processes forward/backward with interleaving
|
|
240
|
+
3. Handles network congestion gracefully
|
|
241
|
+
"""
|
|
242
|
+
logger.info("ComputeEngine run loop started")
|
|
243
|
+
|
|
244
|
+
while self.running:
|
|
245
|
+
try:
|
|
246
|
+
# Non-blocking get from inbound buffer
|
|
247
|
+
packet = self.inbound.get(timeout=0.01)
|
|
248
|
+
|
|
249
|
+
if packet is None:
|
|
250
|
+
# Buffer empty - GPU potentially starved
|
|
251
|
+
# Small sleep to avoid busy-wait
|
|
252
|
+
await asyncio.sleep(0.001)
|
|
253
|
+
continue
|
|
254
|
+
|
|
255
|
+
# Process the packet
|
|
256
|
+
await self._process_packet(packet)
|
|
257
|
+
|
|
258
|
+
# Interleaved 1F1B: After warmup, interleave backwards
|
|
259
|
+
if self.stats.forward_count >= self.num_micro_batches:
|
|
260
|
+
await self._try_interleaved_backward()
|
|
261
|
+
|
|
262
|
+
except asyncio.CancelledError:
|
|
263
|
+
break
|
|
264
|
+
except Exception as e:
|
|
265
|
+
logger.error(f"ComputeEngine error: {e}", exc_info=True)
|
|
266
|
+
await asyncio.sleep(0.1) # Prevent tight error loop
|
|
267
|
+
|
|
268
|
+
# Cleanup
|
|
269
|
+
await self._flush_pending()
|
|
270
|
+
logger.info("ComputeEngine run loop ended")
|
|
271
|
+
|
|
272
|
+
async def _process_packet(self, packet: ActivationPacket):
|
|
273
|
+
"""Process a single activation packet."""
|
|
274
|
+
self.stats.total_steps += 1
|
|
275
|
+
start_time = time.time()
|
|
276
|
+
|
|
277
|
+
# Track queue wait time
|
|
278
|
+
queue_time = (start_time - packet.timestamp) * 1000
|
|
279
|
+
self.stats.total_queue_time_ms += queue_time
|
|
280
|
+
|
|
281
|
+
if packet.is_backward:
|
|
282
|
+
outcome = await self._process_backward(packet)
|
|
283
|
+
else:
|
|
284
|
+
outcome = await self._process_forward_with_overflow(packet)
|
|
285
|
+
|
|
286
|
+
# Track compute time
|
|
287
|
+
compute_time = (time.time() - start_time) * 1000
|
|
288
|
+
self.stats.total_compute_time_ms += compute_time
|
|
289
|
+
|
|
290
|
+
# Update stats based on outcome
|
|
291
|
+
if outcome == StepOutcome.LOCAL_ONLY:
|
|
292
|
+
self.stats.local_only_steps += 1
|
|
293
|
+
elif outcome == StepOutcome.DROPPED:
|
|
294
|
+
self.stats.dropped_steps += 1
|
|
295
|
+
|
|
296
|
+
# Periodic logging
|
|
297
|
+
if self.stats.total_steps % 100 == 0:
|
|
298
|
+
self._log_stats()
|
|
299
|
+
|
|
300
|
+
def _log_stats(self):
|
|
301
|
+
"""Log current statistics."""
|
|
302
|
+
logger.info(
|
|
303
|
+
f"ComputeEngine: steps={self.stats.total_steps}, "
|
|
304
|
+
f"forward={self.stats.forward_count}, backward={self.stats.backward_count}, "
|
|
305
|
+
f"local_only={self.stats.local_only_rate:.1%}, "
|
|
306
|
+
f"dropped={self.stats.drop_rate:.1%}, "
|
|
307
|
+
f"avg_compute={self.stats.avg_compute_time_ms:.1f}ms"
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
async def _process_forward_with_overflow(self, packet: ActivationPacket) -> StepOutcome:
|
|
311
|
+
"""
|
|
312
|
+
Process forward pass with soft overflow handling.
|
|
313
|
+
|
|
314
|
+
The "Don't Stop" Logic:
|
|
315
|
+
1. Always compute forward pass (GPU never waits)
|
|
316
|
+
2. Check outbound pressure AFTER compute
|
|
317
|
+
3. If congested: accumulate locally, skip sending
|
|
318
|
+
4. If ok: queue for outbound
|
|
319
|
+
"""
|
|
320
|
+
# ALWAYS compute forward - GPU must never stall
|
|
321
|
+
try:
|
|
322
|
+
with torch.no_grad() if not packet.requires_grad else torch.enable_grad():
|
|
323
|
+
input_tensor = packet.tensor_data.to(self.device)
|
|
324
|
+
|
|
325
|
+
# Forward through model layers
|
|
326
|
+
if hasattr(self.model, 'forward_my_layers'):
|
|
327
|
+
output = self.model.forward_my_layers(input_tensor)
|
|
328
|
+
else:
|
|
329
|
+
output = self.model(input_tensor)
|
|
330
|
+
|
|
331
|
+
# Save activation for potential backward pass
|
|
332
|
+
if packet.requires_grad:
|
|
333
|
+
self.saved_activations[packet.micro_batch_id] = output.detach().clone()
|
|
334
|
+
|
|
335
|
+
except Exception as e:
|
|
336
|
+
logger.error(f"Forward pass error: {e}")
|
|
337
|
+
return StepOutcome.DROPPED
|
|
338
|
+
|
|
339
|
+
self.stats.forward_count += 1
|
|
340
|
+
|
|
341
|
+
# Check backpressure AFTER compute
|
|
342
|
+
pressure = self._check_outbound_pressure()
|
|
343
|
+
|
|
344
|
+
if pressure == "ok":
|
|
345
|
+
# Normal path: queue activation for sending
|
|
346
|
+
return await self._queue_forward_output(packet, output)
|
|
347
|
+
elif pressure == "soft_overflow":
|
|
348
|
+
# SOFT OVERFLOW: Network congested
|
|
349
|
+
return self._handle_soft_overflow(packet, output)
|
|
350
|
+
else:
|
|
351
|
+
# HARD OVERFLOW: Critical congestion
|
|
352
|
+
return self._handle_hard_overflow(packet, output)
|
|
353
|
+
|
|
354
|
+
async def _queue_forward_output(
|
|
355
|
+
self,
|
|
356
|
+
packet: ActivationPacket,
|
|
357
|
+
output: torch.Tensor
|
|
358
|
+
) -> StepOutcome:
|
|
359
|
+
"""Queue forward output for sending."""
|
|
360
|
+
# Determine next layer
|
|
361
|
+
if hasattr(self.model, 'my_layer_ids'):
|
|
362
|
+
next_layer = max(self.model.my_layer_ids) + 1
|
|
363
|
+
else:
|
|
364
|
+
next_layer = packet.target_layer + 1
|
|
365
|
+
|
|
366
|
+
outbound_packet = ActivationPacket(
|
|
367
|
+
priority=packet.priority,
|
|
368
|
+
session_id=packet.session_id,
|
|
369
|
+
micro_batch_id=packet.micro_batch_id,
|
|
370
|
+
tensor_data=output.cpu(),
|
|
371
|
+
source_node=self.node_id,
|
|
372
|
+
target_layer=next_layer,
|
|
373
|
+
requires_grad=packet.requires_grad,
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
# Non-blocking put with short timeout
|
|
377
|
+
try:
|
|
378
|
+
await asyncio.wait_for(
|
|
379
|
+
self.outbound.put(outbound_packet),
|
|
380
|
+
timeout=0.01 # 10ms max wait
|
|
381
|
+
)
|
|
382
|
+
return StepOutcome.SENT
|
|
383
|
+
except asyncio.TimeoutError:
|
|
384
|
+
# Couldn't send in time - fall through to soft overflow
|
|
385
|
+
return self._handle_soft_overflow(packet, output)
|
|
386
|
+
except asyncio.QueueFull:
|
|
387
|
+
return self._handle_soft_overflow(packet, output)
|
|
388
|
+
|
|
389
|
+
def _handle_soft_overflow(
|
|
390
|
+
self,
|
|
391
|
+
packet: ActivationPacket,
|
|
392
|
+
output: torch.Tensor
|
|
393
|
+
) -> StepOutcome:
|
|
394
|
+
"""
|
|
395
|
+
Handle soft overflow: accumulate locally, discard activation.
|
|
396
|
+
|
|
397
|
+
This implements DiLoCo "local training" behavior during congestion.
|
|
398
|
+
"""
|
|
399
|
+
logger.debug(
|
|
400
|
+
f"Soft overflow at step {self.stats.total_steps}: "
|
|
401
|
+
f"accumulating locally (outbound: {self.outbound.fill_rate:.1%})"
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
self.outbound.soft_overflow_count += 1
|
|
405
|
+
|
|
406
|
+
# Accumulate gradient locally if training
|
|
407
|
+
if packet.requires_grad and self.model.training:
|
|
408
|
+
self._accumulate_local_gradient(output, packet)
|
|
409
|
+
|
|
410
|
+
# Discard activation - don't try to send
|
|
411
|
+
del output
|
|
412
|
+
|
|
413
|
+
return StepOutcome.LOCAL_ONLY
|
|
414
|
+
|
|
415
|
+
def _handle_hard_overflow(
|
|
416
|
+
self,
|
|
417
|
+
packet: ActivationPacket,
|
|
418
|
+
output: torch.Tensor
|
|
419
|
+
) -> StepOutcome:
|
|
420
|
+
"""
|
|
421
|
+
Handle hard overflow: must drop step entirely.
|
|
422
|
+
"""
|
|
423
|
+
logger.warning(
|
|
424
|
+
f"Hard overflow at step {self.stats.total_steps}: "
|
|
425
|
+
f"dropping step (outbound: {self.outbound.fill_rate:.1%})"
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
self.outbound.hard_overflow_count += 1
|
|
429
|
+
|
|
430
|
+
del output
|
|
431
|
+
return StepOutcome.DROPPED
|
|
432
|
+
|
|
433
|
+
def _accumulate_local_gradient(
|
|
434
|
+
self,
|
|
435
|
+
output: torch.Tensor,
|
|
436
|
+
packet: ActivationPacket
|
|
437
|
+
):
|
|
438
|
+
"""
|
|
439
|
+
Accumulate gradient locally during soft overflow.
|
|
440
|
+
|
|
441
|
+
This is the DiLoCo "local training" behavior - gradients are
|
|
442
|
+
accumulated locally and will be synced during next outer step.
|
|
443
|
+
"""
|
|
444
|
+
if not self.model.training:
|
|
445
|
+
return
|
|
446
|
+
|
|
447
|
+
# If we have upstream gradient, compute local gradient
|
|
448
|
+
if packet.grad_output is not None:
|
|
449
|
+
try:
|
|
450
|
+
output.backward(packet.grad_output.to(self.device))
|
|
451
|
+
|
|
452
|
+
# Track in DiLoCo trainer if available
|
|
453
|
+
if self.diloco:
|
|
454
|
+
self.diloco.inner_step_count += 1
|
|
455
|
+
except Exception as e:
|
|
456
|
+
logger.debug(f"Local gradient accumulation error: {e}")
|
|
457
|
+
|
|
458
|
+
async def _process_backward(self, packet: ActivationPacket) -> StepOutcome:
|
|
459
|
+
"""
|
|
460
|
+
Process backward pass.
|
|
461
|
+
|
|
462
|
+
Backward passes must always be processed - they contain gradients
|
|
463
|
+
that need to be applied. However, gradient sending respects overflow.
|
|
464
|
+
"""
|
|
465
|
+
try:
|
|
466
|
+
# Get saved activation for this micro-batch
|
|
467
|
+
saved_act = self.saved_activations.pop(packet.micro_batch_id, None)
|
|
468
|
+
|
|
469
|
+
if saved_act is not None and packet.grad_output is not None:
|
|
470
|
+
grad_output = packet.grad_output.to(self.device)
|
|
471
|
+
|
|
472
|
+
# Recompute forward and backward
|
|
473
|
+
saved_act.requires_grad_(True)
|
|
474
|
+
|
|
475
|
+
# Backward through model
|
|
476
|
+
if hasattr(self.model, 'backward_my_layers'):
|
|
477
|
+
grad_input = self.model.backward_my_layers(saved_act, grad_output)
|
|
478
|
+
else:
|
|
479
|
+
# Standard backward
|
|
480
|
+
saved_act.backward(grad_output)
|
|
481
|
+
grad_input = saved_act.grad
|
|
482
|
+
|
|
483
|
+
except Exception as e:
|
|
484
|
+
logger.error(f"Backward pass error: {e}")
|
|
485
|
+
return StepOutcome.DROPPED
|
|
486
|
+
|
|
487
|
+
self.stats.backward_count += 1
|
|
488
|
+
|
|
489
|
+
# Backward passes also respect soft overflow for gradient sending
|
|
490
|
+
pressure = self._check_outbound_pressure()
|
|
491
|
+
|
|
492
|
+
if pressure != "ok":
|
|
493
|
+
logger.debug(f"Backward gradient local accumulation")
|
|
494
|
+
return StepOutcome.LOCAL_ONLY
|
|
495
|
+
|
|
496
|
+
return StepOutcome.SENT
|
|
497
|
+
|
|
498
|
+
async def _try_interleaved_backward(self):
|
|
499
|
+
"""
|
|
500
|
+
Interleaved 1F1B: Process oldest pending backward if available.
|
|
501
|
+
|
|
502
|
+
This interleaves backward passes with forward passes to hide
|
|
503
|
+
network latency and maintain GPU utilization.
|
|
504
|
+
"""
|
|
505
|
+
if not self.pending_backwards:
|
|
506
|
+
return
|
|
507
|
+
|
|
508
|
+
# Get oldest micro-batch that needs backward
|
|
509
|
+
oldest_mb = min(self.pending_backwards.keys())
|
|
510
|
+
packet = self.pending_backwards.pop(oldest_mb)
|
|
511
|
+
|
|
512
|
+
await self._process_backward(packet)
|
|
513
|
+
|
|
514
|
+
async def _flush_pending(self):
|
|
515
|
+
"""Process any remaining pending backward passes."""
|
|
516
|
+
while self.pending_backwards:
|
|
517
|
+
oldest_mb = min(self.pending_backwards.keys())
|
|
518
|
+
packet = self.pending_backwards.pop(oldest_mb)
|
|
519
|
+
await self._process_backward(packet)
|
|
520
|
+
|
|
521
|
+
def get_stats(self) -> Dict[str, Any]:
|
|
522
|
+
"""Get comprehensive engine statistics."""
|
|
523
|
+
return {
|
|
524
|
+
**self.stats.to_dict(),
|
|
525
|
+
"outbound_fill_rate": self.outbound.fill_rate,
|
|
526
|
+
"inbound_fill_rate": self.inbound.fill_rate,
|
|
527
|
+
"pending_backwards": len(self.pending_backwards),
|
|
528
|
+
"saved_activations": len(self.saved_activations),
|
|
529
|
+
"device": str(self.device),
|
|
530
|
+
}
|
|
531
|
+
|
|
532
|
+
|
|
533
|
+
class InferenceEngine:
|
|
534
|
+
"""
|
|
535
|
+
Simplified compute engine for inference-only workloads.
|
|
536
|
+
|
|
537
|
+
No gradient handling, no backward passes, no soft overflow.
|
|
538
|
+
Just fast forward pass processing.
|
|
539
|
+
"""
|
|
540
|
+
|
|
541
|
+
def __init__(
|
|
542
|
+
self,
|
|
543
|
+
model: Any,
|
|
544
|
+
inbound: ActivationBuffer,
|
|
545
|
+
outbound: OutboundBuffer,
|
|
546
|
+
node_id: str = "",
|
|
547
|
+
):
|
|
548
|
+
self.model = model
|
|
549
|
+
self.inbound = inbound
|
|
550
|
+
self.outbound = outbound
|
|
551
|
+
self.node_id = node_id
|
|
552
|
+
|
|
553
|
+
# Device
|
|
554
|
+
self.device = getattr(model, 'device', torch.device('cpu'))
|
|
555
|
+
|
|
556
|
+
# Stats
|
|
557
|
+
self.requests_processed = 0
|
|
558
|
+
self.total_latency_ms = 0.0
|
|
559
|
+
|
|
560
|
+
self.running = False
|
|
561
|
+
self._task: Optional[asyncio.Task] = None
|
|
562
|
+
|
|
563
|
+
async def start(self):
|
|
564
|
+
"""Start inference loop."""
|
|
565
|
+
self.running = True
|
|
566
|
+
self._task = asyncio.create_task(self._run())
|
|
567
|
+
|
|
568
|
+
async def stop(self):
|
|
569
|
+
"""Stop inference loop."""
|
|
570
|
+
self.running = False
|
|
571
|
+
if self._task:
|
|
572
|
+
self._task.cancel()
|
|
573
|
+
try:
|
|
574
|
+
await self._task
|
|
575
|
+
except asyncio.CancelledError:
|
|
576
|
+
pass
|
|
577
|
+
|
|
578
|
+
async def _run(self):
|
|
579
|
+
"""Main inference loop."""
|
|
580
|
+
while self.running:
|
|
581
|
+
try:
|
|
582
|
+
packet = self.inbound.get(timeout=0.01)
|
|
583
|
+
if packet is None:
|
|
584
|
+
await asyncio.sleep(0.001)
|
|
585
|
+
continue
|
|
586
|
+
|
|
587
|
+
start = time.time()
|
|
588
|
+
|
|
589
|
+
with torch.no_grad():
|
|
590
|
+
input_tensor = packet.tensor_data.to(self.device)
|
|
591
|
+
output = self.model(input_tensor)
|
|
592
|
+
|
|
593
|
+
# Queue output
|
|
594
|
+
outbound_packet = ActivationPacket(
|
|
595
|
+
priority=packet.priority,
|
|
596
|
+
session_id=packet.session_id,
|
|
597
|
+
micro_batch_id=packet.micro_batch_id,
|
|
598
|
+
tensor_data=output.cpu(),
|
|
599
|
+
source_node=self.node_id,
|
|
600
|
+
target_layer=packet.target_layer + 1,
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
await self.outbound.put(outbound_packet)
|
|
604
|
+
|
|
605
|
+
self.requests_processed += 1
|
|
606
|
+
self.total_latency_ms += (time.time() - start) * 1000
|
|
607
|
+
|
|
608
|
+
except asyncio.CancelledError:
|
|
609
|
+
break
|
|
610
|
+
except Exception as e:
|
|
611
|
+
logger.error(f"Inference error: {e}")
|
|
612
|
+
|
|
613
|
+
def get_stats(self) -> Dict[str, Any]:
|
|
614
|
+
"""Get inference statistics."""
|
|
615
|
+
avg_latency = (
|
|
616
|
+
self.total_latency_ms / max(1, self.requests_processed)
|
|
617
|
+
)
|
|
618
|
+
return {
|
|
619
|
+
"requests_processed": self.requests_processed,
|
|
620
|
+
"avg_latency_ms": avg_latency,
|
|
621
|
+
"inbound_fill": self.inbound.fill_rate,
|
|
622
|
+
"outbound_fill": self.outbound.fill_rate,
|
|
623
|
+
}
|
|
624
|
+
|