sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,6 @@ import datetime
3
3
  from google.protobuf import timestamp_pb2 as _timestamp_pb2
4
4
  from google.protobuf import struct_pb2 as _struct_pb2
5
5
  from google.protobuf.internal import containers as _containers
6
- from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
7
6
  from google.protobuf import descriptor as _descriptor
8
7
  from google.protobuf import message as _message
9
8
  from collections.abc import Iterable as _Iterable, Mapping as _Mapping
@@ -12,7 +11,7 @@ from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
12
11
  DESCRIPTOR: _descriptor.FileDescriptor
13
12
 
14
13
  class SamplingParams(_message.Message):
15
- __slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "lora_path", "n", "token_healing", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "structural_tag", "custom_params")
14
+ __slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "structural_tag", "lora_path", "n", "token_healing", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "custom_params")
16
15
  class LogitBiasEntry(_message.Message):
17
16
  __slots__ = ("key", "value")
18
17
  KEY_FIELD_NUMBER: _ClassVar[int]
@@ -35,6 +34,7 @@ class SamplingParams(_message.Message):
35
34
  REGEX_FIELD_NUMBER: _ClassVar[int]
36
35
  JSON_SCHEMA_FIELD_NUMBER: _ClassVar[int]
37
36
  EBNF_GRAMMAR_FIELD_NUMBER: _ClassVar[int]
37
+ STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int]
38
38
  LORA_PATH_FIELD_NUMBER: _ClassVar[int]
39
39
  N_FIELD_NUMBER: _ClassVar[int]
40
40
  TOKEN_HEALING_FIELD_NUMBER: _ClassVar[int]
@@ -43,7 +43,6 @@ class SamplingParams(_message.Message):
43
43
  NO_STOP_TRIM_FIELD_NUMBER: _ClassVar[int]
44
44
  STREAM_INTERVAL_FIELD_NUMBER: _ClassVar[int]
45
45
  LOGIT_BIAS_FIELD_NUMBER: _ClassVar[int]
46
- STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int]
47
46
  CUSTOM_PARAMS_FIELD_NUMBER: _ClassVar[int]
48
47
  temperature: float
49
48
  top_p: float
@@ -60,6 +59,7 @@ class SamplingParams(_message.Message):
60
59
  regex: str
61
60
  json_schema: str
62
61
  ebnf_grammar: str
62
+ structural_tag: str
63
63
  lora_path: str
64
64
  n: int
65
65
  token_healing: bool
@@ -68,9 +68,8 @@ class SamplingParams(_message.Message):
68
68
  no_stop_trim: bool
69
69
  stream_interval: int
70
70
  logit_bias: _containers.ScalarMap[str, float]
71
- structural_tag: str
72
71
  custom_params: _struct_pb2.Struct
73
- def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., lora_path: _Optional[str] = ..., n: _Optional[int] = ..., token_healing: bool = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., structural_tag: _Optional[str] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
72
+ def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., structural_tag: _Optional[str] = ..., lora_path: _Optional[str] = ..., n: _Optional[int] = ..., token_healing: bool = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
74
73
 
75
74
  class DisaggregatedParams(_message.Message):
76
75
  __slots__ = ("bootstrap_host", "bootstrap_port", "bootstrap_room")
@@ -83,7 +82,7 @@ class DisaggregatedParams(_message.Message):
83
82
  def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ...
84
83
 
85
84
  class GenerateRequest(_message.Message):
86
- __slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id")
85
+ __slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "stream")
87
86
  REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
88
87
  TOKENIZED_FIELD_NUMBER: _ClassVar[int]
89
88
  MM_INPUTS_FIELD_NUMBER: _ClassVar[int]
@@ -100,7 +99,7 @@ class GenerateRequest(_message.Message):
100
99
  INPUT_EMBEDS_FIELD_NUMBER: _ClassVar[int]
101
100
  LORA_ID_FIELD_NUMBER: _ClassVar[int]
102
101
  DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int]
103
- DP_BALANCE_ID_FIELD_NUMBER: _ClassVar[int]
102
+ STREAM_FIELD_NUMBER: _ClassVar[int]
104
103
  request_id: str
