nexaroa 0.0.111__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. neuroshard/__init__.py +93 -0
  2. neuroshard/__main__.py +4 -0
  3. neuroshard/cli.py +466 -0
  4. neuroshard/core/__init__.py +92 -0
  5. neuroshard/core/consensus/verifier.py +252 -0
  6. neuroshard/core/crypto/__init__.py +20 -0
  7. neuroshard/core/crypto/ecdsa.py +392 -0
  8. neuroshard/core/economics/__init__.py +52 -0
  9. neuroshard/core/economics/constants.py +387 -0
  10. neuroshard/core/economics/ledger.py +2111 -0
  11. neuroshard/core/economics/market.py +975 -0
  12. neuroshard/core/economics/wallet.py +168 -0
  13. neuroshard/core/governance/__init__.py +74 -0
  14. neuroshard/core/governance/proposal.py +561 -0
  15. neuroshard/core/governance/registry.py +545 -0
  16. neuroshard/core/governance/versioning.py +332 -0
  17. neuroshard/core/governance/voting.py +453 -0
  18. neuroshard/core/model/__init__.py +30 -0
  19. neuroshard/core/model/dynamic.py +4186 -0
  20. neuroshard/core/model/llm.py +905 -0
  21. neuroshard/core/model/registry.py +164 -0
  22. neuroshard/core/model/scaler.py +387 -0
  23. neuroshard/core/model/tokenizer.py +568 -0
  24. neuroshard/core/network/__init__.py +56 -0
  25. neuroshard/core/network/connection_pool.py +72 -0
  26. neuroshard/core/network/dht.py +130 -0
  27. neuroshard/core/network/dht_plan.py +55 -0
  28. neuroshard/core/network/dht_proof_store.py +516 -0
  29. neuroshard/core/network/dht_protocol.py +261 -0
  30. neuroshard/core/network/dht_service.py +506 -0
  31. neuroshard/core/network/encrypted_channel.py +141 -0
  32. neuroshard/core/network/nat.py +201 -0
  33. neuroshard/core/network/nat_traversal.py +695 -0
  34. neuroshard/core/network/p2p.py +929 -0
  35. neuroshard/core/network/p2p_data.py +150 -0
  36. neuroshard/core/swarm/__init__.py +106 -0
  37. neuroshard/core/swarm/aggregation.py +729 -0
  38. neuroshard/core/swarm/buffers.py +643 -0
  39. neuroshard/core/swarm/checkpoint.py +709 -0
  40. neuroshard/core/swarm/compute.py +624 -0
  41. neuroshard/core/swarm/diloco.py +844 -0
  42. neuroshard/core/swarm/factory.py +1288 -0
  43. neuroshard/core/swarm/heartbeat.py +669 -0
  44. neuroshard/core/swarm/logger.py +487 -0
  45. neuroshard/core/swarm/router.py +658 -0
  46. neuroshard/core/swarm/service.py +640 -0
  47. neuroshard/core/training/__init__.py +29 -0
  48. neuroshard/core/training/checkpoint.py +600 -0
  49. neuroshard/core/training/distributed.py +1602 -0
  50. neuroshard/core/training/global_tracker.py +617 -0
  51. neuroshard/core/training/production.py +276 -0
  52. neuroshard/governance_cli.py +729 -0
  53. neuroshard/grpc_server.py +895 -0
  54. neuroshard/runner.py +3223 -0
  55. neuroshard/sdk/__init__.py +92 -0
  56. neuroshard/sdk/client.py +990 -0
  57. neuroshard/sdk/errors.py +101 -0
  58. neuroshard/sdk/types.py +282 -0
  59. neuroshard/tracker/__init__.py +0 -0
  60. neuroshard/tracker/server.py +864 -0
  61. neuroshard/ui/__init__.py +0 -0
  62. neuroshard/ui/app.py +102 -0
  63. neuroshard/ui/templates/index.html +1052 -0
  64. neuroshard/utils/__init__.py +0 -0
  65. neuroshard/utils/autostart.py +81 -0
  66. neuroshard/utils/hardware.py +121 -0
  67. neuroshard/utils/serialization.py +90 -0
  68. neuroshard/version.py +1 -0
  69. nexaroa-0.0.111.dist-info/METADATA +283 -0
  70. nexaroa-0.0.111.dist-info/RECORD +78 -0
  71. nexaroa-0.0.111.dist-info/WHEEL +5 -0
  72. nexaroa-0.0.111.dist-info/entry_points.txt +4 -0
  73. nexaroa-0.0.111.dist-info/licenses/LICENSE +190 -0
  74. nexaroa-0.0.111.dist-info/top_level.txt +2 -0
  75. protos/__init__.py +0 -0
  76. protos/neuroshard.proto +651 -0
  77. protos/neuroshard_pb2.py +160 -0
  78. protos/neuroshard_pb2_grpc.py +1298 -0
