sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -363,31 +363,6 @@ class GraniteForCausalLM(nn.Module):
363
363
  else:
364
364
  return self.pooler(hidden_states, forward_batch)
365
365
 
366
- def get_hidden_dim(self, module_name):
367
- # return input_dim, output_dim
368
- if module_name in ["q_proj", "o_proj", "qkv_proj"]:
369
- return self.config.hidden_size, self.config.hidden_size
370
- elif module_name in ["kv_proj"]:
371
- return self.config.hidden_size, self.config.hidden_size // (
372
- self.config.num_attention_heads // self.config.num_key_value_heads
373
- )
374
- elif module_name == "gate_up_proj":
375
- return self.config.hidden_size, self.config.intermediate_size
376
- elif module_name == "down_proj":
377
- return self.config.intermediate_size, self.config.hidden_size
378
- else:
379
- raise NotImplementedError()
380
-
381
- def get_module_name(self, name):
382
- params_mapping = {
383
- "q_proj": "qkv_proj",
384
- "k_proj": "qkv_proj",
385
- "v_proj": "qkv_proj",
386
- "gate_proj": "gate_up_proj",
387
- "up_proj": "gate_up_proj",
388
- }
389
- return params_mapping.get(name, name)
390
-
391
366
  def get_module_name_from_weight_name(self, name):
392
367
  for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
393
368
  if weight_name in name:
@@ -532,31 +532,6 @@ class LlamaForCausalLM(nn.Module):
532
532
  def get_input_embeddings(self) -> nn.Embedding:
533
533
  return self.model.embed_tokens
534
534
 
535
- def get_hidden_dim(self, module_name):
536
- # return input_dim, output_dim
537
- if module_name in ["q_proj", "o_proj", "qkv_proj"]:
538
- return self.config.hidden_size, self.config.hidden_size
539
- elif module_name in ["kv_proj"]:
540
- return self.config.hidden_size, self.config.hidden_size // (
541
- self.config.num_attention_heads // self.config.num_key_value_heads
542
- )
543
- elif module_name == "gate_up_proj":
544
- return self.config.hidden_size, self.config.intermediate_size
545
- elif module_name == "down_proj":
546
- return self.config.intermediate_size, self.config.hidden_size
547
- else:
548
- raise NotImplementedError()
549
-
550
- def get_module_name(self, name):
551
- params_mapping = {
552
- "q_proj": "qkv_proj",
553
- "k_proj": "qkv_proj",
554
- "v_proj": "qkv_proj",
555
- "gate_proj": "gate_up_proj",
556
- "up_proj": "gate_up_proj",
557
- }
558
- return params_mapping.get(name, name)
559
-
560
535
  def get_module_name_from_weight_name(self, name):
561
536
  for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
562
537
  if weight_name in name:
@@ -204,7 +204,7 @@ class Llama4Attention(nn.Module):
204
204
  super().__init__()
205
205
  self.layer_id = layer_id
206
206
  self.hidden_size = hidden_size
207
- self.use_rope = int((layer_id + 1) % 4 != 0)
207
+ self.use_rope = (layer_id + 1) % 4 != 0
208
208
  self.use_qk_norm = config.use_qk_norm and self.use_rope
209
209
 
210
210
  attn_tp_rank = get_attention_tp_rank()
@@ -107,6 +107,7 @@ class Qwen2Attention(nn.Module):
107
107
  rope_scaling: Optional[Dict[str, Any]] = None,
108
108
  max_position_embeddings: int = 32768,
109
109
  quant_config: Optional[QuantizationConfig] = None,
110
+ dual_chunk_attention_config: Optional[dict[str, Any]] = None,
110
111
  prefix: str = "",
111
112
  ) -> None:
112
113
  super().__init__()
@@ -158,6 +159,7 @@ class Qwen2Attention(nn.Module):
158
159
  max_position=max_position_embeddings,
159
160
  base=rope_theta,
160
161
  rope_scaling=rope_scaling,
162
+ dual_chunk_attention_config=dual_chunk_attention_config,
161
163
  )
