sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -7
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +25 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -2
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +29 -4
- sglang/srt/entrypoints/http_server.py +76 -0
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/entrypoints/openai/serving_chat.py +23 -6
- sglang/srt/entrypoints/openai/serving_completions.py +10 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +14 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
- sglang/srt/layers/attention/triton_backend.py +109 -73
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
- sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +58 -10
- sglang/srt/layers/dp_attention.py +137 -27
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +16 -18
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +18 -46
- sglang/srt/layers/quantization/awq.py +22 -23
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +17 -21
- sglang/srt/layers/quantization/marlin_utils.py +26 -8
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +217 -98
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +222 -39
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +77 -2
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/layers.py +6 -2
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +80 -19
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +23 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +22 -48
- sglang/srt/managers/scheduler.py +28 -20
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +88 -39
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +10 -157
- sglang/srt/mem_cache/allocator_ascend.py +147 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +33 -33
- sglang/srt/model_executor/forward_batch_info.py +11 -10
- sglang/srt/model_executor/model_runner.py +93 -78
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +5 -2
- sglang/srt/models/deepseek_v2.py +226 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +27 -65
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +41 -76
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama.py +10 -2
- sglang/srt/models/llama4.py +18 -7
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +23 -23
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +84 -0
- sglang/srt/models/qwen3_moe.py +27 -43
- sglang/srt/models/step3_vl.py +8 -3
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +22 -2
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +264 -105
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +20 -19
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
- sglang/srt/layers/quantization/fp4.py +0 -557
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
sglang/srt/models/interns1.py
CHANGED
@@ -4,8 +4,9 @@ import torch
|
|
4
4
|
from torch import nn
|
5
5
|
from transformers import PretrainedConfig
|
6
6
|
|
7
|
-
from sglang.srt.
|
7
|
+
from sglang.srt.layers.attention import vision_utils
|
8
8
|
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
9
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
9
10
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
10
11
|
from sglang.srt.managers.mm_utils import (
|
11
12
|
MultiModalityDataPaddingPatternTokenPairs,
|
@@ -20,6 +21,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
20
21
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
21
22
|
from sglang.srt.models.internvl import InternVisionModel
|
22
23
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
24
|
+
from sglang.srt.models.qwen3 import Qwen3ForCausalLM
|
23
25
|
from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM
|
24
26
|
from sglang.utils import logger
|
25
27
|
|
@@ -34,7 +36,7 @@ class InternS1ForConditionalGeneration(nn.Module):
|
|
34
36
|
super().__init__()
|
35
37
|
self.config = config
|
36
38
|
self.quant_config = quant_config
|
37
|
-
self.
|
39
|
+
vision_utils.update_vit_attn_dummy_heads_config(self.config)
|
38
40
|
image_size = (
|
39
41
|
getattr(config, "force_image_size", None) or config.vision_config.image_size
|
40
42
|
)
|
@@ -69,6 +71,10 @@ class InternS1ForConditionalGeneration(nn.Module):
|
|
69
71
|
self.language_model = Qwen3MoeForCausalLM(
|
70
72
|
config=config.text_config, quant_config=quant_config
|
71
73
|
)
|
74
|
+
elif config.text_config.architectures[0] == "Qwen3ForCausalLM":
|
75
|
+
self.language_model = Qwen3ForCausalLM(
|
76
|
+
config=config.text_config, quant_config=quant_config
|
77
|
+
)
|
72
78
|
else:
|
73
79
|
raise NotImplementedError(
|
74
80
|
f"{config.text_config.architectures[0]} is not implemented."
|
@@ -86,21 +92,6 @@ class InternS1ForConditionalGeneration(nn.Module):
|
|
86
92
|
nn.Linear(llm_hidden_size, llm_hidden_size),
|
87
93
|
)
|
88
94
|
|
89
|
-
def _update_hf_config(self):
|
90
|
-
"""update hf config to support tp"""
|
91
|
-
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
92
|
-
num_heads = self.config.vision_config.num_attention_heads
|
93
|
-
head_dim = self.config.vision_config.hidden_size // num_heads
|
94
|
-
num_dummy_heads = 0
|
95
|
-
|
96
|
-
if num_heads % world_size != 0:
|
97
|
-
num_dummy_heads = (
|
98
|
-
(num_heads + world_size) // world_size
|
99
|
-
) * world_size - num_heads
|
100
|
-
|
101
|
-
setattr(self.config.vision_config, "head_dim", head_dim)
|
102
|
-
setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads)
|
103
|
-
|
104
95
|
def pixel_shuffle(self, x, scale_factor=0.5):
|
105
96
|
n, w, h, c = x.size()
|
106
97
|
# N, W, H, C --> N, W, H * scale, C // scale
|
@@ -183,34 +174,6 @@ class InternS1ForConditionalGeneration(nn.Module):
|
|
183
174
|
|
184
175
|
return helper.pad_input_tokens(input_ids, mm_inputs)
|
185
176
|
|
186
|
-
def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
|
187
|
-
"""pad attn qkv weights for dummy heads"""
|
188
|
-
num_dummy_heads = self.config.vision_config.num_dummy_heads
|
189
|
-
if num_dummy_heads == 0:
|
190
|
-
return loaded_weight
|
191
|
-
head_dim = self.config.vision_config.head_dim
|
192
|
-
|
193
|
-
if any([_ in name for _ in ["attn.q_proj", "attn.k_proj", "attn.v_proj"]]):
|
194
|
-
if name.endswith(".weight"):
|
195
|
-
dummy_shape = [num_dummy_heads, head_dim, loaded_weight.shape[-1]]
|
196
|
-
elif name.endswith(".bias"):
|
197
|
-
dummy_shape = [num_dummy_heads, head_dim]
|
198
|
-
else:
|
199
|
-
raise RuntimeError(f"Unsupported weight with name={name}")
|
200
|
-
padded_weight = loaded_weight.new_zeros(dummy_shape)
|
201
|
-
loaded_weight = torch.cat(
|
202
|
-
[loaded_weight.unflatten(0, (-1, head_dim)), padded_weight], dim=0
|
203
|
-
).flatten(0, 1)
|
204
|
-
if "attn.proj.weight" in name:
|
205
|
-
padded_weight = loaded_weight.new_zeros(
|
206
|
-
loaded_weight.shape[0], head_dim * num_dummy_heads
|
207
|
-
)
|
208
|
-
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
|
209
|
-
if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
|
210
|
-
padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
|
211
|
-
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
|
212
|
-
return loaded_weight
|
213
|
-
|
214
177
|
def _mapping_interns1_name(self, name):
|
215
178
|
names_map = {
|
216
179
|
"lm_head.weight": "language_model.lm_head.weight",
|
@@ -254,7 +217,7 @@ class InternS1ForConditionalGeneration(nn.Module):
|
|
254
217
|
]
|
255
218
|
expert_params_mapping = []
|
256
219
|
if "Qwen3MoeForCausalLM" in self.config.text_config.architectures:
|
257
|
-
expert_params_mapping =
|
220
|
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
258
221
|
ckpt_gate_proj_name="gate_proj",
|
259
222
|
ckpt_down_proj_name="down_proj",
|
260
223
|
ckpt_up_proj_name="up_proj",
|
@@ -269,7 +232,9 @@ class InternS1ForConditionalGeneration(nn.Module):
|
|
269
232
|
continue
|
270
233
|
name = self._mapping_interns1_name(name)
|
271
234
|
if "vision_model" in name:
|
272
|
-
loaded_weight =
|
235
|
+
loaded_weight = vision_utils.pad_vit_attn_dummy_heads(
|
236
|
+
self.config, name, loaded_weight
|
237
|
+
)
|
273
238
|
|
274
239
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
275
240
|
if weight_name not in name:
|
sglang/srt/models/internvl.py
CHANGED
@@ -10,9 +10,9 @@ from transformers import PretrainedConfig, PreTrainedModel
|
|
10
10
|
from transformers.activations import ACT2FN
|
11
11
|
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
12
12
|
|
13
|
-
from sglang.srt.
|
13
|
+
from sglang.srt.layers.attention import vision_utils
|
14
14
|
from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention
|
15
|
-
from sglang.srt.layers.moe.
|
15
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
16
16
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
17
17
|
from sglang.srt.managers.mm_utils import (
|
18
18
|
MultiModalityDataPaddingPatternTokenPairs,
|
@@ -412,7 +412,7 @@ class InternVLChatModel(nn.Module):
|
|
412
412
|
super().__init__()
|
413
413
|
self.config = config
|
414
414
|
self.quant_config = quant_config
|
415
|
-
self.
|
415
|
+
vision_utils.update_vit_attn_dummy_heads_config(self.config)
|
416
416
|
image_size = config.force_image_size or config.vision_config.image_size
|
417
417
|
patch_size = config.vision_config.patch_size
|
418
418
|
self.patch_size = patch_size
|
@@ -462,21 +462,6 @@ class InternVLChatModel(nn.Module):
|
|
462
462
|
nn.Linear(llm_hidden_size, llm_hidden_size),
|
463
463
|
)
|
464
464
|
|
465
|
-
def _update_vision_config(self):
|
466
|
-
"""update vision config to support tp"""
|
467
|
-
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
468
|
-
num_heads = self.config.vision_config.num_attention_heads
|
469
|
-
head_dim = self.config.vision_config.hidden_size // num_heads
|
470
|
-
num_dummy_heads = 0
|
471
|
-
|
472
|
-
if num_heads % world_size != 0:
|
473
|
-
num_dummy_heads = (
|
474
|
-
(num_heads + world_size) // world_size
|
475
|
-
) * world_size - num_heads
|
476
|
-
|
477
|
-
setattr(self.config.vision_config, "head_dim", head_dim)
|
478
|
-
setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads)
|
479
|
-
|
480
465
|
def pixel_shuffle(self, x, scale_factor=0.5):
|
481
466
|
n, w, h, c = x.size()
|
482
467
|
# N, W, H, C --> N, W, H * scale, C // scale
|
@@ -559,36 +544,6 @@ class InternVLChatModel(nn.Module):
|
|
559
544
|
|
560
545
|
return helper.pad_input_tokens(input_ids, mm_inputs)
|
561
546
|
|
562
|
-
def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
|
563
|
-
"""pad attn qkv weights for dummy heads"""
|
564
|
-
num_dummy_heads = self.config.vision_config.num_dummy_heads
|
565
|
-
if num_dummy_heads == 0:
|
566
|
-
return loaded_weight
|
567
|
-
head_dim = self.config.vision_config.head_dim
|
568
|
-
|
569
|
-
if "attn.qkv_proj" in name:
|
570
|
-
wq, wk, wv = loaded_weight.chunk(3, dim=0)
|
571
|
-
if name.endswith(".weight"):
|
572
|
-
dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]]
|
573
|
-
elif name.endswith(".bias"):
|
574
|
-
dummy_shape = [num_dummy_heads, head_dim]
|
575
|
-
else:
|
576
|
-
raise RuntimeError(f"Unsupported weight with name={name}")
|
577
|
-
pad_func = lambda x: torch.cat(
|
578
|
-
[x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0
|
579
|
-
).flatten(0, 1)
|
580
|
-
wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv)
|
581
|
-
loaded_weight = torch.cat([wq, wk, wv], dim=0)
|
582
|
-
if "attn.proj.weight" in name:
|
583
|
-
padded_weight = loaded_weight.new_zeros(
|
584
|
-
loaded_weight.shape[0], head_dim * num_dummy_heads
|
585
|
-
)
|
586
|
-
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
|
587
|
-
if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
|
588
|
-
padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
|
589
|
-
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
|
590
|
-
return loaded_weight
|
591
|
-
|
592
547
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
593
548
|
expert_params_mapping = []
|
594
549
|
if "InternLM2ForCausalLM" in self.config.llm_config.architectures:
|
@@ -616,7 +571,7 @@ class InternVLChatModel(nn.Module):
|
|
616
571
|
("gate_up_proj", "up_proj", 1),
|
617
572
|
]
|
618
573
|
|
619
|
-
expert_params_mapping =
|
574
|
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
620
575
|
ckpt_gate_proj_name="gate_proj",
|
621
576
|
ckpt_down_proj_name="down_proj",
|
622
577
|
ckpt_up_proj_name="up_proj",
|
@@ -699,8 +654,8 @@ class InternVLChatModel(nn.Module):
|
|
699
654
|
param, "weight_loader", default_weight_loader
|
700
655
|
)
|
701
656
|
if "vision_model" in name:
|
702
|
-
loaded_weight =
|
703
|
-
name, loaded_weight
|
657
|
+
loaded_weight = vision_utils.pad_vit_attn_dummy_heads(
|
658
|
+
self.config, name, loaded_weight
|
704
659
|
)
|
705
660
|
weight_loader(param, loaded_weight)
|
706
661
|
|
sglang/srt/models/llama.py
CHANGED
@@ -91,10 +91,18 @@ class LlamaMLP(nn.Module):
|
|
91
91
|
)
|
92
92
|
self.act_fn = SiluAndMul()
|
93
93
|
|
94
|
-
def forward(
|
94
|
+
def forward(
|
95
|
+
self,
|
96
|
+
x,
|
97
|
+
forward_batch=None,
|
98
|
+
use_reduce_scatter: bool = False,
|
99
|
+
):
|
95
100
|
gate_up, _ = self.gate_up_proj(x)
|
96
101
|
x = self.act_fn(gate_up)
|
97
|
-
x, _ = self.down_proj(
|
102
|
+
x, _ = self.down_proj(
|
103
|
+
x,
|
104
|
+
skip_all_reduce=use_reduce_scatter,
|
105
|
+
)
|
98
106
|
return x
|
99
107
|
|
100
108
|
|
sglang/srt/models/llama4.py
CHANGED
@@ -31,7 +31,7 @@ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
|
31
31
|
from sglang.srt.layers.dp_attention import (
|
32
32
|
get_attention_tp_rank,
|
33
33
|
get_attention_tp_size,
|
34
|
-
|
34
|
+
is_dp_attention_enabled,
|
35
35
|
)
|
36
36
|
from sglang.srt.layers.layernorm import RMSNorm
|
37
37
|
from sglang.srt.layers.linear import (
|
@@ -45,7 +45,6 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
45
45
|
from sglang.srt.layers.radix_attention import RadixAttention
|
46
46
|
from sglang.srt.layers.rotary_embedding import get_rope
|
47
47
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
48
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
49
48
|
from sglang.srt.model_executor.forward_batch_info import (
|
50
49
|
ForwardBatch,
|
51
50
|
ForwardMode,
|
@@ -131,14 +130,19 @@ class Llama4MoE(nn.Module):
|
|
131
130
|
reduce_results=False, # We need to do scatter before reduce
|
132
131
|
)
|
133
132
|
|
134
|
-
def forward(
|
133
|
+
def forward(
|
134
|
+
self,
|
135
|
+
hidden_states,
|
136
|
+
forward_batch: ForwardBatch,
|
137
|
+
use_reduce_scatter: bool = False,
|
138
|
+
):
|
135
139
|
shared_out, routed_out = self._forward_core(
|
136
140
|
hidden_states, forward_batch.forward_mode
|
137
141
|
)
|
138
142
|
|
139
143
|
out_aD = routed_out + shared_out
|
140
144
|
|
141
|
-
if self.tp_size > 1:
|
145
|
+
if self.tp_size > 1 and not use_reduce_scatter:
|
142
146
|
out_aD = tensor_model_parallel_all_reduce(out_aD)
|
143
147
|
|
144
148
|
return out_aD
|
@@ -359,7 +363,6 @@ class Llama4DecoderLayer(nn.Module):
|
|
359
363
|
rope_theta = config.rope_theta
|
360
364
|
rope_scaling = config.rope_scaling
|
361
365
|
max_position_embeddings = config.max_position_embeddings
|
362
|
-
self.local_dp_size = get_local_attention_dp_size()
|
363
366
|
self.attn_tp_size = get_attention_tp_size()
|
364
367
|
self.attn_tp_rank = get_attention_tp_rank()
|
365
368
|
|
@@ -412,6 +415,7 @@ class Llama4DecoderLayer(nn.Module):
|
|
412
415
|
layer_scatter_modes=self.layer_scatter_modes,
|
413
416
|
input_layernorm=self.input_layernorm,
|
414
417
|
post_attention_layernorm=self.post_attention_layernorm,
|
418
|
+
allow_reduce_scatter=True,
|
415
419
|
)
|
416
420
|
|
417
421
|
def _is_moe_layer(self, layer_id: int) -> bool:
|
@@ -441,8 +445,15 @@ class Llama4DecoderLayer(nn.Module):
|
|
441
445
|
hidden_states, residual, forward_batch
|
442
446
|
)
|
443
447
|
|
448
|
+
# For DP with padding, reduce scatter can be used instead of all-reduce.
|
449
|
+
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
|
450
|
+
forward_batch
|
451
|
+
)
|
452
|
+
|
444
453
|
# Fully Connected
|
445
|
-
hidden_states = self.feed_forward(
|
454
|
+
hidden_states = self.feed_forward(
|
455
|
+
hidden_states, forward_batch, use_reduce_scatter
|
456
|
+
)
|
446
457
|
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
447
458
|
hidden_states, residual, forward_batch
|
448
459
|
)
|
@@ -466,7 +477,7 @@ class Llama4Model(nn.Module):
|
|
466
477
|
config.hidden_size,
|
467
478
|
quant_config=quant_config,
|
468
479
|
prefix=add_prefix("embed_tokens", prefix),
|
469
|
-
enable_tp=not
|
480
|
+
enable_tp=not is_dp_attention_enabled(),
|
470
481
|
)
|
471
482
|
self.layers = make_layers(
|
472
483
|
config.num_hidden_layers,
|
sglang/srt/models/minicpm3.py
CHANGED
@@ -37,7 +37,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
37
37
|
ParallelLMHead,
|
38
38
|
VocabParallelEmbedding,
|
39
39
|
)
|
40
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
41
40
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
42
41
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
43
42
|
from sglang.srt.utils import add_prefix, is_cuda
|
sglang/srt/models/mixtral.py
CHANGED
@@ -47,7 +47,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
47
47
|
ParallelLMHead,
|
48
48
|
VocabParallelEmbedding,
|
49
49
|
)
|
50
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
51
50
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
52
51
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
53
52
|
from sglang.srt.utils import add_prefix, make_layers
|
@@ -104,7 +103,6 @@ class MixtralMoE(nn.Module):
|
|
104
103
|
intermediate_size=intermediate_size,
|
105
104
|
params_dtype=params_dtype,
|
106
105
|
quant_config=quant_config,
|
107
|
-
tp_size=tp_size,
|
108
106
|
prefix=add_prefix("experts", prefix),
|
109
107
|
)
|
110
108
|
|