sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -12,6 +12,7 @@ from sglang.srt.custom_op import CustomOp
12
12
  from sglang.srt.utils import (
13
13
  cpu_has_amx_support,
14
14
  get_bool_env_var,
15
+ get_compiler_backend,
15
16
  is_cpu,
16
17
  is_cuda,
17
18
  is_hip,
@@ -26,13 +27,19 @@ _is_cpu_amx_available = cpu_has_amx_support()
26
27
  _is_cpu = is_cpu()
27
28
 
28
29
  if _is_cuda:
29
- from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
30
+ from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace
31
+ else:
32
+ FusedSetKVBufferArg = None
33
+
30
34
  if _use_aiter:
31
35
  from aiter.rotary_embedding import get_rope as aiter_get_rope
32
36
 
33
37
  if is_npu():
34
38
  import torch_npu
35
39
 
40
+ NPU_ROTARY_MUL_MAX_NUM_HEADS = 1000
41
+ NPU_ROTARY_MUL_MAX_HEAD_SIZE = 896
42
+
36
43
 
37
44
  def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
38
45
  x1 = x[..., : x.shape[-1] // 2]
@@ -142,8 +149,13 @@ class RotaryEmbedding(CustomOp):
142
149
  query: torch.Tensor,
143
150
  key: torch.Tensor,
144
151
  offsets: Optional[torch.Tensor] = None,
152
+ fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
145
153
  ) -> Tuple[torch.Tensor, torch.Tensor]:
146
154
  """A PyTorch-native implementation of forward()."""
155
+ assert (
156
+ fused_set_kv_buffer_arg is None
157
+ ), "fused_set_kv_buffer_arg is not supported for native implementation"
158
+
147
159
  if offsets is not None:
148
160
  positions = positions + offsets
149
161
  positions = positions.flatten()
@@ -172,12 +184,17 @@ class RotaryEmbedding(CustomOp):
172
184
  query: torch.Tensor,
173
185
  key: torch.Tensor,
174
186
  offsets: Optional[torch.Tensor] = None,
187
+ fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
175
188
  ) -> Tuple[torch.Tensor, torch.Tensor]:
176
189
  """A PyTorch-npu implementation of forward()."""
177
- import os
190
+ assert (
191
+ fused_set_kv_buffer_arg is None
192
+ ), "fused_set_kv_buffer_arg is not supported for npu implementation"
178
193
 
179
194
  if get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE"):
180
- return self.forward_native(positions, query, key, offsets)
195
+ return self.forward_native(
196
+ positions, query, key, offsets, fused_set_kv_buffer_arg
197
+ )
181
198
  else:
182
199
  rotary_mode = "half"
183
200
  if self.is_neox_style:
@@ -202,7 +219,12 @@ class RotaryEmbedding(CustomOp):
202
219
  query: torch.Tensor,
203
220
  key: torch.Tensor,
204
221
  offsets: Optional[torch.Tensor] = None,
222
+ fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
205
223
  ) -> Tuple[torch.Tensor, torch.Tensor]:
224
+ assert (
225
+ fused_set_kv_buffer_arg is None
226
+ ), "fused_set_kv_buffer_arg is not supported for cpu implementation"
227
+
206
228
  positions = torch.add(positions, offsets) if offsets is not None else positions
207
229
  if _is_cpu_amx_available:
208
230
  return torch.ops.sgl_kernel.rotary_embedding_cpu(
@@ -214,7 +236,9 @@ class RotaryEmbedding(CustomOp):
214
236
  self.is_neox_style,
215
237
  )
216
238
  else:
217
- return self.forward_native(positions, query, key, offsets)
239
+ return self.forward_native(
240
+ positions, query, key, offsets, fused_set_kv_buffer_arg
241
+ )
218
242
 
219
243
  def forward_cuda(
220
244
  self,
@@ -222,7 +246,7 @@ class RotaryEmbedding(CustomOp):
222
246
  query: torch.Tensor,
223
247
  key: torch.Tensor,
224
248
  offsets: Optional[torch.Tensor] = None,
225
- fused_set_kv_buffer_arg=None, # Optional[FusedSetKVBufferArg]
249
+ fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
226
250
  ) -> Tuple[torch.Tensor, torch.Tensor]:
227
251
  if _is_cuda and (self.head_size in [64, 128, 256, 512]):
228
252
  apply_rope_with_cos_sin_cache_inplace(
@@ -782,27 +806,33 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
782
806
  key: torch.Tensor,
783
807
  offsets: Optional[torch.Tensor] = None,
784
808
  ) -> Tuple[torch.Tensor, torch.Tensor]:
785
- # NOTE: now npu_mrope can only support `numQHeads*headSize <= 4096` pattern,
786
- # and generalization to more scenarios will be supported in the future.
787
- if query.shape[1] * query.shape[2] > 4096:
788
- return self.forward_native(positions, query, key, offsets)
789
- num_tokens = query.shape[0]
790
- rotary_mode = "half" if self.is_neox_style else "interleave"
809
+ num_tokens, num_q_heads, _ = query.shape
810
+ num_k_heads = key.shape[1]
811
+
791
812
  self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device)
813
+ cos_sin = self.cos_sin_cache[
814
+ torch.add(positions, offsets) if offsets is not None else positions
815
+ ]
816
+ cos, sin = cos_sin.chunk(2, dim=-1)
817
+ # Reshape to [batchsize, head_dim, seq, rotary_dim]
818
+ cos = cos.repeat(1, 2).unsqueeze(-2).unsqueeze(-2)
819
+ sin = sin.repeat(1, 2).unsqueeze(-2).unsqueeze(-2)
820
+
792
821
  query_rot = query[..., : self.rotary_dim]
793
822
  key_rot = key[..., : self.rotary_dim]
794
823
  if self.rotary_dim < self.head_size:
795
824
  query_pass = query[..., self.rotary_dim :]
796
825
  key_pass = key[..., self.rotary_dim :]
797
826
 
798
- query_rot, key_rot = torch_npu.npu_mrope(
799
- torch.add(positions, offsets) if offsets is not None else positions,
800
- query_rot.reshape(num_tokens, -1),
801
- key_rot.reshape(num_tokens, -1),
802
- self.cos_sin_cache,
803
- self.rotary_dim,
804
- mrope_section=[0, 0, 0],
805
- rotary_mode=rotary_mode,
827
+ query_rot = torch_npu.npu_interleave_rope(
828
+ query_rot.reshape(num_tokens, num_q_heads, 1, self.rotary_dim),
829
+ cos,
830
+ sin,
831
+ )
832
+ key_rot = torch_npu.npu_interleave_rope(
833
+ key_rot.reshape(num_tokens, num_k_heads, 1, self.rotary_dim),
834
+ cos,
835
+ sin,
806
836
  )
807
837
  query_rot = query_rot.reshape(num_tokens, -1, self.rotary_dim)
808
838
  key_rot = key_rot.reshape(num_tokens, -1, self.rotary_dim)
@@ -1029,12 +1059,13 @@ class MRotaryEmbedding(RotaryEmbedding):
1029
1059
  f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})"
1030
1060
  )
1031
1061
 
1032
- @torch.compile(dynamic=True)
1062
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
1033
1063
  def forward(
1034
1064
  self,
1035
1065
  positions: torch.Tensor,
1036
1066
  query: torch.Tensor,
1037
1067
  key: torch.Tensor,
1068
+ fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
1038
1069
  ) -> Tuple[torch.Tensor, torch.Tensor]:
1039
1070
  """PyTorch-native implementation equivalent to forward().
1040
1071
 
@@ -1045,6 +1076,9 @@ class MRotaryEmbedding(RotaryEmbedding):
1045
1076
  query: [num_tokens, num_heads * head_size]
1046
1077
  key: [num_tokens, num_kv_heads * head_size]
1047
1078
  """
1079
+ assert (
1080
+ fused_set_kv_buffer_arg is None
1081
+ ), "save kv cache is not supported for MRotaryEmbedding."
1048
1082
  assert positions.ndim == 1 or positions.ndim == 2
1049
1083
 
1050
1084
  num_tokens = positions.shape[-1]
@@ -1177,7 +1211,7 @@ class MRotaryEmbedding(RotaryEmbedding):
1177
1211
 
1178
1212
  time_tensor_long = time_tensor.long()
1179
1213
  t_index = time_tensor_long.flatten()
