nexaroa 0.0.111__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. neuroshard/__init__.py +93 -0
  2. neuroshard/__main__.py +4 -0
  3. neuroshard/cli.py +466 -0
  4. neuroshard/core/__init__.py +92 -0
  5. neuroshard/core/consensus/verifier.py +252 -0
  6. neuroshard/core/crypto/__init__.py +20 -0
  7. neuroshard/core/crypto/ecdsa.py +392 -0
  8. neuroshard/core/economics/__init__.py +52 -0
  9. neuroshard/core/economics/constants.py +387 -0
  10. neuroshard/core/economics/ledger.py +2111 -0
  11. neuroshard/core/economics/market.py +975 -0
  12. neuroshard/core/economics/wallet.py +168 -0
  13. neuroshard/core/governance/__init__.py +74 -0
  14. neuroshard/core/governance/proposal.py +561 -0
  15. neuroshard/core/governance/registry.py +545 -0
  16. neuroshard/core/governance/versioning.py +332 -0
  17. neuroshard/core/governance/voting.py +453 -0
  18. neuroshard/core/model/__init__.py +30 -0
  19. neuroshard/core/model/dynamic.py +4186 -0
  20. neuroshard/core/model/llm.py +905 -0
  21. neuroshard/core/model/registry.py +164 -0
  22. neuroshard/core/model/scaler.py +387 -0
  23. neuroshard/core/model/tokenizer.py +568 -0
  24. neuroshard/core/network/__init__.py +56 -0
  25. neuroshard/core/network/connection_pool.py +72 -0
  26. neuroshard/core/network/dht.py +130 -0
  27. neuroshard/core/network/dht_plan.py +55 -0
  28. neuroshard/core/network/dht_proof_store.py +516 -0
  29. neuroshard/core/network/dht_protocol.py +261 -0
  30. neuroshard/core/network/dht_service.py +506 -0
  31. neuroshard/core/network/encrypted_channel.py +141 -0
  32. neuroshard/core/network/nat.py +201 -0
  33. neuroshard/core/network/nat_traversal.py +695 -0
  34. neuroshard/core/network/p2p.py +929 -0
  35. neuroshard/core/network/p2p_data.py +150 -0
  36. neuroshard/core/swarm/__init__.py +106 -0
  37. neuroshard/core/swarm/aggregation.py +729 -0
  38. neuroshard/core/swarm/buffers.py +643 -0
  39. neuroshard/core/swarm/checkpoint.py +709 -0
  40. neuroshard/core/swarm/compute.py +624 -0
  41. neuroshard/core/swarm/diloco.py +844 -0
  42. neuroshard/core/swarm/factory.py +1288 -0
  43. neuroshard/core/swarm/heartbeat.py +669 -0
  44. neuroshard/core/swarm/logger.py +487 -0
  45. neuroshard/core/swarm/router.py +658 -0
  46. neuroshard/core/swarm/service.py +640 -0
  47. neuroshard/core/training/__init__.py +29 -0
  48. neuroshard/core/training/checkpoint.py +600 -0
  49. neuroshard/core/training/distributed.py +1602 -0
  50. neuroshard/core/training/global_tracker.py +617 -0
  51. neuroshard/core/training/production.py +276 -0
  52. neuroshard/governance_cli.py +729 -0
  53. neuroshard/grpc_server.py +895 -0
  54. neuroshard/runner.py +3223 -0
  55. neuroshard/sdk/__init__.py +92 -0
  56. neuroshard/sdk/client.py +990 -0
  57. neuroshard/sdk/errors.py +101 -0
  58. neuroshard/sdk/types.py +282 -0
  59. neuroshard/tracker/__init__.py +0 -0
  60. neuroshard/tracker/server.py +864 -0
  61. neuroshard/ui/__init__.py +0 -0
  62. neuroshard/ui/app.py +102 -0
  63. neuroshard/ui/templates/index.html +1052 -0
  64. neuroshard/utils/__init__.py +0 -0
  65. neuroshard/utils/autostart.py +81 -0
  66. neuroshard/utils/hardware.py +121 -0
  67. neuroshard/utils/serialization.py +90 -0
  68. neuroshard/version.py +1 -0
  69. nexaroa-0.0.111.dist-info/METADATA +283 -0
  70. nexaroa-0.0.111.dist-info/RECORD +78 -0
  71. nexaroa-0.0.111.dist-info/WHEEL +5 -0
  72. nexaroa-0.0.111.dist-info/entry_points.txt +4 -0
  73. nexaroa-0.0.111.dist-info/licenses/LICENSE +190 -0
  74. nexaroa-0.0.111.dist-info/top_level.txt +2 -0
  75. protos/__init__.py +0 -0
  76. protos/neuroshard.proto +651 -0
  77. protos/neuroshard_pb2.py +160 -0
  78. protos/neuroshard_pb2_grpc.py +1298 -0
