sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +73 -14
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/launch_server.py +2 -0
  5. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  6. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
  7. sglang/srt/checkpoint_engine/__init__.py +9 -0
  8. sglang/srt/checkpoint_engine/update.py +317 -0
  9. sglang/srt/compilation/backend.py +1 -1
  10. sglang/srt/configs/__init__.py +2 -0
  11. sglang/srt/configs/deepseek_ocr.py +542 -10
  12. sglang/srt/configs/deepseekvl2.py +95 -194
  13. sglang/srt/configs/kimi_linear.py +160 -0
  14. sglang/srt/configs/mamba_utils.py +66 -0
  15. sglang/srt/configs/model_config.py +30 -7
  16. sglang/srt/constants.py +7 -0
  17. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  18. sglang/srt/disaggregation/decode.py +34 -6
  19. sglang/srt/disaggregation/nixl/conn.py +2 -2
  20. sglang/srt/disaggregation/prefill.py +25 -3
  21. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  22. sglang/srt/distributed/parallel_state.py +9 -12
  23. sglang/srt/entrypoints/engine.py +31 -20
  24. sglang/srt/entrypoints/grpc_server.py +0 -1
  25. sglang/srt/entrypoints/http_server.py +94 -94
  26. sglang/srt/entrypoints/openai/protocol.py +7 -1
  27. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  28. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  29. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  30. sglang/srt/environ.py +23 -2
  31. sglang/srt/eplb/expert_distribution.py +64 -1
  32. sglang/srt/eplb/expert_location.py +106 -36
  33. sglang/srt/function_call/function_call_parser.py +2 -0
  34. sglang/srt/function_call/minimax_m2.py +367 -0
  35. sglang/srt/grpc/compile_proto.py +3 -0
  36. sglang/srt/layers/activation.py +6 -0
  37. sglang/srt/layers/attention/ascend_backend.py +233 -5
  38. sglang/srt/layers/attention/attention_registry.py +3 -0
  39. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  40. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  41. sglang/srt/layers/attention/fla/kda.py +1359 -0
  42. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  43. sglang/srt/layers/attention/flashattention_backend.py +19 -8
  44. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  45. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
  46. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  47. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  48. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  49. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  50. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  51. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  52. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  53. sglang/srt/layers/attention/nsa_backend.py +157 -23
  54. sglang/srt/layers/attention/triton_backend.py +4 -1
  55. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  56. sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
  57. sglang/srt/layers/attention/utils.py +78 -0
  58. sglang/srt/layers/communicator.py +24 -1
  59. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  60. sglang/srt/layers/layernorm.py +35 -6
  61. sglang/srt/layers/logits_processor.py +9 -20
  62. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  63. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  64. sglang/srt/layers/moe/ep_moe/layer.py +78 -289
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  67. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  68. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  69. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  70. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  71. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  72. sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
  73. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  75. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  76. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  77. sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
  78. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  79. sglang/srt/layers/moe/topk.py +35 -10
  80. sglang/srt/layers/moe/utils.py +3 -4
  81. sglang/srt/layers/pooler.py +21 -2
  82. sglang/srt/layers/quantization/__init__.py +13 -84
  83. sglang/srt/layers/quantization/auto_round.py +394 -0
  84. sglang/srt/layers/quantization/awq.py +0 -3
  85. sglang/srt/layers/quantization/base_config.py +7 -0
  86. sglang/srt/layers/quantization/fp8.py +68 -63
  87. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  88. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  89. sglang/srt/layers/quantization/gguf.py +566 -0
  90. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  91. sglang/srt/layers/quantization/mxfp4.py +30 -38
  92. sglang/srt/layers/quantization/unquant.py +23 -45
  93. sglang/srt/layers/quantization/w4afp8.py +38 -2
  94. sglang/srt/layers/radix_attention.py +5 -2
  95. sglang/srt/layers/rotary_embedding.py +130 -46
  96. sglang/srt/layers/sampler.py +12 -1
  97. sglang/srt/lora/lora_registry.py +9 -0
  98. sglang/srt/managers/async_mm_data_processor.py +122 -0
  99. sglang/srt/managers/data_parallel_controller.py +30 -3
  100. sglang/srt/managers/detokenizer_manager.py +3 -0
  101. sglang/srt/managers/io_struct.py +29 -4
  102. sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
  103. sglang/srt/managers/schedule_batch.py +74 -15
  104. sglang/srt/managers/scheduler.py +185 -144
  105. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  107. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  108. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  109. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  110. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  111. sglang/srt/managers/session_controller.py +6 -5
  112. sglang/srt/managers/tokenizer_manager.py +165 -78
  113. sglang/srt/managers/tp_worker.py +24 -1
  114. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  115. sglang/srt/mem_cache/common.py +1 -0
  116. sglang/srt/mem_cache/hicache_storage.py +7 -1
  117. sglang/srt/mem_cache/memory_pool.py +253 -57
  118. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  119. sglang/srt/mem_cache/radix_cache.py +4 -0
  120. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  121. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  122. sglang/srt/metrics/collector.py +46 -3
  123. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  124. sglang/srt/model_executor/forward_batch_info.py +55 -14
  125. sglang/srt/model_executor/model_runner.py +77 -170
  126. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  127. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  128. sglang/srt/model_loader/weight_utils.py +1 -1
  129. sglang/srt/models/bailing_moe.py +9 -2
  130. sglang/srt/models/deepseek_nextn.py +11 -2
  131. sglang/srt/models/deepseek_v2.py +296 -78
  132. sglang/srt/models/glm4.py +391 -77
  133. sglang/srt/models/glm4_moe.py +322 -354
  134. sglang/srt/models/glm4_moe_nextn.py +4 -14
  135. sglang/srt/models/glm4v.py +196 -55
  136. sglang/srt/models/glm4v_moe.py +29 -197
  137. sglang/srt/models/gpt_oss.py +1 -10
  138. sglang/srt/models/kimi_linear.py +678 -0
  139. sglang/srt/models/llama4.py +1 -1
  140. sglang/srt/models/llama_eagle3.py +11 -1
  141. sglang/srt/models/longcat_flash.py +2 -2
  142. sglang/srt/models/minimax_m2.py +922 -0
  143. sglang/srt/models/nvila.py +355 -0
  144. sglang/srt/models/nvila_lite.py +184 -0
  145. sglang/srt/models/qwen2.py +23 -2
  146. sglang/srt/models/qwen2_moe.py +30 -15
  147. sglang/srt/models/qwen3.py +35 -5
  148. sglang/srt/models/qwen3_moe.py +18 -12
  149. sglang/srt/models/qwen3_next.py +7 -0
  150. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  151. sglang/srt/multimodal/processors/base_processor.py +1 -0
  152. sglang/srt/multimodal/processors/glm4v.py +1 -1
  153. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  154. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  155. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  156. sglang/srt/multiplex/pdmux_context.py +164 -0
  157. sglang/srt/parser/conversation.py +7 -1
  158. sglang/srt/parser/reasoning_parser.py +28 -1
  159. sglang/srt/sampling/custom_logit_processor.py +67 -1
  160. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  161. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  162. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  163. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  164. sglang/srt/server_args.py +459 -199
  165. sglang/srt/single_batch_overlap.py +2 -4
  166. sglang/srt/speculative/draft_utils.py +16 -0
  167. sglang/srt/speculative/eagle_info.py +42 -36
  168. sglang/srt/speculative/eagle_info_v2.py +68 -25
  169. sglang/srt/speculative/eagle_utils.py +261 -16
  170. sglang/srt/speculative/eagle_worker.py +11 -3
  171. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  172. sglang/srt/speculative/spec_info.py +305 -31
  173. sglang/srt/speculative/spec_utils.py +44 -8
  174. sglang/srt/tracing/trace.py +121 -12
  175. sglang/srt/utils/common.py +142 -74
  176. sglang/srt/utils/hf_transformers_utils.py +38 -12
  177. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  178. sglang/test/kits/radix_cache_server_kit.py +50 -0
  179. sglang/test/runners.py +31 -7
  180. sglang/test/simple_eval_common.py +5 -3
  181. sglang/test/simple_eval_humaneval.py +1 -0
  182. sglang/test/simple_eval_math.py +1 -0
  183. sglang/test/simple_eval_mmlu.py +1 -0
  184. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  185. sglang/test/test_deterministic.py +235 -12
  186. sglang/test/test_deterministic_utils.py +2 -1
  187. sglang/test/test_utils.py +7 -1
  188. sglang/version.py +1 -1
  189. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
  190. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
  191. sglang/srt/models/vila.py +0 -306
  192. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  193. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  194. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  195. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -12,9 +12,9 @@ from sglang.srt.layers.moe.token_dispatcher.deepep import (
12
12
  DeepEPConfig,
13
13
  DeepEPDispatcher,
14
14
  DeepEPLLCombineInput,
15
- DeepEPLLOutput,
15
+ DeepEPLLDispatchOutput,
16
16
  DeepEPNormalCombineInput,
17
- DeepEPNormalOutput,
17
+ DeepEPNormalDispatchOutput,
18
18
  )