162
164
  self.attn = RadixAttention(
163
165
  self.num_heads,
@@ -198,6 +200,9 @@ class Qwen2DecoderLayer(nn.Module):
198
200
  rope_scaling = getattr(config, "rope_scaling", None)
199
201
  max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
200
202
  head_dim = getattr(config, "head_dim", None)
203
+ dual_chunk_attention_config = getattr(
204
+ config, "dual_chunk_attention_config", None
205
+ )
201
206
  self.self_attn = Qwen2Attention(
202
207
  hidden_size=self.hidden_size,
203
208
  num_heads=config.num_attention_heads,
@@ -208,6 +213,7 @@ class Qwen2DecoderLayer(nn.Module):
208
213
  rope_scaling=rope_scaling,
209
214
  max_position_embeddings=max_position_embeddings,
210
215
  quant_config=quant_config,
216
+ dual_chunk_attention_config=dual_chunk_attention_config,
211
217
  prefix=add_prefix("self_attn", prefix),
212
218
  )
213
219
  self.mlp = Qwen2MLP(
@@ -114,7 +114,7 @@ class Qwen2_5_VisionBlock(nn.Module):
114
114
  num_heads: int,
115
115
  hidden_act="silu",
116
116
  norm_layer: Type[nn.Module] = None,
117
- attn_implementation: Optional[str] = "sdpa",
117
+ attn_implementation: Optional[str] = None,
118
118
  quant_config: Optional[QuantizationConfig] = None,
119
119
  prefix: str = "",
120
120
  ) -> None:
@@ -123,7 +123,12 @@ class Qwen2_5_VisionBlock(nn.Module):
123
123
  norm_layer = partial(nn.LayerNorm, eps=1e-6)
124
124
  self.norm1 = Qwen2RMSNorm(dim, eps=1e-6)
125
125
  self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
126
- if attn_implementation == "sdpa":
126
+
127
+ if attn_implementation is None:
128
+ softmax_in_single_precision = False
129
+ qkv_backend = None
130
+ flatten_batch = True
131
+ elif attn_implementation == "sdpa":
127
132
  softmax_in_single_precision = False
128
133
  qkv_backend = "sdpa"
129
134
  flatten_batch = True
@@ -268,7 +273,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
268
273
  num_heads=num_heads,
269
274
  hidden_act=vision_config.hidden_act,
270
275
  norm_layer=norm_layer,
271
- attn_implementation="sdpa",
272
276
  quant_config=quant_config,
273
277
  prefix=add_prefix(f"blocks.{i}", prefix),
274
278
  )
@@ -52,7 +52,11 @@ from sglang.srt.managers.mm_utils import (
52
52
  MultiModalityDataPaddingPatternMultimodalTokens,
53
53
  general_mm_embed_routine,
54
54
  )
55
- from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
55
+ from sglang.srt.managers.schedule_batch import (
56
+ Modality,
57
+ MultimodalDataItem,
58
+ MultimodalInputs,
59
+ )
56
60
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
57
61
  from sglang.srt.model_loader.weight_utils import default_weight_loader
58
62
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
@@ -106,15 +110,10 @@ class Qwen2AudioForConditionalGeneration(nn.Module):
106
110
  self.language_model = Qwen2ForCausalLM(
107
111
  config.text_config, quant_config, prefix=add_prefix("model", prefix)
108
112
  )
113
+ self.pattern = MultiModalityDataPaddingPatternMultimodalTokens()
109
114
 
110
115
  def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
111
- # Get all special token IDs for audio
112
- audio_token_id: int = getattr(
113
- mm_inputs, "audio_token_id", mm_inputs.im_token_id
114
- )
115
-
116
- pattern = MultiModalityDataPaddingPatternMultimodalTokens([audio_token_id])
117
- return pattern.pad_input_tokens(input_ids, mm_inputs)
116
+ return self.pattern.pad_input_tokens(input_ids, mm_inputs)
118
117
 
119
118
  def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
120
119
  # Extract audio features from input items
@@ -143,7 +142,9 @@ class Qwen2AudioForConditionalGeneration(nn.Module):
143
142
  input_ids=input_ids,
144
143
  forward_batch=forward_batch,
145
144
  language_model=self.language_model,
146
- audio_data_embedding_func=self.get_audio_feature,
145
+ data_embedding_funcs={
146
+ Modality.AUDIO: self.get_audio_feature,
147
+ },
147
148
  positions=positions,
148
149
  )
149
150
 
@@ -210,6 +210,7 @@ class Qwen2MoeAttention(nn.Module):
210
210
  max_position_embeddings: int = 8192,
211
211
  qkv_bias: int = True,
212
212
  quant_config: Optional[QuantizationConfig] = None,
