sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +73 -14
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/launch_server.py +2 -0
  5. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  6. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
  7. sglang/srt/checkpoint_engine/__init__.py +9 -0
  8. sglang/srt/checkpoint_engine/update.py +317 -0
  9. sglang/srt/compilation/backend.py +1 -1
  10. sglang/srt/configs/__init__.py +2 -0
  11. sglang/srt/configs/deepseek_ocr.py +542 -10
  12. sglang/srt/configs/deepseekvl2.py +95 -194
  13. sglang/srt/configs/kimi_linear.py +160 -0
  14. sglang/srt/configs/mamba_utils.py +66 -0
  15. sglang/srt/configs/model_config.py +30 -7
  16. sglang/srt/constants.py +7 -0
  17. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  18. sglang/srt/disaggregation/decode.py +34 -6
  19. sglang/srt/disaggregation/nixl/conn.py +2 -2
  20. sglang/srt/disaggregation/prefill.py +25 -3
  21. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  22. sglang/srt/distributed/parallel_state.py +9 -12
  23. sglang/srt/entrypoints/engine.py +31 -20
  24. sglang/srt/entrypoints/grpc_server.py +0 -1
  25. sglang/srt/entrypoints/http_server.py +94 -94
  26. sglang/srt/entrypoints/openai/protocol.py +7 -1
  27. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  28. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  29. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  30. sglang/srt/environ.py +23 -2
  31. sglang/srt/eplb/expert_distribution.py +64 -1
  32. sglang/srt/eplb/expert_location.py +106 -36
  33. sglang/srt/function_call/function_call_parser.py +2 -0
  34. sglang/srt/function_call/minimax_m2.py +367 -0
  35. sglang/srt/grpc/compile_proto.py +3 -0
  36. sglang/srt/layers/activation.py +6 -0
  37. sglang/srt/layers/attention/ascend_backend.py +233 -5
  38. sglang/srt/layers/attention/attention_registry.py +3 -0
  39. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  40. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  41. sglang/srt/layers/attention/fla/kda.py +1359 -0
  42. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  43. sglang/srt/layers/attention/flashattention_backend.py +19 -8
  44. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  45. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
  46. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  47. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  48. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  49. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  50. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  51. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  52. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  53. sglang/srt/layers/attention/nsa_backend.py +157 -23
  54. sglang/srt/layers/attention/triton_backend.py +4 -1
  55. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  56. sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
  57. sglang/srt/layers/attention/utils.py +78 -0
  58. sglang/srt/layers/communicator.py +24 -1
  59. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  60. sglang/srt/layers/layernorm.py +35 -6
  61. sglang/srt/layers/logits_processor.py +9 -20
  62. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  63. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  64. sglang/srt/layers/moe/ep_moe/layer.py +78 -289
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  67. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  68. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  69. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  70. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  71. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  72. sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
  73. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  75. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  76. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  77. sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
  78. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  79. sglang/srt/layers/moe/topk.py +35 -10
  80. sglang/srt/layers/moe/utils.py +3 -4
  81. sglang/srt/layers/pooler.py +21 -2
  82. sglang/srt/layers/quantization/__init__.py +13 -84
  83. sglang/srt/layers/quantization/auto_round.py +394 -0
  84. sglang/srt/layers/quantization/awq.py +0 -3
  85. sglang/srt/layers/quantization/base_config.py +7 -0
  86. sglang/srt/layers/quantization/fp8.py +68 -63
  87. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  88. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  89. sglang/srt/layers/quantization/gguf.py +566 -0
  90. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  91. sglang/srt/layers/quantization/mxfp4.py +30 -38
  92. sglang/srt/layers/quantization/unquant.py +23 -45
  93. sglang/srt/layers/quantization/w4afp8.py +38 -2
  94. sglang/srt/layers/radix_attention.py +5 -2
  95. sglang/srt/layers/rotary_embedding.py +130 -46
  96. sglang/srt/layers/sampler.py +12 -1
  97. sglang/srt/lora/lora_registry.py +9 -0
  98. sglang/srt/managers/async_mm_data_processor.py +122 -0
  99. sglang/srt/managers/data_parallel_controller.py +30 -3
  100. sglang/srt/managers/detokenizer_manager.py +3 -0
  101. sglang/srt/managers/io_struct.py +29 -4
  102. sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
  103. sglang/srt/managers/schedule_batch.py +74 -15
  104. sglang/srt/managers/scheduler.py +185 -144
  105. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  107. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  108. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  109. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  110. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  111. sglang/srt/managers/session_controller.py +6 -5
  112. sglang/srt/managers/tokenizer_manager.py +165 -78
  113. sglang/srt/managers/tp_worker.py +24 -1
  114. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  115. sglang/srt/mem_cache/common.py +1 -0
  116. sglang/srt/mem_cache/hicache_storage.py +7 -1
  117. sglang/srt/mem_cache/memory_pool.py +253 -57
  118. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  119. sglang/srt/mem_cache/radix_cache.py +4 -0
  120. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  121. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  122. sglang/srt/metrics/collector.py +46 -3
  123. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  124. sglang/srt/model_executor/forward_batch_info.py +55 -14
  125. sglang/srt/model_executor/model_runner.py +77 -170
  126. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  127. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  128. sglang/srt/model_loader/weight_utils.py +1 -1
  129. sglang/srt/models/bailing_moe.py +9 -2
  130. sglang/srt/models/deepseek_nextn.py +11 -2
  131. sglang/srt/models/deepseek_v2.py +296 -78
  132. sglang/srt/models/glm4.py +391 -77
  133. sglang/srt/models/glm4_moe.py +322 -354
  134. sglang/srt/models/glm4_moe_nextn.py +4 -14
  135. sglang/srt/models/glm4v.py +196 -55
  136. sglang/srt/models/glm4v_moe.py +29 -197
  137. sglang/srt/models/gpt_oss.py +1 -10
  138. sglang/srt/models/kimi_linear.py +678 -0
  139. sglang/srt/models/llama4.py +1 -1
  140. sglang/srt/models/llama_eagle3.py +11 -1
  141. sglang/srt/models/longcat_flash.py +2 -2
  142. sglang/srt/models/minimax_m2.py +922 -0
  143. sglang/srt/models/nvila.py +355 -0
  144. sglang/srt/models/nvila_lite.py +184 -0
  145. sglang/srt/models/qwen2.py +23 -2
  146. sglang/srt/models/qwen2_moe.py +30 -15
  147. sglang/srt/models/qwen3.py +35 -5
  148. sglang/srt/models/qwen3_moe.py +18 -12
  149. sglang/srt/models/qwen3_next.py +7 -0
  150. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  151. sglang/srt/multimodal/processors/base_processor.py +1 -0
  152. sglang/srt/multimodal/processors/glm4v.py +1 -1
  153. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  154. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  155. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  156. sglang/srt/multiplex/pdmux_context.py +164 -0
  157. sglang/srt/parser/conversation.py +7 -1
  158. sglang/srt/parser/reasoning_parser.py +28 -1
  159. sglang/srt/sampling/custom_logit_processor.py +67 -1
  160. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  161. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  162. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  163. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  164. sglang/srt/server_args.py +459 -199
  165. sglang/srt/single_batch_overlap.py +2 -4
  166. sglang/srt/speculative/draft_utils.py +16 -0
  167. sglang/srt/speculative/eagle_info.py +42 -36
  168. sglang/srt/speculative/eagle_info_v2.py +68 -25
  169. sglang/srt/speculative/eagle_utils.py +261 -16
  170. sglang/srt/speculative/eagle_worker.py +11 -3
  171. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  172. sglang/srt/speculative/spec_info.py +305 -31
  173. sglang/srt/speculative/spec_utils.py +44 -8
  174. sglang/srt/tracing/trace.py +121 -12
  175. sglang/srt/utils/common.py +142 -74
  176. sglang/srt/utils/hf_transformers_utils.py +38 -12
  177. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  178. sglang/test/kits/radix_cache_server_kit.py +50 -0
  179. sglang/test/runners.py +31 -7
  180. sglang/test/simple_eval_common.py +5 -3
  181. sglang/test/simple_eval_humaneval.py +1 -0
  182. sglang/test/simple_eval_math.py +1 -0
  183. sglang/test/simple_eval_mmlu.py +1 -0
  184. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  185. sglang/test/test_deterministic.py +235 -12
  186. sglang/test/test_deterministic_utils.py +2 -1
  187. sglang/test/test_utils.py +7 -1
  188. sglang/version.py +1 -1
  189. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
  190. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
  191. sglang/srt/models/vila.py +0 -306
  192. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  193. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  194. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  195. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -25,6 +25,13 @@ from sglang.srt.utils import (
25
25
  is_hip,
26
26
  )