19
19
  from sglang.srt.layers.moe.token_dispatcher.mooncake import (
20
20
  MooncakeCombineInput,
@@ -44,8 +44,8 @@ __all__ = [
44
44
  "StandardCombineInput",
45
45
  "DeepEPConfig",
46
46
  "DeepEPDispatcher",
47
- "DeepEPNormalOutput",
48
- "DeepEPLLOutput",
47
+ "DeepEPNormalDispatchOutput",
48
+ "DeepEPLLDispatchOutput",
49
49
  "DeepEPLLCombineInput",
50
50
  "DeepEPNormalCombineInput",
51
51
  ]
@@ -9,9 +9,9 @@ import torch
9
9
  if TYPE_CHECKING:
10
10
  from sglang.srt.layers.moe.token_dispatcher import (
11
11
  DeepEPLLCombineInput,
12
- DeepEPLLOutput,
12
+ DeepEPLLDispatchOutput,
13
13
  DeepEPNormalCombineInput,
14
- DeepEPNormalOutput,
14
+ DeepEPNormalDispatchOutput,
15
15
  StandardCombineInput,
16
16
  StandardDispatchOutput,
17
17
  )
@@ -28,22 +28,28 @@ class DispatchOutputChecker:
28
28
  ) -> TypeGuard[StandardDispatchOutput]:
