sglang 0.4.4.post3__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/bench_serving.py +49 -7
  2. sglang/srt/_custom_ops.py +59 -92
  3. sglang/srt/configs/model_config.py +1 -0
  4. sglang/srt/constrained/base_grammar_backend.py +5 -1
  5. sglang/srt/custom_op.py +5 -0
  6. sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
  7. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  8. sglang/srt/entrypoints/engine.py +0 -5
  9. sglang/srt/layers/attention/flashattention_backend.py +394 -76
  10. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  11. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  12. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  13. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  14. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  15. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
  20. sglang/srt/layers/moe/topk.py +49 -3
  21. sglang/srt/layers/quantization/__init__.py +4 -1
  22. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  23. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  24. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  25. sglang/srt/layers/quantization/moe_wna16.py +501 -0
  26. sglang/srt/layers/quantization/utils.py +1 -1
  27. sglang/srt/layers/rotary_embedding.py +0 -12
  28. sglang/srt/managers/cache_controller.py +34 -11
  29. sglang/srt/managers/mm_utils.py +202 -156
  30. sglang/srt/managers/multimodal_processor.py +0 -2
  31. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  32. sglang/srt/managers/multimodal_processors/clip.py +7 -26
  33. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  34. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  35. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  36. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  37. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  38. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  39. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  40. sglang/srt/managers/schedule_batch.py +185 -128
  41. sglang/srt/managers/scheduler.py +4 -4
  42. sglang/srt/managers/tokenizer_manager.py +1 -1
  43. sglang/srt/managers/utils.py +1 -6
  44. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  45. sglang/srt/mem_cache/memory_pool.py +72 -6
  46. sglang/srt/mem_cache/paged_allocator.py +39 -0
  47. sglang/srt/metrics/collector.py +23 -53
  48. sglang/srt/model_executor/cuda_graph_runner.py +8 -6
  49. sglang/srt/model_executor/forward_batch_info.py +10 -10
  50. sglang/srt/model_executor/model_runner.py +59 -57
  51. sglang/srt/model_loader/loader.py +8 -0
  52. sglang/srt/models/clip.py +12 -7
  53. sglang/srt/models/deepseek_janus_pro.py +10 -15
  54. sglang/srt/models/deepseek_v2.py +212 -121
  55. sglang/srt/models/deepseek_vl2.py +105 -104
  56. sglang/srt/models/gemma3_mm.py +14 -80
  57. sglang/srt/models/llama.py +4 -1
  58. sglang/srt/models/llava.py +31 -19
  59. sglang/srt/models/llavavid.py +16 -7
  60. sglang/srt/models/minicpmo.py +63 -147
  61. sglang/srt/models/minicpmv.py +17 -27
  62. sglang/srt/models/mllama.py +29 -14
  63. sglang/srt/models/qwen2.py +9 -6
  64. sglang/srt/models/qwen2_5_vl.py +21 -31
  65. sglang/srt/models/qwen2_vl.py +20 -21
  66. sglang/srt/openai_api/adapter.py +18 -6
  67. sglang/srt/platforms/interface.py +371 -0
  68. sglang/srt/server_args.py +99 -14
  69. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  70. sglang/srt/speculative/eagle_utils.py +140 -28
  71. sglang/srt/speculative/eagle_worker.py +93 -24
  72. sglang/srt/utils.py +104 -51
  73. sglang/test/test_custom_ops.py +55 -0
  74. sglang/test/test_utils.py +13 -26
  75. sglang/utils.py +2 -2
  76. sglang/version.py +1 -1
  77. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +4 -3
  78. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +81 -76
  79. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
  80. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
  81. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
@@ -75,6 +75,7 @@ from sglang.srt.utils import (
75
75
  get_available_gpu_memory,
76
76
  init_custom_process_group,
77
77
  is_cuda,
78
+ is_flashinfer_available,
78
79
  is_hip,
79
80
  monkey_patch_p2p_access_check,
80
81
  monkey_patch_vllm_gguf_config,
@@ -123,6 +124,10 @@ class ModelRunner:
123
124
  self.page_size = server_args.page_size
124
125
  self.req_to_token_pool = req_to_token_pool
125
126
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
127
+ self.use_mla_backend = (
128
+ self.model_config.attention_arch == AttentionArch.MLA
129
+ and not server_args.disable_mla
130
+ )
126
131
 
127
132
  # Model-specific adjustment
128
133
  self.model_specific_adjustment()
@@ -147,15 +152,18 @@ class ModelRunner:
147
152
  "enable_dp_attention": server_args.enable_dp_attention,
148
153
  "enable_ep_moe": server_args.enable_ep_moe,
149
154
  "enable_deepep_moe": server_args.enable_deepep_moe,
155
+ "deepep_mode": server_args.deepep_mode,
150
156
  "device": server_args.device,
151
157
  "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
152
158
  "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
153
- "enable_flashinfer_mla": server_args.enable_flashinfer_mla,
154
159
  "enable_flashmla": server_args.enable_flashmla,
155
160
  "disable_radix_cache": server_args.disable_radix_cache,
156
161
  "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
157
162
  "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
158
163
  "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
164
+ "n_share_experts_fusion": server_args.n_share_experts_fusion,
165
+ "disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
166
+ "use_mla_backend": self.use_mla_backend,
159
167
  }
