sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_offline_throughput.py +10 -8
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +157 -21
  4. sglang/bench_serving.py +137 -59
  5. sglang/compile_deep_gemm.py +5 -5
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +40 -28
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +49 -44
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +129 -135
  20. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  21. sglang/srt/disaggregation/fake/conn.py +3 -13
  22. sglang/srt/disaggregation/kv_events.py +357 -0
  23. sglang/srt/disaggregation/mini_lb.py +57 -24
  24. sglang/srt/disaggregation/mooncake/conn.py +238 -122
  25. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  26. sglang/srt/disaggregation/nixl/conn.py +10 -19
  27. sglang/srt/disaggregation/prefill.py +132 -47
  28. sglang/srt/disaggregation/utils.py +123 -6
  29. sglang/srt/distributed/utils.py +3 -3
  30. sglang/srt/entrypoints/EngineBase.py +5 -0
  31. sglang/srt/entrypoints/engine.py +44 -9
  32. sglang/srt/entrypoints/http_server.py +23 -6
  33. sglang/srt/entrypoints/http_server_engine.py +5 -2
  34. sglang/srt/function_call/base_format_detector.py +250 -0
  35. sglang/srt/function_call/core_types.py +34 -0
  36. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  37. sglang/srt/function_call/ebnf_composer.py +234 -0
  38. sglang/srt/function_call/function_call_parser.py +175 -0
  39. sglang/srt/function_call/llama32_detector.py +74 -0
  40. sglang/srt/function_call/mistral_detector.py +84 -0
  41. sglang/srt/function_call/pythonic_detector.py +163 -0
  42. sglang/srt/function_call/qwen25_detector.py +67 -0
  43. sglang/srt/function_call/utils.py +35 -0
  44. sglang/srt/hf_transformers_utils.py +46 -7
  45. sglang/srt/layers/attention/aiter_backend.py +513 -0
  46. sglang/srt/layers/attention/flashattention_backend.py +64 -18
  47. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  48. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  49. sglang/srt/layers/attention/triton_backend.py +3 -0
  50. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  51. sglang/srt/layers/attention/utils.py +6 -4
  52. sglang/srt/layers/attention/vision.py +1 -1
  53. sglang/srt/layers/communicator.py +451 -0
  54. sglang/srt/layers/dp_attention.py +61 -21
  55. sglang/srt/layers/layernorm.py +1 -1
  56. sglang/srt/layers/logits_processor.py +46 -11
  57. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  58. sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
  59. sglang/srt/layers/moe/ep_moe/layer.py +105 -51
  60. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  63. sglang/srt/layers/moe/topk.py +67 -10
  64. sglang/srt/layers/multimodal.py +70 -0
  65. sglang/srt/layers/quantization/__init__.py +8 -3
  66. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  67. sglang/srt/layers/quantization/deep_gemm.py +77 -74
  68. sglang/srt/layers/quantization/fp8.py +92 -2
  69. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  70. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  71. sglang/srt/layers/quantization/gptq.py +298 -6
  72. sglang/srt/layers/quantization/int8_kernel.py +20 -7
  73. sglang/srt/layers/quantization/qoq.py +244 -0
  74. sglang/srt/layers/sampler.py +0 -4
  75. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  76. sglang/srt/lora/lora_manager.py +2 -4
  77. sglang/srt/lora/mem_pool.py +4 -4
  78. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  79. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  80. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  81. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  82. sglang/srt/lora/utils.py +1 -1
  83. sglang/srt/managers/data_parallel_controller.py +3 -3
  84. sglang/srt/managers/deepseek_eplb.py +278 -0
  85. sglang/srt/managers/detokenizer_manager.py +21 -8
  86. sglang/srt/managers/eplb_manager.py +55 -0
  87. sglang/srt/managers/expert_distribution.py +704 -56
  88. sglang/srt/managers/expert_location.py +394 -0
  89. sglang/srt/managers/expert_location_dispatch.py +91 -0
  90. sglang/srt/managers/io_struct.py +19 -4
  91. sglang/srt/managers/mm_utils.py +294 -140
  92. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  93. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  94. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  95. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  96. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  97. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  98. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  99. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  100. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  101. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  102. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  103. sglang/srt/managers/schedule_batch.py +122 -42
  104. sglang/srt/managers/schedule_policy.py +1 -5
  105. sglang/srt/managers/scheduler.py +205 -138
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/tokenizer_manager.py +232 -58
  109. sglang/srt/managers/tp_worker.py +12 -9
  110. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  111. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  112. sglang/srt/mem_cache/chunk_cache.py +3 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  114. sglang/srt/mem_cache/memory_pool.py +76 -52
  115. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  116. sglang/srt/mem_cache/radix_cache.py +58 -5
  117. sglang/srt/metrics/collector.py +314 -39
  118. sglang/srt/mm_utils.py +10 -0
  119. sglang/srt/model_executor/cuda_graph_runner.py +29 -19
  120. sglang/srt/model_executor/expert_location_updater.py +422 -0
  121. sglang/srt/model_executor/forward_batch_info.py +5 -1
  122. sglang/srt/model_executor/model_runner.py +163 -68
  123. sglang/srt/model_loader/loader.py +10 -6
  124. sglang/srt/models/clip.py +5 -1
  125. sglang/srt/models/deepseek_janus_pro.py +2 -2
  126. sglang/srt/models/deepseek_v2.py +308 -351
  127. sglang/srt/models/exaone.py +8 -3
  128. sglang/srt/models/gemma3_mm.py +70 -33
  129. sglang/srt/models/llama.py +2 -0
  130. sglang/srt/models/llama4.py +15 -8
  131. sglang/srt/models/llava.py +258 -7
  132. sglang/srt/models/mimo_mtp.py +220 -0
  133. sglang/srt/models/minicpmo.py +5 -12
  134. sglang/srt/models/mistral.py +71 -1
  135. sglang/srt/models/mixtral.py +98 -34
  136. sglang/srt/models/mllama.py +3 -3
  137. sglang/srt/models/pixtral.py +467 -0
  138. sglang/srt/models/qwen2.py +95 -26
  139. sglang/srt/models/qwen2_5_vl.py +8 -0
  140. sglang/srt/models/qwen2_moe.py +330 -60
  141. sglang/srt/models/qwen2_vl.py +6 -0
  142. sglang/srt/models/qwen3.py +52 -10
  143. sglang/srt/models/qwen3_moe.py +411 -48
  144. sglang/srt/models/roberta.py +1 -1
  145. sglang/srt/models/siglip.py +294 -0
  146. sglang/srt/models/torch_native_llama.py +1 -1
  147. sglang/srt/openai_api/adapter.py +58 -20
  148. sglang/srt/openai_api/protocol.py +6 -8
  149. sglang/srt/operations.py +154 -0
  150. sglang/srt/operations_strategy.py +31 -0
  151. sglang/srt/reasoning_parser.py +3 -3
  152. sglang/srt/sampling/custom_logit_processor.py +18 -3
  153. sglang/srt/sampling/sampling_batch_info.py +4 -56
  154. sglang/srt/sampling/sampling_params.py +2 -2
  155. sglang/srt/server_args.py +162 -22
  156. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  157. sglang/srt/speculative/eagle_utils.py +138 -7
  158. sglang/srt/speculative/eagle_worker.py +69 -21
  159. sglang/srt/utils.py +74 -17
  160. sglang/test/few_shot_gsm8k.py +2 -2
  161. sglang/test/few_shot_gsm8k_engine.py +2 -2
  162. sglang/test/run_eval.py +2 -2
  163. sglang/test/runners.py +8 -1
  164. sglang/test/send_one.py +13 -3
  165. sglang/test/simple_eval_common.py +1 -1
  166. sglang/test/simple_eval_humaneval.py +1 -1
  167. sglang/test/test_cutlass_moe.py +278 -0
  168. sglang/test/test_programs.py +5 -5
  169. sglang/test/test_utils.py +55 -14
  170. sglang/utils.py +3 -3
  171. sglang/version.py +1 -1
  172. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
  173. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
  174. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  175. sglang/srt/function_call_parser.py +0 -858
  176. sglang/srt/platforms/interface.py +0 -371
  177. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  178. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  179. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -23,15 +23,17 @@ import triton.language as tl
