sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,64 @@
1
+ from typing import Callable, List, Tuple
2
+
3
+ import torch
4
+
5
+ LoaderFunction = Callable[[torch.Tensor, torch.Tensor], None]
6
+
7
+
8
+ def mamba_v2_sharded_weight_loader(
9
+ shard_spec: List[Tuple[int, int, float]],
10
+ tp_size: int,
11
+ tp_rank: int,
12
+ ) -> LoaderFunction:
13
+ """Create a weight loader for mamba v2. This ensures that the projections
14
+ are correctly sharded so that they can be split into x, B, C. It also
15
+ ensures the the all the groups corresponding to a head shard is placed
16
+ together with it.
17
+ """
18
+
19
+ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
20
+
21
+ # - track boundary of (sharded) param, and loaded_weight, respectively
22
+ boundary, loaded_boundary = 0, 0
23
+
24
+ # - iterate over the shard specs
25
+ for full_dim, extra, duplicate_groups in shard_spec:
26
+ # - full dim is the model dim (before TP).
27
+ # - extra > 0, means there is expected overall increase
28
+ # of dimensions. This is so because of replication.
29
+ # - ratio is used map the tp_rank to the actual shard
30
+ # rank. This is useful when there is replication of
31
+ # groups to accompany head shards.
32
+
33
+ # - size of the loaded shard
34
+ shard_size = full_dim // tp_size
35
+
36
+ # - compute the rank into the loaded shard.
37
+ # - if there is replication, different TP shards will
38
+ # take from the same rank.
39
+ # NOTE: currently we only support duplication
40
+ # in the case where num_groups == 1
41
+ rank = 0 if duplicate_groups else tp_rank
42
+
43
+ # - leftmost boundary index into loaded weight.
44
+ loaded_skip = rank * shard_size
45
+ loaded_start_idx = loaded_boundary + loaded_skip
46
+
47
+ # - take these many dims from the loaded weight.
48
+ take = min(shard_size, full_dim - extra - loaded_skip)
49
+
50
+ # - always shard on dim 0
51
+ # - the ignore is for a mundane mypy error as it does not
52
+ # seem to handle slices well.
53
+ # https://github.com/python/mypy/issues/2410
54
+ param.data[
55
+ boundary : (boundary + take), ... # type: ignore[misc]
56
+ ] = loaded_weight[
57
+ loaded_start_idx : (loaded_start_idx + take) # type: ignore[misc]
58
+ ] # type: ignore[misc]
59
+
60
+ # move indexing boundaries
61
+ boundary += shard_size
62
+ loaded_boundary += full_dim - extra
63
+
64
+ return loader
@@ -193,10 +193,13 @@ class TorchNativeAttnBackend(AttentionBackend):
193
193
  else:
194
194
  o = torch.empty_like(q)
195
195
 
196
+ if layer.is_cross_attention:
197
+ cache_loc = forward_batch.encoder_out_cache_loc
198
+ else:
199
+ cache_loc = forward_batch.out_cache_loc
200
+
196
201
  if save_kv_cache:
197
- forward_batch.token_to_kv_pool.set_kv_buffer(
198
- layer, forward_batch.out_cache_loc, k, v
199
- )
202
+ forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
200
203
 
201
204
  use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
202
205
 
@@ -241,10 +244,13 @@ class TorchNativeAttnBackend(AttentionBackend):
241
244
  else:
242
245
  o = torch.empty_like(q)
243
246
 
247
+ if layer.is_cross_attention:
248
+ cache_loc = forward_batch.encoder_out_cache_loc
249
+ else:
250
+ cache_loc = forward_batch.out_cache_loc
251
+
244
252
  if save_kv_cache:
245
- forward_batch.token_to_kv_pool.set_kv_buffer(
246
- layer, forward_batch.out_cache_loc, k, v
247
- )
253
+ forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
248
254
 
249
255
  use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
250
256
 
@@ -45,12 +45,21 @@ TRTLLM_BLOCK_CONSTRAINT = 128
45
45
  global_zero_init_workspace_buffer = None
46
46
 
47
47
 
48
+ @dataclass
49
+ class TRTLLMMLAPrefillMetadata:
50
+ """Metadata for TRTLLM MLA prefill operations."""
51
+
52
+ max_seq_len: int
53
+ cum_seq_lens: torch.Tensor
54
+ seq_lens: torch.Tensor
55
+
56
+
48
57
  @dataclass
