sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +18 -3
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  5. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
  6. sglang/srt/checkpoint_engine/__init__.py +9 -0
  7. sglang/srt/checkpoint_engine/update.py +317 -0
  8. sglang/srt/configs/__init__.py +2 -0
  9. sglang/srt/configs/deepseek_ocr.py +542 -10
  10. sglang/srt/configs/deepseekvl2.py +95 -194
  11. sglang/srt/configs/kimi_linear.py +160 -0
  12. sglang/srt/configs/mamba_utils.py +66 -0
  13. sglang/srt/configs/model_config.py +25 -2
  14. sglang/srt/constants.py +7 -0
  15. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  16. sglang/srt/disaggregation/decode.py +34 -6
  17. sglang/srt/disaggregation/nixl/conn.py +2 -2
  18. sglang/srt/disaggregation/prefill.py +25 -3
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  20. sglang/srt/distributed/parallel_state.py +9 -5
  21. sglang/srt/entrypoints/engine.py +13 -5
  22. sglang/srt/entrypoints/http_server.py +22 -3
  23. sglang/srt/entrypoints/openai/protocol.py +7 -1
  24. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  27. sglang/srt/environ.py +7 -0
  28. sglang/srt/eplb/expert_distribution.py +34 -1
  29. sglang/srt/eplb/expert_location.py +106 -36
  30. sglang/srt/grpc/compile_proto.py +3 -0
  31. sglang/srt/layers/attention/ascend_backend.py +233 -5
  32. sglang/srt/layers/attention/attention_registry.py +3 -0
  33. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  34. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  35. sglang/srt/layers/attention/fla/kda.py +1359 -0
  36. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  37. sglang/srt/layers/attention/flashattention_backend.py +7 -6
  38. sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
  39. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  40. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  41. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  42. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  43. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  44. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  45. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  46. sglang/srt/layers/attention/nsa_backend.py +157 -23
  47. sglang/srt/layers/attention/triton_backend.py +4 -1
  48. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  49. sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
  50. sglang/srt/layers/communicator.py +23 -1
  51. sglang/srt/layers/layernorm.py +16 -2
  52. sglang/srt/layers/logits_processor.py +4 -20
  53. sglang/srt/layers/moe/ep_moe/layer.py +0 -18
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  57. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  59. sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
  60. sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
  61. sglang/srt/layers/moe/topk.py +31 -6
  62. sglang/srt/layers/pooler.py +21 -2
  63. sglang/srt/layers/quantization/__init__.py +9 -78
  64. sglang/srt/layers/quantization/auto_round.py +394 -0
  65. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  66. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  67. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  68. sglang/srt/layers/rotary_embedding.py +117 -45
  69. sglang/srt/lora/lora_registry.py +9 -0
  70. sglang/srt/managers/async_mm_data_processor.py +122 -0
  71. sglang/srt/managers/data_parallel_controller.py +30 -3
  72. sglang/srt/managers/detokenizer_manager.py +3 -0
  73. sglang/srt/managers/io_struct.py +26 -4
  74. sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
  75. sglang/srt/managers/schedule_batch.py +74 -15
  76. sglang/srt/managers/scheduler.py +164 -129
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  78. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  79. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  80. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  81. sglang/srt/managers/session_controller.py +6 -5
  82. sglang/srt/managers/tokenizer_manager.py +154 -59
  83. sglang/srt/managers/tp_worker.py +24 -1
  84. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  85. sglang/srt/mem_cache/common.py +1 -0
  86. sglang/srt/mem_cache/memory_pool.py +171 -57
  87. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  88. sglang/srt/mem_cache/radix_cache.py +4 -0
  89. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  90. sglang/srt/metrics/collector.py +46 -3
  91. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  92. sglang/srt/model_executor/forward_batch_info.py +11 -11
  93. sglang/srt/model_executor/model_runner.py +76 -21
  94. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  95. sglang/srt/model_loader/weight_utils.py +1 -1
  96. sglang/srt/models/bailing_moe.py +9 -2
  97. sglang/srt/models/deepseek_nextn.py +11 -2
  98. sglang/srt/models/deepseek_v2.py +149 -34
  99. sglang/srt/models/glm4.py +391 -77
  100. sglang/srt/models/glm4v.py +196 -55
  101. sglang/srt/models/glm4v_moe.py +0 -1
  102. sglang/srt/models/gpt_oss.py +1 -10
  103. sglang/srt/models/kimi_linear.py +678 -0
  104. sglang/srt/models/llama4.py +1 -1
  105. sglang/srt/models/llama_eagle3.py +11 -1
  106. sglang/srt/models/longcat_flash.py +2 -2
  107. sglang/srt/models/minimax_m2.py +1 -1
  108. sglang/srt/models/qwen2.py +1 -1
  109. sglang/srt/models/qwen2_moe.py +30 -15
  110. sglang/srt/models/qwen3.py +1 -1
  111. sglang/srt/models/qwen3_moe.py +16 -8
  112. sglang/srt/models/qwen3_next.py +7 -0
  113. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  114. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  115. sglang/srt/multiplex/pdmux_context.py +164 -0
  116. sglang/srt/parser/conversation.py +7 -1
  117. sglang/srt/sampling/custom_logit_processor.py +67 -1
  118. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  119. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  120. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  121. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  122. sglang/srt/server_args.py +103 -22
  123. sglang/srt/single_batch_overlap.py +4 -1
  124. sglang/srt/speculative/draft_utils.py +16 -0
  125. sglang/srt/speculative/eagle_info.py +42 -36
  126. sglang/srt/speculative/eagle_info_v2.py +68 -25
  127. sglang/srt/speculative/eagle_utils.py +261 -16
  128. sglang/srt/speculative/eagle_worker.py +11 -3
  129. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  130. sglang/srt/speculative/spec_info.py +305 -31
  131. sglang/srt/speculative/spec_utils.py +44 -8
  132. sglang/srt/tracing/trace.py +121 -12
  133. sglang/srt/utils/common.py +55 -32
  134. sglang/srt/utils/hf_transformers_utils.py +38 -16
  135. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  136. sglang/test/kits/radix_cache_server_kit.py +50 -0
  137. sglang/test/runners.py +31 -7
  138. sglang/test/simple_eval_common.py +5 -3
  139. sglang/test/simple_eval_humaneval.py +1 -0
  140. sglang/test/simple_eval_math.py +1 -0
  141. sglang/test/simple_eval_mmlu.py +1 -0
  142. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  143. sglang/test/test_utils.py +7 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
  146. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
  147. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  148. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  149. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -25,10 +25,11 @@ from sglang.srt.layers.quantization.base_config import (
25
25
  QuantizationConfig,
26
26
  QuantizeMethodBase,
27
27
  )
