nexaroa 0.0.111__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. neuroshard/__init__.py +93 -0
  2. neuroshard/__main__.py +4 -0
  3. neuroshard/cli.py +466 -0
  4. neuroshard/core/__init__.py +92 -0
  5. neuroshard/core/consensus/verifier.py +252 -0
  6. neuroshard/core/crypto/__init__.py +20 -0
  7. neuroshard/core/crypto/ecdsa.py +392 -0
  8. neuroshard/core/economics/__init__.py +52 -0
  9. neuroshard/core/economics/constants.py +387 -0
  10. neuroshard/core/economics/ledger.py +2111 -0
  11. neuroshard/core/economics/market.py +975 -0
  12. neuroshard/core/economics/wallet.py +168 -0
  13. neuroshard/core/governance/__init__.py +74 -0
  14. neuroshard/core/governance/proposal.py +561 -0
  15. neuroshard/core/governance/registry.py +545 -0
  16. neuroshard/core/governance/versioning.py +332 -0
  17. neuroshard/core/governance/voting.py +453 -0
  18. neuroshard/core/model/__init__.py +30 -0
  19. neuroshard/core/model/dynamic.py +4186 -0
  20. neuroshard/core/model/llm.py +905 -0
  21. neuroshard/core/model/registry.py +164 -0
  22. neuroshard/core/model/scaler.py +387 -0
  23. neuroshard/core/model/tokenizer.py +568 -0
  24. neuroshard/core/network/__init__.py +56 -0
  25. neuroshard/core/network/connection_pool.py +72 -0
  26. neuroshard/core/network/dht.py +130 -0
  27. neuroshard/core/network/dht_plan.py +55 -0
  28. neuroshard/core/network/dht_proof_store.py +516 -0
  29. neuroshard/core/network/dht_protocol.py +261 -0
  30. neuroshard/core/network/dht_service.py +506 -0
  31. neuroshard/core/network/encrypted_channel.py +141 -0
  32. neuroshard/core/network/nat.py +201 -0
  33. neuroshard/core/network/nat_traversal.py +695 -0
  34. neuroshard/core/network/p2p.py +929 -0
  35. neuroshard/core/network/p2p_data.py +150 -0
  36. neuroshard/core/swarm/__init__.py +106 -0
  37. neuroshard/core/swarm/aggregation.py +729 -0
  38. neuroshard/core/swarm/buffers.py +643 -0
  39. neuroshard/core/swarm/checkpoint.py +709 -0
  40. neuroshard/core/swarm/compute.py +624 -0
  41. neuroshard/core/swarm/diloco.py +844 -0
  42. neuroshard/core/swarm/factory.py +1288 -0
  43. neuroshard/core/swarm/heartbeat.py +669 -0
  44. neuroshard/core/swarm/logger.py +487 -0
  45. neuroshard/core/swarm/router.py +658 -0
  46. neuroshard/core/swarm/service.py +640 -0
  47. neuroshard/core/training/__init__.py +29 -0
  48. neuroshard/core/training/checkpoint.py +600 -0
  49. neuroshard/core/training/distributed.py +1602 -0
  50. neuroshard/core/training/global_tracker.py +617 -0
  51. neuroshard/core/training/production.py +276 -0
  52. neuroshard/governance_cli.py +729 -0
  53. neuroshard/grpc_server.py +895 -0
  54. neuroshard/runner.py +3223 -0
  55. neuroshard/sdk/__init__.py +92 -0
  56. neuroshard/sdk/client.py +990 -0
  57. neuroshard/sdk/errors.py +101 -0
  58. neuroshard/sdk/types.py +282 -0
  59. neuroshard/tracker/__init__.py +0 -0
  60. neuroshard/tracker/server.py +864 -0
  61. neuroshard/ui/__init__.py +0 -0
  62. neuroshard/ui/app.py +102 -0
  63. neuroshard/ui/templates/index.html +1052 -0
  64. neuroshard/utils/__init__.py +0 -0
  65. neuroshard/utils/autostart.py +81 -0
  66. neuroshard/utils/hardware.py +121 -0
  67. neuroshard/utils/serialization.py +90 -0
  68. neuroshard/version.py +1 -0
  69. nexaroa-0.0.111.dist-info/METADATA +283 -0
  70. nexaroa-0.0.111.dist-info/RECORD +78 -0
  71. nexaroa-0.0.111.dist-info/WHEEL +5 -0
  72. nexaroa-0.0.111.dist-info/entry_points.txt +4 -0
  73. nexaroa-0.0.111.dist-info/licenses/LICENSE +190 -0
  74. nexaroa-0.0.111.dist-info/top_level.txt +2 -0
  75. protos/__init__.py +0 -0
  76. protos/neuroshard.proto +651 -0
  77. protos/neuroshard_pb2.py +160 -0
  78. protos/neuroshard_pb2_grpc.py +1298 -0
