gpu-worker 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +115 -0
- package/api_client.py +288 -0
- package/batch_processor.py +436 -0
- package/bin/gpu-worker.js +275 -0
- package/cli.py +729 -0
- package/config.2gb.yaml +32 -0
- package/config.8gb.yaml +29 -0
- package/config.example.yaml +72 -0
- package/config.py +213 -0
- package/direct_server.py +140 -0
- package/distributed/__init__.py +35 -0
- package/distributed/grpc_server.py +561 -0
- package/distributed/kv_cache.py +555 -0
- package/distributed/model_shard.py +465 -0
- package/distributed/session.py +455 -0
- package/engines/__init__.py +215 -0
- package/engines/base.py +57 -0
- package/engines/image_gen.py +83 -0
- package/engines/llm.py +97 -0
- package/engines/llm_base.py +216 -0
- package/engines/llm_sglang.py +489 -0
- package/engines/llm_vllm.py +539 -0
- package/engines/speculative.py +513 -0
- package/engines/vision.py +139 -0
- package/machine_id.py +200 -0
- package/main.py +521 -0
- package/package.json +64 -0
- package/requirements-sglang.txt +12 -0
- package/requirements-vllm.txt +15 -0
- package/requirements.txt +35 -0
- package/scripts/postinstall.js +60 -0
- package/setup.py +43 -0
|
@@ -0,0 +1,465 @@
|
|
|
1
|
+
"""
|
|
2
|
+
模型分片加载器
|
|
3
|
+
|
|
4
|
+
支持按 Transformer Block 范围加载模型的部分层,
|
|
5
|
+
用于分布式推理场景。
|
|
6
|
+
|
|
7
|
+
参考 Petals 的 from_pretrained 实现。
|
|
8
|
+
"""
|
|
9
|
+
import logging
|
|
10
|
+
from typing import Dict, Any, List, Optional, Tuple
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
import torch.nn as nn
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class LayerInfo:
|
|
21
|
+
"""层信息"""
|
|
22
|
+
layer_idx: int
|
|
23
|
+
layer_name: str
|
|
24
|
+
param_count: int
|
|
25
|
+
memory_bytes: int
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ModelShard(nn.Module):
|
|
29
|
+
"""
|
|
30
|
+
模型分片
|
|
31
|
+
|
|
32
|
+
只加载模型的指定层范围,用于分布式推理
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
model_id: str,
|
|
38
|
+
start_layer: int,
|
|
39
|
+
end_layer: int,
|
|
40
|
+
device: str = "cuda",
|
|
41
|
+
dtype: torch.dtype = torch.float16,
|
|
42
|
+
):
|
|
43
|
+
super().__init__()
|
|
44
|
+
self.model_id = model_id
|
|
45
|
+
self.start_layer = start_layer
|
|
46
|
+
self.end_layer = end_layer
|
|
47
|
+
self.device = device
|
|
48
|
+
self.dtype = dtype
|
|
49
|
+
|
|
50
|
+
# 模型组件(加载后填充)
|
|
51
|
+
self.layers: nn.ModuleList = nn.ModuleList()
|
|
52
|
+
self.config = None
|
|
53
|
+
self.is_first_shard = False # 是否包含 embedding
|
|
54
|
+
self.is_last_shard = False # 是否包含 lm_head
|
|
55
|
+
|
|
56
|
+
# 元数据
|
|
57
|
+
self.total_layers = 0
|
|
58
|
+
self.hidden_size = 0
|
|
59
|
+
self.num_heads = 0
|
|
60
|
+
|
|
61
|
+
@classmethod
|
|
62
|
+
def from_pretrained(
|
|
63
|
+
cls,
|
|
64
|
+
model_id: str,
|
|
65
|
+
start_layer: int,
|
|
66
|
+
end_layer: int,
|
|
67
|
+
device: str = "cuda",
|
|
68
|
+
dtype: torch.dtype = torch.float16,
|
|
69
|
+
**kwargs
|
|
70
|
+
) -> "ModelShard":
|
|
71
|
+
"""
|
|
72
|
+
加载模型分片
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
model_id: HuggingFace 模型 ID
|
|
76
|
+
start_layer: 起始层(包含)
|
|
77
|
+
end_layer: 结束层(不包含)
|
|
78
|
+
device: 目标设备
|
|
79
|
+
dtype: 数据类型
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
ModelShard 实例
|
|
83
|
+
"""
|
|
84
|
+
from transformers import AutoConfig, AutoModelForCausalLM
|
|
85
|
+
|
|
86
|
+
logger.info(f"Loading model shard: {model_id} layers [{start_layer}, {end_layer})")
|
|
87
|
+
|
|
88
|
+
# 加载配置
|
|
89
|
+
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
|
|
90
|
+
|
|
91
|
+
# 获取总层数
|
|
92
|
+
total_layers = getattr(config, "num_hidden_layers", None)
|
|
93
|
+
if total_layers is None:
|
|
94
|
+
total_layers = getattr(config, "n_layer", 32) # 默认值
|
|
95
|
+
|
|
96
|
+
if end_layer > total_layers:
|
|
97
|
+
raise ValueError(f"end_layer ({end_layer}) > total_layers ({total_layers})")
|
|
98
|
+
|
|
99
|
+
# 创建分片实例
|
|
100
|
+
shard = cls(model_id, start_layer, end_layer, device, dtype)
|
|
101
|
+
shard.config = config
|
|
102
|
+
shard.total_layers = total_layers
|
|
103
|
+
shard.hidden_size = getattr(config, "hidden_size", 4096)
|
|
104
|
+
shard.num_heads = getattr(config, "num_attention_heads", 32)
|
|
105
|
+
shard.is_first_shard = (start_layer == 0)
|
|
106
|
+
shard.is_last_shard = (end_layer == total_layers)
|
|
107
|
+
|
|
108
|
+
# 加载完整模型然后提取需要的层
|
|
109
|
+
# 注意:对于非常大的模型,应该使用更高效的方法
|
|
110
|
+
logger.info("Loading full model for layer extraction...")
|
|
111
|
+
|
|
112
|
+
with torch.device("meta"):
|
|
113
|
+
# 先在 meta 设备上创建模型结构
|
|
114
|
+
full_model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
|
|
115
|
+
|
|
116
|
+
# 获取层的名称模式
|
|
117
|
+
layer_module = _get_layer_module(full_model, config)
|
|
118
|
+
|
|
119
|
+
# 使用 from_pretrained 的 device_map 功能只加载需要的层
|
|
120
|
+
device_map = _create_device_map_for_layers(
|
|
121
|
+
config,
|
|
122
|
+
start_layer,
|
|
123
|
+
end_layer,
|
|
124
|
+
device,
|
|
125
|
+
include_embeddings=shard.is_first_shard,
|
|
126
|
+
include_lm_head=shard.is_last_shard,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
logger.info(f"Loading layers with device_map: {list(device_map.keys())[:5]}...")
|
|
130
|
+
|
|
131
|
+
full_model = AutoModelForCausalLM.from_pretrained(
|
|
132
|
+
model_id,
|
|
133
|
+
torch_dtype=dtype,
|
|
134
|
+
device_map=device_map,
|
|
135
|
+
trust_remote_code=True,
|
|
136
|
+
**kwargs
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# 提取需要的层
|
|
140
|
+
shard._extract_layers(full_model, config)
|
|
141
|
+
|
|
142
|
+
# 清理不需要的组件
|
|
143
|
+
del full_model
|
|
144
|
+
torch.cuda.empty_cache()
|
|
145
|
+
|
|
146
|
+
logger.info(f"Model shard loaded: {shard.get_memory_usage():.2f} GB")
|
|
147
|
+
|
|
148
|
+
return shard
|
|
149
|
+
|
|
150
|
+
def _extract_layers(self, full_model: nn.Module, config) -> None:
|
|
151
|
+
"""从完整模型中提取需要的层"""
|
|
152
|
+
# 获取层模块
|
|
153
|
+
layers = _get_layer_module(full_model, config)
|
|
154
|
+
|
|
155
|
+
if layers is None:
|
|
156
|
+
raise RuntimeError(f"Cannot find layers in model architecture")
|
|
157
|
+
|
|
158
|
+
# 提取指定范围的层
|
|
159
|
+
for i in range(self.start_layer, self.end_layer):
|
|
160
|
+
layer = layers[i]
|
|
161
|
+
self.layers.append(layer)
|
|
162
|
+
|
|
163
|
+
# 如果是第一个分片,保存 embedding
|
|
164
|
+
if self.is_first_shard:
|
|
165
|
+
self.embed_tokens = _get_embedding_module(full_model, config)
|
|
166
|
+
self.embed_positions = getattr(full_model.model, "embed_positions", None)
|
|
167
|
+
|
|
168
|
+
# 如果是最后一个分片,保存 lm_head 和 norm
|
|
169
|
+
if self.is_last_shard:
|
|
170
|
+
self.lm_head = full_model.lm_head if hasattr(full_model, "lm_head") else None
|
|
171
|
+
self.norm = _get_norm_module(full_model, config)
|
|
172
|
+
|
|
173
|
+
def forward(
|
|
174
|
+
self,
|
|
175
|
+
hidden_states: torch.Tensor,
|
|
176
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
177
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
178
|
+
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
|
179
|
+
use_cache: bool = True,
|
|
180
|
+
) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
|
|
181
|
+
"""
|
|
182
|
+
前向传播
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
hidden_states: 输入隐藏状态 [batch, seq_len, hidden_size]
|
|
186
|
+
attention_mask: 注意力掩码
|
|
187
|
+
position_ids: 位置 ID
|
|
188
|
+
past_key_values: 过去的 KV-Cache
|
|
189
|
+
use_cache: 是否返回新的 KV-Cache
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
(hidden_states, new_past_key_values)
|
|
193
|
+
"""
|
|
194
|
+
# 如果是第一个分片,需要处理 embedding
|
|
195
|
+
if self.is_first_shard and hasattr(self, 'embed_tokens'):
|
|
196
|
+
# hidden_states 此时是 input_ids
|
|
197
|
+
hidden_states = self.embed_tokens(hidden_states)
|
|
198
|
+
if hasattr(self, 'embed_positions') and self.embed_positions is not None:
|
|
199
|
+
hidden_states = hidden_states + self.embed_positions(position_ids)
|
|
200
|
+
|
|
201
|
+
# 依次通过每一层
|
|
202
|
+
new_past_key_values = [] if use_cache else None
|
|
203
|
+
past_idx = 0
|
|
204
|
+
|
|
205
|
+
for layer in self.layers:
|
|
206
|
+
layer_past = past_key_values[past_idx] if past_key_values else None
|
|
207
|
+
|
|
208
|
+
# 调用层的 forward
|
|
209
|
+
layer_outputs = layer(
|
|
210
|
+
hidden_states,
|
|
211
|
+
attention_mask=attention_mask,
|
|
212
|
+
position_ids=position_ids,
|
|
213
|
+
past_key_value=layer_past,
|
|
214
|
+
use_cache=use_cache,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
hidden_states = layer_outputs[0]
|
|
218
|
+
|
|
219
|
+
if use_cache:
|
|
220
|
+
new_past_key_values.append(layer_outputs[1])
|
|
221
|
+
|
|
222
|
+
past_idx += 1
|
|
223
|
+
|
|
224
|
+
# 如果是最后一个分片,应用 norm
|
|
225
|
+
if self.is_last_shard and hasattr(self, 'norm') and self.norm is not None:
|
|
226
|
+
hidden_states = self.norm(hidden_states)
|
|
227
|
+
|
|
228
|
+
return hidden_states, new_past_key_values
|
|
229
|
+
|
|
230
|
+
def get_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
231
|
+
"""
|
|
232
|
+
计算 logits(仅最后一个分片)
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
hidden_states: 隐藏状态
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
logits
|
|
239
|
+
"""
|
|
240
|
+
if not self.is_last_shard:
|
|
241
|
+
raise RuntimeError("get_logits() can only be called on the last shard")
|
|
242
|
+
|
|
243
|
+
if hasattr(self, 'lm_head') and self.lm_head is not None:
|
|
244
|
+
return self.lm_head(hidden_states)
|
|
245
|
+
|
|
246
|
+
raise RuntimeError("No lm_head available")
|
|
247
|
+
|
|
248
|
+
def get_memory_usage(self) -> float:
|
|
249
|
+
"""获取显存使用量(GB)"""
|
|
250
|
+
total_bytes = sum(
|
|
251
|
+
p.numel() * p.element_size()
|
|
252
|
+
for p in self.parameters()
|
|
253
|
+
)
|
|
254
|
+
return total_bytes / (1024 ** 3)
|
|
255
|
+
|
|
256
|
+
def get_layer_count(self) -> int:
|
|
257
|
+
"""获取层数"""
|
|
258
|
+
return len(self.layers)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
class ShardedModelLoader:
|
|
262
|
+
"""
|
|
263
|
+
分片模型加载器
|
|
264
|
+
|
|
265
|
+
管理模型分片的加载和分发
|
|
266
|
+
"""
|
|
267
|
+
|
|
268
|
+
def __init__(self, model_id: str):
|
|
269
|
+
self.model_id = model_id
|
|
270
|
+
self.config = None
|
|
271
|
+
self.total_layers = 0
|
|
272
|
+
|
|
273
|
+
def analyze_model(self) -> Dict[str, Any]:
|
|
274
|
+
"""
|
|
275
|
+
分析模型结构
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
模型信息字典
|
|
279
|
+
"""
|
|
280
|
+
from transformers import AutoConfig
|
|
281
|
+
|
|
282
|
+
self.config = AutoConfig.from_pretrained(
|
|
283
|
+
self.model_id,
|
|
284
|
+
trust_remote_code=True
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
self.total_layers = getattr(self.config, "num_hidden_layers", 32)
|
|
288
|
+
hidden_size = getattr(self.config, "hidden_size", 4096)
|
|
289
|
+
num_heads = getattr(self.config, "num_attention_heads", 32)
|
|
290
|
+
num_kv_heads = getattr(self.config, "num_key_value_heads", num_heads)
|
|
291
|
+
intermediate_size = getattr(self.config, "intermediate_size", hidden_size * 4)
|
|
292
|
+
|
|
293
|
+
# 估算每层内存(粗略)
|
|
294
|
+
head_dim = hidden_size // num_heads
|
|
295
|
+
attention_params = hidden_size * (num_heads + 2 * num_kv_heads) * head_dim
|
|
296
|
+
mlp_params = hidden_size * intermediate_size * 3 # gate, up, down
|
|
297
|
+
layer_params = attention_params + mlp_params
|
|
298
|
+
bytes_per_param = 2 # float16
|
|
299
|
+
|
|
300
|
+
memory_per_layer_gb = layer_params * bytes_per_param / (1024 ** 3)
|
|
301
|
+
|
|
302
|
+
return {
|
|
303
|
+
"model_id": self.model_id,
|
|
304
|
+
"total_layers": self.total_layers,
|
|
305
|
+
"hidden_size": hidden_size,
|
|
306
|
+
"num_attention_heads": num_heads,
|
|
307
|
+
"num_key_value_heads": num_kv_heads,
|
|
308
|
+
"intermediate_size": intermediate_size,
|
|
309
|
+
"memory_per_layer_gb": memory_per_layer_gb,
|
|
310
|
+
"total_memory_gb": memory_per_layer_gb * self.total_layers,
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
def create_shard_plan(
|
|
314
|
+
self,
|
|
315
|
+
worker_memory_gb: List[float],
|
|
316
|
+
reserve_ratio: float = 0.2,
|
|
317
|
+
) -> List[Tuple[int, int]]:
|
|
318
|
+
"""
|
|
319
|
+
创建分片计划
|
|
320
|
+
|
|
321
|
+
Args:
|
|
322
|
+
worker_memory_gb: 每个 Worker 的可用显存
|
|
323
|
+
reserve_ratio: 预留显存比例(用于 KV-Cache)
|
|
324
|
+
|
|
325
|
+
Returns:
|
|
326
|
+
[(start_layer, end_layer), ...] 每个 Worker 的层范围
|
|
327
|
+
"""
|
|
328
|
+
if not self.config:
|
|
329
|
+
self.analyze_model()
|
|
330
|
+
|
|
331
|
+
model_info = self.analyze_model()
|
|
332
|
+
memory_per_layer = model_info["memory_per_layer_gb"]
|
|
333
|
+
|
|
334
|
+
# 计算每个 Worker 可以承载的层数
|
|
335
|
+
available_memory = [
|
|
336
|
+
mem * (1 - reserve_ratio) for mem in worker_memory_gb
|
|
337
|
+
]
|
|
338
|
+
layers_per_worker = [
|
|
339
|
+
int(mem / memory_per_layer) for mem in available_memory
|
|
340
|
+
]
|
|
341
|
+
|
|
342
|
+
# 确保总层数足够
|
|
343
|
+
total_capacity = sum(layers_per_worker)
|
|
344
|
+
if total_capacity < self.total_layers:
|
|
345
|
+
raise ValueError(
|
|
346
|
+
f"Insufficient memory: can only fit {total_capacity} layers, "
|
|
347
|
+
f"but model has {self.total_layers} layers"
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
# 分配层
|
|
351
|
+
shard_plan = []
|
|
352
|
+
current_layer = 0
|
|
353
|
+
|
|
354
|
+
for i, capacity in enumerate(layers_per_worker):
|
|
355
|
+
# 按比例分配
|
|
356
|
+
if i == len(layers_per_worker) - 1:
|
|
357
|
+
# 最后一个 Worker 获取剩余所有层
|
|
358
|
+
end_layer = self.total_layers
|
|
359
|
+
else:
|
|
360
|
+
# 按容量比例分配
|
|
361
|
+
ratio = capacity / total_capacity
|
|
362
|
+
num_layers = max(1, int(self.total_layers * ratio))
|
|
363
|
+
end_layer = min(current_layer + num_layers, self.total_layers)
|
|
364
|
+
|
|
365
|
+
if current_layer < self.total_layers:
|
|
366
|
+
shard_plan.append((current_layer, end_layer))
|
|
367
|
+
current_layer = end_layer
|
|
368
|
+
|
|
369
|
+
return shard_plan
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
def get_layer_range_for_worker(
|
|
373
|
+
total_layers: int,
|
|
374
|
+
num_workers: int,
|
|
375
|
+
worker_idx: int,
|
|
376
|
+
) -> Tuple[int, int]:
|
|
377
|
+
"""
|
|
378
|
+
计算 Worker 的层范围(均匀分配)
|
|
379
|
+
|
|
380
|
+
Args:
|
|
381
|
+
total_layers: 总层数
|
|
382
|
+
num_workers: Worker 数量
|
|
383
|
+
worker_idx: 当前 Worker 索引
|
|
384
|
+
|
|
385
|
+
Returns:
|
|
386
|
+
(start_layer, end_layer)
|
|
387
|
+
"""
|
|
388
|
+
layers_per_worker = total_layers // num_workers
|
|
389
|
+
remainder = total_layers % num_workers
|
|
390
|
+
|
|
391
|
+
start = worker_idx * layers_per_worker + min(worker_idx, remainder)
|
|
392
|
+
end = start + layers_per_worker + (1 if worker_idx < remainder else 0)
|
|
393
|
+
|
|
394
|
+
return start, end
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
# 辅助函数
|
|
398
|
+
|
|
399
|
+
def _get_layer_module(model: nn.Module, config) -> Optional[nn.ModuleList]:
|
|
400
|
+
"""获取模型的层模块"""
|
|
401
|
+
# 尝试不同的模型架构
|
|
402
|
+
if hasattr(model, "model"):
|
|
403
|
+
if hasattr(model.model, "layers"):
|
|
404
|
+
return model.model.layers
|
|
405
|
+
if hasattr(model.model, "decoder") and hasattr(model.model.decoder, "layers"):
|
|
406
|
+
return model.model.decoder.layers
|
|
407
|
+
|
|
408
|
+
if hasattr(model, "transformer"):
|
|
409
|
+
if hasattr(model.transformer, "h"):
|
|
410
|
+
return model.transformer.h
|
|
411
|
+
if hasattr(model.transformer, "layers"):
|
|
412
|
+
return model.transformer.layers
|
|
413
|
+
|
|
414
|
+
return None
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
def _get_embedding_module(model: nn.Module, config) -> Optional[nn.Module]:
|
|
418
|
+
"""获取 embedding 模块"""
|
|
419
|
+
if hasattr(model, "model") and hasattr(model.model, "embed_tokens"):
|
|
420
|
+
return model.model.embed_tokens
|
|
421
|
+
if hasattr(model, "transformer") and hasattr(model.transformer, "wte"):
|
|
422
|
+
return model.transformer.wte
|
|
423
|
+
return None
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def _get_norm_module(model: nn.Module, config) -> Optional[nn.Module]:
|
|
427
|
+
"""获取最终 norm 模块"""
|
|
428
|
+
if hasattr(model, "model") and hasattr(model.model, "norm"):
|
|
429
|
+
return model.model.norm
|
|
430
|
+
if hasattr(model, "transformer") and hasattr(model.transformer, "ln_f"):
|
|
431
|
+
return model.transformer.ln_f
|
|
432
|
+
return None
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
def _create_device_map_for_layers(
|
|
436
|
+
config,
|
|
437
|
+
start_layer: int,
|
|
438
|
+
end_layer: int,
|
|
439
|
+
device: str,
|
|
440
|
+
include_embeddings: bool = False,
|
|
441
|
+
include_lm_head: bool = False,
|
|
442
|
+
) -> Dict[str, str]:
|
|
443
|
+
"""
|
|
444
|
+
创建 device_map 用于选择性加载层
|
|
445
|
+
"""
|
|
446
|
+
device_map = {}
|
|
447
|
+
|
|
448
|
+
# 模型前缀(不同模型可能不同)
|
|
449
|
+
model_prefix = "model"
|
|
450
|
+
|
|
451
|
+
if include_embeddings:
|
|
452
|
+
device_map[f"{model_prefix}.embed_tokens"] = device
|
|
453
|
+
|
|
454
|
+
# 层
|
|
455
|
+
for i in range(start_layer, end_layer):
|
|
456
|
+
device_map[f"{model_prefix}.layers.{i}"] = device
|
|
457
|
+
|
|
458
|
+
if include_lm_head:
|
|
459
|
+
device_map[f"{model_prefix}.norm"] = device
|
|
460
|
+
device_map["lm_head"] = device
|
|
461
|
+
|
|
462
|
+
# 其他层放到 CPU 或丢弃
|
|
463
|
+
# 注意:这里简化处理,实际应该更细致
|
|
464
|
+
|
|
465
|
+
return device_map
|