28
+ from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
28
29
  from sglang.srt.layers.quantization.fp8_utils import (
29
30
  apply_fp8_linear,
30
31
  cutlass_fp8_supported,
31
- is_sm100_supported,
32
+ is_blackwell_supported,
32
33
  )
33
34
  from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
34
35
  from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
@@ -49,8 +50,10 @@ if TYPE_CHECKING:
49
50
  )
50
51
  from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
51
52
 
52
- if is_cuda():
53
- from sgl_kernel import scaled_fp4_quant
53
+ try:
54
+ from flashinfer import fp4_quantize
55
+ except ImportError:
56
+ fp4_quantize = None
54
57
 
55
58
  try:
56
59
  from flashinfer import mm_fp4 as fp4_gemm
@@ -466,8 +469,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
466
469
  # Fp8 moe kernel needs single weight scale for w13 per expert.
467
470
  # We take the max of the w1 and w3 scales then dequant and requant each expert.
468
471
  if layer.w13_weight_scale.dim() == 2: # Shape: (num_experts, 2)
469
- from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
470
-
471
472
  # Get the maximum scale across w1 and w3 for each expert
472
473
  max_w13_scales = layer.w13_weight_scale.max(dim=1).values
473
474
 
@@ -515,6 +516,84 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
515
516
  layer.w2_input_scale.max(), requires_grad=False
516
517
  )
517
518
 