23
23
  from torch import nn
24
24
 
25
25
  from sglang.srt.distributed import (
26
- get_tensor_model_parallel_rank,
27
26
  get_tensor_model_parallel_world_size,
28
27
  tensor_model_parallel_all_gather,
29
28
  )
30
29
  from sglang.srt.layers.dp_attention import (
30
+ attn_tp_all_gather,
31
31
  dp_gather_replicate,
32
32
  dp_scatter,
33
- get_attention_dp_rank,
34
33
  get_attention_dp_size,
34
+ get_attention_tp_size,
35
+ get_local_attention_dp_rank,
36
+ get_local_attention_dp_size,
35
37
  )
36
38
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
37
39
  from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -45,6 +47,18 @@ from sglang.srt.utils import dump_to_file
45
47
  logger = logging.getLogger(__name__)
46
48
 
47
49
 
50
+ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
51
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
52
+ from sglang.srt.model_executor.forward_batch_info import (
53
+ CaptureHiddenMode,
54
+ ForwardBatch,
55
+ ForwardMode,
56
+ )
57
+ from sglang.srt.utils import dump_to_file
58
+
59
+ logger = logging.getLogger(__name__)
60
+
61
+
48
62
  @dataclasses.dataclass
49
63
  class LogitsProcessorOutput:
50
64
  ## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
@@ -169,7 +183,7 @@ class LogitsMetadata:
169
183
  return
170
184
 
171
185
  cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
172
- dp_rank = get_attention_dp_rank()
186
+ dp_rank = get_local_attention_dp_rank()
173
187
  if dp_rank == 0:
174
188
  dp_local_start_pos = torch.zeros_like(
175
189
  self.global_num_tokens_for_logprob_gpu[0]
@@ -198,12 +212,20 @@ class LogitsProcessor(nn.Module):
198
212
  super().__init__()
199
213
  self.config = config
200
214
  self.logit_scale = logit_scale
201
- self.do_tensor_parallel_all_gather = (
202
- not skip_all_gather and get_tensor_model_parallel_world_size() > 1
203
- )
204
- self.do_tensor_parallel_all_gather_dp_attn = (
205
- self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1
206
- )
215
+ self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
216
+ if self.use_attn_tp_group:
217
+ self.attn_tp_size = get_attention_tp_size()
218
+ self.do_tensor_parallel_all_gather = (
219
+ not skip_all_gather and self.attn_tp_size > 1
220
+ )
221
+ self.do_tensor_parallel_all_gather_dp_attn = False
222
+ else:
223
+ self.do_tensor_parallel_all_gather = (
224
+ not skip_all_gather and get_tensor_model_parallel_world_size() > 1
225
+ )
226
+ self.do_tensor_parallel_all_gather_dp_attn = (
227
+ self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1
228
+ )
207
229
  self.final_logit_softcapping = getattr(
208
230
  self.config, "final_logit_softcapping", None
209
231
  )
@@ -315,7 +337,8 @@ class LogitsProcessor(nn.Module):
315
337
 
316
338
  if self.debug_tensor_dump_output_folder:
317
339
  assert (
318
- not self.do_tensor_parallel_all_gather or get_attention_dp_size() == 1
340
+ not self.do_tensor_parallel_all_gather
341
+ or get_local_attention_dp_size() == 1
319
342
  ), "dp attention + sharded lm_head doesn't support full logits"
320
343
  full_logits = self._get_logits(hidden_states, lm_head, logits_metadata)
321
344
  dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits)
@@ -442,7 +465,19 @@ class LogitsProcessor(nn.Module):
442
465
  logits.mul_(self.logit_scale)