29
29
  return dispatch_output.format.is_standard()
30
30
 
31
+ @staticmethod
32
+ def format_is_triton_kernels(
33
+ dispatch_output: DispatchOutput,
34
+ ) -> TypeGuard[StandardDispatchOutput]:
35
+ return dispatch_output.format.is_standard()
36
+
31
37
  @staticmethod
32
38
  def format_is_deepep_normal(
33
39
  dispatch_output: DispatchOutput,
34
- ) -> TypeGuard[DeepEPNormalOutput]:
40
+ ) -> TypeGuard[DeepEPNormalDispatchOutput]:
35
41
  return dispatch_output.format.is_deepep_normal()
36
42
 
37
43
  @staticmethod
38
44
  def format_is_deepep_ll(
39
45
  dispatch_output: DispatchOutput,
40
- ) -> TypeGuard[DeepEPLLOutput]:
46
+ ) -> TypeGuard[DeepEPLLDispatchOutput]:
41
47
  return dispatch_output.format.is_deepep_ll()
42
48
 
43
49
  @staticmethod
44
50
  def format_is_deepep(
45
51
  dispatch_output: DispatchOutput,
46
- ) -> TypeGuard[Union[DeepEPNormalOutput, DeepEPLLOutput]]:
52
+ ) -> TypeGuard[Union[DeepEPNormalDispatchOutput, DeepEPLLDispatchOutput]]:
47
53
  return dispatch_output.format.is_deepep()
48
54
 
49
55
 
