gpu-worker 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,465 @@
1
+ """
2
+ 模型分片加载器
3
+
4
+ 支持按 Transformer Block 范围加载模型的部分层,
5
+ 用于分布式推理场景。
6
+
7
+ 参考 Petals 的 from_pretrained 实现。
8
+ """
9
+ import logging
10
+ from typing import Dict, Any, List, Optional, Tuple
11
+ from dataclasses import dataclass
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ @dataclass
20
+ class LayerInfo:
21
+ """层信息"""
22
+ layer_idx: int
23
+ layer_name: str
24
+ param_count: int
25
+ memory_bytes: int
26
+
27
+
28
+ class ModelShard(nn.Module):
29
+ """
30
+ 模型分片
31
+
32
+ 只加载模型的指定层范围,用于分布式推理
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ model_id: str,
38
+ start_layer: int,
39
+ end_layer: int,
40
+ device: str = "cuda",
41
+ dtype: torch.dtype = torch.float16,
42
+ ):
43
+ super().__init__()
44
+ self.model_id = model_id
45
+ self.start_layer = start_layer
46
+ self.end_layer = end_layer
47
+ self.device = device
48
+ self.dtype = dtype
49
+
50
+ # 模型组件(加载后填充)
51
+ self.layers: nn.ModuleList = nn.ModuleList()
52
+ self.config = None
53
+ self.is_first_shard = False # 是否包含 embedding
54
+ self.is_last_shard = False # 是否包含 lm_head
55
+
56
+ # 元数据
57
+ self.total_layers = 0
58
+ self.hidden_size = 0
59
+ self.num_heads = 0
60
+
61
+ @classmethod
62
+ def from_pretrained(
63
+ cls,
64
+ model_id: str,
65
+ start_layer: int,
66
+ end_layer: int,
67
+ device: str = "cuda",
68
+ dtype: torch.dtype = torch.float16,
69
+ **kwargs
70
+ ) -> "ModelShard":
71
+ """
72
+ 加载模型分片
73
+
74
+ Args:
75
+ model_id: HuggingFace 模型 ID
76
+ start_layer: 起始层(包含)
77
+ end_layer: 结束层(不包含)
78
+ device: 目标设备
79
+ dtype: 数据类型
80
+
81
+ Returns:
82
+ ModelShard 实例
83
+ """
84
+ from transformers import AutoConfig, AutoModelForCausalLM
85
+
86
+ logger.info(f"Loading model shard: {model_id} layers [{start_layer}, {end_layer})")
87
+
88
+ # 加载配置
89
+ config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
90
+
91
+ # 获取总层数
92
+ total_layers = getattr(config, "num_hidden_layers", None)
93
+ if total_layers is None:
94
+ total_layers = getattr(config, "n_layer", 32) # 默认值
95
+
96
+ if end_layer > total_layers:
97
+ raise ValueError(f"end_layer ({end_layer}) > total_layers ({total_layers})")
98
+
99
+ # 创建分片实例
100
+ shard = cls(model_id, start_layer, end_layer, device, dtype)
101
+ shard.config = config
102
+ shard.total_layers = total_layers
103
+ shard.hidden_size = getattr(config, "hidden_size", 4096)
104
+ shard.num_heads = getattr(config, "num_attention_heads", 32)
105
+ shard.is_first_shard = (start_layer == 0)
106
+ shard.is_last_shard = (end_layer == total_layers)
107
+
108
+ # 加载完整模型然后提取需要的层
109
+ # 注意:对于非常大的模型,应该使用更高效的方法
110
+ logger.info("Loading full model for layer extraction...")
111
+
112
+ with torch.device("meta"):
113
+ # 先在 meta 设备上创建模型结构
114
+ full_model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
115
+
116
+ # 获取层的名称模式
117
+ layer_module = _get_layer_module(full_model, config)
118
+
119
+ # 使用 from_pretrained 的 device_map 功能只加载需要的层
120
+ device_map = _create_device_map_for_layers(
121
+ config,
122
+ start_layer,
123
+ end_layer,
124
+ device,
125
+ include_embeddings=shard.is_first_shard,
126
+ include_lm_head=shard.is_last_shard,
127
+ )
128
+
129
+ logger.info(f"Loading layers with device_map: {list(device_map.keys())[:5]}...")
130
+
131
+ full_model = AutoModelForCausalLM.from_pretrained(
132
+ model_id,
133
+ torch_dtype=dtype,
134
+ device_map=device_map,
135
+ trust_remote_code=True,
136
+ **kwargs
137
+ )
138
+
139
+ # 提取需要的层
140
+ shard._extract_layers(full_model, config)
141
+
142
+ # 清理不需要的组件
143
+ del full_model
144
+ torch.cuda.empty_cache()
145
+
146
+ logger.info(f"Model shard loaded: {shard.get_memory_usage():.2f} GB")
147
+
148
+ return shard
149
+
150
+ def _extract_layers(self, full_model: nn.Module, config) -> None:
151
+ """从完整模型中提取需要的层"""
152
+ # 获取层模块
153
+ layers = _get_layer_module(full_model, config)
154
+
155
+ if layers is None:
156
+ raise RuntimeError(f"Cannot find layers in model architecture")
157
+
158
+ # 提取指定范围的层
159
+ for i in range(self.start_layer, self.end_layer):
160
+ layer = layers[i]
161
+ self.layers.append(layer)
162
+
163
+ # 如果是第一个分片,保存 embedding
164
+ if self.is_first_shard:
165
+ self.embed_tokens = _get_embedding_module(full_model, config)
166
+ self.embed_positions = getattr(full_model.model, "embed_positions", None)
167
+
168
+ # 如果是最后一个分片,保存 lm_head 和 norm
169
+ if self.is_last_shard:
170
+ self.lm_head = full_model.lm_head if hasattr(full_model, "lm_head") else None
171
+ self.norm = _get_norm_module(full_model, config)
172
+
173
+ def forward(
174
+ self,
175
+ hidden_states: torch.Tensor,
176
+ attention_mask: Optional[torch.Tensor] = None,
177
+ position_ids: Optional[torch.Tensor] = None,
178
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
179
+ use_cache: bool = True,
180
+ ) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
181
+ """
182
+ 前向传播
183
+
184
+ Args:
185
+ hidden_states: 输入隐藏状态 [batch, seq_len, hidden_size]
186
+ attention_mask: 注意力掩码
187
+ position_ids: 位置 ID
188
+ past_key_values: 过去的 KV-Cache
189
+ use_cache: 是否返回新的 KV-Cache
190
+
191
+ Returns:
192
+ (hidden_states, new_past_key_values)
193
+ """
194
+ # 如果是第一个分片,需要处理 embedding
195
+ if self.is_first_shard and hasattr(self, 'embed_tokens'):
196
+ # hidden_states 此时是 input_ids
197
+ hidden_states = self.embed_tokens(hidden_states)
198
+ if hasattr(self, 'embed_positions') and self.embed_positions is not None:
199
+ hidden_states = hidden_states + self.embed_positions(position_ids)
200
+
201
+ # 依次通过每一层
202
+ new_past_key_values = [] if use_cache else None
203
+ past_idx = 0
204
+
205
+ for layer in self.layers:
206
+ layer_past = past_key_values[past_idx] if past_key_values else None
207
+
208
+ # 调用层的 forward
209
+ layer_outputs = layer(
210
+ hidden_states,
211
+ attention_mask=attention_mask,
212
+ position_ids=position_ids,
213
+ past_key_value=layer_past,
214
+ use_cache=use_cache,
215
+ )
216
+
217
+ hidden_states = layer_outputs[0]
218
+
219
+ if use_cache:
220
+ new_past_key_values.append(layer_outputs[1])
221
+
222
+ past_idx += 1
223
+
224
+ # 如果是最后一个分片,应用 norm
225
+ if self.is_last_shard and hasattr(self, 'norm') and self.norm is not None:
226
+ hidden_states = self.norm(hidden_states)
227
+
228
+ return hidden_states, new_past_key_values
229
+
230
+ def get_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
231
+ """
232
+ 计算 logits(仅最后一个分片)
233
+
234
+ Args:
235
+ hidden_states: 隐藏状态
236
+
237
+ Returns:
238
+ logits
239
+ """
240
+ if not self.is_last_shard:
241
+ raise RuntimeError("get_logits() can only be called on the last shard")
242
+
243
+ if hasattr(self, 'lm_head') and self.lm_head is not None:
244
+ return self.lm_head(hidden_states)
245
+
246
+ raise RuntimeError("No lm_head available")
247
+
248
+ def get_memory_usage(self) -> float:
249
+ """获取显存使用量(GB)"""
250
+ total_bytes = sum(
251
+ p.numel() * p.element_size()
252
+ for p in self.parameters()
253
+ )
254
+ return total_bytes / (1024 ** 3)
255
+
256
+ def get_layer_count(self) -> int:
257
+ """获取层数"""
258
+ return len(self.layers)
259
+
260
+
261
+ class ShardedModelLoader:
262
+ """
263
+ 分片模型加载器
264
+
265
+ 管理模型分片的加载和分发
266
+ """
267
+
268
+ def __init__(self, model_id: str):
269
+ self.model_id = model_id
270
+ self.config = None
271
+ self.total_layers = 0
272
+
273
+ def analyze_model(self) -> Dict[str, Any]:
274
+ """
275
+ 分析模型结构
276
+
277
+ Returns:
278
+ 模型信息字典
279
+ """
280
+ from transformers import AutoConfig
281
+
282
+ self.config = AutoConfig.from_pretrained(
283
+ self.model_id,
284
+ trust_remote_code=True
285
+ )
286
+
287
+ self.total_layers = getattr(self.config, "num_hidden_layers", 32)
288
+ hidden_size = getattr(self.config, "hidden_size", 4096)
289
+ num_heads = getattr(self.config, "num_attention_heads", 32)
290
+ num_kv_heads = getattr(self.config, "num_key_value_heads", num_heads)
291
+ intermediate_size = getattr(self.config, "intermediate_size", hidden_size * 4)
292
+
293
+ # 估算每层内存(粗略)
294
+ head_dim = hidden_size // num_heads
295
+ attention_params = hidden_size * (num_heads + 2 * num_kv_heads) * head_dim
296
+ mlp_params = hidden_size * intermediate_size * 3 # gate, up, down
297
+ layer_params = attention_params + mlp_params
298
+ bytes_per_param = 2 # float16
299
+
300
+ memory_per_layer_gb = layer_params * bytes_per_param / (1024 ** 3)
301
+
302
+ return {
303
+ "model_id": self.model_id,
304
+ "total_layers": self.total_layers,
305
+ "hidden_size": hidden_size,
306
+ "num_attention_heads": num_heads,
307
+ "num_key_value_heads": num_kv_heads,
308
+ "intermediate_size": intermediate_size,
309
+ "memory_per_layer_gb": memory_per_layer_gb,
310
+ "total_memory_gb": memory_per_layer_gb * self.total_layers,
311
+ }
312
+
313
+ def create_shard_plan(
314
+ self,
315
+ worker_memory_gb: List[float],
316
+ reserve_ratio: float = 0.2,
317
+ ) -> List[Tuple[int, int]]:
318
+ """
319
+ 创建分片计划
320
+
321
+ Args:
322
+ worker_memory_gb: 每个 Worker 的可用显存
323
+ reserve_ratio: 预留显存比例(用于 KV-Cache)
324
+
325
+ Returns:
326
+ [(start_layer, end_layer), ...] 每个 Worker 的层范围
327
+ """
328
+ if not self.config:
329
+ self.analyze_model()
330
+
331
+ model_info = self.analyze_model()
332
+ memory_per_layer = model_info["memory_per_layer_gb"]
333
+
334
+ # 计算每个 Worker 可以承载的层数
335
+ available_memory = [
336
+ mem * (1 - reserve_ratio) for mem in worker_memory_gb
337
+ ]
338
+ layers_per_worker = [
339
+ int(mem / memory_per_layer) for mem in available_memory
340
+ ]
341
+
342
+ # 确保总层数足够
343
+ total_capacity = sum(layers_per_worker)
344
+ if total_capacity < self.total_layers:
345
+ raise ValueError(
346
+ f"Insufficient memory: can only fit {total_capacity} layers, "
347
+ f"but model has {self.total_layers} layers"
348
+ )
349
+
350
+ # 分配层
351
+ shard_plan = []
352
+ current_layer = 0
353
+
354
+ for i, capacity in enumerate(layers_per_worker):
355
+ # 按比例分配
356
+ if i == len(layers_per_worker) - 1:
357
+ # 最后一个 Worker 获取剩余所有层
358
+ end_layer = self.total_layers
359
+ else:
360
+ # 按容量比例分配
361
+ ratio = capacity / total_capacity
362
+ num_layers = max(1, int(self.total_layers * ratio))
363
+ end_layer = min(current_layer + num_layers, self.total_layers)
364
+
365
+ if current_layer < self.total_layers:
366
+ shard_plan.append((current_layer, end_layer))
367
+ current_layer = end_layer
368
+
369
+ return shard_plan
370
+
371
+
372
+ def get_layer_range_for_worker(
373
+ total_layers: int,
374
+ num_workers: int,
375
+ worker_idx: int,
376
+ ) -> Tuple[int, int]:
377
+ """
378
+ 计算 Worker 的层范围(均匀分配)
379
+
380
+ Args:
381
+ total_layers: 总层数
382
+ num_workers: Worker 数量
383
+ worker_idx: 当前 Worker 索引
384
+
385
+ Returns:
386
+ (start_layer, end_layer)
387
+ """
388
+ layers_per_worker = total_layers // num_workers
389
+ remainder = total_layers % num_workers
390
+
391
+ start = worker_idx * layers_per_worker + min(worker_idx, remainder)
392
+ end = start + layers_per_worker + (1 if worker_idx < remainder else 0)
393
+
394
+ return start, end
395
+
396
+
397
+ # 辅助函数
398
+
399
+ def _get_layer_module(model: nn.Module, config) -> Optional[nn.ModuleList]:
400
+ """获取模型的层模块"""
401
+ # 尝试不同的模型架构
402
+ if hasattr(model, "model"):
403
+ if hasattr(model.model, "layers"):
404
+ return model.model.layers
405
+ if hasattr(model.model, "decoder") and hasattr(model.model.decoder, "layers"):
406
+ return model.model.decoder.layers
407
+
408
+ if hasattr(model, "transformer"):
409
+ if hasattr(model.transformer, "h"):
410
+ return model.transformer.h
411
+ if hasattr(model.transformer, "layers"):
412
+ return model.transformer.layers
413
+
414
+ return None
415
+
416
+
417
+ def _get_embedding_module(model: nn.Module, config) -> Optional[nn.Module]:
418
+ """获取 embedding 模块"""
419
+ if hasattr(model, "model") and hasattr(model.model, "embed_tokens"):
420
+ return model.model.embed_tokens
421
+ if hasattr(model, "transformer") and hasattr(model.transformer, "wte"):
422
+ return model.transformer.wte
423
+ return None
424
+
425
+
426
+ def _get_norm_module(model: nn.Module, config) -> Optional[nn.Module]:
427
+ """获取最终 norm 模块"""
428
+ if hasattr(model, "model") and hasattr(model.model, "norm"):
429
+ return model.model.norm
430
+ if hasattr(model, "transformer") and hasattr(model.transformer, "ln_f"):
431
+ return model.transformer.ln_f
432
+ return None
433
+
434
+
435
+ def _create_device_map_for_layers(
436
+ config,
437
+ start_layer: int,
438
+ end_layer: int,
439
+ device: str,
440
+ include_embeddings: bool = False,
441
+ include_lm_head: bool = False,
442
+ ) -> Dict[str, str]:
443
+ """
444
+ 创建 device_map 用于选择性加载层
445
+ """
446
+ device_map = {}
447
+
448
+ # 模型前缀(不同模型可能不同)
449
+ model_prefix = "model"
450
+
451
+ if include_embeddings:
452
+ device_map[f"{model_prefix}.embed_tokens"] = device
453
+
454
+ # 层
455
+ for i in range(start_layer, end_layer):
456
+ device_map[f"{model_prefix}.layers.{i}"] = device
457
+
458
+ if include_lm_head:
459
+ device_map[f"{model_prefix}.norm"] = device
460
+ device_map["lm_head"] = device
461
+
462
+ # 其他层放到 CPU 或丢弃
463
+ # 注意:这里简化处理,实际应该更细致
464
+
465
+ return device_map