sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,584 @@
1
+ from dataclasses import astuple, dataclass
2
+ from functools import lru_cache
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
9
+ from sglang.srt.layers.attention.fla.chunk import chunk_gated_delta_rule
10
+ from sglang.srt.layers.attention.fla.fused_recurrent import (
11
+ fused_recurrent_gated_delta_rule_update,
12
+ )
13
+ from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (
14
+ fused_sigmoid_gating_delta_rule_update,
15
+ )
16
+ from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
17
+ causal_conv1d_fn,
18
+ causal_conv1d_update,
19
+ )
20
+ from sglang.srt.layers.radix_attention import RadixAttention
21
+ from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool
22
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
23
+ from sglang.srt.model_executor.model_runner import ModelRunner
24
+ from sglang.srt.models.qwen3_next import Qwen3HybridLinearDecoderLayer, fused_gdn_gating
25
+ from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
26
+
27
+
28
+ @dataclass
29
+ class ForwardMetadata:
30
+ query_start_loc: Optional[torch.Tensor]
31
+ mamba_cache_indices: torch.Tensor
32
+
33
+
34
+ class MambaAttnBackend(AttentionBackend):
35
+ """Attention backend using Mamba kernel."""
36
+
37
+ def __init__(self, model_runner: ModelRunner):
38
+ super().__init__()
39
+ self.pad_slot_id = -1 # Default pad slot id
40
+ self.device = model_runner.device
41
+ self.req_to_token_pool: HybridReqToTokenPool = model_runner.req_to_token_pool
42
+ self.forward_metadata: ForwardMetadata = None
43
+ self.state_indices_list = []
44
+ self.query_start_loc_list = []
45
+
46
+ @classmethod
47
+ @lru_cache(maxsize=128)
48
+ def _get_cached_arange(cls, bs: int, device_str: str) -> torch.Tensor:
49
+ """Cache torch.arange tensors for common batch sizes to avoid repeated allocation."""
50
+ device = torch.device(device_str)
51
+ return torch.arange(0, bs + 1, dtype=torch.int32, device=device)
52
+
53
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
54
+ bs = forward_batch.batch_size
55
+ if forward_batch.forward_mode.is_decode_or_idle():
56
+ query_start_loc = self._get_cached_arange(bs, str(self.device))
57
+ elif forward_batch.forward_mode.is_extend():
58
+ if forward_batch.forward_mode.is_target_verify():
59
+ query_start_loc = torch.arange(
60
+ 0,
61
+ forward_batch.input_ids.shape[0] + 1,
62
+ step=forward_batch.spec_info.draft_token_num,
63
+ dtype=torch.int32,
64
+ device=forward_batch.input_ids.device,
65
+ )
66
+ else:
67
+ query_start_loc = torch.empty(
68
+ (bs + 1,), dtype=torch.int32, device=self.device
69
+ )
70
+ query_start_loc[:bs] = forward_batch.extend_start_loc
71
+ query_start_loc[bs] = (
72
+ forward_batch.extend_start_loc[-1]
73
+ + forward_batch.extend_seq_lens[-1]
74
+ )
75
+ else:
76
+ raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode=}")
77
+ mamba_cache_indices = self.req_to_token_pool.get_mamba_indices(
78
+ forward_batch.req_pool_indices
79
+ )
80
+ self.forward_metadata = ForwardMetadata(
81
+ query_start_loc=query_start_loc,
82
+ mamba_cache_indices=mamba_cache_indices,
83
+ )
84
+
85
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
86
+ for i in range(max_bs):
87
+ self.state_indices_list.append(
88
+ torch.full((i + 1,), self.pad_slot_id, dtype=torch.int32, device="cuda")
89
+ )
90
+ self.query_start_loc_list.append(
91
+ torch.empty((i + 2,), dtype=torch.int32, device="cuda")
92
+ )
93
+
94
+ def init_forward_metadata_capture_cuda_graph(
95
+ self,
96
+ bs: int,
97
+ num_tokens: int,
98
+ req_pool_indices: torch.Tensor,
99
+ seq_lens: torch.Tensor,
100
+ encoder_lens: Optional[torch.Tensor],
101
+ forward_mode: ForwardMode,
102
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
103
+ ):
104
+ if forward_mode.is_decode_or_idle():
105
+ self.query_start_loc_list[bs - 1].copy_(self._get_cached_arange(bs, "cuda"))
106
+ elif forward_mode.is_target_verify():
107
+ self.query_start_loc_list[bs - 1].copy_(
108
+ torch.arange(
109
+ 0,
110
+ bs * spec_info.draft_token_num + 1,
111
+ step=spec_info.draft_token_num,
112
+ dtype=torch.int32,
113
+ device="cuda",
114
+ )
115
+ )
116
+ else:
117
+ raise ValueError(f"Invalid forward mode: {forward_mode=}")
118
+ mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices)
119
+ self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices)
120
+ self.forward_metadata = ForwardMetadata(
121
+ query_start_loc=self.query_start_loc_list[bs - 1],
122
+ mamba_cache_indices=self.state_indices_list[bs - 1],
123
+ )
124
+
125
+ def init_forward_metadata_replay_cuda_graph(
126
+ self,
127
+ bs: int,
128
+ req_pool_indices: torch.Tensor,
129
+ seq_lens: torch.Tensor,
130
+ seq_lens_sum: int,
131
+ encoder_lens: Optional[torch.Tensor],
132
+ forward_mode: ForwardMode,
133
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
134
+ seq_lens_cpu: Optional[torch.Tensor],
135
+ ):
136
+ num_padding = torch.count_nonzero(
137
+ seq_lens_cpu == self.get_cuda_graph_seq_len_fill_value()
138
+ )
139
+ # Make sure forward metadata is correctly handled for padding reqs
140
+ req_pool_indices[bs - num_padding :] = 0
141
+ mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices)
142
+ mamba_indices[bs - num_padding :] = -1
143
+ self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices)
144
+ if forward_mode.is_decode_or_idle():
145
+ self.query_start_loc_list[bs - 1].copy_(self._get_cached_arange(bs, "cuda"))
146
+ if num_padding > 0:
147
+ self.query_start_loc_list[bs - 1][bs - num_padding :] = bs - num_padding
148
+ elif forward_mode.is_target_verify():
149
+ self.query_start_loc_list[bs - 1].copy_(
150
+ torch.arange(
151
+ 0,
152
+ bs * spec_info.draft_token_num + 1,
153
+ step=spec_info.draft_token_num,
154
+ dtype=torch.int32,
155
+ device="cuda",
156
+ )
157
+ )
158
+ if num_padding > 0:
159
+ self.query_start_loc_list[bs - 1][bs - num_padding :] = (
160
+ bs - num_padding
161
+ ) * spec_info.draft_token_num
162
+ else:
163
+ raise ValueError(f"Invalid forward mode: {forward_mode=}")
164
+
165
+ self.forward_metadata = ForwardMetadata(
166
+ query_start_loc=self.query_start_loc_list[bs - 1],
167
+ mamba_cache_indices=self.state_indices_list[bs - 1],
168
+ )
169
+
170
+ def get_cuda_graph_seq_len_fill_value(self):
171
+ return 1 # Mamba attn does not use seq lens to index kv cache
172
+
173
+ def forward_decode(
174
+ self,
175
+ q: torch.Tensor,
176
+ k: torch.Tensor,
177
+ v: torch.Tensor,
178
+ layer: RadixAttention,
179
+ forward_batch: ForwardBatch,
180
+ save_kv_cache: bool = True,
181
+ **kwargs,
182
+ ):
183
+ mixed_qkv = kwargs["mixed_qkv"]
184
+ conv_weights = kwargs["conv_weights"]
185
+ bias = kwargs["bias"]
186
+ activation = kwargs["activation"]
187
+ key_dim = kwargs["key_dim"]
188
+ value_dim = kwargs["value_dim"]
189
+ attn_tp_size = kwargs["attention_tp_size"]
190
+ head_k_dim = kwargs["head_k_dim"]
191
+ head_v_dim = kwargs["head_v_dim"]
192
+ a = kwargs["a"]
193
+ b = kwargs["b"]
194
+ A_log = kwargs["A_log"]
195
+ dt_bias = kwargs["dt_bias"]
196
+ layer_id = kwargs["layer_id"]
197
+
198
+ conv_states, ssm_states, *rest = self.req_to_token_pool.get_mamba_params(
199
+ layer_id
200
+ )
201
+ query_start_loc = self.forward_metadata.query_start_loc
202
+ cache_indices = self.forward_metadata.mamba_cache_indices
203
+
204
+ mixed_qkv = causal_conv1d_update(
205
+ mixed_qkv,
206
+ conv_states,
207
+ conv_weights,
208
+ bias,
209
+ activation,
210
+ conv_state_indices=cache_indices,
211
+ )
212
+
213
+ query, key, value = torch.split(
214
+ mixed_qkv,
215
+ [
216
+ key_dim // attn_tp_size,
217
+ key_dim // attn_tp_size,
218
+ value_dim // attn_tp_size,
219
+ ],
220
+ dim=-1,
221
+ )
222
+ # Reshape from [l, h*d] to [1, l, h, d]
223
+ seq_len = query.shape[0]
224
+ num_heads = query.shape[1] // head_k_dim
225
+ query = query.view(1, seq_len, num_heads, head_k_dim)
226
+ key = key.view(1, seq_len, num_heads, head_k_dim)
227
+ value = value.view(1, seq_len, value.shape[1] // head_v_dim, head_v_dim)
228
+
229
+ core_attn_out = fused_sigmoid_gating_delta_rule_update(
230
+ A_log=A_log,
231
+ dt_bias=dt_bias,
232
+ q=query,
233
+ k=key,
234
+ v=value,
235
+ a=a,
236
+ b=b,
237
+ initial_state_source=ssm_states,
238
+ initial_state_indices=cache_indices,
239
+ cu_seqlens=query_start_loc,
240
+ use_qk_l2norm_in_kernel=True,
241
+ softplus_beta=1.0,
242
+ softplus_threshold=20.0,
243
+ )
244
+
245
+ return core_attn_out
246
+
247
+ def forward_extend(
248
+ self,
249
+ q: torch.Tensor,
250
+ k: torch.Tensor,
251
+ v: torch.Tensor,
252
+ layer: RadixAttention,
253
+ forward_batch: ForwardBatch,
254
+ save_kv_cache: bool = True,
255
+ **kwargs,
256
+ ):
257
+ mixed_qkv = kwargs["mixed_qkv"]
258
+ conv_weights = kwargs["conv_weights"]
259
+ bias = kwargs["bias"]
260
+ activation = kwargs["activation"]
261
+ key_dim = kwargs["key_dim"]
262
+ value_dim = kwargs["value_dim"]
263
+ attn_tp_size = kwargs["attention_tp_size"]
264
+ head_k_dim = kwargs["head_k_dim"]
265
+ head_v_dim = kwargs["head_v_dim"]
266
+ a = kwargs["a"]
267
+ b = kwargs["b"]
268
+ A_log = kwargs["A_log"]
269
+ dt_bias = kwargs["dt_bias"]
270
+ layer_id = kwargs["layer_id"]
271
+ seq_len = kwargs["seq_len"]
272
+
273
+ is_target_verify = forward_batch.forward_mode.is_target_verify()
274
+
275
+ query_start_loc = self.forward_metadata.query_start_loc
276
+ cache_indices = self.forward_metadata.mamba_cache_indices
277
+
278
+ if is_target_verify:
279
+ (
280
+ conv_states,
281
+ ssm_states,
282
+ intermediate_state_cache,
283
+ intermediate_conv_window_cache,
284
+ ) = self.req_to_token_pool.get_mamba_params(layer_id)
285
+ has_initial_states = torch.ones(
286
+ seq_len // forward_batch.spec_info.draft_token_num,
287
+ dtype=torch.bool,
288
+ device=forward_batch.input_ids.device,
289
+ )
290
+ conv_states_to_use = conv_states.clone()
291
+ else:
292
+ conv_states, ssm_states, *rest = self.req_to_token_pool.get_mamba_params(
293
+ layer_id
294
+ )
295
+ has_initial_states = forward_batch.extend_prefix_lens > 0
296
+ conv_states_to_use = conv_states
297
+
298
+ if is_target_verify:
299
+ batch_size = seq_len // forward_batch.spec_info.draft_token_num
300
+ draft_token_num = forward_batch.spec_info.draft_token_num
301
+ mixed_qkv_reshaped = (
302
+ mixed_qkv.view(batch_size, draft_token_num, -1)
303
+ .transpose(1, 2)
304
+ .contiguous()
305
+ )
306
+ mixed_qkv_processed = causal_conv1d_update(
307
+ mixed_qkv_reshaped,
308
+ conv_states_to_use,
309
+ conv_weights,
310
+ bias,
311
+ activation,
312
+ conv_state_indices=cache_indices[:batch_size],
313
+ intermediate_conv_window=intermediate_conv_window_cache,
314
+ )
315
+ mixed_qkv = (
316
+ mixed_qkv_processed.transpose(1, 2).contiguous().view(seq_len, -1)
317
+ )
318
+ else:
319
+ mixed_qkv = causal_conv1d_fn(
320
+ mixed_qkv.transpose(0, 1),
321
+ conv_weights,
322
+ bias,
323
+ activation=activation,
324
+ conv_states=conv_states_to_use,
325
+ has_initial_state=has_initial_states,
326
+ cache_indices=cache_indices,
327
+ query_start_loc=query_start_loc,
328
+ ).transpose(0, 1)[:seq_len]
329
+
330
+ key_split_dim = key_dim // attn_tp_size
331
+ value_split_dim = value_dim // attn_tp_size
332
+
333
+ query, key, value = torch.split(
334
+ mixed_qkv,
335
+ [key_split_dim, key_split_dim, value_split_dim],
336
+ dim=-1,
337
+ )
338
+
339
+ actual_seq_len = query.shape[0]
340
+ num_heads = query.shape[1] // head_k_dim
341
+ num_value_heads = value.shape[1] // head_v_dim
342
+
343
+ query = query.view(1, actual_seq_len, num_heads, head_k_dim)
344
+ key = key.view(1, actual_seq_len, num_heads, head_k_dim)
345
+ value = value.view(1, actual_seq_len, num_value_heads, head_v_dim)
346
+
347
+ beta = b.sigmoid()
348
+ g = fused_gdn_gating(A_log, a, dt_bias)
349
+
350
+ g = g.unsqueeze(0)
351
+ beta = beta.unsqueeze(0)
352
+
353
+ if is_target_verify:
354
+ core_attn_out = fused_recurrent_gated_delta_rule_update(
355
+ q=query,
356
+ k=key,
357
+ v=value,
358
+ g=g,
359
+ beta=beta,
360
+ initial_state_source=ssm_states,
361
+ initial_state_indices=cache_indices,
362
+ cu_seqlens=query_start_loc,
363
+ use_qk_l2norm_in_kernel=True,
364
+ disable_state_update=True,
365
+ intermediate_states_buffer=intermediate_state_cache,
366
+ cache_steps=forward_batch.spec_info.draft_token_num,
367
+ )
368
+ else:
369
+ recurrent_state = ssm_states[cache_indices]
370
+ core_attn_out, last_recurrent_state = chunk_gated_delta_rule(
371
+ q=query,
372
+ k=key,
373
+ v=value,
374
+ g=g,
375
+ beta=beta,
376
+ initial_state=recurrent_state,
377
+ output_final_state=True,
378
+ cu_seqlens=query_start_loc,
379
+ head_first=False,
380
+ use_qk_l2norm_in_kernel=True,
381
+ )
382
+ last_recurrent_state = last_recurrent_state.to(ssm_states.dtype, copy=False)
383
+ ssm_states[cache_indices] = last_recurrent_state
384
+
385
+ return core_attn_out
386
+
387
+
388
+ class HybridLinearAttnBackend(AttentionBackend):
389
+ """Support different backends for prefill and decode."""
390
+
391
+ def __init__(
392
+ self,
393
+ full_attn_backend: AttentionBackend,
394
+ linear_attn_backend: AttentionBackend,
395
+ full_attn_layers: list[int],
396
+ ):
397
+ self.full_attn_layers = full_attn_layers
398
+ self.attn_backend_list = [full_attn_backend, linear_attn_backend]
399
+
400
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
401
+ for attn_backend in self.attn_backend_list:
402
+ attn_backend.init_forward_metadata(forward_batch)
403
+
404
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
405
+ for attn_backend in self.attn_backend_list:
406
+ attn_backend.init_cuda_graph_state(max_bs, max_num_tokens)
407
+
408
+ def init_forward_metadata_capture_cuda_graph(
409
+ self,
410
+ bs: int,
411
+ num_tokens: int,
412
+ req_pool_indices: torch.Tensor,
413
+ seq_lens: torch.Tensor,
414
+ encoder_lens: Optional[torch.Tensor],
415
+ forward_mode: ForwardMode,
416
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
417
+ ):
418
+ for attn_backend in self.attn_backend_list:
419
+ attn_backend.init_forward_metadata_capture_cuda_graph(
420
+ bs,
421
+ num_tokens,
422
+ req_pool_indices,
423
+ seq_lens,
424
+ encoder_lens,
425
+ forward_mode,
426
+ spec_info,
427
+ )
428
+
429
+ def init_forward_metadata_replay_cuda_graph(
430
+ self,
431
+ bs: int,
432
+ req_pool_indices: torch.Tensor,
433
+ seq_lens: torch.Tensor,
434
+ seq_lens_sum: int,
435
+ encoder_lens: Optional[torch.Tensor],
436
+ forward_mode: ForwardMode,
437
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
438
+ seq_lens_cpu: Optional[torch.Tensor],
439
+ ):
440
+ for attn_backend in self.attn_backend_list:
441
+ attn_backend.init_forward_metadata_replay_cuda_graph(
442
+ bs,
443
+ req_pool_indices,
444
+ seq_lens,
445
+ seq_lens_sum,
446
+ encoder_lens,
447
+ forward_mode,
448
+ spec_info,
449
+ seq_lens_cpu,
450
+ )
451
+
452
+ def get_cuda_graph_seq_len_fill_value(self):
453
+ return self.attn_backend_list[0].get_cuda_graph_seq_len_fill_value()
454
+
455
+ def forward_decode(
456
+ self,
457
+ q: torch.Tensor,
458
+ k: torch.Tensor,
459
+ v: torch.Tensor,
460
+ layer: RadixAttention,
461
+ forward_batch: ForwardBatch,
462
+ save_kv_cache: bool = True,
463
+ **kwargs,
464
+ ):
465
+ layer_id = layer.layer_id if layer else kwargs["layer_id"]
466
+ if layer_id in self.full_attn_layers:
467
+ return self.attn_backend_list[0].forward_decode(
468
+ q, k, v, layer, forward_batch, save_kv_cache, **kwargs
469
+ )
470
+ return self.attn_backend_list[1].forward_decode(
471
+ q, k, v, layer, forward_batch, save_kv_cache, **kwargs
472
+ )
473
+
474
+ def forward_extend(
475
+ self,
476
+ q: torch.Tensor,
477
+ k: torch.Tensor,
478
+ v: torch.Tensor,
479
+ layer: RadixAttention,
480
+ forward_batch: ForwardBatch,
481
+ save_kv_cache: bool = True,
482
+ **kwargs,
483
+ ):
484
+ layer_id = layer.layer_id if layer else kwargs["layer_id"]
485
+ if layer_id in self.full_attn_layers:
486
+ return self.attn_backend_list[0].forward_extend(
487
+ q, k, v, layer, forward_batch, save_kv_cache, **kwargs
488
+ )
489
+ return self.attn_backend_list[1].forward_extend(
490
+ q, k, v, layer, forward_batch, save_kv_cache, **kwargs
491
+ )
492
+
493
+ def forward(
494
+ self,
495
+ q: torch.Tensor,
496
+ k: torch.Tensor,
497
+ v: torch.Tensor,
498
+ layer: RadixAttention,
499
+ forward_batch: ForwardBatch,
500
+ save_kv_cache: bool = True,
501
+ **kwargs,
502
+ ):
503
+ """Run forward on an attention layer."""
504
+ if forward_batch.forward_mode.is_idle():
505
+ if layer is None:
506
+ return torch.empty_like(kwargs["z"])
507
+ return q.new_empty(q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
508
+ elif forward_batch.forward_mode.is_decode():
509
+ return self.forward_decode(
510
+ q,
511
+ k,
512
+ v,
513
+ layer,
514
+ forward_batch,
515
+ save_kv_cache=save_kv_cache,
516
+ **kwargs,
517
+ )
518
+ else:
519
+ return self.forward_extend(
520
+ q,
521
+ k,
522
+ v,
523
+ layer,
524
+ forward_batch,
525
+ save_kv_cache=save_kv_cache,
526
+ **kwargs,
527
+ )
528
+
529
+ def update_mamba_state_after_mtp_verify(self, accepted_length, model):
530
+ request_number = accepted_length.shape[0]
531
+
532
+ state_indices_tensor = self.attn_backend_list[
533
+ 1
534
+ ].forward_metadata.mamba_cache_indices[:request_number]
535
+
536
+ mamba_caches = self.attn_backend_list[
537
+ 1
538
+ ].req_to_token_pool.get_mamba_params_all_layers()
539
+
540
+ (
541
+ conv_states,
542
+ ssm_states,
543
+ intermediate_state_cache,
544
+ intermediate_conv_window_cache,
545
+ ) = mamba_caches
546
+
547
+ # SSM state updates (chunked to reduce peak memory)
548
+ valid_mask = accepted_length > 0
549
+
550
+ # Compute common indices once to avoid duplication
551
+ last_steps_all = (accepted_length - 1).to(torch.int64)
552
+ valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64)
553
+ last_steps = last_steps_all[valid_mask].to(torch.int64)
554
+
555
+ if valid_state_indices.numel() > 0:
556
+ chunk = 256
557
+ num_valid = valid_state_indices.numel()
558
+
559
+ # SSM state updates
560
+ for i in range(0, num_valid, chunk):
561
+ idx = valid_state_indices[i : i + chunk]
562
+ steps = last_steps[i : i + chunk]
563
+ # per (cache line, step)
564
+ for j in range(idx.numel()):
565
+ ci = idx[j].item()
566
+ st = steps[j].item()
567
+ ssm_states[:, ci, :].copy_(
568
+ intermediate_state_cache[:, ci, st].to(
569
+ ssm_states.dtype, copy=False
570
+ )
571
+ )
572
+
573
+ # Conv window updates
574
+ for i in range(0, num_valid, chunk):
575
+ idx = valid_state_indices[i : i + chunk]
576
+ steps = last_steps[i : i + chunk]
577
+ for j in range(idx.numel()):
578
+ ci = idx[j].item()
579
+ st = steps[j].item()
580
+ conv_states[:, ci, :, :].copy_(
581
+ intermediate_conv_window_cache[:, ci, st].to(
582
+ conv_states.dtype, copy=False
583
+ )
584
+ )
@@ -49,6 +49,9 @@ class IntelAMXAttnBackend(AttentionBackend):
49
49
  max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
50
50
  self.forward_metadata = (attn_logits, max_extend_len)
51
51
 
52
+ def get_graph_seq_len_fill_value(self):
53
+ return 1
54
+
52
55
  def forward_extend(
53
56
  self,
54
57
  q,