1180
- elif model_type == "qwen2_vl":
1214
+ elif model_type in ("qwen2_vl", "qwen3_vl", "qwen3_vl_moe"):
1181
1215
  t_index = (
1182
1216
  torch.arange(llm_grid_t)
1183
1217
  .view(-1, 1)
@@ -1888,17 +1922,30 @@ def apply_rotary_pos_emb_npu(
1888
1922
  sin: torch.Tensor,
1889
1923
  unsqueeze_dim=1,
1890
1924
  ) -> Tuple[torch.Tensor, torch.Tensor]:
1891
- if q.shape[1] != 128:
1925
+ """Ascend implementation equivalent to apply_rotary_pos_emb_native.
1926
+
1927
+ Args:
1928
+ q: [num_tokens, num_heads, head_size]
1929
+ k: [num_tokens, num_kv_heads, head_size]
1930
+ cos: [num_tokens, head_size]
1931
+ sin: [num_tokens, head_size]
1932
+ """
1933
+ if (
1934
+ cos.dim() != 2
1935
+ or q.dim() != 3
1936
+ or q.shape[1] >= NPU_ROTARY_MUL_MAX_NUM_HEADS
1937
+ or q.shape[2] >= NPU_ROTARY_MUL_MAX_HEAD_SIZE
1938
+ ):
1939
+ # Note: num_heads and head_size of q must be less than 1000 and 896, respectively
1892
1940
  return apply_rotary_pos_emb_native(q, k, cos, sin, unsqueeze_dim)
1893
- cos = cos.unsqueeze(unsqueeze_dim)
1894
- cos = torch.transpose(cos, 1, 2)
1895
- sin = sin.unsqueeze(unsqueeze_dim)
1896
- sin = torch.transpose(sin, 1, 2)
1897
- q = torch.transpose(q, 1, 2)
1898
- k = torch.transpose(k, 1, 2)
1899
- q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(q, k, cos, sin)
1900
- q_embed = torch.transpose(q_embed, 1, 2)
1901
- k_embed = torch.transpose(k_embed, 1, 2)
1941
+ cos = cos.unsqueeze(unsqueeze_dim).unsqueeze(0)
1942
+ sin = sin.unsqueeze(unsqueeze_dim).unsqueeze(0)
1943
+ q = q.unsqueeze(0)
1944
+ k = k.unsqueeze(0)
1945
+ q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
1946
+ k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
1947
+ q_embed = q_embed.squeeze(0)
1948
+ k_embed = k_embed.squeeze(0)
1902
1949
  return q_embed, k_embed
1903
1950
 
1904
1951
 
@@ -1,5 +1,5 @@
1
1
  import logging
2
- from typing import List, Tuple
2
+ from typing import List, Optional, Tuple
3
3
 
4
4
  import torch
5
5
  import torch.distributed as dist
@@ -65,6 +65,7 @@ class Sampler(nn.Module):
65
65
  return_logprob: bool,
66
66
  top_logprobs_nums: List[int],
67
67
  token_ids_logprobs: List[List[int]],
68
+ positions: torch.Tensor,
68
69
  ):
69
70
  """Run a sampler & compute logprobs and update logits_output accordingly.
70
71
 
@@ -77,6 +78,8 @@ class Sampler(nn.Module):
77
78
  batch_next_token_ids: next token IDs. If set, skip sampling and only
78
79
  compute output logprobs It is used for speculative decoding which
79
80
  performs sampling in draft workers.
81
+ positions: The positions of the tokens in the sequence. Used for deterministic sampling
82
+ to get the unique seed for each position.
80
83
  """
81
84
  logits = logits_output.next_token_logits
82
85
 
@@ -124,6 +127,8 @@ class Sampler(nn.Module):
124
127
  sampling_info.top_ps,
125
128
  sampling_info.min_ps,
126
129
  sampling_info.need_min_p_sampling,
130
+ sampling_info.sampling_seed,
131
+ positions,
127
132
  )
128
133
  else:
129
134
  raise ValueError(
@@ -189,6 +194,7 @@ class Sampler(nn.Module):
189
194
  Optimized for prefill-only scoring requests that need token probabilities
190
195
  but don't require next token generation.
191
196
  """
197
+
192
198
  if logits_output.next_token_logits is None:
193
199
  logger.warning("No logits available for logprob computation")
194
200
  return
@@ -230,8 +236,14 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
230
236
  top_ps: torch.Tensor,
231
237
  min_ps: torch.Tensor,
232
238
  need_min_p_sampling: bool,
239
+ sampling_seed: Optional[torch.Tensor],
240
+ positions: torch.Tensor,
233
241
  ):
234
- """A top-k, top-p and min-p sampling implementation with native pytorch operations."""
242
+ """
243
+ A top-k, top-p and min-p sampling implementation with native pytorch operations.
244
+ When sampling_seed is not None, deterministic inference will be enabled, it will sample
245
+ with the sampling_seed of each request.
246
+ """
235
247
  probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
236
248
  probs_sum = torch.cumsum(probs_sort, dim=-1)
237
249
  probs_sort[
@@ -243,14 +255,50 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
243
255
  if need_min_p_sampling:
244
256
  min_p_thresholds = probs_sort[:, 0] * min_ps
245
257
  probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
246
-
247
- sampled_index = torch.multinomial(probs_sort, num_samples=1)
258
+ if sampling_seed is not None:
259
+ sampled_index = multinomial_with_seed(probs_sort, sampling_seed, positions)
260
+ else:
261
+ sampled_index = torch.multinomial(probs_sort, num_samples=1)
248
262
  # int32 range is enough to represent the token ids
249
263
  probs_idx = probs_idx.to(torch.int32)
250
264
  batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
251
265
  return batch_next_token_ids
252
266
 
253
267
 
268
+ def multinomial_with_seed(
269
+ inputs: torch.Tensor, seed: torch.Tensor, positions: torch.Tensor
270
+ ) -> torch.Tensor:
271
+ """
272
+ Samples n elements from an input tensor `inputs` of shape (n, m) using
273
+ a unique random seed for each row. This is a deterministic batched alternative to
274
+ `torch.multinomial`.
275
+
276
+ Args:
277
+ inputs: A float tensor of shape (n, m) representing n categorical
278
+ distributions with m categories each. The values are treated
279
+ as weights and do not need to sum to 1.
280
+ seed: An integer tensor of shape (n,) containing the random seed
281
+ for each corresponding row in `inputs`.
282
+ positions: The positions of the tokens in the sequence. Used for deterministic sampling
283
+ to get the unique seed for each position.
284
+
285
+ Returns:
286
+ A tensor of shape (n,) where the i-th element is an index sampled
287
+ from the distribution in `inputs[i]` using `seed[i]`.
288
+ """
289
+ n, m = inputs.shape
290
+ col_indices = torch.arange(m, device=inputs.device).unsqueeze(0)
291
+ step_seed = seed * 19349663 ^ positions * 73856093
292
+ seed_expanded = step_seed.unsqueeze(-1)
293
+ hashed = seed_expanded * 8589934591 ^ col_indices * 479001599
294
+ uniform_samples = (hashed % (2**24)).float() / (2**24)
295
+ epsilon = 1e-9
296
+ gumbel_noise = -torch.log(-torch.log(uniform_samples + epsilon) + epsilon)
297
+ log_probs = torch.log(inputs + epsilon)
298
+ perturbed_log_probs = log_probs + gumbel_noise
299
+ return torch.argmax(perturbed_log_probs, dim=1, keepdim=True)
300
+
301
+
254
302
  def sampling_from_probs_torch(probs: torch.Tensor):
255
303
  """A sampling implementation with native pytorch operations, without
256
304
  top-k, top-p, or min-p filtering."""
@@ -15,6 +15,29 @@ def get_layer_id(weight_name):
15
15
  return None
16
16
 
17
17
 
18
+ def pad_or_narrow_weight(
19
+ loaded_weight: torch.Tensor, input_dim: int, start_idx: int, shard_size: int
20
+ ) -> torch.Tensor:
21
+ # Padding with zeros for special case such as qwen2_5_VL's mlp which is not 8-aligned
22
+ valid_size = max(loaded_weight.shape[input_dim] - start_idx, 0)
23
+
24
+ if valid_size > 0:
25
+ loaded_slice = loaded_weight.narrow(input_dim, start_idx, valid_size)
26
+ pad_shape = list(loaded_weight.shape)
27
+ pad_shape[input_dim] = shard_size - valid_size
28
+ pad = torch.zeros(
29
+ pad_shape, dtype=loaded_weight.dtype, device=loaded_weight.device
30
+ )
31
+ return torch.cat([loaded_slice, pad], dim=input_dim)
32
+
33
+ # All padding
34
+ pad_shape = list(loaded_weight.shape)
35
+ pad_shape[input_dim] = shard_size
36
+ return torch.zeros(
37
+ pad_shape, dtype=loaded_weight.dtype, device=loaded_weight.device
38
+ )
39
+
40
+
18
41
  class PPMissingLayer(torch.nn.Identity):
19
42
  # Adapted from
20
43
  # https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1
@@ -143,10 +143,10 @@ def get_backend_from_name(name: str) -> BaseLoRABackend:
143
143
  from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
144
144
 
145
145
  return TritonLoRABackend
146
- # elif name == "csgmv":
147
- # from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
146
+ elif name == "csgmv":
147
+ from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
148
148
 
149
- # return ChunkedSgmvLoRABackend
149
+ return ChunkedSgmvLoRABackend
150
150
  elif name == "flashinfer":
151
151
  raise ValueError(
152
152
  "FlashInfer LoRA backend has been deprecated, please use `triton` instead."