sglang 0.4.4.post3__py3-none-any.whl → 0.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_serving.py +49 -7
  2. sglang/lang/chat_template.py +24 -0
  3. sglang/srt/_custom_ops.py +59 -92
  4. sglang/srt/configs/model_config.py +5 -0
  5. sglang/srt/constrained/base_grammar_backend.py +5 -1
  6. sglang/srt/conversation.py +29 -4
  7. sglang/srt/custom_op.py +5 -0
  8. sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
  9. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  10. sglang/srt/entrypoints/engine.py +0 -5
  11. sglang/srt/layers/attention/flashattention_backend.py +678 -83
  12. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  13. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  14. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  15. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  16. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  17. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  18. sglang/srt/layers/moe/fused_moe_native.py +5 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +416 -50
  30. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  31. sglang/srt/layers/moe/topk.py +49 -3
  32. sglang/srt/layers/quantization/__init__.py +5 -1
  33. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  34. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  35. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  36. sglang/srt/layers/quantization/fp8.py +3 -1
  37. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  38. sglang/srt/layers/quantization/moe_wna16.py +503 -0
  39. sglang/srt/layers/quantization/utils.py +1 -1
  40. sglang/srt/layers/quantization/w8a8_int8.py +2 -0
  41. sglang/srt/layers/radix_attention.py +2 -0
  42. sglang/srt/layers/rotary_embedding.py +63 -12
  43. sglang/srt/managers/cache_controller.py +34 -11
  44. sglang/srt/managers/mm_utils.py +202 -156
  45. sglang/srt/managers/multimodal_processor.py +0 -2
  46. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  47. sglang/srt/managers/multimodal_processors/clip.py +7 -26
  48. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  49. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  50. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  51. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  52. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  53. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  54. sglang/srt/managers/multimodal_processors/mllama4.py +161 -0
  55. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  56. sglang/srt/managers/schedule_batch.py +185 -128
  57. sglang/srt/managers/scheduler.py +4 -4
  58. sglang/srt/managers/tokenizer_manager.py +1 -1
  59. sglang/srt/managers/utils.py +1 -6
  60. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  61. sglang/srt/mem_cache/memory_pool.py +72 -6
  62. sglang/srt/mem_cache/paged_allocator.py +39 -0
  63. sglang/srt/metrics/collector.py +23 -53
  64. sglang/srt/model_executor/cuda_graph_runner.py +8 -6
  65. sglang/srt/model_executor/forward_batch_info.py +10 -10
  66. sglang/srt/model_executor/model_runner.py +60 -57
  67. sglang/srt/model_loader/loader.py +8 -0
  68. sglang/srt/models/clip.py +12 -7
  69. sglang/srt/models/deepseek_janus_pro.py +10 -15
  70. sglang/srt/models/deepseek_v2.py +212 -121
  71. sglang/srt/models/deepseek_vl2.py +105 -104
  72. sglang/srt/models/gemma3_mm.py +14 -80
  73. sglang/srt/models/llama.py +16 -5
  74. sglang/srt/models/llama4.py +420 -0
  75. sglang/srt/models/llava.py +31 -19
  76. sglang/srt/models/llavavid.py +16 -7
  77. sglang/srt/models/minicpmo.py +63 -147
  78. sglang/srt/models/minicpmv.py +17 -27
  79. sglang/srt/models/mllama.py +29 -14
  80. sglang/srt/models/mllama4.py +154 -0
  81. sglang/srt/models/qwen2.py +9 -6
  82. sglang/srt/models/qwen2_5_vl.py +21 -31
  83. sglang/srt/models/qwen2_vl.py +20 -21
  84. sglang/srt/openai_api/adapter.py +18 -6
  85. sglang/srt/platforms/interface.py +371 -0
  86. sglang/srt/server_args.py +99 -14
  87. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  88. sglang/srt/speculative/eagle_utils.py +140 -28
  89. sglang/srt/speculative/eagle_worker.py +93 -24
  90. sglang/srt/utils.py +104 -51
  91. sglang/test/test_custom_ops.py +55 -0
  92. sglang/test/test_utils.py +13 -26
  93. sglang/utils.py +2 -2
  94. sglang/version.py +1 -1
  95. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/METADATA +4 -3
  96. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/RECORD +99 -84
  97. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/top_level.txt +0 -0