519
+ # Align FP8 weights to FlashInfer per-tensor kernel layout if enabled
520
+ if should_use_flashinfer_trtllm_moe():
521
+ from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a
522
+
523
+ # 1) Swap W13 halves: [Up, Gate] -> [Gate, Up] expected by FI
524
+ num_experts, two_n, hidden = layer.w13_weight.shape
525
+ inter = two_n // 2
526
+ w13_swapped = (
527
+ layer.w13_weight.reshape(num_experts, 2, inter, hidden)
528
+ .flip(dims=[1])
529
+ .reshape(num_experts, two_n, hidden)
530
+ )
531
+
532
+ # 2) Reorder rows for fused gated activation (W13)
533
+ w13_interleaved = [
534
+ reorder_rows_for_gated_act_gemm(w13_swapped[i])
535
+ for i in range(num_experts)
536
+ ]
537
+ w13_interleaved = torch.stack(w13_interleaved).reshape(
538
+ num_experts, two_n, hidden
539
+ )
540
+
541
+ # 3) Shuffle weights for transposed MMA output (both W13, W2)
542
+ epilogue_tile_m = 128
543
+ w13_shuffled = [
544
+ shuffle_matrix_a(w13_interleaved[i].view(torch.uint8), epilogue_tile_m)
545
+ for i in range(num_experts)
546
+ ]
547
+ w2_shuffled = [
548
+ shuffle_matrix_a(layer.w2_weight[i].view(torch.uint8), epilogue_tile_m)
549
+ for i in range(num_experts)
550
+ ]
551
+
552
+ layer.w13_weight = Parameter(
553
+ torch.stack(w13_shuffled).view(torch.float8_e4m3fn),
554
+ requires_grad=False,
555
+ )
556
+ layer.w2_weight = Parameter(
557
+ torch.stack(w2_shuffled).view(torch.float8_e4m3fn),
558
+ requires_grad=False,
559
+ )
560
+
561
+ # Precompute and register per-expert output scaling factors for FI MoE
562
+ if should_use_flashinfer_trtllm_moe():
563
+ # Note: w13_input_scale and w2_input_scale are scalar Parameters post-reduction
564
+ assert (
565
+ hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None
566
+ )
567
+ assert hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None
568
+ assert (
569
+ hasattr(layer, "w13_weight_scale")
570
+ and layer.w13_weight_scale is not None
571
+ )
572
+ assert (
573
+ hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None
574
+ )
575
+
576
+ input_scale = layer.w13_input_scale.to(torch.float32)
577
+ activation_scale = layer.w2_input_scale.to(torch.float32)
578
+ w13_weight_scale = layer.w13_weight_scale.to(torch.float32)
579
+ w2_weight_scale = layer.w2_weight_scale.to(torch.float32)
580
+
581
+ output1_scales_scalar = (
582
+ w13_weight_scale * input_scale * (1.0 / activation_scale)
583
+ )
584
+ output1_scales_gate_scalar = w13_weight_scale * input_scale
585
+ output2_scales_scalar = activation_scale * w2_weight_scale
586
+
587
+ layer.output1_scales_scalar = Parameter(
588
+ output1_scales_scalar, requires_grad=False
589
+ )
590
+ layer.output1_scales_gate_scalar = Parameter(
591
+ output1_scales_gate_scalar, requires_grad=False
592
+ )
593
+ layer.output2_scales_scalar = Parameter(
594
+ output2_scales_scalar, requires_grad=False
595
+ )
596
+
518
597
  def create_moe_runner(
519
598
  self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
520
599
  ):
@@ -526,6 +605,81 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
526
605
  layer: torch.nn.Module,
527
606
  dispatch_output: StandardDispatchOutput,
528
607
  ) -> CombineInput:
608
+ x = dispatch_output.hidden_states
609
+ topk_output = dispatch_output.topk_output
610
+
611
+ # Fast path: TRT-LLM FP8 per-tensor MoE using BYPASSED TopK routing
612
+ from sglang.srt.layers.moe.topk import TopKOutputChecker
613
+
614
+ if should_use_flashinfer_trtllm_moe() and TopKOutputChecker.format_is_bypassed(
615
+ topk_output
616
+ ):
617
+ router_logits = topk_output.router_logits
618
+ topk_config = topk_output.topk_config
619
+
620
+ # Constraints
621
+ assert (
622
+ self.moe_runner_config.activation == "silu"
623
+ ), "Only silu is supported for flashinfer fp8 moe"
624
+
625
+ from flashinfer import RoutingMethodType
626
+ from flashinfer.fused_moe import trtllm_fp8_per_tensor_scale_moe
627
+
628
+ correction_bias = (
629
+ None
630
+ if topk_config.correction_bias is None
631
+ else topk_config.correction_bias
632
+ )
633
+ # Pre-quantize activations to FP8 per-tensor using provided input scale
634
+ x_fp8, _ = scaled_fp8_quant(x, layer.w13_input_scale)
635
+
636
+ use_routing_scales_on_input = True
637
+ routed_scaling_factor = self.moe_runner_config.routed_scaling_factor
638
+
639
+ # Enforce Llama4 routing for ModelOpt FP8 MoE for now.
640
+ # TODO(brayden): support other routing methods
641
+ assert topk_config.top_k == 1, "ModelOpt FP8 MoE requires top_k==1"
642
+ assert (
643
+ not topk_config.num_expert_group
644
+ ), "ModelOpt FP8 MoE does not support expert grouping"
645
+ assert (
646
+ not topk_config.topk_group
647
+ ), "ModelOpt FP8 MoE does not support grouped top-k"
648
+ routing_method_type = RoutingMethodType.Llama4
649
+
650
+ # FlashInfer TRTLLM requires routing_logits (and bias) to be bfloat16
651
+ routing_logits_cast = router_logits.to(torch.bfloat16)
652
+ routing_bias_cast = (
653
+ None if correction_bias is None else correction_bias.to(torch.bfloat16)
654
+ )
655
+
656
+ output = trtllm_fp8_per_tensor_scale_moe(
657
+ routing_logits=routing_logits_cast,
658
+ routing_bias=routing_bias_cast,
659
+ hidden_states=x_fp8,
660
+ gemm1_weights=layer.w13_weight,
661
+ output1_scales_scalar=layer.output1_scales_scalar,
662
+ output1_scales_gate_scalar=layer.output1_scales_gate_scalar,
663
+ gemm2_weights=layer.w2_weight,
664
+ output2_scales_scalar=layer.output2_scales_scalar,
665
+ num_experts=layer.num_experts,
666
+ top_k=topk_config.top_k,
667
+ n_group=0,
668
+ topk_group=0,
669
+ intermediate_size=layer.w2_weight.shape[2],
670
+ local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
671
+ local_num_experts=layer.num_local_experts,
672
+ routed_scaling_factor=(
673
+ routed_scaling_factor if routed_scaling_factor is not None else 1.0
674
+ ),
675
+ use_routing_scales_on_input=use_routing_scales_on_input,
676
+ tile_tokens_dim=8, # TODO(brayden): use the FI tile calculation
677
+ routing_method_type=routing_method_type,
678
+ )
679
+
680
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
681
+
682
+ return StandardCombineInput(hidden_states=output)
529
683
 
530
684
  quant_info = TritonMoeQuantInfo(
531
685
  w13_weight=layer.w13_weight,
@@ -867,10 +1021,9 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
867
1021
  output_shape = [x_m, w_n]
868
1022
 
869
1023
  # Quantize BF16 or FP16 to (FP4 and interleaved block scale)
870
- x_fp4, x_scale_interleaved = scaled_fp4_quant(x, layer.input_scale_inv)
1024
+ x_fp4, x_scale_interleaved = fp4_quantize(x, layer.input_scale_inv)
871
1025
 
872
1026
  assert x_fp4.dtype == torch.uint8
873
- assert x_scale_interleaved.dtype == torch.float8_e4m3fn
874
1027
  assert layer.weight.dtype == torch.uint8
875
1028
  assert layer.weight_scale_interleaved.dtype == torch.float8_e4m3fn
876
1029
  assert layer.alpha.dtype == torch.float32
@@ -903,7 +1056,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
903
1056
 
904
1057
  def __init__(self, quant_config: ModelOptFp4Config):
905
1058
  self.quant_config = quant_config
906
- if not is_sm100_supported():
1059
+ if not is_blackwell_supported():
907
1060
  raise ValueError(
908
1061
  "Current platform does not support NVFP4"
909
1062
  " quantization. Please use Blackwell and"
@@ -1383,8 +1536,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1383
1536
  alt_stream=None,
1384
1537
  ) -> CombineInput:
1385
1538
 
1386
- from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
1387
-
1388
1539
  x = dispatch_output.hidden_states
1389
1540
  topk_output = dispatch_output.topk_output
1390
1541
 
@@ -1397,6 +1548,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1397
1548
  # Check if this is a FlashInferFP4MoE layer that should handle its own forward
1398
1549
  if hasattr(layer, "gemm1_weights_fp4_shuffled"):
1399
1550
  # This layer was processed with flashinfer TRTLLM - delegate to its own forward
1551
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
1552
+
1400
1553
  return StandardCombineInput(hidden_states=layer.forward(x, topk_output))
1401
1554
 
1402
1555
  if self.enable_flashinfer_cutlass_moe:
@@ -1410,7 +1563,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1410
1563
  output_dtype = x.dtype
1411
1564
  x_sf = None
1412
1565
  if should_use_flashinfer_cutlass_moe_fp4_allgather():
1413
- from flashinfer import fp4_quantize, nvfp4_block_scale_interleave
1566
+ from flashinfer import nvfp4_block_scale_interleave
1414
1567
 
1415
1568
  # Quantize before comm, swizzle after.
1416
1569
  if x.shape[0] > 0:
@@ -1465,6 +1618,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1465
1618
  if forward_shared_experts is not None:
1466
1619
  torch.cuda.current_stream().wait_stream(alt_stream)
1467
1620
 
1621
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
1622
+
1468
1623
  return StandardCombineInput(hidden_states=output)
1469
1624
 
1470
1625
  from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
@@ -1486,6 +1641,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1486
1641
  apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
1487
1642
  ).to(x.dtype)