@@ -0,0 +1,600 @@
1
+ """
2
+ Checkpoint Sharding System for NeuroShard
3
+
4
+ This module enables distributing model checkpoints across multiple nodes,
5
+ allowing the network to scale beyond what any single node can hold.
6
+
7
+ Architecture:
8
+ - Full model is split into N shards (typically by layer groups)
9
+ - Each shard is replicated across M nodes (default M=3 for redundancy)
10
+ - Shards are announced via DHT for discovery
11
+ - Training updates are coordinated across shard holders
12
+
13
+ Sharding Strategy:
14
+ - Phase 1-2 (Bootstrap/Early): Full model per node (no sharding needed)
15
+ - Phase 3 (Growth): Optional sharding for 7B model
16
+ - Phase 4 (Mature): Mandatory sharding for 70B+ model
17
+
18
+ Example for 70B model (80 layers):
19
+ - Shard 0: Embedding + Layers 0-19 (~70GB / 4 = ~17.5GB)
20
+ - Shard 1: Layers 20-39 (~17.5GB)
21
+ - Shard 2: Layers 40-59 (~17.5GB)
22
+ - Shard 3: Layers 60-79 + LM Head (~17.5GB)
23
+ """
24
+
25
+ import hashlib
26
+ import json
27
+ import time
28
+ import threading
29
+ import logging
30
+ from typing import Dict, List, Optional, Tuple, Set, Any
31
+ from dataclasses import dataclass, field
32
+ from enum import Enum
33
+ from pathlib import Path
34
+
35
+ import torch
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ class ShardingStrategy(Enum):
41
+ """How to split the model."""
42
+ NONE = "none" # Full model per node (Phase 1-2)
43
+ LAYER_GROUPS = "layer_groups" # Split by layer ranges (Phase 3-4)
44
+ TENSOR_PARALLEL = "tensor_parallel" # Split within layers (future)
45
+
46
+
47
+ @dataclass
48
+ class ShardConfig:
49
+ """Configuration for a single shard."""
50
+ shard_id: int
51
+ total_shards: int
52
+
53
+ # Layer range this shard covers
54
+ start_layer: int
55
+ end_layer: int # Exclusive
56
+
57
+ # Special flags
58
+ has_embedding: bool = False
59
+ has_lm_head: bool = False
60
+
61
+ # Size estimates
62
+ estimated_size_mb: float = 0.0
63
+
64
+ def __post_init__(self):
65
+ """Validate shard config."""
66
+ assert 0 <= self.shard_id < self.total_shards
67
+ assert self.start_layer < self.end_layer
68
+
69
+
70
+ @dataclass
71
+ class ShardState:
72
+ """State of a shard on this node."""
73
+ config: ShardConfig
74
+
75
+ # Version tracking
76
+ version: int = 0
77
+ model_hash: str = ""
78
+
79
+ # Weights (loaded or None)
80
+ weights: Optional[Dict[str, torch.Tensor]] = None
81
+
82
+ # Status
83
+ is_loaded: bool = False
84
+ last_updated: float = field(default_factory=time.time)
85
+
86
+ # Replication
87
+ replicas: Set[str] = field(default_factory=set) # Node URLs holding copies
88
+
89
+
90
+ @dataclass
91
+ class ShardAnnouncement:
92
+ """Announcement of shard availability."""
93
+ node_id: str
94
+ node_url: str
95
+ grpc_addr: str
96
+
97
+ shard_id: int
98
+ total_shards: int
99
+ version: int
100
+ model_hash: str
101
+
102
+ # Capacity info
103
+ available_memory_mb: float
104
+ current_load: float # 0-1
105
+
106
+ timestamp: float = field(default_factory=time.time)
107
+
108
+ def to_dict(self) -> Dict:
109
+ return {
110
+ "node_id": self.node_id,
111
+ "node_url": self.node_url,
112
+ "grpc_addr": self.grpc_addr,
113
+ "shard_id": self.shard_id,
114
+ "total_shards": self.total_shards,
115
+ "version": self.version,
116
+ "model_hash": self.model_hash,
117
+ "available_memory_mb": self.available_memory_mb,
118
+ "current_load": self.current_load,
119
+ "timestamp": self.timestamp,
120
+ }
121
+
122
+ @classmethod
123
+ def from_dict(cls, data: Dict) -> 'ShardAnnouncement':
124
+ return cls(**data)
125
+
126
+
127
+ class ShardRegistry:
128
+ """
129
+ Registry for tracking which nodes hold which shards.
130
+
131
+ Uses DHT for decentralized discovery:
132
+ - Key: "shard_{model_id}_{shard_id}"
133
+ - Value: JSON list of ShardAnnouncements
134
+ """
135
+
136
+ def __init__(self, dht_protocol=None, model_id: str = "neuro_llm"):
137
+ self.dht = dht_protocol
138
+ self.model_id = model_id
139
+
140
+ # Local cache
141
+ self.shard_holders: Dict[int, List[ShardAnnouncement]] = {}
142
+ self.lock = threading.Lock()
143
+
144
+ # Refresh interval
145
+ self.cache_ttl = 60 # seconds
146
+ self.last_refresh: Dict[int, float] = {}
147
+
148
+ def announce_shard(self, announcement: ShardAnnouncement) -> bool:
149
+ """Announce that we hold a shard."""
150
+ if not self.dht:
151
+ logger.warning("No DHT available for shard announcement")
152
+ return False
153
+
154
+ try:
155
+ key = f"shard_{self.model_id}_{announcement.shard_id}"
156
+
157
+ # Get existing announcements
158
+ existing = self._get_shard_holders_from_dht(announcement.shard_id)
159
+
160
+ # Add/update our announcement
161
+ updated = False
162
+ for i, ann in enumerate(existing):
163
+ if ann.node_id == announcement.node_id:
164
+ existing[i] = announcement
165
+ updated = True
166
+ break
167
+
168
+ if not updated:
169
+ existing.append(announcement)
170
+
171
+ # Store back to DHT
172
+ value = json.dumps([a.to_dict() for a in existing])
173
+ self.dht.store(key, value)
174
+
175
+ # Update local cache
176
+ with self.lock:
177
+ self.shard_holders[announcement.shard_id] = existing
178
+
179
+ logger.info(f"Announced shard {announcement.shard_id} "
180
+ f"(version={announcement.version}, holders={len(existing)})")
181
+ return True
182
+
183
+ except Exception as e:
184
+ logger.error(f"Failed to announce shard: {e}")
185
+ return False
186
+
187
+ def get_shard_holders(self, shard_id: int, refresh: bool = False) -> List[ShardAnnouncement]:
188
+ """Get nodes holding a specific shard."""
189
+ with self.lock:
190
+ # Check cache
191
+ if not refresh and shard_id in self.shard_holders:
192
+ if time.time() - self.last_refresh.get(shard_id, 0) < self.cache_ttl:
193
+ return self.shard_holders[shard_id]
194
+
195
+ # Refresh from DHT
196
+ holders = self._get_shard_holders_from_dht(shard_id)
197
+
198
+ with self.lock:
199
+ self.shard_holders[shard_id] = holders
200
+ self.last_refresh[shard_id] = time.time()
201
+
202
+ return holders
203
+
204
+ def _get_shard_holders_from_dht(self, shard_id: int) -> List[ShardAnnouncement]:
205
+ """Query DHT for shard holders."""
206
+ if not self.dht:
207
+ return []
208
+
209
+ try:
210
+ key = f"shard_{self.model_id}_{shard_id}"
211
+ value = self.dht.lookup_value(key)
212
+
213
+ if not value:
214
+ return []
215
+
216
+ data = json.loads(value)
217
+ return [ShardAnnouncement.from_dict(d) for d in data]
218
+
219
+ except Exception as e:
220
+ logger.debug(f"DHT lookup failed for shard {shard_id}: {e}")
221
+ return []
222
+
223
+ def find_best_holder(self, shard_id: int) -> Optional[ShardAnnouncement]:
224
+ """Find the best node to request a shard from."""
225
+ holders = self.get_shard_holders(shard_id)
226
+
227
+ if not holders:
228
+ return None
229
+
230
+ # Sort by: highest version, lowest load, most recent
231
+ holders.sort(key=lambda h: (-h.version, h.current_load, -h.timestamp))
232
+
233
+ return holders[0]
234
+
235
+ def get_all_shards_status(self, total_shards: int) -> Dict[int, Dict]:
236
+ """Get status of all shards."""
237
+ status = {}
238
+
239
+ for shard_id in range(total_shards):
240
+ holders = self.get_shard_holders(shard_id)
241
+
242
+ status[shard_id] = {
243
+ "shard_id": shard_id,
244
+ "holder_count": len(holders),
245
+ "is_available": len(holders) > 0,
246
+ "is_redundant": len(holders) >= 3,
247
+ "max_version": max((h.version for h in holders), default=0),
248
+ "holders": [h.node_url for h in holders],
249
+ }
250
+
251
+ return status
252
+
253
+
254
+ class CheckpointShardManager:
255
+ """
256
+ Manages sharded checkpoints for a NeuroLLM model.
257
+
258
+ Responsibilities:
259
+ 1. Determine sharding strategy based on model size and node capacity
260
+ 2. Load/save individual shards
261
+ 3. Coordinate with ShardRegistry for discovery
262
+ 4. Handle shard replication
263
+ """
264
+
265
+ # Minimum replicas per shard
266
+ MIN_REPLICAS = 3
267
+
268
+ def __init__(
269
+ self,
270
+ model, # NeuroLLMForCausalLM
271
+ node_id: str,
272
+ node_url: str,
273
+ registry: Optional[ShardRegistry] = None,
274
+ checkpoint_dir: Optional[Path] = None
275
+ ):
276
+ self.model = model
277
+ self.node_id = node_id
278
+ self.node_url = node_url
279
+ self.registry = registry
280
+ self.checkpoint_dir = checkpoint_dir or Path.home() / ".neuroshard" / "shards"
281
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
282
+
283
+ # Current shard state
284
+ self.my_shards: Dict[int, ShardState] = {}
285
+ self.shard_configs: List[ShardConfig] = []
286
+
287
+ # Threading
288
+ self.lock = threading.Lock()
289
+
290
+ def compute_sharding_strategy(
291
+ self,
292
+ total_params: int,
293
+ available_memory_mb: float,
294
+ num_layers: int
295
+ ) -> Tuple[ShardingStrategy, List[ShardConfig]]:
296
+ """
297
+ Determine optimal sharding strategy.
298
+
299
+ Args:
300
+ total_params: Total model parameters
301
+ available_memory_mb: Available memory on this node
302
+ num_layers: Number of transformer layers
303
+
304
+ Returns:
305
+ (strategy, list of shard configs)
306
+ """
307
+ # Estimate model size (4 bytes per param for float32)
308
+ model_size_mb = (total_params * 4) / (1024 * 1024)
309
+
310
+ # Add overhead for optimizer state, activations (2x)
311
+ required_memory_mb = model_size_mb * 2
312
+
313
+ logger.info(f"Model size: {model_size_mb:.1f}MB, "
314
+ f"Required: {required_memory_mb:.1f}MB, "
315
+ f"Available: {available_memory_mb:.1f}MB")
316
+
317
+ # Strategy decision
318
+ if required_memory_mb <= available_memory_mb:
319
+ # Can hold full model
320
+ return ShardingStrategy.NONE, [
321
+ ShardConfig(
322
+ shard_id=0,
323
+ total_shards=1,
324
+ start_layer=0,
325
+ end_layer=num_layers,
326
+ has_embedding=True,
327
+ has_lm_head=True,
328
+ estimated_size_mb=model_size_mb
329
+ )
330
+ ]
331
+
332
+ # Need to shard - calculate number of shards needed
333
+ shard_memory_target = available_memory_mb * 0.8 # Leave 20% headroom
334
+ num_shards = max(2, int(required_memory_mb / shard_memory_target) + 1)
335
+
336
+ # Round up to power of 2 for cleaner division
337
+ num_shards = 2 ** (num_shards - 1).bit_length()
338
+ num_shards = min(num_shards, num_layers) # Can't have more shards than layers
339
+
340
+ # Create shard configs
341
+ layers_per_shard = num_layers // num_shards
342
+ shard_size_mb = model_size_mb / num_shards
343
+
344
+ configs = []
345
+ for i in range(num_shards):
346
+ start = i * layers_per_shard
347
+ end = (i + 1) * layers_per_shard if i < num_shards - 1 else num_layers
348
+
349
+ configs.append(ShardConfig(
350
+ shard_id=i,
351
+ total_shards=num_shards,
352
+ start_layer=start,
353
+ end_layer=end,
354
+ has_embedding=(i == 0),
355
+ has_lm_head=(i == num_shards - 1),
356
+ estimated_size_mb=shard_size_mb
357
+ ))
358
+
359
+ logger.info(f"Sharding strategy: {num_shards} shards, "
360
+ f"{layers_per_shard} layers each, "
361
+ f"~{shard_size_mb:.1f}MB per shard")
362
+
363
+ return ShardingStrategy.LAYER_GROUPS, configs
364
+
365
+ def extract_shard_weights(self, shard_config: ShardConfig) -> Dict[str, torch.Tensor]:
366
+ """Extract weights for a specific shard from the full model."""
367
+ state_dict = self.model.state_dict()
368
+ shard_weights = {}
369
+
370
+ for name, param in state_dict.items():
371
+ # Check if this parameter belongs to this shard
372
+ if self._param_belongs_to_shard(name, shard_config):
373
+ shard_weights[name] = param.clone()
374
+
375
+ return shard_weights
376
+
377
+ def _param_belongs_to_shard(self, param_name: str, config: ShardConfig) -> bool:
378
+ """Check if a parameter belongs to a shard."""
379
+ # Embedding layer
380
+ if "embed" in param_name.lower() or "wte" in param_name.lower():
381
+ return config.has_embedding
382
+
383
+ # LM head
384
+ if "lm_head" in param_name.lower() or "output" in param_name.lower():
385
+ return config.has_lm_head
386
+
387
+ # Transformer layers - extract layer number
388
+ import re
389
+ match = re.search(r'layers?[._](\d+)', param_name.lower())
390
+ if match:
391
+ layer_num = int(match.group(1))
392
+ return config.start_layer <= layer_num < config.end_layer
393
+
394
+ # Final norm (belongs to last shard)
395
+ if "final" in param_name.lower() or "ln_f" in param_name.lower():
396
+ return config.has_lm_head
397
+
398
+ # Default: include in first shard
399
+ return config.shard_id == 0
400
+
401
+ def save_shard(self, shard_id: int, version: int) -> Optional[Path]:
402
+ """Save a shard to disk."""
403
+ if shard_id not in self.my_shards:
404
+ logger.error(f"Shard {shard_id} not loaded")
405
+ return None
406
+
407
+ shard_state = self.my_shards[shard_id]
408
+
409
+ if not shard_state.weights:
410
+ # Extract from model
411
+ shard_state.weights = self.extract_shard_weights(shard_state.config)
412
+
413
+ # Compute hash
414
+ model_hash = self._compute_weights_hash(shard_state.weights)
415
+
416
+ # Save
417
+ filename = f"shard_{shard_id}_v{version}.pt"
418
+ path = self.checkpoint_dir / filename
419
+
420
+ torch.save({
421
+ "shard_id": shard_id,
422
+ "config": {
423
+ "shard_id": shard_state.config.shard_id,
424
+ "total_shards": shard_state.config.total_shards,
425
+ "start_layer": shard_state.config.start_layer,
426
+ "end_layer": shard_state.config.end_layer,
427
+ "has_embedding": shard_state.config.has_embedding,
428
+ "has_lm_head": shard_state.config.has_lm_head,
429
+ },
430
+ "weights": shard_state.weights,
431
+ "version": version,
432
+ "model_hash": model_hash,
433
+ "timestamp": time.time(),
434
+ }, path)
435
+
436
+ # Update state
437
+ shard_state.version = version
438
+ shard_state.model_hash = model_hash
439
+ shard_state.last_updated = time.time()
440
+
441
+ logger.info(f"Saved shard {shard_id} v{version} to {path}")
442
+
443
+ return path
444
+
445
+ def load_shard(self, shard_id: int, path: Optional[Path] = None) -> bool:
446
+ """Load a shard from disk."""
447
+ if path is None:
448
+ # Find latest version
449
+ pattern = f"shard_{shard_id}_v*.pt"
450
+ files = list(self.checkpoint_dir.glob(pattern))
451
+ if not files:
452
+ logger.warning(f"No shard {shard_id} found on disk")
453
+ return False
454
+ path = max(files, key=lambda p: p.stat().st_mtime)
455
+
456
+ try:
457
+ data = torch.load(path, map_location="cpu", weights_only=False)
458
+
459
+ config = ShardConfig(**data["config"])
460
+
461
+ self.my_shards[shard_id] = ShardState(
462
+ config=config,
463
+ version=data["version"],
464
+ model_hash=data["model_hash"],
465
+ weights=data["weights"],
466
+ is_loaded=True,
467
+ last_updated=data.get("timestamp", time.time()),
468
+ )
469
+
470
+ logger.info(f"Loaded shard {shard_id} v{data['version']} from {path}")
471
+ return True
472
+
473
+ except Exception as e:
474
+ logger.error(f"Failed to load shard {shard_id}: {e}")
475
+ return False
476
+
477
+ def load_shard_into_model(self, shard_id: int) -> bool:
478
+ """Load shard weights into the model."""
479
+ if shard_id not in self.my_shards:
480
+ logger.error(f"Shard {shard_id} not loaded")
481
+ return False
482
+
483
+ shard_state = self.my_shards[shard_id]
484
+
485
+ if not shard_state.weights:
486
+ logger.error(f"Shard {shard_id} has no weights")
487
+ return False
488
+
489
+ # Load into model
490
+ current_state = self.model.state_dict()
491
+
492
+ for name, param in shard_state.weights.items():
493
+ if name in current_state:
494
+ current_state[name].copy_(param)
495
+ else:
496
+ logger.warning(f"Parameter {name} not found in model")
497
+
498
+ logger.info(f"Loaded shard {shard_id} weights into model")
499
+ return True
500
+
501
+ def _compute_weights_hash(self, weights: Dict[str, torch.Tensor]) -> str:
502
+ """Compute hash of weights."""
503
+ hasher = hashlib.sha256()
504
+
505
+ for name in sorted(weights.keys()):
506
+ hasher.update(name.encode())
507
+ hasher.update(weights[name].cpu().numpy().tobytes()[:1000])
508
+
509
+ return hasher.hexdigest()[:16]
510
+
511
+ def get_shard_status(self) -> Dict[str, Any]:
512
+ """Get status of all shards on this node."""
513
+ return {
514
+ "node_id": self.node_id,
515
+ "shards": {
516
+ shard_id: {
517
+ "shard_id": shard_id,
518
+ "version": state.version,
519
+ "model_hash": state.model_hash,
520
+ "is_loaded": state.is_loaded,
521
+ "layers": f"{state.config.start_layer}-{state.config.end_layer}",
522
+ "has_embedding": state.config.has_embedding,
523
+ "has_lm_head": state.config.has_lm_head,
524
+ "replicas": len(state.replicas),
525
+ }
526
+ for shard_id, state in self.my_shards.items()
527
+ },
528
+ "total_shards": len(self.shard_configs),
529
+ }
530
+
531
+
532
+ def compute_shard_assignment(
533
+ node_capacities: Dict[str, float], # node_id -> available_memory_mb
534
+ model_size_mb: float,
535
+ num_layers: int,
536
+ min_replicas: int = 3
537
+ ) -> Dict[str, List[int]]:
538
+ """
539
+ Compute optimal shard assignment across nodes.
540
+
541
+ This is a simplified greedy algorithm. A production system would use
542
+ more sophisticated optimization (e.g., constraint satisfaction, ILP).
543
+
544
+ Args:
545
+ node_capacities: Available memory per node
546
+ model_size_mb: Total model size
547
+ num_layers: Number of transformer layers
548
+ min_replicas: Minimum replicas per shard
549
+
550
+ Returns:
551
+ Dict mapping node_id -> list of shard_ids to hold
552
+ """
553
+ # Sort nodes by capacity (largest first)
554
+ sorted_nodes = sorted(node_capacities.items(), key=lambda x: -x[1])
555
+
556
+ if not sorted_nodes:
557
+ return {}
558
+
559
+ # Determine number of shards based on largest node
560
+ max_capacity = sorted_nodes[0][1]
561
+ shard_size = max_capacity * 0.8 # Leave headroom
562
+ num_shards = max(1, int(model_size_mb / shard_size) + 1)
563
+ num_shards = min(num_shards, num_layers)
564
+
565
+ actual_shard_size = model_size_mb / num_shards
566
+
567
+ # Assign shards to nodes (round-robin with capacity check)
568
+ assignments: Dict[str, List[int]] = {node_id: [] for node_id, _ in sorted_nodes}
569
+ shard_counts = {shard_id: 0 for shard_id in range(num_shards)}
570
+
571
+ # First pass: assign one shard per node (ensure coverage)
572
+ for shard_id in range(num_shards):
573
+ for node_id, capacity in sorted_nodes:
574
+ current_load = len(assignments[node_id]) * actual_shard_size
575
+ if current_load + actual_shard_size <= capacity:
576
+ assignments[node_id].append(shard_id)
577
+ shard_counts[shard_id] += 1
578
+ break
579
+
580
+ # Second pass: add replicas until min_replicas reached
581
+ for shard_id in range(num_shards):
582
+ while shard_counts[shard_id] < min_replicas:
583
+ # Find node with capacity that doesn't already have this shard
584
+ assigned = False
585
+ for node_id, capacity in sorted_nodes:
586
+ if shard_id in assignments[node_id]:
587
+ continue
588
+ current_load = len(assignments[node_id]) * actual_shard_size
589
+ if current_load + actual_shard_size <= capacity:
590
+ assignments[node_id].append(shard_id)
591
+ shard_counts[shard_id] += 1
592
+ assigned = True
593
+ break
594
+
595
+ if not assigned:
596
+ logger.warning(f"Could not assign {min_replicas} replicas for shard {shard_id}")
597
+ break
598
+
599
+ return assignments
600
+