105
104
  tokenized: TokenizedInput
106
105
  mm_inputs: MultimodalInputs
@@ -117,8 +116,8 @@ class GenerateRequest(_message.Message):
117
116
  input_embeds: _containers.RepeatedScalarFieldContainer[float]
118
117
  lora_id: str
119
118
  data_parallel_rank: int
120
- dp_balance_id: int
121
- def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ...) -> None: ...
119
+ stream: bool
120
+ def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., stream: bool = ...) -> None: ...
122
121
 
123
122
  class TokenizedInput(_message.Message):
124
123
  __slots__ = ("original_text", "input_ids")
@@ -161,52 +160,50 @@ class GenerateResponse(_message.Message):
161
160
  def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
162
161
 
163
162
  class GenerateStreamChunk(_message.Message):
164
- __slots__ = ("token_id", "text", "prompt_tokens", "completion_tokens", "cached_tokens", "logprobs", "hidden_states", "generation_time", "queue_time")
165
- TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
166
- TEXT_FIELD_NUMBER: _ClassVar[int]
163
+ __slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "hidden_states", "input_logprobs", "index")
164
+ TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
167
165
  PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
168
166
  COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
169
167
  CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
170
- LOGPROBS_FIELD_NUMBER: _ClassVar[int]
168
+ OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
171
169
  HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
172
- GENERATION_TIME_FIELD_NUMBER: _ClassVar[int]
173
- QUEUE_TIME_FIELD_NUMBER: _ClassVar[int]
174
- token_id: int
175
- text: str
170
+ INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
171
+ INDEX_FIELD_NUMBER: _ClassVar[int]
172
+ token_ids: _containers.RepeatedScalarFieldContainer[int]
176
173
  prompt_tokens: int
177
174
  completion_tokens: int
178
175
  cached_tokens: int
179
- logprobs: LogProbs
176
+ output_logprobs: OutputLogProbs
180
177
  hidden_states: _containers.RepeatedScalarFieldContainer[float]
181
- generation_time: float
182
- queue_time: int
183
- def __init__(self, token_id: _Optional[int] = ..., text: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., generation_time: _Optional[float] = ..., queue_time: _Optional[int] = ...) -> None: ...
178
+ input_logprobs: InputLogProbs
179
+ index: int
180
+ def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ..., index: _Optional[int] = ...) -> None: ...
184
181
 
185
182
  class GenerateComplete(_message.Message):
186
- __slots__ = ("output_ids", "output_text", "finish_reason", "all_logprobs", "all_hidden_states")
187
- class FinishReason(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
188
- __slots__ = ()
189
- STOP: _ClassVar[GenerateComplete.FinishReason]
190
- LENGTH: _ClassVar[GenerateComplete.FinishReason]
191
- EOS_TOKEN: _ClassVar[GenerateComplete.FinishReason]
192
- STOP_STR: _ClassVar[GenerateComplete.FinishReason]
193
- ABORT: _ClassVar[GenerateComplete.FinishReason]
194
- STOP: GenerateComplete.FinishReason
195
- LENGTH: GenerateComplete.FinishReason
196
- EOS_TOKEN: GenerateComplete.FinishReason
197
- STOP_STR: GenerateComplete.FinishReason
198
- ABORT: GenerateComplete.FinishReason
183
+ __slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str", "input_logprobs", "index")
199
184
  OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
200
- OUTPUT_TEXT_FIELD_NUMBER: _ClassVar[int]
201
185
  FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
202
- ALL_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
186
+ PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
187
+ COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
188
+ CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
189
+ OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
203
190
  ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
191
+ MATCHED_TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
192
+ MATCHED_STOP_STR_FIELD_NUMBER: _ClassVar[int]
193
+ INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
194
+ INDEX_FIELD_NUMBER: _ClassVar[int]
204
195
  output_ids: _containers.RepeatedScalarFieldContainer[int]
205
- output_text: str
206
- finish_reason: GenerateComplete.FinishReason
207
- all_logprobs: _containers.RepeatedCompositeFieldContainer[LogProbs]
196
+ finish_reason: str
197
+ prompt_tokens: int
198
+ completion_tokens: int
199
+ cached_tokens: int
200
+ output_logprobs: OutputLogProbs
208
201
  all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates]