443
466
 
444
467
  if self.do_tensor_parallel_all_gather:
445
- logits = tensor_model_parallel_all_gather(logits)
468
+ if self.use_attn_tp_group:
469
+ global_logits = torch.empty(
470
+ (self.config.vocab_size, logits.shape[0]),
471
+ device=logits.device,
472
+ dtype=logits.dtype,
473
+ )
474
+ global_logits = global_logits.T
475
+ attn_tp_all_gather(
476
+ list(global_logits.tensor_split(self.attn_tp_size, dim=-1)), logits
477
+ )
478
+ logits = global_logits
479
+ else:
480
+ logits = tensor_model_parallel_all_gather(logits)
446
481
 
447
482
  if self.do_tensor_parallel_all_gather_dp_attn:
448
483
  logits, global_logits = (
@@ -0,0 +1,207 @@
1
+ """Cutlass MoE kernel."""
2
+
3
+ import functools
4
+ import json
5
+ import logging
6
+ import os
7
+ from typing import Any, Callable, Dict, List, Optional, Tuple
8
+
9
+ import torch
10
+
11
+ from sglang.srt.utils import is_cuda
12
+
13
+ _is_cuda = is_cuda()
14
+ if _is_cuda:
15
+ import sgl_kernel
16
+ from sgl_kernel import (
17
+ fp8_blockwise_scaled_grouped_mm,
18
+ prepare_moe_input,
19
+ silu_and_mul,
20
+ )
21
+
22
+
23
+ def cutlass_fused_experts(
24
+ a: torch.Tensor,
25
+ w1_q: torch.Tensor,
26
+ w2_q: torch.Tensor,
27
+ w1_scale: torch.Tensor,
28
+ w2_scale: torch.Tensor,
29
+ topk_weights: torch.Tensor,
30
+ topk_ids: torch.Tensor,
31
+ a1_strides: torch.Tensor,
32
+ c1_strides: torch.Tensor,
33
+ a2_strides: torch.Tensor,
34
+ c2_strides: torch.Tensor,
35
+ workspace: torch.Tensor,
36
+ a_ptrs: torch.Tensor,
37
+ b_ptrs: torch.Tensor,
38
+ out_ptrs: torch.Tensor,
39
+ a_scales_ptrs: torch.Tensor,
40
+ b_scales_ptrs: torch.Tensor,
41
+ expert_offsets: torch.Tensor,
42
+ problem_sizes1: torch.Tensor,
43
+ problem_sizes2: torch.Tensor,
44
+ use_fp8_blockscale: bool = True,
45
+ ) -> torch.Tensor:
46
+ """Performs Fused MoE computation using CUTLASS-like kernels with FP8 weights and activations.
47
+
48
+ This function implements a Mixture of Experts (MoE) layer with a SwiGLU/SiLU
49
+ activation, leveraging custom kernels likely derived from CUTLASS principles
50
+ for grouped matrix multiplication (`fp8_blockwise_scaled_grouped_mm`) and
51
+ data preparation (`prepare_moe_input`, `silu_and_mul`).
52
+
53
+ It handles per-token routing, quantizes input activations to FP8 with
54
+ per-token scales, performs the expert computations using FP8 GEMMs with
55
+ pre-quantized FP8 weights (per-block scales), applies the SiLU activation,
56
+ and combines the results weighted by the router scores.
57
+
58
+ Args:
59
+ a (torch.Tensor): Input activations. Shape: `(m, k)`, where `m` is the total
60
+ number of tokens and `k` is the hidden size. Expected dtype: `torch.half`
61
+ or `torch.bfloat16`.
62
+ w1_q (torch.Tensor): Pre-quantized FP8 weight tensor for the first GEMM
63
+ (up-projection part of SwiGLU). Expected shape: `(E, k, n*2)`, where
64
+ `E` is the number of experts, `k` is the hidden size, and `n*2` is the
65
+ intermediate size (`I`). Expected dtype: `torch.float8_e4m3fn`.
66
+ Note: This shape implies weights are stored as (num_experts, hidden_size, intermediate_size).
67
+ w2_q (torch.Tensor): Pre-quantized FP8 weight tensor for the second GEMM
68
+ (down-projection). Expected shape: `(E, n, k)`, where `n` is half the
69
+ intermediate size (`I // 2`). Expected dtype: `torch.float8_e4m3fn`.
70
+ Note: This shape implies weights are stored as (num_experts, intermediate_size // 2, hidden_size).
71
+ w1_scale (torch.Tensor): Scales corresponding to `w1_q` (per-block scales).
72
+ Shape: `(E, num_blocks_n, num_blocks_k)`. Dtype: `torch.float32`.
73
+ w2_scale (torch.Tensor): Scales corresponding to `w2_q` (per-block scales).
74
+ Shape: `(E, num_blocks_k, num_blocks_n)`. Dtype: `torch.float32`.
75
+ topk_weights (torch.Tensor): Router weights for the selected top-k experts
76
+ for each token. Shape: `(m, topk)`. Dtype should ideally match `a`.
77
+ topk_ids (torch.Tensor): Indices of the selected top-k experts for each token.
78
+ Shape: `(m, topk)`. Dtype: `torch.int32`.
79
+ a1_strides (torch.Tensor): Stride information for the first GEMM's 'a' input.
80
+ Passed directly to the underlying kernel. Expected shape `(E,)`, dtype `torch.int64`.
81
+ Note: Its exact usage within `fp8_blockwise_scaled_grouped_mm` needs clarification
82
+ as it's passed as both a_stride and b_stride in the first call.
83
+ c1_strides (torch.Tensor): Stride information for the first GEMM's 'c' output.
84
+ Passed directly to the underlying kernel. Expected shape `(E,)`, dtype `torch.int64`.
85
+ a2_strides (torch.Tensor): Stride information for the second GEMM's 'a' input.
86
+ Passed directly to the underlying kernel. Expected shape `(E,)`, dtype `torch.int64`.
87
+ Note: Its exact usage within `fp8_blockwise_scaled_grouped_mm` needs clarification
88
+ as it's passed as both a_stride and b_stride in the second call.
89
+ c2_strides (torch.Tensor): Stride information for the second GEMM's 'c' output.
90
+ Passed directly to the underlying kernel. Expected shape `(E,)`, dtype `torch.int64`.
91
+ workspace (torch.Tensor): Reusable workspace for the underlying kernel.
92
+ a_ptrs (torch.Tensor): Pointers container for calculating offsets of the input activations for each expert.
93
+ b_ptrs (torch.Tensor): Pointers container for calculating offsets of the input weights for each expert.
94
+ out_ptrs (torch.Tensor): Pointers container for calculating offsets of the output activations for each expert.
95
+ a_scales_ptrs (torch.Tensor): Pointers container for calculating offsets of the input scales for each expert.
96
+ b_scales_ptrs (torch.Tensor): Pointers container for calculating offsets of the input scales for each expert.
97
+ use_fp8_blockscale (bool, optional): Flag indicating usage of FP8 with
98
+ block scaling. Currently, only `True` is supported. Defaults to `True`.
99
+
100
+ Returns:
101
+ torch.Tensor: The computed MoE layer output. Shape: `(m, k)`, dtype matches `a`.
102
+
103
+ Raises:
104
+ AssertionError: If input shapes, dtypes, or flags are inconsistent or unsupported.
105
+ NotImplementedError: If CUDA is not available or `sgl_kernel` is not properly installed.
106
+ """
107
+ assert use_fp8_blockscale, "Only support fp8 blockscale for now"
108
+ assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
109
+ assert w1_q.dtype == torch.float8_e4m3fn
110
+ assert w2_q.dtype == torch.float8_e4m3fn
111
+ assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1"
112
+ assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2"
113
+ assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
114
+ assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch"
115
+ assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
116
+ assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
117
+ assert a.dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
118
+
119
+ if is_cuda:
120
+ from sglang.srt.layers.quantization.fp8_kernel import (
121
+ sglang_per_token_group_quant_fp8,
122
+ )
123
+
124
+ out_dtype = a.dtype
125
+ num_experts = w1_q.size(0)
126
+ m = a.size(0)
127
+ k = w1_q.size(1)
128
+ n = w2_q.size(1)
129
+
130
+ topk = topk_ids.size(1)
131
+
132
+ a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
133
+ device = a_q.device
134
+
135
+ a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
136
+ c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
137
+
138
+ prepare_moe_input(
139
+ topk_ids,
140
+ expert_offsets,
141
+ problem_sizes1,
142
+ problem_sizes2,
143
+ a_map,
144
+ c_map,
145
+ num_experts,
146
+ n,
147
+ k,
148
+ )
149
+
150
+ rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype)
151
+ rep_a1_scales = a1_scale[a_map]
152
+
153
+ c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
154
+ c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
155
+
156
+ a_sf_layout = torch.empty((num_experts, 5), device=device, dtype=torch.int)
157
+ w_sf_layout = torch.empty((num_experts, 5), device=device, dtype=torch.int)
158
+
159
+ fp8_blockwise_scaled_grouped_mm(
160
+ c1,
161
+ a_ptrs,
162
+ b_ptrs,
163
+ out_ptrs,
164
+ a_scales_ptrs,
165
+ b_scales_ptrs,
166
+ rep_a_q,
167
+ w1_q,
168
+ rep_a1_scales,
169
+ w1_scale,
170
+ a1_strides,
171
+ a1_strides,
172
+ c1_strides,
173
+ a_sf_layout,
174
+ w_sf_layout,
175
+ problem_sizes1,
176
+ expert_offsets[:-1],
177
+ workspace,
178
+ )
179
+
180
+ intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
181
+ silu_and_mul(c1, intermediate)
182
+
183
+ intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
184
+
185
+ fp8_blockwise_scaled_grouped_mm(
186
+ c2,
187
+ a_ptrs,
188
+ b_ptrs,
189
+ out_ptrs,
190
+ a_scales_ptrs,
191
+ b_scales_ptrs,
192
+ intemediate_q,
193
+ w2_q,
194
+ a2_scale,
195
+ w2_scale,
196
+ a2_strides,
197
+ a2_strides,
198
+ c2_strides,
199
+ a_sf_layout,
200
+ w_sf_layout,
201
+ problem_sizes2,
202
+ expert_offsets[:-1],
203
+ workspace,
204
+ )
205
+ return (
206
+ c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)
207
+ ).sum(dim=1)
@@ -3,10 +3,9 @@ from typing import List, Optional
3
3
 