@@ -58,7 +58,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
58
58
  logger = logging.getLogger(__name__)
59
59
 
60
60
 
61
- class DeepEPNormalOutput(NamedTuple):
61
+ class DeepEPNormalDispatchOutput(NamedTuple):
62
62
  """DeepEP normal dispatch output."""
63
63
 
64
64
  hidden_states: torch.Tensor
@@ -72,7 +72,7 @@ class DeepEPNormalOutput(NamedTuple):
72
72
  return DispatchOutputFormat.DEEPEP_NORMAL
73
73
 
74
74
 
75
- class DeepEPLLOutput(NamedTuple):
75
+ class DeepEPLLDispatchOutput(NamedTuple):
76
76
  """DeepEP low latency dispatch output."""
77
77
 
78
78
  hidden_states: torch.Tensor
@@ -87,14 +87,16 @@ class DeepEPLLOutput(NamedTuple):
87
87
  return DispatchOutputFormat.DEEPEP_LL
88
88
 
89
89
 
90
- assert isinstance(DeepEPNormalOutput, DispatchOutput)
91
- assert isinstance(DeepEPLLOutput, DispatchOutput)
90
+ assert isinstance(DeepEPNormalDispatchOutput, DispatchOutput)
91
+ assert isinstance(DeepEPLLDispatchOutput, DispatchOutput)
92
92
 
93
93
 
94
94
  class DeepEPNormalCombineInput(NamedTuple):
95
95
  """DeepEP normal combine input."""
96
96
 
97
- pass
97
+ hidden_states: torch.Tensor
98
+ topk_ids: torch.Tensor
99
+ topk_weights: torch.Tensor
98
100
 
99
101
  @property
100
102
  def format(self) -> CombineInputFormat:
@@ -104,7 +106,9 @@ class DeepEPNormalCombineInput(NamedTuple):
104
106
  class DeepEPLLCombineInput(NamedTuple):
105
107
  """DeepEP low latency combine input."""
106
108
 
107
- pass
109
+ hidden_states: torch.Tensor
110
+ topk_ids: torch.Tensor
111
+ topk_weights: torch.Tensor
108
112
 
109
113
  @property
110
114
  def format(self) -> CombineInputFormat:
@@ -327,7 +331,7 @@ class _DeepEPDispatcherImplBase:
327
331
  hidden_states: torch.Tensor,
328
332
  topk_ids: torch.Tensor,
329
333
  topk_weights: torch.Tensor,
330
- overlap_args: Optional["CombineOverlapArgs"],
334
+ overlap_args: Optional[CombineOverlapArgs] = None,
331
335
  ):
332
336
  raise NotImplementedError
333
337
 
@@ -383,7 +387,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
383
387
  else:
384
388
  hidden_states_scale = None
385
389
 