@@ -75,6 +75,7 @@ from sglang.srt.utils import (
75
75
  get_available_gpu_memory,
76
76
  init_custom_process_group,
77
77
  is_cuda,
78
+ is_flashinfer_available,
78
79
  is_hip,
79
80
  monkey_patch_p2p_access_check,
80
81
  monkey_patch_vllm_gguf_config,
@@ -123,6 +124,11 @@ class ModelRunner:
123
124
  self.page_size = server_args.page_size
124
125
  self.req_to_token_pool = req_to_token_pool
125
126
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
127
+ self.use_mla_backend = (
128
+ self.model_config.attention_arch == AttentionArch.MLA
129
+ and not server_args.disable_mla
130
+ )
131
+ self.attention_chunk_size = model_config.attention_chunk_size
126
132
 
127
133
  # Model-specific adjustment
128
134
  self.model_specific_adjustment()
@@ -147,15 +153,18 @@ class ModelRunner:
147
153
  "enable_dp_attention": server_args.enable_dp_attention,
148
154
  "enable_ep_moe": server_args.enable_ep_moe,
149
155
  "enable_deepep_moe": server_args.enable_deepep_moe,
156
+ "deepep_mode": server_args.deepep_mode,
150
157
  "device": server_args.device,
151
158
  "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
152
159
  "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
153
- "enable_flashinfer_mla": server_args.enable_flashinfer_mla,
154
160
  "enable_flashmla": server_args.enable_flashmla,
155
161
  "disable_radix_cache": server_args.disable_radix_cache,
156
162
  "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
157
163
  "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
158
164
  "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
165
+ "n_share_experts_fusion": server_args.n_share_experts_fusion,
166
+ "disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
167
+ "use_mla_backend": self.use_mla_backend,
159
168
  }
160
169
  )
161
170
 
@@ -216,27 +225,38 @@ class ModelRunner:
216
225
  def model_specific_adjustment(self):
217
226
  server_args = self.server_args
218
227
 
219
- if (
220
- self.model_config.attention_arch == AttentionArch.MLA
221
- and not server_args.disable_mla
222
- ):
228
+ if server_args.enable_flashinfer_mla:
229
+ # TODO: remove this branch after enable_flashinfer_mla is deprecated
230
+ logger.info("MLA optimization is turned on. Use flashinfer backend.")
231
+ server_args.attention_backend = "flashinfer"
232
+ elif server_args.enable_flashmla:
233
+ # TODO: remove this branch after enable_flashmla is deprecated
234
+ logger.info("MLA optimization is turned on. Use flashmla decode.")
235
+ server_args.attention_backend = "flashmla"
236
+ elif server_args.attention_backend is None:
237
+ # By default, use flashinfer for non-mla attention and triton for mla attention
238
+ if not self.use_mla_backend:
239
+ server_args.attention_backend = (
240
+ "flashinfer" if is_flashinfer_available() else "triton"
241
+ )
242
+ else:
243
+ server_args.attention_backend = "triton"
244
+ logger.info(
245
+ f"Attention backend not set. Use {server_args.attention_backend} backend by default."
246
+ )
247
+ elif self.use_mla_backend:
223
248
  # TODO: add MLA optimization on CPU
224
249
  if server_args.device != "cpu":
225
- if server_args.enable_flashinfer_mla:
226
- logger.info(
227
- "MLA optimization is turned on. Use flashinfer mla backend."
228
- )
229
- server_args.attention_backend = "flashinfer_mla"
230
- elif server_args.enable_flashmla:
231
- logger.info("MLA optimization is turned on. Use flashmla decode.")
232
- server_args.attention_backend = "flashmla"
233
- elif server_args.attention_backend == "fa3":
250
+ if server_args.attention_backend in ["flashinfer", "fa3", "triton"]:
234
251
  logger.info(
235
- f"MLA optimization is turned on. Use flash attention 3 backend."
252
+ f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
236
253
  )
237
254
  else:
238
- logger.info("MLA optimization is turned on. Use triton backend.")
239
- server_args.attention_backend = "triton"
255
+ raise ValueError(
256
+ f"Invalid attention backend for MLA: {server_args.attention_backend}"
257
+ )
258
+ else:
259
+ raise ValueError(f"MLA optimization not supported on CPU.")
240
260
 
241
261
  if server_args.enable_double_sparsity:
