sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_one_batch.py +8 -6
  4. sglang/bench_serving.py +1 -1
  5. sglang/lang/interpreter.py +40 -1
  6. sglang/lang/ir.py +27 -0
  7. sglang/math_utils.py +8 -0
  8. sglang/srt/_custom_ops.py +2 -2
  9. sglang/srt/code_completion_parser.py +2 -44
  10. sglang/srt/configs/model_config.py +6 -0
  11. sglang/srt/constants.py +3 -0
  12. sglang/srt/conversation.py +19 -3
  13. sglang/srt/custom_op.py +5 -1
  14. sglang/srt/disaggregation/base/__init__.py +1 -1
  15. sglang/srt/disaggregation/base/conn.py +25 -11
  16. sglang/srt/disaggregation/common/__init__.py +5 -1
  17. sglang/srt/disaggregation/common/utils.py +42 -0
  18. sglang/srt/disaggregation/decode.py +211 -72
  19. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  20. sglang/srt/disaggregation/fake/__init__.py +1 -1
  21. sglang/srt/disaggregation/fake/conn.py +15 -9
  22. sglang/srt/disaggregation/mini_lb.py +34 -4
  23. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  24. sglang/srt/disaggregation/mooncake/conn.py +30 -29
  25. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  26. sglang/srt/disaggregation/nixl/conn.py +17 -12
  27. sglang/srt/disaggregation/prefill.py +144 -55
  28. sglang/srt/disaggregation/utils.py +155 -123
  29. sglang/srt/distributed/parallel_state.py +12 -4
  30. sglang/srt/entrypoints/engine.py +37 -29
  31. sglang/srt/entrypoints/http_server.py +153 -72
  32. sglang/srt/entrypoints/http_server_engine.py +0 -3
  33. sglang/srt/entrypoints/openai/__init__.py +0 -0
  34. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
  35. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  36. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  37. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  38. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  39. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  40. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  41. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  42. sglang/srt/entrypoints/openai/utils.py +72 -0
  43. sglang/srt/eplb_simulator/__init__.py +1 -0
  44. sglang/srt/eplb_simulator/reader.py +51 -0
  45. sglang/srt/function_call/base_format_detector.py +7 -4
  46. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  47. sglang/srt/function_call/ebnf_composer.py +64 -10
  48. sglang/srt/function_call/function_call_parser.py +6 -6
  49. sglang/srt/function_call/llama32_detector.py +1 -1
  50. sglang/srt/function_call/mistral_detector.py +1 -1
  51. sglang/srt/function_call/pythonic_detector.py +1 -1
  52. sglang/srt/function_call/qwen25_detector.py +1 -1
  53. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  54. sglang/srt/layers/activation.py +40 -3
  55. sglang/srt/layers/attention/aiter_backend.py +20 -4
  56. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  57. sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
  58. sglang/srt/layers/attention/flashattention_backend.py +71 -72
  59. sglang/srt/layers/attention/flashinfer_backend.py +10 -8
  60. sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
  61. sglang/srt/layers/attention/flashmla_backend.py +7 -12
  62. sglang/srt/layers/attention/tbo_backend.py +3 -3
  63. sglang/srt/layers/attention/triton_backend.py +138 -130
  64. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  65. sglang/srt/layers/attention/vision.py +51 -24
  66. sglang/srt/layers/communicator.py +28 -10
  67. sglang/srt/layers/dp_attention.py +11 -2
  68. sglang/srt/layers/layernorm.py +29 -2
  69. sglang/srt/layers/linear.py +0 -4
  70. sglang/srt/layers/logits_processor.py +2 -14
  71. sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
  72. sglang/srt/layers/moe/ep_moe/layer.py +249 -33
  73. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
  76. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  77. sglang/srt/layers/moe/topk.py +107 -12
  78. sglang/srt/layers/pooler.py +56 -0
  79. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  80. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  81. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  82. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  83. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  84. sglang/srt/layers/quantization/fp8.py +25 -17
  85. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  86. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  87. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  88. sglang/srt/layers/quantization/utils.py +5 -2
  89. sglang/srt/layers/radix_attention.py +2 -3
  90. sglang/srt/layers/rotary_embedding.py +42 -2
  91. sglang/srt/layers/sampler.py +1 -1
  92. sglang/srt/lora/lora_manager.py +249 -105
  93. sglang/srt/lora/mem_pool.py +53 -50
  94. sglang/srt/lora/utils.py +1 -1
  95. sglang/srt/managers/cache_controller.py +33 -14
  96. sglang/srt/managers/io_struct.py +31 -10
  97. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  98. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  99. sglang/srt/managers/schedule_batch.py +79 -37
  100. sglang/srt/managers/schedule_policy.py +70 -56
  101. sglang/srt/managers/scheduler.py +220 -79
  102. sglang/srt/managers/template_manager.py +226 -0
  103. sglang/srt/managers/tokenizer_manager.py +40 -10
  104. sglang/srt/managers/tp_worker.py +12 -2
  105. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  106. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  107. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  108. sglang/srt/mem_cache/chunk_cache.py +11 -15
  109. sglang/srt/mem_cache/hiradix_cache.py +38 -25
  110. sglang/srt/mem_cache/memory_pool.py +213 -505
  111. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  112. sglang/srt/mem_cache/radix_cache.py +56 -28
  113. sglang/srt/model_executor/cuda_graph_runner.py +198 -100
  114. sglang/srt/model_executor/forward_batch_info.py +32 -10
  115. sglang/srt/model_executor/model_runner.py +28 -12
  116. sglang/srt/model_loader/loader.py +16 -2
  117. sglang/srt/model_loader/weight_utils.py +11 -2
  118. sglang/srt/models/bert.py +113 -13
  119. sglang/srt/models/deepseek_nextn.py +29 -27
  120. sglang/srt/models/deepseek_v2.py +213 -173
  121. sglang/srt/models/glm4.py +312 -0
  122. sglang/srt/models/internvl.py +46 -102
  123. sglang/srt/models/mimo_mtp.py +2 -18
  124. sglang/srt/models/roberta.py +117 -9
  125. sglang/srt/models/vila.py +305 -0
  126. sglang/srt/reasoning_parser.py +21 -11
  127. sglang/srt/sampling/sampling_batch_info.py +24 -0
  128. sglang/srt/sampling/sampling_params.py +2 -0
  129. sglang/srt/server_args.py +351 -238
  130. sglang/srt/speculative/build_eagle_tree.py +1 -1
  131. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
  132. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
  133. sglang/srt/speculative/eagle_utils.py +468 -116
  134. sglang/srt/speculative/eagle_worker.py +258 -84
  135. sglang/srt/torch_memory_saver_adapter.py +19 -15
  136. sglang/srt/two_batch_overlap.py +4 -2
  137. sglang/srt/utils.py +235 -11
  138. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  139. sglang/test/runners.py +38 -3
  140. sglang/test/test_block_fp8.py +1 -0
  141. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  142. sglang/test/test_block_fp8_ep.py +2 -0
  143. sglang/test/test_utils.py +4 -1
  144. sglang/utils.py +9 -0
  145. sglang/version.py +1 -1
  146. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
  147. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
  148. sglang/srt/entrypoints/verl_engine.py +0 -179
  149. sglang/srt/openai_api/adapter.py +0 -1990
  150. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  151. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  152. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -28,9 +28,9 @@ from sglang.srt.layers.dp_attention import (
28
28
  attn_tp_reduce_scatter,
29
29
  dp_gather_partial,
30
30
  dp_scatter,
31
+ get_attention_dp_size,
31
32
  get_attention_tp_rank,
32
33
  get_attention_tp_size,
33
- get_local_attention_dp_size,
34
34
  )