@@ -0,0 +1,164 @@
1
+ """
2
+ Dynamic Tokenizer Registry
3
+
4
+ Allows tokenizer vocabulary to grow as network needs evolve.
5
+
6
+ FUTURE ENHANCEMENT:
7
+ - Start with 32k vocabulary (English-focused)
8
+ - Grow to 64k when multilingual support needed
9
+ - Grow to 128k when code/math support needed
10
+
11
+ Just like architecture, tokenizer can be upgraded!
12
+ """
13
+
14
+ import logging
15
+ from dataclasses import dataclass
16
+ from typing import Optional
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ @dataclass
22
+ class TokenizerConfig:
23
+ """Configuration for NeuroShard tokenizer."""
24
+ vocab_size: int
25
+ tokenizer_version: int
26
+ model_path: str
27
+ supported_languages: list
28
+
29
+ def estimate_embedding_memory_mb(self, hidden_dim: int) -> float:
30
+ """Calculate memory for embedding layer with this vocab size."""
31
+ # Embedding: vocab_size × hidden_dim
32
+ # Plus gradients and optimizer states (×4 total)
33
+ params = self.vocab_size * hidden_dim
34
+ return (params * 4 * 4) / (1024 * 1024)
35
+
36
+ def to_dict(self):
37
+ return {
38
+ "vocab_size": self.vocab_size,
39
+ "tokenizer_version": self.tokenizer_version,
40
+ "model_path": self.model_path,
41
+ "supported_languages": self.supported_languages,
42
+ }
43
+
44
+ @classmethod
45
+ def from_dict(cls, data: dict):
46
+ return cls(**data)
47
+
48
+
49
+ # Default tokenizer configurations
50
+ TOKENIZER_CONFIGS = {
51
+ # Version 1: English-focused (bootstrap)
52
+ 1: TokenizerConfig(
53
+ vocab_size=32000,
54
+ tokenizer_version=1,
55
+ model_path="neuroshard_tokenizer_v1.model",
56
+ supported_languages=["en"],
57
+ ),
58
+
59
+ # Version 2: Multilingual (future upgrade)
60
+ # Triggered when network has >1000 nodes from diverse regions
61
+ 2: TokenizerConfig(
62
+ vocab_size=64000,
63
+ tokenizer_version=2,
64
+ model_path="neuroshard_tokenizer_v2.model",
65
+ supported_languages=["en", "es", "fr", "de", "zh", "ja"],
66
+ ),
67
+
68
+ # Version 3: Code + Math specialized (future upgrade)
69
+ # Triggered by community vote or when code training data > 30%
70
+ 3: TokenizerConfig(
71
+ vocab_size=100000,
72
+ tokenizer_version=3,
73
+ model_path="neuroshard_tokenizer_v3.model",
74
+ supported_languages=["en", "code", "math"],
75
+ ),
76
+ }
77
+
78
+
79
+ def get_current_tokenizer_config(network_size: int = 1) -> TokenizerConfig:
80
+ """
81
+ Get appropriate tokenizer config for current network state.
82
+
83
+ Upgrade triggers:
84
+ - v1 → v2: Network has >1000 nodes (multilingual needed)
85
+ - v2 → v3: Community votes for code specialization
86
+
87
+ Args:
88
+ network_size: Number of active nodes
89
+
90
+ Returns:
91
+ TokenizerConfig for current network state
92
+ """
93
+ if network_size < 1000:
94
+ # Bootstrap: English-focused 32k vocab
95
+ return TOKENIZER_CONFIGS[1]
96
+ elif network_size < 5000:
97
+ # Growth: Multilingual 64k vocab
98
+ # TODO: Implement migration from v1 → v2
99
+ logger.warning("Network ready for multilingual tokenizer upgrade (v1 → v2)")
100
+ return TOKENIZER_CONFIGS[1] # Stay on v1 until migration implemented
101
+ else:
102
+ # Maturity: Specialized 100k vocab
103
+ logger.warning("Network ready for code-specialized tokenizer (v2 → v3)")
104
+ return TOKENIZER_CONFIGS[2] # Stay on v2 until migration implemented
105
+
106
+
107
+ def should_upgrade_tokenizer(
108
+ current: TokenizerConfig,
109
+ new: TokenizerConfig,
110
+ ) -> tuple[bool, str]:
111
+ """
112
+ Determine if tokenizer upgrade is worthwhile.
113
+
114
+ Tokenizer upgrades are EXPENSIVE (require retraining):
115
+ - All embeddings must be expanded (vocab_size × hidden_dim)
116
+ - All existing checkpoints incompatible
117
+ - Requires community vote (unlike architecture, which is automatic)
118
+
119
+ Returns:
120
+ (should_upgrade, reason)
121
+ """
122
+ if new.tokenizer_version <= current.tokenizer_version:
123
+ return False, "Not a newer version"
124
+
125
+ # Tokenizer upgrades require community governance vote
126
+ # Unlike architecture (automatic), tokenizer affects all model outputs
127
+ reason = (f"Tokenizer upgrade available: v{current.tokenizer_version} → v{new.tokenizer_version} "
128
+ f"({current.vocab_size} → {new.vocab_size} tokens). "
129
+ f"Requires NeuroDAO vote and coordinated upgrade.")
130
+
131
+ return True, reason
132
+
133
+
134
+ # INTEGRATION WITH ARCHITECTURE
135
+
136
+ def adjust_architecture_for_tokenizer(
137
+ arch: 'ModelArchitecture', # type: ignore
138
+ tokenizer: TokenizerConfig
139
+ ) -> 'ModelArchitecture': # type: ignore
140
+ """
141
+ Adjust architecture to account for tokenizer vocab size.
142
+
143
+ Larger vocab → larger embedding → less memory for layers.
144
+ """
145
+ from neuroshard.core.model.scaler import ModelArchitecture
146
+
147
+ # Estimate embedding memory
148
+ embedding_mem_mb = tokenizer.estimate_embedding_memory_mb(arch.hidden_dim)
149
+
150
+ # If embedding is too large, we might need to reduce num_layers slightly
151
+ # (This is handled in calculate_optimal_architecture already, but good to verify)
152
+
153
+ return ModelArchitecture(
154
+ hidden_dim=arch.hidden_dim,
155
+ intermediate_dim=arch.intermediate_dim,
156
+ num_layers=arch.num_layers,
157
+ num_heads=arch.num_heads,
158
+ num_kv_heads=arch.num_kv_heads,
159
+ vocab_size=tokenizer.vocab_size, # Use tokenizer's vocab_size!
160
+ max_seq_len=arch.max_seq_len,
161
+ dropout=arch.dropout,
162
+ rope_theta=arch.rope_theta,
163
+ )
164
+
@@ -0,0 +1,387 @@
1
+ """
2
+ Dynamic Architecture Scaler - Core of NeuroShard's Unlimited Scaling
3
+
4
+ This module calculates optimal model architecture (width + depth) based on
5
+ total network capacity, following empirical scaling laws from GPT-3 and Chinchilla.
6
+
7
+ Key Principles:
8
+ 1. Width grows faster than depth (empirically more efficient)
9
+ 2. Both dimensions scale with network size
10
+ 3. Architecture updates are gradual and automated
11
+ 4. No fixed model sizes - purely capacity-driven
12
+ """
13
+
14
+ import math
15
+ import logging
16
+ from typing import Dict, Optional, Tuple
17
+ from dataclasses import dataclass
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ @dataclass
23
+ class ModelArchitecture:
24
+ """Configuration for a dynamically-sized model."""
25
+ hidden_dim: int
26
+ intermediate_dim: int
27
+ num_layers: int
28
+ num_heads: int
29
+ num_kv_heads: int
30
+ vocab_size: int = 32000 # Initial size - expands automatically as tokenizer grows (unlimited)
31
+ max_seq_len: int = 2048
32
+ dropout: float = 0.0
33
+ rope_theta: float = 10000.0
34
+
35
+ def __post_init__(self):
36
+ """Validate architecture constraints."""
37
+ assert self.hidden_dim % self.num_heads == 0, \
38
+ f"hidden_dim ({self.hidden_dim}) must be divisible by num_heads ({self.num_heads})"
39
+ assert self.num_heads % self.num_kv_heads == 0, \
40
+ f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})"
41
+ assert self.hidden_dim > 0 and self.num_layers > 0, \
42
+ "Dimensions must be positive"
43
+ # head_dim must be even for Rotary Position Embeddings (RoPE)
44
+ head_dim = self.hidden_dim // self.num_heads
45
+ assert head_dim % 2 == 0, \
46
+ f"head_dim ({head_dim}) must be even for RoPE. Adjust hidden_dim or num_heads."
47
+
48
+ def estimate_params(self) -> int:
49
+ """Estimate total parameters in this architecture."""
50
+ # Embedding
51
+ embed_params = self.vocab_size * self.hidden_dim
52
+
53
+ # Per-layer params (attention + FFN + norms)
54
+ # Attention: Q, K, V, O projections
55
+ attn_params = 4 * self.hidden_dim * self.hidden_dim
56
+ # FFN: gate, up, down (SwiGLU)
57
+ ffn_params = 3 * self.hidden_dim * self.intermediate_dim
58
+ # Norms (RMSNorm has 1 param per dim)
59
+ norm_params = 2 * self.hidden_dim
60
+
61
+ layer_params = attn_params + ffn_params + norm_params
62
+
63
+ # LM head (tied with embedding, so don't count twice)
64
+ # Final norm
65
+ final_norm_params = self.hidden_dim
66
+
67
+ total = embed_params + (layer_params * self.num_layers) + final_norm_params
68
+ return total
69
+
70
+ def estimate_memory_mb(self) -> float:
71
+ """
72
+ Estimate memory needed for this architecture.
73
+ Includes weights + gradients + optimizer states (Adam).
74
+ """
75
+ params = self.estimate_params()
76
+ # 4 bytes per param (FP32) × 4 (weights + grads + Adam m + v)
77
+ return (params * 4 * 4) / (1024 * 1024)
78
+
79
+ def to_dict(self) -> Dict:
80
+ """Serialize for checkpointing."""
81
+ return {
82
+ "hidden_dim": self.hidden_dim,
83
+ "intermediate_dim": self.intermediate_dim,
84
+ "num_layers": self.num_layers,
85
+ "num_heads": self.num_heads,
86
+ "num_kv_heads": self.num_kv_heads,
87
+ "vocab_size": self.vocab_size,
88
+ "max_seq_len": self.max_seq_len,
89
+ "dropout": self.dropout,
90
+ "rope_theta": self.rope_theta,
91
+ }
92
+
93
+ @classmethod
94
+ def from_dict(cls, data: Dict):
95
+ """Deserialize from checkpoint."""
96
+ return cls(**data)
97
+
98
+
99
+ def calculate_optimal_architecture(
100
+ total_network_memory_mb: float,
101
+ utilization_factor: float = 0.6
102
+ ) -> ModelArchitecture:
103
+ """
104
+ Calculate optimal model architecture based on total network capacity.
105
+
106
+ Follows empirical scaling laws:
107
+ - Chinchilla: Optimal compute scales as params^1.0 × tokens^1.0
108
+ - GPT-3: Width grows faster than depth for efficiency
109
+ - Empirical: depth ∝ params^0.4, width ∝ params^0.6
110
+
111
+ Args:
112
+ total_network_memory_mb: Total available memory across all nodes
113
+ utilization_factor: Fraction of memory to use (default 0.6 for safety)
114
+
115
+ Returns:
116
+ ModelArchitecture optimized for this capacity
117
+ """
118
+ # STABILITY FIX: Round memory to nearest 500MB tier to prevent architecture
119
+ # changes due to small memory fluctuations between runs.
120
+ # This ensures checkpoints remain compatible across restarts.
121
+ MEMORY_TIER_MB = 500
122
+ rounded_memory_mb = math.floor(total_network_memory_mb / MEMORY_TIER_MB) * MEMORY_TIER_MB
123
+ # Ensure at least 500MB (minimum tier)
124
+ rounded_memory_mb = max(MEMORY_TIER_MB, rounded_memory_mb)
125
+
126
+ # Calculate parameter budget
127
+ # Memory per param: 16 bytes (FP32 weight + grad + 2× Adam states)
128
+ usable_memory_mb = rounded_memory_mb * utilization_factor
129
+ max_params = int((usable_memory_mb * 1024 * 1024) / 16)
130
+ params_millions = max_params / 1e6
131
+
132
+ logger.info(f"Network capacity: {total_network_memory_mb:.0f}MB (rounded to {rounded_memory_mb}MB tier) → {params_millions:.0f}M params budget")
133
+
134
+ # Scaling formulas based on parameter count
135
+ # These are empirically derived from successful LLM architectures
136
+
137
+ if params_millions < 50: # < 50M params (tiny network)
138
+ # Example: 1-5 nodes with 2GB each
139
+ # Architecture: Similar to GPT-2 Small
140
+ hidden_dim = 384
141
+ num_layers = max(6, int(12 * (params_millions / 50) ** 0.5))
142
+ num_heads = 6
143
+
144
+ elif params_millions < 150: # 50M - 150M params (small network)
145
+ # Example: 5-20 nodes
146
+ # Architecture: GPT-2 Medium scale
147
+ hidden_dim = 512
148
+ num_layers = max(8, int(16 * (params_millions / 150) ** 0.5))
149
+ num_heads = 8
150
+
151
+ elif params_millions < 500: # 150M - 500M params (medium network)
152
+ # Example: 20-50 nodes
153
+ # Architecture: GPT-2 Large → GPT-2 XL
154
+ hidden_dim = int(640 + 384 * ((params_millions - 150) / 350))
155
+ num_layers = max(12, int(24 * (params_millions / 500) ** 0.5))
156
+ num_heads = max(8, hidden_dim // 64)
157
+
158
+ elif params_millions < 2000: # 500M - 2B params (large network)
159
+ # Example: 50-200 nodes
160
+ # Architecture: GPT-3 Small → Medium
161
+ hidden_dim = int(1024 * (params_millions / 500) ** 0.6)
162
+ num_layers = max(18, int(32 * (params_millions / 2000) ** 0.4))
163
+ num_heads = max(12, hidden_dim // 64)
164
+
165
+ elif params_millions < 10000: # 2B - 10B params (very large)
166
+ # Example: 200-1000 nodes
167
+ # Architecture: GPT-3 Medium → Large
168
+ hidden_dim = int(1536 * (params_millions / 2000) ** 0.6)
169
+ num_layers = max(24, int(48 * (params_millions / 10000) ** 0.4))
170
+ num_heads = max(16, hidden_dim // 64)
171
+
172
+ elif params_millions < 100000: # 10B - 100B params (frontier)
173
+ # Example: 1000-10000 nodes
174
+ # Architecture: GPT-3 XL → GPT-4 scale
175
+ hidden_dim = int(2048 * (params_millions / 10000) ** 0.6)
176
+ num_layers = max(32, int(64 * (params_millions / 100000) ** 0.35))
177
+ num_heads = max(20, hidden_dim // 96)
178
+
179
+ else: # > 100B params (mega-scale)
180
+ # Example: 10000+ nodes
181
+ # Architecture: Beyond GPT-4
182
+ hidden_dim = int(4096 * (params_millions / 100000) ** 0.5)
183
+ num_layers = max(48, int(80 * (params_millions / 100000) ** 0.3))
184
+ num_heads = max(32, hidden_dim // 128)
185
+
186
+ # Ensure hidden_dim is divisible by num_heads AND head_dim is even (for RoPE)
187
+ # head_dim = hidden_dim // num_heads must be even for rotary embeddings
188
+ head_dim = hidden_dim // num_heads
189
+ if head_dim % 2 != 0:
190
+ head_dim = ((head_dim + 1) // 2) * 2 # Round up to nearest even
191
+ hidden_dim = head_dim * num_heads
192
+
193
+ # GQA: Use 1/3 to 1/4 the number of KV heads
194
+ # IMPORTANT: num_heads must be divisible by num_kv_heads
195
+ num_kv_heads = max(1, num_heads // 3)
196
+ # Round to ensure divisibility
197
+ while num_heads % num_kv_heads != 0:
198
+ num_kv_heads -= 1
199
+ if num_kv_heads < 1:
200
+ num_kv_heads = 1
201
+ break
202
+
203
+ # SwiGLU intermediate dimension (standard: 8/3 × hidden for gated FFN)
204
+ intermediate_dim = int(hidden_dim * 8 / 3)
205
+ # Round to nearest multiple of 64 for efficiency
206
+ intermediate_dim = ((intermediate_dim + 63) // 64) * 64
207
+
208
+ arch = ModelArchitecture(
209
+ hidden_dim=hidden_dim,
210
+ intermediate_dim=intermediate_dim,
211
+ num_layers=num_layers,
212
+ num_heads=num_heads,
213
+ num_kv_heads=num_kv_heads,
214
+ )
215
+
216
+ # Verify we're within budget
217
+ actual_memory = arch.estimate_memory_mb()
218
+ if actual_memory > usable_memory_mb:
219
+ # Scale down slightly
220
+ scale_factor = math.sqrt(usable_memory_mb / actual_memory)
221
+ hidden_dim = int(hidden_dim * scale_factor)
222
+ # Ensure head_dim is even (for RoPE)
223
+ head_dim = hidden_dim // num_heads
224
+ if head_dim % 2 != 0:
225
+ head_dim = (head_dim // 2) * 2 # Round down to nearest even
226
+ hidden_dim = head_dim * num_heads
227
+
228
+ arch = ModelArchitecture(
229
+ hidden_dim=hidden_dim,
230
+ intermediate_dim=int(hidden_dim * 8 / 3),
231
+ num_layers=num_layers,
232
+ num_heads=num_heads,
233
+ num_kv_heads=num_kv_heads,
234
+ )
235
+
236
+ logger.info(f"Calculated architecture: {arch.num_layers}L × {arch.hidden_dim}H "
237
+ f"({arch.estimate_params()/1e6:.0f}M params, {arch.estimate_memory_mb():.0f}MB)")
238
+
239
+ return arch
240
+
241
+
242
+ def should_upgrade_architecture(
243
+ current: ModelArchitecture,
244
+ new: ModelArchitecture,
245
+ min_improvement: float = 1.3
246
+ ) -> Tuple[bool, str]:
247
+ """
248
+ Determine if architecture upgrade is worthwhile.
249
+
250
+ Args:
251
+ current: Current architecture
252
+ new: Proposed new architecture
253
+ min_improvement: Minimum parameter ratio to trigger upgrade (default 1.3 = 30% improvement)
254
+
255
+ Returns:
256
+ (should_upgrade, reason)
257
+ """
258
+ current_params = current.estimate_params()
259
+ new_params = new.estimate_params()
260
+ improvement_ratio = new_params / current_params
261
+
262
+ if improvement_ratio < min_improvement:
263
+ return False, f"Improvement {improvement_ratio:.2f}x < threshold {min_improvement}x"
264
+
265
+ # Additional checks: ensure balanced scaling
266
+ width_ratio = new.hidden_dim / current.hidden_dim
267
+ depth_ratio = new.num_layers / current.num_layers
268
+
269
+ if width_ratio < 1.0 or depth_ratio < 1.0:
270
+ return False, "Architecture would shrink (not allowed)"
271
+
272
+ # Warn if extremely imbalanced
273
+ if width_ratio > 2.0 and depth_ratio < 1.1:
274
+ logger.warning(f"Width growing much faster than depth ({width_ratio:.1f}x vs {depth_ratio:.1f}x)")
275
+
276
+ reason = (f"Upgrade justified: {current_params/1e6:.0f}M → {new_params/1e6:.0f}M params "
277
+ f"({improvement_ratio:.2f}x, width {width_ratio:.2f}x, depth {depth_ratio:.2f}x)")
278
+
279
+ return True, reason
280
+
281
+
282
+ def estimate_memory_per_layer(arch: ModelArchitecture) -> float:
283
+ """
284
+ Estimate memory needed per layer (useful for layer assignment).
285
+
286
+ Returns memory in MB.
287
+ """
288
+ # Per-layer params
289
+ attn_params = 4 * arch.hidden_dim * arch.hidden_dim
290
+ ffn_params = 3 * arch.hidden_dim * arch.intermediate_dim
291
+ norm_params = 2 * arch.hidden_dim
292
+
293
+ layer_params = attn_params + ffn_params + norm_params
294
+
295
+ # Memory: params × 16 bytes (weights + grads + Adam states)
296
+ return (layer_params * 16) / (1024 * 1024)
297
+
298
+
299
+ def estimate_embedding_memory_mb(hidden_dim: int, vocab_capacity: int) -> float:
300
+ """
301
+ Estimate memory for embedding and LM head based on vocab capacity.
302
+
303
+ Args:
304
+ hidden_dim: Model hidden dimension
305
+ vocab_capacity: Current vocabulary capacity (NOT tokenizer vocab_size)
306
+
307
+ Returns:
308
+ Memory in MB for embedding + lm_head (weights + gradients + optimizer states)
309
+ """
310
+ # Embedding: vocab_capacity × hidden_dim params
311
+ # LM head: vocab_capacity × hidden_dim params (not tied)
312
+ # Total: 2 × vocab_capacity × hidden_dim
313
+ embed_params = 2 * vocab_capacity * hidden_dim
314
+
315
+ # Memory: params × 16 bytes (weights 4B + grads 4B + Adam m 4B + Adam v 4B)
316
+ return (embed_params * 16) / (1024 * 1024)
317
+
318
+
319
+ def calculate_layer_assignment(
320
+ available_memory_mb: float,
321
+ arch: ModelArchitecture,
322
+ safety_factor: float = 0.6,
323
+ vocab_capacity: int = 32000,
324
+ training_mode: bool = True,
325
+ needs_embedding: bool = True
326
+ ) -> int:
327
+ """
328
+ Calculate how many layers a node can hold.
329
+
330
+ Args:
331
+ available_memory_mb: Node's available memory
332
+ arch: Current network architecture
333
+ safety_factor: Use only 60% of memory for safety (GPU) or 30% (CPU)
334
+ vocab_capacity: Current vocab capacity for embedding/LM head (dynamic!)
335
+ training_mode: If True, reserve more memory for gradients/optimizer
336
+ needs_embedding: If True, reserve memory for embedding/LM head (Driver/Validator only!)
337
+ Workers (middle layers) don't need embedding and can hold MORE layers!
338
+
339
+ Returns:
340
+ Number of layers this node can hold
341
+ """
342
+ usable_memory = available_memory_mb * safety_factor
343
+ memory_per_layer = estimate_memory_per_layer(arch)
344
+
345
+ # DYNAMIC: Calculate actual embedding/LM head memory based on vocab capacity
346
+ # This is critical for dynamic vocabulary - as vocab grows, less memory for layers!
347
+ # NOTE: Only Drivers (Layer 0) and Validators (Last Layer) need embedding memory!
348
+ # Workers (middle layers) skip this and can hold MORE layers!
349
+ if needs_embedding:
350
+ embedding_memory = estimate_embedding_memory_mb(arch.hidden_dim, vocab_capacity)
351
+ else:
352
+ embedding_memory = 0 # Workers don't need embedding/LM head!
353
+
354
+ # TRAINING vs INFERENCE memory overhead
355
+ # Training needs: forward activations + backward gradients + optimizer peak
356
+ # With gradient checkpointing (always on for >16 layers), activations are minimal
357
+ if training_mode:
358
+ # Reserve 20% for training overhead (checkpointed activations, optimizer peaks)
359
+ # Reduced from 35% since gradient checkpointing is now always enabled
360
+ training_overhead_ratio = 0.20
361
+ else:
362
+ # Inference only needs 5% buffer
363
+ training_overhead_ratio = 0.05
364
+
365
+ training_overhead = usable_memory * training_overhead_ratio
366
+ usable_for_layers = usable_memory - embedding_memory - training_overhead
367
+
368
+ if usable_for_layers <= 0:
369
+ # Not enough memory even for embedding - return minimum
370
+ logger.warning(f"[MEMORY] Very limited memory: {usable_memory:.0f}MB usable, "
371
+ f"{embedding_memory:.0f}MB for vocab, {training_overhead:.0f}MB overhead")
372
+ return 1
373
+
374
+ max_layers = max(1, int(usable_for_layers / memory_per_layer))
375
+
376
+ # Log the calculation for transparency
377
+ role = "Driver/Validator" if needs_embedding else "Worker"
378
+ logger.debug(f"[MEMORY] Layer calc ({role}): {available_memory_mb:.0f}MB × {safety_factor} = {usable_memory:.0f}MB usable")
379
+ if needs_embedding:
380
+ logger.debug(f"[MEMORY] - Embedding ({vocab_capacity:,} vocab): {embedding_memory:.0f}MB")
381
+ else:
382
+ logger.debug(f"[MEMORY] - No embedding needed (Worker)")
383
+ logger.debug(f"[MEMORY] - Training overhead ({training_overhead_ratio*100:.0f}%): {training_overhead:.0f}MB")
384
+ logger.debug(f"[MEMORY] - Available for layers: {usable_for_layers:.0f}MB → {max_layers} layers")
385
+
386
+ return max_layers
387
+