386
- return DeepEPNormalOutput(
390
+ return DeepEPNormalDispatchOutput(
387
391
  hidden_states,
388
392
  hidden_states_scale,
389
393
  topk_ids,
@@ -457,7 +461,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
457
461
  hidden_states: torch.Tensor,
458
462
  topk_ids: torch.Tensor,
459
463
  topk_weights: torch.Tensor,
460
- overlap_args: Optional["CombineOverlapArgs"],
464
+ overlap_args: Optional[CombineOverlapArgs] = None,
461
465
  ):
462
466
 
463
467
  if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
@@ -562,7 +566,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
562
566
  else:
563
567
  hidden_states_scale = None
564
568
 
565
- deepep_output = DeepEPLLOutput(
569
+ deepep_output = DeepEPLLDispatchOutput(
566
570
  hidden_states,
567
571
  hidden_states_scale,
568
572
  topk_ids,
@@ -613,7 +617,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
613
617
  hidden_states: torch.Tensor,
614
618
  topk_ids: torch.Tensor,
615
619
  topk_weights: torch.Tensor,
616
- overlap_args: Optional["CombineOverlapArgs"],
620
+ overlap_args: Optional[CombineOverlapArgs] = None,
617
621
  ):
618
622
  hidden_states, event, hook = self._combine_core(
619
623
  hidden_states,
@@ -639,7 +643,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
639
643
  hidden_states: torch.Tensor,
640
644
  topk_ids: torch.Tensor,
641
645
  topk_weights: torch.Tensor,
642
- overlap_args: Optional["CombineOverlapArgs"],
646
+ overlap_args: Optional[CombineOverlapArgs] = None,
643
647
  ):
644
648
  buffer = self._get_buffer()
645
649
 
@@ -756,18 +760,21 @@ class DeepEPDispatcher(BaseDispatcher):
756
760
  del self._dispatch_intermediate_state
757
761
  return self._get_impl().dispatch_b(*inner_state)
758
762
 
759
- def combine(self, *args, **kwargs) -> Tuple:
760
- self.combine_a(*args, **kwargs)
763
+ def combine(
764
+ self,
765
+ combine_input: CombineInput,
766
+ overlap_args: Optional[CombineOverlapArgs] = None,
767
+ ) -> Tuple:
768
+ self.combine_a(combine_input, overlap_args)
761
769
  ret = self.combine_b()
762
770
  return ret
763
771
 
764
772
  def combine_a(
765
773
  self,
766
- hidden_states: torch.Tensor,
767
- topk_ids: torch.Tensor,
768
- topk_weights: torch.Tensor,
769
- overlap_args: Optional["CombineOverlapArgs"] = None,
774
+ combine_input: CombineInput,
775
+ overlap_args: Optional[CombineOverlapArgs] = None,
770
776
  ):
777
+ hidden_states, topk_ids, topk_weights = combine_input
771
778
  self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
772
779
  inner_state = self._get_impl().combine_a(
773
780
  hidden_states=hidden_states,
@@ -88,7 +88,7 @@ class StandardDispatcher(BaseDispatcher):
88
88
  topk_output = topk_output._replace(
89
89
  topk_ids=self.local_expert_mapping[topk_output.topk_ids]
90
90
  )
91
- elif TopKOutputChecker.format_is_triton_kernel(topk_output):
91
+ elif TopKOutputChecker.format_is_triton_kernels(topk_output):
92
92
  raise NotImplementedError()
93
93
 
94
94
  return StandardDispatchOutput(
@@ -111,10 +111,10 @@ class TopKOutputChecker:
111
111
  return topk_output.format.is_standard()
112
112
 
113
113
  @staticmethod
114
- def format_is_triton_kernel(
114
+ def format_is_triton_kernels(
115
115
  topk_output: TopKOutput,
116
116
  ) -> TypeGuard[TritonKernelTopKOutput]:
117
- return topk_output.format.is_triton_kernel()
117
+ return topk_output.format.is_triton_kernels()
118
118
 
119
119
  @staticmethod
120
120
  def format_is_bypassed(topk_output: TopKOutput) -> TypeGuard[BypassedTopKOutput]:
@@ -129,7 +129,7 @@ class TopKOutputFormat(Enum):
129
129
  def is_standard(self) -> bool:
130
130
  return self == TopKOutputFormat.STANDARD
131
131
 
132
- def is_triton_kernel(self) -> bool:
132
+ def is_triton_kernels(self) -> bool:
133
133
  return self == TopKOutputFormat.TRITON_KERNEL
134
134
 
135
135
  def is_bypassed(self) -> bool:
@@ -254,7 +254,7 @@ class TopK(CustomOp):
254
254
  ) -> TopKOutput:
255
255
  if self.topk_config.output_format is not None:
256
256
  output_format = self.topk_config.output_format
257
- elif get_moe_runner_backend().is_triton_kernel():
257
+ elif get_moe_runner_backend().is_triton_kernels():
258
258
  output_format = TopKOutputFormat.TRITON_KERNEL
259
259
  elif (
260
260
  should_use_flashinfer_trtllm_moe()
@@ -314,16 +314,41 @@ class TopK(CustomOp):
314
314
  num_token_non_padded: Optional[torch.Tensor] = None,
315
315
  expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
316
316
  ) -> TopKOutput:
317
- global_num_experts = router_logits.shape[-1]
318
317
 
319
- # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
320
- if global_num_experts == 256:
318
+ use_grouped_topk = self.topk_config.use_grouped_topk
319
+ torch_native = self.topk_config.torch_native
320
+ renormalize = self.topk_config.renormalize
321
321
 
322
+ if not use_grouped_topk and not torch_native:
323
+ topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
324
+ router_logits,
325
+ k=self.topk_config.top_k,
326
+ )
327
+ topk_weights = topk_weights.to(torch.float32)
328
+
329
+ if renormalize:
330
+ topk_weights_sum = (
331
+ topk_weights.sum(dim=-1, keepdim=True)
332
+ if self.topk_config.num_fused_shared_experts == 0
333
+ else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
334
+ )
335
+ topk_weights = topk_weights / topk_weights_sum
336
+
337
+ if expert_location_dispatch_info is not None:
338
+ topk_ids = topk_ids_logical_to_physical(
339
+ topk_ids, expert_location_dispatch_info
340
+ )
341
+ get_global_expert_distribution_recorder().on_select_experts(
342
+ topk_ids=topk_ids
343
+ )
344
+
345
+ return StandardTopKOutput(topk_weights, topk_ids, _)
346
+ if use_grouped_topk and not torch_native and router_logits.shape[-1] == 256:
347
+ # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
322
348
  routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
323
- router_logits = router_logits.to(torch.float32)
324
349
 
325
350
  topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
326
- router_logits,
351
+ router_logits.to(torch.float32),
327
352
  k=self.topk_config.top_k,
328
353
  bias=self.topk_config.correction_bias.to(torch.float32),
329
354
  k_group=self.topk_config.topk_group,
@@ -335,7 +360,7 @@ class TopK(CustomOp):
335
360
  eps=float(1e-20),
336
361
  )
337
362
 
338
- if self.topk_config.renormalize:
363
+ if renormalize:
339
364
  topk_weights_sum = (
340
365
  topk_weights.sum(dim=-1, keepdim=True)
341
366
  if self.topk_config.num_fused_shared_experts == 0
@@ -51,7 +51,7 @@ class MoeRunnerBackend(Enum):
51
51
  AUTO = "auto"
52
52
  DEEP_GEMM = "deep_gemm"
53
53
  TRITON = "triton"
54
- TRITON_KERNEL = "triton_kernel"
54
+ TRITON_KERNELS = "triton_kernel"
55
55
  FLASHINFER_TRTLLM = "flashinfer_trtllm"
56
56
  FLASHINFER_CUTLASS = "flashinfer_cutlass"
57
57
  FLASHINFER_MXFP4 = "flashinfer_mxfp4"
@@ -67,8 +67,8 @@ class MoeRunnerBackend(Enum):
67
67
  def is_triton(self):
68
68
  return self == MoeRunnerBackend.TRITON
69
69
 
70
- def is_triton_kernel(self):
71
- return self == MoeRunnerBackend.TRITON_KERNEL
70
+ def is_triton_kernels(self):
71
+ return self == MoeRunnerBackend.TRITON_KERNELS
72
72
 
73
73
  def is_flashinfer_trtllm(self):
74
74
  return self == MoeRunnerBackend.FLASHINFER_TRTLLM
@@ -152,7 +152,6 @@ def initialize_moe_config(server_args: ServerArgs):
152
152
  def get_moe_a2a_backend() -> MoeA2ABackend:
153
153
  global MOE_A2A_BACKEND
154
154
  if MOE_A2A_BACKEND is None:
155
- logger.warning("MOE_A2A_BACKEND is not initialized, using default backend")
156
155
  MOE_A2A_BACKEND = MoeA2ABackend.NONE
157
156
  return MOE_A2A_BACKEND
158
157
 
@@ -20,7 +20,9 @@ class PoolingType(IntEnum):
20
20
 
21
21
  @dataclass
22
22
  class EmbeddingPoolerOutput:
23
- embeddings: torch.Tensor
23
+ # Pooler can return list[tensor] instead of tensor if the dimension of each tensor in the batch is different
24
+ # due to different per-request matryoshka dim truncation
25
+ embeddings: torch.Tensor | list[torch.Tensor]
24
26
 
25
27
 
26
28
  class Pooler(nn.Module):
@@ -42,6 +44,7 @@ class Pooler(nn.Module):
42
44
  def forward(
43
45
  self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
44
46
  ) -> EmbeddingPoolerOutput:
47
+
45
48
  if self.pooling_type == PoolingType.LAST:
46
49
  last_token_indices = torch.cumsum(forward_batch.extend_seq_lens, dim=0) - 1
47
50
  pooled_data = hidden_states[last_token_indices]
@@ -53,8 +56,24 @@ class Pooler(nn.Module):
53
56
  else:
54
57
  raise ValueError(f"Invalid pooling type: {self.pooling_type}")
55
58
 
59
+ if forward_batch.dimensions is not None:
60
+ all_same_dimensions = len(set(forward_batch.dimensions)) == 1
61
+ if all_same_dimensions:
62
+ pooled_data = pooled_data[..., : forward_batch.dimensions[0]]
63
+ else:
64
+ pooled_data = [
65
+ tensor[..., :dim]
66
+ for tensor, dim in zip(pooled_data, forward_batch.dimensions)
67
+ ]
68
+
56
69
  if self.normalize:
57
- pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
70
+ if isinstance(pooled_data, list):
71
+ pooled_data = [
72
+ nn.functional.normalize(tensor, p=2, dim=-1)
73
+ for tensor in pooled_data
74
+ ]
75
+ else:
76
+ pooled_data = nn.functional.normalize(pooled_data, p=2, dim=-1)
58
77
 
59
78
  return EmbeddingPoolerOutput(embeddings=pooled_data)
60
79
 
@@ -7,36 +7,16 @@ from typing import TYPE_CHECKING, Dict, Optional, Type
7
7
 
8
8
  import torch
9
9
 
10
- try:
11
- from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
12
- from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
13
- from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
14
- from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
15
- from vllm.model_executor.layers.quantization.gguf import GGUFConfig
16
- from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
17
- GPTQMarlin24Config,
18
- )
19
- from vllm.model_executor.layers.quantization.marlin import MarlinConfig
20
- from vllm.model_executor.layers.quantization.qqq import QQQConfig
21
- from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
22
-
23
- VLLM_AVAILABLE = True
24
- except ImportError as e:
25
- VLLM_AVAILABLE = False
26
- VLLM_IMPORT_ERROR = e
27
-
28
- # Define empty classes as placeholders when vllm is not available
29
- class DummyConfig:
30
- def override_quantization_method(self, *args, **kwargs):
31
- return None
32
-
33
- AQLMConfig = BitsAndBytesConfig = CompressedTensorsConfig = DeepSpeedFPConfig = (
34
- ExpertsInt8Config
35
- ) = GGUFConfig = GPTQMarlin24Config = MarlinConfig = QQQConfig = Int8TpuConfig = (
36
- DummyConfig
37
- )
10
+
11
+ # Define empty classes as placeholders when vllm is not available
12
+ class DummyConfig:
13
+ def override_quantization_method(self, *args, **kwargs):
14
+ return None
38
15
 
39
16
 
17
+ CompressedTensorsConfig = DummyConfig
18
+
19
+ from sglang.srt.layers.quantization.auto_round import AutoRoundConfig
40
20
  from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig
41
21
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
42
22
  from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
@@ -45,6 +25,7 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
45
25
  )
46
26
  from sglang.srt.layers.quantization.fp8 import Fp8Config
47
27
  from sglang.srt.layers.quantization.fpgemm_fp8 import FBGEMMFp8Config
28
+ from sglang.srt.layers.quantization.gguf import GGUFConfig
48
29
  from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
49
30
  from sglang.srt.layers.quantization.modelopt_quant import (
50
31
  ModelOptFp4Config,
@@ -64,7 +45,7 @@ _is_mxfp_supported = mxfp_supported()
64
45
  if TYPE_CHECKING:
65
46
  from sglang.srt.layers.moe.topk import TopKOutput
66
47
 
67
- # Base quantization methods that don't depend on vllm
48
+ # Base quantization methods
68
49
  BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
69
50
  "fp8": Fp8Config,
70
51
  "blockwise_int8": BlockInt8Config,
@@ -75,6 +56,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
75
56
  "w8a8_fp8": W8A8Fp8Config,
76
57
  "awq": AWQConfig,
77
58
  "awq_marlin": AWQMarlinConfig,
59
+ "gguf": GGUFConfig,
78
60
  "gptq": GPTQConfig,
79
61
  "gptq_marlin": GPTQMarlinConfig,
80
62
  "moe_wna16": MoeWNA16Config,
@@ -83,6 +65,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
83
65
  "w4afp8": W4AFp8Config,
84
66
  "petit_nvfp4": PetitNvFp4Config,
85
67
  "fbgemm_fp8": FBGEMMFp8Config,
68
+ "auto-round": AutoRoundConfig,
86
69
  }
87
70
 
88
71
 
@@ -102,20 +85,8 @@ elif _is_mxfp_supported and is_hip():
102
85
  "mxfp4": Mxfp4Config,
103
86
  }
