nexaroa 0.0.111__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- neuroshard/__init__.py +93 -0
- neuroshard/__main__.py +4 -0
- neuroshard/cli.py +466 -0
- neuroshard/core/__init__.py +92 -0
- neuroshard/core/consensus/verifier.py +252 -0
- neuroshard/core/crypto/__init__.py +20 -0
- neuroshard/core/crypto/ecdsa.py +392 -0
- neuroshard/core/economics/__init__.py +52 -0
- neuroshard/core/economics/constants.py +387 -0
- neuroshard/core/economics/ledger.py +2111 -0
- neuroshard/core/economics/market.py +975 -0
- neuroshard/core/economics/wallet.py +168 -0
- neuroshard/core/governance/__init__.py +74 -0
- neuroshard/core/governance/proposal.py +561 -0
- neuroshard/core/governance/registry.py +545 -0
- neuroshard/core/governance/versioning.py +332 -0
- neuroshard/core/governance/voting.py +453 -0
- neuroshard/core/model/__init__.py +30 -0
- neuroshard/core/model/dynamic.py +4186 -0
- neuroshard/core/model/llm.py +905 -0
- neuroshard/core/model/registry.py +164 -0
- neuroshard/core/model/scaler.py +387 -0
- neuroshard/core/model/tokenizer.py +568 -0
- neuroshard/core/network/__init__.py +56 -0
- neuroshard/core/network/connection_pool.py +72 -0
- neuroshard/core/network/dht.py +130 -0
- neuroshard/core/network/dht_plan.py +55 -0
- neuroshard/core/network/dht_proof_store.py +516 -0
- neuroshard/core/network/dht_protocol.py +261 -0
- neuroshard/core/network/dht_service.py +506 -0
- neuroshard/core/network/encrypted_channel.py +141 -0
- neuroshard/core/network/nat.py +201 -0
- neuroshard/core/network/nat_traversal.py +695 -0
- neuroshard/core/network/p2p.py +929 -0
- neuroshard/core/network/p2p_data.py +150 -0
- neuroshard/core/swarm/__init__.py +106 -0
- neuroshard/core/swarm/aggregation.py +729 -0
- neuroshard/core/swarm/buffers.py +643 -0
- neuroshard/core/swarm/checkpoint.py +709 -0
- neuroshard/core/swarm/compute.py +624 -0
- neuroshard/core/swarm/diloco.py +844 -0
- neuroshard/core/swarm/factory.py +1288 -0
- neuroshard/core/swarm/heartbeat.py +669 -0
- neuroshard/core/swarm/logger.py +487 -0
- neuroshard/core/swarm/router.py +658 -0
- neuroshard/core/swarm/service.py +640 -0
- neuroshard/core/training/__init__.py +29 -0
- neuroshard/core/training/checkpoint.py +600 -0
- neuroshard/core/training/distributed.py +1602 -0
- neuroshard/core/training/global_tracker.py +617 -0
- neuroshard/core/training/production.py +276 -0
- neuroshard/governance_cli.py +729 -0
- neuroshard/grpc_server.py +895 -0
- neuroshard/runner.py +3223 -0
- neuroshard/sdk/__init__.py +92 -0
- neuroshard/sdk/client.py +990 -0
- neuroshard/sdk/errors.py +101 -0
- neuroshard/sdk/types.py +282 -0
- neuroshard/tracker/__init__.py +0 -0
- neuroshard/tracker/server.py +864 -0
- neuroshard/ui/__init__.py +0 -0
- neuroshard/ui/app.py +102 -0
- neuroshard/ui/templates/index.html +1052 -0
- neuroshard/utils/__init__.py +0 -0
- neuroshard/utils/autostart.py +81 -0
- neuroshard/utils/hardware.py +121 -0
- neuroshard/utils/serialization.py +90 -0
- neuroshard/version.py +1 -0
- nexaroa-0.0.111.dist-info/METADATA +283 -0
- nexaroa-0.0.111.dist-info/RECORD +78 -0
- nexaroa-0.0.111.dist-info/WHEEL +5 -0
- nexaroa-0.0.111.dist-info/entry_points.txt +4 -0
- nexaroa-0.0.111.dist-info/licenses/LICENSE +190 -0
- nexaroa-0.0.111.dist-info/top_level.txt +2 -0
- protos/__init__.py +0 -0
- protos/neuroshard.proto +651 -0
- protos/neuroshard_pb2.py +160 -0
- protos/neuroshard_pb2_grpc.py +1298 -0
|
@@ -0,0 +1,905 @@
|
|
|
1
|
+
"""
|
|
2
|
+
NeuroLLM - The People's Language Model
|
|
3
|
+
|
|
4
|
+
A transformer-based LLM designed from the ground up for decentralized training
|
|
5
|
+
and inference. This is not a wrapper around GPT or LLaMA - it's a completely
|
|
6
|
+
new model that grows smarter as more nodes contribute compute and data.
|
|
7
|
+
|
|
8
|
+
Key Design Principles:
|
|
9
|
+
1. Modular Architecture: Easy to distribute layers/shards across nodes
|
|
10
|
+
2. Efficient Gradients: Optimized for gradient compression and gossip
|
|
11
|
+
3. Privacy-First: Supports differential privacy and federated learning
|
|
12
|
+
4. Reward-Aligned: Training contributions earn NEURO tokens
|
|
13
|
+
5. Self-Improving: Model improves as network grows
|
|
14
|
+
|
|
15
|
+
Architecture:
|
|
16
|
+
- RMSNorm (more stable for distributed training)
|
|
17
|
+
- Rotary Position Embeddings (RoPE) - no absolute position limits
|
|
18
|
+
- Grouped Query Attention (GQA) - memory efficient
|
|
19
|
+
- SwiGLU activation - better performance
|
|
20
|
+
- Mixture of Experts (optional) - scalability
|
|
21
|
+
|
|
22
|
+
The model starts small and grows with the network:
|
|
23
|
+
- Phase 1: 125M params - Bootstrap
|
|
24
|
+
- Phase 2: 1B params - Early adoption
|
|
25
|
+
- Phase 3: 7B params - Growth
|
|
26
|
+
- Phase 4: 70B+ params - Maturity
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
import torch
|
|
30
|
+
import torch.nn as nn
|
|
31
|
+
import torch.nn.functional as F
|
|
32
|
+
import math
|
|
33
|
+
import logging
|
|
34
|
+
from typing import Optional, Tuple, Dict, List, Any
|
|
35
|
+
from dataclasses import dataclass, field
|
|
36
|
+
from enum import Enum
|
|
37
|
+
|
|
38
|
+
logger = logging.getLogger(__name__)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class NeuroLLMPhase(Enum):
|
|
42
|
+
"""Growth phases of NeuroLLM."""
|
|
43
|
+
BOOTSTRAP = "bootstrap" # 125M - Initial training
|
|
44
|
+
EARLY = "early" # 1B - Early adoption
|
|
45
|
+
GROWTH = "growth" # 7B - Network growth
|
|
46
|
+
MATURE = "mature" # 70B+ - Full scale
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclass
|
|
50
|
+
class NeuroLLMConfig:
|
|
51
|
+
"""
|
|
52
|
+
Configuration for NeuroLLM.
|
|
53
|
+
|
|
54
|
+
Designed to be easily scalable as the network grows.
|
|
55
|
+
"""
|
|
56
|
+
# Model identity
|
|
57
|
+
version: str = "0.1.0"
|
|
58
|
+
phase: NeuroLLMPhase = NeuroLLMPhase.BOOTSTRAP
|
|
59
|
+
|
|
60
|
+
# Architecture - Bootstrap phase (125M params)
|
|
61
|
+
vocab_size: int = 32000 # Initial vocab - auto-expands to unlimited as tokenizer learns
|
|
62
|
+
hidden_dim: int = 768 # Hidden dimension
|
|
63
|
+
num_layers: int = 12 # Transformer layers
|
|
64
|
+
num_heads: int = 12 # Attention heads
|
|
65
|
+
num_kv_heads: int = 4 # KV heads (GQA)
|
|
66
|
+
intermediate_dim: int = 2048 # FFN intermediate
|
|
67
|
+
max_seq_len: int = 2048 # Maximum sequence length
|
|
68
|
+
|
|
69
|
+
# Regularization
|
|
70
|
+
dropout: float = 0.0 # Dropout (0 for inference)
|
|
71
|
+
attention_dropout: float = 0.0
|
|
72
|
+
|
|
73
|
+
# Training
|
|
74
|
+
tie_word_embeddings: bool = True
|
|
75
|
+
use_cache: bool = True
|
|
76
|
+
|
|
77
|
+
# Distributed training
|
|
78
|
+
gradient_checkpointing: bool = False
|
|
79
|
+
gradient_accumulation_steps: int = 1
|
|
80
|
+
|
|
81
|
+
# RoPE settings
|
|
82
|
+
rope_theta: float = 10000.0
|
|
83
|
+
rope_scaling: Optional[Dict] = None
|
|
84
|
+
|
|
85
|
+
# MoE settings (for future scaling)
|
|
86
|
+
num_experts: int = 0 # 0 = dense model
|
|
87
|
+
num_experts_per_tok: int = 2
|
|
88
|
+
|
|
89
|
+
@classmethod
|
|
90
|
+
def bootstrap(cls) -> 'NeuroLLMConfig':
|
|
91
|
+
"""Bootstrap phase config (~125M params)."""
|
|
92
|
+
return cls(
|
|
93
|
+
phase=NeuroLLMPhase.BOOTSTRAP,
|
|
94
|
+
hidden_dim=768,
|
|
95
|
+
num_layers=12,
|
|
96
|
+
num_heads=12,
|
|
97
|
+
num_kv_heads=4,
|
|
98
|
+
intermediate_dim=2048,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
@classmethod
|
|
102
|
+
def early(cls) -> 'NeuroLLMConfig':
|
|
103
|
+
"""Early adoption config (~1B params)."""
|
|
104
|
+
return cls(
|
|
105
|
+
phase=NeuroLLMPhase.EARLY,
|
|
106
|
+
hidden_dim=2048,
|
|
107
|
+
num_layers=24,
|
|
108
|
+
num_heads=16,
|
|
109
|
+
num_kv_heads=4,
|
|
110
|
+
intermediate_dim=5632,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
@classmethod
|
|
114
|
+
def growth(cls) -> 'NeuroLLMConfig':
|
|
115
|
+
"""Growth phase config (~7B params)."""
|
|
116
|
+
return cls(
|
|
117
|
+
phase=NeuroLLMPhase.GROWTH,
|
|
118
|
+
hidden_dim=4096,
|
|
119
|
+
num_layers=32,
|
|
120
|
+
num_heads=32,
|
|
121
|
+
num_kv_heads=8,
|
|
122
|
+
intermediate_dim=11008,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
@classmethod
|
|
126
|
+
def mature(cls) -> 'NeuroLLMConfig':
|
|
127
|
+
"""Mature phase config (~70B params)."""
|
|
128
|
+
return cls(
|
|
129
|
+
phase=NeuroLLMPhase.MATURE,
|
|
130
|
+
hidden_dim=8192,
|
|
131
|
+
num_layers=80,
|
|
132
|
+
num_heads=64,
|
|
133
|
+
num_kv_heads=8,
|
|
134
|
+
intermediate_dim=28672,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
@property
|
|
138
|
+
def head_dim(self) -> int:
|
|
139
|
+
return self.hidden_dim // self.num_heads
|
|
140
|
+
|
|
141
|
+
@property
|
|
142
|
+
def num_params(self) -> int:
|
|
143
|
+
"""Estimate total parameters."""
|
|
144
|
+
# Embeddings
|
|
145
|
+
embed_params = self.vocab_size * self.hidden_dim
|
|
146
|
+
|
|
147
|
+
# Per layer
|
|
148
|
+
# Attention: Q, K, V, O projections
|
|
149
|
+
attn_params = (
|
|
150
|
+
self.hidden_dim * self.hidden_dim + # Q
|
|
151
|
+
self.hidden_dim * (self.hidden_dim // self.num_heads * self.num_kv_heads) * 2 + # K, V
|
|
152
|
+
self.hidden_dim * self.hidden_dim # O
|
|
153
|
+
)
|
|
154
|
+
# FFN: up, gate, down
|
|
155
|
+
ffn_params = 3 * self.hidden_dim * self.intermediate_dim
|
|
156
|
+
# Norms
|
|
157
|
+
norm_params = 2 * self.hidden_dim
|
|
158
|
+
|
|
159
|
+
layer_params = attn_params + ffn_params + norm_params
|
|
160
|
+
|
|
161
|
+
# Total
|
|
162
|
+
total = embed_params + self.num_layers * layer_params + self.hidden_dim # final norm
|
|
163
|
+
|
|
164
|
+
if not self.tie_word_embeddings:
|
|
165
|
+
total += self.vocab_size * self.hidden_dim
|
|
166
|
+
|
|
167
|
+
return total
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class RMSNorm(nn.Module):
|
|
171
|
+
"""
|
|
172
|
+
Root Mean Square Layer Normalization.
|
|
173
|
+
|
|
174
|
+
More stable than LayerNorm for distributed training.
|
|
175
|
+
"""
|
|
176
|
+
def __init__(self, dim: int, eps: float = 1e-6):
|
|
177
|
+
super().__init__()
|
|
178
|
+
self.eps = eps
|
|
179
|
+
self.weight = nn.Parameter(torch.ones(dim))
|
|
180
|
+
|
|
181
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
182
|
+
# Calculate RMS
|
|
183
|
+
rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
184
|
+
return x * rms * self.weight
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
class RotaryEmbedding(nn.Module):
|
|
188
|
+
"""
|
|
189
|
+
Rotary Position Embeddings (RoPE).
|
|
190
|
+
|
|
191
|
+
Allows the model to generalize to longer sequences than seen during training.
|
|
192
|
+
"""
|
|
193
|
+
def __init__(self, dim: int, max_seq_len: int = 2048, theta: float = 10000.0):
|
|
194
|
+
super().__init__()
|
|
195
|
+
self.dim = dim
|
|
196
|
+
self.max_seq_len = max_seq_len
|
|
197
|
+
self.theta = theta
|
|
198
|
+
|
|
199
|
+
# Precompute frequencies
|
|
200
|
+
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
|
201
|
+
self.register_buffer("inv_freq", inv_freq)
|
|
202
|
+
|
|
203
|
+
# Build cache
|
|
204
|
+
self._build_cache(max_seq_len)
|
|
205
|
+
|
|
206
|
+
def _build_cache(self, seq_len: int):
|
|
207
|
+
"""Build cos/sin cache for given sequence length."""
|
|
208
|
+
t = torch.arange(seq_len, device=self.inv_freq.device)
|
|
209
|
+
freqs = torch.outer(t, self.inv_freq)
|
|
210
|
+
emb = torch.cat((freqs, freqs), dim=-1)
|
|
211
|
+
self.register_buffer("cos_cached", emb.cos())
|
|
212
|
+
self.register_buffer("sin_cached", emb.sin())
|
|
213
|
+
|
|
214
|
+
def forward(self, x: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
215
|
+
if seq_len > self.max_seq_len:
|
|
216
|
+
self._build_cache(seq_len)
|
|
217
|
+
self.max_seq_len = seq_len
|
|
218
|
+
|
|
219
|
+
return (
|
|
220
|
+
self.cos_cached[:seq_len],
|
|
221
|
+
self.sin_cached[:seq_len]
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
|
226
|
+
"""Rotate half the hidden dims."""
|
|
227
|
+
x1 = x[..., :x.shape[-1] // 2]
|
|
228
|
+
x2 = x[..., x.shape[-1] // 2:]
|
|
229
|
+
return torch.cat((-x2, x1), dim=-1)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor,
|
|
233
|
+
cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
234
|
+
"""
|
|
235
|
+
Apply rotary position embeddings to Q and K.
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
q: Query tensor [batch, heads, seq_len, head_dim]
|
|
239
|
+
k: Key tensor [batch, heads, seq_len, head_dim]
|
|
240
|
+
cos: Cosine embeddings [seq_len, head_dim]
|
|
241
|
+
sin: Sine embeddings [seq_len, head_dim]
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
Tuple of rotated Q and K tensors
|
|
245
|
+
"""
|
|
246
|
+
# Ensure cos/sin are properly shaped for broadcasting with 4D tensors
|
|
247
|
+
# cos/sin: [seq_len, head_dim] -> [1, 1, seq_len, head_dim]
|
|
248
|
+
if cos.dim() == 2:
|
|
249
|
+
cos = cos.unsqueeze(0).unsqueeze(0)
|
|
250
|
+
sin = sin.unsqueeze(0).unsqueeze(0)
|
|
251
|
+
|
|
252
|
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
253
|
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
254
|
+
return q_embed, k_embed
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
class NeuroAttention(nn.Module):
|
|
258
|
+
"""
|
|
259
|
+
Grouped Query Attention (GQA) for NeuroLLM.
|
|
260
|
+
|
|
261
|
+
Uses fewer KV heads than query heads for memory efficiency.
|
|
262
|
+
"""
|
|
263
|
+
def __init__(self, config: NeuroLLMConfig, layer_idx: int):
|
|
264
|
+
super().__init__()
|
|
265
|
+
self.config = config
|
|
266
|
+
self.layer_idx = layer_idx
|
|
267
|
+
|
|
268
|
+
self.hidden_dim = config.hidden_dim
|
|
269
|
+
self.num_heads = config.num_heads
|
|
270
|
+
self.num_kv_heads = config.num_kv_heads
|
|
271
|
+
self.head_dim = config.head_dim
|
|
272
|
+
self.num_key_value_groups = self.num_heads // self.num_kv_heads
|
|
273
|
+
|
|
274
|
+
# Projections
|
|
275
|
+
self.q_proj = nn.Linear(self.hidden_dim, self.num_heads * self.head_dim, bias=False)
|
|
276
|
+
self.k_proj = nn.Linear(self.hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
|
|
277
|
+
self.v_proj = nn.Linear(self.hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
|
|
278
|
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_dim, bias=False)
|
|
279
|
+
|
|
280
|
+
# RoPE
|
|
281
|
+
self.rotary_emb = RotaryEmbedding(
|
|
282
|
+
self.head_dim,
|
|
283
|
+
max_seq_len=config.max_seq_len,
|
|
284
|
+
theta=config.rope_theta
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
|
288
|
+
|
|
289
|
+
def forward(
|
|
290
|
+
self,
|
|
291
|
+
hidden_states: torch.Tensor,
|
|
292
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
293
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
294
|
+
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
295
|
+
use_cache: bool = False,
|
|
296
|
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
297
|
+
batch_size, seq_len, _ = hidden_states.shape
|
|
298
|
+
|
|
299
|
+
# Project Q, K, V
|
|
300
|
+
query_states = self.q_proj(hidden_states)
|
|
301
|
+
key_states = self.k_proj(hidden_states)
|
|
302
|
+
value_states = self.v_proj(hidden_states)
|
|
303
|
+
|
|
304
|
+
# Reshape for multi-head attention
|
|
305
|
+
query_states = query_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
306
|
+
key_states = key_states.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
|
307
|
+
value_states = value_states.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
|
308
|
+
|
|
309
|
+
# Apply RoPE
|
|
310
|
+
cos, sin = self.rotary_emb(hidden_states, seq_len)
|
|
311
|
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
312
|
+
|
|
313
|
+
# Handle KV cache
|
|
314
|
+
if past_key_value is not None:
|
|
315
|
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
316
|
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
317
|
+
|
|
318
|
+
past_key_value = (key_states, value_states) if use_cache else None
|
|
319
|
+
|
|
320
|
+
# Repeat KV heads for GQA
|
|
321
|
+
key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
|
|
322
|
+
value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
|
|
323
|
+
|
|
324
|
+
# Attention
|
|
325
|
+
attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
|
326
|
+
|
|
327
|
+
if attention_mask is not None:
|
|
328
|
+
attn_weights = attn_weights + attention_mask
|
|
329
|
+
|
|
330
|
+
attn_weights = F.softmax(attn_weights, dim=-1)
|
|
331
|
+
attn_weights = self.attention_dropout(attn_weights)
|
|
332
|
+
|
|
333
|
+
attn_output = torch.matmul(attn_weights, value_states)
|
|
334
|
+
|
|
335
|
+
# Reshape and project output
|
|
336
|
+
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
|
|
337
|
+
attn_output = self.o_proj(attn_output)
|
|
338
|
+
|
|
339
|
+
return attn_output, past_key_value
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
class NeuroMLP(nn.Module):
|
|
343
|
+
"""
|
|
344
|
+
SwiGLU-based MLP for NeuroLLM.
|
|
345
|
+
|
|
346
|
+
SwiGLU provides better performance than standard GELU.
|
|
347
|
+
"""
|
|
348
|
+
def __init__(self, config: NeuroLLMConfig):
|
|
349
|
+
super().__init__()
|
|
350
|
+
self.hidden_dim = config.hidden_dim
|
|
351
|
+
self.intermediate_dim = config.intermediate_dim
|
|
352
|
+
|
|
353
|
+
# SwiGLU: gate * silu(up)
|
|
354
|
+
self.gate_proj = nn.Linear(self.hidden_dim, self.intermediate_dim, bias=False)
|
|
355
|
+
self.up_proj = nn.Linear(self.hidden_dim, self.intermediate_dim, bias=False)
|
|
356
|
+
self.down_proj = nn.Linear(self.intermediate_dim, self.hidden_dim, bias=False)
|
|
357
|
+
|
|
358
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
359
|
+
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
class NeuroDecoderLayer(nn.Module):
|
|
363
|
+
"""
|
|
364
|
+
Single transformer decoder layer for NeuroLLM.
|
|
365
|
+
"""
|
|
366
|
+
def __init__(self, config: NeuroLLMConfig, layer_idx: int):
|
|
367
|
+
super().__init__()
|
|
368
|
+
self.layer_idx = layer_idx
|
|
369
|
+
|
|
370
|
+
self.self_attn = NeuroAttention(config, layer_idx)
|
|
371
|
+
self.mlp = NeuroMLP(config)
|
|
372
|
+
|
|
373
|
+
self.input_layernorm = RMSNorm(config.hidden_dim)
|
|
374
|
+
self.post_attention_layernorm = RMSNorm(config.hidden_dim)
|
|
375
|
+
|
|
376
|
+
def forward(
|
|
377
|
+
self,
|
|
378
|
+
hidden_states: torch.Tensor,
|
|
379
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
380
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
381
|
+
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
382
|
+
use_cache: bool = False,
|
|
383
|
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
384
|
+
# Self attention with residual
|
|
385
|
+
residual = hidden_states
|
|
386
|
+
hidden_states = self.input_layernorm(hidden_states)
|
|
387
|
+
hidden_states, present_key_value = self.self_attn(
|
|
388
|
+
hidden_states,
|
|
389
|
+
attention_mask=attention_mask,
|
|
390
|
+
position_ids=position_ids,
|
|
391
|
+
past_key_value=past_key_value,
|
|
392
|
+
use_cache=use_cache,
|
|
393
|
+
)
|
|
394
|
+
hidden_states = residual + hidden_states
|
|
395
|
+
|
|
396
|
+
# MLP with residual
|
|
397
|
+
residual = hidden_states
|
|
398
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
399
|
+
hidden_states = self.mlp(hidden_states)
|
|
400
|
+
hidden_states = residual + hidden_states
|
|
401
|
+
|
|
402
|
+
return hidden_states, present_key_value
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
class NeuroLLMModel(nn.Module):
|
|
406
|
+
"""
|
|
407
|
+
The core NeuroLLM transformer model.
|
|
408
|
+
|
|
409
|
+
This is the base model without the LM head.
|
|
410
|
+
"""
|
|
411
|
+
def __init__(self, config: NeuroLLMConfig):
|
|
412
|
+
super().__init__()
|
|
413
|
+
self.config = config
|
|
414
|
+
|
|
415
|
+
# Token embeddings
|
|
416
|
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_dim)
|
|
417
|
+
|
|
418
|
+
# Transformer layers
|
|
419
|
+
self.layers = nn.ModuleList([
|
|
420
|
+
NeuroDecoderLayer(config, layer_idx)
|
|
421
|
+
for layer_idx in range(config.num_layers)
|
|
422
|
+
])
|
|
423
|
+
|
|
424
|
+
# Final norm
|
|
425
|
+
self.norm = RMSNorm(config.hidden_dim)
|
|
426
|
+
|
|
427
|
+
# Gradient checkpointing
|
|
428
|
+
self.gradient_checkpointing = config.gradient_checkpointing
|
|
429
|
+
|
|
430
|
+
# Initialize weights
|
|
431
|
+
self.apply(self._init_weights)
|
|
432
|
+
|
|
433
|
+
def _init_weights(self, module):
|
|
434
|
+
"""Initialize weights with small values for stable training."""
|
|
435
|
+
if isinstance(module, nn.Linear):
|
|
436
|
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
437
|
+
if module.bias is not None:
|
|
438
|
+
torch.nn.init.zeros_(module.bias)
|
|
439
|
+
elif isinstance(module, nn.Embedding):
|
|
440
|
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
441
|
+
|
|
442
|
+
def forward(
|
|
443
|
+
self,
|
|
444
|
+
input_ids: torch.Tensor,
|
|
445
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
446
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
447
|
+
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
|
448
|
+
use_cache: bool = False,
|
|
449
|
+
) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
|
|
450
|
+
batch_size, seq_len = input_ids.shape
|
|
451
|
+
|
|
452
|
+
# Get embeddings
|
|
453
|
+
hidden_states = self.embed_tokens(input_ids)
|
|
454
|
+
|
|
455
|
+
# Create causal mask
|
|
456
|
+
if attention_mask is None:
|
|
457
|
+
attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device)
|
|
458
|
+
|
|
459
|
+
# Create causal attention mask
|
|
460
|
+
causal_mask = torch.triu(
|
|
461
|
+
torch.full((seq_len, seq_len), float('-inf'), device=input_ids.device),
|
|
462
|
+
diagonal=1
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
# Process through layers
|
|
466
|
+
present_key_values = [] if use_cache else None
|
|
467
|
+
|
|
468
|
+
for idx, layer in enumerate(self.layers):
|
|
469
|
+
past_key_value = past_key_values[idx] if past_key_values else None
|
|
470
|
+
|
|
471
|
+
if self.gradient_checkpointing and self.training:
|
|
472
|
+
hidden_states, present_key_value = torch.utils.checkpoint.checkpoint(
|
|
473
|
+
layer,
|
|
474
|
+
hidden_states,
|
|
475
|
+
causal_mask,
|
|
476
|
+
position_ids,
|
|
477
|
+
past_key_value,
|
|
478
|
+
use_cache,
|
|
479
|
+
)
|
|
480
|
+
else:
|
|
481
|
+
hidden_states, present_key_value = layer(
|
|
482
|
+
hidden_states,
|
|
483
|
+
attention_mask=causal_mask,
|
|
484
|
+
position_ids=position_ids,
|
|
485
|
+
past_key_value=past_key_value,
|
|
486
|
+
use_cache=use_cache,
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
if use_cache:
|
|
490
|
+
present_key_values.append(present_key_value)
|
|
491
|
+
|
|
492
|
+
# Final norm
|
|
493
|
+
hidden_states = self.norm(hidden_states)
|
|
494
|
+
|
|
495
|
+
return hidden_states, present_key_values
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
class NeuroLLMForCausalLM(nn.Module):
|
|
499
|
+
"""
|
|
500
|
+
NeuroLLM with language modeling head.
|
|
501
|
+
|
|
502
|
+
This is the full model for both training and inference.
|
|
503
|
+
"""
|
|
504
|
+
def __init__(self, config: NeuroLLMConfig):
|
|
505
|
+
super().__init__()
|
|
506
|
+
self.config = config
|
|
507
|
+
|
|
508
|
+
self.model = NeuroLLMModel(config)
|
|
509
|
+
|
|
510
|
+
# LM head (optionally tied to embeddings)
|
|
511
|
+
self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
|
|
512
|
+
|
|
513
|
+
if config.tie_word_embeddings:
|
|
514
|
+
self.lm_head.weight = self.model.embed_tokens.weight
|
|
515
|
+
|
|
516
|
+
# Loss function
|
|
517
|
+
self.loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
|
|
518
|
+
|
|
519
|
+
logger.info(f"NeuroLLM initialized: {config.num_params / 1e6:.1f}M parameters, "
|
|
520
|
+
f"phase={config.phase.value}")
|
|
521
|
+
|
|
522
|
+
def forward(
|
|
523
|
+
self,
|
|
524
|
+
input_ids: torch.Tensor,
|
|
525
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
526
|
+
labels: Optional[torch.Tensor] = None,
|
|
527
|
+
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
|
528
|
+
use_cache: bool = False,
|
|
529
|
+
) -> Dict[str, torch.Tensor]:
|
|
530
|
+
# Get hidden states
|
|
531
|
+
hidden_states, present_key_values = self.model(
|
|
532
|
+
input_ids,
|
|
533
|
+
attention_mask=attention_mask,
|
|
534
|
+
past_key_values=past_key_values,
|
|
535
|
+
use_cache=use_cache,
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
# Get logits
|
|
539
|
+
logits = self.lm_head(hidden_states)
|
|
540
|
+
|
|
541
|
+
# Calculate loss if labels provided
|
|
542
|
+
loss = None
|
|
543
|
+
if labels is not None:
|
|
544
|
+
# Shift for next token prediction
|
|
545
|
+
shift_logits = logits[..., :-1, :].contiguous()
|
|
546
|
+
shift_labels = labels[..., 1:].contiguous()
|
|
547
|
+
|
|
548
|
+
# Flatten
|
|
549
|
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
550
|
+
shift_labels = shift_labels.view(-1)
|
|
551
|
+
|
|
552
|
+
loss = self.loss_fn(shift_logits, shift_labels)
|
|
553
|
+
|
|
554
|
+
return {
|
|
555
|
+
"loss": loss,
|
|
556
|
+
"logits": logits,
|
|
557
|
+
"past_key_values": present_key_values,
|
|
558
|
+
}
|
|
559
|
+
|
|
560
|
+
@torch.no_grad()
|
|
561
|
+
def generate(
|
|
562
|
+
self,
|
|
563
|
+
input_ids: torch.Tensor,
|
|
564
|
+
max_new_tokens: int = 100,
|
|
565
|
+
temperature: float = 1.0,
|
|
566
|
+
top_k: int = 50,
|
|
567
|
+
top_p: float = 0.9,
|
|
568
|
+
do_sample: bool = True,
|
|
569
|
+
valid_vocab_size: int = 266, # Mask tokens beyond this (default: byte tokens only)
|
|
570
|
+
) -> torch.Tensor:
|
|
571
|
+
"""Generate text autoregressively."""
|
|
572
|
+
self.eval()
|
|
573
|
+
|
|
574
|
+
past_key_values = None
|
|
575
|
+
generated = input_ids
|
|
576
|
+
|
|
577
|
+
for _ in range(max_new_tokens):
|
|
578
|
+
# Forward pass
|
|
579
|
+
outputs = self.forward(
|
|
580
|
+
input_ids=generated if past_key_values is None else generated[:, -1:],
|
|
581
|
+
past_key_values=past_key_values,
|
|
582
|
+
use_cache=True,
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
logits = outputs["logits"][:, -1, :]
|
|
586
|
+
past_key_values = outputs["past_key_values"]
|
|
587
|
+
|
|
588
|
+
# Constrain to valid vocabulary (standard BPE tokenizer behavior)
|
|
589
|
+
# Tokens beyond valid_vocab_size don't exist in the tokenizer yet
|
|
590
|
+
if valid_vocab_size < logits.size(-1):
|
|
591
|
+
logits[:, valid_vocab_size:] = float('-inf')
|
|
592
|
+
|
|
593
|
+
# Apply temperature
|
|
594
|
+
if temperature != 1.0:
|
|
595
|
+
logits = logits / temperature
|
|
596
|
+
|
|
597
|
+
# Apply top-k
|
|
598
|
+
if top_k > 0:
|
|
599
|
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
|
600
|
+
logits[indices_to_remove] = float('-inf')
|
|
601
|
+
|
|
602
|
+
# Apply top-p (nucleus sampling)
|
|
603
|
+
if top_p < 1.0:
|
|
604
|
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
605
|
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
|
606
|
+
|
|
607
|
+
sorted_indices_to_remove = cumulative_probs > top_p
|
|
608
|
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
|
609
|
+
sorted_indices_to_remove[..., 0] = 0
|
|
610
|
+
|
|
611
|
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
|
612
|
+
1, sorted_indices, sorted_indices_to_remove
|
|
613
|
+
)
|
|
614
|
+
logits[indices_to_remove] = float('-inf')
|
|
615
|
+
|
|
616
|
+
# Sample or greedy
|
|
617
|
+
if do_sample:
|
|
618
|
+
probs = F.softmax(logits, dim=-1)
|
|
619
|
+
next_token = torch.multinomial(probs, num_samples=1)
|
|
620
|
+
else:
|
|
621
|
+
next_token = torch.argmax(logits, dim=-1, keepdim=True)
|
|
622
|
+
|
|
623
|
+
generated = torch.cat([generated, next_token], dim=-1)
|
|
624
|
+
|
|
625
|
+
# Check for EOS (assuming token 2 is EOS)
|
|
626
|
+
if next_token.item() == 2:
|
|
627
|
+
break
|
|
628
|
+
|
|
629
|
+
return generated
|
|
630
|
+
|
|
631
|
+
def get_num_params(self) -> int:
|
|
632
|
+
"""Get actual number of parameters."""
|
|
633
|
+
return sum(p.numel() for p in self.parameters())
|
|
634
|
+
|
|
635
|
+
def save_checkpoint(self, path: str, extra_state: Dict = None):
|
|
636
|
+
"""Save model checkpoint."""
|
|
637
|
+
state = {
|
|
638
|
+
"config": self.config.__dict__,
|
|
639
|
+
"model_state_dict": self.state_dict(),
|
|
640
|
+
"version": self.config.version,
|
|
641
|
+
}
|
|
642
|
+
if extra_state:
|
|
643
|
+
state.update(extra_state)
|
|
644
|
+
torch.save(state, path)
|
|
645
|
+
logger.info(f"Saved checkpoint to {path}")
|
|
646
|
+
|
|
647
|
+
@classmethod
|
|
648
|
+
def load_checkpoint(cls, path: str, device: str = "cpu") -> 'NeuroLLMForCausalLM':
|
|
649
|
+
"""Load model from checkpoint."""
|
|
650
|
+
# Use weights_only=False for full checkpoint loading (includes config)
|
|
651
|
+
# This is safe because we only load our own checkpoints
|
|
652
|
+
state = torch.load(path, map_location=device, weights_only=False)
|
|
653
|
+
|
|
654
|
+
# Reconstruct config
|
|
655
|
+
config_dict = state["config"]
|
|
656
|
+
config_dict["phase"] = NeuroLLMPhase(config_dict["phase"])
|
|
657
|
+
config = NeuroLLMConfig(**config_dict)
|
|
658
|
+
|
|
659
|
+
# Create model and load weights
|
|
660
|
+
model = cls(config)
|
|
661
|
+
model.load_state_dict(state["model_state_dict"])
|
|
662
|
+
|
|
663
|
+
logger.info(f"Loaded checkpoint from {path}")
|
|
664
|
+
return model
|
|
665
|
+
|
|
666
|
+
# ==================== SHARDED MODEL SUPPORT ====================
|
|
667
|
+
|
|
668
|
+
def get_layer_names(self, layer_idx: int) -> List[str]:
|
|
669
|
+
"""Get parameter names for a specific layer."""
|
|
670
|
+
prefix = f"model.layers.{layer_idx}."
|
|
671
|
+
return [name for name in self.state_dict().keys() if name.startswith(prefix)]
|
|
672
|
+
|
|
673
|
+
def extract_shard(
|
|
674
|
+
self,
|
|
675
|
+
start_layer: int,
|
|
676
|
+
end_layer: int,
|
|
677
|
+
include_embedding: bool = False,
|
|
678
|
+
include_lm_head: bool = False
|
|
679
|
+
) -> Dict[str, torch.Tensor]:
|
|
680
|
+
"""
|
|
681
|
+
Extract weights for a shard (subset of layers).
|
|
682
|
+
|
|
683
|
+
Args:
|
|
684
|
+
start_layer: First layer to include
|
|
685
|
+
end_layer: Last layer to include (exclusive)
|
|
686
|
+
include_embedding: Include embedding layer
|
|
687
|
+
include_lm_head: Include LM head
|
|
688
|
+
|
|
689
|
+
Returns:
|
|
690
|
+
Dict of parameter names to tensors
|
|
691
|
+
"""
|
|
692
|
+
state_dict = self.state_dict()
|
|
693
|
+
shard_weights = {}
|
|
694
|
+
|
|
695
|
+
for name, param in state_dict.items():
|
|
696
|
+
include = False
|
|
697
|
+
|
|
698
|
+
# Embedding
|
|
699
|
+
if include_embedding and ("embed" in name.lower() or "wte" in name.lower()):
|
|
700
|
+
include = True
|
|
701
|
+
|
|
702
|
+
# LM head
|
|
703
|
+
if include_lm_head and ("lm_head" in name.lower()):
|
|
704
|
+
include = True
|
|
705
|
+
|
|
706
|
+
# Final norm (goes with LM head)
|
|
707
|
+
if include_lm_head and ("final" in name.lower() or "ln_f" in name.lower() or "model.norm" in name.lower()):
|
|
708
|
+
include = True
|
|
709
|
+
|
|
710
|
+
# Transformer layers
|
|
711
|
+
import re
|
|
712
|
+
match = re.search(r'layers\.(\d+)\.', name)
|
|
713
|
+
if match:
|
|
714
|
+
layer_num = int(match.group(1))
|
|
715
|
+
if start_layer <= layer_num < end_layer:
|
|
716
|
+
include = True
|
|
717
|
+
|
|
718
|
+
if include:
|
|
719
|
+
shard_weights[name] = param.clone()
|
|
720
|
+
|
|
721
|
+
logger.info(f"Extracted shard: layers {start_layer}-{end_layer}, "
|
|
722
|
+
f"embed={include_embedding}, head={include_lm_head}, "
|
|
723
|
+
f"params={len(shard_weights)}")
|
|
724
|
+
|
|
725
|
+
return shard_weights
|
|
726
|
+
|
|
727
|
+
def load_shard(
|
|
728
|
+
self,
|
|
729
|
+
shard_weights: Dict[str, torch.Tensor],
|
|
730
|
+
strict: bool = False
|
|
731
|
+
) -> Tuple[List[str], List[str]]:
|
|
732
|
+
"""
|
|
733
|
+
Load weights from a shard into the model.
|
|
734
|
+
|
|
735
|
+
Args:
|
|
736
|
+
shard_weights: Dict of parameter names to tensors
|
|
737
|
+
strict: If True, raise error on missing/unexpected keys
|
|
738
|
+
|
|
739
|
+
Returns:
|
|
740
|
+
(missing_keys, unexpected_keys)
|
|
741
|
+
"""
|
|
742
|
+
current_state = self.state_dict()
|
|
743
|
+
|
|
744
|
+
missing_keys = []
|
|
745
|
+
unexpected_keys = []
|
|
746
|
+
loaded_keys = []
|
|
747
|
+
|
|
748
|
+
for name, param in shard_weights.items():
|
|
749
|
+
if name in current_state:
|
|
750
|
+
if current_state[name].shape == param.shape:
|
|
751
|
+
current_state[name].copy_(param)
|
|
752
|
+
loaded_keys.append(name)
|
|
753
|
+
else:
|
|
754
|
+
logger.warning(f"Shape mismatch for {name}: "
|
|
755
|
+
f"model={current_state[name].shape}, shard={param.shape}")
|
|
756
|
+
else:
|
|
757
|
+
unexpected_keys.append(name)
|
|
758
|
+
|
|
759
|
+
# Check for missing keys in the shard range
|
|
760
|
+
for name in current_state.keys():
|
|
761
|
+
if name not in shard_weights and name not in loaded_keys:
|
|
762
|
+
# Only report as missing if it should have been in the shard
|
|
763
|
+
pass # Shards are partial by design
|
|
764
|
+
|
|
765
|
+
if strict and unexpected_keys:
|
|
766
|
+
raise RuntimeError(f"Unexpected keys in shard: {unexpected_keys}")
|
|
767
|
+
|
|
768
|
+
logger.info(f"Loaded shard: {len(loaded_keys)} parameters")
|
|
769
|
+
|
|
770
|
+
return missing_keys, unexpected_keys
|
|
771
|
+
|
|
772
|
+
def save_shard(
|
|
773
|
+
self,
|
|
774
|
+
path: str,
|
|
775
|
+
start_layer: int,
|
|
776
|
+
end_layer: int,
|
|
777
|
+
include_embedding: bool = False,
|
|
778
|
+
include_lm_head: bool = False,
|
|
779
|
+
extra_state: Dict = None
|
|
780
|
+
):
|
|
781
|
+
"""
|
|
782
|
+
Save a shard to disk.
|
|
783
|
+
|
|
784
|
+
Args:
|
|
785
|
+
path: Save path
|
|
786
|
+
start_layer: First layer
|
|
787
|
+
end_layer: Last layer (exclusive)
|
|
788
|
+
include_embedding: Include embedding
|
|
789
|
+
include_lm_head: Include LM head
|
|
790
|
+
extra_state: Additional state to save
|
|
791
|
+
"""
|
|
792
|
+
shard_weights = self.extract_shard(
|
|
793
|
+
start_layer, end_layer, include_embedding, include_lm_head
|
|
794
|
+
)
|
|
795
|
+
|
|
796
|
+
state = {
|
|
797
|
+
"shard_config": {
|
|
798
|
+
"start_layer": start_layer,
|
|
799
|
+
"end_layer": end_layer,
|
|
800
|
+
"include_embedding": include_embedding,
|
|
801
|
+
"include_lm_head": include_lm_head,
|
|
802
|
+
},
|
|
803
|
+
"model_config": self.config.__dict__,
|
|
804
|
+
"shard_weights": shard_weights,
|
|
805
|
+
"num_params": sum(p.numel() for p in shard_weights.values()),
|
|
806
|
+
}
|
|
807
|
+
|
|
808
|
+
if extra_state:
|
|
809
|
+
state.update(extra_state)
|
|
810
|
+
|
|
811
|
+
torch.save(state, path)
|
|
812
|
+
logger.info(f"Saved shard to {path}: layers {start_layer}-{end_layer}, "
|
|
813
|
+
f"{state['num_params']} params")
|
|
814
|
+
|
|
815
|
+
@classmethod
|
|
816
|
+
def load_shard_file(cls, path: str, device: str = "cpu") -> Tuple[Dict[str, torch.Tensor], Dict]:
|
|
817
|
+
"""
|
|
818
|
+
Load a shard file.
|
|
819
|
+
|
|
820
|
+
Args:
|
|
821
|
+
path: Shard file path
|
|
822
|
+
device: Device to load to
|
|
823
|
+
|
|
824
|
+
Returns:
|
|
825
|
+
(shard_weights, shard_config)
|
|
826
|
+
"""
|
|
827
|
+
state = torch.load(path, map_location=device, weights_only=False)
|
|
828
|
+
|
|
829
|
+
return state["shard_weights"], state["shard_config"]
|
|
830
|
+
|
|
831
|
+
def forward_shard(
|
|
832
|
+
self,
|
|
833
|
+
hidden_states: torch.Tensor,
|
|
834
|
+
start_layer: int,
|
|
835
|
+
end_layer: int,
|
|
836
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
837
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
838
|
+
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
|
839
|
+
use_cache: bool = False,
|
|
840
|
+
) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
|
|
841
|
+
"""
|
|
842
|
+
Forward pass through a subset of layers (for pipeline parallelism).
|
|
843
|
+
|
|
844
|
+
This is used when the model is sharded across nodes:
|
|
845
|
+
- Node A processes layers 0-3
|
|
846
|
+
- Node B processes layers 4-7
|
|
847
|
+
- etc.
|
|
848
|
+
|
|
849
|
+
Args:
|
|
850
|
+
hidden_states: Input hidden states [batch, seq, hidden]
|
|
851
|
+
start_layer: First layer to process
|
|
852
|
+
end_layer: Last layer to process (exclusive)
|
|
853
|
+
attention_mask: Attention mask
|
|
854
|
+
position_ids: Position IDs
|
|
855
|
+
past_key_values: KV cache
|
|
856
|
+
use_cache: Whether to return KV cache
|
|
857
|
+
|
|
858
|
+
Returns:
|
|
859
|
+
(output_hidden_states, new_past_key_values)
|
|
860
|
+
"""
|
|
861
|
+
new_past_key_values = [] if use_cache else None
|
|
862
|
+
|
|
863
|
+
for idx in range(start_layer, end_layer):
|
|
864
|
+
layer = self.model.layers[idx]
|
|
865
|
+
|
|
866
|
+
past_kv = past_key_values[idx] if past_key_values else None
|
|
867
|
+
|
|
868
|
+
hidden_states, new_kv = layer(
|
|
869
|
+
hidden_states,
|
|
870
|
+
attention_mask=attention_mask,
|
|
871
|
+
position_ids=position_ids,
|
|
872
|
+
past_key_value=past_kv,
|
|
873
|
+
use_cache=use_cache,
|
|
874
|
+
)
|
|
875
|
+
|
|
876
|
+
if use_cache:
|
|
877
|
+
new_past_key_values.append(new_kv)
|
|
878
|
+
|
|
879
|
+
return hidden_states, new_past_key_values
|
|
880
|
+
|
|
881
|
+
|
|
882
|
+
# Convenience function
|
|
883
|
+
def create_neuro_llm(phase: str = "bootstrap") -> NeuroLLMForCausalLM:
|
|
884
|
+
"""
|
|
885
|
+
Create a NeuroLLM model for the specified phase.
|
|
886
|
+
|
|
887
|
+
Args:
|
|
888
|
+
phase: One of "bootstrap", "early", "growth", "mature"
|
|
889
|
+
|
|
890
|
+
Returns:
|
|
891
|
+
NeuroLLMForCausalLM instance
|
|
892
|
+
"""
|
|
893
|
+
config_map = {
|
|
894
|
+
"bootstrap": NeuroLLMConfig.bootstrap,
|
|
895
|
+
"early": NeuroLLMConfig.early,
|
|
896
|
+
"growth": NeuroLLMConfig.growth,
|
|
897
|
+
"mature": NeuroLLMConfig.mature,
|
|
898
|
+
}
|
|
899
|
+
|
|
900
|
+
if phase not in config_map:
|
|
901
|
+
raise ValueError(f"Unknown phase: {phase}. Choose from {list(config_map.keys())}")
|
|
902
|
+
|
|
903
|
+
config = config_map[phase]()
|
|
904
|
+
return NeuroLLMForCausalLM(config)
|
|
905
|
+
|