160
168
  )
161
169
 
@@ -216,27 +224,38 @@ class ModelRunner:
216
224
  def model_specific_adjustment(self):
217
225
  server_args = self.server_args
218
226
 
219
- if (
220
- self.model_config.attention_arch == AttentionArch.MLA
221
- and not server_args.disable_mla
222
- ):
227
+ if server_args.enable_flashinfer_mla:
228
+ # TODO: remove this branch after enable_flashinfer_mla is deprecated
229
+ logger.info("MLA optimization is turned on. Use flashinfer backend.")
230
+ server_args.attention_backend = "flashinfer"
231
+ elif server_args.enable_flashmla:
232
+ # TODO: remove this branch after enable_flashmla is deprecated
233
+ logger.info("MLA optimization is turned on. Use flashmla decode.")
234
+ server_args.attention_backend = "flashmla"
235
+ elif server_args.attention_backend is None:
236
+ # By default, use flashinfer for non-mla attention and triton for mla attention
237
+ if not self.use_mla_backend:
238
+ server_args.attention_backend = (
239
+ "flashinfer" if is_flashinfer_available() else "triton"
240
+ )
241
+ else:
242
+ server_args.attention_backend = "triton"
243
+ logger.info(
244
+ f"Attention backend not set. Use {server_args.attention_backend} backend by default."
245
+ )
246
+ elif self.use_mla_backend:
223
247
  # TODO: add MLA optimization on CPU
224
248
  if server_args.device != "cpu":
225
- if server_args.enable_flashinfer_mla:
226
- logger.info(
227
- "MLA optimization is turned on. Use flashinfer mla backend."
228
- )
229
- server_args.attention_backend = "flashinfer_mla"
230
- elif server_args.enable_flashmla:
231
- logger.info("MLA optimization is turned on. Use flashmla decode.")
232
- server_args.attention_backend = "flashmla"
233
- elif server_args.attention_backend == "fa3":
249
+ if server_args.attention_backend in ["flashinfer", "fa3", "triton"]:
234
250
  logger.info(
235
- f"MLA optimization is turned on. Use flash attention 3 backend."
251
+ f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
236
252
  )
237
253
  else:
238
- logger.info("MLA optimization is turned on. Use triton backend.")
239
- server_args.attention_backend = "triton"
254
+ raise ValueError(
255
+ f"Invalid attention backend for MLA: {server_args.attention_backend}"
256
+ )
257
+ else:
258
+ raise ValueError(f"MLA optimization not supported on CPU.")
240
259
 
241
260
  if server_args.enable_double_sparsity:
242
261
  logger.info(
@@ -251,17 +270,16 @@ class ModelRunner:
251
270
  self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
252
271
 
253
272
  if self.is_multimodal:
254
- self.mem_fraction_static *= 0.95
273
+ self.mem_fraction_static *= 0.90
255
274
  logger.info(
256
275
  f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
257
276
  f"because this is a multimodal model."
258
277
  )
259
278
 
260
- if self.model_config.hf_config.architectures == [
261
- "MllamaForConditionalGeneration"
262
- ]:
263
- logger.info("Automatically turn off --chunked-prefill-size for mllama.")
264
- server_args.chunked_prefill_size = -1
279
+ logger.info(
280
+ "Automatically turn off --chunked-prefill-size for multimodal model."
281
+ )
282
+ server_args.chunked_prefill_size = -1
265
283
 
266
284
  if self.model_config.hf_config.architectures == [
267
285
  "Qwen2VLForConditionalGeneration"
@@ -269,22 +287,11 @@ class ModelRunner:
269
287
  "Qwen2_5_VLForConditionalGeneration"
270
288
  ]:
271
289
  # TODO: qwen2-vl series does not support radix cache now, set disable_radix_cache=True automatically
272
- logger.info(
273
- "Automatically turn off --chunked-prefill-size and disable radix cache for qwen-vl series."
274
- )
275
- server_args.chunked_prefill_size = -1
276
- server_args.disable_radix_cache = True
277
-
278
- if self.model_config.hf_config.architectures == ["DeepseekVL2ForCausalLM"]:
279
- # TODO: deepseek-vl2 does not support radix cache now, set disable_radix_cache=True automatically
280
- logger.info(
281
- "Automatically turn off --chunked-prefill-size and disable radix cache for deepseek-vl2."
282
- )
283
- server_args.chunked_prefill_size = -1
290
+ logger.info("Automatically disable radix cache for qwen-vl series.")
284
291
  server_args.disable_radix_cache = True
285
292
 
286
293
  if server_args.enable_deepep_moe:
287
- logger.info("DeepEP is turned on.")
294
+ logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}")
288
295
 