213
+ dual_chunk_attention_config: Optional[dict[str, Any]] = None,
213
214
  prefix: str = "",
214
215
  ) -> None:
215
216
  super().__init__()
@@ -267,6 +268,7 @@ class Qwen2MoeAttention(nn.Module):
267
268
  max_position=max_position_embeddings,
268
269
  base=rope_theta,
269
270
  rope_scaling=rope_scaling,
271
+ dual_chunk_attention_config=dual_chunk_attention_config,
270
272
  )
271
273
  self.attn = RadixAttention(
272
274
  self.num_heads,
@@ -308,6 +310,9 @@ class Qwen2MoeDecoderLayer(nn.Module):
308
310
  rope_scaling = getattr(config, "rope_scaling", None)
309
311
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
310
312
  qkv_bias = getattr(config, "qkv_bias", True)
313
+ dual_chunk_attention_config = getattr(
314
+ config, "dual_chunk_attention_config", None
315
+ )
311
316
  self.self_attn = Qwen2MoeAttention(
312
317
  hidden_size=self.hidden_size,
313
318
  num_heads=config.num_attention_heads,
@@ -317,6 +322,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
317
322
  rope_scaling=rope_scaling,
318
323
  max_position_embeddings=max_position_embeddings,
319
324
  quant_config=quant_config,
325
+ dual_chunk_attention_config=dual_chunk_attention_config,
320
326
  qkv_bias=qkv_bias,
321
327
  prefix=add_prefix("self_attn", prefix),
322
328
  )
@@ -330,30 +330,6 @@ class Qwen3ForCausalLM(nn.Module):
330
330
  def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
331
331
  return self.model.get_input_embeddings(input_ids)
332
332
 
333
- def get_hidden_dim(self, module_name: str) -> Tuple[int]:
334
- # return input_dim, output_dim
335
- if module_name in ["q_proj", "qkv_proj"]:
336
- return (
337
- self.config.hidden_size,
338
- self.config.head_dim * self.config.num_attention_heads,
339
- )
340
- elif module_name in ["o_proj"]:
341
- return (
342
- self.config.head_dim * self.config.num_attention_heads,
343
- self.config.hidden_size,
344
- )
345
- elif module_name in ["kv_proj"]:
346
- return (
347
- self.config.hidden_size,
348
- self.config.head_dim * self.config.num_key_value_heads,
349
- )
350
- elif module_name == "gate_up_proj":
351
- return self.config.hidden_size, self.config.intermediate_size
352
- elif module_name == "down_proj":
353
- return self.config.intermediate_size, self.config.hidden_size
354
- else:
355
- raise NotImplementedError()
356
-
357
333
  @torch.no_grad()
358
334
  def forward(
359
335
  self,
@@ -295,6 +295,7 @@ class Qwen3MoeAttention(nn.Module):
295
295
  attention_bias: bool = False,
296
296
  quant_config: Optional[QuantizationConfig] = None,
297
297
  prefix: str = "",
298
+ dual_chunk_attention_config: Optional[dict[str, Any]] = None,
298
299
  alt_stream: Optional[torch.cuda.Stream] = None,
299
300
  ) -> None:
300
301
  super().__init__()
@@ -353,6 +354,7 @@ class Qwen3MoeAttention(nn.Module):
353
354
  max_position=max_position_embeddings,
354
355
  base=rope_theta,
355
356
  rope_scaling=rope_scaling,
357
+ dual_chunk_attention_config=dual_chunk_attention_config,
356
358
  )
357
359
  self.attn = RadixAttention(
358
360
  self.num_heads,
@@ -458,6 +460,9 @@ class Qwen3MoeDecoderLayer(nn.Module):
458
460
  )
459
461
  rms_norm_eps = config.rms_norm_eps
460
462
  attention_bias = config.attention_bias
463
+ dual_chunk_attention_config = getattr(
464
+ config, "dual_chunk_attention_config", None
465
+ )
461
466
  self.self_attn = Qwen3MoeAttention(
462
467
  hidden_size=self.hidden_size,
463
468
  num_heads=config.num_attention_heads,
@@ -471,6 +476,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
471
476
  attention_bias=attention_bias,
472
477
  quant_config=quant_config,
473
478
  prefix=add_prefix("self_attn", prefix),
479
+ dual_chunk_attention_config=dual_chunk_attention_config,
474
480
  alt_stream=alt_stream,
475
481
  )
