sglang 0.3.6.post3__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +4 -0
- sglang/bench_serving.py +13 -0
- sglang/check_env.py +1 -1
- sglang/srt/_custom_ops.py +118 -0
- sglang/srt/configs/device_config.py +17 -0
- sglang/srt/configs/load_config.py +84 -0
- sglang/srt/configs/model_config.py +161 -4
- sglang/srt/configs/qwen2vl.py +5 -8
- sglang/srt/constrained/outlines_backend.py +6 -1
- sglang/srt/constrained/outlines_jump_forward.py +8 -1
- sglang/srt/distributed/__init__.py +3 -0
- sglang/srt/distributed/communication_op.py +34 -0
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
- sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
- sglang/srt/distributed/device_communicators/pynccl.py +204 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
- sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
- sglang/srt/distributed/parallel_state.py +1275 -0
- sglang/srt/distributed/utils.py +223 -0
- sglang/srt/hf_transformers_utils.py +37 -1
- sglang/srt/layers/attention/flashinfer_backend.py +13 -15
- sglang/srt/layers/attention/torch_native_backend.py +285 -0
- sglang/srt/layers/fused_moe_patch.py +20 -11
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +17 -3
- sglang/srt/layers/quantization/__init__.py +34 -0
- sglang/srt/layers/vocab_parallel_embedding.py +1 -0
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/managers/io_struct.py +48 -2
- sglang/srt/managers/schedule_batch.py +18 -14
- sglang/srt/managers/schedule_policy.py +7 -4
- sglang/srt/managers/scheduler.py +76 -20
- sglang/srt/managers/tokenizer_manager.py +166 -68
- sglang/srt/managers/tp_worker.py +36 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +21 -3
- sglang/srt/model_executor/cuda_graph_runner.py +16 -7
- sglang/srt/model_executor/forward_batch_info.py +9 -4
- sglang/srt/model_executor/model_runner.py +136 -150
- sglang/srt/model_loader/__init__.py +34 -0
- sglang/srt/model_loader/loader.py +1139 -0
- sglang/srt/model_loader/utils.py +41 -0
- sglang/srt/model_loader/weight_utils.py +640 -0
- sglang/srt/models/baichuan.py +9 -10
- sglang/srt/models/chatglm.py +6 -15
- sglang/srt/models/commandr.py +2 -3
- sglang/srt/models/dbrx.py +2 -3
- sglang/srt/models/deepseek.py +4 -11
- sglang/srt/models/deepseek_v2.py +3 -11
- sglang/srt/models/exaone.py +2 -3
- sglang/srt/models/gemma.py +2 -6
- sglang/srt/models/gemma2.py +3 -14
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/gpt2.py +5 -12
- sglang/srt/models/gpt_bigcode.py +6 -22
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/internlm2.py +2 -3
- sglang/srt/models/internlm2_reward.py +0 -1
- sglang/srt/models/llama.py +97 -27
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_embedding.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +1 -4
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +4 -7
- sglang/srt/models/minicpm3.py +6 -19
- sglang/srt/models/mixtral.py +12 -5
- sglang/srt/models/mixtral_quant.py +2 -3
- sglang/srt/models/mllama.py +3 -7
- sglang/srt/models/olmo.py +2 -8
- sglang/srt/models/olmo2.py +0 -1
- sglang/srt/models/olmoe.py +3 -5
- sglang/srt/models/phi3_small.py +8 -8
- sglang/srt/models/qwen.py +2 -3
- sglang/srt/models/qwen2.py +10 -9
- sglang/srt/models/qwen2_moe.py +4 -11
- sglang/srt/models/qwen2_vl.py +2 -6
- sglang/srt/models/registry.py +99 -0
- sglang/srt/models/stablelm.py +2 -3
- sglang/srt/models/torch_native_llama.py +6 -12
- sglang/srt/models/xverse.py +2 -4
- sglang/srt/models/xverse_moe.py +4 -11
- sglang/srt/models/yivl.py +2 -3
- sglang/srt/openai_api/adapter.py +9 -5
- sglang/srt/openai_api/protocol.py +1 -0
- sglang/srt/server.py +267 -170
- sglang/srt/server_args.py +65 -31
- sglang/srt/utils.py +245 -28
- sglang/test/test_utils.py +7 -0
- sglang/version.py +1 -1
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/METADATA +1 -1
- sglang-0.4.0.dist-info/RECORD +184 -0
- sglang-0.3.6.post3.dist-info/RECORD +0 -162
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
sglang/srt/models/gpt_bigcode.py
CHANGED
@@ -21,9 +21,7 @@ from typing import Iterable, Optional, Tuple
|
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
23
23
|
from transformers import GPTBigCodeConfig
|
24
|
-
from vllm.config import LoRAConfig
|
25
24
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
26
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
27
25
|
|
28
26
|
from sglang.srt.layers.activation import get_act_fn
|
29
27
|
from sglang.srt.layers.linear import (
|
@@ -36,6 +34,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
36
34
|
from sglang.srt.layers.radix_attention import RadixAttention
|
37
35
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
38
36
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
37
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
39
38
|
|
40
39
|
|
41
40
|
class GPTBigCodeAttention(nn.Module):
|
@@ -44,7 +43,6 @@ class GPTBigCodeAttention(nn.Module):
|
|
44
43
|
self,
|
45
44
|
layer_id: int,
|
46
45
|
config: GPTBigCodeConfig,
|
47
|
-
cache_config=None,
|
48
46
|
quant_config: Optional[QuantizationConfig] = None,
|
49
47
|
):
|
50
48
|
super().__init__()
|
@@ -145,7 +143,6 @@ class GPTBigCodeBlock(nn.Module):
|
|
145
143
|
self,
|
146
144
|
layer_id: int,
|
147
145
|
config: GPTBigCodeConfig,
|
148
|
-
cache_config=None,
|
149
146
|
quant_config: Optional[QuantizationConfig] = None,
|
150
147
|
):
|
151
148
|
super().__init__()
|
@@ -153,7 +150,7 @@ class GPTBigCodeBlock(nn.Module):
|
|
153
150
|
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
154
151
|
|
155
152
|
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
156
|
-
self.attn = GPTBigCodeAttention(layer_id, config,
|
153
|
+
self.attn = GPTBigCodeAttention(layer_id, config, quant_config)
|
157
154
|
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
158
155
|
self.mlp = GPTBigMLP(inner_dim, config, quant_config)
|
159
156
|
|
@@ -183,20 +180,14 @@ class GPTBigCodeModel(nn.Module):
|
|
183
180
|
def __init__(
|
184
181
|
self,
|
185
182
|
config: GPTBigCodeConfig,
|
186
|
-
cache_config=None,
|
187
183
|
quant_config: Optional[QuantizationConfig] = None,
|
188
|
-
lora_config: Optional[LoRAConfig] = None,
|
189
184
|
):
|
190
185
|
super().__init__()
|
191
186
|
self.config = config
|
192
187
|
assert not config.add_cross_attention
|
193
188
|
|
194
189
|
self.embed_dim = config.hidden_size
|
195
|
-
lora_vocab =
|
196
|
-
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
|
197
|
-
if lora_config
|
198
|
-
else 0
|
199
|
-
)
|
190
|
+
lora_vocab = 0
|
200
191
|
self.vocab_size = config.vocab_size + lora_vocab
|
201
192
|
self.wte = VocabParallelEmbedding(
|
202
193
|
self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size
|
@@ -204,7 +195,7 @@ class GPTBigCodeModel(nn.Module):
|
|
204
195
|
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
205
196
|
self.h = nn.ModuleList(
|
206
197
|
[
|
207
|
-
GPTBigCodeBlock(i, config,
|
198
|
+
GPTBigCodeBlock(i, config, quant_config)
|
208
199
|
for i in range(config.num_hidden_layers)
|
209
200
|
]
|
210
201
|
)
|
@@ -243,23 +234,16 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|
243
234
|
def __init__(
|
244
235
|
self,
|
245
236
|
config: GPTBigCodeConfig,
|
246
|
-
cache_config=None,
|
247
237
|
quant_config: Optional[QuantizationConfig] = None,
|
248
|
-
lora_config: Optional[LoRAConfig] = None,
|
249
238
|
):
|
250
239
|
super().__init__()
|
251
240
|
|
252
241
|
self.config = config
|
253
|
-
self.lora_config = lora_config
|
254
242
|
|
255
243
|
self.quant_config = quant_config
|
256
|
-
self.transformer = GPTBigCodeModel(
|
257
|
-
config, cache_config, quant_config, lora_config
|
258
|
-
)
|
244
|
+
self.transformer = GPTBigCodeModel(config, quant_config)
|
259
245
|
self.lm_head = self.transformer.wte
|
260
246
|
self.unpadded_vocab_size = config.vocab_size
|
261
|
-
if lora_config:
|
262
|
-
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
263
247
|
self.logits_processor = LogitsProcessor(config)
|
264
248
|
|
265
249
|
@torch.no_grad()
|
@@ -271,7 +255,7 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|
271
255
|
) -> torch.Tensor:
|
272
256
|
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
273
257
|
return self.logits_processor(
|
274
|
-
input_ids, hidden_states, self.lm_head
|
258
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
275
259
|
)
|
276
260
|
|
277
261
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/grok.py
CHANGED
@@ -24,7 +24,6 @@ from torch import nn
|
|
24
24
|
from transformers import PretrainedConfig
|
25
25
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
26
26
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
27
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
28
27
|
|
29
28
|
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
30
29
|
from sglang.srt.layers.layernorm import RMSNorm
|
@@ -43,6 +42,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
43
42
|
)
|
44
43
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
45
44
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
45
|
+
from sglang.srt.model_loader.loader import DefaultModelLoader
|
46
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
46
47
|
|
47
48
|
|
48
49
|
class Grok1MoE(nn.Module):
|
@@ -285,7 +286,6 @@ class Grok1ForCausalLM(nn.Module):
|
|
285
286
|
self,
|
286
287
|
config: PretrainedConfig,
|
287
288
|
quant_config: Optional[QuantizationConfig] = None,
|
288
|
-
cache_config=None,
|
289
289
|
) -> None:
|
290
290
|
super().__init__()
|
291
291
|
self.config = config
|
@@ -304,7 +304,7 @@ class Grok1ForCausalLM(nn.Module):
|
|
304
304
|
) -> torch.Tensor:
|
305
305
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
306
306
|
return self.logits_processor(
|
307
|
-
input_ids, hidden_states, self.lm_head
|
307
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
308
308
|
)
|
309
309
|
|
310
310
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/internlm2.py
CHANGED
@@ -21,7 +21,6 @@ from torch import nn
|
|
21
21
|
from transformers import PretrainedConfig
|
22
22
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
23
23
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
24
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
25
24
|
|
26
25
|
from sglang.srt.layers.activation import SiluAndMul
|
27
26
|
from sglang.srt.layers.layernorm import RMSNorm
|
@@ -38,6 +37,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
38
37
|
VocabParallelEmbedding,
|
39
38
|
)
|
40
39
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
40
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
41
41
|
|
42
42
|
|
43
43
|
class InternLM2MLP(nn.Module):
|
@@ -251,7 +251,6 @@ class InternLM2ForCausalLM(nn.Module):
|
|
251
251
|
self,
|
252
252
|
config: PretrainedConfig,
|
253
253
|
quant_config: Optional[QuantizationConfig] = None,
|
254
|
-
cache_config=None,
|
255
254
|
) -> None:
|
256
255
|
super().__init__()
|
257
256
|
self.config = config
|
@@ -270,7 +269,7 @@ class InternLM2ForCausalLM(nn.Module):
|
|
270
269
|
) -> torch.Tensor:
|
271
270
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
272
271
|
return self.logits_processor(
|
273
|
-
input_ids, hidden_states, self.output
|
272
|
+
input_ids, hidden_states, self.output, forward_batch
|
274
273
|
)
|
275
274
|
|
276
275
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/llama.py
CHANGED
@@ -16,6 +16,7 @@
|
|
16
16
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
|
17
17
|
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
18
18
|
|
19
|
+
import logging
|
19
20
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
20
21
|
|
21
22
|
import torch
|
@@ -23,7 +24,6 @@ from torch import nn
|
|
23
24
|
from transformers import LlamaConfig
|
24
25
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
25
26
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
26
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
27
27
|
|
28
28
|
from sglang.srt.layers.activation import SiluAndMul
|
29
29
|
from sglang.srt.layers.layernorm import RMSNorm
|
@@ -43,7 +43,11 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
43
43
|
)
|
44
44
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
45
45
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
46
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
46
47
|
from sglang.srt.utils import make_layers
|
48
|
+
from sglang.utils import get_exception_traceback
|
49
|
+
|
50
|
+
logger = logging.getLogger(__name__)
|
47
51
|
|
48
52
|
|
49
53
|
class LlamaMLP(nn.Module):
|
@@ -255,6 +259,7 @@ class LlamaModel(nn.Module):
|
|
255
259
|
self.embed_tokens = VocabParallelEmbedding(
|
256
260
|
config.vocab_size,
|
257
261
|
config.hidden_size,
|
262
|
+
quant_config=quant_config,
|
258
263
|
)
|
259
264
|
self.layers = make_layers(
|
260
265
|
config.num_hidden_layers,
|
@@ -295,16 +300,30 @@ class LlamaForCausalLM(nn.Module):
|
|
295
300
|
self,
|
296
301
|
config: LlamaConfig,
|
297
302
|
quant_config: Optional[QuantizationConfig] = None,
|
298
|
-
cache_config=None,
|
299
303
|
) -> None:
|
300
304
|
super().__init__()
|
301
305
|
self.config = config
|
302
306
|
self.quant_config = quant_config
|
303
307
|
self.torchao_config = global_server_args_dict["torchao_config"]
|
304
308
|
self.model = LlamaModel(config, quant_config=quant_config)
|
305
|
-
|
309
|
+
# Llama 3.2 1B Insturct set tie_word_embeddings to True
|
310
|
+
# Llama 3.1 8B Insturct set tie_word_embeddings to False
|
311
|
+
if self.config.tie_word_embeddings:
|
312
|
+
self.lm_head = self.model.embed_tokens
|
313
|
+
else:
|
314
|
+
self.lm_head = ParallelLMHead(
|
315
|
+
config.vocab_size, config.hidden_size, quant_config=quant_config
|
316
|
+
)
|
306
317
|
self.logits_processor = LogitsProcessor(config)
|
307
318
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
319
|
+
self.stacked_params_mapping = [
|
320
|
+
# (param_name, shard_name, shard_id)
|
321
|
+
(".qkv_proj", ".q_proj", "q"),
|
322
|
+
(".qkv_proj", ".k_proj", "k"),
|
323
|
+
(".qkv_proj", ".v_proj", "v"),
|
324
|
+
(".gate_up_proj", ".gate_proj", 0),
|
325
|
+
(".gate_up_proj", ".up_proj", 1),
|
326
|
+
]
|
308
327
|
|
309
328
|
@torch.no_grad()
|
310
329
|
def forward(
|
@@ -318,7 +337,7 @@ class LlamaForCausalLM(nn.Module):
|
|
318
337
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
319
338
|
if not get_embedding:
|
320
339
|
return self.logits_processor(
|
321
|
-
input_ids, hidden_states, self.lm_head
|
340
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
322
341
|
)
|
323
342
|
else:
|
324
343
|
return self.pooler(hidden_states, forward_batch)
|
@@ -349,15 +368,7 @@ class LlamaForCausalLM(nn.Module):
|
|
349
368
|
return params_mapping.get(name, name)
|
350
369
|
|
351
370
|
def get_module_name_from_weight_name(self, name):
|
352
|
-
|
353
|
-
# (param_name, shard_name, shard_id, num_shard)
|
354
|
-
("qkv_proj", "q_proj", "q", 3),
|
355
|
-
("qkv_proj", "k_proj", "k", 3),
|
356
|
-
("qkv_proj", "v_proj", "v", 3),
|
357
|
-
("gate_up_proj", "gate_proj", 0, 2),
|
358
|
-
("gate_up_proj", "up_proj", 1, 2),
|
359
|
-
]
|
360
|
-
for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
|
371
|
+
for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
|
361
372
|
if weight_name in name:
|
362
373
|
return (
|
363
374
|
name.replace(weight_name, param_name)[: -len(".weight")],
|
@@ -378,13 +389,8 @@ class LlamaForCausalLM(nn.Module):
|
|
378
389
|
(".gate_up_proj", ".gate_proj", 0),
|
379
390
|
(".gate_up_proj", ".up_proj", 1),
|
380
391
|
]
|
381
|
-
params_dict = dict(self.named_parameters())
|
382
392
|
|
383
|
-
|
384
|
-
hasattr(self.config, "tie_word_embeddings")
|
385
|
-
and self.config.tie_word_embeddings
|
386
|
-
and "lm_head.weight" in params_dict
|
387
|
-
)
|
393
|
+
params_dict = dict(self.named_parameters())
|
388
394
|
|
389
395
|
for name, loaded_weight in weights:
|
390
396
|
if "rotary_emb.inv_freq" in name or "projector" in name:
|
@@ -418,16 +424,80 @@ class LlamaForCausalLM(nn.Module):
|
|
418
424
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
419
425
|
weight_loader(param, loaded_weight)
|
420
426
|
|
421
|
-
|
422
|
-
embed_tokens_weight = loaded_weight
|
427
|
+
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
423
428
|
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
+
def get_weights_by_name(
|
430
|
+
self, name: str, truncate_size: int = 100, tp_size: int = 1
|
431
|
+
) -> Optional[torch.Tensor]:
|
432
|
+
"""Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
|
433
|
+
|
434
|
+
Only used for unit test with an unoptimized performance.
|
435
|
+
For optimized performance, please use torch.save and torch.load.
|
436
|
+
"""
|
437
|
+
try:
|
438
|
+
if name == "lm_head.weight" and self.config.tie_word_embeddings:
|
439
|
+
logger.info(
|
440
|
+
"word embedding is tied for this model, return embed_tokens.weight as lm_head.weight."
|
441
|
+
)
|
442
|
+
return (
|
443
|
+
self.model.embed_tokens.weight.cpu()
|
444
|
+
.to(torch.float32)
|
445
|
+
.numpy()
|
446
|
+
.tolist()[:truncate_size]
|
447
|
+
)
|
429
448
|
|
430
|
-
|
449
|
+
mapped_name = name
|
450
|
+
mapped_shard_id = None
|
451
|
+
for param_name, weight_name, shard_id in self.stacked_params_mapping:
|
452
|
+
if weight_name in name:
|
453
|
+
mapped_name = name.replace(weight_name, param_name)
|
454
|
+
mapped_shard_id = shard_id
|
455
|
+
break
|
456
|
+
params_dict = dict(self.named_parameters())
|
457
|
+
param = params_dict[mapped_name]
|
458
|
+
if mapped_shard_id is not None:
|
459
|
+
if mapped_shard_id in ["q", "k", "v"]:
|
460
|
+
num_heads = self.config.num_attention_heads // tp_size
|
461
|
+
num_kv_heads = self.config.num_key_value_heads // tp_size
|
462
|
+
head_dim = (
|
463
|
+
self.config.hidden_size // self.config.num_attention_heads
|
464
|
+
)
|
465
|
+
if mapped_shard_id == "q":
|
466
|
+
offset = 0
|
467
|
+
size = num_heads * head_dim
|
468
|
+
elif mapped_shard_id == "k":
|
469
|
+
offset = num_heads * head_dim
|
470
|
+
size = num_kv_heads * head_dim
|
471
|
+
elif mapped_shard_id == "v":
|
472
|
+
offset = (num_heads + num_kv_heads) * head_dim
|
473
|
+
size = num_kv_heads * head_dim
|
474
|
+
weight = param.data.narrow(0, offset, size)
|
475
|
+
elif mapped_shard_id in [0, 1]:
|
476
|
+
intermediate_size = self.config.intermediate_size
|
477
|
+
slice_size = intermediate_size // tp_size
|
478
|
+
if mapped_shard_id == 0: # gate_proj
|
479
|
+
offset = 0
|
480
|
+
size = slice_size
|
481
|
+
elif mapped_shard_id == 1: # up_proj
|
482
|
+
offset = slice_size
|
483
|
+
size = slice_size
|
484
|
+
|
485
|
+
weight = param.data.narrow(0, offset, size)
|
486
|
+
else:
|
487
|
+
weight = param.data
|
488
|
+
else:
|
489
|
+
weight = param.data
|
490
|
+
if tp_size > 1 and ("o_proj" in name or "down_proj" in name):
|
491
|
+
gathered_weights = [torch.zeros_like(weight) for _ in range(tp_size)]
|
492
|
+
torch.distributed.all_gather(gathered_weights, weight)
|
493
|
+
weight = torch.cat(gathered_weights, dim=1)
|
494
|
+
return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size]
|
495
|
+
|
496
|
+
except Exception:
|
497
|
+
logger.error(
|
498
|
+
f"Error getting weights by name {name} in LlamaForCausalLM: {get_exception_traceback()}"
|
499
|
+
)
|
500
|
+
return None
|
431
501
|
|
432
502
|
|
433
503
|
class Phi3ForCausalLM(LlamaForCausalLM):
|
@@ -17,11 +17,11 @@ from typing import Iterable, Optional, Tuple
|
|
17
17
|
import torch
|
18
18
|
from torch import nn
|
19
19
|
from transformers import LlamaConfig
|
20
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
21
20
|
|
22
21
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
23
22
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
24
23
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
24
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
25
25
|
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
|
26
26
|
|
27
27
|
|
@@ -30,7 +30,6 @@ class LlamaForClassification(nn.Module):
|
|
30
30
|
self,
|
31
31
|
config: LlamaConfig,
|
32
32
|
quant_config: Optional[QuantizationConfig] = None,
|
33
|
-
cache_config=None,
|
34
33
|
) -> None:
|
35
34
|
super().__init__()
|
36
35
|
self.config = config
|
@@ -3,10 +3,10 @@ from typing import Iterable, Tuple
|
|
3
3
|
import torch
|
4
4
|
from torch import nn
|
5
5
|
from transformers import LlamaConfig
|
6
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
7
6
|
|
8
7
|
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
9
8
|
from sglang.srt.model_executor.model_runner import ForwardBatch
|
9
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
10
10
|
from sglang.srt.models.llama import LlamaModel
|
11
11
|
|
12
12
|
|
@@ -15,7 +15,6 @@ class LlamaEmbeddingModel(nn.Module):
|
|
15
15
|
self,
|
16
16
|
config: LlamaConfig,
|
17
17
|
quant_config=None,
|
18
|
-
cache_config=None,
|
19
18
|
) -> None:
|
20
19
|
super().__init__()
|
21
20
|
self.model = LlamaModel(config, quant_config=quant_config)
|
@@ -21,6 +21,7 @@ from transformers import LlamaConfig
|
|
21
21
|
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
22
22
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
23
23
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
24
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
24
25
|
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
|
25
26
|
|
26
27
|
|
@@ -29,7 +30,6 @@ class LlamaForSequenceClassification(nn.Module):
|
|
29
30
|
self,
|
30
31
|
config: LlamaConfig,
|
31
32
|
quant_config: Optional[QuantizationConfig] = None,
|
32
|
-
cache_config=None,
|
33
33
|
) -> None:
|
34
34
|
super().__init__()
|
35
35
|
self.config = config
|
@@ -84,9 +84,8 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific
|
|
84
84
|
self,
|
85
85
|
config: LlamaConfig,
|
86
86
|
quant_config: Optional[QuantizationConfig] = None,
|
87
|
-
cache_config=None,
|
88
87
|
) -> None:
|
89
|
-
super().__init__(config, quant_config
|
88
|
+
super().__init__(config, quant_config)
|
90
89
|
self.weights = self.Weights(config.hidden_size, self.num_labels)
|
91
90
|
|
92
91
|
@torch.no_grad()
|
sglang/srt/models/llava.py
CHANGED
@@ -29,7 +29,6 @@ from transformers import (
|
|
29
29
|
SiglipVisionModel,
|
30
30
|
)
|
31
31
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
32
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
33
32
|
|
34
33
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
35
34
|
from sglang.srt.managers.schedule_batch import ImageInputs
|
@@ -39,6 +38,7 @@ from sglang.srt.mm_utils import (
|
|
39
38
|
unpad_image_shape,
|
40
39
|
)
|
41
40
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
41
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
42
42
|
from sglang.srt.models.llama import LlamaForCausalLM
|
43
43
|
from sglang.srt.models.mistral import MistralForCausalLM
|
44
44
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
@@ -451,7 +451,6 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
|
|
451
451
|
self,
|
452
452
|
config: LlavaConfig,
|
453
453
|
quant_config: Optional[QuantizationConfig] = None,
|
454
|
-
cache_config=None,
|
455
454
|
) -> None:
|
456
455
|
super().__init__()
|
457
456
|
|
@@ -473,7 +472,6 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
|
|
473
472
|
self,
|
474
473
|
config: LlavaConfig,
|
475
474
|
quant_config: Optional[QuantizationConfig] = None,
|
476
|
-
cache_config=None,
|
477
475
|
) -> None:
|
478
476
|
super().__init__()
|
479
477
|
|
@@ -506,7 +504,6 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
|
|
506
504
|
self,
|
507
505
|
config: LlavaConfig,
|
508
506
|
quant_config: Optional[QuantizationConfig] = None,
|
509
|
-
cache_config=None,
|
510
507
|
) -> None:
|
511
508
|
super().__init__()
|
512
509
|
|
sglang/srt/models/llavavid.py
CHANGED
@@ -20,11 +20,11 @@ import torch
|
|
20
20
|
from torch import nn
|
21
21
|
from transformers import CLIPVisionModel, LlavaConfig
|
22
22
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
23
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
24
23
|
|
25
24
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
26
25
|
from sglang.srt.managers.schedule_batch import ImageInputs
|
27
26
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
27
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
28
28
|
from sglang.srt.models.llama import LlamaForCausalLM
|
29
29
|
|
30
30
|
|
@@ -33,7 +33,6 @@ class LlavaVidForCausalLM(nn.Module):
|
|
33
33
|
self,
|
34
34
|
config: LlavaConfig,
|
35
35
|
quant_config: Optional[QuantizationConfig] = None,
|
36
|
-
cache_config=None,
|
37
36
|
) -> None:
|
38
37
|
super().__init__()
|
39
38
|
self.config = config
|
sglang/srt/models/minicpm.py
CHANGED
@@ -20,7 +20,6 @@ import torch
|
|
20
20
|
from torch import nn
|
21
21
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
22
22
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
23
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
24
23
|
|
25
24
|
from sglang.srt.layers.activation import SiluAndMul
|
26
25
|
from sglang.srt.layers.layernorm import RMSNorm
|
@@ -37,6 +36,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
37
36
|
VocabParallelEmbedding,
|
38
37
|
)
|
39
38
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
39
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
40
40
|
|
41
41
|
|
42
42
|
class MiniCPMMLP(nn.Module):
|
@@ -275,7 +275,6 @@ class MiniCPMForCausalLM(nn.Module):
|
|
275
275
|
self,
|
276
276
|
config,
|
277
277
|
quant_config: Optional[QuantizationConfig] = None,
|
278
|
-
cache_config=None,
|
279
278
|
) -> None:
|
280
279
|
super().__init__()
|
281
280
|
self.config = config
|
@@ -308,12 +307,10 @@ class MiniCPMForCausalLM(nn.Module):
|
|
308
307
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
309
308
|
hidden_states = hidden_states / self.scale_width
|
310
309
|
if self.config.tie_word_embeddings:
|
311
|
-
|
310
|
+
lm_head = self.model.embed_tokens
|
312
311
|
else:
|
313
|
-
|
314
|
-
return self.logits_processor(
|
315
|
-
input_ids, hidden_states, lm_head_weight, forward_batch
|
316
|
-
)
|
312
|
+
lm_head = self.lm_head
|
313
|
+
return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch)
|
317
314
|
|
318
315
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
319
316
|
stacked_params_mapping = [
|
sglang/srt/models/minicpm3.py
CHANGED
@@ -27,7 +27,6 @@ from vllm.model_executor.layers.linear import (
|
|
27
27
|
RowParallelLinear,
|
28
28
|
)
|
29
29
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
30
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
31
30
|
|
32
31
|
from sglang.srt.layers.activation import SiluAndMul
|
33
32
|
from sglang.srt.layers.layernorm import RMSNorm
|
@@ -40,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
40
39
|
)
|
41
40
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
42
41
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
42
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
43
43
|
from sglang.srt.utils import is_flashinfer_available
|
44
44
|
|
45
45
|
if is_flashinfer_available():
|
@@ -105,7 +105,6 @@ class MiniCPM3Attention(nn.Module):
|
|
105
105
|
rope_theta: float = 10000,
|
106
106
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
107
107
|
max_position_embeddings: int = 8192,
|
108
|
-
cache_config=None,
|
109
108
|
quant_config: Optional[QuantizationConfig] = None,
|
110
109
|
layer_id=None,
|
111
110
|
) -> None:
|
@@ -249,7 +248,6 @@ class MiniCPM3AttentionMLA(nn.Module):
|
|
249
248
|
rope_theta: float = 10000,
|
250
249
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
251
250
|
max_position_embeddings: int = 8192,
|
252
|
-
cache_config=None,
|
253
251
|
quant_config: Optional[QuantizationConfig] = None,
|
254
252
|
layer_id=None,
|
255
253
|
) -> None:
|
@@ -406,7 +404,6 @@ class MiniCPM3DecoderLayer(nn.Module):
|
|
406
404
|
self,
|
407
405
|
config: PretrainedConfig,
|
408
406
|
layer_id: int,
|
409
|
-
cache_config=None,
|
410
407
|
quant_config: Optional[QuantizationConfig] = None,
|
411
408
|
) -> None:
|
412
409
|
super().__init__()
|
@@ -430,7 +427,6 @@ class MiniCPM3DecoderLayer(nn.Module):
|
|
430
427
|
rope_theta=rope_theta,
|
431
428
|
rope_scaling=rope_scaling,
|
432
429
|
max_position_embeddings=max_position_embeddings,
|
433
|
-
cache_config=cache_config,
|
434
430
|
quant_config=quant_config,
|
435
431
|
layer_id=layer_id,
|
436
432
|
)
|
@@ -449,7 +445,6 @@ class MiniCPM3DecoderLayer(nn.Module):
|
|
449
445
|
rope_theta=rope_theta,
|
450
446
|
rope_scaling=rope_scaling,
|
451
447
|
max_position_embeddings=max_position_embeddings,
|
452
|
-
cache_config=cache_config,
|
453
448
|
quant_config=quant_config,
|
454
449
|
layer_id=layer_id,
|
455
450
|
)
|
@@ -498,7 +493,6 @@ class MiniCPM3Model(nn.Module):
|
|
498
493
|
def __init__(
|
499
494
|
self,
|
500
495
|
config: PretrainedConfig,
|
501
|
-
cache_config=None,
|
502
496
|
quant_config: Optional[QuantizationConfig] = None,
|
503
497
|
) -> None:
|
504
498
|
super().__init__()
|
@@ -512,9 +506,7 @@ class MiniCPM3Model(nn.Module):
|
|
512
506
|
)
|
513
507
|
self.layers = nn.ModuleList(
|
514
508
|
[
|
515
|
-
MiniCPM3DecoderLayer(
|
516
|
-
config, i, cache_config=cache_config, quant_config=quant_config
|
517
|
-
)
|
509
|
+
MiniCPM3DecoderLayer(config, i, quant_config=quant_config)
|
518
510
|
for i in range(config.num_hidden_layers)
|
519
511
|
]
|
520
512
|
)
|
@@ -549,7 +541,6 @@ class MiniCPM3ForCausalLM(nn.Module):
|
|
549
541
|
def __init__(
|
550
542
|
self,
|
551
543
|
config: PretrainedConfig,
|
552
|
-
cache_config=None,
|
553
544
|
quant_config: Optional[QuantizationConfig] = None,
|
554
545
|
) -> None:
|
555
546
|
super().__init__()
|
@@ -557,9 +548,7 @@ class MiniCPM3ForCausalLM(nn.Module):
|
|
557
548
|
|
558
549
|
self.num_experts = getattr(self.config, "num_experts", 0)
|
559
550
|
self.quant_config = quant_config
|
560
|
-
self.model = MiniCPM3Model(
|
561
|
-
config, cache_config=cache_config, quant_config=quant_config
|
562
|
-
)
|
551
|
+
self.model = MiniCPM3Model(config, quant_config=quant_config)
|
563
552
|
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
564
553
|
if not self.config.tie_word_embeddings:
|
565
554
|
self.lm_head = ParallelLMHead(
|
@@ -585,12 +574,10 @@ class MiniCPM3ForCausalLM(nn.Module):
|
|
585
574
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
586
575
|
hidden_states = hidden_states / self.scale_width
|
587
576
|
if self.config.tie_word_embeddings:
|
588
|
-
|
577
|
+
lm_head = self.model.embed_tokens
|
589
578
|
else:
|
590
|
-
|
591
|
-
return self.logits_processor(
|
592
|
-
input_ids, hidden_states, lm_head_weight, forward_batch
|
593
|
-
)
|
579
|
+
lm_head = self.lm_head
|
580
|
+
return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch)
|
594
581
|
|
595
582
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
596
583
|
stacked_params_mapping = [
|