4
4
  import torch
5
5
  import triton
6
- import triton.language as tl
7
6
 
8
7
  from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
9
- from sglang.srt.utils import is_cuda
8
+ from sglang.srt.utils import dispose_tensor, is_cuda
10
9
 
11
10
  logger = logging.getLogger(__name__)
12
11
 
@@ -116,7 +115,7 @@ def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int):
116
115
  seg_indptr = torch.empty(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
117
116
  src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int64)
118
117
 
119
- # Find offet
118
+ # Find offset
120
119
  expert_ids = torch.arange(
121
120
  num_experts + 1, device=topk_ids.device, dtype=reorder_topk_ids.dtype
122
121
  )
@@ -653,12 +652,15 @@ def grouped_gemm_triton(
653
652
  scale_a: torch.Tensor = None,
654
653
  scale_b: torch.Tensor = None,
655
654
  block_shape: Optional[List[int]] = None,
655
+ c_dtype=None,
656
656
  ):
657
657
  assert weight_column_major == True # TODO: more
658
658
  if use_fp8_w8a8 and block_shape is None:
659
659
  assert scale_a is not None and scale_b is not None
660
660
 
661
661
  if block_shape is not None:
662
+ a_original = a
663
+
662
664
  assert len(block_shape) == 2
663
665
  block_n, block_k = block_shape[0], block_shape[1]
