sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -3,10 +3,11 @@ from typing import Optional, Union
3
3
  import torch
4
4
 
5
5
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
6
+ from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
6
7
  from sglang.srt.layers.radix_attention import RadixAttention
7
8
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
8
9
  from sglang.srt.model_executor.model_runner import ModelRunner
9
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
10
+ from sglang.srt.speculative.spec_info import SpecInput
10
11
 
11
12
 
12
13
  class HybridAttnBackend(AttentionBackend):
@@ -21,6 +22,7 @@ class HybridAttnBackend(AttentionBackend):
21
22
  self.model_runner = model_runner
22
23
  self.prefill_backend = prefill_backend
23
24
  self.decode_backend = decode_backend
25
+ self.data_type = model_runner.kv_cache_dtype
24
26
 
25
27
  def _select_backend(self, forward_mode: ForwardMode) -> AttentionBackend:
26
28
  """
@@ -70,7 +72,7 @@ class HybridAttnBackend(AttentionBackend):
70
72
  seq_lens: torch.Tensor,
71
73
  encoder_lens: Optional[torch.Tensor],
72
74
  forward_mode: ForwardMode,
73
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
75
+ spec_info: Optional[SpecInput],
74
76
  ):
75
77
  backend = self._select_backend(forward_mode)
76
78
  backend.init_forward_metadata_capture_cuda_graph(
@@ -91,7 +93,7 @@ class HybridAttnBackend(AttentionBackend):
91
93
  seq_lens_sum: int,
92
94
  encoder_lens: Optional[torch.Tensor],
93
95
  forward_mode: ForwardMode,
94
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
96
+ spec_info: Optional[SpecInput],
95
97
  seq_lens_cpu: Optional[torch.Tensor],
96
98
  ):
97
99
  backend = self._select_backend(forward_mode)
@@ -137,3 +139,9 @@ class HybridAttnBackend(AttentionBackend):
137
139
  return backend.forward_extend(
138
140
  q, k, v, layer, forward_batch, save_kv_cache, **kwargs
139
141
  )
142
+
143
+ def get_indexer_metadata(
144
+ self, layer_id: int, forward_batch: ForwardBatch
145
+ ) -> Optional[BaseIndexerMetadata]:
146
+ backend = self._select_backend(forward_batch.forward_mode)
147
+ return backend.get_indexer_metadata(layer_id, forward_batch)
@@ -21,11 +21,17 @@ from sglang.srt.layers.radix_attention import RadixAttention
21
21
  from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool
22
22
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
23
23
  from sglang.srt.model_executor.model_runner import ModelRunner
24
- from sglang.srt.models.qwen3_next import Qwen3HybridLinearDecoderLayer, fused_gdn_gating
25
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
26
- from sglang.srt.utils import is_npu
24
+ from sglang.srt.models.qwen3_next import fused_gdn_gating
25
+ from sglang.srt.speculative.spec_info import SpecInput
26
+ from sglang.srt.utils import is_cuda, is_npu
27
27
 
28
- if is_npu():
28
+ if is_cuda():
29
+ from sglang.srt.layers.attention.mamba.causal_conv1d import (
30
+ causal_conv1d_fn as causal_conv1d_fn_cuda,
31
+ )
32
+
33
+ causal_conv1d_fn = causal_conv1d_fn_cuda
34
+ elif is_npu():
29
35
  from sgl_kernel_npu.fla.chunk import chunk_gated_delta_rule_npu
30
36
  from sgl_kernel_npu.fla.fused_sigmoid_gating_recurrent import (
31
37
  fused_sigmoid_gating_delta_rule_update_npu,
@@ -58,18 +64,16 @@ class MambaAttnBackend(AttentionBackend):
58
64
  self.forward_metadata: ForwardMetadata = None
59
65
  self.state_indices_list = []
60
66
  self.query_start_loc_list = []
61
-
62
- @classmethod
63
- @lru_cache(maxsize=128)
64
- def _get_cached_arange(cls, bs: int, device_str: str) -> torch.Tensor:
65
- """Cache torch.arange tensors for common batch sizes to avoid repeated allocation."""
66
- device = torch.device(device_str)
67
- return torch.arange(0, bs + 1, dtype=torch.int32, device=device)
67
+ self.cached_cuda_graph_decode_query_start_loc: torch.Tensor = None
68
+ self.cached_cuda_graph_verify_query_start_loc: torch.Tensor = None
68
69
 
69
70
  def init_forward_metadata(self, forward_batch: ForwardBatch):
70
71
  bs = forward_batch.batch_size
72
+
71
73
  if forward_batch.forward_mode.is_decode_or_idle():
72
- query_start_loc = self._get_cached_arange(bs, str(self.device))
74
+ query_start_loc = torch.arange(
75
+ 0, bs + 1, dtype=torch.int32, device=self.device
76
+ )
73
77
  elif forward_batch.forward_mode.is_extend():
74
78
  if forward_batch.forward_mode.is_target_verify():
75
79
  query_start_loc = torch.arange(
@@ -99,6 +103,10 @@ class MambaAttnBackend(AttentionBackend):
99
103
  )
100
104
 
101
105
  def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
106
+ assert (
107
+ max_num_tokens % max_bs == 0
108
+ ), f"max_num_tokens={max_num_tokens} must be divisible by max_bs={max_bs}"
109
+ verify_step = max_num_tokens / max_bs
102
110
  for i in range(max_bs):
103
111
  self.state_indices_list.append(
104
112
  torch.full(
@@ -108,6 +116,16 @@ class MambaAttnBackend(AttentionBackend):
108
116
  self.query_start_loc_list.append(
109
117
  torch.empty((i + 2,), dtype=torch.int32, device=self.device)
110
118
  )
119
+ self.cached_cuda_graph_decode_query_start_loc = torch.arange(
120
+ 0, max_bs + 1, dtype=torch.int32, device=self.device
121
+ )
122
+ self.cached_cuda_graph_verify_query_start_loc = torch.arange(
123
+ 0,
124
+ max_bs * verify_step + 1,
125
+ step=verify_step,
126
+ dtype=torch.int32,
127
+ device=self.device,
128
+ )
111
129
 
112
130
  def init_forward_metadata_capture_cuda_graph(
113
131
  self,
@@ -117,19 +135,15 @@ class MambaAttnBackend(AttentionBackend):
117
135
  seq_lens: torch.Tensor,
118
136
  encoder_lens: Optional[torch.Tensor],
119
137
  forward_mode: ForwardMode,
120
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
138
+ spec_info: Optional[SpecInput],
121
139
  ):
122
140
  if forward_mode.is_decode_or_idle():
123
- self.query_start_loc_list[bs - 1].copy_(self._get_cached_arange(bs, "cuda"))
141
+ self.query_start_loc_list[bs - 1].copy_(
142
+ self.cached_cuda_graph_decode_query_start_loc[: bs + 1]
143
+ )
124
144
  elif forward_mode.is_target_verify():
125
145
  self.query_start_loc_list[bs - 1].copy_(
126
- torch.arange(
127
- 0,
128
- bs * spec_info.draft_token_num + 1,
129
- step=spec_info.draft_token_num,
130
- dtype=torch.int32,
131
- device=self.device,
132
- )
146
+ self.cached_cuda_graph_verify_query_start_loc[: bs + 1]
133
147
  )
134
148
  else:
135
149
  raise ValueError(f"Invalid forward mode: {forward_mode=}")
@@ -148,7 +162,7 @@ class MambaAttnBackend(AttentionBackend):
148
162
  seq_lens_sum: int,
149
163
  encoder_lens: Optional[torch.Tensor],
150
164
  forward_mode: ForwardMode,
151
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
165
+ spec_info: Optional[SpecInput],
152
166
  seq_lens_cpu: Optional[torch.Tensor],
153
167
  ):
154
168
  num_padding = torch.count_nonzero(
@@ -160,23 +174,29 @@ class MambaAttnBackend(AttentionBackend):
160
174
  mamba_indices[bs - num_padding :] = -1
161
175
  self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices)
162
176
  if forward_mode.is_decode_or_idle():
163
- self.query_start_loc_list[bs - 1].copy_(self._get_cached_arange(bs, "cuda"))
164
- if num_padding > 0:
165
- self.query_start_loc_list[bs - 1][bs - num_padding :] = bs - num_padding
166
- elif forward_mode.is_target_verify():
167
- self.query_start_loc_list[bs - 1].copy_(
168
- torch.arange(
169
- 0,
170
- bs * spec_info.draft_token_num + 1,
171
- step=spec_info.draft_token_num,
172
- dtype=torch.int32,
173
- device=self.device,
177
+ if num_padding == 0:
178
+ self.query_start_loc_list[bs - 1].copy_(
179
+ self.cached_cuda_graph_decode_query_start_loc[: bs + 1]
174
180
  )
175
- )
176
- if num_padding > 0:
177
- self.query_start_loc_list[bs - 1][bs - num_padding :] = (
181
+ else:
182
+ self.query_start_loc_list[bs - 1][: bs - num_padding].copy_(
183
+ self.cached_cuda_graph_decode_query_start_loc[: bs - num_padding]
184
+ )
185
+ self.query_start_loc_list[bs - 1][bs - num_padding :].copy_(
178
186
  bs - num_padding
179
- ) * spec_info.draft_token_num
187
+ )
188
+ elif forward_mode.is_target_verify():
189
+ if num_padding == 0:
190
+ self.query_start_loc_list[bs - 1].copy_(
191
+ self.cached_cuda_graph_verify_query_start_loc[: bs + 1]
192
+ )
193
+ else:
194
+ self.query_start_loc_list[bs - 1][: bs - num_padding].copy_(
195
+ self.cached_cuda_graph_verify_query_start_loc[: bs - num_padding]
196
+ )
197
+ self.query_start_loc_list[bs - 1][bs - num_padding :].copy_(
198
+ (bs - num_padding) * spec_info.draft_token_num
199
+ )
180
200
  else:
181
201
  raise ValueError(f"Invalid forward mode: {forward_mode=}")
182
202
 
@@ -343,6 +363,7 @@ class MambaAttnBackend(AttentionBackend):
343
363
  has_initial_state=has_initial_states,
344
364
  cache_indices=cache_indices,
345
365
  query_start_loc=query_start_loc,
366
+ seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
346
367
  ).transpose(0, 1)[:seq_len]
347
368
 
348
369
  key_split_dim = key_dim // attn_tp_size
@@ -431,7 +452,7 @@ class HybridLinearAttnBackend(AttentionBackend):
431
452
  seq_lens: torch.Tensor,
432
453
  encoder_lens: Optional[torch.Tensor],
433
454
  forward_mode: ForwardMode,
434
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
455
+ spec_info: Optional[SpecInput],
435
456
  ):
436
457
  for attn_backend in self.attn_backend_list:
437
458
  attn_backend.init_forward_metadata_capture_cuda_graph(
@@ -452,7 +473,7 @@ class HybridLinearAttnBackend(AttentionBackend):
452
473
  seq_lens_sum: int,
453
474
  encoder_lens: Optional[torch.Tensor],
454
475
  forward_mode: ForwardMode,
455
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
476
+ spec_info: Optional[SpecInput],
456
477
  seq_lens_cpu: Optional[torch.Tensor],
457
478
  ):
458
479
  for attn_backend in self.attn_backend_list:
@@ -567,36 +588,15 @@ class HybridLinearAttnBackend(AttentionBackend):
567
588
 
568
589
  # Compute common indices once to avoid duplication
569
590
  last_steps_all = (accepted_length - 1).to(torch.int64)
570
- valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64)
571
- last_steps = last_steps_all[valid_mask].to(torch.int64)
572
-
573
- if valid_state_indices.numel() > 0:
574
- chunk = 256
575
- num_valid = valid_state_indices.numel()
576
-
577
- # SSM state updates
578
- for i in range(0, num_valid, chunk):
579
- idx = valid_state_indices[i : i + chunk]
580
- steps = last_steps[i : i + chunk]
581
- # per (cache line, step)
582
- for j in range(idx.numel()):
583
- ci = idx[j].item()
584
- st = steps[j].item()
585
- ssm_states[:, ci, :].copy_(
586
- intermediate_state_cache[:, ci, st].to(
587
- ssm_states.dtype, copy=False
588
- )
589
- )
590
-
591
- # Conv window updates
592
- for i in range(0, num_valid, chunk):
593
- idx = valid_state_indices[i : i + chunk]
594
- steps = last_steps[i : i + chunk]
595
- for j in range(idx.numel()):
596
- ci = idx[j].item()
597
- st = steps[j].item()
598
- conv_states[:, ci, :, :].copy_(
599
- intermediate_conv_window_cache[:, ci, st].to(
600
- conv_states.dtype, copy=False
601
- )
602
- )
591
+ valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64) # [N]
592
+ last_steps = last_steps_all[valid_mask].to(torch.int64) # [N]
593
+
594
+ # scatter into ssm_states at the chosen cache lines
595
+ ssm_states[:, valid_state_indices, :] = intermediate_state_cache[
596
+ :, valid_state_indices, last_steps
597
+ ].to(ssm_states.dtype, copy=False)
598
+
599
+ # Scatter into conv_states at the chosen cache lines
600
+ conv_states[:, valid_state_indices, :, :] = intermediate_conv_window_cache[
601
+ :, valid_state_indices, last_steps
602
+ ].to(conv_states.dtype, copy=False)
@@ -23,6 +23,7 @@ def causal_conv1d_fn(
23
23
  conv_states: Optional[torch.Tensor] = None,
24
24
  activation: Optional[str] = "silu",
25
25
  pad_slot_id: int = PAD_SLOT_ID,
26
+ **kwargs,
26
27
  ):
27
28
  """