476
482
 
@@ -766,7 +772,10 @@ class Qwen3MoeForCausalLM(nn.Module):
766
772
  num_experts=self.config.num_experts,
767
773
  )
768
774
 
769
- params_dict = dict(self.named_parameters())
775
+ # Cache params_dict to avoid repeated expensive traversal of model parameters
776
+ if not hasattr(self, "_cached_params_dict"):
777
+ self._cached_params_dict = dict(self.named_parameters())
778
+ params_dict = self._cached_params_dict
770
779
  for name, loaded_weight in weights:
771
780
  layer_id = get_layer_id(name)
772
781
  if (
@@ -805,11 +814,22 @@ class Qwen3MoeForCausalLM(nn.Module):
805
814
  weight_loader(param, loaded_weight, shard_id)
806
815
  break
807
816
  else:
817
+ # Track if this is an expert weight to enable early skipping
818
+ is_expert_weight = False
819
+
808
820
  for mapping in expert_params_mapping:
809
821
  param_name, weight_name, expert_id, shard_id = mapping
810
822
  if weight_name not in name:
811
823
  continue
824
+
825
+ # Mark as expert weight regardless of whether we can process it
826
+ is_expert_weight = True
827
+
812
828
  name = name.replace(weight_name, param_name)
829
+ if name not in params_dict:
830
+ # Expert weight not on this rank, will be skipped below
831
+ continue
832
+
813
833
  param = params_dict[name]
814
834
  weight_loader = param.weight_loader
815
835
  weight_loader(
@@ -821,6 +841,10 @@ class Qwen3MoeForCausalLM(nn.Module):
821
841
  )
822
842
  break
823
843
  else:
844
+ if is_expert_weight:
845
+ # This is an expert weight but not mapped to this rank, skip all remaining processing
846
+ continue
847
+
824
848
  # Skip loading extra bias for GPTQ models.
825
849
  if name.endswith(".bias") and name not in params_dict:
826
850
  continue
@@ -837,11 +861,13 @@ class Qwen3MoeForCausalLM(nn.Module):
837
861
  logger.warning(f"Parameter {name} not found in params_dict")
838
862
 
839
863
  # TODO mimic deepseek
840
- self.routed_experts_weights_of_layer = {
841
- layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
842
- for layer_id in range(self.start_layer, self.end_layer)
843
- if isinstance(self.model.layers[layer_id].mlp, Qwen3MoeSparseMoeBlock)
844
- }
864
+ # Lazy initialization of expert weights cache to avoid slowing down load_weights
865
+ if not hasattr(self, "routed_experts_weights_of_layer"):
866
+ self.routed_experts_weights_of_layer = {
867
+ layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
868
+ for layer_id in range(self.start_layer, self.end_layer)
869
+ if isinstance(self.model.layers[layer_id].mlp, Qwen3MoeSparseMoeBlock)
870
+ }
845
871
 
846
872
  @classmethod
847
873
  def get_model_config_for_expert_location(cls, config):
@@ -83,7 +83,7 @@ def import_model_classes():
83
83
  try:
84
84
  module = importlib.import_module(name)
85
85
  except Exception as e:
86
- logger.warning(f"Ignore import error when loading {name}. " f"{e}")
86
+ logger.warning(f"Ignore import error when loading {name}: {e}")
87
87
  continue
88
88
  if hasattr(module, "EntryClass"):
89
89
  entry = module.EntryClass
@@ -531,11 +531,18 @@ class Step3VisionMLP(nn.Module):
531
531
  prefix: str = "",
532
532
  ) -> None:
533
533
  super().__init__()
534
+ # Since this is a dense model,
535
+ # the MLP component likewise adopts a DP-MLP approach modeled after DP Attention.
536
+ # This choice may not represent the optimal solution and remains open to further deliberation.
537
+ attn_tp_rank = get_attention_tp_rank()
538
+ attn_tp_size = get_attention_tp_size()
534
539
  self.fc1 = ColumnParallelLinear(
535
540
  dim,
536
541
  intermediate_size,
537
542
  bias=bias,
538
543
  quant_config=quant_config,
544
+ tp_rank=attn_tp_rank,
545
+ tp_size=attn_tp_size,
539
546
  prefix=add_prefix("gate_proj", prefix),
540
547
  )
541
548
  self.act = ACT2FN[hidden_act] # quick_gelu
