sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/device_config.py +3 -1
  5. sglang/srt/configs/dots_vlm.py +139 -0
  6. sglang/srt/configs/load_config.py +1 -0
  7. sglang/srt/configs/model_config.py +50 -6
  8. sglang/srt/configs/qwen3_next.py +326 -0
  9. sglang/srt/connector/__init__.py +8 -1
  10. sglang/srt/connector/remote_instance.py +82 -0
  11. sglang/srt/constrained/base_grammar_backend.py +48 -12
  12. sglang/srt/constrained/llguidance_backend.py +0 -1
  13. sglang/srt/constrained/outlines_backend.py +0 -1
  14. sglang/srt/constrained/xgrammar_backend.py +28 -9
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/base/conn.py +1 -1
  21. sglang/srt/disaggregation/common/conn.py +15 -12
  22. sglang/srt/disaggregation/decode.py +21 -10
  23. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -445
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +5 -3
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +24 -3
  31. sglang/srt/entrypoints/engine.py +38 -17
  32. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  33. sglang/srt/entrypoints/grpc_server.py +680 -0
  34. sglang/srt/entrypoints/http_server.py +85 -54
  35. sglang/srt/entrypoints/openai/protocol.py +4 -1
  36. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  37. sglang/srt/entrypoints/openai/serving_chat.py +36 -16
  38. sglang/srt/entrypoints/openai/serving_completions.py +12 -3
  39. sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
  40. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  41. sglang/srt/entrypoints/openai/serving_responses.py +6 -3
  42. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  43. sglang/srt/eplb/eplb_manager.py +2 -2
  44. sglang/srt/eplb/expert_distribution.py +26 -13
  45. sglang/srt/eplb/expert_location.py +8 -3
  46. sglang/srt/eplb/expert_location_updater.py +1 -1
  47. sglang/srt/function_call/base_format_detector.py +3 -6
  48. sglang/srt/function_call/ebnf_composer.py +11 -9
  49. sglang/srt/function_call/function_call_parser.py +6 -0
  50. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  51. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  52. sglang/srt/grpc/__init__.py +1 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  56. sglang/srt/hf_transformers_utils.py +4 -0
  57. sglang/srt/layers/activation.py +142 -9
  58. sglang/srt/layers/attention/ascend_backend.py +11 -4
  59. sglang/srt/layers/attention/fla/chunk.py +242 -0
  60. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  61. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  62. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  63. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  66. sglang/srt/layers/attention/fla/index.py +37 -0
  67. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  68. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  69. sglang/srt/layers/attention/fla/op.py +66 -0
  70. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  71. sglang/srt/layers/attention/fla/utils.py +331 -0
  72. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  73. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  74. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  75. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  76. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  77. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  78. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  79. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  80. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  81. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  82. sglang/srt/layers/attention/triton_backend.py +18 -1
  83. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  84. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  85. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  86. sglang/srt/layers/dp_attention.py +30 -1
  87. sglang/srt/layers/layernorm.py +32 -15
  88. sglang/srt/layers/linear.py +34 -3
  89. sglang/srt/layers/logits_processor.py +29 -10
  90. sglang/srt/layers/moe/__init__.py +2 -1
  91. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  92. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  93. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  94. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  95. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  96. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  107. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  108. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  109. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  110. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  111. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  112. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  113. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  114. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  115. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  116. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  117. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  118. sglang/srt/layers/moe/topk.py +30 -9
  119. sglang/srt/layers/moe/utils.py +12 -6
  120. sglang/srt/layers/quantization/awq.py +19 -7
  121. sglang/srt/layers/quantization/base_config.py +11 -6
  122. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  123. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  124. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  125. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  126. sglang/srt/layers/quantization/fp8.py +76 -47
  127. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  128. sglang/srt/layers/quantization/gptq.py +25 -17
  129. sglang/srt/layers/quantization/modelopt_quant.py +147 -47
  130. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  131. sglang/srt/layers/quantization/mxfp4.py +64 -40
  132. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  133. sglang/srt/layers/quantization/unquant.py +135 -47
  134. sglang/srt/layers/quantization/w4afp8.py +30 -17
  135. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  136. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  137. sglang/srt/layers/sampler.py +162 -18
  138. sglang/srt/lora/backend/base_backend.py +50 -8
  139. sglang/srt/lora/backend/triton_backend.py +90 -2
  140. sglang/srt/lora/layers.py +32 -0
  141. sglang/srt/lora/lora.py +4 -1
  142. sglang/srt/lora/lora_manager.py +35 -112
  143. sglang/srt/lora/mem_pool.py +24 -10
  144. sglang/srt/lora/utils.py +18 -9
  145. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  146. sglang/srt/managers/cache_controller.py +158 -160
  147. sglang/srt/managers/data_parallel_controller.py +105 -35
  148. sglang/srt/managers/detokenizer_manager.py +8 -4
  149. sglang/srt/managers/disagg_service.py +46 -0
  150. sglang/srt/managers/io_struct.py +199 -12
  151. sglang/srt/managers/mm_utils.py +1 -0
  152. sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
  153. sglang/srt/managers/schedule_batch.py +77 -56
  154. sglang/srt/managers/schedule_policy.py +1 -1
  155. sglang/srt/managers/scheduler.py +187 -39
  156. sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
  157. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  158. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  159. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  160. sglang/srt/managers/tokenizer_manager.py +259 -519
  161. sglang/srt/managers/tp_worker.py +53 -4
  162. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  163. sglang/srt/mem_cache/hicache_storage.py +3 -23
  164. sglang/srt/mem_cache/hiradix_cache.py +103 -43
  165. sglang/srt/mem_cache/memory_pool.py +347 -48
  166. sglang/srt/mem_cache/memory_pool_host.py +105 -46
  167. sglang/srt/mem_cache/radix_cache.py +0 -2
  168. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  169. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  170. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
  171. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  172. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  173. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
  174. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  175. sglang/srt/metrics/collector.py +493 -76
  176. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  177. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  178. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  179. sglang/srt/model_executor/forward_batch_info.py +59 -2
  180. sglang/srt/model_executor/model_runner.py +356 -29
  181. sglang/srt/model_loader/__init__.py +9 -3
  182. sglang/srt/model_loader/loader.py +128 -4
  183. sglang/srt/model_loader/weight_utils.py +2 -1
  184. sglang/srt/models/apertus.py +686 -0
  185. sglang/srt/models/bailing_moe.py +798 -218
  186. sglang/srt/models/bailing_moe_nextn.py +168 -0
  187. sglang/srt/models/deepseek_v2.py +109 -15
  188. sglang/srt/models/dots_vlm.py +174 -0
  189. sglang/srt/models/dots_vlm_vit.py +337 -0
  190. sglang/srt/models/ernie4.py +1 -1
  191. sglang/srt/models/gemma3n_mm.py +1 -1
  192. sglang/srt/models/glm4_moe.py +1 -1
  193. sglang/srt/models/glm4v.py +4 -2
  194. sglang/srt/models/glm4v_moe.py +3 -0
  195. sglang/srt/models/gpt_oss.py +1 -1
  196. sglang/srt/models/llama4.py +9 -0
  197. sglang/srt/models/llama_eagle3.py +13 -0
  198. sglang/srt/models/longcat_flash.py +2 -2
  199. sglang/srt/models/mllama4.py +25 -0
  200. sglang/srt/models/opt.py +637 -0
  201. sglang/srt/models/qwen2.py +7 -0
  202. sglang/srt/models/qwen2_5_vl.py +27 -3
  203. sglang/srt/models/qwen2_moe.py +56 -12
  204. sglang/srt/models/qwen3_moe.py +1 -1
  205. sglang/srt/models/qwen3_next.py +1042 -0
  206. sglang/srt/models/qwen3_next_mtp.py +112 -0
  207. sglang/srt/models/step3_vl.py +1 -1
  208. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  209. sglang/srt/multimodal/processors/glm4v.py +9 -9
  210. sglang/srt/multimodal/processors/internvl.py +141 -129
  211. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  212. sglang/srt/offloader.py +27 -3
  213. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  214. sglang/srt/sampling/sampling_batch_info.py +18 -15
  215. sglang/srt/server_args.py +276 -35
  216. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  217. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  218. sglang/srt/speculative/eagle_utils.py +0 -2
  219. sglang/srt/speculative/eagle_worker.py +43 -4
  220. sglang/srt/speculative/spec_info.py +5 -0
  221. sglang/srt/speculative/standalone_worker.py +109 -0
  222. sglang/srt/tracing/trace.py +552 -0
  223. sglang/srt/utils.py +34 -3
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  226. sglang/test/runners.py +4 -0
  227. sglang/test/test_cutlass_moe.py +24 -6
  228. sglang/test/test_disaggregation_utils.py +66 -0
  229. sglang/test/test_fp4_moe.py +370 -1
  230. sglang/test/test_utils.py +28 -1
  231. sglang/utils.py +11 -0
  232. sglang/version.py +1 -1
  233. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  234. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
  235. sglang/srt/disaggregation/launch_lb.py +0 -118
  236. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  237. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  238. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,602 @@