35
35
  from sglang.srt.managers.schedule_batch import global_server_args_dict
36
36
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -226,31 +226,32 @@ class LayerCommunicator:
226
226
 
227
227
  @dataclass
228
228
  class CommunicateContext:
229
- process_group_sizes: Dict["ScatterMode", int]
229
+ process_group_sizes: Dict[ScatterMode, int]
230
230
  attn_tp_rank: int
231
231
  attn_tp_size: int
232
- local_attn_dp_size: int
232
+ attn_dp_size: int
233
233
  tp_size: int
234
234
 
235
- def is_same_group_size(self, a: "ScatterMode", b: "ScatterMode"):
235
+ def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
236
236
  return self.process_group_sizes[a] == self.process_group_sizes[b]
237
237
 
238
238
  @classmethod
239
239
  def init_new(cls):
240
240
  attn_tp_rank = get_attention_tp_rank()
241
241
  attn_tp_size = get_attention_tp_size()
242
- local_attn_dp_size = get_local_attention_dp_size()
242
+ attn_dp_size = get_attention_dp_size()
243
243
  tp_size = get_tensor_model_parallel_world_size()
244
244
  process_group_sizes = {
245
245
  ScatterMode.SCATTERED: 1,
246
246
  ScatterMode.TP_ATTN_FULL: attn_tp_size,
247
+ # TODO: support --moe-dense-tp-size > 1
247
248
  ScatterMode.FULL: tp_size,
248
249
  }
