sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. sglang/bench_one_batch.py +0 -7
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +25 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -2
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +29 -4
  24. sglang/srt/entrypoints/http_server.py +76 -0
  25. sglang/srt/entrypoints/openai/protocol.py +4 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +23 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +10 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +14 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
  37. sglang/srt/layers/attention/triton_backend.py +109 -73
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +58 -10
  46. sglang/srt/layers/dp_attention.py +137 -27
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +16 -18
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  68. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  69. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  70. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  71. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  72. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  73. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  75. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  76. sglang/srt/layers/moe/router.py +15 -9
  77. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  78. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  79. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  80. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  81. sglang/srt/layers/moe/topk.py +167 -83
  82. sglang/srt/layers/moe/utils.py +159 -18
  83. sglang/srt/layers/multimodal.py +156 -40
  84. sglang/srt/layers/quantization/__init__.py +18 -46
  85. sglang/srt/layers/quantization/awq.py +22 -23
  86. sglang/srt/layers/quantization/base_config.py +2 -6
  87. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  88. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
  89. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  90. sglang/srt/layers/quantization/fp8.py +127 -119
  91. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  92. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  93. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  94. sglang/srt/layers/quantization/gptq.py +17 -21
  95. sglang/srt/layers/quantization/marlin_utils.py +26 -8
  96. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  97. sglang/srt/layers/quantization/modelopt_quant.py +217 -98
  98. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  99. sglang/srt/layers/quantization/mxfp4.py +222 -39
  100. sglang/srt/layers/quantization/quark/quark.py +390 -0
  101. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  102. sglang/srt/layers/quantization/unquant.py +34 -70
  103. sglang/srt/layers/quantization/utils.py +77 -2
  104. sglang/srt/layers/quantization/w4afp8.py +7 -8
  105. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  106. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  107. sglang/srt/layers/radix_attention.py +6 -0
  108. sglang/srt/layers/rotary_embedding.py +1 -0
  109. sglang/srt/layers/sampler.py +5 -2
  110. sglang/srt/lora/layers.py +6 -2
  111. sglang/srt/lora/lora_manager.py +21 -22
  112. sglang/srt/lora/lora_registry.py +3 -3
  113. sglang/srt/lora/mem_pool.py +26 -24
  114. sglang/srt/lora/utils.py +10 -12
  115. sglang/srt/managers/cache_controller.py +80 -19
  116. sglang/srt/managers/detokenizer_manager.py +10 -2
  117. sglang/srt/managers/io_struct.py +23 -0
  118. sglang/srt/managers/mm_utils.py +1 -1
  119. sglang/srt/managers/schedule_batch.py +22 -48
  120. sglang/srt/managers/scheduler.py +28 -20
  121. sglang/srt/managers/session_controller.py +1 -1
  122. sglang/srt/managers/template_manager.py +7 -5
  123. sglang/srt/managers/tokenizer_manager.py +88 -39
  124. sglang/srt/managers/tp_worker.py +1 -0
  125. sglang/srt/managers/utils.py +59 -1
  126. sglang/srt/mem_cache/allocator.py +10 -157
  127. sglang/srt/mem_cache/allocator_ascend.py +147 -0
  128. sglang/srt/mem_cache/chunk_cache.py +1 -1
  129. sglang/srt/mem_cache/hicache_storage.py +14 -4
  130. sglang/srt/mem_cache/memory_pool.py +3 -3
  131. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  132. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  133. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  134. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  135. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  136. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  137. sglang/srt/model_executor/cuda_graph_runner.py +33 -33
  138. sglang/srt/model_executor/forward_batch_info.py +11 -10
  139. sglang/srt/model_executor/model_runner.py +93 -78
  140. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  141. sglang/srt/model_loader/loader.py +24 -6
  142. sglang/srt/models/dbrx.py +12 -6
  143. sglang/srt/models/deepseek.py +2 -1
  144. sglang/srt/models/deepseek_nextn.py +5 -2
  145. sglang/srt/models/deepseek_v2.py +226 -223
  146. sglang/srt/models/ernie4.py +2 -2
  147. sglang/srt/models/glm4_moe.py +27 -65
  148. sglang/srt/models/glm4_moe_nextn.py +2 -1
  149. sglang/srt/models/glm4v.py +52 -1
  150. sglang/srt/models/glm4v_moe.py +8 -11
  151. sglang/srt/models/gpt_oss.py +41 -76
  152. sglang/srt/models/granitemoe.py +0 -1
  153. sglang/srt/models/grok.py +376 -48
  154. sglang/srt/models/interns1.py +12 -47
  155. sglang/srt/models/internvl.py +6 -51
  156. sglang/srt/models/llama.py +10 -2
  157. sglang/srt/models/llama4.py +18 -7
  158. sglang/srt/models/minicpm3.py +0 -1
  159. sglang/srt/models/mixtral.py +0 -2
  160. sglang/srt/models/nemotron_nas.py +435 -0
  161. sglang/srt/models/olmoe.py +0 -1
  162. sglang/srt/models/phi4mm.py +3 -21
  163. sglang/srt/models/qwen2.py +2 -2
  164. sglang/srt/models/qwen2_5_vl.py +2 -0
  165. sglang/srt/models/qwen2_moe.py +23 -23
  166. sglang/srt/models/qwen3.py +2 -2
  167. sglang/srt/models/qwen3_classification.py +84 -0
  168. sglang/srt/models/qwen3_moe.py +27 -43
  169. sglang/srt/models/step3_vl.py +8 -3
  170. sglang/srt/models/xverse_moe.py +11 -5
  171. sglang/srt/multimodal/processors/base_processor.py +3 -3
  172. sglang/srt/multimodal/processors/internvl.py +7 -2
  173. sglang/srt/multimodal/processors/llava.py +11 -7
  174. sglang/srt/offloader.py +433 -0
  175. sglang/srt/operations.py +22 -2
  176. sglang/srt/reasoning_parser.py +4 -3
  177. sglang/srt/sampling/sampling_batch_info.py +7 -4
  178. sglang/srt/server_args.py +264 -105
  179. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
  180. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  181. sglang/srt/speculative/eagle_utils.py +36 -13
  182. sglang/srt/speculative/eagle_worker.py +56 -3
  183. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  184. sglang/srt/two_batch_overlap.py +20 -19
  185. sglang/srt/utils.py +68 -70
  186. sglang/test/runners.py +8 -5
  187. sglang/test/test_block_fp8.py +5 -6
  188. sglang/test/test_block_fp8_ep.py +13 -19
  189. sglang/test/test_cutlass_moe.py +4 -6
  190. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  191. sglang/test/test_fp4_moe.py +4 -3
  192. sglang/test/test_marlin_moe.py +1 -1
  193. sglang/test/test_marlin_utils.py +1 -1
  194. sglang/test/test_utils.py +7 -0
  195. sglang/utils.py +0 -1
  196. sglang/version.py +1 -1
  197. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
  198. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
  199. sglang/srt/layers/quantization/fp4.py +0 -557
  200. sglang/srt/layers/quantization/scalar_type.py +0 -352
  201. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  202. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  203. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -4,8 +4,9 @@ import torch
