sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +18 -3
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  5. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
  6. sglang/srt/checkpoint_engine/__init__.py +9 -0
  7. sglang/srt/checkpoint_engine/update.py +317 -0
  8. sglang/srt/configs/__init__.py +2 -0
  9. sglang/srt/configs/deepseek_ocr.py +542 -10
  10. sglang/srt/configs/deepseekvl2.py +95 -194
  11. sglang/srt/configs/kimi_linear.py +160 -0
  12. sglang/srt/configs/mamba_utils.py +66 -0
  13. sglang/srt/configs/model_config.py +25 -2
  14. sglang/srt/constants.py +7 -0
  15. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  16. sglang/srt/disaggregation/decode.py +34 -6
  17. sglang/srt/disaggregation/nixl/conn.py +2 -2
  18. sglang/srt/disaggregation/prefill.py +25 -3
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  20. sglang/srt/distributed/parallel_state.py +9 -5
  21. sglang/srt/entrypoints/engine.py +13 -5
  22. sglang/srt/entrypoints/http_server.py +22 -3
  23. sglang/srt/entrypoints/openai/protocol.py +7 -1
  24. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  27. sglang/srt/environ.py +7 -0
  28. sglang/srt/eplb/expert_distribution.py +34 -1
  29. sglang/srt/eplb/expert_location.py +106 -36
  30. sglang/srt/grpc/compile_proto.py +3 -0
  31. sglang/srt/layers/attention/ascend_backend.py +233 -5
  32. sglang/srt/layers/attention/attention_registry.py +3 -0
  33. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  34. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  35. sglang/srt/layers/attention/fla/kda.py +1359 -0
  36. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  37. sglang/srt/layers/attention/flashattention_backend.py +7 -6
  38. sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
  39. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  40. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  41. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  42. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  43. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  44. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  45. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  46. sglang/srt/layers/attention/nsa_backend.py +157 -23
  47. sglang/srt/layers/attention/triton_backend.py +4 -1
  48. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  49. sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
  50. sglang/srt/layers/communicator.py +23 -1
  51. sglang/srt/layers/layernorm.py +16 -2
  52. sglang/srt/layers/logits_processor.py +4 -20
  53. sglang/srt/layers/moe/ep_moe/layer.py +0 -18
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  57. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  59. sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
  60. sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
  61. sglang/srt/layers/moe/topk.py +31 -6
  62. sglang/srt/layers/pooler.py +21 -2
  63. sglang/srt/layers/quantization/__init__.py +9 -78
  64. sglang/srt/layers/quantization/auto_round.py +394 -0
  65. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  66. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  67. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  68. sglang/srt/layers/rotary_embedding.py +117 -45
  69. sglang/srt/lora/lora_registry.py +9 -0
  70. sglang/srt/managers/async_mm_data_processor.py +122 -0
  71. sglang/srt/managers/data_parallel_controller.py +30 -3
  72. sglang/srt/managers/detokenizer_manager.py +3 -0
  73. sglang/srt/managers/io_struct.py +26 -4
  74. sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
  75. sglang/srt/managers/schedule_batch.py +74 -15
  76. sglang/srt/managers/scheduler.py +164 -129
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  78. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  79. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  80. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  81. sglang/srt/managers/session_controller.py +6 -5
  82. sglang/srt/managers/tokenizer_manager.py +154 -59
  83. sglang/srt/managers/tp_worker.py +24 -1
  84. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  85. sglang/srt/mem_cache/common.py +1 -0
  86. sglang/srt/mem_cache/memory_pool.py +171 -57
  87. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  88. sglang/srt/mem_cache/radix_cache.py +4 -0
  89. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  90. sglang/srt/metrics/collector.py +46 -3
  91. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  92. sglang/srt/model_executor/forward_batch_info.py +11 -11
  93. sglang/srt/model_executor/model_runner.py +76 -21
  94. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  95. sglang/srt/model_loader/weight_utils.py +1 -1
  96. sglang/srt/models/bailing_moe.py +9 -2
  97. sglang/srt/models/deepseek_nextn.py +11 -2
  98. sglang/srt/models/deepseek_v2.py +149 -34
  99. sglang/srt/models/glm4.py +391 -77
  100. sglang/srt/models/glm4v.py +196 -55
  101. sglang/srt/models/glm4v_moe.py +0 -1
  102. sglang/srt/models/gpt_oss.py +1 -10
  103. sglang/srt/models/kimi_linear.py +678 -0
  104. sglang/srt/models/llama4.py +1 -1
  105. sglang/srt/models/llama_eagle3.py +11 -1
  106. sglang/srt/models/longcat_flash.py +2 -2
  107. sglang/srt/models/minimax_m2.py +1 -1
  108. sglang/srt/models/qwen2.py +1 -1
  109. sglang/srt/models/qwen2_moe.py +30 -15
  110. sglang/srt/models/qwen3.py +1 -1
  111. sglang/srt/models/qwen3_moe.py +16 -8
  112. sglang/srt/models/qwen3_next.py +7 -0
  113. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  114. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  115. sglang/srt/multiplex/pdmux_context.py +164 -0
  116. sglang/srt/parser/conversation.py +7 -1
  117. sglang/srt/sampling/custom_logit_processor.py +67 -1
  118. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  119. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  120. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  121. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  122. sglang/srt/server_args.py +103 -22
  123. sglang/srt/single_batch_overlap.py +4 -1
  124. sglang/srt/speculative/draft_utils.py +16 -0
  125. sglang/srt/speculative/eagle_info.py +42 -36
  126. sglang/srt/speculative/eagle_info_v2.py +68 -25
  127. sglang/srt/speculative/eagle_utils.py +261 -16
  128. sglang/srt/speculative/eagle_worker.py +11 -3
  129. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  130. sglang/srt/speculative/spec_info.py +305 -31
  131. sglang/srt/speculative/spec_utils.py +44 -8
  132. sglang/srt/tracing/trace.py +121 -12
  133. sglang/srt/utils/common.py +55 -32
  134. sglang/srt/utils/hf_transformers_utils.py +38 -16
  135. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  136. sglang/test/kits/radix_cache_server_kit.py +50 -0
  137. sglang/test/runners.py +31 -7
  138. sglang/test/simple_eval_common.py +5 -3
  139. sglang/test/simple_eval_humaneval.py +1 -0
  140. sglang/test/simple_eval_math.py +1 -0
  141. sglang/test/simple_eval_mmlu.py +1 -0
  142. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  143. sglang/test/test_utils.py +7 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
  146. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
  147. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  148. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  149. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -25,6 +25,13 @@ from sglang.srt.utils import (
25
25
  is_hip,
26
26
  )