27
27
 
28
+ try:
29
+ from triton.tools.tensor_descriptor import TensorDescriptor
30
+
31
+ _support_tensor_descriptor = True
32
+ except:
33
+ _support_tensor_descriptor = False
34
+
28
35
  _is_hip = is_hip()
29
36
  _is_cuda = is_cuda()
30
37
  _is_cpu_amx_available = cpu_has_amx_support()
@@ -41,6 +48,10 @@ elif _is_hip:
41
48
  padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
42
49
 
43
50
 
51
+ def support_tensor_descriptor():
52
+ return _support_tensor_descriptor
53
+
54
+
44
55
  @triton.jit
45
56
  def write_zeros_to_output(
46
57
  c_ptr,
@@ -108,6 +119,7 @@ def fused_moe_kernel_gptq_awq(
108
119
  use_int4_w4a16: tl.constexpr,
109
120
  use_int8_w8a16: tl.constexpr,
110
121
  even_Ks: tl.constexpr,
122
+ filter_expert: tl.constexpr,
111
123
  ):
112
124
  """
113
125
  Implements the fused computation for a Mixture of Experts (MOE) using
@@ -161,7 +173,7 @@ def fused_moe_kernel_gptq_awq(
161
173
  token_mask = offs_token < num_valid_tokens
162
174
 
163
175
  off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
164
- if off_experts == -1:
176
+ if filter_expert and off_experts == -1:
165
177
  # -----------------------------------------------------------
166
178
  # Write back zeros to the output when the expert is not
167
179
  # in the current expert parallel rank.
@@ -296,7 +308,9 @@ def fused_moe_kernel_gptq_awq(
296
308
  def fused_moe_kernel(
297
309
  # Pointers to matrices
298
310
  a_ptr,
311
+ a_desc,
299
312
  b_ptr,
313
+ b_desc,
300
314
  bias_ptr,
301
315
  c_ptr,
302
316
  a_scale_ptr,
@@ -344,6 +358,8 @@ def fused_moe_kernel(
344
358
  use_int8_w8a16: tl.constexpr,
345
359
  per_channel_quant: tl.constexpr,
346
360
  even_Ks: tl.constexpr,
361
+ c_sorted: tl.constexpr,
362
+ filter_expert: tl.constexpr,
347
363
  ):
348
364
  """
349
365
  Implements the fused computation for a Mixture of Experts (MOE) using
@@ -399,9 +415,10 @@ def fused_moe_kernel(
399
415
  offs_token = offs_token.to(tl.int64)
400
416
  token_mask = offs_token < num_valid_tokens
401
417
 
402
- off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
418
+ off_experts_i32 = tl.load(expert_ids_ptr + pid_m)
419
+ off_experts = off_experts_i32.to(tl.int64)
403
420
 
404
- if off_experts == -1:
421
+ if filter_expert and off_experts == -1:
405
422
  # -----------------------------------------------------------
406
423
  # Write back zeros to the output when the expert is not
407
424
  # in the current expert parallel rank.
@@ -421,15 +438,23 @@ def fused_moe_kernel(
421
438
 
422
439
  offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
423
440
  offs_k = tl.arange(0, BLOCK_SIZE_K)
424
- a_ptrs = a_ptr + (
425
- offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
426
- )
441
+ if a_desc is not None:
442
+ assert use_fp8_w8a8 and group_n > 0 and group_k > 0
443
+ start_offs_m = pid_m * BLOCK_SIZE_M
444
+ else:
445
+ a_ptrs = a_ptr + (
446
+ offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
447
+ )
448
+
449
+ if b_desc is not None:
450
+ start_offs_n = pid_n * BLOCK_SIZE_N
451
+ else:
452
+ b_ptrs = (
453
+ b_ptr
454
+ + off_experts * stride_be
455
+ + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
456
+ )
427
457
 
428
- b_ptrs = (
429
- b_ptr
430
- + off_experts * stride_be
431
- + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
432
- )
433
458
  if bias_ptr is not None:
434
459
  bias = tl.load(
435
460
  bias_ptr + off_experts * stride_bias_e + offs_bn[None, :] * stride_bias_n
@@ -443,8 +468,14 @@ def fused_moe_kernel(
443
468
  if use_fp8_w8a8 or use_int8_w8a8:
444
469
  # block-wise
445
470
  if group_k > 0 and group_n > 0:
446
- a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
447
- offs_bsn = offs_bn // group_n
471
+ if a_desc is not None:
472
+ a_scale_ptrs = a_scale_ptr + offs_token_id * stride_asm
473
+ else:
474
+ a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
475
+ if BLOCK_SIZE_N > group_n:
476
+ offs_bsn = offs_bn // group_n
477
+ else:
478
+ offs_bsn = pid_n * BLOCK_SIZE_N // group_n
448
479
  b_scale_ptrs = (
449
480
  b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
450
481
  )
@@ -469,37 +500,49 @@ def fused_moe_kernel(
469
500
  # `accumulator` will be converted back to fp16 after the loop.
470
501
  accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
471
502
 
472
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
503
+ for k_start in range(0, K, BLOCK_SIZE_K):
473
504
  # Load the next block of A and B, generate a mask by checking the
474
505
  # K dimension.
475
- if even_Ks:
506
+ if a_desc is not None:
507
+ a = a_desc.load([start_offs_m, k_start])
508
+ elif even_Ks:
476
509
  a = tl.load(
477
510
  a_ptrs,
478
511
  mask=token_mask[:, None],
479
512
  other=0.0,
480
513
  )
481
- b = tl.load(b_ptrs)
482
514
  else:
483
515
  a = tl.load(
484
516
  a_ptrs,
485
- mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
517
+ mask=token_mask[:, None] & (offs_k[None, :] < K - k_start),
486
518
  other=0.0,
487
519
  )
488
- b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
520
+
521
+ if b_desc is not None:
522
+ b = (
523
+ b_desc.load([off_experts_i32, start_offs_n, k_start])
524
+ .reshape(BLOCK_SIZE_N, BLOCK_SIZE_K)
525
+ .T
526
+ )
527
+ elif even_Ks:
528
+ b = tl.load(b_ptrs)
529
+ else:
530
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k_start, other=0.0)
489
531
 
490
532
  # We accumulate along the K dimension.
491
533
  if use_int8_w8a16:
492
534
  accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
493
535
  elif use_fp8_w8a8 or use_int8_w8a8:
494
536
  if group_k > 0 and group_n > 0:
495
- k_start = k * BLOCK_SIZE_K
496
537
  offs_ks = k_start // group_k
497
538
  a_scale = tl.load(
498
539
  a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
499
540
  )
500
541
  b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
501
-
502
- accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
542
+ if BLOCK_SIZE_N > group_n:
543
+ accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
544
+ else:
545
+ accumulator += tl.dot(a, b) * (a_scale[:, None] * b_scale)
503
546
  else:
504
547
  if use_fp8_w8a8:
505
548
  accumulator = tl.dot(a, b, acc=accumulator)
@@ -508,8 +551,10 @@ def fused_moe_kernel(
508
551
  else:
509
552
  accumulator += tl.dot(a, b)
510
553
  # Advance the ptrs to the next K block.
511
- a_ptrs += BLOCK_SIZE_K * stride_ak
512
- b_ptrs += BLOCK_SIZE_K * stride_bk
554
+ if a_desc is None:
555
+ a_ptrs += BLOCK_SIZE_K * stride_ak
556
+ if b_desc is None:
557
+ b_ptrs += BLOCK_SIZE_K * stride_bk
513
558
 
514
559
  if use_int8_w8a16:
515
560
  accumulator *= b_scale
@@ -528,7 +573,12 @@ def fused_moe_kernel(
528
573
  # -----------------------------------------------------------
529
574
  # Write back the block of the output
530
575
  offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
531
- c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
576
+ if c_sorted:
577
+ c_ptrs = (
578
+ c_ptr + stride_cm * offs_token_id[:, None] + stride_cn * offs_cn[None, :]
579
+ )
580
+ else:
581
+ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
532
582
  c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
533
583
  tl.store(c_ptrs, accumulator, mask=c_mask)
534
584
 
@@ -557,6 +607,10 @@ def invoke_fused_moe_kernel(
557
607
  per_channel_quant: bool,
558
608
  block_shape: Optional[List[int]] = None,
559
609
  no_combine: bool = False,
610
+ a_use_tma: bool = False,
611
+ b_use_tma: bool = False,
612
+ c_sorted: bool = False,
613
+ filter_expert: bool = True,
560
614
  ) -> None:
561
615
  assert topk_weights.stride(1) == 1
562
616
  assert sorted_token_ids.stride(0) == 1
@@ -662,14 +716,38 @@ def invoke_fused_moe_kernel(
662
716
  use_int4_w4a16=use_int4_w4a16,
663
717
  use_int8_w8a16=use_int8_w8a16,
664
718
  even_Ks=even_Ks,
719
+ filter_expert=filter_expert,
665
720
  **config,
666
721
  )
667
722
 
668
723
  else:
724
+ if a_use_tma or b_use_tma:
725
+ # TMA descriptors require a global memory allocation
726
+ def alloc_fn(size: int, alignment: int, stream: Optional[int]):
727
+ return torch.empty(size, device="cuda", dtype=torch.int8)
728
+
729
+ triton.set_allocator(alloc_fn)
730
+ if a_use_tma:
731
+ a_desc = TensorDescriptor(
732
+ A, A.shape, A.stride(), [config["BLOCK_SIZE_M"], config["BLOCK_SIZE_K"]]
733
+ )
734
+ else:
735
+ a_desc = None
736
+ if b_use_tma:
737
+ b_desc = TensorDescriptor(
738
+ B,
739
+ B.shape,
740
+ B.stride(),
741
+ [1, config["BLOCK_SIZE_N"], config["BLOCK_SIZE_K"]],
742
+ )
743
+ else:
744
+ b_desc = None
669
745
 
670
746
  fused_moe_kernel[grid](
671
747
  A,
748
+ a_desc,
672
749
  B,
750
+ b_desc,
673
751
  bias,
674
752
  C,
675
753
  A_scale,
@@ -689,8 +767,8 @@ def invoke_fused_moe_kernel(
689
767
  B.stride(1),
690
768
  bias.stride(0) if bias is not None else 0,
691
769
  bias.stride(1) if bias is not None else 0,
692
- C.stride(1),
693
- C.stride(2),
770
+ C.stride(-2),
771
+ C.stride(-1),
694
772
  A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
695
773
  A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
696
774
  B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
@@ -706,6 +784,8 @@ def invoke_fused_moe_kernel(
706
784
  use_int8_w8a16=use_int8_w8a16,
707
785
  per_channel_quant=per_channel_quant,
708
786
  even_Ks=even_Ks,
787
+ c_sorted=c_sorted,
788
+ filter_expert=filter_expert,
709
789
  **config,
710
790
  )
711
791
 
@@ -172,7 +172,7 @@ class FusedMoE(torch.nn.Module):
172
172
  self.reduce_results = reduce_results
173
173
  self.use_presharded_weights = use_presharded_weights
174
174
 
175
- self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
175
+ self.use_triton_kernels = get_moe_runner_backend().is_triton_kernels()
176
176
 
177
177
  self.quant_config = quant_config
178
178
  self.use_flashinfer_mxfp4_moe = get_moe_runner_backend().is_flashinfer_mxfp4()
@@ -232,7 +232,7 @@ class FusedMoE(torch.nn.Module):
232
232
  self.quant_method, ModelOptNvFp4FusedMoEMethod
233
233
  ) or (
234
234
  isinstance(self.quant_method, Fp8MoEMethod)
235
- and self.quant_method.use_cutlass_fused_experts_fp8
235
+ and self.quant_method._should_use_cutlass_fused_experts()
236
236
  )
237
237
 
238
238
  def _load_per_tensor_weight_scale(
@@ -839,7 +839,7 @@ class FusedMoE(torch.nn.Module):
839
839
  dispatch_output=dispatch_output,
840
840
  **kwargs,
841
841
  )
842
- final_hidden_states = self.dispatcher.combine(combine_input)
842
+ final_hidden_states = self.dispatcher.combine(combine_input=combine_input)
843
843
 
844
844
  # TODO: should we add some conditions here?
845
845
  final_hidden_states = final_hidden_states[
@@ -47,7 +47,7 @@ def triton_kernel_moe_forward(
47
47
 
48
48
  from sglang.srt.layers.moe.topk import TopKOutputChecker
49
49
 
50
- assert TopKOutputChecker.format_is_triton_kernel(topk_output)
50
+ assert TopKOutputChecker.format_is_triton_kernels(topk_output)
51
51
 
52
52
  routing_data, gather_idx, scatter_idx = topk_output
53
53
 
@@ -172,6 +172,7 @@ def triton_kernel_moe_with_bias_forward(
172
172
  b2: torch.Tensor,
173
173
  topk_output: TopKOutput,
174
174
  moe_runner_config: MoeRunnerConfig,
175
+ apply_router_weight_on_input: bool = False,
175
176
  use_fp8_w8a8: bool = False,
176
177
  per_channel_quant: bool = False,
177
178
  global_num_experts: int = -1,
@@ -184,7 +185,7 @@ def triton_kernel_moe_with_bias_forward(
184
185
  ) -> torch.Tensor:
185
186
  from sglang.srt.layers.moe.topk import TopKOutputChecker
186
187
 
187
- assert TopKOutputChecker.format_is_triton_kernel(topk_output)
188
+ assert TopKOutputChecker.format_is_triton_kernels(topk_output)
188
189
 
189
190
  routing_data, gather_idx, scatter_idx = topk_output
190
191
 
@@ -201,6 +202,7 @@ def triton_kernel_moe_with_bias_forward(
201
202
  scatter_indx=scatter_idx,
202
203
  inplace=False, # triton kernel doesn't support inplace
203
204
  activation=moe_runner_config.activation,
205
+ apply_router_weight_on_input=apply_router_weight_on_input,
204
206
  use_fp8_w8a8=use_fp8_w8a8,
205
207
  per_channel_quant=per_channel_quant,
206
208
  global_num_experts=global_num_experts,
@@ -228,6 +230,7 @@ def triton_kernel_fused_experts_with_bias(
228
230
  scatter_indx: ScatterIndx,
229
231
  inplace: bool = False,
230
232
  activation: str = "silu",
233
+ apply_router_weight_on_input: bool = False,
231
234
  use_fp8_w8a8: bool = False,
232
235
  per_channel_quant: bool = False,
233
236
  global_num_experts: int = -1,
@@ -296,7 +299,7 @@ def triton_kernel_fused_experts_with_bias(
296
299
  routing_data,
297
300
  gather_indx=gather_indx,
298
301
  precision_config=w1_pcg,
299
- gammas=None,
302
+ gammas=routing_data.gate_scal if apply_router_weight_on_input else None,
300
303
  fused_activation=act,
301
304
  )
302
305
 
@@ -307,5 +310,5 @@ def triton_kernel_fused_experts_with_bias(
307
310
  routing_data,
308
311
  scatter_indx=scatter_indx,
309
312
  precision_config=w2_pcg,
310
- gammas=routing_data.gate_scal,
313
+ gammas=None if apply_router_weight_on_input else routing_data.gate_scal,
311
314
  )