209
- def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., output_text: _Optional[str] = ..., finish_reason: _Optional[_Union[GenerateComplete.FinishReason, str]] = ..., all_logprobs: _Optional[_Iterable[_Union[LogProbs, _Mapping]]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ...) -> None: ...
202
+ matched_token_id: int
203
+ matched_stop_str: str
204
+ input_logprobs: InputLogProbs
205
+ index: int
206
+ def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ..., index: _Optional[int] = ...) -> None: ...
210
207
 
211
208
  class GenerateError(_message.Message):
212
209
  __slots__ = ("message", "http_status_code", "details")
@@ -218,27 +215,39 @@ class GenerateError(_message.Message):
218
215
  details: str
219
216
  def __init__(self, message: _Optional[str] = ..., http_status_code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ...
220
217
 
221
- class LogProbs(_message.Message):
222
- __slots__ = ("token_logprobs", "token_ids", "top_logprobs", "token_texts")
218
+ class OutputLogProbs(_message.Message):
219
+ __slots__ = ("token_logprobs", "token_ids", "top_logprobs")
223
220
  TOKEN_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
224
221
  TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
225
222
  TOP_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
226
- TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
227
223
  token_logprobs: _containers.RepeatedScalarFieldContainer[float]
228
224
  token_ids: _containers.RepeatedScalarFieldContainer[int]
229
225
  top_logprobs: _containers.RepeatedCompositeFieldContainer[TopLogProbs]
230
- token_texts: _containers.RepeatedScalarFieldContainer[str]
231
- def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
226
+ def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ...) -> None: ...
227
+
228
+ class InputLogProbs(_message.Message):
229
+ __slots__ = ("token_logprobs", "token_ids", "top_logprobs")
230
+ TOKEN_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
231
+ TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
232
+ TOP_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
233
+ token_logprobs: _containers.RepeatedCompositeFieldContainer[InputTokenLogProb]
234
+ token_ids: _containers.RepeatedScalarFieldContainer[int]
235
+ top_logprobs: _containers.RepeatedCompositeFieldContainer[TopLogProbs]
236
+ def __init__(self, token_logprobs: _Optional[_Iterable[_Union[InputTokenLogProb, _Mapping]]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ...) -> None: ...
237
+
238
+ class InputTokenLogProb(_message.Message):
239
+ __slots__ = ("value",)
240
+ VALUE_FIELD_NUMBER: _ClassVar[int]
241
+ value: float
242
+ def __init__(self, value: _Optional[float] = ...) -> None: ...
232
243
 
233
244
  class TopLogProbs(_message.Message):
234
- __slots__ = ("values", "token_ids", "token_texts")
245
+ __slots__ = ("values", "token_ids")
235
246
  VALUES_FIELD_NUMBER: _ClassVar[int]
236
247
  TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
237
- TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
238
248
  values: _containers.RepeatedScalarFieldContainer[float]
239
249
  token_ids: _containers.RepeatedScalarFieldContainer[int]
240
- token_texts: _containers.RepeatedScalarFieldContainer[str]
241
- def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
250
+ def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ...) -> None: ...
242
251
 
243
252
  class HiddenStates(_message.Message):
244
253
  __slots__ = ("values", "layer", "position")
@@ -283,20 +292,18 @@ class EmbedResponse(_message.Message):
283
292
  def __init__(self, request_id: _Optional[str] = ..., complete: _Optional[_Union[EmbedComplete, _Mapping]] = ..., error: _Optional[_Union[EmbedError, _Mapping]] = ...) -> None: ...
284
293
 
285
294
  class EmbedComplete(_message.Message):
286
- __slots__ = ("embedding", "prompt_tokens", "cached_tokens", "embedding_dim", "generation_time", "batch_embeddings")
295
+ __slots__ = ("embedding", "prompt_tokens", "cached_tokens", "embedding_dim", "batch_embeddings")
287
296
  EMBEDDING_FIELD_NUMBER: _ClassVar[int]
