sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/device_config.py +3 -1
  5. sglang/srt/configs/dots_vlm.py +139 -0
  6. sglang/srt/configs/load_config.py +1 -0
  7. sglang/srt/configs/model_config.py +50 -6
  8. sglang/srt/configs/qwen3_next.py +326 -0
  9. sglang/srt/connector/__init__.py +8 -1
  10. sglang/srt/connector/remote_instance.py +82 -0
  11. sglang/srt/constrained/base_grammar_backend.py +48 -12
  12. sglang/srt/constrained/llguidance_backend.py +0 -1
  13. sglang/srt/constrained/outlines_backend.py +0 -1
  14. sglang/srt/constrained/xgrammar_backend.py +28 -9
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/base/conn.py +1 -1
  21. sglang/srt/disaggregation/common/conn.py +15 -12
  22. sglang/srt/disaggregation/decode.py +21 -10
  23. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -445
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +5 -3
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +24 -3
  31. sglang/srt/entrypoints/engine.py +38 -17
  32. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  33. sglang/srt/entrypoints/grpc_server.py +680 -0
  34. sglang/srt/entrypoints/http_server.py +85 -54
  35. sglang/srt/entrypoints/openai/protocol.py +4 -1
  36. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  37. sglang/srt/entrypoints/openai/serving_chat.py +36 -16
  38. sglang/srt/entrypoints/openai/serving_completions.py +12 -3
  39. sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
  40. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  41. sglang/srt/entrypoints/openai/serving_responses.py +6 -3
  42. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  43. sglang/srt/eplb/eplb_manager.py +2 -2
  44. sglang/srt/eplb/expert_distribution.py +26 -13
  45. sglang/srt/eplb/expert_location.py +8 -3
  46. sglang/srt/eplb/expert_location_updater.py +1 -1
  47. sglang/srt/function_call/base_format_detector.py +3 -6
  48. sglang/srt/function_call/ebnf_composer.py +11 -9
  49. sglang/srt/function_call/function_call_parser.py +6 -0
  50. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  51. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  52. sglang/srt/grpc/__init__.py +1 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  56. sglang/srt/hf_transformers_utils.py +4 -0
  57. sglang/srt/layers/activation.py +142 -9
  58. sglang/srt/layers/attention/ascend_backend.py +11 -4
  59. sglang/srt/layers/attention/fla/chunk.py +242 -0
  60. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  61. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  62. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  63. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  66. sglang/srt/layers/attention/fla/index.py +37 -0
  67. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  68. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  69. sglang/srt/layers/attention/fla/op.py +66 -0
  70. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  71. sglang/srt/layers/attention/fla/utils.py +331 -0
  72. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  73. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  74. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  75. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  76. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  77. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  78. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  79. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  80. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  81. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  82. sglang/srt/layers/attention/triton_backend.py +18 -1
  83. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  84. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  85. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  86. sglang/srt/layers/dp_attention.py +30 -1
  87. sglang/srt/layers/layernorm.py +32 -15
  88. sglang/srt/layers/linear.py +34 -3
  89. sglang/srt/layers/logits_processor.py +29 -10
  90. sglang/srt/layers/moe/__init__.py +2 -1
  91. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  92. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  93. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  94. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  95. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  96. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  107. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  108. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  109. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  110. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  111. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  112. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  113. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  114. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  115. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  116. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  117. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  118. sglang/srt/layers/moe/topk.py +30 -9
  119. sglang/srt/layers/moe/utils.py +12 -6
  120. sglang/srt/layers/quantization/awq.py +19 -7
  121. sglang/srt/layers/quantization/base_config.py +11 -6
  122. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  123. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  124. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  125. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  126. sglang/srt/layers/quantization/fp8.py +76 -47
  127. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  128. sglang/srt/layers/quantization/gptq.py +25 -17
  129. sglang/srt/layers/quantization/modelopt_quant.py +147 -47
  130. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  131. sglang/srt/layers/quantization/mxfp4.py +64 -40
  132. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  133. sglang/srt/layers/quantization/unquant.py +135 -47
  134. sglang/srt/layers/quantization/w4afp8.py +30 -17
  135. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  136. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  137. sglang/srt/layers/sampler.py +162 -18
  138. sglang/srt/lora/backend/base_backend.py +50 -8
  139. sglang/srt/lora/backend/triton_backend.py +90 -2
  140. sglang/srt/lora/layers.py +32 -0
  141. sglang/srt/lora/lora.py +4 -1
  142. sglang/srt/lora/lora_manager.py +35 -112
  143. sglang/srt/lora/mem_pool.py +24 -10
  144. sglang/srt/lora/utils.py +18 -9
  145. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  146. sglang/srt/managers/cache_controller.py +158 -160
  147. sglang/srt/managers/data_parallel_controller.py +105 -35
  148. sglang/srt/managers/detokenizer_manager.py +8 -4
  149. sglang/srt/managers/disagg_service.py +46 -0
  150. sglang/srt/managers/io_struct.py +199 -12
  151. sglang/srt/managers/mm_utils.py +1 -0
  152. sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
  153. sglang/srt/managers/schedule_batch.py +77 -56
  154. sglang/srt/managers/schedule_policy.py +1 -1
  155. sglang/srt/managers/scheduler.py +187 -39
  156. sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
  157. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  158. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  159. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  160. sglang/srt/managers/tokenizer_manager.py +259 -519
  161. sglang/srt/managers/tp_worker.py +53 -4
  162. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  163. sglang/srt/mem_cache/hicache_storage.py +3 -23
  164. sglang/srt/mem_cache/hiradix_cache.py +103 -43
  165. sglang/srt/mem_cache/memory_pool.py +347 -48
  166. sglang/srt/mem_cache/memory_pool_host.py +105 -46
  167. sglang/srt/mem_cache/radix_cache.py +0 -2
  168. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  169. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  170. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
  171. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  172. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  173. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
  174. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  175. sglang/srt/metrics/collector.py +493 -76
  176. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  177. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  178. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  179. sglang/srt/model_executor/forward_batch_info.py +59 -2
  180. sglang/srt/model_executor/model_runner.py +356 -29
  181. sglang/srt/model_loader/__init__.py +9 -3
  182. sglang/srt/model_loader/loader.py +128 -4
  183. sglang/srt/model_loader/weight_utils.py +2 -1
  184. sglang/srt/models/apertus.py +686 -0
  185. sglang/srt/models/bailing_moe.py +798 -218
  186. sglang/srt/models/bailing_moe_nextn.py +168 -0
  187. sglang/srt/models/deepseek_v2.py +109 -15
  188. sglang/srt/models/dots_vlm.py +174 -0
  189. sglang/srt/models/dots_vlm_vit.py +337 -0
  190. sglang/srt/models/ernie4.py +1 -1
  191. sglang/srt/models/gemma3n_mm.py +1 -1
  192. sglang/srt/models/glm4_moe.py +1 -1
  193. sglang/srt/models/glm4v.py +4 -2
  194. sglang/srt/models/glm4v_moe.py +3 -0
  195. sglang/srt/models/gpt_oss.py +1 -1
  196. sglang/srt/models/llama4.py +9 -0
  197. sglang/srt/models/llama_eagle3.py +13 -0
  198. sglang/srt/models/longcat_flash.py +2 -2
  199. sglang/srt/models/mllama4.py +25 -0
  200. sglang/srt/models/opt.py +637 -0
  201. sglang/srt/models/qwen2.py +7 -0
  202. sglang/srt/models/qwen2_5_vl.py +27 -3
  203. sglang/srt/models/qwen2_moe.py +56 -12
  204. sglang/srt/models/qwen3_moe.py +1 -1
  205. sglang/srt/models/qwen3_next.py +1042 -0
  206. sglang/srt/models/qwen3_next_mtp.py +112 -0
  207. sglang/srt/models/step3_vl.py +1 -1
  208. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  209. sglang/srt/multimodal/processors/glm4v.py +9 -9
  210. sglang/srt/multimodal/processors/internvl.py +141 -129
  211. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  212. sglang/srt/offloader.py +27 -3
  213. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  214. sglang/srt/sampling/sampling_batch_info.py +18 -15
  215. sglang/srt/server_args.py +276 -35
  216. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  217. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  218. sglang/srt/speculative/eagle_utils.py +0 -2
  219. sglang/srt/speculative/eagle_worker.py +43 -4
  220. sglang/srt/speculative/spec_info.py +5 -0
  221. sglang/srt/speculative/standalone_worker.py +109 -0
  222. sglang/srt/tracing/trace.py +552 -0
  223. sglang/srt/utils.py +34 -3
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  226. sglang/test/runners.py +4 -0
  227. sglang/test/test_cutlass_moe.py +24 -6
  228. sglang/test/test_disaggregation_utils.py +66 -0
  229. sglang/test/test_fp4_moe.py +370 -1
  230. sglang/test/test_utils.py +28 -1
  231. sglang/utils.py +11 -0
  232. sglang/version.py +1 -1
  233. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  234. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
  235. sglang/srt/disaggregation/launch_lb.py +0 -118
  236. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  237. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  238. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,64 @@