27
27
 
28
+ try:
29
+ from triton.tools.tensor_descriptor import TensorDescriptor
30
+
31
+ _support_tensor_descriptor = True
32
+ except:
33
+ _support_tensor_descriptor = False
34
+
28
35
  _is_hip = is_hip()
29
36
  _is_cuda = is_cuda()
30
37
  _is_cpu_amx_available = cpu_has_amx_support()
@@ -41,6 +48,10 @@ elif _is_hip:
41
48
  padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
42
49
 
43
50
 
51
+ def support_tensor_descriptor():
52
+ return _support_tensor_descriptor
53
+
54
+
44
55
  @triton.jit
45
56
  def write_zeros_to_output(
46
57
  c_ptr,
@@ -108,6 +119,7 @@ def fused_moe_kernel_gptq_awq(
108
119
  use_int4_w4a16: tl.constexpr,
109
120
  use_int8_w8a16: tl.constexpr,
110
121
  even_Ks: tl.constexpr,
122
+ filter_expert: tl.constexpr,
111
123
  ):
112
124
  """
113
125
  Implements the fused computation for a Mixture of Experts (MOE) using
@@ -161,7 +173,7 @@ def fused_moe_kernel_gptq_awq(
161
173
  token_mask = offs_token < num_valid_tokens
162
174
 
163
175
  off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
164
- if off_experts == -1:
176
+ if filter_expert and off_experts == -1:
165
177
  # -----------------------------------------------------------
166
178
  # Write back zeros to the output when the expert is not
167
179
  # in the current expert parallel rank.
@@ -296,7 +308,9 @@ def fused_moe_kernel_gptq_awq(
296
308
  def fused_moe_kernel(
297
309
  # Pointers to matrices
298
310
  a_ptr,
311
+ a_desc,
299
312
  b_ptr,
313
+ b_desc,
300
314
  bias_ptr,
301
315
  c_ptr,
302
316
  a_scale_ptr,
@@ -344,6 +358,8 @@ def fused_moe_kernel(
344
358
  use_int8_w8a16: tl.constexpr,
345
359
  per_channel_quant: tl.constexpr,
346
360
  even_Ks: tl.constexpr,
361
+ c_sorted: tl.constexpr,
362
+ filter_expert: tl.constexpr,
347
363
  ):
348
364
  """
349
365
  Implements the fused computation for a Mixture of Experts (MOE) using
@@ -399,9 +415,10 @@ def fused_moe_kernel(
399
415
  offs_token = offs_token.to(tl.int64)
400
416
  token_mask = offs_token < num_valid_tokens
401
417
 
402
- off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
418
+ off_experts_i32 = tl.load(expert_ids_ptr + pid_m)
419
+ off_experts = off_experts_i32.to(tl.int64)
403
420
 
404
- if off_experts == -1:
421
+ if filter_expert and off_experts == -1:
405
422
  # -----------------------------------------------------------
406
423
  # Write back zeros to the output when the expert is not
407
424
  # in the current expert parallel rank.
@@ -421,15 +438,23 @@ def fused_moe_kernel(
421
438
 
422
439
  offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
423
440
  offs_k = tl.arange(0, BLOCK_SIZE_K)
424
- a_ptrs = a_ptr + (
425
- offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
426
- )
441
+ if a_desc is not None:
442
+ assert use_fp8_w8a8 and group_n > 0 and group_k > 0
443
+ start_offs_m = pid_m * BLOCK_SIZE_M
444
+ else:
445
+ a_ptrs = a_ptr + (
446
+ offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
447
+ )
448
+
449
+ if b_desc is not None:
450
+ start_offs_n = pid_n * BLOCK_SIZE_N
451
+ else:
452
+ b_ptrs = (
453
+ b_ptr
454
+ + off_experts * stride_be
455
+ + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
456
+ )
427
457
 
428
- b_ptrs = (
429
- b_ptr
430
- + off_experts * stride_be
431
- + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
432
- )
433
458
  if bias_ptr is not None:
434
459
  bias = tl.load(
435
460
  bias_ptr + off_experts * stride_bias_e + offs_bn[None, :] * stride_bias_n
@@ -443,8 +468,14 @@ def fused_moe_kernel(
443
468
  if use_fp8_w8a8 or use_int8_w8a8:
444
469
  # block-wise
445
470
  if group_k > 0 and group_n > 0:
446
- a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
447
- offs_bsn = offs_bn // group_n
471
+ if a_desc is not None:
472
+ a_scale_ptrs = a_scale_ptr + offs_token_id * stride_asm
473
+ else:
474
+ a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
475
+ if BLOCK_SIZE_N > group_n:
476
+ offs_bsn = offs_bn // group_n
477
+ else:
478
+ offs_bsn = pid_n * BLOCK_SIZE_N // group_n
448
479
  b_scale_ptrs = (
449
480
  b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
450
481
  )
@@ -469,37 +500,49 @@ def fused_moe_kernel(
469
500
  # `accumulator` will be converted back to fp16 after the loop.
470
501
  accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
471
502
 
472
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
503
+ for k_start in range(0, K, BLOCK_SIZE_K):
473
504
  # Load the next block of A and B, generate a mask by checking the
474
505
  # K dimension.
475
- if even_Ks:
506
+ if a_desc is not None:
507
+ a = a_desc.load([start_offs_m, k_start])
508
+ elif even_Ks:
476
509
  a = tl.load(
477
510
  a_ptrs,
478
511
  mask=token_mask[:, None],
479
512
  other=0.0,
480
513
  )
481
- b = tl.load(b_ptrs)
482
514
  else:
483
515
  a = tl.load(
484
516
  a_ptrs,
485
- mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
517
+ mask=token_mask[:, None] & (offs_k[None, :] < K - k_start),
486
518
  other=0.0,
487
519
  )
488
- b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
520
+
521
+ if b_desc is not None:
522
+ b = (
523
+ b_desc.load([off_experts_i32, start_offs_n, k_start])
524
+ .reshape(BLOCK_SIZE_N, BLOCK_SIZE_K)
525
+ .T
526
+ )
527
+ elif even_Ks:
528
+ b = tl.load(b_ptrs)
529
+ else:
530
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k_start, other=0.0)
489
531
 
490
532
  # We accumulate along the K dimension.
491
533
  if use_int8_w8a16:
492
534
  accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
493
535
  elif use_fp8_w8a8 or use_int8_w8a8:
494
536
  if group_k > 0 and group_n > 0:
495
- k_start = k * BLOCK_SIZE_K
496
537
  offs_ks = k_start // group_k
497
538
  a_scale = tl.load(
498
539
  a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
499
540
  )
500
541
  b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
501
-
502
- accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
542
+ if BLOCK_SIZE_N > group_n:
543
+ accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
544
+ else:
545
+ accumulator += tl.dot(a, b) * (a_scale[:, None] * b_scale)
503
546
  else:
504
547
  if use_fp8_w8a8:
505
548
  accumulator = tl.dot(a, b, acc=accumulator)
@@ -508,8 +551,10 @@ def fused_moe_kernel(
508
551
  else:
509
552
  accumulator += tl.dot(a, b)
510
553
  # Advance the ptrs to the next K block.
511
- a_ptrs += BLOCK_SIZE_K * stride_ak
512
- b_ptrs += BLOCK_SIZE_K * stride_bk
554
+ if a_desc is None:
555
+ a_ptrs += BLOCK_SIZE_K * stride_ak
556
+ if b_desc is None:
557
+ b_ptrs += BLOCK_SIZE_K * stride_bk
513
558
 
514
559
  if use_int8_w8a16:
515
560
  accumulator *= b_scale
@@ -528,7 +573,12 @@ def fused_moe_kernel(
528
573
  # -----------------------------------------------------------
529
574
  # Write back the block of the output
530
575
  offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
531
- c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
576
+ if c_sorted:
577
+ c_ptrs = (
578
+ c_ptr + stride_cm * offs_token_id[:, None] + stride_cn * offs_cn[None, :]
579
+ )
580
+ else:
581
+ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
532
582
  c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
533
583
  tl.store(c_ptrs, accumulator, mask=c_mask)
534
584
 
@@ -557,6 +607,10 @@ def invoke_fused_moe_kernel(
557
607
  per_channel_quant: bool,
558
608
  block_shape: Optional[List[int]] = None,
559
609
  no_combine: bool = False,
610
+ a_use_tma: bool = False,
611
+ b_use_tma: bool = False,
612
+ c_sorted: bool = False,
613
+ filter_expert: bool = True,
560
614
  ) -> None:
561
615
  assert topk_weights.stride(1) == 1
562
616
  assert sorted_token_ids.stride(0) == 1
@@ -662,14 +716,38 @@ def invoke_fused_moe_kernel(
662
716
  use_int4_w4a16=use_int4_w4a16,
663
717
  use_int8_w8a16=use_int8_w8a16,
664
718
  even_Ks=even_Ks,
719
+ filter_expert=filter_expert,
665
720
  **config,
666
721
  )
667
722
 
668
723
  else:
724
+ if a_use_tma or b_use_tma:
725
+ # TMA descriptors require a global memory allocation
726
+ def alloc_fn(size: int, alignment: int, stream: Optional[int]):
727
+ return torch.empty(size, device="cuda", dtype=torch.int8)
728
+
729
+ triton.set_allocator(alloc_fn)
730
+ if a_use_tma:
731
+ a_desc = TensorDescriptor(
732
+ A, A.shape, A.stride(), [config["BLOCK_SIZE_M"], config["BLOCK_SIZE_K"]]
733
+ )
734
+ else:
735
+ a_desc = None
736
+ if b_use_tma:
737
+ b_desc = TensorDescriptor(
738
+ B,
739
+ B.shape,
740
+ B.stride(),
741
+ [1, config["BLOCK_SIZE_N"], config["BLOCK_SIZE_K"]],
742
+ )
743
+ else:
744
+ b_desc = None
669
745
 
670
746
  fused_moe_kernel[grid](
671
747
  A,
748
+ a_desc,
672
749
  B,
750
+ b_desc,
673
751
  bias,
674
752
  C,
675
753
  A_scale,
@@ -689,8 +767,8 @@ def invoke_fused_moe_kernel(
689
767
  B.stride(1),
690
768
  bias.stride(0) if bias is not None else 0,
691
769
  bias.stride(1) if bias is not None else 0,
692
- C.stride(1),
693
- C.stride(2),
770
+ C.stride(-2),
771
+ C.stride(-1),
694
772
  A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
695
773
  A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
696
774
  B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
@@ -706,6 +784,8 @@ def invoke_fused_moe_kernel(
706
784
  use_int8_w8a16=use_int8_w8a16,
707
785
  per_channel_quant=per_channel_quant,
708
786
  even_Ks=even_Ks,
787
+ c_sorted=c_sorted,
788
+ filter_expert=filter_expert,
709
789
  **config,
710
790
  )
711
791
 
@@ -39,6 +39,9 @@ if not (_is_npu or _is_hip):
39
39
  from sgl_kernel import silu_and_mul
40
40
 
41
41
 
42
+ _MASKED_GEMM_FAST_ACT = get_bool_env_var("SGLANG_MASKED_GEMM_FAST_ACT")
43
+
44
+
42
45
  # TODO(kaixih@nvidia): ideally we should merge this logic into
43
46
  # `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
44
47
  @torch.compile
@@ -214,6 +217,9 @@ class DeepGemmRunnerCore(MoeRunnerCore):
214
217
  from sglang.srt.layers.moe.ep_moe.kernels import (
215
218
  silu_and_mul_masked_post_quant_fwd,
216
219
  )
220
+ from sglang.srt.layers.quantization.fp8_kernel import (
221
+ sglang_per_token_group_quant_8bit,
222
+ )
217
223
 
218
224
  hidden_states = runner_input.hidden_states
219
225
  hidden_states_scale = runner_input.hidden_states_scale
@@ -227,15 +233,16 @@ class DeepGemmRunnerCore(MoeRunnerCore):
227
233
 
228
234
  hidden_states_device = running_state["hidden_states_device"]
229
235
 
230
- if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
231
- b, s_mn, s_k = hidden_states_scale.shape
232
- assert (
233
- s_mn % 4 == 0 and s_k % 4 == 0
234
- ), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
235
-
236
236
  # GroupGemm-0
237
237
  if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
238
- hidden_states_scale = _cast_to_e8m0_with_rounding_up(hidden_states_scale)
238
+ if hidden_states_scale.dtype != torch.int:
239
+ b, s_mn, s_k = hidden_states_scale.shape
240
+ assert (
241
+ s_mn % 4 == 0 and s_k % 4 == 0
242
+ ), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
243
+ hidden_states_scale = _cast_to_e8m0_with_rounding_up(
244
+ hidden_states_scale
245
+ )
239
246
  else:
240
247
  hidden_states_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
241
248
  hidden_states_scale
@@ -257,33 +264,46 @@ class DeepGemmRunnerCore(MoeRunnerCore):
257
264
  dispose_tensor(hidden_states_scale)
258
265
 
259
266
  # Act
260
- down_input = torch.empty(
261
- (
262
- gateup_output.shape[0],
263
- gateup_output.shape[1],
264
- gateup_output.shape[2] // 2,
265
- ),
266
- device=hidden_states_device,
267
- dtype=torch.float8_e4m3fn,
268
- )
269
267
  scale_block_size = 128
270
- down_input_scale = torch.empty(
271
- (
272
- gateup_output.shape[0],
273
- gateup_output.shape[1],
274
- gateup_output.shape[2] // 2 // scale_block_size,
275
- ),
276
- device=hidden_states_device,
277
- dtype=torch.float32,
278
- )
279
- silu_and_mul_masked_post_quant_fwd(
280
- gateup_output,
281
- down_input,
282
- down_input_scale,
283
- scale_block_size,
284
- masked_m,
285
- scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
286
- )
268
+ if _MASKED_GEMM_FAST_ACT:
269
+ down_input, down_input_scale = sglang_per_token_group_quant_8bit(
270
+ x=gateup_output,
271
+ dst_dtype=torch.float8_e4m3fn,
272
+ group_size=scale_block_size,
273
+ masked_m=masked_m,
274
+ column_major_scales=True,
275
+ scale_tma_aligned=True,
276
+ scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
277
+ fuse_silu_and_mul=True,
278
+ enable_v2=True,
279
+ )
280
+ else:
281
+ down_input = torch.empty(
282
+ (
283
+ gateup_output.shape[0],
284
+ gateup_output.shape[1],
285
+ gateup_output.shape[2] // 2,
286
+ ),
287
+ device=hidden_states_device,
288
+ dtype=torch.float8_e4m3fn,
289
+ )
290
+ down_input_scale = torch.empty(
291
+ (
292
+ gateup_output.shape[0],
293
+ gateup_output.shape[1],
294
+ gateup_output.shape[2] // 2 // scale_block_size,
295
+ ),
296
+ device=hidden_states_device,
297
+ dtype=torch.float32,
298
+ )
299
+ silu_and_mul_masked_post_quant_fwd(
300
+ gateup_output,
301
+ down_input,
302
+ down_input_scale,
303
+ scale_block_size,
304
+ masked_m,
305
+ scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
306
+ )
287
307
  del gateup_output
288
308
 
289
309
  # GroupGemm-1
@@ -97,7 +97,6 @@ class DeepEPNormalCombineInput(NamedTuple):
97
97
  hidden_states: torch.Tensor
98
98
  topk_ids: torch.Tensor
99
99
  topk_weights: torch.Tensor
100
- overlap_args: Optional[CombineOverlapArgs] = None
101
100
 
102
101
  @property
103
102
  def format(self) -> CombineInputFormat:
@@ -110,7 +109,6 @@ class DeepEPLLCombineInput(NamedTuple):
110
109
  hidden_states: torch.Tensor
111
110
  topk_ids: torch.Tensor
112
111
  topk_weights: torch.Tensor
113
- overlap_args: Optional[CombineOverlapArgs] = None
114
112
 
115
113
  @property
116
114
  def format(self) -> CombineInputFormat:
@@ -333,7 +331,7 @@ class _DeepEPDispatcherImplBase:
333
331
  hidden_states: torch.Tensor,
334
332
  topk_ids: torch.Tensor,
335
333
  topk_weights: torch.Tensor,
336
- overlap_args: Optional["CombineOverlapArgs"],
334
+ overlap_args: Optional[CombineOverlapArgs] = None,
337
335
  ):
338
336
  raise NotImplementedError
339
337
 
@@ -463,7 +461,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
463
461
  hidden_states: torch.Tensor,
464
462
  topk_ids: torch.Tensor,
465
463
  topk_weights: torch.Tensor,
466
- overlap_args: Optional["CombineOverlapArgs"],
464
+ overlap_args: Optional[CombineOverlapArgs] = None,
467
465
  ):
468
466
 
469
467
  if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
@@ -619,7 +617,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
619
617
  hidden_states: torch.Tensor,
620
618
  topk_ids: torch.Tensor,
621
619
  topk_weights: torch.Tensor,
622
- overlap_args: Optional["CombineOverlapArgs"],
620
+ overlap_args: Optional[CombineOverlapArgs] = None,
623
621
  ):
624
622
  hidden_states, event, hook = self._combine_core(
625
623
  hidden_states,
@@ -645,7 +643,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
645
643
  hidden_states: torch.Tensor,
646
644
  topk_ids: torch.Tensor,
647
645
  topk_weights: torch.Tensor,
648
- overlap_args: Optional["CombineOverlapArgs"],
646
+ overlap_args: Optional[CombineOverlapArgs] = None,
649
647
  ):
650
648
  buffer = self._get_buffer()
651
649
 
@@ -762,16 +760,21 @@ class DeepEPDispatcher(BaseDispatcher):
762
760
  del self._dispatch_intermediate_state
763
761
  return self._get_impl().dispatch_b(*inner_state)
764
762
 
765
- def combine(self, combine_input: CombineInput) -> Tuple:
766
- self.combine_a(combine_input)
763
+ def combine(
764
+ self,
765
+ combine_input: CombineInput,
766
+ overlap_args: Optional[CombineOverlapArgs] = None,
767
+ ) -> Tuple:
768
+ self.combine_a(combine_input, overlap_args)
767
769
  ret = self.combine_b()
768
770
  return ret
769
771
 
770
772
  def combine_a(
771
773
  self,
772
774
  combine_input: CombineInput,
775
+ overlap_args: Optional[CombineOverlapArgs] = None,
773
776
  ):
774
- hidden_states, topk_ids, topk_weights, overlap_args = combine_input
777
+ hidden_states, topk_ids, topk_weights = combine_input
775
778
  self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
776
779
  inner_state = self._get_impl().combine_a(
777
780
  hidden_states=hidden_states,
@@ -314,16 +314,41 @@ class TopK(CustomOp):
314
314
  num_token_non_padded: Optional[torch.Tensor] = None,
315
315
  expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
316
316
  ) -> TopKOutput:
317
- global_num_experts = router_logits.shape[-1]
318
317
 
319
- # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
320
- if global_num_experts == 256:
318
+ use_grouped_topk = self.topk_config.use_grouped_topk
319
+ torch_native = self.topk_config.torch_native
320
+ renormalize = self.topk_config.renormalize
321
321
 
322
+ if not use_grouped_topk and not torch_native:
323
+ topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
324
+ router_logits,
325
+ k=self.topk_config.top_k,
326
+ )
327
+ topk_weights = topk_weights.to(torch.float32)
328
+
329
+ if renormalize:
330
+ topk_weights_sum = (
331
+ topk_weights.sum(dim=-1, keepdim=True)
332
+ if self.topk_config.num_fused_shared_experts == 0
333
+ else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
334
+ )
335
+ topk_weights = topk_weights / topk_weights_sum
336
+
337
+ if expert_location_dispatch_info is not None:
338
+ topk_ids = topk_ids_logical_to_physical(
339
+ topk_ids, expert_location_dispatch_info
340
+ )
341
+ get_global_expert_distribution_recorder().on_select_experts(
342
+ topk_ids=topk_ids
343
+ )
344
+
345
+ return StandardTopKOutput(topk_weights, topk_ids, _)
346
+ if use_grouped_topk and not torch_native and router_logits.shape[-1] == 256:
347
+ # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
322
348
  routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
323
- router_logits = router_logits.to(torch.float32)
324
349
 
325
350
  topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
326
- router_logits,
351
+ router_logits.to(torch.float32),
327
352
  k=self.topk_config.top_k,
328
353
  bias=self.topk_config.correction_bias.to(torch.float32),
329
354
  k_group=self.topk_config.topk_group,
@@ -335,7 +360,7 @@ class TopK(CustomOp):
335
360
  eps=float(1e-20),
336
361
  )
337
362
 
338
- if self.topk_config.renormalize:
363
+ if renormalize:
339
364
  topk_weights_sum = (
340
365
  topk_weights.sum(dim=-1, keepdim=True)
341
366
  if self.topk_config.num_fused_shared_experts == 0
@@ -20,7 +20,9 @@ class PoolingType(IntEnum):
20
20
 
21
21
  @dataclass
22
22
  class EmbeddingPoolerOutput:
23
- embeddings: torch.Tensor
23
+ # Pooler can return list[tensor] instead of tensor if the dimension of each tensor in the batch is different
24
+ # due to different per-request matryoshka dim truncation
25
+ embeddings: torch.Tensor | list[torch.Tensor]
24
26
 
25
27
 
26
28
  class Pooler(nn.Module):
@@ -42,6 +44,7 @@ class Pooler(nn.Module):
42
44
  def forward(
43
45
  self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
44
46
  ) -> EmbeddingPoolerOutput:
47
+
45
48
  if self.pooling_type == PoolingType.LAST:
46
49
  last_token_indices = torch.cumsum(forward_batch.extend_seq_lens, dim=0) - 1
47
50
  pooled_data = hidden_states[last_token_indices]
@@ -53,8 +56,24 @@ class Pooler(nn.Module):
53
56
  else:
54
57
  raise ValueError(f"Invalid pooling type: {self.pooling_type}")
55
58
 
59
+ if forward_batch.dimensions is not None:
60
+ all_same_dimensions = len(set(forward_batch.dimensions)) == 1
61
+ if all_same_dimensions:
62
+ pooled_data = pooled_data[..., : forward_batch.dimensions[0]]
63
+ else:
64
+ pooled_data = [
65
+ tensor[..., :dim]
66
+ for tensor, dim in zip(pooled_data, forward_batch.dimensions)
67
+ ]
68
+
56
69
  if self.normalize:
57
- pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
70
+ if isinstance(pooled_data, list):
71
+ pooled_data = [
72
+ nn.functional.normalize(tensor, p=2, dim=-1)
73
+ for tensor in pooled_data
74
+ ]
75
+ else:
76
+ pooled_data = nn.functional.normalize(pooled_data, p=2, dim=-1)
58
77
 
59
78
  return EmbeddingPoolerOutput(embeddings=pooled_data)
60
79