1
+ from dataclasses import astuple, dataclass
2
+ from functools import lru_cache
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
9
+ from sglang.srt.layers.attention.fla.chunk import chunk_gated_delta_rule
10
+ from sglang.srt.layers.attention.fla.fused_recurrent import (
11
+ fused_recurrent_gated_delta_rule_update,
12
+ )
13
+ from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (
14
+ fused_sigmoid_gating_delta_rule_update,
15
+ )
16
+ from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
17
+ causal_conv1d_fn,
18
+ causal_conv1d_update,
19
+ )
20
+ from sglang.srt.layers.radix_attention import RadixAttention
21
+ from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool
22
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
23
+ from sglang.srt.model_executor.model_runner import ModelRunner
24
+ from sglang.srt.models.qwen3_next import Qwen3HybridLinearDecoderLayer, fused_gdn_gating
25
+ from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
26
+ from sglang.srt.utils import is_npu
27
+
28
+ if is_npu():
29
+ from sgl_kernel_npu.fla.chunk import chunk_gated_delta_rule_npu
30
+ from sgl_kernel_npu.fla.fused_sigmoid_gating_recurrent import (
31
+ fused_sigmoid_gating_delta_rule_update_npu,
32
+ )
33
+ from sgl_kernel_npu.mamba.causal_conv1d import (
34
+ causal_conv1d_fn_npu,
35
+ causal_conv1d_update_npu,
36
+ )
37
+
38
+ chunk_gated_delta_rule = chunk_gated_delta_rule_npu
39
+ fused_sigmoid_gating_delta_rule_update = fused_sigmoid_gating_delta_rule_update_npu
40
+ causal_conv1d_fn = causal_conv1d_fn_npu
41
+ causal_conv1d_update = causal_conv1d_update_npu
42
+
43
+
44
+ @dataclass
45
+ class ForwardMetadata:
46
+ query_start_loc: Optional[torch.Tensor]
47
+ mamba_cache_indices: torch.Tensor
48
+
49
+
50
+ class MambaAttnBackend(AttentionBackend):
51
+ """Attention backend using Mamba kernel."""
52
+
53
+ def __init__(self, model_runner: ModelRunner):
54
+ super().__init__()
55
+ self.pad_slot_id = -1 # Default pad slot id
56
+ self.device = model_runner.device
57
+ self.req_to_token_pool: HybridReqToTokenPool = model_runner.req_to_token_pool
58
+ self.forward_metadata: ForwardMetadata = None
59
+ self.state_indices_list = []
60
+ self.query_start_loc_list = []
61
+
62
+ @classmethod
63
+ @lru_cache(maxsize=128)
64
+ def _get_cached_arange(cls, bs: int, device_str: str) -> torch.Tensor:
65
+ """Cache torch.arange tensors for common batch sizes to avoid repeated allocation."""
66
+ device = torch.device(device_str)
67
+ return torch.arange(0, bs + 1, dtype=torch.int32, device=device)
68
+
69
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
70
+ bs = forward_batch.batch_size
71
+ if forward_batch.forward_mode.is_decode_or_idle():
72
+ query_start_loc = self._get_cached_arange(bs, str(self.device))
73
+ elif forward_batch.forward_mode.is_extend():
74
+ if forward_batch.forward_mode.is_target_verify():
75
+ query_start_loc = torch.arange(
76
+ 0,
77
+ forward_batch.input_ids.shape[0] + 1,
78
+ step=forward_batch.spec_info.draft_token_num,
79
+ dtype=torch.int32,
80
+ device=forward_batch.input_ids.device,
81
+ )
82
+ else:
83
+ query_start_loc = torch.empty(
84
+ (bs + 1,), dtype=torch.int32, device=self.device
85
+ )
86
+ query_start_loc[:bs] = forward_batch.extend_start_loc
87
+ query_start_loc[bs] = (
88
+ forward_batch.extend_start_loc[-1]
89
+ + forward_batch.extend_seq_lens[-1]
90
+ )
91
+ else:
92
+ raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode=}")
93
+ mamba_cache_indices = self.req_to_token_pool.get_mamba_indices(
94
+ forward_batch.req_pool_indices
95
+ )
96
+ self.forward_metadata = ForwardMetadata(
97
+ query_start_loc=query_start_loc,
98
+ mamba_cache_indices=mamba_cache_indices,
99
+ )
100
+
101
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
102
+ for i in range(max_bs):
103
+ self.state_indices_list.append(
104
+ torch.full(
105
+ (i + 1,), self.pad_slot_id, dtype=torch.int32, device=self.device
106
+ )
107
+ )
108
+ self.query_start_loc_list.append(
109
+ torch.empty((i + 2,), dtype=torch.int32, device=self.device)
110
+ )
111
+
112
+ def init_forward_metadata_capture_cuda_graph(
113
+ self,
114
+ bs: int,
115
+ num_tokens: int,
116
+ req_pool_indices: torch.Tensor,
117
+ seq_lens: torch.Tensor,
118
+ encoder_lens: Optional[torch.Tensor],
119
+ forward_mode: ForwardMode,
120
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
121
+ ):
122
+ if forward_mode.is_decode_or_idle():
123
+ self.query_start_loc_list[bs - 1].copy_(self._get_cached_arange(bs, "cuda"))
124
+ elif forward_mode.is_target_verify():
125
+ self.query_start_loc_list[bs - 1].copy_(
126
+ torch.arange(
127
+ 0,
128
+ bs * spec_info.draft_token_num + 1,
129
+ step=spec_info.draft_token_num,
130
+ dtype=torch.int32,
131
+ device=self.device,
132
+ )
133
+ )
134
+ else:
135
+ raise ValueError(f"Invalid forward mode: {forward_mode=}")
136
+ mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices)
137
+ self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices)
138
+ self.forward_metadata = ForwardMetadata(
139
+ query_start_loc=self.query_start_loc_list[bs - 1],
140
+ mamba_cache_indices=self.state_indices_list[bs - 1],
141
+ )
142
+
143
+ def init_forward_metadata_replay_cuda_graph(
144
+ self,
145
+ bs: int,
146
+ req_pool_indices: torch.Tensor,
147
+ seq_lens: torch.Tensor,
148
+ seq_lens_sum: int,
149
+ encoder_lens: Optional[torch.Tensor],
150
+ forward_mode: ForwardMode,
151
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
152
+ seq_lens_cpu: Optional[torch.Tensor],
153
+ ):
154
+ num_padding = torch.count_nonzero(
155
+ seq_lens_cpu == self.get_cuda_graph_seq_len_fill_value()
156
+ )
157
+ # Make sure forward metadata is correctly handled for padding reqs
158
+ req_pool_indices[bs - num_padding :] = 0
159
+ mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices)
160
+ mamba_indices[bs - num_padding :] = -1
161
+ self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices)
162
+ if forward_mode.is_decode_or_idle():
163
+ self.query_start_loc_list[bs - 1].copy_(self._get_cached_arange(bs, "cuda"))
164
+ if num_padding > 0:
165
+ self.query_start_loc_list[bs - 1][bs - num_padding :] = bs - num_padding
166
+ elif forward_mode.is_target_verify():
167
+ self.query_start_loc_list[bs - 1].copy_(
168
+ torch.arange(
169
+ 0,
170
+ bs * spec_info.draft_token_num + 1,
171
+ step=spec_info.draft_token_num,
172
+ dtype=torch.int32,
173
+ device=self.device,
174
+ )
175
+ )
176
+ if num_padding > 0:
177
+ self.query_start_loc_list[bs - 1][bs - num_padding :] = (
178
+ bs - num_padding
179
+ ) * spec_info.draft_token_num
180
+ else:
181
+ raise ValueError(f"Invalid forward mode: {forward_mode=}")
182
+
183
+ self.forward_metadata = ForwardMetadata(
184
+ query_start_loc=self.query_start_loc_list[bs - 1],
185
+ mamba_cache_indices=self.state_indices_list[bs - 1],
186
+ )
187
+
188
+ def get_cuda_graph_seq_len_fill_value(self):
189
+ return 1 # Mamba attn does not use seq lens to index kv cache
190
+
191
+ def forward_decode(
192
+ self,
193
+ q: torch.Tensor,
194
+ k: torch.Tensor,
195
+ v: torch.Tensor,
196
+ layer: RadixAttention,
197
+ forward_batch: ForwardBatch,
198
+ save_kv_cache: bool = True,
199
+ **kwargs,
200
+ ):
201
+ mixed_qkv = kwargs["mixed_qkv"]
202
+ conv_weights = kwargs["conv_weights"]
203
+ bias = kwargs["bias"]
204
+ activation = kwargs["activation"]
205
+ key_dim = kwargs["key_dim"]
206
+ value_dim = kwargs["value_dim"]
207
+ attn_tp_size = kwargs["attention_tp_size"]
208
+ head_k_dim = kwargs["head_k_dim"]
209
+ head_v_dim = kwargs["head_v_dim"]
210
+ a = kwargs["a"]
211
+ b = kwargs["b"]
212
+ A_log = kwargs["A_log"]
213
+ dt_bias = kwargs["dt_bias"]
214
+ layer_id = kwargs["layer_id"]
215
+
216
+ conv_states, ssm_states, *rest = self.req_to_token_pool.get_mamba_params(
217
+ layer_id
218
+ )
219
+ query_start_loc = self.forward_metadata.query_start_loc
220
+ cache_indices = self.forward_metadata.mamba_cache_indices
221
+
222
+ mixed_qkv = causal_conv1d_update(
223
+ mixed_qkv,
224
+ conv_states,
225
+ conv_weights,
226
+ bias,
227
+ activation,
228
+ conv_state_indices=cache_indices,
229
+ )
230
+
231
+ query, key, value = torch.split(
232
+ mixed_qkv,
233
+ [
234
+ key_dim // attn_tp_size,
235
+ key_dim // attn_tp_size,
236
+ value_dim // attn_tp_size,
237
+ ],
238
+ dim=-1,
239
+ )
240
+ # Reshape from [l, h*d] to [1, l, h, d]
241
+ seq_len = query.shape[0]
242
+ num_heads = query.shape[1] // head_k_dim
243
+ query = query.view(1, seq_len, num_heads, head_k_dim)
244
+ key = key.view(1, seq_len, num_heads, head_k_dim)
245
+ value = value.view(1, seq_len, value.shape[1] // head_v_dim, head_v_dim)
246
+
247
+ core_attn_out = fused_sigmoid_gating_delta_rule_update(
248
+ A_log=A_log,
249
+ dt_bias=dt_bias,
250
+ q=query,
251
+ k=key,
252
+ v=value,
253
+ a=a,
254
+ b=b,
255
+ initial_state_source=ssm_states,
256
+ initial_state_indices=cache_indices,
257
+ cu_seqlens=query_start_loc,
258
+ use_qk_l2norm_in_kernel=True,
259
+ softplus_beta=1.0,
260
+ softplus_threshold=20.0,
261
+ )
262
+
263
+ return core_attn_out
264
+
265
+ def forward_extend(
266
+ self,
267
+ q: torch.Tensor,
268
+ k: torch.Tensor,
269
+ v: torch.Tensor,
270
+ layer: RadixAttention,
271
+ forward_batch: ForwardBatch,
272
+ save_kv_cache: bool = True,
273
+ **kwargs,
274
+ ):
275
+ mixed_qkv = kwargs["mixed_qkv"]
276
+ conv_weights = kwargs["conv_weights"]
277
+ bias = kwargs["bias"]
278
+ activation = kwargs["activation"]
279
+ key_dim = kwargs["key_dim"]
280
+ value_dim = kwargs["value_dim"]
281
+ attn_tp_size = kwargs["attention_tp_size"]
282
+ head_k_dim = kwargs["head_k_dim"]
283
+ head_v_dim = kwargs["head_v_dim"]
284
+ a = kwargs["a"]
285
+ b = kwargs["b"]
286
+ A_log = kwargs["A_log"]
287
+ dt_bias = kwargs["dt_bias"]
288
+ layer_id = kwargs["layer_id"]
289
+ seq_len = kwargs["seq_len"]
290
+
291
+ is_target_verify = forward_batch.forward_mode.is_target_verify()
292
+
293
+ query_start_loc = self.forward_metadata.query_start_loc
294
+ cache_indices = self.forward_metadata.mamba_cache_indices
295
+
296
+ if is_target_verify:
297
+ (
298
+ conv_states,
299
+ ssm_states,
300
+ intermediate_state_cache,
301
+ intermediate_conv_window_cache,
302
+ ) = self.req_to_token_pool.get_mamba_params(layer_id)
303
+ has_initial_states = torch.ones(
304
+ seq_len // forward_batch.spec_info.draft_token_num,
305
+ dtype=torch.bool,
306
+ device=forward_batch.input_ids.device,
307
+ )
308
+ conv_states_to_use = conv_states.clone()
309
+ else:
310
+ conv_states, ssm_states, *rest = self.req_to_token_pool.get_mamba_params(
311
+ layer_id
312
+ )
313
+ has_initial_states = forward_batch.extend_prefix_lens > 0
314
+ conv_states_to_use = conv_states
315
+
316
+ if is_target_verify:
317
+ batch_size = seq_len // forward_batch.spec_info.draft_token_num
318
+ draft_token_num = forward_batch.spec_info.draft_token_num
319
+ mixed_qkv_reshaped = (
320
+ mixed_qkv.view(batch_size, draft_token_num, -1)
321
+ .transpose(1, 2)
322
+ .contiguous()
323
+ )
324
+ mixed_qkv_processed = causal_conv1d_update(
325
+ mixed_qkv_reshaped,
326
+ conv_states_to_use,
327
+ conv_weights,
328
+ bias,
329
+ activation,
330
+ conv_state_indices=cache_indices[:batch_size],
331
+ intermediate_conv_window=intermediate_conv_window_cache,
332
+ )
333
+ mixed_qkv = (
334
+ mixed_qkv_processed.transpose(1, 2).contiguous().view(seq_len, -1)
335
+ )
336
+ else:
337
+ mixed_qkv = causal_conv1d_fn(
338
+ mixed_qkv.transpose(0, 1),
339
+ conv_weights,
340
+ bias,
341
+ activation=activation,
342
+ conv_states=conv_states_to_use,
343
+ has_initial_state=has_initial_states,
344
+ cache_indices=cache_indices,
345
+ query_start_loc=query_start_loc,
346
+ ).transpose(0, 1)[:seq_len]
347
+
348
+ key_split_dim = key_dim // attn_tp_size
349
+ value_split_dim = value_dim // attn_tp_size
350
+
351
+ query, key, value = torch.split(
352
+ mixed_qkv,
353
+ [key_split_dim, key_split_dim, value_split_dim],
354
+ dim=-1,
355
+ )
356
+
357
+ actual_seq_len = query.shape[0]
358
+ num_heads = query.shape[1] // head_k_dim
359
+ num_value_heads = value.shape[1] // head_v_dim
360
+
361
+ query = query.view(1, actual_seq_len, num_heads, head_k_dim)
362
+ key = key.view(1, actual_seq_len, num_heads, head_k_dim)
363
+ value = value.view(1, actual_seq_len, num_value_heads, head_v_dim)
364
+
365
+ beta = b.sigmoid()
366
+ g = fused_gdn_gating(A_log, a, dt_bias)
367
+
368
+ g = g.unsqueeze(0)
369
+ beta = beta.unsqueeze(0)
370
+
371
+ if is_target_verify:
372
+ core_attn_out = fused_recurrent_gated_delta_rule_update(
373
+ q=query,
374
+ k=key,
375
+ v=value,
376
+ g=g,
377
+ beta=beta,
378
+ initial_state_source=ssm_states,
379
+ initial_state_indices=cache_indices,
380
+ cu_seqlens=query_start_loc,
381
+ use_qk_l2norm_in_kernel=True,
382
+ disable_state_update=True,
383
+ intermediate_states_buffer=intermediate_state_cache,
384
+ cache_steps=forward_batch.spec_info.draft_token_num,
385
+ )
386
+ else:
387
+ recurrent_state = ssm_states[cache_indices]
388
+ core_attn_out, last_recurrent_state = chunk_gated_delta_rule(
389
+ q=query,
390
+ k=key,
391
+ v=value,
392
+ g=g,
393
+ beta=beta,
394
+ initial_state=recurrent_state,
395
+ output_final_state=True,
396
+ cu_seqlens=query_start_loc,
397
+ head_first=False,
398
+ use_qk_l2norm_in_kernel=True,
399
+ )
400
+ last_recurrent_state = last_recurrent_state.to(ssm_states.dtype, copy=False)
401
+ ssm_states[cache_indices] = last_recurrent_state
402
+
403
+ return core_attn_out
404
+
405
+
406
+ class HybridLinearAttnBackend(AttentionBackend):
407
+ """Support different backends for prefill and decode."""
408
+
409
+ def __init__(
410
+ self,
411
+ full_attn_backend: AttentionBackend,
412
+ linear_attn_backend: AttentionBackend,
413
+ full_attn_layers: list[int],
414
+ ):
415
+ self.full_attn_layers = full_attn_layers
416
+ self.attn_backend_list = [full_attn_backend, linear_attn_backend]
417
+
418
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
419
+ for attn_backend in self.attn_backend_list:
420
+ attn_backend.init_forward_metadata(forward_batch)
421
+
422
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
423
+ for attn_backend in self.attn_backend_list:
424
+ attn_backend.init_cuda_graph_state(max_bs, max_num_tokens)
425
+
426
+ def init_forward_metadata_capture_cuda_graph(
427
+ self,
428
+ bs: int,
429
+ num_tokens: int,
430
+ req_pool_indices: torch.Tensor,
431
+ seq_lens: torch.Tensor,
432
+ encoder_lens: Optional[torch.Tensor],
433
+ forward_mode: ForwardMode,
434
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
435
+ ):
436
+ for attn_backend in self.attn_backend_list:
437
+ attn_backend.init_forward_metadata_capture_cuda_graph(
438
+ bs,
439
+ num_tokens,
440
+ req_pool_indices,
441
+ seq_lens,
442
+ encoder_lens,
443
+ forward_mode,
444
+ spec_info,
445
+ )
446
+
447
+ def init_forward_metadata_replay_cuda_graph(
448
+ self,
449
+ bs: int,
450
+ req_pool_indices: torch.Tensor,
451
+ seq_lens: torch.Tensor,
452
+ seq_lens_sum: int,
453
+ encoder_lens: Optional[torch.Tensor],
454
+ forward_mode: ForwardMode,
455
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
456
+ seq_lens_cpu: Optional[torch.Tensor],
457
+ ):
458
+ for attn_backend in self.attn_backend_list:
459
+ attn_backend.init_forward_metadata_replay_cuda_graph(
460
+ bs,
461
+ req_pool_indices,
462
+ seq_lens,
463
+ seq_lens_sum,
464
+ encoder_lens,
465
+ forward_mode,
466
+ spec_info,
467
+ seq_lens_cpu,
468
+ )
469
+
470
+ def get_cuda_graph_seq_len_fill_value(self):
471
+ return self.attn_backend_list[0].get_cuda_graph_seq_len_fill_value()
472
+
473
+ def forward_decode(
474
+ self,
475
+ q: torch.Tensor,
476
+ k: torch.Tensor,
477
+ v: torch.Tensor,
478
+ layer: RadixAttention,
479
+ forward_batch: ForwardBatch,
480
+ save_kv_cache: bool = True,
481
+ **kwargs,
482
+ ):
483
+ layer_id = layer.layer_id if layer else kwargs["layer_id"]
484
+ if layer_id in self.full_attn_layers:
485
+ return self.attn_backend_list[0].forward_decode(
486
+ q, k, v, layer, forward_batch, save_kv_cache, **kwargs
487
+ )
488
+ return self.attn_backend_list[1].forward_decode(
489
+ q, k, v, layer, forward_batch, save_kv_cache, **kwargs
490
+ )
491
+
492
+ def forward_extend(
493
+ self,
494
+ q: torch.Tensor,
495
+ k: torch.Tensor,
496
+ v: torch.Tensor,
497
+ layer: RadixAttention,
498
+ forward_batch: ForwardBatch,
499
+ save_kv_cache: bool = True,
500
+ **kwargs,
501
+ ):
502
+ layer_id = layer.layer_id if layer else kwargs["layer_id"]
503
+ if layer_id in self.full_attn_layers:
504
+ return self.attn_backend_list[0].forward_extend(
505
+ q, k, v, layer, forward_batch, save_kv_cache, **kwargs
506
+ )
507
+ return self.attn_backend_list[1].forward_extend(
508
+ q, k, v, layer, forward_batch, save_kv_cache, **kwargs
509
+ )
510
+
511
+ def forward(
512
+ self,
513
+ q: torch.Tensor,
514
+ k: torch.Tensor,
515
+ v: torch.Tensor,
516
+ layer: RadixAttention,
517
+ forward_batch: ForwardBatch,
518
+ save_kv_cache: bool = True,
519
+ **kwargs,
520
+ ):
521
+ """Run forward on an attention layer."""
522
+ if forward_batch.forward_mode.is_idle():
523
+ if layer is None:
524
+ return torch.empty_like(kwargs["z"])
525
+ return q.new_empty(q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
526
+ elif forward_batch.forward_mode.is_decode():
527
+ return self.forward_decode(
528
+ q,
529
+ k,
530
+ v,
531
+ layer,
532
+ forward_batch,
533
+ save_kv_cache=save_kv_cache,
534
+ **kwargs,
535
+ )
536
+ else:
537
+ return self.forward_extend(
538
+ q,
539
+ k,
540
+ v,
541
+ layer,
542
+ forward_batch,
543
+ save_kv_cache=save_kv_cache,
544
+ **kwargs,
545
+ )
546
+
547
+ def update_mamba_state_after_mtp_verify(self, accepted_length, model):
548
+ request_number = accepted_length.shape[0]
549
+
550
+ state_indices_tensor = self.attn_backend_list[
551
+ 1
552
+ ].forward_metadata.mamba_cache_indices[:request_number]
553
+
554
+ mamba_caches = self.attn_backend_list[
555
+ 1
556
+ ].req_to_token_pool.get_mamba_params_all_layers()
557
+
558
+ (
559
+ conv_states,
560
+ ssm_states,
561
+ intermediate_state_cache,
562
+ intermediate_conv_window_cache,
563
+ ) = mamba_caches
564
+
565
+ # SSM state updates (chunked to reduce peak memory)
566
+ valid_mask = accepted_length > 0
567
+
568
+ # Compute common indices once to avoid duplication
569
+ last_steps_all = (accepted_length - 1).to(torch.int64)
570
+ valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64)
571
+ last_steps = last_steps_all[valid_mask].to(torch.int64)
572
+
573
+ if valid_state_indices.numel() > 0:
574
+ chunk = 256
575
+ num_valid = valid_state_indices.numel()
576
+
577
+ # SSM state updates
578
+ for i in range(0, num_valid, chunk):
579
+ idx = valid_state_indices[i : i + chunk]
580
+ steps = last_steps[i : i + chunk]
581
+ # per (cache line, step)
582
+ for j in range(idx.numel()):
583
+ ci = idx[j].item()
584
+ st = steps[j].item()
585
+ ssm_states[:, ci, :].copy_(
586
+ intermediate_state_cache[:, ci, st].to(
587
+ ssm_states.dtype, copy=False
588
+ )
589
+ )
590
+
591
+ # Conv window updates
592
+ for i in range(0, num_valid, chunk):
593
+ idx = valid_state_indices[i : i + chunk]
594
+ steps = last_steps[i : i + chunk]
595
+ for j in range(idx.numel()):
596
+ ci = idx[j].item()
597
+ st = steps[j].item()
598
+ conv_states[:, ci, :, :].copy_(
599
+ intermediate_conv_window_cache[:, ci, st].to(
600
+ conv_states.dtype, copy=False
601
+ )
602
+ )
@@ -49,6 +49,9 @@ class IntelAMXAttnBackend(AttentionBackend):
49
49
  max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
50
50
  self.forward_metadata = (attn_logits, max_extend_len)
51
51
 
52
+ def get_graph_seq_len_fill_value(self):
53
+ return 1
54
+
52
55
  def forward_extend(
53
56
  self,
54
57
  q,