1
+ from typing import Callable, List, Tuple
2
+
3
+ import torch
4
+
5
+ LoaderFunction = Callable[[torch.Tensor, torch.Tensor], None]
6
+
7
+
8
+ def mamba_v2_sharded_weight_loader(
9
+ shard_spec: List[Tuple[int, int, float]],
10
+ tp_size: int,
11
+ tp_rank: int,
12
+ ) -> LoaderFunction:
13
+ """Create a weight loader for mamba v2. This ensures that the projections
14
+ are correctly sharded so that they can be split into x, B, C. It also
15
+ ensures the the all the groups corresponding to a head shard is placed
16
+ together with it.
17
+ """
18
+
19
+ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
20
+
21
+ # - track boundary of (sharded) param, and loaded_weight, respectively
22
+ boundary, loaded_boundary = 0, 0
23
+
24
+ # - iterate over the shard specs
25
+ for full_dim, extra, duplicate_groups in shard_spec:
26
+ # - full dim is the model dim (before TP).
27
+ # - extra > 0, means there is expected overall increase
28
+ # of dimensions. This is so because of replication.
29
+ # - ratio is used map the tp_rank to the actual shard
30
+ # rank. This is useful when there is replication of
31
+ # groups to accompany head shards.
32
+
33
+ # - size of the loaded shard
34
+ shard_size = full_dim // tp_size
35
+
36
+ # - compute the rank into the loaded shard.
37
+ # - if there is replication, different TP shards will
38
+ # take from the same rank.
39
+ # NOTE: currently we only support duplication
40
+ # in the case where num_groups == 1
41
+ rank = 0 if duplicate_groups else tp_rank
42
+
43
+ # - leftmost boundary index into loaded weight.
44
+ loaded_skip = rank * shard_size
45
+ loaded_start_idx = loaded_boundary + loaded_skip
46
+
47
+ # - take these many dims from the loaded weight.
48
+ take = min(shard_size, full_dim - extra - loaded_skip)
49
+
50
+ # - always shard on dim 0
51
+ # - the ignore is for a mundane mypy error as it does not
52
+ # seem to handle slices well.
53
+ # https://github.com/python/mypy/issues/2410
54
+ param.data[
55
+ boundary : (boundary + take), ... # type: ignore[misc]
56
+ ] = loaded_weight[
57
+ loaded_start_idx : (loaded_start_idx + take) # type: ignore[misc]
58
+ ] # type: ignore[misc]
59
+
60
+ # move indexing boundaries
61
+ boundary += shard_size
62
+ loaded_boundary += full_dim - extra
63
+
64
+ return loader
@@ -193,10 +193,13 @@ class TorchNativeAttnBackend(AttentionBackend):
193
193
  else:
194
194
  o = torch.empty_like(q)
195
195
 
196
+ if layer.is_cross_attention:
197
+ cache_loc = forward_batch.encoder_out_cache_loc
198
+ else:
199
+ cache_loc = forward_batch.out_cache_loc
200
+
196
201
  if save_kv_cache:
197
- forward_batch.token_to_kv_pool.set_kv_buffer(
198
- layer, forward_batch.out_cache_loc, k, v
199
- )
202
+ forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
200
203
 
201
204
  use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
202
205
 
@@ -241,10 +244,13 @@ class TorchNativeAttnBackend(AttentionBackend):
241
244
  else:
242
245
  o = torch.empty_like(q)
243
246
 
247
+ if layer.is_cross_attention:
248
+ cache_loc = forward_batch.encoder_out_cache_loc
249
+ else:
250
+ cache_loc = forward_batch.out_cache_loc
251
+
244
252
  if save_kv_cache:
245
- forward_batch.token_to_kv_pool.set_kv_buffer(
246
- layer, forward_batch.out_cache_loc, k, v
247
- )
253
+ forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
248
254
 
249
255
  use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
250
256
 
@@ -80,7 +80,13 @@ class TritonAttnBackend(AttentionBackend):
80
80
  self.num_kv_head = model_runner.model_config.get_num_kv_heads(
81
81
  get_attention_tp_size()
82
82
  )