242
262
  logger.info(
@@ -251,17 +271,16 @@ class ModelRunner:
251
271
  self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
252
272
 
253
273
  if self.is_multimodal:
254
- self.mem_fraction_static *= 0.95
274
+ self.mem_fraction_static *= 0.90
255
275
  logger.info(
256
276
  f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
257
277
  f"because this is a multimodal model."
258
278
  )
259
279
 
260
- if self.model_config.hf_config.architectures == [
261
- "MllamaForConditionalGeneration"
262
- ]:
263
- logger.info("Automatically turn off --chunked-prefill-size for mllama.")
264
- server_args.chunked_prefill_size = -1
280
+ logger.info(
281
+ "Automatically turn off --chunked-prefill-size for multimodal model."
282
+ )
283
+ server_args.chunked_prefill_size = -1
265
284
 
266
285
  if self.model_config.hf_config.architectures == [
267
286
  "Qwen2VLForConditionalGeneration"
@@ -269,22 +288,11 @@ class ModelRunner:
269
288
  "Qwen2_5_VLForConditionalGeneration"
270
289
  ]:
271
290
  # TODO: qwen2-vl series does not support radix cache now, set disable_radix_cache=True automatically
272
- logger.info(
273
- "Automatically turn off --chunked-prefill-size and disable radix cache for qwen-vl series."
274
- )
275
- server_args.chunked_prefill_size = -1
276
- server_args.disable_radix_cache = True
277
-
278
- if self.model_config.hf_config.architectures == ["DeepseekVL2ForCausalLM"]:
279
- # TODO: deepseek-vl2 does not support radix cache now, set disable_radix_cache=True automatically
280
- logger.info(
281
- "Automatically turn off --chunked-prefill-size and disable radix cache for deepseek-vl2."
282
- )
283
- server_args.chunked_prefill_size = -1
291
+ logger.info("Automatically disable radix cache for qwen-vl series.")
284
292
  server_args.disable_radix_cache = True
285
293
 
286
294
  if server_args.enable_deepep_moe:
287
- logger.info("DeepEP is turned on.")
295
+ logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}")
288
296
 
289
297
  def init_torch_distributed(self):
290
298
  logger.info("Init torch distributed begin.")
@@ -646,10 +654,7 @@ class ModelRunner:
646
654
  available_gpu_memory = get_available_gpu_memory(
647
655
  self.device, self.gpu_id, distributed=self.tp_size > 1
648
656
  )
649
- if (
650
- self.model_config.attention_arch == AttentionArch.MLA
651
- and not self.server_args.disable_mla
652
- ):
657
+ if self.use_mla_backend:
653
658
  cell_size = (
654
659
  (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
655
660
  * self.model_config.num_hidden_layers
@@ -760,10 +765,7 @@ class ModelRunner:
760
765
  # Draft worker shares req_to_token_pool with the target worker.
761
766
  assert self.is_draft_worker
762
767
 
763
- if (
764
- self.model_config.attention_arch == AttentionArch.MLA
765
- and not self.server_args.disable_mla
766
- ):
768
+ if self.use_mla_backend:
767
769
  self.token_to_kv_pool = MLATokenToKVPool(
768
770
  self.max_total_num_tokens,
769
771
  page_size=self.page_size,
@@ -834,14 +836,21 @@ class ModelRunner:
834
836
  def init_attention_backend(self):
835
837
  """Init attention kernel backend."""
836
838
  if self.server_args.attention_backend == "flashinfer":
837
- from sglang.srt.layers.attention.flashinfer_backend import (
838
- FlashInferAttnBackend,
839
- )
839
+ if not self.use_mla_backend:
840
+ from sglang.srt.layers.attention.flashinfer_backend import (
841
+ FlashInferAttnBackend,
842
+ )
840
843
 
841
- # Init streams
842
- if self.server_args.speculative_algorithm == "EAGLE":
843
- self.plan_stream_for_flashinfer = torch.cuda.Stream()
844
- self.attn_backend = FlashInferAttnBackend(self)
844
+ # Init streams
845
+ if self.server_args.speculative_algorithm == "EAGLE":
846
+ self.plan_stream_for_flashinfer = torch.cuda.Stream()
847
+ self.attn_backend = FlashInferAttnBackend(self)
848
+ else:
849
+ from sglang.srt.layers.attention.flashinfer_mla_backend import (
850
+ FlashInferMLAAttnBackend,
851
+ )
852
+
853
+ self.attn_backend = FlashInferMLAAttnBackend(self)
845
854
  elif self.server_args.attention_backend == "triton":
846
855
  assert self.sliding_window_size is None, (
847
856
  "Window attention is not supported in the triton attention backend. "
@@ -867,12 +876,6 @@ class ModelRunner:
867
876
  )
868
877
 
869
878
  self.attn_backend = TorchNativeAttnBackend(self)
870
- elif self.server_args.attention_backend == "flashinfer_mla":
871
- from sglang.srt.layers.attention.flashinfer_mla_backend import (
872
- FlashInferMLAAttnBackend,
873
- )
874
-
875
- self.attn_backend = FlashInferMLAAttnBackend(self)
876
879
  elif self.server_args.attention_backend == "flashmla":
877
880
  from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
878
881
 
@@ -489,6 +489,14 @@ class DummyModelLoader(BaseModelLoader):
489
489
  # NOTE(woosuk): For accurate performance evaluation, we assign
490
490
  # random values to the weights.
491
491
  initialize_dummy_weights(model)
492
+
493
+ # Model weight loading consists of two stages:
494
+ # 1. Initial weight loading.
495
+ # 2. Post-processing of weights, including assigning specific member variables.
496
+ # For `dummy_init`, only the second stage is required.
497
+ if hasattr(model, "post_load_weights"):
498
+ model.post_load_weights()
499
+
492
500
  return model.eval()
493
501
 
494
502
 
sglang/srt/models/clip.py CHANGED
@@ -17,7 +17,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
17
17
  from sglang.srt.managers.schedule_batch import MultimodalInputs
18
18
  from sglang.srt.model_executor.model_runner import ForwardBatch
19
19
  from sglang.srt.model_loader.weight_utils import default_weight_loader
20
- from sglang.srt.utils import add_prefix
20
+ from sglang.srt.utils import add_prefix, flatten_nested_list
21
21
 
22
22
 
23
23
  class CLIPVisionEmbeddings(nn.Module):
@@ -368,7 +368,6 @@ class CLIPVisionTransformer(nn.Module):
368
368
  self,
369
369
  pixel_values: torch.Tensor,
370
370
  ) -> torch.Tensor:
371
-
372
371
  hidden_states = self.embeddings(pixel_values.to(self.device))
373
372
  hidden_states = self.pre_layrnorm(hidden_states)
374
373
 
@@ -456,12 +455,18 @@ class CLIPModel(nn.Module):
456
455
  get_embedding: bool = True,
457
456
  ):
458
457
  assert get_embedding, "CLIPEmbeddingModel is only used for embedding"
459
- image_inputs = None
458
+ mm_inputs = []
460
459
  if forward_batch.mm_inputs is not None:
461
- image_inputs = forward_batch.mm_inputs
462
-
463
- if image_inputs is not None and image_inputs[0] is not None:
464
- vision_outputs = self.vision_model(image_inputs[0].pixel_values)
460
+ mm_inputs = forward_batch.mm_inputs
461
+ pixel_values_list = [
462
+ item.pixel_values
463
+ for item in flatten_nested_list(
464
+ [mm_input.mm_items for mm_input in mm_inputs if mm_input is not None]
465
+ )
466
+ ]
467
+ if len(pixel_values_list) != 0:
468
+ pixel_values = torch.concat(pixel_values_list)
469
+ vision_outputs = self.vision_model(pixel_values)
465
470
  pooled_output = vision_outputs[:, 0, :]
466
471
  image_embeds = self.visual_projection(pooled_output)
467
472
  image_embeds = nn.functional.normalize(image_embeds, p=2, dim=1)
@@ -51,7 +51,7 @@ from sglang.srt.managers.mm_utils import (
51
51
  MultiModalityDataPaddingPatternTokenPairs,
52
52
  general_mm_embed_routine,
53
53
  )
54
- from sglang.srt.managers.schedule_batch import MultimodalInputs, global_server_args_dict
54
+ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
55
55
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
56
56
  from sglang.srt.model_loader.weight_utils import default_weight_loader
57
57
  from sglang.srt.models.llama import LlamaForCausalLM
@@ -1959,8 +1959,8 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
1959
1959
  )