664
666
  a, scale_a = per_token_group_quant_fp8(a, block_k)
@@ -667,6 +669,8 @@ def grouped_gemm_triton(
667
669
  assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2]
668
670
  assert triton.cdiv(b.shape[-1], block_k) == scale_b.shape[-1]
669
671
 
672
+ dispose_tensor(a_original)
673
+
670
674
  # TODO: adjust config or tune kernel
671
675
  # Reduce block size to prevent L40 shared memory overflow.
672
676
  config = {
@@ -680,6 +684,10 @@ def grouped_gemm_triton(
680
684
  m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"]
681
685
  )
682
686
 
687
+ if c is None:
688
+ assert c_dtype is not None
689
+ c = torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=c_dtype)
690
+
683
691
  grid = lambda META: (
684
692
  triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size,
685
693
  triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]),
@@ -783,19 +791,23 @@ def _fwd_kernel_ep_scatter_2(
783
791
  offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
784
792
  mask_s = offset_in_s < SCALE_HIDDEN_SIZE
785
793
 
786
- for token_id in range(start_token_id, total_token_num, grid_num):
794
+ for token_id_int32 in range(start_token_id, total_token_num, grid_num):
795
+ token_id = token_id_int32.to(tl.int64)
787
796
  to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
788
797
  to_copy_s = tl.load(
789
798
  recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s
790
799
  )
791
800
 
792
- for topk_index in tl.range(0, topk_num, 1, num_stages=4):
801
+ for topk_idx_int32 in tl.range(0, topk_num, 1, num_stages=4):
802
+ topk_index = topk_idx_int32.to(tl.int64)
793
803
  expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)