83
- self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
83
+ if model_runner.is_hybrid_gdn:
84
+ # For hybrid linear models, layer_id = 0 may not be full attention
85
+ self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim()
86
+ else:
87
+ self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[
88
+ -1
89
+ ]
84
90
  self.max_context_len = model_runner.model_config.context_len
85
91
  self.device = model_runner.device
86
92
  self.device_core_count = get_device_core_count(model_runner.gpu_id)
@@ -88,6 +94,11 @@ class TritonAttnBackend(AttentionBackend):
88
94
  "SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
89
95
  )
90
96
  self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
97
+ self.split_tile_size = model_runner.server_args.triton_attention_split_tile_size
98
+ if self.split_tile_size is not None:
99
+ self.max_kv_splits = (
100
+ self.max_context_len + self.split_tile_size - 1
101
+ ) // self.split_tile_size
91
102
 
92
103
  # Check arguments
93
104
  assert not (
@@ -147,6 +158,12 @@ class TritonAttnBackend(AttentionBackend):
147
158
  num_kv_splits.fill_(self.max_kv_splits)
148
159
  return
149
160
 
161
+ if self.split_tile_size is not None:
162
+ num_kv_splits[:] = (
163
+ seq_lens + self.split_tile_size - 1
164
+ ) // self.split_tile_size
165
+ return
166
+
150
167
  if num_seq < 256:
151
168
  SCHEDULE_SEQ = 256
152
169
  else:
@@ -20,6 +20,7 @@ from sglang.srt.layers.attention.utils import (
20
20
  create_flashmla_kv_indices_triton,
21
21
  )
22
22
  from sglang.srt.layers.dp_attention import get_attention_tp_size
23
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
23
24
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
24
25
  from sglang.srt.utils import is_flashinfer_available
25
26
 
@@ -45,11 +46,19 @@ TRTLLM_BLOCK_CONSTRAINT = 128
45
46
  global_zero_init_workspace_buffer = None
46
47
 
47
48
 
49
+ @dataclass
50
+ class TRTLLMMLAPrefillMetadata:
51
+ """Metadata for TRTLLM MLA prefill operations."""
52
+
53
+ max_seq_len: int
54
+ cum_seq_lens: torch.Tensor
55
+ seq_lens: torch.Tensor
56
+
57
+
48
58
  @dataclass
49
59
  class TRTLLMMLADecodeMetadata:
50
60
  """Metadata for TRTLLM MLA decode operations."""
51
61
 
52
- workspace: Optional[torch.Tensor] = None
53
62
  block_kv_indices: Optional[torch.Tensor] = None
54
63
  max_seq_len: Optional[int] = None
55
64
 
@@ -64,7 +73,12 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
64
73
  kv_indptr_buf: Optional[torch.Tensor] = None,
65
74
  q_indptr_decode_buf: Optional[torch.Tensor] = None,
66
75
  ):
67
- super().__init__(model_runner, skip_prefill, kv_indptr_buf, q_indptr_decode_buf)
76
+ super().__init__(
77
+ model_runner,
78
+ skip_prefill,
79
+ kv_indptr_buf,
80
+ q_indptr_decode_buf,
81
+ )
68
82
 
69
83
  config = model_runner.model_config
70
84
 
@@ -101,7 +115,12 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
101
115
  # CUDA graph state
102
116
  self.decode_cuda_graph_metadata = {}
103
117
  self.decode_cuda_graph_kv_indices = None
104
- self.forward_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
118
+ self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None
119
+ self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
120
+
121
+ self.disable_chunked_prefix_cache = global_server_args_dict[
122
+ "disable_chunked_prefix_cache"
123
+ ]
105
124
 
106
125
  def _calc_padded_blocks(self, max_seq_len: int) -> int:
107
126
  """
@@ -177,9 +196,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
177
196
  self.decode_cuda_graph_kv_indices = torch.full(
178
197
  (max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
179
198
  )
180
- self.decode_cuda_graph_workspace = torch.empty(
181
- self.workspace_size, dtype=torch.int8, device=self.device
182
- )
183
199
 
184
200
  super().init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf)
185
201
 
@@ -230,12 +246,11 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
230
246
  max_seq_len_val = int(seq_lens.max().item())
231
247
 
232
248
  metadata = TRTLLMMLADecodeMetadata(
233
- self.decode_cuda_graph_workspace,
234
249
  block_kv_indices,
235
250
  max_seq_len_val,
236
251
  )
237
252
  self.decode_cuda_graph_metadata[bs] = metadata
238
- self.forward_metadata = metadata
253
+ self.forward_decode_metadata = metadata
239
254
 
240
255
  def init_forward_metadata_replay_cuda_graph(
241
256
  self,
@@ -291,31 +306,55 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
291
306
  def init_forward_metadata(self, forward_batch: ForwardBatch):
292
307
  """Initialize the metadata for a forward pass."""
293
308
  # Delegate to parent for non-decode modes.
294
- if not forward_batch.forward_mode.is_decode_or_idle():
295
- return super().init_forward_metadata(forward_batch)
309
+ if (
310
+ forward_batch.forward_mode.is_extend()
311
+ and not forward_batch.forward_mode.is_target_verify()
312
+ and not forward_batch.forward_mode.is_draft_extend()
313
+ ):
314
+ if self.disable_chunked_prefix_cache:
315
+ super().init_forward_metadata(forward_batch)
316
+
317
+ seq_lens = forward_batch.seq_lens - forward_batch.extend_prefix_lens
318
+ cum_seq_lens_q = torch.cat(
319
+ (
320
+ torch.tensor([0], device=forward_batch.seq_lens.device),
321
+ torch.cumsum(seq_lens, dim=0),
322
+ )
323
+ ).int()
324
+ max_seq_len = max(forward_batch.extend_seq_lens_cpu)
325
+ self.forward_prefill_metadata = TRTLLMMLAPrefillMetadata(
326
+ max_seq_len,
327
+ cum_seq_lens_q,
328
+ seq_lens,
329
+ )
330
+ elif forward_batch.forward_mode.is_decode_or_idle():
331
+ bs = forward_batch.batch_size
296
332
 
297
- bs = forward_batch.batch_size
333
+ # Get maximum sequence length.
334
+ if getattr(forward_batch, "seq_lens_cpu", None) is not None:
335
+ max_seq = forward_batch.seq_lens_cpu.max().item()
336
+ else:
337
+ max_seq = forward_batch.seq_lens.max().item()
338
+
339
+ max_seqlen_pad = self._calc_padded_blocks(max_seq)
340
+ block_kv_indices = self._create_block_kv_indices(
341
+ bs,
342
+ max_seqlen_pad,
343
+ forward_batch.req_pool_indices,
344
+ forward_batch.seq_lens,
345
+ forward_batch.seq_lens.device,
346
+ )
298
347
 
299
- # Get maximum sequence length.
300
- if getattr(forward_batch, "seq_lens_cpu", None) is not None:
301
- max_seq = forward_batch.seq_lens_cpu.max().item()
348
+ max_seq_len_val = int(max_seq)
349
+ self.forward_decode_metadata = TRTLLMMLADecodeMetadata(
350
+ block_kv_indices, max_seq_len_val
351
+ )
352
+ forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata
302
353
  else:
303
- max_seq = forward_batch.seq_lens.max().item()
304
-
305
- max_seqlen_pad = self._calc_padded_blocks(max_seq)
306
- block_kv_indices = self._create_block_kv_indices(
307
- bs,
308
- max_seqlen_pad,
309
- forward_batch.req_pool_indices,
310
- forward_batch.seq_lens,
311
- forward_batch.seq_lens.device,
312
- )
354
+ return super().init_forward_metadata(forward_batch)
313
355
 
314
- max_seq_len_val = int(max_seq)
315
- self.forward_metadata = TRTLLMMLADecodeMetadata(
316
- self.workspace_buffer, block_kv_indices, max_seq_len_val
317
- )
318
- forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
356
+ def init_mha_chunk_metadata(self, forward_batch: ForwardBatch):
357
+ super().init_mha_chunk_metadata(forward_batch, disable_flashinfer_ragged=True)
319
358
 
320
359
  def quantize_and_rope_for_fp8(
321
360
  self,
@@ -459,7 +498,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
459
498
  # Get metadata
460
499
  metadata = (
461
500
  getattr(forward_batch, "decode_trtllm_mla_metadata", None)
462
- or self.forward_metadata
501
+ or self.forward_decode_metadata
463
502
  )
464
503
 
465
504
  # Scale computation for TRTLLM MLA kernel BMM1 operation:
@@ -482,7 +521,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
482
521
  raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
483
522
  query=query,
484
523
  kv_cache=kv_cache,
485
- workspace_buffer=metadata.workspace,
524
+ workspace_buffer=self.workspace_buffer,
486
525
  qk_nope_head_dim=self.qk_nope_head_dim,
487
526
  kv_lora_rank=self.kv_lora_rank,
488
527
  qk_rope_head_dim=self.qk_rope_head_dim,
@@ -496,6 +535,60 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
496
535
  output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
497
536
  return output
498
537
 
538
+ def forward_extend(
539
+ self,
540
+ q: torch.Tensor,
541
+ k: torch.Tensor,
542
+ v: torch.Tensor,
543
+ layer: RadixAttention,
544
+ forward_batch: ForwardBatch,
545
+ save_kv_cache: bool = True,
546
+ q_rope: Optional[torch.Tensor] = None,
547
+ k_rope: Optional[torch.Tensor] = None,
548
+ ) -> torch.Tensor:
549
+ if (
550
+ forward_batch.forward_mode.is_target_verify()
551
+ or forward_batch.forward_mode.is_draft_extend()
552
+ ):
553
+ return super().forward_extend(
554
+ q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
555
+ )
556
+ # chunked prefix cache is not enabled, use Flashinfer MLA prefill kernel
557
+ if forward_batch.attn_attend_prefix_cache is None:
558
+ return super().forward_extend(
559
+ q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
560
+ )
561
+
562
+ if not forward_batch.attn_attend_prefix_cache:
563
+ q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
564
+ k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
565
+ v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
566
+ output = flashinfer.prefill.trtllm_ragged_attention_deepseek(
567
+ query=q,
568
+ key=k,
569
+ value=v,
570
+ workspace_buffer=self.workspace_buffer,
571
+ seq_lens=self.forward_prefill_metadata.seq_lens,
572
+ max_q_len=self.forward_prefill_metadata.max_seq_len,
573
+ max_kv_len=self.forward_prefill_metadata.max_seq_len,
574
+ bmm1_scale=layer.scaling,
575
+ bmm2_scale=1.0,
576
+ o_sf_scale=1.0,
577
+ batch_size=forward_batch.batch_size,
578
+ window_left=-1,
579
+ cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
580
+ cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens,
581
+ enable_pdl=False,
582
+ is_causal=True,
583
+ return_lse=forward_batch.mha_return_lse,
584
+ )
585
+ else:
586
+ # replace with trtllm ragged attention once accuracy is resolved.
587
+ output = super().forward_extend(
588
+ q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
589
+ )
590
+ return output
591
+
499
592
 
500
593
  class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
501
594
  """Multi-step draft backend for TRT-LLM MLA used by EAGLE."""
@@ -64,8 +64,7 @@ def get_wave_kernel(
64
64
  subs=hyperparams_0,
65
65
  canonicalize=True,
66
66
  run_bench=False,
67
- use_buffer_load_ops=True,
68
- use_buffer_store_ops=True,
67
+ use_buffer_ops=True,
69
68
  waves_per_eu=2,
70
69
  dynamic_symbols=dynamic_symbols_0,
71
70
  wave_runtime=True,
@@ -77,8 +76,7 @@ def get_wave_kernel(
77
76
  subs=hyperparams_1,
78
77
  canonicalize=True,
79
78
  run_bench=False,
80
- use_buffer_load_ops=False,
81
- use_buffer_store_ops=False,
79
+ use_buffer_ops=False,
82
80
  waves_per_eu=4,
83
81
  dynamic_symbols=dynamic_symbols_1,
84
82
  wave_runtime=True,
@@ -67,11 +67,9 @@ def get_wave_kernel(
67
67
  schedule=SchedulingType.NONE,
68
68
  use_scheduling_barriers=False,
69
69
  dynamic_symbols=dynamic_symbols,
70
- use_buffer_load_ops=True,
71
- use_buffer_store_ops=True,
70
+ use_buffer_ops=True,
72
71
  waves_per_eu=2,
73
72
  denorm_fp_math_f32="preserve-sign",
74
- gpu_native_math_precision=True,
75
73
  wave_runtime=True,
76
74
  )
77
75
  options = set_default_run_config(options)
@@ -51,7 +51,12 @@ class DpPaddingMode(IntEnum):
51
51
  return self == DpPaddingMode.SUM_LEN
52
52
 
53
53
  @classmethod
54
- def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DpPaddingMode:
54
+ def get_dp_padding_mode(
55
+ cls, is_extend_in_batch, global_num_tokens: List[int]
56
+ ) -> DpPaddingMode:
57
+ if is_extend_in_batch:
58
+ return DpPaddingMode.SUM_LEN
59
+
55
60
  # we choose the mode that minimizes the communication cost
56
61
  max_len = max(global_num_tokens)
57
62
  sum_len = sum(global_num_tokens)
@@ -119,6 +124,18 @@ class _DpGatheredBufferWrapper:
119
124
  def get_dp_global_num_tokens(cls) -> List[int]:
120
125
  return cls._global_num_tokens
121
126
 
127
+ @classmethod
128
+ def get_dp_hidden_size(cls) -> int:
129
+ return cls._hidden_size
130
+
131
+ @classmethod
132
+ def get_dp_dtype(cls) -> torch.dtype:
133
+ return cls._dtype
134
+
135
+ @classmethod
136
+ def get_dp_device(cls) -> torch.device:
137
+ return cls._device
138
+
122
139
 
123
140
  def set_dp_buffer_len(
124
141
  global_dp_buffer_len: int,
@@ -150,6 +167,18 @@ def get_dp_global_num_tokens() -> List[int]:
150
167
  return _DpGatheredBufferWrapper.get_dp_global_num_tokens()
151
168
 
152
169
 
170
+ def get_dp_hidden_size() -> int:
171
+ return _DpGatheredBufferWrapper.get_dp_hidden_size()
172
+
173
+
174
+ def get_dp_dtype() -> torch.dtype:
175
+ return _DpGatheredBufferWrapper.get_dp_dtype()
176
+
177
+
178
+ def get_dp_device() -> torch.device:
179
+ return _DpGatheredBufferWrapper.get_dp_device()
180
+
181
+
153
182
  def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
154
183
  if not enable_dp_attention:
155
184
  return tp_rank, tp_size, 0
@@ -18,6 +18,7 @@ from typing import Optional, Tuple, Union
18
18
 
19
19
  import torch
20
20
  import torch.nn as nn
21
+ from packaging.version import Version
21
22
 
22
23
  from sglang.srt.custom_op import CustomOp
23
24
  from sglang.srt.utils import (
@@ -25,32 +26,38 @@ from sglang.srt.utils import (
25
26
  get_bool_env_var,
26
27
  is_cpu,
27
28
  is_cuda,
29
+ is_flashinfer_available,
28
30
  is_hip,
29
31
  is_npu,
32
+ is_xpu,
30
33
  supports_custom_op,
31
34
  )
32
35
 
33
36
  _is_cuda = is_cuda()
37
+ _is_flashinfer_available = is_flashinfer_available()
34
38
  _is_hip = is_hip()
35
39
  _is_npu = is_npu()
36
40
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
37
41
  _is_cpu_amx_available = cpu_has_amx_support()
38
42
  _is_cpu = is_cpu()
43
+ _is_xpu = is_xpu()
39
44
 
40
45
  if _is_cuda:
41
- from sgl_kernel import (
42
- fused_add_rmsnorm,
43
- gemma_fused_add_rmsnorm,
44
- gemma_rmsnorm,
45
- rmsnorm,
46
- )
46
+ if _is_flashinfer_available:
47
+ from flashinfer.norm import fused_add_rmsnorm
48
+ else:
49
+ from sgl_kernel import fused_add_rmsnorm
50
+ from sgl_kernel import gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm
47
51
 
48
52
  if _use_aiter:
49
53
  from aiter import rmsnorm2d_fwd as rms_norm
50
54
  from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm
51
55
  elif _is_hip:
56
+ import vllm
52
57
  from vllm._custom_ops import fused_add_rms_norm, rms_norm
53
58
 
59
+ _vllm_version = Version(vllm.__version__)
60
+
54
61
  logger = logging.getLogger(__name__)
55
62
 
56
63
  if _is_npu:
@@ -127,8 +134,21 @@ class RMSNorm(CustomOp):
127
134
  # NOTE: Remove this if aiter kernel supports discontinuous input
128
135
  x = x.contiguous()
129
136
  if residual is not None:
130
- fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
131
- return x, residual
137
+ if _vllm_version < Version("0.9"):
138
+ fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
139
+ return x, residual
140
+ else:
141
+ residual_out = torch.empty_like(x)
142
+ output = torch.empty_like(x)
143
+ fused_add_rms_norm(
144
+ output,
145
+ x,
146
+ residual_out,
147
+ residual,
148
+ self.weight.data,
149
+ self.variance_epsilon,
150
+ )
151
+ return output, residual_out
132
152
  out = torch.empty_like(x)
133
153
  rms_norm(out, x, self.weight.data, self.variance_epsilon)
134
154
  return out
@@ -271,16 +291,11 @@ class GemmaRMSNorm(CustomOp):
271
291
  x: torch.Tensor,
272
292
  residual: Optional[torch.Tensor] = None,
273
293
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
274
- orig_dtype = x.dtype
275
294
  if residual is not None:
276
295
  x = x + residual
277
296
  residual = x
278
297
 
279
- x = x.float()
280
- variance = torch_npu.mean(torch_npu.pow(x, 2), dim=-1, keepdim=True)
281
- x = x * torch_npu.rsqrt(variance + self.variance_epsilon)
282
- x = x * (1.0 + self.weight.float())
283
- x = x.to(orig_dtype)
298
+ x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon)
284
299
  return x if residual is None else (x, residual)
285
300
 
286
301
 
@@ -312,7 +327,9 @@ class Gemma3RMSNorm(CustomOp):
312
327
  return f"{tuple(self.weight.shape)}, eps={self.eps}"
313
328
 
314
329
 
315
- if not (_is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available)):
330
+ if not (
331
+ _is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_xpu
332
+ ):
316
333
  logger.info(
317
334
  "sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
318
335
  )
@@ -235,9 +235,8 @@ class ReplicatedLinear(LinearBase):
235
235
  loaded_weight = loaded_weight[:1]
236
236
  else:
237
237
  raise ValueError(f"{loaded_weight} are not all equal")
238
- assert (
239
- param.size() == loaded_weight.size()
240
- ), f"Loading weight error: param: {param.size()}, loaded_weight: {loaded_weight.size()}"
238
+
239
+ assert param.size() == loaded_weight.size()
241
240
  param.data.copy_(loaded_weight)
242
241
 
243
242
  def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@@ -894,6 +893,35 @@ class QKVParallelLinear(ColumnParallelLinear):
894
893
  )
895
894
  self.weight_loader_v2(param, loaded_weight_shard, shard_id)
896
895
 
896
+ def _load_qkv_block_scale(
897
+ self, param: BasevLLMParameter, loaded_weight: torch.Tensor
898
+ ):
899
+ block_n, _ = self.quant_method.quant_config.weight_block_size
900
+ q_size = self.total_num_heads * self.head_size // block_n
901
+ k_size = self.total_num_kv_heads * self.head_size // block_n
902
+ v_size = self.total_num_kv_heads * self.head_size // block_n
903
+ shard_offsets = [
904
+ # (shard_id, shard_offset, shard_size)
905
+ ("q", 0, q_size),
906
+ ("k", q_size, k_size),
907
+ ("v", q_size + k_size, v_size),
908
+ ]
909
+ for shard_id, shard_offset, shard_size in shard_offsets:
910
+ loaded_weight_shard = loaded_weight.narrow(
911
+ param.output_dim, shard_offset, shard_size
912
+ )
913
+ rank_shard_offset = self._get_shard_offset_mapping(shard_id) // block_n
914
+ rank_shard_size = self._get_shard_size_mapping(shard_id) // block_n
915
+ param.load_qkv_weight(
916
+ loaded_weight=loaded_weight_shard,
917
+ num_heads=self.num_kv_head_replicas,
918
+ shard_id=shard_id,
919
+ shard_offset=rank_shard_offset,
920
+ shard_size=rank_shard_size,
921
+ tp_rank=self.tp_rank,
922
+ use_presharded_weights=self.use_presharded_weights,
923
+ )
924
+
897
925
  def weight_loader_v2(
898
926
  self,
899
927
  param: BasevLLMParameter,
@@ -907,6 +935,9 @@ class QKVParallelLinear(ColumnParallelLinear):
907
935
  elif type(param) in (RowvLLMParameter, BasevLLMParameter):
908
936
  param.load_qkv_weight(loaded_weight=loaded_weight)
909
937
  return
938
+ elif isinstance(param, BlockQuantScaleParameter):
939
+ self._load_qkv_block_scale(param, loaded_weight)
940
+ return
910
941
  # TODO: @dsikka - move to parameter.py
911
942
  self._load_fused_module_from_checkpoint(param, loaded_weight)
912
943
  return