289
296
  def init_torch_distributed(self):
290
297
  logger.info("Init torch distributed begin.")
@@ -646,10 +653,7 @@ class ModelRunner:
646
653
  available_gpu_memory = get_available_gpu_memory(
647
654
  self.device, self.gpu_id, distributed=self.tp_size > 1
648
655
  )
649
- if (
650
- self.model_config.attention_arch == AttentionArch.MLA
651
- and not self.server_args.disable_mla
652
- ):
656
+ if self.use_mla_backend:
653
657
  cell_size = (
654
658
  (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
655
659
  * self.model_config.num_hidden_layers
@@ -760,10 +764,7 @@ class ModelRunner:
760
764
  # Draft worker shares req_to_token_pool with the target worker.
761
765
  assert self.is_draft_worker
762
766
 
763
- if (
764
- self.model_config.attention_arch == AttentionArch.MLA
765
- and not self.server_args.disable_mla
766
- ):
767
+ if self.use_mla_backend:
767
768
  self.token_to_kv_pool = MLATokenToKVPool(
768
769
  self.max_total_num_tokens,
769
770
  page_size=self.page_size,
@@ -834,14 +835,21 @@ class ModelRunner:
834
835
  def init_attention_backend(self):
835
836
  """Init attention kernel backend."""
836
837
  if self.server_args.attention_backend == "flashinfer":
837
- from sglang.srt.layers.attention.flashinfer_backend import (
838
- FlashInferAttnBackend,
839
- )
838
+ if not self.use_mla_backend:
839
+ from sglang.srt.layers.attention.flashinfer_backend import (
840
+ FlashInferAttnBackend,
841
+ )
840
842
 
841
- # Init streams
842
- if self.server_args.speculative_algorithm == "EAGLE":
843
- self.plan_stream_for_flashinfer = torch.cuda.Stream()
844
- self.attn_backend = FlashInferAttnBackend(self)
843
+ # Init streams
844
+ if self.server_args.speculative_algorithm == "EAGLE":
845
+ self.plan_stream_for_flashinfer = torch.cuda.Stream()
846
+ self.attn_backend = FlashInferAttnBackend(self)
847
+ else:
848
+ from sglang.srt.layers.attention.flashinfer_mla_backend import (
849
+ FlashInferMLAAttnBackend,
850
+ )
851
+
852
+ self.attn_backend = FlashInferMLAAttnBackend(self)
845
853
  elif self.server_args.attention_backend == "triton":
846
854
  assert self.sliding_window_size is None, (
847
855
  "Window attention is not supported in the triton attention backend. "
@@ -867,12 +875,6 @@ class ModelRunner:
867
875
  )
868
876
 
869
877
  self.attn_backend = TorchNativeAttnBackend(self)
870
- elif self.server_args.attention_backend == "flashinfer_mla":
871
- from sglang.srt.layers.attention.flashinfer_mla_backend import (
872
- FlashInferMLAAttnBackend,
873
- )
874
-
875
- self.attn_backend = FlashInferMLAAttnBackend(self)
876
878
  elif self.server_args.attention_backend == "flashmla":
877
879
  from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
878
880
 
@@ -489,6 +489,14 @@ class DummyModelLoader(BaseModelLoader):
489
489
  # NOTE(woosuk): For accurate performance evaluation, we assign
490
490
  # random values to the weights.
491
491
  initialize_dummy_weights(model)
492
+
493
+ # Model weight loading consists of two stages:
494
+ # 1. Initial weight loading.
495
+ # 2. Post-processing of weights, including assigning specific member variables.
496
+ # For `dummy_init`, only the second stage is required.
497
+ if hasattr(model, "post_load_weights"):
498
+ model.post_load_weights()
499
+
492
500
  return model.eval()
493
501
 
494
502
 
sglang/srt/models/clip.py CHANGED
@@ -17,7 +17,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
17
17
  from sglang.srt.managers.schedule_batch import MultimodalInputs
18
18
  from sglang.srt.model_executor.model_runner import ForwardBatch
19
19
  from sglang.srt.model_loader.weight_utils import default_weight_loader
20
- from sglang.srt.utils import add_prefix
20
+ from sglang.srt.utils import add_prefix, flatten_nested_list
21
21
 
22
22
 
23
23
  class CLIPVisionEmbeddings(nn.Module):
@@ -368,7 +368,6 @@ class CLIPVisionTransformer(nn.Module):
368
368
  self,
369
369
  pixel_values: torch.Tensor,
370
370
  ) -> torch.Tensor:
371
-
372
371
  hidden_states = self.embeddings(pixel_values.to(self.device))
373
372
  hidden_states = self.pre_layrnorm(hidden_states)
374
373
 
@@ -456,12 +455,18 @@ class CLIPModel(nn.Module):
456
455
  get_embedding: bool = True,
457
456
  ):
458
457
  assert get_embedding, "CLIPEmbeddingModel is only used for embedding"
459
- image_inputs = None
458
+ mm_inputs = []
460
459
  if forward_batch.mm_inputs is not None:
461
- image_inputs = forward_batch.mm_inputs
462
-
463
- if image_inputs is not None and image_inputs[0] is not None:
464
- vision_outputs = self.vision_model(image_inputs[0].pixel_values)
460
+ mm_inputs = forward_batch.mm_inputs
461
+ pixel_values_list = [
462
+ item.pixel_values
463
+ for item in flatten_nested_list(
464
+ [mm_input.mm_items for mm_input in mm_inputs if mm_input is not None]
465
+ )
466
+ ]
467
+ if len(pixel_values_list) != 0:
468
+ pixel_values = torch.concat(pixel_values_list)
469
+ vision_outputs = self.vision_model(pixel_values)
465
470
  pooled_output = vision_outputs[:, 0, :]
466
471
  image_embeds = self.visual_projection(pooled_output)
467
472
  image_embeds = nn.functional.normalize(image_embeds, p=2, dim=1)
@@ -51,7 +51,7 @@ from sglang.srt.managers.mm_utils import (
51
51
  MultiModalityDataPaddingPatternTokenPairs,
52
52
  general_mm_embed_routine,
53
53
  )
54
- from sglang.srt.managers.schedule_batch import MultimodalInputs, global_server_args_dict
54
+ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
55
55
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
56
56
  from sglang.srt.model_loader.weight_utils import default_weight_loader
57
57
  from sglang.srt.models.llama import LlamaForCausalLM
@@ -1959,8 +1959,8 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
1959
1959
  )