288
297
  PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
289
298
  CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
290
299
  EMBEDDING_DIM_FIELD_NUMBER: _ClassVar[int]
291
- GENERATION_TIME_FIELD_NUMBER: _ClassVar[int]
292
300
  BATCH_EMBEDDINGS_FIELD_NUMBER: _ClassVar[int]
293
301
  embedding: _containers.RepeatedScalarFieldContainer[float]
294
302
  prompt_tokens: int
295
303
  cached_tokens: int
296
304
  embedding_dim: int
297
- generation_time: float
298
305
  batch_embeddings: _containers.RepeatedCompositeFieldContainer[Embedding]
299
- def __init__(self, embedding: _Optional[_Iterable[float]] = ..., prompt_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., embedding_dim: _Optional[int] = ..., generation_time: _Optional[float] = ..., batch_embeddings: _Optional[_Iterable[_Union[Embedding, _Mapping]]] = ...) -> None: ...
306
+ def __init__(self, embedding: _Optional[_Iterable[float]] = ..., prompt_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., embedding_dim: _Optional[int] = ..., batch_embeddings: _Optional[_Iterable[_Union[Embedding, _Mapping]]] = ...) -> None: ...
300
307
 
301
308
  class Embedding(_message.Message):
302
309
  __slots__ = ("values", "index")
@@ -1,3 +1,6 @@
1
+ # This file is auto-generated. Do not edit manually.
2
+ # Regenerate with: python compile_proto.py
3
+
1
4
  # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2
5
  """Client and server classes corresponding to protobuf-defined services."""
3
6
  import grpc
@@ -224,12 +224,13 @@ class XIELU(CustomOp):
224
224
  self._xielu_cuda_fn = self._xielu_cuda
225
225
  logger.warning_once(msg)
226
226
  except Exception as err:
227
- logger.warning_once(
228
- "CUDA-fused xIELU not available (%s) –"
229
- " falling back to a Python version.\n"
230
- "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
231
- str(err),
232
- )
227
+ pass
228
+ # logger.warning_once(
229
+ # "CUDA-fused xIELU not available (%s) "
230
+ # " falling back to a Python version.\n"
231
+ # "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
232
+ # str(err),
233
+ # )
233
234
 
234
235
  def _xielu_python(self, x: torch.Tensor) -> torch.Tensor:
235
236
  alpha_p = nn.functional.softplus(self.alpha_p)
@@ -4,18 +4,13 @@ from __future__ import annotations
4
4
  end to end attention solution with aiter kernels