4
4
  from torch import nn
5
5
  from transformers import PretrainedConfig
6
6
 
7
- from sglang.srt.distributed import parallel_state
7
+ from sglang.srt.layers.attention import vision_utils
8
8
  from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
9
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
9
10
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
10
11
  from sglang.srt.managers.mm_utils import (
11
12
  MultiModalityDataPaddingPatternTokenPairs,
@@ -20,6 +21,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
20
21
  from sglang.srt.model_loader.weight_utils import default_weight_loader
21
22
  from sglang.srt.models.internvl import InternVisionModel
22
23
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
24
+ from sglang.srt.models.qwen3 import Qwen3ForCausalLM
23
25
  from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM
24
26
  from sglang.utils import logger
25
27
 
@@ -34,7 +36,7 @@ class InternS1ForConditionalGeneration(nn.Module):
34
36
  super().__init__()
35
37
  self.config = config
36
38
  self.quant_config = quant_config
37
- self._update_hf_config()
39
+ vision_utils.update_vit_attn_dummy_heads_config(self.config)
38
40
  image_size = (
39
41
  getattr(config, "force_image_size", None) or config.vision_config.image_size
40
42
  )
@@ -69,6 +71,10 @@ class InternS1ForConditionalGeneration(nn.Module):
69
71
  self.language_model = Qwen3MoeForCausalLM(
70
72
  config=config.text_config, quant_config=quant_config
71
73
  )
74
+ elif config.text_config.architectures[0] == "Qwen3ForCausalLM":
75
+ self.language_model = Qwen3ForCausalLM(
76
+ config=config.text_config, quant_config=quant_config
77
+ )
72
78
  else:
73
79
  raise NotImplementedError(
74
80
  f"{config.text_config.architectures[0]} is not implemented."
@@ -86,21 +92,6 @@ class InternS1ForConditionalGeneration(nn.Module):
86
92
  nn.Linear(llm_hidden_size, llm_hidden_size),
87
93
  )
88
94
 
89
- def _update_hf_config(self):
90
- """update hf config to support tp"""
91
- world_size = parallel_state.get_tensor_model_parallel_world_size()
92
- num_heads = self.config.vision_config.num_attention_heads
93
- head_dim = self.config.vision_config.hidden_size // num_heads
94
- num_dummy_heads = 0
95
-
96
- if num_heads % world_size != 0:
97
- num_dummy_heads = (
98
- (num_heads + world_size) // world_size
99
- ) * world_size - num_heads
100
-
101
- setattr(self.config.vision_config, "head_dim", head_dim)
102
- setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads)
103
-
104
95
  def pixel_shuffle(self, x, scale_factor=0.5):