1960
1960
  self.logits_processor = LogitsProcessor(config)
1961
1961
 
1962
- def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor:
1963
- pixel_values = image_input.pixel_values
1962
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
1963
+ pixel_values = torch.concat([item.pixel_values for item in items], dim=0)
1964
1964
  bs, n = pixel_values.shape[0:2]
1965
1965
  pixel_values = pixel_values.to(
1966
1966
  device=self.vision_model.device, dtype=self.vision_model.dtype
@@ -1976,7 +1976,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
1976
1976
  return images_embeds
1977
1977
 
1978
1978
  def get_input_embeddings(self) -> nn.Embedding:
1979
- return self.language_model.model.embed_tokens
1979
+ return self.language_model.get_input_embeddings()
1980
1980
 
1981
1981
  @torch.no_grad()
1982
1982
  def forward(
@@ -1984,23 +1984,18 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
1984
1984
  input_ids: torch.LongTensor,
1985
1985
  positions: torch.Tensor,
1986
1986
  forward_batch: ForwardBatch,
1987
+ get_embedding: bool = False,
1987
1988
  ) -> torch.Tensor:
1988
-
1989
- inputs_embeds = general_mm_embed_routine(
1989
+ hidden_states = general_mm_embed_routine(
1990
1990
  input_ids=input_ids,
1991
1991
  forward_batch=forward_batch,
1992
- embed_tokens=self.get_input_embeddings(),
1993
- mm_data_embedding_func=self.get_image_feature,
1994
- )
1995
-
1996
- return self.language_model(
1997
- input_ids=None,
1992
+ image_data_embedding_func=self.get_image_feature,
1993
+ language_model=self.language_model,
1998
1994
  positions=positions,
1999
- forward_batch=forward_batch,
2000
- input_embeds=inputs_embeds,
2001
- get_embedding=False,
2002
1995
  )
2003
1996
 
1997
+ return hidden_states
1998
+
2004
1999
  def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
2005
2000
  return self.gen_aligner(self.gen_embed(image_ids))
2006
2001