794
804
  if expert_id >= 0:
795
- dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1)
805
+ dest_token_index_int32 = tl.atomic_add(expert_start_loc + expert_id, 1)
806
+ dest_token_index = dest_token_index_int32.to(tl.int64)
807
+
796
808
  tl.store(
797
809
  output_index + token_id * output_index_stride0 + topk_index,
798
- dest_token_index,
810
+ dest_token_index_int32,
799
811
  )
800
812
  output_tensor_ptr = (
801
813
  output_tensor + dest_token_index * output_tensor_stride0
@@ -894,21 +906,31 @@ def _fwd_kernel_ep_gather(
894
906
  topk_num: tl.constexpr,
895
907
  BLOCK_D: tl.constexpr,
896
908
  ):
897
- cur_block = tl.program_id(0)
898
- start_cur_token = tl.program_id(1)
909
+ cur_block_int32 = tl.program_id(0)
910
+ cur_block = cur_block_int32.to(tl.int64)
911
+
912
+ start_cur_token_int32 = tl.program_id(1)
913
+
899
914
  grid_num = tl.num_programs(1)
900
915
 
901
- for cur_token in range(start_cur_token, total_token_num, grid_num):
916
+ for cur_token_int32 in range(start_cur_token_int32, total_token_num, grid_num):
917
+ cur_token = cur_token_int32.to(tl.int64)
918
+
902
919
  off_d = tl.arange(0, BLOCK_D)
903
920
  accumulator = tl.zeros([BLOCK_D], dtype=tl.float32)
904
- for topk_index in range(0, topk_num):
921
+
922
+ for topk_index_int32 in range(0, topk_num):
923
+ topk_index = topk_index_int32.to(tl.int64)
924
+
905
925
  expert_id = tl.load(
906
926
  recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index
907
927
  )
908
928
  if expert_id >= 0:
909
- source_token_index = tl.load(
929
+ source_token_index_int32 = tl.load(
910
930
  input_index + cur_token * input_index_stride0 + topk_index
911
931
  )
932
+ source_token_index = source_token_index_int32.to(tl.int64)
933
+
912
934
  acc_weight = tl.load(
913
935
  recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index
914
936
  )