28
29
  x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
@@ -2,7 +2,7 @@
2
2
  # Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
3
3
  # and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
4
4
 
5
- from typing import Optional, Union
5
+ from typing import List, Optional, Union
6
6
 
7
7
  import numpy as np
8
8
  import torch
@@ -22,11 +22,8 @@ def _causal_conv1d_fwd_kernel( # continuous batching
22
22
  cache_indices_ptr, # conv_state_indices_ptr
23
23
  has_initial_states_ptr,
24
24
  query_start_loc_ptr,
25
- batch_ptr,
26
- token_chunk_offset_ptr,
27
25
  o_ptr, # (dim, seqlen) - actually pointing to x_ptr
28
26
  # Matrix dimensions
29
- batch: tl.int32, # actually padded_batch
30
27
  dim: tl.constexpr,
31
28
  seqlen: tl.int32, # cu_seqlen
32
29
  num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
@@ -69,11 +66,11 @@ def _causal_conv1d_fwd_kernel( # continuous batching
69
66
  # rather than mixing sequences - to make updating initial_states across sequences efficiently
70
67
 
71
68
  # single-sequence id
72
- idx_seq = tl.load(batch_ptr + tl.program_id(0))
73
- chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0))
69
+ idx_seq = tl.program_id(0)
70
+ chunk_offset = tl.program_id(1)
74
71
 
