sglang 0.3.6.post3__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/bench_one_batch.py +4 -0
- sglang/bench_serving.py +13 -0
- sglang/check_env.py +1 -1
- sglang/srt/_custom_ops.py +118 -0
- sglang/srt/configs/device_config.py +17 -0
- sglang/srt/configs/load_config.py +84 -0
- sglang/srt/configs/model_config.py +161 -4
- sglang/srt/configs/qwen2vl.py +5 -8
- sglang/srt/constrained/outlines_backend.py +11 -1
- sglang/srt/constrained/outlines_jump_forward.py +8 -1
- sglang/srt/constrained/xgrammar_backend.py +5 -5
- sglang/srt/distributed/__init__.py +3 -0
- sglang/srt/distributed/communication_op.py +34 -0
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
- sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
- sglang/srt/distributed/device_communicators/pynccl.py +204 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
- sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
- sglang/srt/distributed/parallel_state.py +1275 -0
- sglang/srt/distributed/utils.py +223 -0
- sglang/srt/hf_transformers_utils.py +37 -1
- sglang/srt/layers/attention/__init__.py +5 -2
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
- sglang/srt/layers/attention/flashinfer_backend.py +33 -20
- sglang/srt/layers/attention/torch_native_backend.py +299 -0
- sglang/srt/layers/attention/triton_backend.py +22 -8
- sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
- sglang/srt/layers/ep_moe/__init__.py +0 -0
- sglang/srt/layers/ep_moe/kernels.py +349 -0
- sglang/srt/layers/ep_moe/layer.py +661 -0
- sglang/srt/layers/fused_moe_patch.py +20 -11
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +17 -3
- sglang/srt/layers/quantization/__init__.py +36 -2
- sglang/srt/layers/quantization/fp8.py +559 -0
- sglang/srt/layers/quantization/fp8_utils.py +27 -0
- sglang/srt/layers/radix_attention.py +4 -2
- sglang/srt/layers/sampler.py +2 -0
- sglang/srt/layers/torchao_utils.py +23 -45
- sglang/srt/layers/vocab_parallel_embedding.py +1 -0
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/managers/io_struct.py +48 -2
- sglang/srt/managers/schedule_batch.py +19 -14
- sglang/srt/managers/schedule_policy.py +7 -4
- sglang/srt/managers/scheduler.py +145 -85
- sglang/srt/managers/tokenizer_manager.py +166 -68
- sglang/srt/managers/tp_worker.py +36 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +28 -8
- sglang/srt/mem_cache/memory_pool.py +5 -1
- sglang/srt/model_executor/cuda_graph_runner.py +30 -7
- sglang/srt/model_executor/forward_batch_info.py +9 -4
- sglang/srt/model_executor/model_runner.py +146 -153
- sglang/srt/model_loader/__init__.py +34 -0
- sglang/srt/model_loader/loader.py +1139 -0
- sglang/srt/model_loader/utils.py +41 -0
- sglang/srt/model_loader/weight_utils.py +640 -0
- sglang/srt/model_parallel.py +1 -5
- sglang/srt/models/baichuan.py +9 -10
- sglang/srt/models/chatglm.py +6 -15
- sglang/srt/models/commandr.py +4 -5
- sglang/srt/models/dbrx.py +2 -3
- sglang/srt/models/deepseek.py +4 -11
- sglang/srt/models/deepseek_v2.py +90 -18
- sglang/srt/models/exaone.py +2 -3
- sglang/srt/models/gemma.py +2 -6
- sglang/srt/models/gemma2.py +3 -14
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/gpt2.py +5 -12
- sglang/srt/models/gpt_bigcode.py +6 -22
- sglang/srt/models/grok.py +3 -8
- sglang/srt/models/internlm2.py +2 -3
- sglang/srt/models/internlm2_reward.py +0 -1
- sglang/srt/models/llama.py +96 -31
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_embedding.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +1 -4
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +4 -7
- sglang/srt/models/minicpm3.py +6 -19
- sglang/srt/models/mixtral.py +24 -14
- sglang/srt/models/mixtral_quant.py +2 -3
- sglang/srt/models/mllama.py +3 -7
- sglang/srt/models/olmo.py +2 -8
- sglang/srt/models/olmo2.py +0 -1
- sglang/srt/models/olmoe.py +3 -5
- sglang/srt/models/phi3_small.py +8 -13
- sglang/srt/models/qwen.py +2 -3
- sglang/srt/models/qwen2.py +10 -9
- sglang/srt/models/qwen2_moe.py +4 -16
- sglang/srt/models/qwen2_vl.py +2 -6
- sglang/srt/models/registry.py +99 -0
- sglang/srt/models/stablelm.py +2 -3
- sglang/srt/models/torch_native_llama.py +6 -17
- sglang/srt/models/xverse.py +2 -4
- sglang/srt/models/xverse_moe.py +4 -11
- sglang/srt/models/yivl.py +2 -3
- sglang/srt/openai_api/adapter.py +9 -5
- sglang/srt/openai_api/protocol.py +1 -0
- sglang/srt/sampling/sampling_batch_info.py +9 -8
- sglang/srt/server.py +270 -173
- sglang/srt/server_args.py +102 -29
- sglang/srt/utils.py +295 -28
- sglang/test/test_utils.py +7 -0
- sglang/version.py +1 -1
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
- sglang-0.4.0.post1.dist-info/RECORD +189 -0
- sglang-0.3.6.post3.dist-info/RECORD +0 -162
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
sglang/srt/models/llava.py
CHANGED
@@ -29,7 +29,6 @@ from transformers import (
|
|
29
29
|
SiglipVisionModel,
|
30
30
|
)
|
31
31
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
32
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
33
32
|
|
34
33
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
35
34
|
from sglang.srt.managers.schedule_batch import ImageInputs
|
@@ -39,6 +38,7 @@ from sglang.srt.mm_utils import (
|
|
39
38
|
unpad_image_shape,
|
40
39
|
)
|
41
40
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
41
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
42
42
|
from sglang.srt.models.llama import LlamaForCausalLM
|
43
43
|
from sglang.srt.models.mistral import MistralForCausalLM
|
44
44
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
@@ -451,7 +451,6 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
|
|
451
451
|
self,
|
452
452
|
config: LlavaConfig,
|
453
453
|
quant_config: Optional[QuantizationConfig] = None,
|
454
|
-
cache_config=None,
|
455
454
|
) -> None:
|
456
455
|
super().__init__()
|
457
456
|
|
@@ -473,7 +472,6 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
|
|
473
472
|
self,
|
474
473
|
config: LlavaConfig,
|
475
474
|
quant_config: Optional[QuantizationConfig] = None,
|
476
|
-
cache_config=None,
|
477
475
|
) -> None:
|
478
476
|
super().__init__()
|
479
477
|
|
@@ -506,7 +504,6 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
|
|
506
504
|
self,
|
507
505
|
config: LlavaConfig,
|
508
506
|
quant_config: Optional[QuantizationConfig] = None,
|
509
|
-
cache_config=None,
|
510
507
|
) -> None:
|
511
508
|
super().__init__()
|
512
509
|
|
sglang/srt/models/llavavid.py
CHANGED
@@ -20,11 +20,11 @@ import torch
|
|
20
20
|
from torch import nn
|
21
21
|
from transformers import CLIPVisionModel, LlavaConfig
|
22
22
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
23
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
24
23
|
|
25
24
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
26
25
|
from sglang.srt.managers.schedule_batch import ImageInputs
|
27
26
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
27
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
28
28
|
from sglang.srt.models.llama import LlamaForCausalLM
|
29
29
|
|
30
30
|
|
@@ -33,7 +33,6 @@ class LlavaVidForCausalLM(nn.Module):
|
|
33
33
|
self,
|
34
34
|
config: LlavaConfig,
|
35
35
|
quant_config: Optional[QuantizationConfig] = None,
|
36
|
-
cache_config=None,
|
37
36
|
) -> None:
|
38
37
|
super().__init__()
|
39
38
|
self.config = config
|
sglang/srt/models/minicpm.py
CHANGED
@@ -20,7 +20,6 @@ import torch
|
|
20
20
|
from torch import nn
|
21
21
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
22
22
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
23
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
24
23
|
|
25
24
|
from sglang.srt.layers.activation import SiluAndMul
|
26
25
|
from sglang.srt.layers.layernorm import RMSNorm
|
@@ -37,6 +36,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
37
36
|
VocabParallelEmbedding,
|
38
37
|
)
|
39
38
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
39
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
40
40
|
|
41
41
|
|
42
42
|
class MiniCPMMLP(nn.Module):
|
@@ -275,7 +275,6 @@ class MiniCPMForCausalLM(nn.Module):
|
|
275
275
|
self,
|
276
276
|
config,
|
277
277
|
quant_config: Optional[QuantizationConfig] = None,
|
278
|
-
cache_config=None,
|
279
278
|
) -> None:
|
280
279
|
super().__init__()
|
281
280
|
self.config = config
|
@@ -308,12 +307,10 @@ class MiniCPMForCausalLM(nn.Module):
|
|
308
307
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
309
308
|
hidden_states = hidden_states / self.scale_width
|
310
309
|
if self.config.tie_word_embeddings:
|
311
|
-
|
310
|
+
lm_head = self.model.embed_tokens
|
312
311
|
else:
|
313
|
-
|
314
|
-
return self.logits_processor(
|
315
|
-
input_ids, hidden_states, lm_head_weight, forward_batch
|
316
|
-
)
|
312
|
+
lm_head = self.lm_head
|
313
|
+
return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch)
|
317
314
|
|
318
315
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
319
316
|
stacked_params_mapping = [
|
sglang/srt/models/minicpm3.py
CHANGED
@@ -27,7 +27,6 @@ from vllm.model_executor.layers.linear import (
|
|
27
27
|
RowParallelLinear,
|
28
28
|
)
|
29
29
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
30
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
31
30
|
|
32
31
|
from sglang.srt.layers.activation import SiluAndMul
|
33
32
|
from sglang.srt.layers.layernorm import RMSNorm
|
@@ -40,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
40
39
|
)
|
41
40
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
42
41
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
42
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
43
43
|
from sglang.srt.utils import is_flashinfer_available
|
44
44
|
|
45
45
|
if is_flashinfer_available():
|
@@ -105,7 +105,6 @@ class MiniCPM3Attention(nn.Module):
|
|
105
105
|
rope_theta: float = 10000,
|
106
106
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
107
107
|
max_position_embeddings: int = 8192,
|
108
|
-
cache_config=None,
|
109
108
|
quant_config: Optional[QuantizationConfig] = None,
|
110
109
|
layer_id=None,
|
111
110
|
) -> None:
|
@@ -249,7 +248,6 @@ class MiniCPM3AttentionMLA(nn.Module):
|
|
249
248
|
rope_theta: float = 10000,
|
250
249
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
251
250
|
max_position_embeddings: int = 8192,
|
252
|
-
cache_config=None,
|
253
251
|
quant_config: Optional[QuantizationConfig] = None,
|
254
252
|
layer_id=None,
|
255
253
|
) -> None:
|
@@ -406,7 +404,6 @@ class MiniCPM3DecoderLayer(nn.Module):
|
|
406
404
|
self,
|
407
405
|
config: PretrainedConfig,
|
408
406
|
layer_id: int,
|
409
|
-
cache_config=None,
|
410
407
|
quant_config: Optional[QuantizationConfig] = None,
|
411
408
|
) -> None:
|
412
409
|
super().__init__()
|
@@ -430,7 +427,6 @@ class MiniCPM3DecoderLayer(nn.Module):
|
|
430
427
|
rope_theta=rope_theta,
|
431
428
|
rope_scaling=rope_scaling,
|
432
429
|
max_position_embeddings=max_position_embeddings,
|
433
|
-
cache_config=cache_config,
|
434
430
|
quant_config=quant_config,
|
435
431
|
layer_id=layer_id,
|
436
432
|
)
|
@@ -449,7 +445,6 @@ class MiniCPM3DecoderLayer(nn.Module):
|
|
449
445
|
rope_theta=rope_theta,
|
450
446
|
rope_scaling=rope_scaling,
|
451
447
|
max_position_embeddings=max_position_embeddings,
|
452
|
-
cache_config=cache_config,
|
453
448
|
quant_config=quant_config,
|
454
449
|
layer_id=layer_id,
|
455
450
|
)
|
@@ -498,7 +493,6 @@ class MiniCPM3Model(nn.Module):
|
|
498
493
|
def __init__(
|
499
494
|
self,
|
500
495
|
config: PretrainedConfig,
|
501
|
-
cache_config=None,
|
502
496
|
quant_config: Optional[QuantizationConfig] = None,
|
503
497
|
) -> None:
|
504
498
|
super().__init__()
|
@@ -512,9 +506,7 @@ class MiniCPM3Model(nn.Module):
|
|
512
506
|
)
|
513
507
|
self.layers = nn.ModuleList(
|
514
508
|
[
|
515
|
-
MiniCPM3DecoderLayer(
|
516
|
-
config, i, cache_config=cache_config, quant_config=quant_config
|
517
|
-
)
|
509
|
+
MiniCPM3DecoderLayer(config, i, quant_config=quant_config)
|
518
510
|
for i in range(config.num_hidden_layers)
|
519
511
|
]
|
520
512
|
)
|
@@ -549,7 +541,6 @@ class MiniCPM3ForCausalLM(nn.Module):
|
|
549
541
|
def __init__(
|
550
542
|
self,
|
551
543
|
config: PretrainedConfig,
|
552
|
-
cache_config=None,
|
553
544
|
quant_config: Optional[QuantizationConfig] = None,
|
554
545
|
) -> None:
|
555
546
|
super().__init__()
|
@@ -557,9 +548,7 @@ class MiniCPM3ForCausalLM(nn.Module):
|
|
557
548
|
|
558
549
|
self.num_experts = getattr(self.config, "num_experts", 0)
|
559
550
|
self.quant_config = quant_config
|
560
|
-
self.model = MiniCPM3Model(
|
561
|
-
config, cache_config=cache_config, quant_config=quant_config
|
562
|
-
)
|
551
|
+
self.model = MiniCPM3Model(config, quant_config=quant_config)
|
563
552
|
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
564
553
|
if not self.config.tie_word_embeddings:
|
565
554
|
self.lm_head = ParallelLMHead(
|
@@ -585,12 +574,10 @@ class MiniCPM3ForCausalLM(nn.Module):
|
|
585
574
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
586
575
|
hidden_states = hidden_states / self.scale_width
|
587
576
|
if self.config.tie_word_embeddings:
|
588
|
-
|
577
|
+
lm_head = self.model.embed_tokens
|
589
578
|
else:
|
590
|
-
|
591
|
-
return self.logits_processor(
|
592
|
-
input_ids, hidden_states, lm_head_weight, forward_batch
|
593
|
-
)
|
579
|
+
lm_head = self.lm_head
|
580
|
+
return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch)
|
594
581
|
|
595
582
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
596
583
|
stacked_params_mapping = [
|
sglang/srt/models/mixtral.py
CHANGED
@@ -21,10 +21,13 @@ from typing import Iterable, Optional, Tuple
|
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
23
23
|
from transformers import MixtralConfig
|
24
|
-
from vllm.distributed import
|
24
|
+
from vllm.distributed import (
|
25
|
+
get_tensor_model_parallel_world_size,
|
26
|
+
tensor_model_parallel_all_reduce,
|
27
|
+
)
|
25
28
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
26
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
27
29
|
|
30
|
+
from sglang.srt.layers.ep_moe.layer import EPMoE
|
28
31
|
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
29
32
|
from sglang.srt.layers.layernorm import RMSNorm
|
30
33
|
from sglang.srt.layers.linear import (
|
@@ -35,13 +38,13 @@ from sglang.srt.layers.linear import (
|
|
35
38
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
36
39
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
37
40
|
from sglang.srt.layers.radix_attention import RadixAttention
|
38
|
-
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
39
41
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
40
42
|
ParallelLMHead,
|
41
43
|
VocabParallelEmbedding,
|
42
44
|
)
|
43
45
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
44
46
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
47
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
45
48
|
|
46
49
|
|
47
50
|
class MixtralMoE(nn.Module):
|
@@ -65,6 +68,7 @@ class MixtralMoE(nn.Module):
|
|
65
68
|
prefix: str = "",
|
66
69
|
):
|
67
70
|
super().__init__()
|
71
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
68
72
|
self.hidden_size = hidden_size
|
69
73
|
|
70
74
|
# Gate always runs at half / full precision for now.
|
@@ -76,14 +80,13 @@ class MixtralMoE(nn.Module):
|
|
76
80
|
quant_config=None,
|
77
81
|
prefix=f"{prefix}.gate",
|
78
82
|
)
|
79
|
-
|
80
|
-
self.experts =
|
83
|
+
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
84
|
+
self.experts = MoEImpl(
|
81
85
|
num_experts=num_experts,
|
82
86
|
top_k=top_k,
|
83
87
|
hidden_size=hidden_size,
|
84
88
|
intermediate_size=intermediate_size,
|
85
89
|
params_dtype=params_dtype,
|
86
|
-
reduce_results=True,
|
87
90
|
renormalize=True,
|
88
91
|
quant_config=quant_config,
|
89
92
|
tp_size=tp_size,
|
@@ -97,6 +100,8 @@ class MixtralMoE(nn.Module):
|
|
97
100
|
# router_logits: (num_tokens, n_experts)
|
98
101
|
router_logits, _ = self.gate(hidden_states)
|
99
102
|
final_hidden_states = self.experts(hidden_states, router_logits)
|
103
|
+
if self.tp_size > 1:
|
104
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
100
105
|
return final_hidden_states.view(orig_shape)
|
101
106
|
|
102
107
|
|
@@ -291,12 +296,10 @@ class MixtralForCausalLM(nn.Module):
|
|
291
296
|
self,
|
292
297
|
config: MixtralConfig,
|
293
298
|
quant_config: Optional[QuantizationConfig] = None,
|
294
|
-
cache_config=None,
|
295
299
|
) -> None:
|
296
300
|
super().__init__()
|
297
301
|
self.config = config
|
298
302
|
self.quant_config = quant_config
|
299
|
-
self.torchao_config = global_server_args_dict["torchao_config"]
|
300
303
|
self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
|
301
304
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
302
305
|
self.logits_processor = LogitsProcessor(config)
|
@@ -310,7 +313,7 @@ class MixtralForCausalLM(nn.Module):
|
|
310
313
|
) -> torch.Tensor:
|
311
314
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
312
315
|
return self.logits_processor(
|
313
|
-
input_ids, hidden_states, self.lm_head
|
316
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
314
317
|
)
|
315
318
|
|
316
319
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
@@ -323,7 +326,8 @@ class MixtralForCausalLM(nn.Module):
|
|
323
326
|
|
324
327
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
325
328
|
# (param_name, weight_name, expert_id, shard_id)
|
326
|
-
|
329
|
+
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
330
|
+
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
327
331
|
ckpt_gate_proj_name="w1",
|
328
332
|
ckpt_down_proj_name="w2",
|
329
333
|
ckpt_up_proj_name="w3",
|
@@ -340,7 +344,9 @@ class MixtralForCausalLM(nn.Module):
|
|
340
344
|
continue
|
341
345
|
name = name.replace(weight_name, param_name)
|
342
346
|
# Skip loading extra bias for GPTQ models.
|
343
|
-
if
|
347
|
+
if (
|
348
|
+
name.endswith(".bias") or name.endswith("_bias")
|
349
|
+
) and name not in params_dict:
|
344
350
|
continue
|
345
351
|
|
346
352
|
param = params_dict[name]
|
@@ -354,6 +360,10 @@ class MixtralForCausalLM(nn.Module):
|
|
354
360
|
continue
|
355
361
|
name = name.replace(weight_name, param_name)
|
356
362
|
|
363
|
+
if (
|
364
|
+
name.endswith(".bias") or name.endswith("_bias")
|
365
|
+
) and name not in params_dict:
|
366
|
+
continue
|
357
367
|
param = params_dict[name]
|
358
368
|
weight_loader = param.weight_loader
|
359
369
|
weight_loader(
|
@@ -366,7 +376,9 @@ class MixtralForCausalLM(nn.Module):
|
|
366
376
|
break
|
367
377
|
else:
|
368
378
|
# Skip loading extra bias for GPTQ models.
|
369
|
-
if
|
379
|
+
if (
|
380
|
+
name.endswith(".bias") or name.endswith("_bias")
|
381
|
+
) and name not in params_dict:
|
370
382
|
continue
|
371
383
|
# Skip loading kv_scale from ckpts towards new design.
|
372
384
|
if name.endswith(".kv_scale") and name not in params_dict:
|
@@ -380,7 +392,5 @@ class MixtralForCausalLM(nn.Module):
|
|
380
392
|
)
|
381
393
|
weight_loader(param, loaded_weight)
|
382
394
|
|
383
|
-
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
384
|
-
|
385
395
|
|
386
396
|
EntryClass = MixtralForCausalLM
|
@@ -29,7 +29,6 @@ from vllm.distributed import (
|
|
29
29
|
tensor_model_parallel_all_reduce,
|
30
30
|
)
|
31
31
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
32
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
33
32
|
|
34
33
|
from sglang.srt.layers.layernorm import RMSNorm
|
35
34
|
from sglang.srt.layers.linear import (
|
@@ -45,6 +44,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
45
44
|
VocabParallelEmbedding,
|
46
45
|
)
|
47
46
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
47
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
48
48
|
|
49
49
|
|
50
50
|
class MixtralMLP(nn.Module):
|
@@ -324,7 +324,6 @@ class QuantMixtralForCausalLM(nn.Module):
|
|
324
324
|
self,
|
325
325
|
config: MixtralConfig,
|
326
326
|
quant_config: Optional[QuantizationConfig] = None,
|
327
|
-
cache_config=None,
|
328
327
|
) -> None:
|
329
328
|
super().__init__()
|
330
329
|
self.config = config
|
@@ -343,7 +342,7 @@ class QuantMixtralForCausalLM(nn.Module):
|
|
343
342
|
) -> torch.Tensor:
|
344
343
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
345
344
|
return self.logits_processor(
|
346
|
-
input_ids, hidden_states, self.lm_head
|
345
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
347
346
|
)
|
348
347
|
|
349
348
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/mllama.py
CHANGED
@@ -15,7 +15,6 @@ from transformers.models.mllama.modeling_mllama import (
|
|
15
15
|
_prepare_aspect_ratio_attention_mask,
|
16
16
|
)
|
17
17
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
18
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
19
18
|
|
20
19
|
from sglang.srt.layers.activation import get_act_fn
|
21
20
|
from sglang.srt.layers.layernorm import RMSNorm
|
@@ -34,6 +33,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
34
33
|
)
|
35
34
|
from sglang.srt.managers.schedule_batch import ImageInputs
|
36
35
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
36
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
37
37
|
from sglang.srt.models.llama import LlamaDecoderLayer, LlamaMLP
|
38
38
|
|
39
39
|
|
@@ -654,7 +654,6 @@ class MllamaTextModel(nn.Module):
|
|
654
654
|
self,
|
655
655
|
config: config_mllama.MllamaTextConfig,
|
656
656
|
quant_config: Optional[QuantizationConfig],
|
657
|
-
cache_config=None,
|
658
657
|
):
|
659
658
|
super().__init__()
|
660
659
|
self.padding_id = config.pad_token_id
|
@@ -732,11 +731,10 @@ class MllamaForCausalLM(nn.Module):
|
|
732
731
|
self,
|
733
732
|
config: config_mllama.MllamaTextConfig,
|
734
733
|
quant_config: Optional[QuantizationConfig],
|
735
|
-
cache_config=None,
|
736
734
|
):
|
737
735
|
super().__init__()
|
738
736
|
self.vocab_size = config.vocab_size
|
739
|
-
self.model = MllamaTextModel(config,
|
737
|
+
self.model = MllamaTextModel(config, quant_config)
|
740
738
|
self.lm_head = ParallelLMHead(
|
741
739
|
config.vocab_size,
|
742
740
|
config.hidden_size,
|
@@ -772,7 +770,6 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
772
770
|
self,
|
773
771
|
config: config_mllama.MllamaConfig,
|
774
772
|
quant_config: Optional[QuantizationConfig] = None,
|
775
|
-
cache_config=None,
|
776
773
|
):
|
777
774
|
super().__init__()
|
778
775
|
self.vocab_size = config.text_config.vocab_size
|
@@ -787,7 +784,6 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
787
784
|
self.vision_model = MllamaVisionModel(config.vision_config)
|
788
785
|
self.language_model = MllamaForCausalLM(
|
789
786
|
config.text_config,
|
790
|
-
cache_config=cache_config,
|
791
787
|
quant_config=quant_config,
|
792
788
|
)
|
793
789
|
self.multi_modal_projector = nn.Linear(
|
@@ -966,7 +962,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
966
962
|
skip_cross_attention=skip_cross_attention,
|
967
963
|
)
|
968
964
|
return self.logits_processor(
|
969
|
-
input_ids, hidden_states, self.language_model.lm_head
|
965
|
+
input_ids, hidden_states, self.language_model.lm_head, forward_batch
|
970
966
|
)
|
971
967
|
|
972
968
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/olmo.py
CHANGED
@@ -22,7 +22,6 @@ from torch import nn
|
|
22
22
|
from transformers import OlmoConfig
|
23
23
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
24
24
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
25
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
26
25
|
|
27
26
|
from sglang.srt.layers.activation import SiluAndMul
|
28
27
|
from sglang.srt.layers.linear import (
|
@@ -38,6 +37,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
38
37
|
VocabParallelEmbedding,
|
39
38
|
)
|
40
39
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
40
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
41
41
|
from sglang.srt.utils import make_layers
|
42
42
|
|
43
43
|
|
@@ -274,7 +274,6 @@ class OlmoForCausalLM(nn.Module):
|
|
274
274
|
def __init__(
|
275
275
|
self,
|
276
276
|
config: OlmoConfig,
|
277
|
-
cache_config=None,
|
278
277
|
quant_config: Optional[QuantizationConfig] = None,
|
279
278
|
):
|
280
279
|
super().__init__()
|
@@ -306,7 +305,7 @@ class OlmoForCausalLM(nn.Module):
|
|
306
305
|
input_embeds=input_embeds,
|
307
306
|
)
|
308
307
|
return self.logits_processor(
|
309
|
-
input_ids, hidden_states, self.lm_head
|
308
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
310
309
|
)
|
311
310
|
|
312
311
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
@@ -326,11 +325,6 @@ class OlmoForCausalLM(nn.Module):
|
|
326
325
|
# Models trained using ColossalAI may include these tensors in
|
327
326
|
# the checkpoint. Skip them.
|
328
327
|
continue
|
329
|
-
# With tie_word_embeddings, we can skip lm_head.weight
|
330
|
-
# The weight might appear unnecessarily in the files if the model is
|
331
|
-
# processed with quantization, LoRA, fine-tuning, etc.
|
332
|
-
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
333
|
-
continue
|
334
328
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
335
329
|
if weight_name not in name:
|
336
330
|
continue
|
sglang/srt/models/olmo2.py
CHANGED
sglang/srt/models/olmoe.py
CHANGED
@@ -34,8 +34,6 @@ from vllm.model_executor.layers.linear import (
|
|
34
34
|
RowParallelLinear,
|
35
35
|
)
|
36
36
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
37
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
38
|
-
from vllm.utils import print_warning_once
|
39
37
|
|
40
38
|
from sglang.srt.layers.activation import SiluAndMul
|
41
39
|
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
@@ -48,7 +46,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
48
46
|
VocabParallelEmbedding,
|
49
47
|
)
|
50
48
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
51
|
-
from sglang.srt.
|
49
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
50
|
+
from sglang.srt.utils import make_layers, print_warning_once
|
52
51
|
|
53
52
|
|
54
53
|
class OlmoeMoE(nn.Module):
|
@@ -300,7 +299,6 @@ class OlmoeForCausalLM(nn.Module):
|
|
300
299
|
def __init__(
|
301
300
|
self,
|
302
301
|
config: PretrainedConfig,
|
303
|
-
cache_config=None,
|
304
302
|
quant_config: Optional[QuantizationConfig] = None,
|
305
303
|
) -> None:
|
306
304
|
super().__init__()
|
@@ -321,7 +319,7 @@ class OlmoeForCausalLM(nn.Module):
|
|
321
319
|
) -> torch.Tensor:
|
322
320
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
323
321
|
return self.logits_processor(
|
324
|
-
input_ids, hidden_states, self.lm_head
|
322
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
325
323
|
)
|
326
324
|
|
327
325
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/phi3_small.py
CHANGED
@@ -7,8 +7,6 @@ from transformers import Phi3Config
|
|
7
7
|
from transformers.configuration_utils import PretrainedConfig
|
8
8
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
9
9
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
10
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
11
|
-
from vllm.model_executor.models.utils import make_layers
|
12
10
|
|
13
11
|
from sglang.srt.layers.linear import (
|
14
12
|
MergedColumnParallelLinear,
|
@@ -19,14 +17,14 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
|
|
19
17
|
from sglang.srt.layers.pooler import Pooler, PoolingType
|
20
18
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
21
19
|
from sglang.srt.layers.radix_attention import RadixAttention
|
22
|
-
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
23
20
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
24
21
|
DEFAULT_VOCAB_PADDING_SIZE,
|
25
22
|
ParallelLMHead,
|
26
23
|
VocabParallelEmbedding,
|
27
24
|
)
|
28
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
29
25
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
26
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
27
|
+
from sglang.srt.utils import make_layers
|
30
28
|
|
31
29
|
|
32
30
|
@torch.jit.script
|
@@ -235,7 +233,6 @@ class Phi3SmallDecoderLayer(nn.Module):
|
|
235
233
|
self,
|
236
234
|
config: PretrainedConfig,
|
237
235
|
layer_id: int,
|
238
|
-
cache_config=None,
|
239
236
|
quant_config: Optional[QuantizationConfig] = None,
|
240
237
|
):
|
241
238
|
super().__init__()
|
@@ -286,7 +283,6 @@ class Phi3SmallModel(nn.Module):
|
|
286
283
|
super().__init__()
|
287
284
|
|
288
285
|
self.config = config
|
289
|
-
cache_config = None
|
290
286
|
self.embed_tokens = VocabParallelEmbedding(
|
291
287
|
config.vocab_size, config.hidden_size
|
292
288
|
)
|
@@ -294,7 +290,7 @@ class Phi3SmallModel(nn.Module):
|
|
294
290
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
295
291
|
config.num_hidden_layers,
|
296
292
|
lambda prefix: Phi3SmallDecoderLayer(
|
297
|
-
config, int(prefix.split(".")[-1]),
|
293
|
+
config, int(prefix.split(".")[-1]), quant_config
|
298
294
|
),
|
299
295
|
prefix=f"{prefix}.layers",
|
300
296
|
)
|
@@ -339,7 +335,6 @@ class Phi3SmallForCausalLM(nn.Module):
|
|
339
335
|
self,
|
340
336
|
config: Phi3Config,
|
341
337
|
quant_config: Optional[QuantizationConfig] = None,
|
342
|
-
cache_config=None,
|
343
338
|
):
|
344
339
|
|
345
340
|
super().__init__()
|
@@ -351,7 +346,6 @@ class Phi3SmallForCausalLM(nn.Module):
|
|
351
346
|
quant_config=quant_config,
|
352
347
|
prefix="model",
|
353
348
|
)
|
354
|
-
self.torchao_config = global_server_args_dict["torchao_config"]
|
355
349
|
self.vocab_size = config.vocab_size
|
356
350
|
self.mup_width_multiplier = config.mup_width_multiplier
|
357
351
|
self.lm_head = ParallelLMHead(
|
@@ -397,10 +391,13 @@ class Phi3SmallForCausalLM(nn.Module):
|
|
397
391
|
|
398
392
|
def compute_logits(
|
399
393
|
self,
|
394
|
+
input_ids: torch.LongTensor,
|
400
395
|
hidden_states: torch.Tensor,
|
401
396
|
sampling_metadata,
|
402
397
|
) -> Optional[torch.Tensor]:
|
403
|
-
logits = self.logits_processor(
|
398
|
+
logits = self.logits_processor(
|
399
|
+
input_ids, self.lm_head, hidden_states, sampling_metadata
|
400
|
+
)
|
404
401
|
if self.dummy_token_indices is not None and logits is not None:
|
405
402
|
logits.index_fill_(-1, self.dummy_token_indices, -torch.inf)
|
406
403
|
return logits
|
@@ -422,7 +419,7 @@ class Phi3SmallForCausalLM(nn.Module):
|
|
422
419
|
|
423
420
|
if not get_embedding:
|
424
421
|
return self.logits_processor(
|
425
|
-
input_ids, hidden_states, self.lm_head
|
422
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
426
423
|
)
|
427
424
|
|
428
425
|
else:
|
@@ -441,7 +438,5 @@ class Phi3SmallForCausalLM(nn.Module):
|
|
441
438
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
442
439
|
weight_loader(param, loaded_weight)
|
443
440
|
|
444
|
-
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
445
|
-
|
446
441
|
|
447
442
|
EntryClass = Phi3SmallForCausalLM
|
sglang/srt/models/qwen.py
CHANGED
@@ -22,7 +22,6 @@ from torch import nn
|
|
22
22
|
from transformers import PretrainedConfig
|
23
23
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
24
24
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
25
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
26
25
|
|
27
26
|
from sglang.srt.layers.activation import SiluAndMul
|
28
27
|
from sglang.srt.layers.layernorm import RMSNorm
|
@@ -39,6 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
39
38
|
VocabParallelEmbedding,
|
40
39
|
)
|
41
40
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
41
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
42
42
|
|
43
43
|
|
44
44
|
class QWenMLP(nn.Module):
|
@@ -242,7 +242,6 @@ class QWenLMHeadModel(nn.Module):
|
|
242
242
|
self,
|
243
243
|
config: PretrainedConfig,
|
244
244
|
quant_config: Optional[QuantizationConfig] = None,
|
245
|
-
cache_config=None,
|
246
245
|
):
|
247
246
|
super().__init__()
|
248
247
|
self.config = config
|
@@ -260,7 +259,7 @@ class QWenLMHeadModel(nn.Module):
|
|
260
259
|
):
|
261
260
|
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
262
261
|
return self.logits_processor(
|
263
|
-
input_ids, hidden_states, self.lm_head
|
262
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
264
263
|
)
|
265
264
|
|
266
265
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|