1488
1643
  # Scale by routed_scaling_factor is fused into select_experts.
1644
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
1645
+
1489
1646
  return StandardCombineInput(hidden_states=output)
1490
1647
 
1491
1648
  def apply_without_routing_weights(
@@ -125,8 +125,13 @@ class RotaryEmbedding(CustomOp):
125
125
  self.cos_sin_cache: torch.Tensor
126
126
  self.register_buffer("cos_sin_cache", cache, persistent=False)
127
127
 
128
+ self._apply_rotary_emb_wrapped = _apply_rotary_emb
129
+
128
130
  if get_global_server_args().rl_on_policy_target == "fsdp":
129
131
  self._forward_method = self.forward_native
132
+ self._apply_rotary_emb_wrapped = torch.compile(dynamic=True)(
133
+ self._apply_rotary_emb_wrapped
134
+ )
130
135
 
131
136
  def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
132
137
  """Compute the inverse frequency."""
@@ -185,14 +190,16 @@ class RotaryEmbedding(CustomOp):
185
190
  query = query.view(num_tokens, -1, self.head_size)
186
191
  query_rot = query[..., : self.rotary_dim]
187
192
  query_pass = query[..., self.rotary_dim :]
188
- query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
193
+ query_rot = self._apply_rotary_emb_wrapped(
194
+ query_rot, cos, sin, self.is_neox_style
195
+ )
189
196
  query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
190
197
 
191
198
  key_shape = key.shape
192
199
  key = key.view(num_tokens, -1, self.head_size)
193
200
  key_rot = key[..., : self.rotary_dim]
194
201
  key_pass = key[..., self.rotary_dim :]
195
- key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
202
+ key_rot = self._apply_rotary_emb_wrapped(key_rot, cos, sin, self.is_neox_style)
196
203
  key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
197
204
  return query, key
198
205
 
@@ -312,10 +319,20 @@ class RotaryEmbedding(CustomOp):
312
319
  query: torch.Tensor,
313
320
  key: torch.Tensor,
314
321
  offsets: Optional[torch.Tensor] = None,
322
+ fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
315
323
  ) -> Tuple[torch.Tensor, torch.Tensor]:
316
- # TODO: make a wrapper, and XPU will implement this kernel later.
317
- self.cos_sin_cache = self.cos_sin_cache.to(query.device)
318
- return self.forward_native(positions, query, key, offsets)
324
+ assert (
325
+ fused_set_kv_buffer_arg is None
326
+ ), "fused_set_kv_buffer_arg is not supported for xpu implementation"
327
+ positions = torch.add(positions, offsets) if offsets is not None else positions
328
+ return torch.ops.sgl_kernel.rotary_embedding(
329
+ positions,
330
+ query,
331
+ key,
332
+ self.head_size,
333
+ self.cos_sin_cache,
334
+ self.is_neox_style,
335
+ )
319
336
 
320
337
 
321
338
  class LinearScalingRotaryEmbedding(RotaryEmbedding):
@@ -1070,6 +1087,7 @@ def _triton_mrope_forward(
1070
1087
  mrope_section_h: tl.constexpr,
1071
1088
  mrope_section_w: tl.constexpr,
1072
1089
  is_interleaved: tl.constexpr,
1090
+ is_neox_style: tl.constexpr,
1073
1091
  ):