@@ -0,0 +1,905 @@
1
+ """
2
+ NeuroLLM - The People's Language Model
3
+
4
+ A transformer-based LLM designed from the ground up for decentralized training
5
+ and inference. This is not a wrapper around GPT or LLaMA - it's a completely
6
+ new model that grows smarter as more nodes contribute compute and data.
7
+
8
+ Key Design Principles:
9
+ 1. Modular Architecture: Easy to distribute layers/shards across nodes
10
+ 2. Efficient Gradients: Optimized for gradient compression and gossip
11
+ 3. Privacy-First: Supports differential privacy and federated learning
12
+ 4. Reward-Aligned: Training contributions earn NEURO tokens
13
+ 5. Self-Improving: Model improves as network grows
14
+
15
+ Architecture:
16
+ - RMSNorm (more stable for distributed training)
17
+ - Rotary Position Embeddings (RoPE) - no absolute position limits
18
+ - Grouped Query Attention (GQA) - memory efficient
19
+ - SwiGLU activation - better performance
20
+ - Mixture of Experts (optional) - scalability
21
+
22
+ The model starts small and grows with the network:
23
+ - Phase 1: 125M params - Bootstrap
24
+ - Phase 2: 1B params - Early adoption
25
+ - Phase 3: 7B params - Growth
26
+ - Phase 4: 70B+ params - Maturity
27
+ """
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+ import math
33
+ import logging
34
+ from typing import Optional, Tuple, Dict, List, Any
35
+ from dataclasses import dataclass, field
36
+ from enum import Enum
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ class NeuroLLMPhase(Enum):
42
+ """Growth phases of NeuroLLM."""
43
+ BOOTSTRAP = "bootstrap" # 125M - Initial training
44
+ EARLY = "early" # 1B - Early adoption
45
+ GROWTH = "growth" # 7B - Network growth
46
+ MATURE = "mature" # 70B+ - Full scale
47
+
48
+
49
+ @dataclass
50
+ class NeuroLLMConfig:
51
+ """
52
+ Configuration for NeuroLLM.
53
+
54
+ Designed to be easily scalable as the network grows.
55
+ """
56
+ # Model identity
57
+ version: str = "0.1.0"
58
+ phase: NeuroLLMPhase = NeuroLLMPhase.BOOTSTRAP
59
+
60
+ # Architecture - Bootstrap phase (125M params)
61
+ vocab_size: int = 32000 # Initial vocab - auto-expands to unlimited as tokenizer learns
62
+ hidden_dim: int = 768 # Hidden dimension
63
+ num_layers: int = 12 # Transformer layers
64
+ num_heads: int = 12 # Attention heads
65
+ num_kv_heads: int = 4 # KV heads (GQA)
66
+ intermediate_dim: int = 2048 # FFN intermediate
67
+ max_seq_len: int = 2048 # Maximum sequence length
68
+
69
+ # Regularization
70
+ dropout: float = 0.0 # Dropout (0 for inference)
71
+ attention_dropout: float = 0.0
72
+
73
+ # Training
74
+ tie_word_embeddings: bool = True
75
+ use_cache: bool = True
76
+
77
+ # Distributed training
78
+ gradient_checkpointing: bool = False
79
+ gradient_accumulation_steps: int = 1
80
+
81
+ # RoPE settings
82
+ rope_theta: float = 10000.0
83
+ rope_scaling: Optional[Dict] = None
84
+
85
+ # MoE settings (for future scaling)
86
+ num_experts: int = 0 # 0 = dense model
87
+ num_experts_per_tok: int = 2
88
+
89
+ @classmethod
90
+ def bootstrap(cls) -> 'NeuroLLMConfig':
91
+ """Bootstrap phase config (~125M params)."""
92
+ return cls(
93
+ phase=NeuroLLMPhase.BOOTSTRAP,
94
+ hidden_dim=768,
95
+ num_layers=12,
96
+ num_heads=12,
97
+ num_kv_heads=4,
98
+ intermediate_dim=2048,
99
+ )
100
+
101
+ @classmethod
102
+ def early(cls) -> 'NeuroLLMConfig':
103
+ """Early adoption config (~1B params)."""
104
+ return cls(
105
+ phase=NeuroLLMPhase.EARLY,
106
+ hidden_dim=2048,
107
+ num_layers=24,
108
+ num_heads=16,
109
+ num_kv_heads=4,
110
+ intermediate_dim=5632,
111
+ )
112
+
113
+ @classmethod
114
+ def growth(cls) -> 'NeuroLLMConfig':
115
+ """Growth phase config (~7B params)."""
116
+ return cls(
117
+ phase=NeuroLLMPhase.GROWTH,
118
+ hidden_dim=4096,
119
+ num_layers=32,
120
+ num_heads=32,
121
+ num_kv_heads=8,
122
+ intermediate_dim=11008,
123
+ )
124
+
125
+ @classmethod
126
+ def mature(cls) -> 'NeuroLLMConfig':
127
+ """Mature phase config (~70B params)."""
128
+ return cls(
129
+ phase=NeuroLLMPhase.MATURE,
130
+ hidden_dim=8192,
131
+ num_layers=80,
132
+ num_heads=64,
133
+ num_kv_heads=8,
134
+ intermediate_dim=28672,
135
+ )
136
+
137
+ @property
138
+ def head_dim(self) -> int:
139
+ return self.hidden_dim // self.num_heads
140
+
141
+ @property
142
+ def num_params(self) -> int:
143
+ """Estimate total parameters."""
144
+ # Embeddings
145
+ embed_params = self.vocab_size * self.hidden_dim
146
+
147
+ # Per layer
148
+ # Attention: Q, K, V, O projections
149
+ attn_params = (
150
+ self.hidden_dim * self.hidden_dim + # Q
151
+ self.hidden_dim * (self.hidden_dim // self.num_heads * self.num_kv_heads) * 2 + # K, V
152
+ self.hidden_dim * self.hidden_dim # O
153
+ )
154
+ # FFN: up, gate, down
155
+ ffn_params = 3 * self.hidden_dim * self.intermediate_dim
156
+ # Norms
157
+ norm_params = 2 * self.hidden_dim
158
+
159
+ layer_params = attn_params + ffn_params + norm_params
160
+
161
+ # Total
162
+ total = embed_params + self.num_layers * layer_params + self.hidden_dim # final norm
163
+
164
+ if not self.tie_word_embeddings:
165
+ total += self.vocab_size * self.hidden_dim
166
+
167
+ return total
168
+
169
+
170
+ class RMSNorm(nn.Module):
171
+ """
172
+ Root Mean Square Layer Normalization.
173
+
174
+ More stable than LayerNorm for distributed training.
175
+ """
176
+ def __init__(self, dim: int, eps: float = 1e-6):
177
+ super().__init__()
178
+ self.eps = eps
179
+ self.weight = nn.Parameter(torch.ones(dim))
180
+
181
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
182
+ # Calculate RMS
183
+ rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
184
+ return x * rms * self.weight
185
+
186
+
187
+ class RotaryEmbedding(nn.Module):
188
+ """
189
+ Rotary Position Embeddings (RoPE).
190
+
191
+ Allows the model to generalize to longer sequences than seen during training.
192
+ """
193
+ def __init__(self, dim: int, max_seq_len: int = 2048, theta: float = 10000.0):
194
+ super().__init__()
195
+ self.dim = dim
196
+ self.max_seq_len = max_seq_len
197
+ self.theta = theta
198
+
199
+ # Precompute frequencies
200
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
201
+ self.register_buffer("inv_freq", inv_freq)
202
+
203
+ # Build cache
204
+ self._build_cache(max_seq_len)
205
+
206
+ def _build_cache(self, seq_len: int):
207
+ """Build cos/sin cache for given sequence length."""
208
+ t = torch.arange(seq_len, device=self.inv_freq.device)
209
+ freqs = torch.outer(t, self.inv_freq)
210
+ emb = torch.cat((freqs, freqs), dim=-1)
211
+ self.register_buffer("cos_cached", emb.cos())
212
+ self.register_buffer("sin_cached", emb.sin())
213
+
214
+ def forward(self, x: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
215
+ if seq_len > self.max_seq_len:
216
+ self._build_cache(seq_len)
217
+ self.max_seq_len = seq_len
218
+
219
+ return (
220
+ self.cos_cached[:seq_len],
221
+ self.sin_cached[:seq_len]
222
+ )
223
+
224
+
225
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
226
+ """Rotate half the hidden dims."""
227
+ x1 = x[..., :x.shape[-1] // 2]
228
+ x2 = x[..., x.shape[-1] // 2:]
229
+ return torch.cat((-x2, x1), dim=-1)
230
+
231
+
232
+ def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor,
233
+ cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
234
+ """
235
+ Apply rotary position embeddings to Q and K.
236
+
237
+ Args:
238
+ q: Query tensor [batch, heads, seq_len, head_dim]
239
+ k: Key tensor [batch, heads, seq_len, head_dim]
240
+ cos: Cosine embeddings [seq_len, head_dim]
241
+ sin: Sine embeddings [seq_len, head_dim]
242
+
243
+ Returns:
244
+ Tuple of rotated Q and K tensors
245
+ """
246
+ # Ensure cos/sin are properly shaped for broadcasting with 4D tensors
247
+ # cos/sin: [seq_len, head_dim] -> [1, 1, seq_len, head_dim]
248
+ if cos.dim() == 2:
249
+ cos = cos.unsqueeze(0).unsqueeze(0)
250
+ sin = sin.unsqueeze(0).unsqueeze(0)
251
+
252
+ q_embed = (q * cos) + (rotate_half(q) * sin)
253
+ k_embed = (k * cos) + (rotate_half(k) * sin)
254
+ return q_embed, k_embed
255
+
256
+
257
+ class NeuroAttention(nn.Module):
258
+ """
259
+ Grouped Query Attention (GQA) for NeuroLLM.
260
+
261
+ Uses fewer KV heads than query heads for memory efficiency.
262
+ """
263
+ def __init__(self, config: NeuroLLMConfig, layer_idx: int):
264
+ super().__init__()
265
+ self.config = config
266
+ self.layer_idx = layer_idx
267
+
268
+ self.hidden_dim = config.hidden_dim
269
+ self.num_heads = config.num_heads
270
+ self.num_kv_heads = config.num_kv_heads
271
+ self.head_dim = config.head_dim
272
+ self.num_key_value_groups = self.num_heads // self.num_kv_heads
273
+
274
+ # Projections
275
+ self.q_proj = nn.Linear(self.hidden_dim, self.num_heads * self.head_dim, bias=False)
276
+ self.k_proj = nn.Linear(self.hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
277
+ self.v_proj = nn.Linear(self.hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
278
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_dim, bias=False)
279
+
280
+ # RoPE
281
+ self.rotary_emb = RotaryEmbedding(
282
+ self.head_dim,
283
+ max_seq_len=config.max_seq_len,
284
+ theta=config.rope_theta
285
+ )
286
+
287
+ self.attention_dropout = nn.Dropout(config.attention_dropout)
288
+
289
+ def forward(
290
+ self,
291
+ hidden_states: torch.Tensor,
292
+ attention_mask: Optional[torch.Tensor] = None,
293
+ position_ids: Optional[torch.Tensor] = None,
294
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
295
+ use_cache: bool = False,
296
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
297
+ batch_size, seq_len, _ = hidden_states.shape
298
+
299
+ # Project Q, K, V
300
+ query_states = self.q_proj(hidden_states)
301
+ key_states = self.k_proj(hidden_states)
302
+ value_states = self.v_proj(hidden_states)
303
+
304
+ # Reshape for multi-head attention
305
+ query_states = query_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
306
+ key_states = key_states.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
307
+ value_states = value_states.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
308
+
309
+ # Apply RoPE
310
+ cos, sin = self.rotary_emb(hidden_states, seq_len)
311
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
312
+
313
+ # Handle KV cache
314
+ if past_key_value is not None:
315
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
316
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
317
+
318
+ past_key_value = (key_states, value_states) if use_cache else None
319
+
320
+ # Repeat KV heads for GQA
321
+ key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
322
+ value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
323
+
324
+ # Attention
325
+ attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_dim)
326
+
327
+ if attention_mask is not None:
328
+ attn_weights = attn_weights + attention_mask
329
+
330
+ attn_weights = F.softmax(attn_weights, dim=-1)
331
+ attn_weights = self.attention_dropout(attn_weights)
332
+
333
+ attn_output = torch.matmul(attn_weights, value_states)
334
+
335
+ # Reshape and project output
336
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
337
+ attn_output = self.o_proj(attn_output)
338
+
339
+ return attn_output, past_key_value
340
+
341
+
342
+ class NeuroMLP(nn.Module):
343
+ """
344
+ SwiGLU-based MLP for NeuroLLM.
345
+
346
+ SwiGLU provides better performance than standard GELU.
347
+ """
348
+ def __init__(self, config: NeuroLLMConfig):
349
+ super().__init__()
350
+ self.hidden_dim = config.hidden_dim
351
+ self.intermediate_dim = config.intermediate_dim
352
+
353
+ # SwiGLU: gate * silu(up)
354
+ self.gate_proj = nn.Linear(self.hidden_dim, self.intermediate_dim, bias=False)
355
+ self.up_proj = nn.Linear(self.hidden_dim, self.intermediate_dim, bias=False)
356
+ self.down_proj = nn.Linear(self.intermediate_dim, self.hidden_dim, bias=False)
357
+
358
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
359
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
360
+
361
+
362
+ class NeuroDecoderLayer(nn.Module):
363
+ """
364
+ Single transformer decoder layer for NeuroLLM.
365
+ """
366
+ def __init__(self, config: NeuroLLMConfig, layer_idx: int):
367
+ super().__init__()
368
+ self.layer_idx = layer_idx
369
+
370
+ self.self_attn = NeuroAttention(config, layer_idx)
371
+ self.mlp = NeuroMLP(config)
372
+
373
+ self.input_layernorm = RMSNorm(config.hidden_dim)
374
+ self.post_attention_layernorm = RMSNorm(config.hidden_dim)
375
+
376
+ def forward(
377
+ self,
378
+ hidden_states: torch.Tensor,
379
+ attention_mask: Optional[torch.Tensor] = None,
380
+ position_ids: Optional[torch.Tensor] = None,
381
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
382
+ use_cache: bool = False,
383
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
384
+ # Self attention with residual
385
+ residual = hidden_states
386
+ hidden_states = self.input_layernorm(hidden_states)
387
+ hidden_states, present_key_value = self.self_attn(
388
+ hidden_states,
389
+ attention_mask=attention_mask,
390
+ position_ids=position_ids,
391
+ past_key_value=past_key_value,
392
+ use_cache=use_cache,
393
+ )
394
+ hidden_states = residual + hidden_states
395
+
396
+ # MLP with residual
397
+ residual = hidden_states
398
+ hidden_states = self.post_attention_layernorm(hidden_states)
399
+ hidden_states = self.mlp(hidden_states)
400
+ hidden_states = residual + hidden_states
401
+
402
+ return hidden_states, present_key_value
403
+
404
+
405
+ class NeuroLLMModel(nn.Module):
406
+ """
407
+ The core NeuroLLM transformer model.
408
+
409
+ This is the base model without the LM head.
410
+ """
411
+ def __init__(self, config: NeuroLLMConfig):
412
+ super().__init__()
413
+ self.config = config
414
+
415
+ # Token embeddings
416
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_dim)
417
+
418
+ # Transformer layers
419
+ self.layers = nn.ModuleList([
420
+ NeuroDecoderLayer(config, layer_idx)
421
+ for layer_idx in range(config.num_layers)
422
+ ])
423
+
424
+ # Final norm
425
+ self.norm = RMSNorm(config.hidden_dim)
426
+
427
+ # Gradient checkpointing
428
+ self.gradient_checkpointing = config.gradient_checkpointing
429
+
430
+ # Initialize weights
431
+ self.apply(self._init_weights)
432
+
433
+ def _init_weights(self, module):
434
+ """Initialize weights with small values for stable training."""
435
+ if isinstance(module, nn.Linear):
436
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
437
+ if module.bias is not None:
438
+ torch.nn.init.zeros_(module.bias)
439
+ elif isinstance(module, nn.Embedding):
440
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
441
+
442
+ def forward(
443
+ self,
444
+ input_ids: torch.Tensor,
445
+ attention_mask: Optional[torch.Tensor] = None,
446
+ position_ids: Optional[torch.Tensor] = None,
447
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
448
+ use_cache: bool = False,
449
+ ) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
450
+ batch_size, seq_len = input_ids.shape
451
+
452
+ # Get embeddings
453
+ hidden_states = self.embed_tokens(input_ids)
454
+
455
+ # Create causal mask
456
+ if attention_mask is None:
457
+ attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device)
458
+
459
+ # Create causal attention mask
460
+ causal_mask = torch.triu(
461
+ torch.full((seq_len, seq_len), float('-inf'), device=input_ids.device),
462
+ diagonal=1
463
+ )
464
+
465
+ # Process through layers
466
+ present_key_values = [] if use_cache else None
467
+
468
+ for idx, layer in enumerate(self.layers):
469
+ past_key_value = past_key_values[idx] if past_key_values else None
470
+
471
+ if self.gradient_checkpointing and self.training:
472
+ hidden_states, present_key_value = torch.utils.checkpoint.checkpoint(
473
+ layer,
474
+ hidden_states,
475
+ causal_mask,
476
+ position_ids,
477
+ past_key_value,
478
+ use_cache,
479
+ )
480
+ else:
481
+ hidden_states, present_key_value = layer(
482
+ hidden_states,
483
+ attention_mask=causal_mask,
484
+ position_ids=position_ids,
485
+ past_key_value=past_key_value,
486
+ use_cache=use_cache,
487
+ )
488
+
489
+ if use_cache:
490
+ present_key_values.append(present_key_value)
491
+
492
+ # Final norm
493
+ hidden_states = self.norm(hidden_states)
494
+
495
+ return hidden_states, present_key_values
496
+
497
+
498
+ class NeuroLLMForCausalLM(nn.Module):
499
+ """
500
+ NeuroLLM with language modeling head.
501
+
502
+ This is the full model for both training and inference.
503
+ """
504
+ def __init__(self, config: NeuroLLMConfig):
505
+ super().__init__()
506
+ self.config = config
507
+
508
+ self.model = NeuroLLMModel(config)
509
+
510
+ # LM head (optionally tied to embeddings)
511
+ self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
512
+
513
+ if config.tie_word_embeddings:
514
+ self.lm_head.weight = self.model.embed_tokens.weight
515
+
516
+ # Loss function
517
+ self.loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
518
+
519
+ logger.info(f"NeuroLLM initialized: {config.num_params / 1e6:.1f}M parameters, "
520
+ f"phase={config.phase.value}")
521
+
522
+ def forward(
523
+ self,
524
+ input_ids: torch.Tensor,
525
+ attention_mask: Optional[torch.Tensor] = None,
526
+ labels: Optional[torch.Tensor] = None,
527
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
528
+ use_cache: bool = False,
529
+ ) -> Dict[str, torch.Tensor]:
530
+ # Get hidden states
531
+ hidden_states, present_key_values = self.model(
532
+ input_ids,
533
+ attention_mask=attention_mask,
534
+ past_key_values=past_key_values,
535
+ use_cache=use_cache,
536
+ )
537
+
538
+ # Get logits
539
+ logits = self.lm_head(hidden_states)
540
+
541
+ # Calculate loss if labels provided
542
+ loss = None
543
+ if labels is not None:
544
+ # Shift for next token prediction
545
+ shift_logits = logits[..., :-1, :].contiguous()
546
+ shift_labels = labels[..., 1:].contiguous()
547
+
548
+ # Flatten
549
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
550
+ shift_labels = shift_labels.view(-1)
551
+
552
+ loss = self.loss_fn(shift_logits, shift_labels)
553
+
554
+ return {
555
+ "loss": loss,
556
+ "logits": logits,
557
+ "past_key_values": present_key_values,
558
+ }
559
+
560
+ @torch.no_grad()
561
+ def generate(
562
+ self,
563
+ input_ids: torch.Tensor,
564
+ max_new_tokens: int = 100,
565
+ temperature: float = 1.0,
566
+ top_k: int = 50,
567
+ top_p: float = 0.9,
568
+ do_sample: bool = True,
569
+ valid_vocab_size: int = 266, # Mask tokens beyond this (default: byte tokens only)
570
+ ) -> torch.Tensor:
571
+ """Generate text autoregressively."""
572
+ self.eval()
573
+
574
+ past_key_values = None
575
+ generated = input_ids
576
+
577
+ for _ in range(max_new_tokens):
578
+ # Forward pass
579
+ outputs = self.forward(
580
+ input_ids=generated if past_key_values is None else generated[:, -1:],
581
+ past_key_values=past_key_values,
582
+ use_cache=True,
583
+ )
584
+
585
+ logits = outputs["logits"][:, -1, :]
586
+ past_key_values = outputs["past_key_values"]
587
+
588
+ # Constrain to valid vocabulary (standard BPE tokenizer behavior)
589
+ # Tokens beyond valid_vocab_size don't exist in the tokenizer yet
590
+ if valid_vocab_size < logits.size(-1):
591
+ logits[:, valid_vocab_size:] = float('-inf')
592
+
593
+ # Apply temperature
594
+ if temperature != 1.0:
595
+ logits = logits / temperature
596
+
597
+ # Apply top-k
598
+ if top_k > 0:
599
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
600
+ logits[indices_to_remove] = float('-inf')
601
+
602
+ # Apply top-p (nucleus sampling)
603
+ if top_p < 1.0:
604
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
605
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
606
+
607
+ sorted_indices_to_remove = cumulative_probs > top_p
608
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
609
+ sorted_indices_to_remove[..., 0] = 0
610
+
611
+ indices_to_remove = sorted_indices_to_remove.scatter(
612
+ 1, sorted_indices, sorted_indices_to_remove
613
+ )
614
+ logits[indices_to_remove] = float('-inf')
615
+
616
+ # Sample or greedy
617
+ if do_sample:
618
+ probs = F.softmax(logits, dim=-1)
619
+ next_token = torch.multinomial(probs, num_samples=1)
620
+ else:
621
+ next_token = torch.argmax(logits, dim=-1, keepdim=True)
622
+
623
+ generated = torch.cat([generated, next_token], dim=-1)
624
+
625
+ # Check for EOS (assuming token 2 is EOS)
626
+ if next_token.item() == 2:
627
+ break
628
+
629
+ return generated
630
+
631
+ def get_num_params(self) -> int:
632
+ """Get actual number of parameters."""
633
+ return sum(p.numel() for p in self.parameters())
634
+
635
+ def save_checkpoint(self, path: str, extra_state: Dict = None):
636
+ """Save model checkpoint."""
637
+ state = {
638
+ "config": self.config.__dict__,
639
+ "model_state_dict": self.state_dict(),
640
+ "version": self.config.version,
641
+ }
642
+ if extra_state:
643
+ state.update(extra_state)
644
+ torch.save(state, path)
645
+ logger.info(f"Saved checkpoint to {path}")
646
+
647
+ @classmethod
648
+ def load_checkpoint(cls, path: str, device: str = "cpu") -> 'NeuroLLMForCausalLM':
649
+ """Load model from checkpoint."""
650
+ # Use weights_only=False for full checkpoint loading (includes config)
651
+ # This is safe because we only load our own checkpoints
652
+ state = torch.load(path, map_location=device, weights_only=False)
653
+
654
+ # Reconstruct config
655
+ config_dict = state["config"]
656
+ config_dict["phase"] = NeuroLLMPhase(config_dict["phase"])
657
+ config = NeuroLLMConfig(**config_dict)
658
+
659
+ # Create model and load weights
660
+ model = cls(config)
661
+ model.load_state_dict(state["model_state_dict"])
662
+
663
+ logger.info(f"Loaded checkpoint from {path}")
664
+ return model
665
+
666
+ # ==================== SHARDED MODEL SUPPORT ====================
667
+
668
+ def get_layer_names(self, layer_idx: int) -> List[str]:
669
+ """Get parameter names for a specific layer."""
670
+ prefix = f"model.layers.{layer_idx}."
671
+ return [name for name in self.state_dict().keys() if name.startswith(prefix)]
672
+
673
+ def extract_shard(
674
+ self,
675
+ start_layer: int,
676
+ end_layer: int,
677
+ include_embedding: bool = False,
678
+ include_lm_head: bool = False
679
+ ) -> Dict[str, torch.Tensor]:
680
+ """
681
+ Extract weights for a shard (subset of layers).
682
+
683
+ Args:
684
+ start_layer: First layer to include
685
+ end_layer: Last layer to include (exclusive)
686
+ include_embedding: Include embedding layer
687
+ include_lm_head: Include LM head
688
+
689
+ Returns:
690
+ Dict of parameter names to tensors
691
+ """
692
+ state_dict = self.state_dict()
693
+ shard_weights = {}
694
+
695
+ for name, param in state_dict.items():
696
+ include = False
697
+
698
+ # Embedding
699
+ if include_embedding and ("embed" in name.lower() or "wte" in name.lower()):
700
+ include = True
701
+
702
+ # LM head
703
+ if include_lm_head and ("lm_head" in name.lower()):
704
+ include = True
705
+
706
+ # Final norm (goes with LM head)
707
+ if include_lm_head and ("final" in name.lower() or "ln_f" in name.lower() or "model.norm" in name.lower()):
708
+ include = True
709
+
710
+ # Transformer layers
711
+ import re
712
+ match = re.search(r'layers\.(\d+)\.', name)
713
+ if match:
714
+ layer_num = int(match.group(1))
715
+ if start_layer <= layer_num < end_layer:
716
+ include = True
717
+
718
+ if include:
719
+ shard_weights[name] = param.clone()
720
+
721
+ logger.info(f"Extracted shard: layers {start_layer}-{end_layer}, "
722
+ f"embed={include_embedding}, head={include_lm_head}, "
723
+ f"params={len(shard_weights)}")
724
+
725
+ return shard_weights
726
+
727
+ def load_shard(
728
+ self,
729
+ shard_weights: Dict[str, torch.Tensor],
730
+ strict: bool = False
731
+ ) -> Tuple[List[str], List[str]]:
732
+ """
733
+ Load weights from a shard into the model.
734
+
735
+ Args:
736
+ shard_weights: Dict of parameter names to tensors
737
+ strict: If True, raise error on missing/unexpected keys
738
+
739
+ Returns:
740
+ (missing_keys, unexpected_keys)
741
+ """
742
+ current_state = self.state_dict()
743
+
744
+ missing_keys = []
745
+ unexpected_keys = []
746
+ loaded_keys = []
747
+
748
+ for name, param in shard_weights.items():
749
+ if name in current_state:
750
+ if current_state[name].shape == param.shape:
751
+ current_state[name].copy_(param)
752
+ loaded_keys.append(name)
753
+ else:
754
+ logger.warning(f"Shape mismatch for {name}: "
755
+ f"model={current_state[name].shape}, shard={param.shape}")
756
+ else:
757
+ unexpected_keys.append(name)
758
+
759
+ # Check for missing keys in the shard range
760
+ for name in current_state.keys():
761
+ if name not in shard_weights and name not in loaded_keys:
762
+ # Only report as missing if it should have been in the shard
763
+ pass # Shards are partial by design
764
+
765
+ if strict and unexpected_keys:
766
+ raise RuntimeError(f"Unexpected keys in shard: {unexpected_keys}")
767
+
768
+ logger.info(f"Loaded shard: {len(loaded_keys)} parameters")
769
+
770
+ return missing_keys, unexpected_keys
771
+
772
+ def save_shard(
773
+ self,
774
+ path: str,
775
+ start_layer: int,
776
+ end_layer: int,
777
+ include_embedding: bool = False,
778
+ include_lm_head: bool = False,
779
+ extra_state: Dict = None
780
+ ):
781
+ """
782
+ Save a shard to disk.
783
+
784
+ Args:
785
+ path: Save path
786
+ start_layer: First layer
787
+ end_layer: Last layer (exclusive)
788
+ include_embedding: Include embedding
789
+ include_lm_head: Include LM head
790
+ extra_state: Additional state to save
791
+ """
792
+ shard_weights = self.extract_shard(
793
+ start_layer, end_layer, include_embedding, include_lm_head
794
+ )
795
+
796
+ state = {
797
+ "shard_config": {
798
+ "start_layer": start_layer,
799
+ "end_layer": end_layer,
800
+ "include_embedding": include_embedding,
801
+ "include_lm_head": include_lm_head,
802
+ },
803
+ "model_config": self.config.__dict__,
804
+ "shard_weights": shard_weights,
805
+ "num_params": sum(p.numel() for p in shard_weights.values()),
806
+ }
807
+
808
+ if extra_state:
809
+ state.update(extra_state)
810
+
811
+ torch.save(state, path)
812
+ logger.info(f"Saved shard to {path}: layers {start_layer}-{end_layer}, "
813
+ f"{state['num_params']} params")
814
+
815
+ @classmethod
816
+ def load_shard_file(cls, path: str, device: str = "cpu") -> Tuple[Dict[str, torch.Tensor], Dict]:
817
+ """
818
+ Load a shard file.
819
+
820
+ Args:
821
+ path: Shard file path
822
+ device: Device to load to
823
+
824
+ Returns:
825
+ (shard_weights, shard_config)
826
+ """
827
+ state = torch.load(path, map_location=device, weights_only=False)
828
+
829
+ return state["shard_weights"], state["shard_config"]
830
+
831
+ def forward_shard(
832
+ self,
833
+ hidden_states: torch.Tensor,
834
+ start_layer: int,
835
+ end_layer: int,
836
+ attention_mask: Optional[torch.Tensor] = None,
837
+ position_ids: Optional[torch.Tensor] = None,
838
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
839
+ use_cache: bool = False,
840
+ ) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
841
+ """
842
+ Forward pass through a subset of layers (for pipeline parallelism).
843
+
844
+ This is used when the model is sharded across nodes:
845
+ - Node A processes layers 0-3
846
+ - Node B processes layers 4-7
847
+ - etc.
848
+
849
+ Args:
850
+ hidden_states: Input hidden states [batch, seq, hidden]
851
+ start_layer: First layer to process
852
+ end_layer: Last layer to process (exclusive)
853
+ attention_mask: Attention mask
854
+ position_ids: Position IDs
855
+ past_key_values: KV cache
856
+ use_cache: Whether to return KV cache
857
+
858
+ Returns:
859
+ (output_hidden_states, new_past_key_values)
860
+ """
861
+ new_past_key_values = [] if use_cache else None
862
+
863
+ for idx in range(start_layer, end_layer):
864
+ layer = self.model.layers[idx]
865
+
866
+ past_kv = past_key_values[idx] if past_key_values else None
867
+
868
+ hidden_states, new_kv = layer(
869
+ hidden_states,
870
+ attention_mask=attention_mask,
871
+ position_ids=position_ids,
872
+ past_key_value=past_kv,
873
+ use_cache=use_cache,
874
+ )
875
+
876
+ if use_cache:
877
+ new_past_key_values.append(new_kv)
878
+
879
+ return hidden_states, new_past_key_values
880
+
881
+
882
+ # Convenience function
883
+ def create_neuro_llm(phase: str = "bootstrap") -> NeuroLLMForCausalLM:
884
+ """
885
+ Create a NeuroLLM model for the specified phase.
886
+
887
+ Args:
888
+ phase: One of "bootstrap", "early", "growth", "mature"
889
+
890
+ Returns:
891
+ NeuroLLMForCausalLM instance
892
+ """
893
+ config_map = {
894
+ "bootstrap": NeuroLLMConfig.bootstrap,
895
+ "early": NeuroLLMConfig.early,
896
+ "growth": NeuroLLMConfig.growth,
897
+ "mature": NeuroLLMConfig.mature,
898
+ }
899
+
900
+ if phase not in config_map:
901
+ raise ValueError(f"Unknown phase: {phase}. Choose from {list(config_map.keys())}")
902
+
903
+ config = config_map[phase]()
904
+ return NeuroLLMForCausalLM(config)
905
+