5
5
  """
6
6
 
7
- import math
8
- import os
9
7
  from dataclasses import dataclass
10
8
  from enum import Enum, auto
11
- from functools import partial
12
- from typing import TYPE_CHECKING, List, Optional, Union
9
+ from typing import TYPE_CHECKING, Optional
13
10
 
14
11
  import torch
15
12
  import triton
16
- import triton.language as tl
17
13
 
18
- from sglang.global_config import global_config
19
14
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
20
15
  from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
21
16
  from sglang.srt.layers.dp_attention import (
@@ -27,7 +22,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
27
22
  if TYPE_CHECKING:
28
23
  from sglang.srt.layers.radix_attention import RadixAttention
29
24
  from sglang.srt.model_executor.model_runner import ModelRunner
30
- from sglang.srt.speculative.spec_info import SpecInfo
25
+ from sglang.srt.speculative.spec_info import SpecInput
31
26
 
32
27
  try:
33
28
  from aiter import (
@@ -374,7 +369,7 @@ class AiterAttnBackend(AttentionBackend):
374
369
  seq_lens: torch.Tensor,
375
370
  encoder_lens: Optional[torch.Tensor],
376
371
  forward_mode: ForwardMode,
377
- spec_info: Optional[SpecInfo],
372
+ spec_info: Optional[SpecInput],
378
373
  ):
379
374
  if forward_mode.is_decode_or_idle():
380
375
  qo_indptr = None
@@ -509,7 +504,7 @@ class AiterAttnBackend(AttentionBackend):
509
504
  seq_lens_sum: int,
510
505
  encoder_lens: Optional[torch.Tensor],
511
506
  forward_mode: ForwardMode,
512
- spec_info: Optional[SpecInfo],
507
+ spec_info: Optional[SpecInput],
513
508
  seq_lens_cpu: Optional[torch.Tensor],
514
509
  ):
515
510
  if forward_mode.is_decode_or_idle():
@@ -619,7 +614,11 @@ class AiterAttnBackend(AttentionBackend):
619
614
  assert len(k.shape) == 3
620
615
  assert len(v.shape) == 3
621
616
 
622
- if forward_batch.forward_mode.is_extend():
617
+ if (
618
+ forward_batch.forward_mode.is_extend()
619
+ and not forward_batch.forward_mode.is_target_verify()
620
+ and not forward_batch.forward_mode.is_draft_extend()
621
+ ):
623
622
  if kv_indices.shape[0] == 0:
624
623
  o = flash_attn_varlen_func(
625
624
  q,
@@ -884,7 +883,7 @@ class AiterIndicesUpdaterPrefill:
884
883
  seq_lens_sum: int,
885
884
  prefix_lens: torch.Tensor,
886
885
  encoder_lens: Optional[torch.Tensor],
887
- spec_info: Optional[SpecInfo],
886
+ spec_info: Optional[SpecInput],
888
887
  ):
889
888
  # Keep the signature for type checking. It will be assigned during runtime.
890
889
  raise NotImplementedError()
@@ -896,7 +895,7 @@ class AiterIndicesUpdaterPrefill:
896
895
  seq_lens_sum: int,
897
896
  prefix_lens: torch.Tensor,
898
897
  encoder_lens: Optional[torch.Tensor],
899
- spec_info: Optional[SpecInfo],
898
+ spec_info: Optional[SpecInput],
900
899
  ):
901
900
 
902
901
  kv_start_idx = None
@@ -980,7 +979,7 @@ class AiterMlaIndicesUpdaterPrefill:
980
979
  extend_lens: torch.Tensor,
981
980
  max_q_len: int,
982
981
  max_kv_len: int,
983
- spec_info: Optional[SpecInfo],
982
+ spec_info: Optional[SpecInput],
984
983
  ):
985
984
  # Keep the signature for type checking. It will be assigned during runtime.
986
985
  raise NotImplementedError()
@@ -993,7 +992,7 @@ class AiterMlaIndicesUpdaterPrefill:
993
992
  extend_lens: torch.Tensor,
994
993
  max_q_len: int,
995
994
  max_kv_len: int,
996
- spec_info: Optional[SpecInfo],
995
+ spec_info: Optional[SpecInput],
997
996
  ):
998
997
  bs = len(req_pool_indices)
999
998
 
@@ -1050,7 +1049,7 @@ class AiterMultiStepDraftBackend:
1050
1049
  topk: int,
1051
1050
  speculative_num_steps: int,
1052
1051
  ):
1053
- from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
1052
+ from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
1054
1053
 
1055
1054
  self.topk = topk
1056
1055
  self.speculative_num_steps = speculative_num_steps
@@ -5,14 +5,15 @@ from typing import TYPE_CHECKING, List, Optional
5
5
 
6
6
  import torch
7
7
  import torch_npu
8
- from torch.nn.functional import scaled_dot_product_attention
9
8
 
10
9
  from sglang.srt.configs.model_config import AttentionArch
11
10
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
11
+ from sglang.srt.layers.attention.npu_ops.mla_preprocess import is_mla_preprocess_enabled
12
12
  from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
13
13
  from sglang.srt.layers.dp_attention import get_attention_tp_size
14
14
  from sglang.srt.layers.radix_attention import AttentionType
15
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
15
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
16
+ from sglang.srt.speculative.spec_info import SpecInput
16
17
  from sglang.srt.utils import get_bool_env_var
17
18
 
18
19
  if TYPE_CHECKING:
@@ -35,6 +36,8 @@ class ForwardMetadata:
35
36
  seq_lens_cpu_int: Optional[torch.Tensor] = None
36
37
  seq_lens_cpu_list: Optional[List[int]] = None
37
38
  seq_lens_list_cumsum: Optional[List[int]] = None
39
+ seq_lens: Optional[torch.Tensor] = None
40
+ actual_seq_lengths_q: Optional[torch.Tensor] = None
38
41
 
39
42
 
40
43
  class AscendAttnBackend(AttentionBackend):
@@ -66,6 +69,9 @@ class AscendAttnBackend(AttentionBackend):
66
69
  if self.use_mla:
67
70
  self.kv_lora_rank = model_runner.model_config.kv_lora_rank
68
71
  self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
72
+ self.q_head_dim = (
73
+ self.qk_rope_head_dim + model_runner.model_config.qk_nope_head_dim
74
+ )
69
75
  self.native_attn = TorchNativeAttnBackend(model_runner)
70
76
  self.graph_metadata = {}
71
77
  self.max_context_len = model_runner.model_config.context_len
@@ -101,10 +107,6 @@ class AscendAttnBackend(AttentionBackend):
101
107
  self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
102
108
 
103
109
  seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
104
- if forward_batch.is_extend_in_batch:
105
- seq_lens_list_cumsum[-1] = (
106
- (seq_lens_list_cumsum[-1] - 1) // tp_size + 1
107
- ) * tp_size
108
110
  self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
109
111
 
110
112
  self.graph_mode = False
@@ -126,12 +128,16 @@ class AscendAttnBackend(AttentionBackend):
126
128
  seq_lens: torch.Tensor,
127
129
  encoder_lens: Optional[torch.Tensor],
128
130
  forward_mode: ForwardMode,
129
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
131
+ spec_info: Optional[SpecInput],
130
132
  ):
131
133
  metadata = ForwardMetadata()
132
134
 
133
135
  metadata.block_tables = self.graph_metadata["block_tables"][:bs, :]
134
136
  metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist()
137
+ metadata.seq_lens = seq_lens
138
+ metadata.actual_seq_lengths_q = torch.tensor(
139
+ [1 + i * 1 for i in range(bs)], dtype=torch.int32, device=seq_lens.device
140
+ )
135
141
 
136
142
  self.graph_metadata[bs] = metadata
137
143
  self.forward_metadata = metadata
@@ -146,7 +152,7 @@ class AscendAttnBackend(AttentionBackend):
146
152
  seq_lens_sum: int,
147
153
  encoder_lens: Optional[torch.Tensor],
148
154
  forward_mode: ForwardMode,
149
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
155
+ spec_info: Optional[SpecInput],
150
156
  seq_lens_cpu: Optional[torch.Tensor],
151
157
  ):
152
158
  metadata = self.graph_metadata[bs]
@@ -160,6 +166,8 @@ class AscendAttnBackend(AttentionBackend):
160
166
  metadata.block_tables[:bs, max_seq_pages:].fill_(0)
161
167
  metadata.block_tables[bs:, :].fill_(0)
162
168
 
169
+ metadata.seq_lens[:bs].copy_(seq_lens[:bs])
170
+
163
171
  self.forward_metadata = metadata
164
172
 
165
173
  self.graph_mode = True
@@ -167,6 +175,64 @@ class AscendAttnBackend(AttentionBackend):
167
175
  def get_cuda_graph_seq_len_fill_value(self):
168
176
  return 0
169
177
 
178
+ def forward_sparse(
179
+ self,
180
+ q: torch.Tensor,
181
+ k: torch.Tensor,
182
+ v: torch.Tensor,
183
+ layer: RadixAttention,
184
+ forward_batch: ForwardBatch,
185
+ save_kv_cache: bool = True,
186
+ # For multi_head latent attention
187
+ q_rope: Optional[torch.Tensor] = None,
188
+ k_rope: Optional[torch.Tensor] = None,
189
+ topk_indices: torch.Tensor = None,
190
+ ):
191
+
192
+ is_prefill = forward_batch.forward_mode.is_extend()
193
+
194
+ if save_kv_cache:
195
+ k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
196
+ k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)
197
+ forward_batch.token_to_kv_pool.set_kv_buffer(
198
+ layer, forward_batch.out_cache_loc, k, k_rope
199
+ )
200
+ q_nope, q_pe = q, q_rope
201
+ k_nope, k_pe = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
202
+ block_table = self.forward_metadata.block_tables
203
+ if is_prefill:
204
+ actual_seq_qlen = torch.cumsum(forward_batch.seq_lens, dim=0)
205
+ else:
206
+ if self.forward_metadata.actual_seq_lengths_q is None:
207
+ actual_seq_qlen = (
208
+ torch.arange(1, q.shape[0] + 1).to(q.device).to(torch.int32)
209
+ )
210
+ else:
211
+ actual_seq_qlen = self.forward_metadata.actual_seq_lengths_q
212
+ if self.forward_metadata.seq_lens_cpu_int is None:
213
+ actual_seq_lengths_kv = self.forward_metadata.seq_lens
214
+ else:
215
+ actual_seq_lengths_kv = self.forward_metadata.seq_lens_cpu_int
216
+
217
+ attn_out = torch.ops.custom.npu_sparse_flash_attention(
218
+ query=q_nope,
219
+ key=k_nope,
220
+ value=k_nope,
221
+ query_rope=q_pe,
222
+ key_rope=k_pe,
223
+ sparse_indices=topk_indices,
224
+ scale_value=layer.scaling,
225
+ actual_seq_lengths_query=actual_seq_qlen.to(torch.int32),
226
+ actual_seq_lengths_kv=actual_seq_lengths_kv.to(q.device),
227
+ block_table=block_table,
228
+ sparse_block_size=1,
229
+ layout_query="TND",
230
+ layout_kv="PA_BSND",
231
+ sparse_mode=3,
232
+ )
233
+
234
+ return attn_out
235
+
170
236
  def forward_extend(
171
237
  self,
172
238
  q,
@@ -175,7 +241,23 @@ class AscendAttnBackend(AttentionBackend):
175
241
  layer: RadixAttention,
176
242
  forward_batch: ForwardBatch,
177
243
  save_kv_cache: bool = True,
244
+ # For multi_head latent attention
245
+ q_rope: Optional[torch.Tensor] = None,
246
+ k_rope: Optional[torch.Tensor] = None,
247
+ topk_indices: Optional[torch.Tensor] = None,
178
248
  ):
249
+ if topk_indices is not None:
250
+ return self.forward_sparse(
251
+ q,
252
+ k,
253
+ v,
254
+ layer,
255
+ forward_batch,
256
+ save_kv_cache,
257
+ q_rope,
258
+ k_rope,
259
+ topk_indices,
260
+ )
179
261
  if not self.use_mla:
180
262
  if save_kv_cache:
181
263
  forward_batch.token_to_kv_pool.set_kv_buffer(
@@ -401,7 +483,7 @@ class AscendAttnBackend(AttentionBackend):
401
483
  antiquant_scale=None,
402
484
  sparse_mode=0,
403
485
  )
404
- output = torch.zeros_like(q_nope, dtype=q.dtype, device=q.device)
486
+ output = torch.empty_like(q_nope, dtype=q.dtype, device=q.device)
405
487
  softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
406
488
 
407
489
  torch_npu.npu_fused_infer_attention_score.out(
@@ -436,7 +518,24 @@ class AscendAttnBackend(AttentionBackend):
436
518
  # For multi-head latent attention
437
519
  q_rope: Optional[torch.Tensor] = None,
438
520
  k_rope: Optional[torch.Tensor] = None,
521
+ topk_indices: Optional[torch.Tensor] = None,
439
522
  ):
523
+ if is_mla_preprocess_enabled():
524
+ # MLAPO does saving kv_cache
525
+ save_kv_cache = False
526
+ if topk_indices is not None:
527
+ return self.forward_sparse(
528
+ q,
529
+ k,
530
+ v,
531
+ layer,
532
+ forward_batch,
533
+ save_kv_cache,
534
+ q_rope,
535
+ k_rope,
536
+ topk_indices,
537
+ )
538
+
440
539
  if self.graph_mode:
441
540
  return self.forward_decode_graph(
442
541
  q,