1074
1092
  # Adapted from
1075
1093
  # https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py
@@ -1124,51 +1142,99 @@ def _triton_mrope_forward(
1124
1142
  # program instance (i.e. for the current token) separately
1125
1143
  # ####################################################################
1126
1144
  # left half of the head
1127
- first_half_q_offsets = (
1128
- tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
1129
- )
1130
- first_half_k_offsets = (
1131
- tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
1132
- )
1133
- first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
1134
- tl.arange(0, pad_hd // 2)[None, :] < rd // 2
1135
- )
1136
- first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
1137
- tl.arange(0, pad_hd // 2)[None, :] < rd // 2
1138
- )
1145
+ if is_neox_style:
1146
+ first_half_q_offsets = (
1147
+ tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
1148
+ )
1149
+ first_half_k_offsets = (
1150
+ tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
1151
+ )
1152
+ first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
1153
+ tl.arange(0, pad_hd // 2)[None, :] < rd // 2
1154
+ )
1155
+ first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
1156
+ tl.arange(0, pad_hd // 2)[None, :] < rd // 2
1157
+ )
1139
1158
 
1140
- q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
1141
- sin_row.dtype
1142
- )
1143
- k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
1144
- sin_row.dtype
1145
- )
1159
+ q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
1160
+ sin_row.dtype
1161
+ )
1162
+ k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
1163
+ sin_row.dtype
1164
+ )
1146
1165
 
