sglang 0.3.6.post2__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +55 -2
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +4 -3
- sglang/bench_serving.py +13 -0
- sglang/check_env.py +1 -1
- sglang/launch_server.py +3 -2
- sglang/srt/_custom_ops.py +118 -0
- sglang/srt/configs/device_config.py +17 -0
- sglang/srt/configs/load_config.py +84 -0
- sglang/srt/configs/model_config.py +161 -4
- sglang/srt/configs/qwen2vl.py +5 -8
- sglang/srt/constrained/outlines_backend.py +6 -1
- sglang/srt/constrained/outlines_jump_forward.py +8 -1
- sglang/srt/distributed/__init__.py +3 -0
- sglang/srt/distributed/communication_op.py +34 -0
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
- sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
- sglang/srt/distributed/device_communicators/pynccl.py +204 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
- sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
- sglang/srt/distributed/parallel_state.py +1275 -0
- sglang/srt/distributed/utils.py +223 -0
- sglang/srt/hf_transformers_utils.py +37 -1
- sglang/srt/layers/attention/flashinfer_backend.py +13 -15
- sglang/srt/layers/attention/torch_native_backend.py +285 -0
- sglang/srt/layers/fused_moe_patch.py +20 -11
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +17 -3
- sglang/srt/layers/quantization/__init__.py +34 -0
- sglang/srt/layers/vocab_parallel_embedding.py +1 -0
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +7 -11
- sglang/srt/managers/detokenizer_manager.py +7 -4
- sglang/srt/managers/image_processor.py +1 -1
- sglang/srt/managers/io_struct.py +48 -12
- sglang/srt/managers/schedule_batch.py +42 -36
- sglang/srt/managers/schedule_policy.py +7 -4
- sglang/srt/managers/scheduler.py +111 -46
- sglang/srt/managers/session_controller.py +0 -3
- sglang/srt/managers/tokenizer_manager.py +169 -100
- sglang/srt/managers/tp_worker.py +36 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +32 -5
- sglang/srt/model_executor/cuda_graph_runner.py +16 -7
- sglang/srt/model_executor/forward_batch_info.py +9 -4
- sglang/srt/model_executor/model_runner.py +136 -150
- sglang/srt/model_loader/__init__.py +34 -0
- sglang/srt/model_loader/loader.py +1139 -0
- sglang/srt/model_loader/utils.py +41 -0
- sglang/srt/model_loader/weight_utils.py +640 -0
- sglang/srt/models/baichuan.py +9 -10
- sglang/srt/models/chatglm.py +6 -15
- sglang/srt/models/commandr.py +2 -3
- sglang/srt/models/dbrx.py +2 -3
- sglang/srt/models/deepseek.py +4 -11
- sglang/srt/models/deepseek_v2.py +3 -11
- sglang/srt/models/exaone.py +2 -3
- sglang/srt/models/gemma.py +2 -6
- sglang/srt/models/gemma2.py +3 -14
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/gpt2.py +5 -12
- sglang/srt/models/gpt_bigcode.py +6 -22
- sglang/srt/models/grok.py +14 -51
- sglang/srt/models/internlm2.py +2 -3
- sglang/srt/models/internlm2_reward.py +0 -1
- sglang/srt/models/llama.py +97 -27
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_embedding.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +10 -12
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +4 -7
- sglang/srt/models/minicpm3.py +6 -19
- sglang/srt/models/mixtral.py +12 -5
- sglang/srt/models/mixtral_quant.py +2 -3
- sglang/srt/models/mllama.py +3 -7
- sglang/srt/models/olmo.py +2 -8
- sglang/srt/models/olmo2.py +391 -0
- sglang/srt/models/olmoe.py +3 -5
- sglang/srt/models/phi3_small.py +8 -8
- sglang/srt/models/qwen.py +2 -3
- sglang/srt/models/qwen2.py +10 -9
- sglang/srt/models/qwen2_moe.py +4 -11
- sglang/srt/models/qwen2_vl.py +12 -9
- sglang/srt/models/registry.py +99 -0
- sglang/srt/models/stablelm.py +2 -3
- sglang/srt/models/torch_native_llama.py +6 -12
- sglang/srt/models/xverse.py +2 -4
- sglang/srt/models/xverse_moe.py +4 -11
- sglang/srt/models/yivl.py +2 -3
- sglang/srt/openai_api/adapter.py +10 -6
- sglang/srt/openai_api/protocol.py +1 -0
- sglang/srt/server.py +303 -204
- sglang/srt/server_args.py +65 -31
- sglang/srt/utils.py +253 -48
- sglang/test/test_utils.py +27 -7
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/METADATA +2 -1
- sglang-0.4.0.dist-info/RECORD +184 -0
- sglang/srt/layers/fused_moe_grok/__init__.py +0 -1
- sglang/srt/layers/fused_moe_grok/fused_moe.py +0 -692
- sglang/srt/layers/fused_moe_grok/layer.py +0 -630
- sglang-0.3.6.post2.dist-info/RECORD +0 -164
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
sglang/srt/models/gpt_bigcode.py
CHANGED
@@ -21,9 +21,7 @@ from typing import Iterable, Optional, Tuple
|
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
23
23
|
from transformers import GPTBigCodeConfig
|
24
|
-
from vllm.config import LoRAConfig
|
25
24
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
26
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
27
25
|
|
28
26
|
from sglang.srt.layers.activation import get_act_fn
|
29
27
|
from sglang.srt.layers.linear import (
|
@@ -36,6 +34,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
36
34
|
from sglang.srt.layers.radix_attention import RadixAttention
|
37
35
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
38
36
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
37
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
39
38
|
|
40
39
|
|
41
40
|
class GPTBigCodeAttention(nn.Module):
|
@@ -44,7 +43,6 @@ class GPTBigCodeAttention(nn.Module):
|
|
44
43
|
self,
|
45
44
|
layer_id: int,
|
46
45
|
config: GPTBigCodeConfig,
|
47
|
-
cache_config=None,
|
48
46
|
quant_config: Optional[QuantizationConfig] = None,
|
49
47
|
):
|
50
48
|
super().__init__()
|
@@ -145,7 +143,6 @@ class GPTBigCodeBlock(nn.Module):
|
|
145
143
|
self,
|
146
144
|
layer_id: int,
|
147
145
|
config: GPTBigCodeConfig,
|
148
|
-
cache_config=None,
|
149
146
|
quant_config: Optional[QuantizationConfig] = None,
|
150
147
|
):
|
151
148
|
super().__init__()
|
@@ -153,7 +150,7 @@ class GPTBigCodeBlock(nn.Module):
|
|
153
150
|
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
154
151
|
|
155
152
|
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
156
|
-
self.attn = GPTBigCodeAttention(layer_id, config,
|
153
|
+
self.attn = GPTBigCodeAttention(layer_id, config, quant_config)
|
157
154
|
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
158
155
|
self.mlp = GPTBigMLP(inner_dim, config, quant_config)
|
159
156
|
|
@@ -183,20 +180,14 @@ class GPTBigCodeModel(nn.Module):
|
|
183
180
|
def __init__(
|
184
181
|
self,
|
185
182
|
config: GPTBigCodeConfig,
|
186
|
-
cache_config=None,
|
187
183
|
quant_config: Optional[QuantizationConfig] = None,
|
188
|
-
lora_config: Optional[LoRAConfig] = None,
|
189
184
|
):
|
190
185
|
super().__init__()
|
191
186
|
self.config = config
|
192
187
|
assert not config.add_cross_attention
|
193
188
|
|
194
189
|
self.embed_dim = config.hidden_size
|
195
|
-
lora_vocab =
|
196
|
-
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
|
197
|
-
if lora_config
|
198
|
-
else 0
|
199
|
-
)
|
190
|
+
lora_vocab = 0
|
200
191
|
self.vocab_size = config.vocab_size + lora_vocab
|
201
192
|
self.wte = VocabParallelEmbedding(
|
202
193
|
self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size
|
@@ -204,7 +195,7 @@ class GPTBigCodeModel(nn.Module):
|
|
204
195
|
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
205
196
|
self.h = nn.ModuleList(
|
206
197
|
[
|
207
|
-
GPTBigCodeBlock(i, config,
|
198
|
+
GPTBigCodeBlock(i, config, quant_config)
|
208
199
|
for i in range(config.num_hidden_layers)
|
209
200
|
]
|
210
201
|
)
|
@@ -243,23 +234,16 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|
243
234
|
def __init__(
|
244
235
|
self,
|
245
236
|
config: GPTBigCodeConfig,
|
246
|
-
cache_config=None,
|
247
237
|
quant_config: Optional[QuantizationConfig] = None,
|
248
|
-
lora_config: Optional[LoRAConfig] = None,
|
249
238
|
):
|
250
239
|
super().__init__()
|
251
240
|
|
252
241
|
self.config = config
|
253
|
-
self.lora_config = lora_config
|
254
242
|
|
255
243
|
self.quant_config = quant_config
|
256
|
-
self.transformer = GPTBigCodeModel(
|
257
|
-
config, cache_config, quant_config, lora_config
|
258
|
-
)
|
244
|
+
self.transformer = GPTBigCodeModel(config, quant_config)
|
259
245
|
self.lm_head = self.transformer.wte
|
260
246
|
self.unpadded_vocab_size = config.vocab_size
|
261
|
-
if lora_config:
|
262
|
-
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
263
247
|
self.logits_processor = LogitsProcessor(config)
|
264
248
|
|
265
249
|
@torch.no_grad()
|
@@ -271,7 +255,7 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|
271
255
|
) -> torch.Tensor:
|
272
256
|
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
273
257
|
return self.logits_processor(
|
274
|
-
input_ids, hidden_states, self.lm_head
|
258
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
275
259
|
)
|
276
260
|
|
277
261
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/grok.py
CHANGED
@@ -16,22 +16,16 @@
|
|
16
16
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
|
17
17
|
"""Inference-only Grok1 model."""
|
18
18
|
|
19
|
-
import
|
20
|
-
from typing import Iterable, List, Optional, Tuple
|
19
|
+
from typing import Iterable, Optional, Tuple
|
21
20
|
|
22
21
|
import torch
|
23
22
|
import torch.nn.functional as F
|
24
23
|
from torch import nn
|
25
24
|
from transformers import PretrainedConfig
|
26
|
-
from vllm.distributed import
|
27
|
-
get_tensor_model_parallel_rank,
|
28
|
-
get_tensor_model_parallel_world_size,
|
29
|
-
)
|
25
|
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
30
26
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
31
|
-
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
32
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
33
27
|
|
34
|
-
from sglang.srt.layers.
|
28
|
+
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
35
29
|
from sglang.srt.layers.layernorm import RMSNorm
|
36
30
|
from sglang.srt.layers.linear import (
|
37
31
|
QKVParallelLinear,
|
@@ -41,11 +35,15 @@ from sglang.srt.layers.linear import (
|
|
41
35
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
42
36
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
43
37
|
from sglang.srt.layers.radix_attention import RadixAttention
|
38
|
+
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
44
39
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
45
40
|
ParallelLMHead,
|
46
41
|
VocabParallelEmbedding,
|
47
42
|
)
|
43
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
48
44
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
45
|
+
from sglang.srt.model_loader.loader import DefaultModelLoader
|
46
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
49
47
|
|
50
48
|
|
51
49
|
class Grok1MoE(nn.Module):
|
@@ -288,22 +286,15 @@ class Grok1ForCausalLM(nn.Module):
|
|
288
286
|
self,
|
289
287
|
config: PretrainedConfig,
|
290
288
|
quant_config: Optional[QuantizationConfig] = None,
|
291
|
-
cache_config=None,
|
292
289
|
) -> None:
|
293
290
|
super().__init__()
|
294
291
|
self.config = config
|
295
292
|
self.quant_config = quant_config
|
293
|
+
self.torchao_config = global_server_args_dict["torchao_config"]
|
296
294
|
self.model = Grok1Model(config, quant_config=quant_config)
|
297
295
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
298
296
|
self.logits_processor = LogitsProcessor(config)
|
299
297
|
|
300
|
-
# Monkey patch _prepare_weights to load pre-sharded weights
|
301
|
-
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
302
|
-
|
303
|
-
self.use_presharded_weights = True
|
304
|
-
|
305
|
-
warnings.filterwarnings("ignore", category=FutureWarning)
|
306
|
-
|
307
298
|
def forward(
|
308
299
|
self,
|
309
300
|
input_ids: torch.Tensor,
|
@@ -313,7 +304,7 @@ class Grok1ForCausalLM(nn.Module):
|
|
313
304
|
) -> torch.Tensor:
|
314
305
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
315
306
|
return self.logits_processor(
|
316
|
-
input_ids, hidden_states, self.lm_head
|
307
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
317
308
|
)
|
318
309
|
|
319
310
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
@@ -357,28 +348,23 @@ class Grok1ForCausalLM(nn.Module):
|
|
357
348
|
continue
|
358
349
|
name = name.replace(weight_name, param_name)
|
359
350
|
|
360
|
-
if self.use_presharded_weights:
|
361
|
-
extra_kwargs = {
|
362
|
-
"use_presharded_weights": self.use_presharded_weights
|
363
|
-
}
|
364
|
-
else:
|
365
|
-
extra_kwargs = {}
|
366
|
-
|
367
351
|
param = params_dict[name]
|
368
352
|
weight_loader = param.weight_loader
|
369
353
|
weight_loader(
|
370
354
|
param,
|
371
355
|
loaded_weight,
|
372
|
-
|
356
|
+
name,
|
373
357
|
shard_id=shard_id,
|
374
358
|
expert_id=expert_id,
|
375
|
-
**extra_kwargs,
|
376
359
|
)
|
377
360
|
break
|
378
361
|
else:
|
379
362
|
# Skip loading extra bias for GPTQ models.
|
380
363
|
if name.endswith(".bias") and name not in params_dict:
|
381
364
|
continue
|
365
|
+
# Skip loading kv_scale from ckpts towards new design.
|
366
|
+
if name.endswith(".kv_scale") and name not in params_dict:
|
367
|
+
continue
|
382
368
|
if name is None:
|
383
369
|
continue
|
384
370
|
|
@@ -388,30 +374,7 @@ class Grok1ForCausalLM(nn.Module):
|
|
388
374
|
)
|
389
375
|
weight_loader(param, loaded_weight)
|
390
376
|
|
391
|
-
|
392
|
-
old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights")
|
393
|
-
|
394
|
-
|
395
|
-
def _prepare_presharded_weights(
|
396
|
-
self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
|
397
|
-
) -> Tuple[str, List[str], bool]:
|
398
|
-
import glob
|
399
|
-
import os
|
400
|
-
|
401
|
-
if get_tensor_model_parallel_world_size() == 1:
|
402
|
-
return old_prepare_weights(self, model_name_or_path, revision, fall_back_to_pt)
|
403
|
-
|
404
|
-
tp_rank = get_tensor_model_parallel_rank()
|
405
|
-
allow_patterns = [f"*-{tp_rank:03d}.bin"]
|
406
|
-
|
407
|
-
hf_folder = model_name_or_path
|
408
|
-
|
409
|
-
hf_weights_files: List[str] = []
|
410
|
-
for pattern in allow_patterns:
|
411
|
-
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
412
|
-
use_safetensors = False
|
413
|
-
|
414
|
-
return hf_folder, hf_weights_files, use_safetensors
|
377
|
+
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
415
378
|
|
416
379
|
|
417
380
|
class Grok1ModelForCausalLM(Grok1ForCausalLM):
|
sglang/srt/models/internlm2.py
CHANGED
@@ -21,7 +21,6 @@ from torch import nn
|
|
21
21
|
from transformers import PretrainedConfig
|
22
22
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
23
23
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
24
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
25
24
|
|
26
25
|
from sglang.srt.layers.activation import SiluAndMul
|
27
26
|
from sglang.srt.layers.layernorm import RMSNorm
|
@@ -38,6 +37,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
38
37
|
VocabParallelEmbedding,
|
39
38
|
)
|
40
39
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
40
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
41
41
|
|
42
42
|
|
43
43
|
class InternLM2MLP(nn.Module):
|
@@ -251,7 +251,6 @@ class InternLM2ForCausalLM(nn.Module):
|
|
251
251
|
self,
|
252
252
|
config: PretrainedConfig,
|
253
253
|
quant_config: Optional[QuantizationConfig] = None,
|
254
|
-
cache_config=None,
|
255
254
|
) -> None:
|
256
255
|
super().__init__()
|
257
256
|
self.config = config
|
@@ -270,7 +269,7 @@ class InternLM2ForCausalLM(nn.Module):
|
|
270
269
|
) -> torch.Tensor:
|
271
270
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
272
271
|
return self.logits_processor(
|
273
|
-
input_ids, hidden_states, self.output
|
272
|
+
input_ids, hidden_states, self.output, forward_batch
|
274
273
|
)
|
275
274
|
|
276
275
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/llama.py
CHANGED
@@ -16,6 +16,7 @@
|
|
16
16
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
|
17
17
|
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
18
18
|
|
19
|
+
import logging
|
19
20
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
20
21
|
|
21
22
|
import torch
|
@@ -23,7 +24,6 @@ from torch import nn
|
|
23
24
|
from transformers import LlamaConfig
|
24
25
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
25
26
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
26
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
27
27
|
|
28
28
|
from sglang.srt.layers.activation import SiluAndMul
|
29
29
|
from sglang.srt.layers.layernorm import RMSNorm
|
@@ -43,7 +43,11 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
43
43
|
)
|
44
44
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
45
45
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
46
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
46
47
|
from sglang.srt.utils import make_layers
|
48
|
+
from sglang.utils import get_exception_traceback
|
49
|
+
|
50
|
+
logger = logging.getLogger(__name__)
|
47
51
|
|
48
52
|
|
49
53
|
class LlamaMLP(nn.Module):
|
@@ -255,6 +259,7 @@ class LlamaModel(nn.Module):
|
|
255
259
|
self.embed_tokens = VocabParallelEmbedding(
|
256
260
|
config.vocab_size,
|
257
261
|
config.hidden_size,
|
262
|
+
quant_config=quant_config,
|
258
263
|
)
|
259
264
|
self.layers = make_layers(
|
260
265
|
config.num_hidden_layers,
|
@@ -295,16 +300,30 @@ class LlamaForCausalLM(nn.Module):
|
|
295
300
|
self,
|
296
301
|
config: LlamaConfig,
|
297
302
|
quant_config: Optional[QuantizationConfig] = None,
|
298
|
-
cache_config=None,
|
299
303
|
) -> None:
|
300
304
|
super().__init__()
|
301
305
|
self.config = config
|
302
306
|
self.quant_config = quant_config
|
303
307
|
self.torchao_config = global_server_args_dict["torchao_config"]
|
304
308
|
self.model = LlamaModel(config, quant_config=quant_config)
|
305
|
-
|
309
|
+
# Llama 3.2 1B Insturct set tie_word_embeddings to True
|
310
|
+
# Llama 3.1 8B Insturct set tie_word_embeddings to False
|
311
|
+
if self.config.tie_word_embeddings:
|
312
|
+
self.lm_head = self.model.embed_tokens
|
313
|
+
else:
|
314
|
+
self.lm_head = ParallelLMHead(
|
315
|
+
config.vocab_size, config.hidden_size, quant_config=quant_config
|
316
|
+
)
|
306
317
|
self.logits_processor = LogitsProcessor(config)
|
307
318
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
319
|
+
self.stacked_params_mapping = [
|
320
|
+
# (param_name, shard_name, shard_id)
|
321
|
+
(".qkv_proj", ".q_proj", "q"),
|
322
|
+
(".qkv_proj", ".k_proj", "k"),
|
323
|
+
(".qkv_proj", ".v_proj", "v"),
|
324
|
+
(".gate_up_proj", ".gate_proj", 0),
|
325
|
+
(".gate_up_proj", ".up_proj", 1),
|
326
|
+
]
|
308
327
|
|
309
328
|
@torch.no_grad()
|
310
329
|
def forward(
|
@@ -318,7 +337,7 @@ class LlamaForCausalLM(nn.Module):
|
|
318
337
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
319
338
|
if not get_embedding:
|
320
339
|
return self.logits_processor(
|
321
|
-
input_ids, hidden_states, self.lm_head
|
340
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
322
341
|
)
|
323
342
|
else:
|
324
343
|
return self.pooler(hidden_states, forward_batch)
|
@@ -349,15 +368,7 @@ class LlamaForCausalLM(nn.Module):
|
|
349
368
|
return params_mapping.get(name, name)
|
350
369
|
|
351
370
|
def get_module_name_from_weight_name(self, name):
|
352
|
-
|
353
|
-
# (param_name, shard_name, shard_id, num_shard)
|
354
|
-
("qkv_proj", "q_proj", "q", 3),
|
355
|
-
("qkv_proj", "k_proj", "k", 3),
|
356
|
-
("qkv_proj", "v_proj", "v", 3),
|
357
|
-
("gate_up_proj", "gate_proj", 0, 2),
|
358
|
-
("gate_up_proj", "up_proj", 1, 2),
|
359
|
-
]
|
360
|
-
for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
|
371
|
+
for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
|
361
372
|
if weight_name in name:
|
362
373
|
return (
|
363
374
|
name.replace(weight_name, param_name)[: -len(".weight")],
|
@@ -378,13 +389,8 @@ class LlamaForCausalLM(nn.Module):
|
|
378
389
|
(".gate_up_proj", ".gate_proj", 0),
|
379
390
|
(".gate_up_proj", ".up_proj", 1),
|
380
391
|
]
|
381
|
-
params_dict = dict(self.named_parameters())
|
382
392
|
|
383
|
-
|
384
|
-
hasattr(self.config, "tie_word_embeddings")
|
385
|
-
and self.config.tie_word_embeddings
|
386
|
-
and "lm_head.weight" in params_dict
|
387
|
-
)
|
393
|
+
params_dict = dict(self.named_parameters())
|
388
394
|
|
389
395
|
for name, loaded_weight in weights:
|
390
396
|
if "rotary_emb.inv_freq" in name or "projector" in name:
|
@@ -418,16 +424,80 @@ class LlamaForCausalLM(nn.Module):
|
|
418
424
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
419
425
|
weight_loader(param, loaded_weight)
|
420
426
|
|
421
|
-
|
422
|
-
embed_tokens_weight = loaded_weight
|
427
|
+
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
423
428
|
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
+
def get_weights_by_name(
|
430
|
+
self, name: str, truncate_size: int = 100, tp_size: int = 1
|
431
|
+
) -> Optional[torch.Tensor]:
|
432
|
+
"""Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
|
433
|
+
|
434
|
+
Only used for unit test with an unoptimized performance.
|
435
|
+
For optimized performance, please use torch.save and torch.load.
|
436
|
+
"""
|
437
|
+
try:
|
438
|
+
if name == "lm_head.weight" and self.config.tie_word_embeddings:
|
439
|
+
logger.info(
|
440
|
+
"word embedding is tied for this model, return embed_tokens.weight as lm_head.weight."
|
441
|
+
)
|
442
|
+
return (
|
443
|
+
self.model.embed_tokens.weight.cpu()
|
444
|
+
.to(torch.float32)
|
445
|
+
.numpy()
|
446
|
+
.tolist()[:truncate_size]
|
447
|
+
)
|
429
448
|
|
430
|
-
|
449
|
+
mapped_name = name
|
450
|
+
mapped_shard_id = None
|
451
|
+
for param_name, weight_name, shard_id in self.stacked_params_mapping:
|
452
|
+
if weight_name in name:
|
453
|
+
mapped_name = name.replace(weight_name, param_name)
|
454
|
+
mapped_shard_id = shard_id
|
455
|
+
break
|
456
|
+
params_dict = dict(self.named_parameters())
|
457
|
+
param = params_dict[mapped_name]
|
458
|
+
if mapped_shard_id is not None:
|
459
|
+
if mapped_shard_id in ["q", "k", "v"]:
|
460
|
+
num_heads = self.config.num_attention_heads // tp_size
|
461
|
+
num_kv_heads = self.config.num_key_value_heads // tp_size
|
462
|
+
head_dim = (
|
463
|
+
self.config.hidden_size // self.config.num_attention_heads
|
464
|
+
)
|
465
|
+
if mapped_shard_id == "q":
|
466
|
+
offset = 0
|
467
|
+
size = num_heads * head_dim
|
468
|
+
elif mapped_shard_id == "k":
|
469
|
+
offset = num_heads * head_dim
|
470
|
+
size = num_kv_heads * head_dim
|
471
|
+
elif mapped_shard_id == "v":
|
472
|
+
offset = (num_heads + num_kv_heads) * head_dim
|
473
|
+
size = num_kv_heads * head_dim
|
474
|
+
weight = param.data.narrow(0, offset, size)
|
475
|
+
elif mapped_shard_id in [0, 1]:
|
476
|
+
intermediate_size = self.config.intermediate_size
|
477
|
+
slice_size = intermediate_size // tp_size
|
478
|
+
if mapped_shard_id == 0: # gate_proj
|
479
|
+
offset = 0
|
480
|
+
size = slice_size
|
481
|
+
elif mapped_shard_id == 1: # up_proj
|
482
|
+
offset = slice_size
|
483
|
+
size = slice_size
|
484
|
+
|
485
|
+
weight = param.data.narrow(0, offset, size)
|
486
|
+
else:
|
487
|
+
weight = param.data
|
488
|
+
else:
|
489
|
+
weight = param.data
|
490
|
+
if tp_size > 1 and ("o_proj" in name or "down_proj" in name):
|
491
|
+
gathered_weights = [torch.zeros_like(weight) for _ in range(tp_size)]
|
492
|
+
torch.distributed.all_gather(gathered_weights, weight)
|
493
|
+
weight = torch.cat(gathered_weights, dim=1)
|
494
|
+
return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size]
|
495
|
+
|
496
|
+
except Exception:
|
497
|
+
logger.error(
|
498
|
+
f"Error getting weights by name {name} in LlamaForCausalLM: {get_exception_traceback()}"
|
499
|
+
)
|
500
|
+
return None
|
431
501
|
|
432
502
|
|
433
503
|
class Phi3ForCausalLM(LlamaForCausalLM):
|
@@ -17,11 +17,11 @@ from typing import Iterable, Optional, Tuple
|
|
17
17
|
import torch
|
18
18
|
from torch import nn
|
19
19
|
from transformers import LlamaConfig
|
20
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
21
20
|
|
22
21
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
23
22
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
24
23
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
24
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
25
25
|
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
|
26
26
|
|
27
27
|
|
@@ -30,7 +30,6 @@ class LlamaForClassification(nn.Module):
|
|
30
30
|
self,
|
31
31
|
config: LlamaConfig,
|
32
32
|
quant_config: Optional[QuantizationConfig] = None,
|
33
|
-
cache_config=None,
|
34
33
|
) -> None:
|
35
34
|
super().__init__()
|
36
35
|
self.config = config
|
@@ -3,10 +3,10 @@ from typing import Iterable, Tuple
|
|
3
3
|
import torch
|
4
4
|
from torch import nn
|
5
5
|
from transformers import LlamaConfig
|
6
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
7
6
|
|
8
7
|
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
9
8
|
from sglang.srt.model_executor.model_runner import ForwardBatch
|
9
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
10
10
|
from sglang.srt.models.llama import LlamaModel
|
11
11
|
|
12
12
|
|
@@ -15,7 +15,6 @@ class LlamaEmbeddingModel(nn.Module):
|
|
15
15
|
self,
|
16
16
|
config: LlamaConfig,
|
17
17
|
quant_config=None,
|
18
|
-
cache_config=None,
|
19
18
|
) -> None:
|
20
19
|
super().__init__()
|
21
20
|
self.model = LlamaModel(config, quant_config=quant_config)
|
@@ -21,6 +21,7 @@ from transformers import LlamaConfig
|
|
21
21
|
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
22
22
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
23
23
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
24
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
24
25
|
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
|
25
26
|
|
26
27
|
|
@@ -29,7 +30,6 @@ class LlamaForSequenceClassification(nn.Module):
|
|
29
30
|
self,
|
30
31
|
config: LlamaConfig,
|
31
32
|
quant_config: Optional[QuantizationConfig] = None,
|
32
|
-
cache_config=None,
|
33
33
|
) -> None:
|
34
34
|
super().__init__()
|
35
35
|
self.config = config
|
@@ -84,9 +84,8 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific
|
|
84
84
|
self,
|
85
85
|
config: LlamaConfig,
|
86
86
|
quant_config: Optional[QuantizationConfig] = None,
|
87
|
-
cache_config=None,
|
88
87
|
) -> None:
|
89
|
-
super().__init__(config, quant_config
|
88
|
+
super().__init__(config, quant_config)
|
90
89
|
self.weights = self.Weights(config.hidden_size, self.num_labels)
|
91
90
|
|
92
91
|
@torch.no_grad()
|
sglang/srt/models/llava.py
CHANGED
@@ -29,7 +29,6 @@ from transformers import (
|
|
29
29
|
SiglipVisionModel,
|
30
30
|
)
|
31
31
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
32
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
33
32
|
|
34
33
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
35
34
|
from sglang.srt.managers.schedule_batch import ImageInputs
|
@@ -39,6 +38,7 @@ from sglang.srt.mm_utils import (
|
|
39
38
|
unpad_image_shape,
|
40
39
|
)
|
41
40
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
41
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
42
42
|
from sglang.srt.models.llama import LlamaForCausalLM
|
43
43
|
from sglang.srt.models.mistral import MistralForCausalLM
|
44
44
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
@@ -57,7 +57,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
57
57
|
else:
|
58
58
|
image_aspect_ratio = "anyres"
|
59
59
|
offset_list = []
|
60
|
-
for image_s in image_sizes:
|
60
|
+
for image_idx, image_s in enumerate(image_sizes):
|
61
61
|
if len(image_sizes) > 16:
|
62
62
|
# 2x2 pooling with stride 2
|
63
63
|
new_image_feature_len = (
|
@@ -92,10 +92,6 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
92
92
|
new_w = int(new_w // times)
|
93
93
|
new_image_feature_len += new_h * (new_w + 1)
|
94
94
|
|
95
|
-
pad_ids = pad_values * (
|
96
|
-
(new_image_feature_len + len(pad_values)) // len(pad_values)
|
97
|
-
)
|
98
|
-
# print("calculated new_image_feature_len: ", new_image_feature_len)
|
99
95
|
try:
|
100
96
|
offset = input_ids.index(self.config.image_token_index)
|
101
97
|
except ValueError:
|
@@ -103,7 +99,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
103
99
|
# old_len + pad_len - 1, because we need to remove image_token_id
|
104
100
|
input_ids = (
|
105
101
|
input_ids[:offset]
|
106
|
-
+
|
102
|
+
+ [pad_values[image_idx]] * new_image_feature_len
|
107
103
|
+ input_ids[offset + 1 :]
|
108
104
|
)
|
109
105
|
offset_list.append(offset)
|
@@ -138,7 +134,6 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
138
134
|
image_inputs = forward_batch.image_inputs
|
139
135
|
|
140
136
|
if forward_batch.forward_mode.is_extend():
|
141
|
-
bs = forward_batch.batch_size
|
142
137
|
# Got List[List[str]] extend it to List[str]
|
143
138
|
# The length of the List should be equal to batch size
|
144
139
|
modalities_list = []
|
@@ -146,11 +141,16 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
146
141
|
for im in image_inputs:
|
147
142
|
if im and im.modalities is not None:
|
148
143
|
modalities_list.extend(im.modalities)
|
149
|
-
if im and im.image_offsets
|
144
|
+
if im and im.image_offsets:
|
150
145
|
max_image_offset.append(max(im.image_offsets))
|
151
146
|
else:
|
152
147
|
max_image_offset.append(-1)
|
153
148
|
|
149
|
+
# Clamp input ids. This is because the input_ids for the image tokens are
|
150
|
+
# filled with the hash values of the image for the prefix matching in the radix attention.
|
151
|
+
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
152
|
+
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
153
|
+
|
154
154
|
# Embed text inputs
|
155
155
|
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
156
156
|
|
@@ -158,6 +158,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
158
158
|
need_vision = start_positions <= np.array(max_image_offset)
|
159
159
|
|
160
160
|
if need_vision.any():
|
161
|
+
bs = forward_batch.batch_size
|
161
162
|
pixel_values = [
|
162
163
|
image_inputs[i].pixel_values for i in range(bs) if need_vision[i]
|
163
164
|
]
|
@@ -450,7 +451,6 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
|
|
450
451
|
self,
|
451
452
|
config: LlavaConfig,
|
452
453
|
quant_config: Optional[QuantizationConfig] = None,
|
453
|
-
cache_config=None,
|
454
454
|
) -> None:
|
455
455
|
super().__init__()
|
456
456
|
|
@@ -472,7 +472,6 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
|
|
472
472
|
self,
|
473
473
|
config: LlavaConfig,
|
474
474
|
quant_config: Optional[QuantizationConfig] = None,
|
475
|
-
cache_config=None,
|
476
475
|
) -> None:
|
477
476
|
super().__init__()
|
478
477
|
|
@@ -505,7 +504,6 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
|
|
505
504
|
self,
|
506
505
|
config: LlavaConfig,
|
507
506
|
quant_config: Optional[QuantizationConfig] = None,
|
508
|
-
cache_config=None,
|
509
507
|
) -> None:
|
510
508
|
super().__init__()
|
511
509
|
|
sglang/srt/models/llavavid.py
CHANGED
@@ -20,11 +20,11 @@ import torch
|
|
20
20
|
from torch import nn
|
21
21
|
from transformers import CLIPVisionModel, LlavaConfig
|
22
22
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
23
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
24
23
|
|
25
24
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
26
25
|
from sglang.srt.managers.schedule_batch import ImageInputs
|
27
26
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
27
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
28
28
|
from sglang.srt.models.llama import LlamaForCausalLM
|
29
29
|
|
30
30
|
|
@@ -33,7 +33,6 @@ class LlavaVidForCausalLM(nn.Module):
|
|
33
33
|
self,
|
34
34
|
config: LlavaConfig,
|
35
35
|
quant_config: Optional[QuantizationConfig] = None,
|
36
|
-
cache_config=None,
|
37
36
|
) -> None:
|
38
37
|
super().__init__()
|
39
38
|
self.config = config
|