@@ -544,6 +551,8 @@ class Step3VisionMLP(nn.Module):
544
551
  dim,
545
552
  bias=bias,
546
553
  quant_config=quant_config,
554
+ tp_rank=attn_tp_rank,
555
+ tp_size=attn_tp_size,
547
556
  prefix=add_prefix("down_proj", prefix),
548
557
  )
549
558
 
@@ -416,30 +416,6 @@ class TorchNativeLlamaForCausalLM(nn.Module):
416
416
  input_ids, hidden_states, self.lm_head, forward_batch
417
417
  )
418
418
 
419
- def get_hidden_dim(self, module_name):
420
- if module_name in ["q_proj", "o_proj", "qkv_proj"]:
421
- return self.config.hidden_size, self.config.hidden_size
422
- elif module_name in ["kv_proj"]:
423
- return self.config.hidden_size, self.config.hidden_size // (
424
- self.config.num_attention_heads // self.config.num_key_value_heads
425
- )
426
- elif module_name == "gate_up_proj":
427
- return self.config.hidden_size, self.config.intermediate_size
428
- elif module_name == "down_proj":
429
- return self.config.intermediate_size, self.config.hidden_size
430
- else:
431
- raise NotImplementedError()
432
-
433
- def get_module_name(self, name):
434
- params_mapping = {
435
- "q_proj": "qkv_proj",
436
- "k_proj": "qkv_proj",
437
- "v_proj": "qkv_proj",
438
- "gate_proj": "gate_up_proj",
439
- "up_proj": "gate_up_proj",
440
- }
441
- return params_mapping.get(name, name)
442
-
443
419
  def get_module_name_from_weight_name(self, name):
444
420
  stacked_params_mapping = [
445
421
  # (param_name, shard_name, shard_id, num_shard)
@@ -211,16 +211,13 @@ class TransformersForCausalLM(nn.Module):
211
211
  Apply the model's tensor parallelization plan.
212
212
  Currently only supports linear layers.
213
213
  """
214
- if not self.model.supports_tp_plan:
215
- if tp_size <= 1:
216
- return
214
+ tp_plan = getattr(self.model.config, "base_model_tp_plan", None) or {}
217
215
 
216
+ if not tp_plan and self.tp_size > 1:
218
217
  raise ValueError(
219
218
  f"{type(self.model)} does not support tensor parallel yet!"
220
219
  )
221
220
 
222
- tp_plan = self.model._tp_plan
223
-
224
221
  def _tensor_parallel(module: nn.Module, prefix: str = ""):
225
222
  for child_name, child_module in module.named_children():
226
223
  qual_name = maybe_prefix(prefix, child_name)
@@ -22,13 +22,19 @@ class BaseMultiModalProcessorOutput:
22
22
  input_text: str
23
23
 
24
24
  # frames loaded from image, in given order
25
- images: Optional[list[Union[Image.Image, dict]]] = None
25
+ images: Optional[list[Union[Image.Image, dict]]] = dataclasses.field(
26
+ default_factory=list
27
+ )
26
28
 
27
29
  # videos
28
- videos: Optional[list[Union[torch.Tensor, dict]]] = None
30
+ videos: Optional[list[Union[torch.Tensor, dict]]] = dataclasses.field(
31
+ default_factory=list
32
+ )
29
33
 
30
34
  # audios
31
- audios: Optional[list[Union[np.ndarray, dict]]] = None
35
+ audios: Optional[list[Union[np.ndarray, dict]]] = dataclasses.field(
36
+ default_factory=list
37
+ )
32
38
 
33
39
  def organize_results(self) -> List[Tuple[Modality, Any]]:
34
40
  """
@@ -202,7 +208,7 @@ class BaseMultimodalProcessor(ABC):
202
208
 
203
209
  def process_mm_data(
204
210
  self, input_text, images=None, videos=None, audios=None, **kwargs
205
- ):
211
+ ) -> dict:
206
212
  """
207
213
  process multimodal data with transformers AutoProcessor
208
214
  """
@@ -211,10 +217,14 @@ class BaseMultimodalProcessor(ABC):
211
217
  if videos:
212
218
  kwargs["videos"] = videos
213
219
  if audios:
214
- kwargs["audios"] = audios
215
- if self.__class__.__name__ == "Gemma3nSGLangProcessor":
220
+ if self.arch in {
221
+ "Gemma3nForConditionalGeneration",
222
+ "Qwen2AudioForConditionalGeneration",
223
+ }:
216
224
  # Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107
