sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,622 @@
1
+ # Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py
2
+
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
5
+
6
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
7
+ # Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_scan.py
8
+
9
+ # ruff: noqa: E501,SIM102
10
+
11
+ import torch
12
+ import triton
13
+ import triton.language as tl
14
+ from packaging import version
15
+
16
+ TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
17
+
18
+
19
+ # @triton.autotune(
20
+ # configs=[
21
+ # triton.Config(
22
+ # {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
23
+ # num_stages=3,
24
+ # num_warps=8,
25
+ # ),
26
+ # triton.Config(
27
+ # {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
28
+ # num_stages=4,
29
+ # num_warps=4,
30
+ # ),
31
+ # triton.Config(
32
+ # {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
33
+ # num_stages=4,
34
+ # num_warps=4,
35
+ # ),
36
+ # triton.Config(
37
+ # {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
38
+ # num_stages=4,
39
+ # num_warps=4,
40
+ # ),
41
+ # triton.Config(
42
+ # {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
43
+ # num_stages=4,
44
+ # num_warps=4,
45
+ # ),
46
+ # triton.Config(
47
+ # {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
48
+ # num_stages=4,
49
+ # num_warps=4,
50
+ # ),
51
+ # triton.Config(
52
+ # {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
53
+ # num_stages=4,
54
+ # num_warps=4,
55
+ # ),
56
+ # triton.Config(
57
+ # {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
58
+ # num_stages=4,
59
+ # num_warps=4,
60
+ # ),
61
+ # triton.Config(
62
+ # {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
63
+ # num_stages=5,
64
+ # num_warps=2,
65
+ # ),
66
+ # triton.Config(
67
+ # {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
68
+ # num_stages=5,
69
+ # num_warps=2,
70
+ # ),
71
+ # triton.Config(
72
+ # {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
73
+ # num_stages=4,
74
+ # num_warps=2,
75
+ # ),
76
+ # ],
77
+ # key=["chunk_size", "hdim", "dstate", "IS_CAUSAL"],
78
+ # )
79
+ @triton.jit
80
+ def _chunk_scan_fwd_kernel(
81
+ # Pointers to matrices
82
+ cb_ptr,
83
+ x_ptr,
84
+ z_ptr,
85
+ out_ptr,
86
+ out_x_ptr,
87
+ dt_ptr,
88
+ dA_cumsum_ptr,
89
+ seq_idx_ptr,
90
+ C_ptr,
91
+ states_ptr,
92
+ D_ptr,
93
+ initstates_ptr,
94
+ chunk_indices_ptr,
95
+ chunk_offsets_ptr,
96
+ chunk_meta_num,
97
+ # Matrix dimensions
98
+ chunk_size,
99
+ hdim,
100
+ dstate,
101
+ batch,
102
+ seqlen,
103
+ nheads_ngroups_ratio,
104
+ # Strides
105
+ stride_cb_batch,
106
+ stride_cb_chunk,
107
+ stride_cb_head,
108
+ stride_cb_csize_m,
109
+ stride_cb_csize_k,
110
+ stride_x_batch,
111
+ stride_x_seqlen,
112
+ stride_x_head,
113
+ stride_x_hdim,
114
+ stride_z_batch,
115
+ stride_z_seqlen,
116
+ stride_z_head,
117
+ stride_z_hdim,
118
+ stride_out_batch,
119
+ stride_out_seqlen,
120
+ stride_out_head,
121
+ stride_out_hdim,
122
+ stride_dt_batch,
123
+ stride_dt_chunk,
124
+ stride_dt_head,
125
+ stride_dt_csize,
126
+ stride_dA_cs_batch,
127
+ stride_dA_cs_chunk,
128
+ stride_dA_cs_head,
129
+ stride_dA_cs_csize,
130
+ stride_seq_idx_batch,
131
+ stride_seq_idx_seqlen,
132
+ stride_C_batch,
133
+ stride_C_seqlen,
134
+ stride_C_head,
135
+ stride_C_dstate,
136
+ stride_states_batch,
137
+ stride_states_chunk,
138
+ stride_states_head,
139
+ stride_states_hdim,
140
+ stride_states_dstate,
141
+ stride_init_states_batch,
142
+ stride_init_states_head,
143
+ stride_init_states_hdim,
144
+ stride_init_states_dstate,
145
+ stride_D_head,
146
+ # Meta-parameters
147
+ IS_CAUSAL: tl.constexpr,
148
+ HAS_D: tl.constexpr,
149
+ D_HAS_HDIM: tl.constexpr,
150
+ HAS_Z: tl.constexpr,
151
+ HAS_SEQ_IDX: tl.constexpr,
152
+ BLOCK_SIZE_DSTATE: tl.constexpr,
153
+ IS_TRITON_22: tl.constexpr,
154
+ HAS_INITSTATES: tl.constexpr,
155
+ BLOCK_SIZE_M: tl.constexpr = 16,
156
+ BLOCK_SIZE_N: tl.constexpr = 16,
157
+ BLOCK_SIZE_K: tl.constexpr = 16,
158
+ ):
159
+ pid_bc = tl.program_id(axis=1).to(tl.int64)
160
+ pid_c = pid_bc // batch
161
+ pid_b = pid_bc - pid_c * batch
162
+ if not HAS_INITSTATES:
163
+ c_idx = pid_c
164
+ c_off = 0
165
+ else:
166
+ c_idx = tl.load(chunk_indices_ptr + pid_c, mask=pid_c > -1, other=0)
167
+ c_off = tl.load(chunk_offsets_ptr + pid_c, mask=pid_c > -1, other=0)
168
+
169
+ pid_h = tl.program_id(axis=2)
170
+ num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
171
+ pid_m = tl.program_id(axis=0) // num_pid_n
172
+ pid_n = tl.program_id(axis=0) % num_pid_n
173
+ cb_ptr += (
174
+ pid_b * stride_cb_batch
175
+ + c_idx * stride_cb_chunk
176
+ + (pid_h // nheads_ngroups_ratio) * stride_cb_head
177
+ )
178
+ x_ptr += (
179
+ pid_b * stride_x_batch
180
+ + c_idx * chunk_size * stride_x_seqlen
181
+ + pid_h * stride_x_head
182
+ )
183
+ dt_ptr += pid_b * stride_dt_batch + c_idx * stride_dt_chunk + pid_h * stride_dt_head
184
+ dA_cumsum_ptr += (
185
+ pid_b * stride_dA_cs_batch
186
+ + c_idx * stride_dA_cs_chunk
187
+ + pid_h * stride_dA_cs_head
188
+ )
189
+ C_ptr += (
190
+ pid_b * stride_C_batch
191
+ + c_idx * chunk_size * stride_C_seqlen
192
+ + (pid_h // nheads_ngroups_ratio) * stride_C_head
193
+ )
194
+
195
+ # M-block offsets and prev states
196
+ # - logic in next block may override these if there is an active offset
197
+ offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M)
198
+ prev_states_ptr = (
199
+ states_ptr
200
+ + pid_b * stride_states_batch
201
+ + c_idx * stride_states_chunk
202
+ + pid_h * stride_states_head
203
+ )
204
+ prev_states_hdim = stride_states_hdim
205
+ prev_states_dstate = stride_states_dstate
206
+
207
+ chunk_size_limit = min(chunk_size, seqlen - c_idx * chunk_size)
208
+ if HAS_SEQ_IDX:
209
+ seq_idx_ptr += (
210
+ pid_b * stride_seq_idx_batch + c_idx * chunk_size * stride_seq_idx_seqlen
211
+ )
212
+
213
+ # - we only need seq_idx_prev to be aligned to chunk boundary
214
+ seq_idx_prev = tl.load(
215
+ seq_idx_ptr - stride_seq_idx_seqlen, mask=c_idx >= 1, other=0
216
+ )
217
+
218
+ if HAS_INITSTATES:
219
+ # if there are init states, we only need seq_idx_m to point
220
+ # what is the current seq_idx
221
+
222
+ # get current seq idx
223
+ if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit:
224
+ seq_idx_m = tl.load(
225
+ seq_idx_ptr
226
+ + (pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen,
227
+ )
228
+
229
+ # - recall that in ssd_state_passing, for the case c_off == 0
230
+ # i.e., the very first sequence, we made states_ptr hold its initial state
231
+ # so this edge case is taken care of
232
+ if (
233
+ (c_off == 0)
234
+ and (
235
+ seq_idx_prev != seq_idx_m
236
+ ) # if a seq is changed exactly on boundary
237
+ or (c_off > 0) # implies a new example (pseudo chunk)
238
+ ):
239
+
240
+ # - replace prev_states_ptr with init_states
241
+ prev_states_ptr = (
242
+ initstates_ptr
243
+ + seq_idx_m * stride_init_states_batch
244
+ + pid_h * stride_init_states_head
245
+ )
246
+ prev_states_hdim = stride_init_states_hdim # override strides
247
+ prev_states_dstate = stride_init_states_dstate
248
+
249
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
250
+ dA_cs_m = tl.load(
251
+ dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0
252
+ ).to(tl.float32)
253
+
254
+ # - handle chunk state limit
255
+ if HAS_INITSTATES:
256
+
257
+ # have to split this if otherwise compilation will have problems
258
+ dA_cs_m_boundary = 0.0
259
+
260
+ # get the c_idx for the next (logica) chunk
261
+ c_idx_n = tl.load(
262
+ chunk_indices_ptr + (pid_c + 1),
263
+ mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num,
264
+ other=-1, # to trigger different chunk
265
+ )
266
+
267
+ # - there are things to consider
268
+ # A. if c_off > 0 then we need to move the dA_cs boundary to ensure correct
269
+ # contribution of past states
270
+ # B. if c_off_n < chunk_size_limit, then we need to adjust this so as not to
271
+ # encroach into the next sequence, where c_off_n is the offset of the next
272
+ # (logical) chunk.
273
+ # An equivalent check for B is c_idx == c_idx_n, where there is repetition in
274
+ # (logical) chunk indices.
275
+
276
+ if (c_idx == c_idx_n) or c_off > 0:
277
+
278
+ # get the next offset
279
+ c_off_n = tl.load(
280
+ chunk_offsets_ptr + (pid_c + 1),
281
+ mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num,
282
+ other=chunk_size,
283
+ )
284
+
285
+ # in this case, adjust down the chunk_size_limit
286
+ if c_idx == c_idx_n:
287
+ chunk_size_limit = min(c_off_n, chunk_size_limit)
288
+
289
+ # get the cs at the offset boundary
290
+ # - c_off == 0 is a passthrough
291
+ # - We need dA_cs at the boundary, defined by c_off - no need
292
+ # to increase pointer by pid_m (it is a constant offset,
293
+ # i.e. the same for all blocks)
294
+ dA_cs_m_boundary = tl.load(
295
+ dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize,
296
+ mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)),
297
+ other=0.0,
298
+ ).to(tl.float32)
299
+
300
+ if HAS_SEQ_IDX:
301
+ # - handle seq idx when HAS_INITSTATES==False
302
+ if not HAS_INITSTATES:
303
+ seq_idx_m = tl.load(
304
+ seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
305
+ mask=offs_m < chunk_size_limit,
306
+ other=-1,
307
+ )
308
+
309
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
310
+
311
+ # Without the if (pid_c > -1), with Triton 2.1.0, I get
312
+ # Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed.
313
+ # With Triton 2.2.0, this works
314
+ if IS_TRITON_22 or c_idx > -1:
315
+ # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
316
+ offs_k_dstate = tl.arange(
317
+ 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
318
+ )
319
+ C_ptrs = C_ptr + (
320
+ offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate
321
+ )
322
+
323
+ prev_states_ptrs = prev_states_ptr + (
324
+ offs_n[None, :] * prev_states_hdim
325
+ + offs_k_dstate[:, None] * prev_states_dstate
326
+ )
327
+ if HAS_SEQ_IDX:
328
+
329
+ if not HAS_INITSTATES:
330
+ # - this is for continuous batching where there is no init states
331
+ scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)
332
+ else:
333
+ # - if there is initstates, we will rely on prev_states, no zeroing
334
+ # required.
335
+ scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary)
336
+ else:
337
+ scale_m = tl.exp(dA_cs_m)
338
+ if BLOCK_SIZE_DSTATE <= 128:
339
+ C = tl.load(
340
+ C_ptrs,
341
+ mask=(offs_m[:, None] < chunk_size_limit)
342
+ & (offs_k_dstate[None, :] < dstate),
343
+ other=0.0,
344
+ )
345
+
346
+ prev_states = tl.load(
347
+ prev_states_ptrs,
348
+ mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim),
349
+ other=0.0,
350
+ )
351
+ prev_states = prev_states.to(C_ptr.dtype.element_ty)
352
+ acc = tl.dot(C, prev_states) * scale_m[:, None]
353
+ else:
354
+ for k in range(0, dstate, BLOCK_SIZE_K):
355
+ C = tl.load(
356
+ C_ptrs,
357
+ mask=(offs_m[:, None] < chunk_size_limit)
358
+ & (offs_k_dstate[None, :] < dstate - k),
359
+ other=0.0,
360
+ )
361
+ # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty)
362
+ prev_states = tl.load(
363
+ prev_states_ptrs,
364
+ mask=(offs_k_dstate[:, None] < dstate - k)
365
+ & (offs_n[None, :] < hdim),
366
+ other=0.0,
367
+ )
368
+ prev_states = prev_states.to(C_ptr.dtype.element_ty)
369
+ acc += tl.dot(C, prev_states)
370
+ C_ptrs += BLOCK_SIZE_K
371
+ prev_states_ptrs += BLOCK_SIZE_K
372
+ acc *= scale_m[:, None]
373
+
374
+ offs_k = tl.arange(0, BLOCK_SIZE_K) + c_off
375
+ cb_ptrs = cb_ptr + (
376
+ offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k
377
+ )
378
+ x_ptrs = x_ptr + (
379
+ offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
380
+ )
381
+ dt_ptrs = dt_ptr + offs_k * stride_dt_csize
382
+ dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
383
+ K_MAX = (
384
+ chunk_size_limit
385
+ if not IS_CAUSAL
386
+ else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit)
387
+ )
388
+ for k in range(0, K_MAX, BLOCK_SIZE_K):
389
+ cb = tl.load(
390
+ cb_ptrs,
391
+ mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k),
392
+ other=0.0,
393
+ ).to(tl.float32)
394
+ dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(
395
+ tl.float32
396
+ )
397
+ # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
398
+ # So we don't need masking wrt seq_idx here.
399
+ cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :])
400
+ dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
401
+ cb *= dt_k
402
+ if IS_CAUSAL:
403
+ mask = offs_m[:, None] >= k + offs_k[None, :]
404
+ cb = tl.where(mask, cb, 0.0)
405
+ cb = cb.to(x_ptr.dtype.element_ty)
406
+ x = tl.load(
407
+ x_ptrs,
408
+ mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim),
409
+ other=0.0,
410
+ )
411
+ acc += tl.dot(cb, x)
412
+ cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
413
+ x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
414
+ dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
415
+ dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
416
+
417
+ offs_out_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M)
418
+ offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
419
+
420
+ if HAS_D:
421
+ if D_HAS_HDIM:
422
+ D = tl.load(
423
+ D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0
424
+ ).to(tl.float32)
425
+ else:
426
+ D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
427
+ x_residual = tl.load(
428
+ x_ptr
429
+ + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim),
430
+ mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
431
+ other=0.0,
432
+ ).to(tl.float32)
433
+ acc += x_residual * D
434
+
435
+ if HAS_Z:
436
+ out_x_ptr += (
437
+ pid_b * stride_out_batch
438
+ + c_idx * chunk_size * stride_out_seqlen
439
+ + pid_h * stride_out_head
440
+ )
441
+ out_x_ptrs = out_x_ptr + (
442
+ stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :]
443
+ )
444
+ tl.store(
445
+ out_x_ptrs,
446
+ acc,
447
+ mask=(offs_out_m[:, None] < chunk_size_limit)
448
+ & (offs_out_n[None, :] < hdim),
449
+ )
450
+
451
+ z_ptr += (
452
+ pid_b * stride_z_batch
453
+ + c_idx * chunk_size * stride_z_seqlen
454
+ + pid_h * stride_z_head
455
+ )
456
+ z_ptrs = z_ptr + (
457
+ stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]
458
+ )
459
+ z = tl.load(
460
+ z_ptrs,
461
+ mask=(offs_out_m[:, None] < chunk_size_limit)
462
+ & (offs_out_n[None, :] < hdim),
463
+ other=0.0,
464
+ ).to(tl.float32)
465
+ acc *= z * tl.sigmoid(z)
466
+
467
+ out_ptr += (
468
+ pid_b * stride_out_batch
469
+ + c_idx * chunk_size * stride_out_seqlen
470
+ + pid_h * stride_out_head
471
+ )
472
+ out_ptrs = out_ptr + (
473
+ stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim
474
+ )
475
+ tl.store(
476
+ out_ptrs,
477
+ acc,
478
+ mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim),
479
+ )
480
+
481
+
482
+ def _chunk_scan_fwd(
483
+ cb,
484
+ x,
485
+ dt,
486
+ dA_cumsum,
487
+ C,
488
+ states,
489
+ D=None,
490
+ z=None,
491
+ seq_idx=None,
492
+ chunk_indices=None,
493
+ chunk_offsets=None,
494
+ initial_states=None,
495
+ out=None,
496
+ ):
497
+ batch, seqlen, nheads, headdim = x.shape
498
+ _, _, nchunks, chunk_size = dt.shape
499
+ _, _, ngroups, dstate = C.shape
500
+ assert nheads % ngroups == 0
501
+ assert C.shape == (batch, seqlen, ngroups, dstate)
502
+ assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
503
+ if z is not None:
504
+ assert z.shape == x.shape
505
+ if D is not None:
506
+ assert D.shape == (nheads, headdim) or D.shape == (nheads,)
507
+ assert dt.shape == (batch, nheads, nchunks, chunk_size)
508
+ assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
509
+ assert states.shape == (batch, nchunks, nheads, headdim, dstate)
510
+
511
+ if seq_idx is not None:
512
+ assert seq_idx.shape == (batch, seqlen)
513
+
514
+ if initial_states is not None:
515
+ # with initial states, we need to take care of how
516
+ # seq_idx crosses the boundaries
517
+ assert batch == 1, "chunk scan only supports initial states with batch 1"
518
+ assert (
519
+ chunk_indices is not None and chunk_offsets is not None
520
+ ), "chunk_indices and chunk_offsets should have been set"
521
+ else:
522
+ chunk_indices, chunk_offsets = None, None
523
+ else:
524
+ chunk_indices, chunk_offsets = None, None
525
+
526
+ assert out.shape == x.shape
527
+
528
+ if z is not None:
529
+ out_x = torch.empty_like(x)
530
+ assert out_x.stride() == out.stride()
531
+ else:
532
+ out_x = None
533
+
534
+ grid = lambda META: (
535
+ triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
536
+ * triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
537
+ batch * nchunks if chunk_offsets is None else len(chunk_offsets),
538
+ nheads,
539
+ )
540
+ z_strides = (
541
+ (z.stride(0), z.stride(1), z.stride(2), z.stride(3))
542
+ if z is not None
543
+ else (0, 0, 0, 0)
544
+ )
545
+ _chunk_scan_fwd_kernel[grid](
546
+ cb,
547
+ x,
548
+ z,
549
+ out,
550
+ out_x,
551
+ dt,
552
+ dA_cumsum,
553
+ seq_idx,
554
+ C,
555
+ states,
556
+ D,
557
+ initial_states,
558
+ chunk_indices,
559
+ chunk_offsets,
560
+ len(chunk_indices) if chunk_indices is not None else 0,
561
+ chunk_size,
562
+ headdim,
563
+ dstate,
564
+ batch,
565
+ seqlen,
566
+ nheads // ngroups,
567
+ cb.stride(0),
568
+ cb.stride(1),
569
+ cb.stride(2),
570
+ cb.stride(3),
571
+ cb.stride(4),
572
+ x.stride(0),
573
+ x.stride(1),
574
+ x.stride(2),
575
+ x.stride(3),
576
+ z_strides[0],
577
+ z_strides[1],
578
+ z_strides[2],
579
+ z_strides[3],
580
+ out.stride(0),
581
+ out.stride(1),
582
+ out.stride(2),
583
+ out.stride(3),
584
+ dt.stride(0),
585
+ dt.stride(2),
586
+ dt.stride(1),
587
+ dt.stride(3),
588
+ dA_cumsum.stride(0),
589
+ dA_cumsum.stride(2),
590
+ dA_cumsum.stride(1),
591
+ dA_cumsum.stride(3),
592
+ *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
593
+ C.stride(0),
594
+ C.stride(1),
595
+ C.stride(2),
596
+ C.stride(3),
597
+ states.stride(0),
598
+ states.stride(1),
599
+ states.stride(2),
600
+ states.stride(3),
601
+ states.stride(4),
602
+ *(
603
+ (
604
+ initial_states.stride(0),
605
+ initial_states.stride(1),
606
+ initial_states.stride(2),
607
+ initial_states.stride(3),
608
+ )
609
+ if initial_states is not None
610
+ else (0, 0, 0, 0)
611
+ ),
612
+ D.stride(0) if D is not None else 0,
613
+ True,
614
+ D is not None,
615
+ D.dim() == 2 if D is not None else True,
616
+ BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
617
+ HAS_Z=z is not None,
618
+ HAS_SEQ_IDX=seq_idx is not None,
619
+ IS_TRITON_22=TRITON_22,
620
+ HAS_INITSTATES=initial_states is not None,
621
+ )
622
+ return out_x