1960
1960
  self.logits_processor = LogitsProcessor(config)
1961
1961
 
1962
- def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor:
1963
- pixel_values = image_input.pixel_values
1962
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
1963
+ pixel_values = torch.concat([item.pixel_values for item in items], dim=0)
1964
1964
  bs, n = pixel_values.shape[0:2]
1965
1965
  pixel_values = pixel_values.to(
1966
1966
  device=self.vision_model.device, dtype=self.vision_model.dtype
@@ -1976,7 +1976,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
1976
1976
  return images_embeds
1977
1977
 
1978
1978
  def get_input_embeddings(self) -> nn.Embedding:
1979
- return self.language_model.model.embed_tokens
1979
+ return self.language_model.get_input_embeddings()
1980
1980
 
1981
1981
  @torch.no_grad()
1982
1982
  def forward(
@@ -1984,23 +1984,18 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
1984
1984
  input_ids: torch.LongTensor,
1985
1985
  positions: torch.Tensor,
1986
1986
  forward_batch: ForwardBatch,
1987
+ get_embedding: bool = False,
1987
1988
  ) -> torch.Tensor:
1988
-
1989
- inputs_embeds = general_mm_embed_routine(
1989
+ hidden_states = general_mm_embed_routine(
1990
1990
  input_ids=input_ids,
1991
1991
  forward_batch=forward_batch,
1992
- embed_tokens=self.get_input_embeddings(),
1993
- mm_data_embedding_func=self.get_image_feature,
1994
- )
1995
-
1996
- return self.language_model(
1997
- input_ids=None,
1992
+ image_data_embedding_func=self.get_image_feature,
1993
+ language_model=self.language_model,
1998
1994
  positions=positions,
1999
- forward_batch=forward_batch,
2000
- input_embeds=inputs_embeds,
2001
- get_embedding=False,
2002
1995
  )
2003
1996
 
1997
+ return hidden_states
1998
+
2004
1999
  def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
2005
2000
  return self.gen_aligner(self.gen_embed(image_ids))
2006
2001