49
58
  class TRTLLMMLADecodeMetadata:
50
59
  """Metadata for TRTLLM MLA decode operations."""
51
60
 
52
- workspace: Optional[torch.Tensor] = None
53
61
  block_kv_indices: Optional[torch.Tensor] = None
62
+ max_seq_len: Optional[int] = None
54
63
 
55
64
 
56
65
  class TRTLLMMLABackend(FlashInferMLAAttnBackend):
@@ -100,7 +109,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
100
109
  # CUDA graph state
101
110
  self.decode_cuda_graph_metadata = {}
102
111
  self.decode_cuda_graph_kv_indices = None
103
- self.forward_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
112
+ self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None
113
+ self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
104
114
 
105
115
  def _calc_padded_blocks(self, max_seq_len: int) -> int:
106
116
  """
@@ -176,9 +186,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
176
186
  self.decode_cuda_graph_kv_indices = torch.full(
177
187
  (max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
178
188
  )
179
- self.decode_cuda_graph_workspace = torch.empty(
180
- self.workspace_size, dtype=torch.int8, device=self.device
181
- )
182
189
 
183
190
  super().init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf)
184
191
 
@@ -207,8 +214,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
207
214
  )
208
215
 
209
216
  # Custom fast-path for decode/idle.
210
- max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item())
211
- block_kv_indices = self.decode_cuda_graph_kv_indices[:bs, :max_seqlen_pad]
217
+ # Capture with full width so future longer sequences are safe during replay
218
+ max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
219
+ block_kv_indices = self.decode_cuda_graph_kv_indices[:bs, :max_blocks_per_seq]
212
220
 
213
221
  create_flashmla_kv_indices_triton[(bs,)](
214
222
  self.req_to_token,
@@ -217,16 +225,22 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
217
225
  None,
218
226
  block_kv_indices,
219
227
  self.req_to_token.stride(0),
220
- max_seqlen_pad,
228
+ max_blocks_per_seq,
221
229
  NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
222
230
  PAGED_SIZE=self.page_size,
223
231
  )
224
232
 
233
+ # Record the true maximum sequence length for this capture batch so that
234
+ # the kernel launch path (which requires an int not a tensor) can reuse
235
+ # it safely during both capture and replay.
236
+ max_seq_len_val = int(seq_lens.max().item())
237
+
225
238
  metadata = TRTLLMMLADecodeMetadata(
226
- self.decode_cuda_graph_workspace, block_kv_indices
239
+ block_kv_indices,
240
+ max_seq_len_val,
227
241
  )
228
242
  self.decode_cuda_graph_metadata[bs] = metadata
229
- self.forward_metadata = metadata
243
+ self.forward_decode_metadata = metadata
230
244
 
231
245
  def init_forward_metadata_replay_cuda_graph(
232
246
  self,
@@ -268,6 +282,13 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
268
282
  PAGED_SIZE=self.page_size,
269
283
  )
270
284
 
285
+ # Update stored max_seq_len so subsequent kernel calls use the correct value
286
+ # Prefer CPU tensor to avoid GPU synchronization when available.
287
+ if seq_lens_cpu is not None:
288
+ metadata.max_seq_len = int(seq_lens_cpu.max().item())
289
+ else:
290
+ metadata.max_seq_len = int(seq_lens.max().item())
291
+
271
292
  def get_cuda_graph_seq_len_fill_value(self) -> int:
272
293
  """Get the fill value for sequence lengths in CUDA graph."""
273
294
  return 1
@@ -275,30 +296,52 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
275
296
  def init_forward_metadata(self, forward_batch: ForwardBatch):
276
297
  """Initialize the metadata for a forward pass."""
277
298
  # Delegate to parent for non-decode modes.
278
- if not forward_batch.forward_mode.is_decode_or_idle():
279
- return super().init_forward_metadata(forward_batch)
299
+ if (
300
+ forward_batch.forward_mode.is_extend()
301
+ and not forward_batch.forward_mode.is_target_verify()
302
+ and not forward_batch.forward_mode.is_draft_extend()
303
+ ):
304
+ seq_lens = forward_batch.seq_lens - forward_batch.extend_prefix_lens
305
+ cum_seq_lens_q = torch.cat(
306
+ (
307
+ torch.tensor([0], device=forward_batch.seq_lens.device),
308
+ torch.cumsum(seq_lens, dim=0),
309
+ )
310
+ ).int()
311
+ max_seq_len = max(forward_batch.extend_seq_lens_cpu)
312
+ self.forward_prefill_metadata = TRTLLMMLAPrefillMetadata(
313
+ max_seq_len,
314
+ cum_seq_lens_q,
315
+ seq_lens,
316
+ )
317
+ elif forward_batch.forward_mode.is_decode_or_idle():
318
+ bs = forward_batch.batch_size
280
319
 
281
- bs = forward_batch.batch_size
320
+ # Get maximum sequence length.
321
+ if getattr(forward_batch, "seq_lens_cpu", None) is not None:
322
+ max_seq = forward_batch.seq_lens_cpu.max().item()
323
+ else:
324
+ max_seq = forward_batch.seq_lens.max().item()
282
325
 
283
- # Get maximum sequence length.
284
- if getattr(forward_batch, "seq_lens_cpu", None) is not None:
285
- max_seq = forward_batch.seq_lens_cpu.max().item()
326
+ max_seqlen_pad = self._calc_padded_blocks(max_seq)
327
+ block_kv_indices = self._create_block_kv_indices(
328
+ bs,
329
+ max_seqlen_pad,
330
+ forward_batch.req_pool_indices,
331
+ forward_batch.seq_lens,
332
+ forward_batch.seq_lens.device,
333
+ )
334
+
335
+ max_seq_len_val = int(max_seq)
336
+ self.forward_decode_metadata = TRTLLMMLADecodeMetadata(
337
+ block_kv_indices, max_seq_len_val
338
+ )
339
+ forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata
286
340
  else:
287
- max_seq = forward_batch.seq_lens.max().item()
288
-
289
- max_seqlen_pad = self._calc_padded_blocks(max_seq)
290
- block_kv_indices = self._create_block_kv_indices(
291
- bs,
292
- max_seqlen_pad,
293
- forward_batch.req_pool_indices,
294
- forward_batch.seq_lens,
295
- forward_batch.seq_lens.device,
296
- )
341
+ return super().init_forward_metadata(forward_batch)
297
342
 
298
- self.forward_metadata = TRTLLMMLADecodeMetadata(
299
- self.workspace_buffer, block_kv_indices
300
- )
301
- forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
343
+ def init_mha_chunk_metadata(self, forward_batch: ForwardBatch):
344
+ super().init_mha_chunk_metadata(forward_batch, disable_flashinfer_ragged=True)
302
345
 
303
346
  def quantize_and_rope_for_fp8(
304
347
  self,
@@ -442,7 +485,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
442
485
  # Get metadata
443
486
  metadata = (
444
487
  getattr(forward_batch, "decode_trtllm_mla_metadata", None)
445
- or self.forward_metadata
488
+ or self.forward_decode_metadata
446
489
  )
447
490
 
448
491
  # Scale computation for TRTLLM MLA kernel BMM1 operation:
@@ -465,20 +508,67 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
465
508
  raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
466
509
  query=query,
467
510
  kv_cache=kv_cache,
468
- workspace_buffer=metadata.workspace,
511
+ workspace_buffer=self.workspace_buffer,
469
512
  qk_nope_head_dim=self.qk_nope_head_dim,
470
513
  kv_lora_rank=self.kv_lora_rank,
471
514
  qk_rope_head_dim=self.qk_rope_head_dim,
472
515
  block_tables=metadata.block_kv_indices,
473
516
  seq_lens=forward_batch.seq_lens.to(torch.int32),
474
- max_seq_len=int(metadata.block_kv_indices.shape[1] * self.page_size),
517
+ max_seq_len=metadata.max_seq_len,
475
518
  bmm1_scale=bmm1_scale,
476
519
  )
477
520
 
478
- # Extract value projection part and reshape
479
- raw_out_v = raw_out[..., : layer.v_head_dim].contiguous()
480
- output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim)
521
+ # Reshape output directly without slicing
522
+ output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
523
+ return output
524
+
525
+ def forward_extend(
526
+ self,
527
+ q: torch.Tensor,
528
+ k: torch.Tensor,
529
+ v: torch.Tensor,
530
+ layer: RadixAttention,
531
+ forward_batch: ForwardBatch,
532
+ save_kv_cache: bool = True,
533
+ q_rope: Optional[torch.Tensor] = None,
534
+ k_rope: Optional[torch.Tensor] = None,
535
+ ) -> torch.Tensor:
536
+ if (
537
+ forward_batch.forward_mode.is_target_verify()
538
+ or forward_batch.forward_mode.is_draft_extend()
539
+ ):
540
+ return super().forward_extend(
541
+ q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
542
+ )
481
543
 
544
+ if not forward_batch.attn_attend_prefix_cache:
545
+ q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
546
+ k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
547
+ v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
548
+ output = flashinfer.prefill.trtllm_ragged_attention_deepseek(
549
+ query=q,
550
+ key=k,
551
+ value=v,
552
+ workspace_buffer=self.workspace_buffer,
553
+ seq_lens=self.forward_prefill_metadata.seq_lens,
554
+ max_q_len=self.forward_prefill_metadata.max_seq_len,
555
+ max_kv_len=self.forward_prefill_metadata.max_seq_len,
556
+ bmm1_scale=layer.scaling,
557
+ bmm2_scale=1.0,
558
+ o_sf_scale=1.0,
559
+ batch_size=forward_batch.batch_size,
560
+ window_left=-1,
561
+ cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
562
+ cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens,
563
+ enable_pdl=False,
564
+ is_causal=True,
565
+ return_lse=forward_batch.mha_return_lse,
566
+ )
567
+ else:
568
+ # replace with trtllm ragged attention once accuracy is resolved.
569
+ output = super().forward_extend(
570
+ q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
571
+ )
482
572
  return output
483
573
 
484
574
 
@@ -64,8 +64,7 @@ def get_wave_kernel(
64
64
  subs=hyperparams_0,
65
65
  canonicalize=True,
66
66
  run_bench=False,
67
- use_buffer_load_ops=True,
68
- use_buffer_store_ops=True,
67
+ use_buffer_ops=True,
69
68
  waves_per_eu=2,
70
69
  dynamic_symbols=dynamic_symbols_0,
71
70
  wave_runtime=True,
@@ -77,8 +76,7 @@ def get_wave_kernel(
77
76
  subs=hyperparams_1,
78
77
  canonicalize=True,
79
78
  run_bench=False,
80
- use_buffer_load_ops=False,
81
- use_buffer_store_ops=False,
79
+ use_buffer_ops=False,
82
80
  waves_per_eu=4,
83
81
  dynamic_symbols=dynamic_symbols_1,
84
82
  wave_runtime=True,
@@ -67,11 +67,9 @@ def get_wave_kernel(
67
67
  schedule=SchedulingType.NONE,
68
68
  use_scheduling_barriers=False,
69
69
  dynamic_symbols=dynamic_symbols,
70
- use_buffer_load_ops=True,
71
- use_buffer_store_ops=True,
70
+ use_buffer_ops=True,
72
71
  waves_per_eu=2,
73
72
  denorm_fp_math_f32="preserve-sign",
74
- gpu_native_math_precision=True,
75
73
  wave_runtime=True,
76
74
  )
77
75
  options = set_default_run_config(options)
@@ -42,10 +42,24 @@ from sglang.srt.layers.moe import (
42
42
  )
43
43
  from sglang.srt.managers.schedule_batch import global_server_args_dict
44
44
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
45
- from sglang.srt.utils import is_cuda, is_flashinfer_available, is_sm100_supported
45
+ from sglang.srt.utils import (
46
+ get_bool_env_var,
47
+ is_cuda,
48
+ is_flashinfer_available,
49
+ is_gfx95_supported,
50
+ is_hip,
51
+ is_sm90_supported,
52
+ is_sm100_supported,
53
+ )
46
54
 
47
55
  _is_flashinfer_available = is_flashinfer_available()
56
+ _is_sm90_supported = is_cuda() and is_sm90_supported()
48
57
  _is_sm100_supported = is_cuda() and is_sm100_supported()
58
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
59
+ _is_gfx95_supported = is_gfx95_supported()
60
+
61
+ if _use_aiter and _is_gfx95_supported:
62
+ from sglang.srt.layers.quantization.rocm_mxfp4_utils import fused_rms_mxfp4_quant
49
63
 
50
64
  FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048
51
65
 
@@ -201,6 +215,7 @@ class LayerCommunicator:
201
215
  hidden_states: torch.Tensor,
202
216
  residual: torch.Tensor,
203
217
  forward_batch: ForwardBatch,
218
+ qaunt_format: str = "",
204
219
  ):
205
220
  if hidden_states.shape[0] == 0:
206
221
  residual = hidden_states
@@ -218,11 +233,34 @@ class LayerCommunicator:
218
233
  else:
219
234
  if residual is None:
220
235
  residual = hidden_states
221
- hidden_states = self.input_layernorm(hidden_states)
236
+
237
+ if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format):
238
+ hidden_states = fused_rms_mxfp4_quant(
239
+ hidden_states,
240
+ self.input_layernorm.weight,
241
+ self.input_layernorm.variance_epsilon,
242
+ None,
243
+ None,
244
+ None,
245
+ None,
246
+ )
247
+ else:
248
+ hidden_states = self.input_layernorm(hidden_states)
222
249
  else:
223
- hidden_states, residual = self.input_layernorm(
224
- hidden_states, residual
225
- )
250
+ if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format):
251
+ hidden_states, residual = fused_rms_mxfp4_quant(
252
+ hidden_states,
253
+ self.input_layernorm.weight,
254
+ self.input_layernorm.variance_epsilon,
255
+ None,
256
+ None,
257
+ None,
258
+ residual,
259
+ )
260
+ else:
261
+ hidden_states, residual = self.input_layernorm(
262
+ hidden_states, residual
263
+ )
226
264
 
227
265
  hidden_states = self._communicate_simple_fn(
228
266
  hidden_states=hidden_states,
@@ -484,11 +522,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
484
522
  # According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
485
523
  # We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
486
524
  if (
487
- _is_sm100_supported
525
+ (_is_sm100_supported or _is_sm90_supported)
488
526
  and _is_flashinfer_available
489
527
  and hasattr(layernorm, "forward_with_allreduce_fusion")
490
528
  and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
491
- and hidden_states.shape[0] <= 2048
529
+ and hidden_states.shape[0] <= 4096
492
530
  ):
493
531
  hidden_states, residual = layernorm.forward_with_allreduce_fusion(
494
532
  hidden_states, residual
@@ -18,6 +18,7 @@ from typing import Optional, Tuple, Union
18
18
 
19
19
  import torch
20
20
  import torch.nn as nn
21
+ from packaging.version import Version
21
22
 
22
23
  from sglang.srt.custom_op import CustomOp
23
24
  from sglang.srt.utils import (
@@ -25,35 +26,41 @@ from sglang.srt.utils import (
25
26
  get_bool_env_var,
26
27
  is_cpu,
27
28
  is_cuda,
29
+ is_flashinfer_available,
28
30
  is_hip,
29
31
  is_npu,
32
+ is_xpu,
30
33
  supports_custom_op,
31
34
  )
32
35
 
33
36
  _is_cuda = is_cuda()
37
+ _is_flashinfer_available = is_flashinfer_available()
34
38
  _is_hip = is_hip()
35
39
  _is_npu = is_npu()
36
40
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
37
41
  _is_cpu_amx_available = cpu_has_amx_support()
38
42
  _is_cpu = is_cpu()
43
+ _is_xpu = is_xpu()
39
44
 
40
45
  if _is_cuda:
41
- from sgl_kernel import (
42
- fused_add_rmsnorm,
43
- gemma_fused_add_rmsnorm,
44
- gemma_rmsnorm,
45
- rmsnorm,
46
- )
46
+ if _is_flashinfer_available:
47
+ from flashinfer.norm import fused_add_rmsnorm
48
+ else:
49
+ from sgl_kernel import fused_add_rmsnorm
50
+ from sgl_kernel import gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm
47
51
 
48
52
  if _use_aiter:
49
53
  from aiter import rmsnorm2d_fwd as rms_norm
50
54
  from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm
51
55
  elif _is_hip:
56
+ import vllm
52
57
  from vllm._custom_ops import fused_add_rms_norm, rms_norm
53
58
 
59
+ _vllm_version = Version(vllm.__version__)
60
+
54
61
  logger = logging.getLogger(__name__)
55
62
 
56
- if is_npu():
63
+ if _is_npu:
57
64
  import torch_npu
58
65
 
59
66
 
@@ -127,8 +134,21 @@ class RMSNorm(CustomOp):
127
134
  # NOTE: Remove this if aiter kernel supports discontinuous input
128
135
  x = x.contiguous()
129
136
  if residual is not None:
130
- fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
131
- return x, residual
137
+ if _vllm_version < Version("0.9"):
138
+ fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
139
+ return x, residual
140
+ else:
141
+ residual_out = torch.empty_like(x)
142
+ output = torch.empty_like(x)
143
+ fused_add_rms_norm(
144
+ output,
145
+ x,
146
+ residual_out,
147
+ residual,
148
+ self.weight.data,
149
+ self.variance_epsilon,
150
+ )
151
+ return output, residual_out
132
152
  out = torch.empty_like(x)
133
153
  rms_norm(out, x, self.weight.data, self.variance_epsilon)
134
154
  return out
@@ -266,28 +286,50 @@ class GemmaRMSNorm(CustomOp):
266
286
  out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
267
287
  return out
268
288
 
289
+ def forward_npu(
290
+ self,
291
+ x: torch.Tensor,
292
+ residual: Optional[torch.Tensor] = None,
293
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
294
+ if residual is not None:
295
+ x = x + residual
296
+ residual = x
297
+
298
+ x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon)
299
+ return x if residual is None else (x, residual)
269
300
 
270
- class Gemma3RMSNorm(nn.Module):
301
+
302
+ class Gemma3RMSNorm(CustomOp):
271
303
  def __init__(self, dim: int, eps: float = 1e-6):
272
304
  super().__init__()
273
305
  self.eps = eps
274
306
  self.weight = nn.Parameter(torch.zeros(dim))
307
+ # Re-dispatch
275
308
 
276
309
  def _norm(self, x):
277
310
  return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
278
311
 
279
- def forward(self, x):
312
+ def forward_native(self, x):
280
313
  output = self._norm(x.float())
281
314
  # Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
282
315
  # See https://github.com/huggingface/transformers/pull/29402
283
316
  output = output * (1.0 + self.weight.float())
284
317
  return output.type_as(x)
285
318
 
319
+ def forward_cuda(self, x):
320
+ return self.forward_native(x)
321
+
322
+ def forward_npu(self, x):
323
+ output, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.eps)
324
+ return output
325
+
286
326
  def extra_repr(self):
287
327
  return f"{tuple(self.weight.shape)}, eps={self.eps}"
288
328
 
289
329
 
290
- if not (_is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available)):
330
+ if not (
331
+ _is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_xpu
332
+ ):
291
333
  logger.info(
292
334
  "sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
293
335
  )
@@ -46,10 +46,12 @@ from sglang.srt.model_executor.forward_batch_info import (
46
46
  ForwardBatch,
47
47
  ForwardMode,
48
48
  )
49
- from sglang.srt.utils import dump_to_file, use_intel_amx_backend
49
+ from sglang.srt.utils import dump_to_file, is_npu, use_intel_amx_backend
50
50
 
51
51
  logger = logging.getLogger(__name__)
52
52
 
53
+ _is_npu = is_npu()
54
+
53
55
 
54
56
  @dataclasses.dataclass
55
57
  class LogitsProcessorOutput:
@@ -61,7 +63,7 @@ class LogitsProcessorOutput:
61
63
  hidden_states: Optional[torch.Tensor] = None
62
64
 
63
65
  ## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler
64
- # The logprobs of the next tokens. shape: [#seq]
66
+ # he log probs of output tokens, if RETURN_ORIGINAL_LOGPROB = True, will get the log probs before applying temperature. If False, will get the log probs before applying temperature.
65
67
  next_token_logprobs: Optional[torch.Tensor] = None
66
68
  # The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
67
69
  next_token_top_logprobs_val: Optional[List] = None
@@ -517,7 +519,12 @@ class LogitsProcessor(nn.Module):
517
519
  logits = logits[:, : self.config.vocab_size].float()
518
520
 
519
521
  if self.final_logit_softcapping:
520
- fused_softcap(logits, self.final_logit_softcapping)
522
+ if not _is_npu:
523
+ fused_softcap(logits, self.final_logit_softcapping)
524
+ else:
525
+ logits = self.final_logit_softcapping * torch.tanh(
526
+ logits / self.final_logit_softcapping
527
+ )
521
528
 
522
529
  return logits
523
530
 
@@ -1,4 +1,4 @@
1
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
1
+ from sglang.srt.layers.moe.moe_runner import MoeRunner, MoeRunnerConfig
2
2
  from sglang.srt.layers.moe.utils import (
3
3
  DeepEPMode,
4
4
  MoeA2ABackend,
@@ -17,6 +17,7 @@ from sglang.srt.layers.moe.utils import (
17
17
  __all__ = [
18
18
  "DeepEPMode",
19
19
  "MoeA2ABackend",
20
+ "MoeRunner",
20
21
  "MoeRunnerConfig",
21
22
  "MoeRunnerBackend",
22
23
  "initialize_moe_config",