217
225
  kwargs["audio"] = audios
226
+ else:
227
+ kwargs["audios"] = audios
218
228
 
219
229
  processor = self._processor
220
230
  if (
@@ -601,12 +611,6 @@ class BaseMultimodalProcessor(ABC):
601
611
  all_collected_items: list[MultimodalDataItem] = []
602
612
  input_ids = None
603
613
 
604
- # Handle dict items (already processed)
605
- for dict_item in dict_items:
606
- all_collected_items.extend(
607
- self.collect_mm_items_from_processor_output(dict_item)
608
- )
609
-
610
614
  # Handle raw items (need processing)
611
615
  if raw_images or raw_audios or raw_videos:
612
616
  collected_items, input_ids, ret = self._process_and_collect_mm_items(
@@ -616,10 +620,16 @@ class BaseMultimodalProcessor(ABC):
616
620
  videos=raw_videos,
617
621
  **kwargs,
618
622
  )
619
- all_collected_items.extend(collected_items)
623
+ all_collected_items = collected_items
620
624
  else:
621
625
  ret = None
622
626
 
627
+ # Handle dict items (already processed)
628
+ for dict_item in dict_items:
629
+ all_collected_items.extend(
630
+ self.collect_mm_items_from_processor_output(dict_item)
631
+ )
632
+
623
633
  # Fallback tokenization if no raw items were processed
624
634
  if input_ids is None:
625
635
  input_ids = self._processor.tokenizer(
@@ -0,0 +1,132 @@
1
+ import re
2
+ from typing import List, Union
3
+
4
+ from decord import VideoReader
5
+ from transformers.video_utils import VideoMetadata
6
+
7
+ from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
8
+ from sglang.srt.models.glm4v import Glm4vForConditionalGeneration
9
+ from sglang.srt.models.glm4v_moe import Glm4vMoeForConditionalGeneration
10
+ from sglang.srt.multimodal.processors.base_processor import (
11
+ BaseMultimodalProcessor as SGLangBaseProcessor,
12
+ )
13
+ from sglang.srt.multimodal.processors.base_processor import (
14
+ BaseMultiModalProcessorOutput,
15
+ MultimodalSpecialTokens,
16
+ )
17
+
18
+
19
+ class Glm4vImageProcessor(SGLangBaseProcessor):
20
+ models = [Glm4vForConditionalGeneration, Glm4vMoeForConditionalGeneration]
21
+
22
+ def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
23
+ super().__init__(hf_config, server_args, _processor, *args, **kwargs)
24
+
25
+ # GLM-4.1V and GLM-4.5V specific tokens
26
+ self.IMAGE_TOKEN = "<|image|>"
27
+ self.VIDEO_TOKEN = "<|video|>"
28
+ self.IMAGE_START_TOKEN = "<|begin_of_image|>"
29
+ self.IMAGE_END_TOKEN = "<|end_of_image|>"
30
+ self.VIDEO_START_TOKEN = "<|begin_of_video|>"
31
+ self.VIDEO_END_TOKEN = "<|end_of_video|>"
32
+
33
+ # Token IDs
34
+ self.IM_TOKEN_ID = hf_config.image_token_id
35
+ self.VIDEO_TOKEN_ID = hf_config.video_token_id
36
+ self.IMAGE_START_TOKEN_ID = hf_config.image_start_token_id
37
+ self.IMAGE_END_TOKEN_ID = hf_config.image_end_token_id
38
+ self.VIDEO_START_TOKEN_ID = hf_config.video_start_token_id
39
+ self.VIDEO_END_TOKEN_ID = hf_config.video_end_token_id
40
+
41
+ # Vision config
42
+ self.IMAGE_FACTOR = 28
43
+ self.MIN_PIXELS = 112 * 112
44
+ self.MAX_PIXELS = 30000 * 28 * 28 * 2
45
+
46
+ self.mm_tokens = MultimodalSpecialTokens(
47
+ image_token=self.IMAGE_TOKEN,
48
+ image_token_id=self.IM_TOKEN_ID,
49
+ video_token=self.VIDEO_TOKEN,
50
+ # Note: For GLM4v videos, it uses the video token before tokenization but uses image token after tokenization
51
+ video_token_id=self.IM_TOKEN_ID,
52
+ ).build(_processor)
53
+
54
+ # adapted from https://github.com/huggingface/transformers/blob/369c99d0cea403b77bd0aef818527106453fd9fc/src/transformers/video_utils.py#L312
55
+ async def preprocess_video(self, vr: VideoReader):
56
+ """
57
+ Preprocess video using VideoReader from Decord backend.
58
+
59
+ Args:
60
+ vr (VideoReader): VideoReader object from decord
61
+
62
+ Returns:
63
+ tuple: A tuple containing processed frames and metadata
64
+ """
65
+ video_fps = vr.get_avg_fps()
66
+ total_num_frames = len(vr)
67
+ duration = total_num_frames / video_fps if video_fps else 0
68
+
69
+ metadata = VideoMetadata(
70
+ total_num_frames=int(total_num_frames),
71
+ fps=float(video_fps),
72
+ duration=float(duration),
73
+ video_backend="decord",
74
+ )
75
+
76
+ # Extract all frames
77
+ indices = list(range(total_num_frames))
78
+ frames = vr.get_batch(indices).asnumpy()
79
+ metadata.frames_indices = indices
80
+
81
+ return frames, metadata
82
+
83
+ async def process_mm_data_async(
84
+ self,
85
+ image_data: List[Union[str, bytes]],
86
+ input_text,
87
+ request_obj,
88
+ *args,
89
+ **kwargs,
90
+ ):
91
+ base_output = self.load_mm_data(
92
+ prompt=input_text,
93
+ image_data=image_data,
94
+ video_data=request_obj.video_data,
95
+ multimodal_tokens=self.mm_tokens,
96
+ )
97
+
98
+ video_metadata = None
99
+
100
+ if base_output.videos:
101
+ videos_processed = [
102
+ await self.preprocess_video(video) for video in base_output.videos
103
+ ]
104
+ base_output.videos, video_metadata = map(list, zip(*videos_processed))
105
+ # transformer requires the video inputs to be under this format
106
+ base_output.videos = [base_output.videos]
107
+ video_metadata = [video_metadata]
108
+
109
+ mm_items, input_ids, ret = self.process_and_combine_mm_data(
110
+ base_output, self.mm_tokens, video_metadata=video_metadata
111
+ )
112
+
113
+ input_ids = input_ids.flatten()
114
+ mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index_glm4v(
115
+ input_ids=input_ids.unsqueeze(0),
116
+ hf_config=self.hf_config,
117
+ image_grid_thw=getattr(ret, "image_grid_thw", None),
118
+ video_grid_thw=getattr(ret, "video_grid_thw", None),
119
+ attention_mask=getattr(ret, "attention_mask", None),
120
+ )
121
+ mrope_positions = mrope_positions.squeeze(1)
122
+
123
+ mm_inputs = {
124
+ "input_ids": input_ids.tolist(),
125
+ "mm_items": mm_items,
126
+ "im_token_id": self.mm_tokens.image_token_id,
127
+ "video_token_id": self.mm_tokens.video_token_id,
128
+ "mrope_positions": mrope_positions,
129
+ "mrope_position_delta": mrope_position_delta,
130
+ }
131
+
132
+ return mm_inputs
@@ -1,6 +1,6 @@
1
1
  import re
2
2
 
3
- from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
3
+ from sglang.srt.managers.schedule_batch import Modality
4
4
  from sglang.srt.models.qwen2_audio import Qwen2AudioForConditionalGeneration
5
5
  from sglang.srt.multimodal.processors.base_processor import (
6
6
  BaseMultimodalProcessor,
@@ -29,6 +29,8 @@ class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
29
29
  audio_token_id=self.audio_token_id,
30
30
  ).build(_processor)
31
31
 
32
+ self.ATTR_NAME_TO_MODALITY.update({"feature_attention_mask": Modality.AUDIO})
33
+
32
34
  async def process_mm_data_async(
33
35
  self,
34
36
  audio_data,
@@ -54,7 +56,7 @@ class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
54
56
  input_lengths = (input_lengths - 1) // 2 + 1
55
57
  output_lengths = (input_lengths - 2) // 2 + 1
56
58
 
57
- mm_items[0].model_specific_data["audio_feature_lens"] = output_lengths
59
+ mm_items[0].audio_feature_lens = output_lengths
58
60
 
59
61
  return {
60
62
  "mm_items": mm_items,