104
87
  )
105
- # VLLM-dependent quantization methods
106
- VLLM_QUANTIZATION_METHODS = {
107
- "aqlm": AQLMConfig,
108
- "deepspeedfp": DeepSpeedFPConfig,
109
- "tpu_int8": Int8TpuConfig,
110
- "marlin": MarlinConfig,
111
- "gguf": GGUFConfig,
112
- "gptq_marlin_24": GPTQMarlin24Config,
113
- "bitsandbytes": BitsAndBytesConfig,
114
- "qqq": QQQConfig,
115
- "experts_int8": ExpertsInt8Config,
116
- }
117
88
 
118
- QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS, **VLLM_QUANTIZATION_METHODS}
89
+ QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS}
119
90
 
120
91
 
121
92
  def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
@@ -124,50 +95,8 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
124
95
  f"Invalid quantization method: {quantization}. "
125
96
  f"Available methods: {list(QUANTIZATION_METHODS.keys())}"
126
97
  )
127
- if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
128
- raise ValueError(
129
- f"{quantization} quantization requires some operators from vllm. "
130
- f"Please install vllm by `pip install vllm==0.9.0.1`\n"
131
- f"Import error: {VLLM_IMPORT_ERROR}"
132
- )
133
98
 
134
99
  return QUANTIZATION_METHODS[quantization]