105
96
  n, w, h, c = x.size()
106
97
  # N, W, H, C --> N, W, H * scale, C // scale
@@ -183,34 +174,6 @@ class InternS1ForConditionalGeneration(nn.Module):
183
174
 
184
175
  return helper.pad_input_tokens(input_ids, mm_inputs)
185
176
 
186
- def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
187
- """pad attn qkv weights for dummy heads"""
188
- num_dummy_heads = self.config.vision_config.num_dummy_heads
189
- if num_dummy_heads == 0:
190
- return loaded_weight
191
- head_dim = self.config.vision_config.head_dim
192
-
193
- if any([_ in name for _ in ["attn.q_proj", "attn.k_proj", "attn.v_proj"]]):
194
- if name.endswith(".weight"):
195
- dummy_shape = [num_dummy_heads, head_dim, loaded_weight.shape[-1]]
196
- elif name.endswith(".bias"):
197
- dummy_shape = [num_dummy_heads, head_dim]
198
- else:
199
- raise RuntimeError(f"Unsupported weight with name={name}")
200
- padded_weight = loaded_weight.new_zeros(dummy_shape)
201
- loaded_weight = torch.cat(
202
- [loaded_weight.unflatten(0, (-1, head_dim)), padded_weight], dim=0
203
- ).flatten(0, 1)
204
- if "attn.proj.weight" in name:
205
- padded_weight = loaded_weight.new_zeros(
206
- loaded_weight.shape[0], head_dim * num_dummy_heads
207
- )
208
- loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
209
- if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
210
- padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
211
- loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
212
- return loaded_weight
213
-
214
177
  def _mapping_interns1_name(self, name):
215
178
  names_map = {
216
179
  "lm_head.weight": "language_model.lm_head.weight",
@@ -254,7 +217,7 @@ class InternS1ForConditionalGeneration(nn.Module):
254
217
  ]
255
218
  expert_params_mapping = []
256
219
  if "Qwen3MoeForCausalLM" in self.config.text_config.architectures:
257
- expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
220
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
258
221
  ckpt_gate_proj_name="gate_proj",
259
222
  ckpt_down_proj_name="down_proj",
260
223
  ckpt_up_proj_name="up_proj",
@@ -269,7 +232,9 @@ class InternS1ForConditionalGeneration(nn.Module):
269
232
  continue
270
233
  name = self._mapping_interns1_name(name)
271
234
  if "vision_model" in name:
272
- loaded_weight = self._pad_vit_attn_dummy_heads(name, loaded_weight)
235
+ loaded_weight = vision_utils.pad_vit_attn_dummy_heads(
236
+ self.config, name, loaded_weight
237
+ )
273
238
 
274
239
  for param_name, weight_name, shard_id in stacked_params_mapping:
275
240
  if weight_name not in name:
@@ -10,9 +10,9 @@ from transformers import PretrainedConfig, PreTrainedModel
10
10
  from transformers.activations import ACT2FN
11
11
  from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
12
12
 
13
- from sglang.srt.distributed import parallel_state
13
+ from sglang.srt.layers.attention import vision_utils
14
14
  from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention
15
- from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
15
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
16
16
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
17
17
  from sglang.srt.managers.mm_utils import (
18
18
  MultiModalityDataPaddingPatternTokenPairs,
@@ -412,7 +412,7 @@ class InternVLChatModel(nn.Module):
412
412
  super().__init__()
413
413
  self.config = config
414
414
  self.quant_config = quant_config
415
- self._update_vision_config()
415
+ vision_utils.update_vit_attn_dummy_heads_config(self.config)
416
416
  image_size = config.force_image_size or config.vision_config.image_size
417
417
  patch_size = config.vision_config.patch_size
418
418
  self.patch_size = patch_size
@@ -462,21 +462,6 @@ class InternVLChatModel(nn.Module):
462
462
  nn.Linear(llm_hidden_size, llm_hidden_size),
463
463
  )