1147
- # right half of the head
1148
- second_half_q_offsets = first_half_q_offsets + (rd // 2)
1149
- second_half_k_offsets = first_half_k_offsets + (rd // 2)
1150
- second_q_mask = first_q_mask
1151
- second_k_mask = first_k_mask
1166
+ # right half of the head
1167
+ second_half_q_offsets = first_half_q_offsets + (rd // 2)
1168
+ second_half_k_offsets = first_half_k_offsets + (rd // 2)
1169
+ second_q_mask = first_q_mask
1170
+ second_k_mask = first_k_mask
1171
+
1172
+ q_tile_2 = tl.load(
1173
+ q_ptr + second_half_q_offsets, mask=second_q_mask, other=0
1174
+ ).to(sin_row.dtype)
1175
+ k_tile_2 = tl.load(
1176
+ k_ptr + second_half_k_offsets, mask=second_k_mask, other=0
1177
+ ).to(sin_row.dtype)
1178
+
1179
+ # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
1180
+ # Since cos and sin are now half-size,
1181
+ # we use the same cos_row and sin_row for both halves
1182
+ new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
1183
+ tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
1184
+ new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
1185
+ tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
1186
+
1187
+ new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
1188
+ tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
1189
+ new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
1190
+ tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
1191
+ else:
1192
+ base_q = tl.arange(0, pad_n_qh)[:, None] * hd
1193
+ base_k = tl.arange(0, pad_n_kh)[:, None] * hd
1194
+ even_idx = 2 * tl.arange(0, pad_hd // 2)[None, :]
1195
+ odd_idx = even_idx + 1
1196
+
1197
+ even_q_offsets = base_q + even_idx
1198
+ odd_q_offsets = base_q + odd_idx
1199
+ even_k_offsets = base_k + even_idx
1200
+ odd_k_offsets = base_k + odd_idx
1201
+
1202
+ idx_mask = tl.arange(0, pad_hd // 2)[None, :] < (rd // 2)
1203
+ qn_mask = tl.arange(0, pad_n_qh)[:, None] < n_qh
1204
+ kn_mask = tl.arange(0, pad_n_kh)[:, None] < n_kh
1205
+
1206
+ even_q_mask = qn_mask & idx_mask
1207
+ odd_q_mask = qn_mask & idx_mask
1208
+ even_k_mask = kn_mask & idx_mask
1209
+ odd_k_mask = kn_mask & idx_mask
1210
+
1211
+ q_tile_1 = tl.load(q_ptr + even_q_offsets, mask=even_q_mask, other=0).to(
1212
+ sin_row.dtype
1213
+ )
1214
+ k_tile_1 = tl.load(k_ptr + even_k_offsets, mask=even_k_mask, other=0).to(
1215
+ sin_row.dtype
1216
+ )
1152
1217
 
1153
- q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
1154
- sin_row.dtype
1155
- )
1156
- k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
1157
- sin_row.dtype
1158
- )
1218
+ q_tile_2 = tl.load(q_ptr + odd_q_offsets, mask=odd_q_mask, other=0).to(
1219
+ sin_row.dtype
1220
+ )
1221
+ k_tile_2 = tl.load(k_ptr + odd_k_offsets, mask=odd_k_mask, other=0).to(
1222
+ sin_row.dtype
1223
+ )
1159
1224
 
1160
- # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
1161
- # Since cos and sin are now half-size,
1162
- # we use the same cos_row and sin_row for both halves
1163
- new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
1164
- tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
1165
- new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
1166
- tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
1225
+ # y = [x_even, x_odd] * [cos, cos] + [-x_odd, x_even] * [sin, sin]
1226
+ # NeoX-style rotary embedding:
1227
+ # Each (even, odd) channel pair forms one rotation arm.
1228
+ # cos_row and sin_row each have length rd//2, shared across all (even, odd) pairs.
1229
+ new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
1230
+ tl.store(q_ptr + even_q_offsets, new_q_tile_1, mask=even_q_mask)
1231
+ new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
1232
+ tl.store(q_ptr + odd_q_offsets, new_q_tile_2, mask=odd_q_mask)
1167
1233
 
1168
- new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
1169
- tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
1170
- new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
1171
- tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
1234
+ new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
1235
+ tl.store(k_ptr + even_k_offsets, new_k_tile_1, mask=even_k_mask)
1236
+ new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
1237
+ tl.store(k_ptr + odd_k_offsets, new_k_tile_2, mask=odd_k_mask)
1172
1238
 
1173
1239
 
1174
1240
  def triton_mrope(
@@ -1180,6 +1246,7 @@ def triton_mrope(
1180
1246
  head_size: int,
1181
1247
  rotary_dim: int,
1182
1248
  mrope_interleaved: bool,
1249
+ is_neox_style: bool,
1183
1250
  ) -> tuple[torch.Tensor, torch.Tensor]:
1184
1251
  """The mrope triton kernel.
1185
1252
 
@@ -1230,6 +1297,7 @@ def triton_mrope(
1230
1297
  mrope_section[1],
1231
1298
  mrope_section[2],
1232
1299
  mrope_interleaved,
1300
+ is_neox_style,
1233
1301
  )
1234
1302
  return q, k
1235
1303
 
@@ -1373,6 +1441,7 @@ class MRotaryEmbedding(RotaryEmbedding):
1373
1441
  else:
1374
1442
  return self._forward_native(positions, query, key)
1375
1443
 
1444
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
1376
1445
  def _forward_triton(
1377
1446
  self,
1378
1447
  positions: torch.Tensor,
@@ -1391,6 +1460,7 @@ class MRotaryEmbedding(RotaryEmbedding):
1391
1460
  if positions.ndim == 2:
1392
1461
  assert self.mrope_section
1393
1462
 
1463
+ torch._dynamo.graph_break()
1394
1464
  q, k = triton_mrope(
1395
1465
  query,
1396
1466
  key,
@@ -1400,7 +1470,9 @@ class MRotaryEmbedding(RotaryEmbedding):
1400
1470
  self.head_size,
1401
1471
  self.rotary_dim,
1402
1472
  self.mrope_interleaved,
1473
+ self.is_neox_style,
1403
1474
  )
1475
+ torch._dynamo.graph_break()
1404
1476
 
1405
1477
  return q.reshape(query_shape), k.reshape(key_shape)
1406
1478
 
@@ -205,3 +205,12 @@ class LoRARegistry:
205
205
  Returns the total number of LoRA adapters currently registered.
206
206
  """
207
207
  return len(self._registry)
208
+
209
+ def get_all_adapters(self) -> Dict[str, LoRARef]:
210
+ """
211
+ Returns a dictionary of all registered LoRA adapters.
212
+
213
+ Returns:
214
+ Dict[str, LoRARef]: A dictionary mapping LoRA names to LoRARef objects.
215
+ """
216
+ return dict(self._registry)
@@ -0,0 +1,122 @@
1
+ import asyncio
2
+ import logging
3
+ from concurrent.futures import ThreadPoolExecutor
4
+ from functools import partial
5
+ from typing import Any, Dict, List, Optional, Union
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class AsyncMMDataProcessor:
11
+ """
12
+ Async wrapper for a multimodal processor.
13
+
14
+ Behavior:
15
+ - If the underlying processor exposes `process_mm_data_async`, call/await it directly.
16
+ - Otherwise, fall back to running a synchronous `process_mm_data` in a thread pool.
17
+ - Optionally guard per-call concurrency via an asyncio.Semaphore.
18
+ - Optionally enforce per-call timeout via asyncio.wait_for.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ mm_processor: Any,
24
+ *,
25
+ max_concurrent_calls: Optional[int] = None,
26
+ timeout_s: Optional[float] = None,
27
+ ) -> None:
28
+ """
29
+ Args:
30
+ mm_processor: An object exposing either
31
+ - async def process_mm_data_async(...): -> Dict[str, Any]
32
+ or
33
+ - def process_mm_data(...): -> Dict[str, Any]
34
+ max_concurrent_calls: Optional concurrency cap for per-call execution.
35
+ timeout_s: Optional timeout (seconds) for each `process()` call.
36
+ """
37
+ self.mm_processor = mm_processor
38
+ self.timeout_s = timeout_s
39
+
40
+ # Concurrency guard (None -> unlimited)
41
+ self.semaphore = (
42
+ asyncio.Semaphore(max_concurrent_calls) if max_concurrent_calls else None
43
+ )
44
+
45
+ # Detect async path; if missing, prepare a fallback executor for sync path
46
+ self._proc_async = getattr(mm_processor, "process_mm_data_async", None)
47
+ self.is_async = asyncio.iscoroutinefunction(self._proc_async)
48
+ self.fallback_exec: Optional[ThreadPoolExecutor] = (
49
+ ThreadPoolExecutor(max_workers=max_concurrent_calls)
50
+ if not self.is_async
51
+ else None
52
+ )
53
+
54
+ async def process(
55
+ self,
56
+ *,
57
+ image_data: Optional[List[Union[str, bytes]]] = None,
58
+ audio_data: Optional[List[Union[str, bytes]]] = None,
59
+ input_text_or_ids: Union[str, List[int], None] = None,
60
+ request_obj: Any,
61
+ **kwargs: Any,
62
+ ) -> Dict[str, Any]:
63
+ """
64
+ Public entrypoint: process a single multimodal request without blocking the event loop.
65
+ """
66
+
67
+ async def _invoke() -> Dict[str, Any]:
68
+ if self.is_async:
69
+ # Native async implementation
70
+ return await self._proc_async(
71
+ image_data=image_data,
72
+ audio_data=audio_data,
73
+ input_text=input_text_or_ids,
74
+ request_obj=request_obj,
75
+ **kwargs,
76
+ )
77
+
78
+ # Synchronous fallback
79
+ sync_fn = getattr(self.mm_processor, "process_mm_data", None)
80
+ if not callable(sync_fn):
81
+ raise RuntimeError(
82
+ "mm_processor has neither 'process_mm_data_async' nor 'process_mm_data'."
83
+ )
84
+ loop = asyncio.get_running_loop()
85
+ fn = partial(
86
+ sync_fn,
87
+ image_data=image_data,
88
+ audio_data=audio_data,
89
+ input_text=input_text_or_ids,
90
+ request_obj=request_obj,
91
+ **kwargs,
92
+ )
93
+ return await loop.run_in_executor(self.fallback_exec, fn)
94
+
95
+ # Apply optional concurrency guard
96
+ if self.semaphore is not None:
97
+ async with self.semaphore:
98
+ if self.timeout_s is not None:
99
+ return await asyncio.wait_for(_invoke(), timeout=self.timeout_s)
100
+ return await _invoke()
101
+
102
+ # No concurrency guard
103
+ if self.timeout_s is not None:
104
+ return await asyncio.wait_for(_invoke(), timeout=self.timeout_s)
105
+ return await _invoke()
106
+
107
+ def shutdown(self) -> None:
108
+ """Gracefully shutdown resources owned by this wrapper."""
109
+ try:
110
+ if self.fallback_exec:
111
+ self.fallback_exec.shutdown(wait=False)
112
+ except Exception:
113
+ logger.exception(
114
+ "Error while shutting down fallback executor in AsyncMMDataProcessor"
115
+ )
116
+
117
+ def __del__(self):
118
+ # Best-effort shutdown
119
+ try:
120
+ self.shutdown()
121
+ except Exception:
122
+ pass