135
100
 
136
101
 
137
102
  original_isinstance = builtins.isinstance
138
-
139
-
140
- def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
141
- """
142
- Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig
143
- can recognize sglang layers
144
- """
145
- if not VLLM_AVAILABLE:
146
- return
147
-
148
- if reverse:
149
- builtins.isinstance = original_isinstance
150
- return
151
-
152
- from vllm.model_executor.layers.fused_moe import FusedMoE
153
- from vllm.model_executor.layers.linear import LinearBase
154
- from vllm.model_executor.layers.vocab_parallel_embedding import (
155
- VocabParallelEmbedding,
156
- )
157
-
158
- from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
159
- from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE as PatchedFusedMoE
160
- from sglang.srt.layers.vocab_parallel_embedding import (
161
- VocabParallelEmbedding as PatchedVocabParallelEmbedding,
162
- )
163
-
164
- def patched_isinstance(obj, classinfo):
165
- if classinfo is LinearBase:
166
- return original_isinstance(obj, PatchedLinearBase)
167
- if classinfo is FusedMoE:
168
- return original_isinstance(obj, PatchedFusedMoE)
169
- if classinfo is VocabParallelEmbedding:
170
- return original_isinstance(obj, PatchedVocabParallelEmbedding)
171
- return original_isinstance(obj, classinfo)
172
-
173
- builtins.isinstance = patched_isinstance