464
464
 
465
- def _update_vision_config(self):
466
- """update vision config to support tp"""
467
- world_size = parallel_state.get_tensor_model_parallel_world_size()
468
- num_heads = self.config.vision_config.num_attention_heads
469
- head_dim = self.config.vision_config.hidden_size // num_heads
470
- num_dummy_heads = 0
471
-
472
- if num_heads % world_size != 0:
473
- num_dummy_heads = (
474
- (num_heads + world_size) // world_size
475
- ) * world_size - num_heads
476
-
477
- setattr(self.config.vision_config, "head_dim", head_dim)
478
- setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads)
479
-
480
465
  def pixel_shuffle(self, x, scale_factor=0.5):
481
466
  n, w, h, c = x.size()
482
467
  # N, W, H, C --> N, W, H * scale, C // scale
@@ -559,36 +544,6 @@ class InternVLChatModel(nn.Module):
559
544
 
560
545
  return helper.pad_input_tokens(input_ids, mm_inputs)
561
546
 
562
- def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
563
- """pad attn qkv weights for dummy heads"""
564
- num_dummy_heads = self.config.vision_config.num_dummy_heads
565
- if num_dummy_heads == 0:
566
- return loaded_weight
567
- head_dim = self.config.vision_config.head_dim
568
-
569
- if "attn.qkv_proj" in name:
570
- wq, wk, wv = loaded_weight.chunk(3, dim=0)
571
- if name.endswith(".weight"):
572
- dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]]
573
- elif name.endswith(".bias"):
574
- dummy_shape = [num_dummy_heads, head_dim]
575
- else:
576
- raise RuntimeError(f"Unsupported weight with name={name}")
577
- pad_func = lambda x: torch.cat(
578
- [x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0
579
- ).flatten(0, 1)
580
- wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv)
581
- loaded_weight = torch.cat([wq, wk, wv], dim=0)
582
- if "attn.proj.weight" in name:
583
- padded_weight = loaded_weight.new_zeros(
584
- loaded_weight.shape[0], head_dim * num_dummy_heads
585
- )
586
- loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
587
- if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
588
- padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
589
- loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
590
- return loaded_weight
591
-
592
547
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
593
548
  expert_params_mapping = []
594
549
  if "InternLM2ForCausalLM" in self.config.llm_config.architectures:
@@ -616,7 +571,7 @@ class InternVLChatModel(nn.Module):
616
571
  ("gate_up_proj", "up_proj", 1),
617
572
  ]
618
573
 
619
- expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
574
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
620
575
  ckpt_gate_proj_name="gate_proj",
621
576
  ckpt_down_proj_name="down_proj",
622
577
  ckpt_up_proj_name="up_proj",
@@ -699,8 +654,8 @@ class InternVLChatModel(nn.Module):
699
654
  param, "weight_loader", default_weight_loader
700
655
  )
701
656
  if "vision_model" in name:
702
- loaded_weight = self._pad_vit_attn_dummy_heads(
703
- name, loaded_weight
657
+ loaded_weight = vision_utils.pad_vit_attn_dummy_heads(
658
+ self.config, name, loaded_weight
704
659
  )
705
660
  weight_loader(param, loaded_weight)
706
661
 
@@ -91,10 +91,18 @@ class LlamaMLP(nn.Module):
91
91
  )
92
92
  self.act_fn = SiluAndMul()
93
93
 
94
- def forward(self, x, forward_batch=None):
94
+ def forward(
95
+ self,
96
+ x,
97
+ forward_batch=None,
98
+ use_reduce_scatter: bool = False,
99
+ ):
95
100
  gate_up, _ = self.gate_up_proj(x)
96
101
  x = self.act_fn(gate_up)
97
- x, _ = self.down_proj(x)
102
+ x, _ = self.down_proj(
103
+ x,
104
+ skip_all_reduce=use_reduce_scatter,
105
+ )
98
106
  return x
99
107
 
100
108
 
@@ -31,7 +31,7 @@ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
31
31
  from sglang.srt.layers.dp_attention import (
32
32
  get_attention_tp_rank,
33
33
  get_attention_tp_size,
34
- get_local_attention_dp_size,
34
+ is_dp_attention_enabled,
35
35
  )
36
36
  from sglang.srt.layers.layernorm import RMSNorm