75
72
  # BLOCK_N elements along the feature-dimension (channel)
76
- idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
73
+ idx_feats = tl.program_id(2) * BLOCK_N + tl.arange(0, BLOCK_N)
77
74
 
78
75
  if idx_seq == pad_slot_id:
79
76
  return
@@ -86,6 +83,9 @@ def _causal_conv1d_fwd_kernel( # continuous batching
86
83
  token_offset = BLOCK_M * chunk_offset
87
84
  segment_len = min(BLOCK_M, seqlen - token_offset)
88
85
 
86
+ if segment_len <= 0:
87
+ return
88
+
89
89
  # base of the sequence
90
90
  x_base = (
91
91
  x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim
@@ -382,12 +382,13 @@ def causal_conv1d_fn(
382
382
  bias: Union[torch.Tensor, None],
383
383
  conv_states: torch.Tensor,
384
384
  query_start_loc: torch.Tensor,
385
+ seq_lens_cpu: List[int],
385
386
  cache_indices: Optional[torch.Tensor] = None,
386
387
  has_initial_state: Optional[torch.Tensor] = None,
387
388
  activation: Optional[str] = "silu",
388
389
  pad_slot_id: int = PAD_SLOT_ID,
389
- metadata=None,
390
390
  validate_data=False,
391
+ **kwargs,
391
392
  ):
392
393
  """support varlen + continuous batching when x is 2D tensor
393
394
 
@@ -413,6 +414,8 @@ def causal_conv1d_fn(
413
414
  [length(query_start_loc)-1 == batch]
414
415
  for example: query_start_loc = torch.Tensor([0,10,16,17]),
415
416
  x.shape=(dim,17)
417
+ seq_lens_cpu: (batch) int32
418
+ The sequence lengths of the sequences in the batch
416
419
  cache_indices: (batch) int32
417
420
  indicates the corresponding state index,
418
421
  like so: conv_state = conv_states[cache_indices[batch_id]]
@@ -434,26 +437,7 @@ def causal_conv1d_fn(
434
437
  if isinstance(activation, bool) and activation:
435
438
  activation = "silu"
436
439
 
437
- args = None
438
440
  out = torch.empty_like(x)
439
- if metadata is not None:
440
- cu_seqlen = metadata.cu_seqlen
441
- nums_dict = metadata.nums_dict
442
- # x = metadata.x
443
- args = nums_dict
444
- batch_ptr = metadata.batch_ptr
445
- token_chunk_offset_ptr = metadata.token_chunk_offset_ptr
446
- else:
447
- seqlens = np.diff(query_start_loc.to("cpu"))
448
- args = seqlens
449
- MAX_NUM_PROGRAMS = 1024
450
-
451
- batch_ptr = torch.full(
452
- (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device
453
- ) # tracking which seq-idx the Triton program is handling
454
- token_chunk_offset_ptr = torch.full(
455
- (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device
456
- ) # tracking BLOCK_M-based index in the sequence the Triton program is handling
457
441
 
458
442
  is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1)
459
443
  dim, cu_seqlen = x.shape
@@ -461,7 +445,6 @@ def causal_conv1d_fn(
461
445
  state_len = width - 1
462
446
  np2_statelen = triton.next_power_of_2(state_len)
463
447
 
464
- padded_batch = query_start_loc.size(0) - 1
465
448
  stride_x_seq = 0
466
449
  stride_x_dim = x.stride(0)
467
450
  stride_x_token = x.stride(1)
@@ -501,6 +484,7 @@ def causal_conv1d_fn(
501
484
  assert query_start_loc is not None
502
485
  assert query_start_loc.dim() == 1
503
486
  assert x.stride(0) == 1 or x.stride(1) == 1
487
+ padded_batch = query_start_loc.size(0) - 1
504
488
  if bias is not None:
505
489
  assert bias.dim() == 1
506
490
  assert dim == bias.size(0)
@@ -516,78 +500,14 @@ def causal_conv1d_fn(
516
500
  assert (dim, width) == weight.shape
517
501
  assert is_channel_last, "Need to run in channel-last layout"
518
502
 
519
- if metadata is None:
520
-
521
- def num_program(META, seqlens):
522
- tot = 0
523
-
524
- mlist = []
525
- offsetlist = [] # type: ignore
526
-
527
- nums = -(-seqlens // META["BLOCK_M"])
528
-
529
- tot = nums.sum().item()
530
- mlist = np.repeat(np.arange(len(nums)), nums)
531
- for idx, num in enumerate(nums):
532
- offsetlist.extend(
533
- range(num)
534
- ) # chunk-idx if a sequence is split into multiple chunks
535
-
536
- if META["batch_ptr"].nelement() < len(mlist):
537
- newlen = len(mlist) + 1
538
- META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
539
- META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
540
-
541
- if META["batch_ptr"].nelement() >= len(mlist):
542
- META["batch_ptr"][0 : len(mlist)].copy_(
543
- torch.from_numpy(np.array(mlist))
544
- )
545
- META["token_chunk_offset_ptr"][0 : len(mlist)].copy_(
546
- torch.from_numpy(np.array(offsetlist))
547
- )
548
-
549
- META["batch_ptr"] = META["batch_ptr"].to(META["x_ptr"].device)
550
- META["token_chunk_offset_ptr"] = META["token_chunk_offset_ptr"].to(
551
- META["x_ptr"].device
552
- )
553
- return tot
554
-
555
- else:
556
-
557
- def num_program(META, nums_dict):
558
- tot = nums_dict[META["BLOCK_M"]]["tot"]
559
-
560
- mlist = nums_dict[META["BLOCK_M"]]["mlist"]
561
- mlist_len = nums_dict[META["BLOCK_M"]]["mlist_len"]
562
-
563
- offsetlist = nums_dict[META["BLOCK_M"]]["offsetlist"]
564
-
565
- if nums_dict[META["BLOCK_M"]]["batch_ptr"] is not None:
566
- META["batch_ptr"] = nums_dict[META["BLOCK_M"]]["batch_ptr"]
567
- META["token_chunk_offset_ptr"] = nums_dict[META["BLOCK_M"]][
568
- "token_chunk_offset_ptr"
569
- ]
570
- else:
571
- if META["batch_ptr"].nelement() < mlist_len:
572
- newlen = mlist_len + 1
573
- META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
574
- META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
575
-
576
- if META["batch_ptr"].nelement() >= mlist_len:
577
- META["batch_ptr"][0:mlist_len].copy_(mlist)
578
- META["token_chunk_offset_ptr"][0:mlist_len].copy_(offsetlist)
579
- return tot
580
-
581
503
  def grid(META):
504
+ max_seq_len = max(seq_lens_cpu)
582
505
  return (
583
- num_program(META, args),
506
+ len(seq_lens_cpu), # batch_size
507
+ (max_seq_len + META["BLOCK_M"] - 1) // META["BLOCK_M"],
584
508
  triton.cdiv(dim, META["BLOCK_N"]),
585
509
  )
586
510
 
587
- if batch_ptr.device != x.device:
588
- batch_ptr = batch_ptr.to(x.device)
589
- token_chunk_offset_ptr = token_chunk_offset_ptr.to(x.device)
590
-
591
511
  _causal_conv1d_fwd_kernel[grid](
592
512
  # Pointers to matrices
593
513
  x,
@@ -597,11 +517,8 @@ def causal_conv1d_fn(
597
517
  cache_indices,
598
518
  has_initial_state,
599
519
  query_start_loc,
600
- batch_ptr,
601
- token_chunk_offset_ptr,
602
520
  out,
603
521
  # Matrix dimensions
604
- padded_batch,
605
522
  dim,
606
523
  cu_seqlen,
607
524
  num_cache_lines,