249
250
  return cls(
250
251
  process_group_sizes=process_group_sizes,
251
252
  attn_tp_rank=attn_tp_rank,
252
253
  attn_tp_size=attn_tp_size,
253
- local_attn_dp_size=local_attn_dp_size,
254
+ attn_dp_size=attn_dp_size,
254
255
  tp_size=tp_size,
255
256
  )
256
257
 
@@ -323,11 +324,16 @@ class CommunicateWithAllReduceAndLayerNormFn:
323
324
 
324
325
  if (
325
326
  (hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
326
- and (residual_input_mode == ScatterMode.TP_ATTN_FULL)
327
+ and (
328
+ residual_input_mode in [ScatterMode.SCATTERED, ScatterMode.TP_ATTN_FULL]
329
+ )
327
330
  and (hidden_states_output_mode == ScatterMode.FULL)
328
331
  and (residual_output_mode == ScatterMode.TP_ATTN_FULL)
329
332
  ):
330
- return CommunicateWithAllReduceAndLayerNormFn._gather_hidden_states
333
+ return partial(
334
+ CommunicateWithAllReduceAndLayerNormFn._gather_hidden_states_and_residual,
335
+ residual_input_mode=residual_input_mode,
336
+ )
331
337
 
332
338
  if (
333
339
  (hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
@@ -360,14 +366,26 @@ class CommunicateWithAllReduceAndLayerNormFn:
360
366
  return hidden_states, residual
361
367
 
362
368
  @staticmethod
363
- def _gather_hidden_states(
369
+ def _gather_hidden_states_and_residual(
364
370
  hidden_states: torch.Tensor,
365
371
  residual: torch.Tensor,
366
372
  forward_batch: ForwardBatch,
367
373
  layernorm: torch.nn.Module,
368
374
  context: CommunicateContext,
375
+ *,
376
+ residual_input_mode,
369
377
  ):
370
- if context.local_attn_dp_size != 1:
378
+ if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1:
379
+ residual, local_residual = (
380
+ forward_batch.gathered_buffer[
381
+ : forward_batch.input_ids.shape[0]
382
+ ].clone(),
383
+ residual,
384
+ )
385
+ attn_tp_all_gather(
386
+ list(residual.tensor_split(context.attn_tp_size)), local_residual
387
+ )
388
+ if context.attn_dp_size != 1:
371
389
  if context.attn_tp_rank == 0:
372
390
  hidden_states += residual
373
391
  hidden_states, local_hidden_states = (
@@ -165,7 +165,8 @@ def disable_dp_size():
165
165
 
166
166
 
167
167
  def get_dp_local_info(forward_batch: ForwardBatch):
168
- dp_rank = get_local_attention_dp_rank()
168
+ # `get_dp_local_info` is only called in global DP gather and scatter. We use global DP rank here.
169
+ dp_rank = get_attention_dp_rank()
169
170
 
170
171
  if forward_batch.dp_local_start_pos is None:
171
172
  cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
@@ -238,6 +239,10 @@ def _dp_gather(
238
239
  assert (
239
240
  local_tokens.untyped_storage() is not global_tokens.untyped_storage()
240
241
  ), "aliasing between global_tokens and local_tokens not allowed"
242
+ if forward_batch.forward_mode.is_draft_extend():
243
+ shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
244
+ local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
245
+
241
246
  memcpy_triton(
242
247
  global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
243
248
  )
@@ -288,6 +293,10 @@ def dp_scatter(
288
293
  assert (
289
294
  local_tokens.untyped_storage() is not global_tokens.untyped_storage()
290
295
  ), "aliasing between local_tokens and global_tokens not allowed"
296
+ if forward_batch.forward_mode.is_draft_extend():
297
+ shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
298
+ local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
299
+
291
300
  memcpy_triton(
292
301
  local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
293
302
  )
@@ -301,4 +310,4 @@ def attn_tp_reduce_scatter(
301
310
 
302
311
 
303
312
  def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
304
- return get_attention_tp_group().all_gather(input_, tensor_list=output_list)
313
+ return get_attention_tp_group().all_gather(input_, output_tensor_list=output_list)
@@ -20,11 +20,21 @@ import torch
20
20
  import torch.nn as nn
21
21
 
22
22
  from sglang.srt.custom_op import CustomOp
23
- from sglang.srt.utils import get_bool_env_var, is_cuda, is_hip
23
+ from sglang.srt.utils import (
24
+ cpu_has_amx_support,
25
+ get_bool_env_var,
26
+ is_cpu,
27
+ is_cuda,
28
+ is_hip,
29
+ is_npu,
30
+ )
24
31
 
25
32
  _is_cuda = is_cuda()
26
33
  _is_hip = is_hip()
34
+ _is_npu = is_npu()
27
35
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
36
+ _is_cpu_amx_available = cpu_has_amx_support()
37
+ _is_cpu = is_cpu()
28
38
 
29
39
  if _is_cuda:
30
40
  from sgl_kernel import (
@@ -121,6 +131,23 @@ class RMSNorm(CustomOp):
121
131
  else:
122
132
  return x, residual
123
133
 
134
+ def forward_cpu(
135
+ self,
136
+ x: torch.Tensor,
137
+ residual: Optional[torch.Tensor] = None,
138
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
139
+ if _is_cpu_amx_available:
140
+ if residual is not None:
141
+ torch.ops.sgl_kernel.fused_add_rmsnorm_cpu(
142
+ x, residual, self.weight.data, self.variance_epsilon
143
+ )
144
+ return x, residual
145
+ return torch.ops.sgl_kernel.rmsnorm_cpu(
146
+ x, self.weight.data, self.variance_epsilon
147
+ )
148
+ else:
149
+ return self.forward_native(x, residual)
150
+
124
151
 
125
152
  class GemmaRMSNorm(CustomOp):
126
153
  def __init__(
@@ -187,7 +214,7 @@ class Gemma3RMSNorm(nn.Module):
187
214
  return f"{tuple(self.weight.shape)}, eps={self.eps}"
188
215
 
189
216
 
190
- if not (_is_cuda or _is_hip):
217
+ if not (_is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available)):
191
218
  logger.info(
192
219
  "sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
193
220
  )
@@ -546,8 +546,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
546
546
  param.shard_id.append(loaded_shard_id)
547
547
  param.shard_id_map[loaded_shard_id] = len(param.data_container)
548
548
  param.data_container.append(loaded_weight)
549
- if len(param.data_container) == 2:
550
- self.qweight = param.materialize_nested()
551
549
  return
552
550
 
553
551
  param_data = param.data
@@ -961,8 +959,6 @@ class QKVParallelLinear(ColumnParallelLinear):
961
959
  param.shard_id.append(loaded_shard_id)
962
960
  param.shard_id_map[loaded_shard_id] = len(param.data_container)
963
961
  param.data_container.append(loaded_weight)
964
- if len(param.data_container) == 3:
965
- self.qweight = param.materialize_nested()
966
962
  return
967
963
 
968
964
  param_data = param.data
@@ -30,9 +30,9 @@ from sglang.srt.layers.dp_attention import (
30
30
  attn_tp_all_gather,
31
31
  dp_gather_replicate,
32
32
  dp_scatter,
33
+ get_attention_dp_rank,
33
34
  get_attention_dp_size,
34
35
  get_attention_tp_size,
35
- get_local_attention_dp_rank,
36
36
  get_local_attention_dp_size,
37
37
  )
38
38
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
@@ -47,18 +47,6 @@ from sglang.srt.utils import dump_to_file
47
47
  logger = logging.getLogger(__name__)
48
48
 
49
49
 
50
- from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
51
- from sglang.srt.managers.schedule_batch import global_server_args_dict
52
- from sglang.srt.model_executor.forward_batch_info import (
53
- CaptureHiddenMode,
54
- ForwardBatch,
55
- ForwardMode,
56
- )
57
- from sglang.srt.utils import dump_to_file
58
-
59
- logger = logging.getLogger(__name__)
60
-
61
-
62
50
  @dataclasses.dataclass
63
51
  class LogitsProcessorOutput:
64
52
  ## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
@@ -183,7 +171,7 @@ class LogitsMetadata:
183
171
  return
184
172
 
185
173
  cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
186
- dp_rank = get_local_attention_dp_rank()
174
+ dp_rank = get_attention_dp_rank()
187
175
  if dp_rank == 0:
188
176
  dp_local_start_pos = torch.zeros_like(
189
177
  self.global_num_tokens_for_logprob_gpu[0]
@@ -4,6 +4,7 @@ from typing import List, Optional
4
4
  import torch
5
5
  import triton
6
6
 
7
+ from sglang.math_utils import ceil_div
7
8
  from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
8
9
  from sglang.srt.utils import dispose_tensor, is_cuda
9
10
 
@@ -15,11 +16,6 @@ if _is_cuda:
15
16
  sglang_per_token_group_quant_fp8 as per_token_group_quant_fp8,
16
17
  )
17
18
 
18
- try:
19
- from deep_gemm import ceil_div
20
- except ImportError:
21
- logger.error(f"Failed to import ceil_div from deep_gemm.")
22
-
23
19
  import triton.language as tl
24
20
 
25
21
 
@@ -278,6 +274,7 @@ def _silu_and_mul_post_quant_kernel(
278
274
  fp8_min,
279
275
  BLOCK_N: tl.constexpr,
280
276
  NUM_STAGE: tl.constexpr,
277
+ SCALE_UE8M0: tl.constexpr,
281
278
  ):
282
279
  expert_id = tl.program_id(2)
283
280
  token_id = tl.program_id(1)
@@ -319,6 +316,8 @@ def _silu_and_mul_post_quant_kernel(
319
316
  gate_up = up * gate
320
317
  _absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10)
321
318
  output_s = _absmax / fp8_max
319
+ if SCALE_UE8M0:
320
+ output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s))))
322
321
  output_q = tl.clamp(gate_up / output_s, fp8_min, fp8_max).to(
323
322
  output_ptr.dtype.element_ty
324
323
  )
@@ -339,6 +338,7 @@ def silu_and_mul_masked_post_quant_fwd(
339
338
  output_scale: torch.Tensor,
340
339
  quant_group_size: int,
341
340
  masked_m: torch.Tensor,
341
+ scale_ue8m0: bool = False,
342
342
  ):
343
343
  """
344
344
  input shape [expert_num, token_num_padded, hidden_dim]
@@ -395,6 +395,7 @@ def silu_and_mul_masked_post_quant_fwd(
395
395
  BLOCK_N=BLOCK_N,
396
396
  NUM_STAGE=NUM_STAGES,
397
397
  num_warps=num_warps,
398
+ SCALE_UE8M0=scale_ue8m0,
398
399
  )
399
400
  return
400
401
 
@@ -477,11 +478,13 @@ def post_reorder_triton_kernel(
477
478
  end_expert_id,
478
479
  topk,
479
480
  hidden_size,
481
+ dst_start,
480
482
  BLOCK_SIZE: tl.constexpr,
481
483
  ):
482
484
  InDtype = down_output_ptr.dtype.element_ty
483
485
 
484
- src_idx = tl.program_id(0)
486
+ src_idx_int32 = tl.program_id(0)
487
+ src_idx = src_idx_int32.to(tl.int64)
485
488
  src2dst_ptr = src2dst_ptr + src_idx * topk
486
489
  topk_ids_ptr = topk_ids_ptr + src_idx * topk
487
490
  topk_weights_ptr = topk_weights_ptr + src_idx * topk
@@ -500,7 +503,9 @@ def post_reorder_triton_kernel(
500
503
  expert_id = tl.load(topk_ids_ptr + idx)
501
504
  if expert_id >= start_expert_id and expert_id <= end_expert_id:
502
505
  computed = True
503
- dst_idx = tl.load(src2dst_ptr + idx)
506
+ dst_idx_int32 = tl.load(src2dst_ptr + idx)
507
+ dst_idx = dst_idx_int32.to(tl.int64)
508
+ dst_idx = dst_idx - dst_start
504
509
  weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
505
510
  load_ptr = down_output_ptr + dst_idx * hidden_size
506
511
  in_data = tl.load(load_ptr + offset, mask=mask)
@@ -1085,3 +1090,156 @@ def tma_align_input_scale(input_scale: torch.Tensor):
1085
1090
  BLOCK_SIZE_K=BLOCK_SIZE_K,
1086
1091
  )
1087
1092
  return output.t()[:m]
1093
+
1094
+
1095
+ @triton.jit
1096
+ def compute_masked_m_triton_kernel(seg_indptr, masked_m):
1097
+ expert_id = tl.program_id(0)
1098
+ start = tl.load(seg_indptr + expert_id)
1099
+ end = tl.load(seg_indptr + expert_id + 1)
1100
+ tl.store(masked_m + expert_id, (end - start))
1101
+
1102
+
1103
+ @triton.jit
1104
+ def deepgemm_compute_src2dst_triton_kernel(
1105
+ topk_ids,
1106
+ reorder_ids,
1107
+ seg_indptr,
1108
+ src2dst,
1109
+ m_max,
1110
+ num_toks,
1111
+ BLOCK_SIZE: tl.constexpr,
1112
+ ):
1113
+ pid = tl.program_id(axis=0)
1114
+ dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
1115
+ mask = dst_id < num_toks
1116
+ src_id = tl.load(reorder_ids + dst_id, mask=mask)
1117
+ expert_id = tl.load(topk_ids + src_id, mask=(src_id < num_toks))
1118
+ expert_dst_start = tl.load(seg_indptr + expert_id)
1119
+ expert_dst_offset = dst_id - expert_dst_start
1120
+ dst_id = expert_id * m_max + expert_dst_offset
1121
+ tl.store(src2dst + src_id, dst_id, mask=mask)
1122
+
1123
+
1124
+ @triton.jit
1125
+ def fill_gateup_input_triton_kernel(
1126
+ input_ptr,
1127
+ scale_ptr,
1128
+ gateup_input_ptr,
1129
+ gateup_input_scale_ptr,
1130
+ src2dst_ptr,
1131
+ topk_ids_ptr,
1132
+ start_expert_id,
1133
+ end_expert_id,
1134
+ topk,
1135
+ m_max,
1136
+ hidden_size,
1137
+ scale_size,
1138
+ BLOCK_SIZE: tl.constexpr,
1139
+ ):
1140
+
1141
+ src_idx_int32 = tl.program_id(0)
1142
+ src_idx = src_idx_int32.to(tl.int64)
1143
+ src2dst_ptr = src2dst_ptr + src_idx * topk
1144
+ topk_ids_ptr = topk_ids_ptr + src_idx * topk
1145
+ src_ptr = input_ptr + src_idx * hidden_size
1146
+ scale_src_ptr = scale_ptr + src_idx * scale_size
1147
+
1148
+ vec = tl.arange(0, BLOCK_SIZE)
1149
+ for idx in range(topk):
1150
+ expert_id = tl.load(topk_ids_ptr + idx)
1151
+ if expert_id >= start_expert_id and expert_id <= end_expert_id:
1152
+ dst_idx_int32 = tl.load(src2dst_ptr + idx)
1153
+ dst_idx = dst_idx_int32.to(tl.int64)
1154
+ dst_idx = dst_idx - start_expert_id * m_max
1155
+ dst_ptr = gateup_input_ptr + dst_idx * hidden_size
1156
+ for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
1157
+ offset = start_offset + vec
1158
+ mask = offset < hidden_size
1159
+ in_data = tl.load(src_ptr + offset, mask=mask)
1160
+ tl.store(dst_ptr + offset, in_data, mask=mask)
1161
+ scale_dst_ptr = gateup_input_scale_ptr + dst_idx * scale_size
1162
+ for start_offset in tl.range(0, scale_size, BLOCK_SIZE):
1163
+ offset = start_offset + vec
1164
+ mask = offset < scale_size
1165
+ in_scale = tl.load(scale_src_ptr + offset, mask=mask)
1166
+ tl.store(scale_dst_ptr + offset, in_scale, mask=mask)
1167
+
1168
+
1169
+ def moe_ep_deepgemm_preprocess(
1170
+ topk_ids: torch.Tensor,
1171
+ num_experts: int,
1172
+ hidden_states: torch.Tensor,
1173
+ top_k: int,
1174
+ start_expert_id,
1175
+ end_expert_id,
1176
+ block_shape,
1177
+ output_dtype: torch.dtype = torch.float8_e4m3fn,
1178
+ ):
1179
+ reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
1180
+ seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
1181
+ src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
1182
+ masked_m = torch.zeros(num_experts, device=topk_ids.device, dtype=torch.int32)
1183
+
1184
+ compute_seg_indptr_triton_kernel[(num_experts,)](
1185
+ reorder_topk_ids, seg_indptr, topk_ids.numel()
1186
+ )
1187
+
1188
+ grid = lambda meta: (triton.cdiv(topk_ids.numel(), meta["BLOCK_SIZE"]),)
1189
+ compute_masked_m_triton_kernel[(num_experts,)](seg_indptr, masked_m)
1190
+
1191
+ # For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m}) https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/m_grouped_gemm.py#L165
1192
+ m_max = (hidden_states.size(0) + 255) // 256 * 256
1193
+ expected_m = (topk_ids.numel() + num_experts - 1) // num_experts
1194
+ gateup_input = torch.empty(
1195
+ (int(end_expert_id - start_expert_id + 1), m_max, hidden_states.size(1)),
1196
+ device=hidden_states.device,
1197
+ dtype=output_dtype,
1198
+ )
1199
+
1200
+ deepgemm_compute_src2dst_triton_kernel[grid](
1201
+ topk_ids,
1202
+ reorder_ids,
1203
+ seg_indptr,
1204
+ src2dst,
1205
+ m_max,
1206
+ topk_ids.numel(),
1207
+ BLOCK_SIZE=256,
1208
+ )
1209
+
1210
+ if block_shape is None:
1211
+ block_shape = [128, 128]
1212
+ assert len(block_shape) == 2
1213
+ block_n, block_k = block_shape[0], block_shape[1]
1214
+ hidden_states, scale = per_token_group_quant_fp8(hidden_states, block_k)
1215
+
1216
+ gateup_input_scale = torch.empty(
1217
+ (gateup_input.size(0), gateup_input.size(1), scale.size(1)),
1218
+ device=hidden_states.device,
1219
+ dtype=scale.dtype,
1220
+ )
1221
+
1222
+ fill_gateup_input_triton_kernel[(hidden_states.shape[0],)](
1223
+ hidden_states,
1224
+ scale,
1225
+ gateup_input,
1226
+ gateup_input_scale,
1227
+ src2dst,
1228
+ topk_ids,
1229
+ start_expert_id,
1230
+ end_expert_id,
1231
+ top_k,
1232
+ m_max,
1233
+ hidden_states.size(1),
1234
+ scale.size(1),
1235
+ BLOCK_SIZE=1024,
1236
+ )
1237
+
1238
+ return (
1239
+ m_max,
1240
+ masked_m[start_expert_id : (end_expert_id + 1)],
1241
+ expected_m,
1242
+ src2dst,
1243
+ gateup_input,
1244
+ gateup_input_scale,
1245
+ )