37
37
  from sglang.srt.layers.linear import (
@@ -45,7 +45,6 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
45
45
  from sglang.srt.layers.radix_attention import RadixAttention
46
46
  from sglang.srt.layers.rotary_embedding import get_rope
47
47
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
48
- from sglang.srt.managers.schedule_batch import global_server_args_dict
49
48
  from sglang.srt.model_executor.forward_batch_info import (
50
49
  ForwardBatch,
51
50
  ForwardMode,
@@ -131,14 +130,19 @@ class Llama4MoE(nn.Module):
131
130
  reduce_results=False, # We need to do scatter before reduce
132
131
  )
133
132
 
134
- def forward(self, hidden_states, forward_batch: ForwardBatch):
133
+ def forward(
134
+ self,
135
+ hidden_states,
136
+ forward_batch: ForwardBatch,
137
+ use_reduce_scatter: bool = False,
138
+ ):
135
139
  shared_out, routed_out = self._forward_core(
136
140
  hidden_states, forward_batch.forward_mode
137
141
  )
138
142
 
139
143
  out_aD = routed_out + shared_out
140
144
 
141
- if self.tp_size > 1:
145
+ if self.tp_size > 1 and not use_reduce_scatter:
142
146
  out_aD = tensor_model_parallel_all_reduce(out_aD)
143
147
 
144
148
  return out_aD
@@ -359,7 +363,6 @@ class Llama4DecoderLayer(nn.Module):
359
363
  rope_theta = config.rope_theta
360
364
  rope_scaling = config.rope_scaling
361
365
  max_position_embeddings = config.max_position_embeddings
362
- self.local_dp_size = get_local_attention_dp_size()
363
366
  self.attn_tp_size = get_attention_tp_size()
364
367
  self.attn_tp_rank = get_attention_tp_rank()
365
368
 
@@ -412,6 +415,7 @@ class Llama4DecoderLayer(nn.Module):
412
415
  layer_scatter_modes=self.layer_scatter_modes,
413
416
  input_layernorm=self.input_layernorm,
414
417
  post_attention_layernorm=self.post_attention_layernorm,
418
+ allow_reduce_scatter=True,
415
419
  )
416
420
 
417
421
  def _is_moe_layer(self, layer_id: int) -> bool:
@@ -441,8 +445,15 @@ class Llama4DecoderLayer(nn.Module):
441
445
  hidden_states, residual, forward_batch
442
446
  )
443
447
 
448
+ # For DP with padding, reduce scatter can be used instead of all-reduce.
449
+ use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
450
+ forward_batch
451
+ )
452
+
444
453
  # Fully Connected
445
- hidden_states = self.feed_forward(hidden_states, forward_batch)
454
+ hidden_states = self.feed_forward(
455
+ hidden_states, forward_batch, use_reduce_scatter
456
+ )
446
457
  hidden_states, residual = self.layer_communicator.postprocess_layer(
447
458
  hidden_states, residual, forward_batch
448
459
  )
@@ -466,7 +477,7 @@ class Llama4Model(nn.Module):
466
477
  config.hidden_size,
467
478
  quant_config=quant_config,
468
479
  prefix=add_prefix("embed_tokens", prefix),
469
- enable_tp=not global_server_args_dict["enable_dp_attention"],
480
+ enable_tp=not is_dp_attention_enabled(),
470
481
  )
471
482
  self.layers = make_layers(
472
483
  config.num_hidden_layers,
@@ -37,7 +37,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
37
37
  ParallelLMHead,
38
38
  VocabParallelEmbedding,
39
39
  )
40
- from sglang.srt.managers.schedule_batch import global_server_args_dict
41
40
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
42
41
  from sglang.srt.model_loader.weight_utils import default_weight_loader
43
42
  from sglang.srt.utils import add_prefix, is_cuda
@@ -47,7 +47,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
47
47
  ParallelLMHead,
48
48
  VocabParallelEmbedding,
49
49
  )
50
- from sglang.srt.managers.schedule_batch import global_server_args_dict
51
50
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
52
51
  from sglang.srt.model_loader.weight_utils import default_weight_loader
53
52
  from sglang.srt.utils import add_prefix, make_layers
@@ -104,7 +103,6 @@ class MixtralMoE(nn.Module):
104
103
  intermediate_size=intermediate_size,
105
104
  params_dtype=params_dtype,
106
105
  quant_config=quant_config,
107
- tp_size=tp_size,
108
106
  prefix=add_prefix("experts", prefix),
109
107
  )
110
108