sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -21,7 +21,6 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple
21
21
  import torch
22
22
 
23
23
  from sglang.srt.configs.load_config import LoadConfig
24
- from sglang.srt.hf_transformers_utils import AutoConfig
25
24
  from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_from_name
26
25
  from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
27
26
  from sglang.srt.lora.lora import LoRAAdapter
@@ -35,9 +34,11 @@ from sglang.srt.lora.utils import (
35
34
  get_normalized_target_modules,
36
35
  get_target_module_name,
37
36
  )
38
- from sglang.srt.managers.io_struct import LoRAUpdateResult
37
+ from sglang.srt.managers.io_struct import LoRAUpdateOutput
39
38
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
39
+ from sglang.srt.server_args import ServerArgs
40
40
  from sglang.srt.utils import replace_submodule
41
+ from sglang.srt.utils.hf_transformers_utils import AutoConfig
41
42
 
42
43
  logger = logging.getLogger(__name__)
43
44
 
@@ -56,6 +57,7 @@ class LoRAManager:
56
57
  max_lora_rank: Optional[int] = None,
57
58
  target_modules: Optional[Iterable[str]] = None,
58
59
  lora_paths: Optional[List[LoRARef]] = None,
60
+ server_args: Optional[ServerArgs] = None,
59
61
  ):
60
62
  self.base_model: torch.nn.Module = base_model
61
63
  self.base_hf_config: AutoConfig = base_hf_config
@@ -69,7 +71,11 @@ class LoRAManager:
69
71
  # LoRA backend for running sgemm kernels
70
72
  logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
71
73
  backend_type = get_backend_from_name(lora_backend)
72
- self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
74
+ self.lora_backend: BaseLoRABackend = backend_type(
75
+ max_loras_per_batch=max_loras_per_batch,
76
+ device=self.device,
77
+ server_args=server_args,
78
+ )
73
79
 
74
80
  # Initialize mutable internal state of the LoRAManager.
75
81
  self.init_state(
@@ -82,34 +88,27 @@ class LoRAManager:
82
88
  self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
83
89
  with torch.device("cuda"):
84
90
  self.cuda_graph_batch_info = LoRABatchInfo(
85
- bs=self.max_bs_in_cuda_graph,
86
- seg_lens=torch.zeros(self.max_bs_in_cuda_graph, dtype=torch.int32),
87
- seg_indptr=torch.zeros(
88
- self.max_bs_in_cuda_graph + 1, dtype=torch.int32
89
- ),
91
+ bs=max_bs_in_cuda_graph,
92
+ use_cuda_graph=True,
93
+ num_segments=None,
94
+ seg_lens=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
95
+ seg_indptr=torch.zeros(max_bs_in_cuda_graph + 1, dtype=torch.int32),
90
96
  max_len=1,
91
- weight_indices=torch.zeros(
92
- self.max_bs_in_cuda_graph, dtype=torch.int32
93
- ),
97
+ weight_indices=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
98
+ permutation=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
94
99
  lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
95
100
  scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
96
101
  )
97
102
 
98
- # Initialize seg_lens and seg_indptr for CUDA graph as they remain constant
99
- # across batches.
100
- self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph].fill_(1)
101
- torch.cumsum(
102
- self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph],
103
- dim=0,
104
- out=self.cuda_graph_batch_info.seg_indptr[
105
- 1 : self.max_bs_in_cuda_graph + 1
106
- ],
107
- )
103
+ self.lora_backend.init_cuda_graph_batch_info(
104
+ cuda_graph_batch_info=self.cuda_graph_batch_info,
105
+ max_bs_in_cuda_graph=max_bs_in_cuda_graph,
106
+ )
108
107
 
109
108
  def create_lora_update_result(
110
109
  self, success: bool, error_message: str = ""
111
- ) -> LoRAUpdateResult:
112
- return LoRAUpdateResult(
110
+ ) -> LoRAUpdateOutput:
111
+ return LoRAUpdateOutput(
113
112
  success=success,
114
113
  error_message=error_message,
115
114
  loaded_adapters={
@@ -118,7 +117,7 @@ class LoRAManager:
118
117
  },
119
118
  )
120
119
 
121
- def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
120
+ def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput:
122
121
  """
123
122
  Load a single LoRA adapter from the specified path.
124
123
 
@@ -175,7 +174,7 @@ class LoRAManager:
175
174
  "`--max-loras-per-batch` or load it as unpinned LoRA adapters."
176
175
  )
177
176
 
178
- def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
177
+ def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput:
179
178
  """
180
179
  Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
181
180
  delete the corresponding LoRA modules.
@@ -232,7 +231,6 @@ class LoRAManager:
232
231
  return required_slots <= mem_pool_vacancy
233
232
 
234
233
  def prepare_lora_batch(self, forward_batch: ForwardBatch):
235
-
236
234
  # Load active loras into lora memory pool
237
235
  cur_uids = set(forward_batch.lora_ids)
238
236
 
@@ -247,102 +245,30 @@ class LoRAManager:
247
245
  # set up batch info shared by all lora modules
248
246
  bs = forward_batch.batch_size
249
247
 
250
- def transfer_adapter_info(
251
- weight_indices_out: torch.Tensor,
252
- lora_ranks_out: torch.Tensor,
253
- scalings_out: torch.Tensor,
254
- ):
255
- """
256
- Transfer adapter metadata (weight indices, LoRA rank, scalings) from host
257
- to device (CUDA) asynchronously.
258
- """
259
- weight_indices = [0] * len(forward_batch.lora_ids)
260
- lora_ranks = [0] * self.max_loras_per_batch
261
- scalings = [0] * self.max_loras_per_batch
262
- for i, uid in enumerate(forward_batch.lora_ids):
263
- weight_indices[i] = self.memory_pool.get_buffer_id(uid)
264
- if uid is not None:
265
- lora = self.loras[uid]
266
- lora_ranks[weight_indices[i]] = lora.config.r
267
- scalings[weight_indices[i]] = lora.scaling
268
-
269
- # Use pinned memory to avoid synchronizations during host-to-device transfer
270
- weight_indices_tensor = torch.tensor(
271
- weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
272
- )
273
- lora_ranks_tensor = torch.tensor(
274
- lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
275
- )
276
- scalings_tensor = torch.tensor(
277
- scalings, dtype=torch.float, pin_memory=True, device="cpu"
278
- )
279
-
280
- # Copy to device tensors asynchronously
281
- weight_indices_out[:bs].copy_(weight_indices_tensor, non_blocking=True)
282
- lora_ranks_out[: self.max_loras_per_batch].copy_(
283
- lora_ranks_tensor, non_blocking=True
284
- )
285
- scalings_out[: self.max_loras_per_batch].copy_(
286
- scalings_tensor, non_blocking=True
287
- )
288
-
289
- if (
248
+ use_cuda_graph = (
290
249
  hasattr(self, "max_bs_in_cuda_graph")
291
250
  and bs <= self.max_bs_in_cuda_graph
292
251
  and forward_batch.forward_mode.is_cuda_graph()
293
- ):
294
- # Do in-place updates when CUDA graph is enabled and the batch forward mode
295
- # could use CUDA graph.
296
-
297
- transfer_adapter_info(
298
- self.cuda_graph_batch_info.weight_indices,
299
- self.cuda_graph_batch_info.lora_ranks,
300
- self.cuda_graph_batch_info.scalings,
301
- )
302
-
303
- self.cuda_graph_batch_info.bs = bs
304
- self.cuda_graph_batch_info.max_len = 1
305
- batch_info = self.cuda_graph_batch_info
306
- else:
307
- weight_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
308
- lora_ranks = torch.zeros(
309
- (self.max_loras_per_batch,), dtype=torch.int64, device=self.device
310
- )
311
- scalings = torch.zeros(
312
- (self.max_loras_per_batch,), dtype=torch.float, device=self.device
313
- )
314
- transfer_adapter_info(
315
- weight_indices,
316
- lora_ranks,
317
- scalings,
318
- )
319
-
320
- seg_lens = (
321
- forward_batch.extend_seq_lens
322
- if forward_batch.forward_mode.is_extend()
323
- else torch.ones(bs, device=self.device)
324
- )
325
-
326
- max_len = (
327
- # Calculate max_len from the CPU copy to avoid D2H transfer.
328
- max(forward_batch.extend_seq_lens_cpu)
329
- if forward_batch.forward_mode.is_extend()
330
- else 1
331
- )
252
+ )
332
253
 
333
- seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
334
- seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
335
-
336
- batch_info = LoRABatchInfo(
337
- bs=bs,
338
- seg_lens=seg_lens,
339
- seg_indptr=seg_indptr,
340
- max_len=max_len,
341
- weight_indices=weight_indices,
342
- lora_ranks=lora_ranks,
343
- scalings=scalings,
344
- )
345
- self.lora_backend.set_batch_info(batch_info)
254
+ weight_indices = [0] * len(forward_batch.lora_ids)
255
+ lora_ranks = [0] * self.max_loras_per_batch
256
+ scalings = [0] * self.max_loras_per_batch
257
+ for i, uid in enumerate(forward_batch.lora_ids):
258
+ weight_indices[i] = self.memory_pool.get_buffer_id(uid)
259
+ if uid is not None:
260
+ lora = self.loras[uid]
261
+ lora_ranks[weight_indices[i]] = lora.config.r
262
+ scalings[weight_indices[i]] = lora.scaling
263
+ # Do in-place updates when CUDA graph is enabled and the batch forward mode
264
+ # could use CUDA graph.
265
+ self.lora_backend.prepare_lora_batch(
266
+ forward_batch=forward_batch,
267
+ weight_indices=weight_indices,
268
+ lora_ranks=lora_ranks,
269
+ scalings=scalings,
270
+ batch_info=self.cuda_graph_batch_info if use_cuda_graph else None,
271
+ )
346
272
 
347
273
  def update_lora_info(self):
348
274
  """
@@ -492,6 +418,10 @@ class LoRAManager:
492
418
  replace_submodule(self.base_model, module_name, lora_module)
493
419
  return lora_module
494
420
 
421
+ def should_skip_lora_for_vision_model(self, module_name):
422
+ # TODO: support different vision models
423
+ return module_name.find("vision_model.model") != -1
424
+
495
425
  def init_lora_modules(self):
496
426
  # Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
497
427
  self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [
@@ -509,6 +439,10 @@ class LoRAManager:
509
439
  ) and not self.base_model.should_apply_lora(module_name):
510
440
  continue
511
441
 
442
+ # Skip vision model
443
+ if self.should_skip_lora_for_vision_model(module_name):
444
+ continue
445
+
512
446
  # The module should be converted if it is included in target_names
513
447
  if module_name.split(".")[-1] in self.target_modules:
514
448
  layer_id = get_layer_id(module_name)
@@ -4,7 +4,6 @@ from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
4
4
  import torch
5
5
 
6
6
  from sglang.srt.distributed import divide
7
- from sglang.srt.hf_transformers_utils import AutoConfig
8
7
  from sglang.srt.lora.layers import BaseLayerWithLoRA
9
8
  from sglang.srt.lora.lora import LoRAAdapter
10
9
  from sglang.srt.lora.lora_config import LoRAConfig
@@ -17,6 +16,7 @@ from sglang.srt.lora.utils import (
17
16
  get_stacked_multiply,
18
17
  get_target_module_name,
19
18
  )
19
+ from sglang.srt.utils.hf_transformers_utils import AutoConfig
20
20
 
21
21
  logger = logging.getLogger(__name__)
22
22
 
@@ -104,12 +104,18 @@ class LoRAMemoryPool:
104
104
  return all(_can_support(x) for x in config)
105
105
 
106
106
  def get_lora_A_shape(
107
- self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
107
+ self,
108
+ module_name: str,
109
+ base_model: torch.nn.Module,
110
+ max_lora_dim: int,
111
+ layer_idx: int,
108
112
  ) -> Tuple[int]:
109
113
  """
110
114
  Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
111
115
  """
112
- input_dim, _ = get_hidden_dim(module_name, self.base_hf_config, base_model)
116
+ input_dim, _ = get_hidden_dim(
117
+ module_name, self.base_hf_config, base_model, layer_idx
118
+ )
113
119
  c = get_stacked_multiply(module_name)
114
120
  if self.tp_size > 1 and module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES:
115
121
  input_dim = divide(input_dim, self.tp_size)
@@ -120,12 +126,18 @@ class LoRAMemoryPool:
120
126
  )
121
127
 
122
128
  def get_lora_B_shape(
123
- self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
129
+ self,
130
+ module_name: str,
131
+ base_model: torch.nn.Module,
132
+ max_lora_dim: int,
133
+ layer_idx: int,
124
134
  ) -> Tuple[int]:
125
135
  """
126
136
  Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
127
137
  """
128
- _, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
138
+ _, output_dim = get_hidden_dim(
139
+ module_name, self.base_hf_config, base_model, layer_idx
140
+ )
129
141
  if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
130
142
  output_dim = divide(output_dim, self.tp_size)
131
143
  return (
@@ -140,19 +152,21 @@ class LoRAMemoryPool:
140
152
  def init_buffer(
141
153
  buffer: Dict[str, List[torch.Tensor]],
142
154
  target_modules: Set[str],
143
- get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
155
+ get_lora_shape_fn: Callable[[str, torch.nn.Module, int, int], Tuple[int]],
144
156
  ):
145
157
  for module_name in target_modules:
146
- lora_shape = get_lora_shape_fn(
147
- module_name, base_model, self.max_lora_rank
148
- )
149
158
  buffer[module_name] = [
150
159
  torch.empty(
151
- lora_shape,
160
+ get_lora_shape_fn(
161
+ module_name,
162
+ base_model,
163
+ self.max_lora_rank,
164
+ idx,
165
+ ),
152
166
  dtype=self.dtype,
153
167
  device=device,
154
168
  )
155
- for _ in range(self.num_layer)
169
+ for idx in range(self.num_layer)
156
170
  ]
157
171
 
158
172
  init_buffer(
@@ -1,3 +1,5 @@
1
+ from .chunked_sgmv_expand import chunked_sgmv_lora_expand_forward
2
+ from .chunked_sgmv_shrink import chunked_sgmv_lora_shrink_forward
1
3
  from .gate_up_lora_b import gate_up_lora_b_fwd
2
4
  from .qkv_lora_b import qkv_lora_b_fwd
3
5
  from .sgemm_lora_a import sgemm_lora_a_fwd
@@ -8,4 +10,6 @@ __all__ = [
8
10
  "qkv_lora_b_fwd",
9
11
  "sgemm_lora_a_fwd",
10
12
  "sgemm_lora_b_fwd",
13
+ "chunked_sgmv_lora_shrink_forward",
14
+ "chunked_sgmv_lora_expand_forward",
11
15
  ]
@@ -0,0 +1,214 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from sglang.srt.lora.utils import LoRABatchInfo
8
+ from sglang.srt.utils import cached_triton_kernel
9
+
10
+
11
+ @cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
12
+ @triton.jit
13
+ def _chunked_lora_expand_kernel(
14
+ # Pointers to matrices
15
+ x,
16
+ weights,
17
+ output,
18
+ # Information on sequence lengths and weight id
19
+ seg_indptr,
20
+ weight_indices,
21
+ lora_ranks,
22
+ permutation,
23
+ num_segs,
24
+ # For fused output scaling
25
+ scalings,
26
+ # Offsets of q/k/v slice on output dimension
27
+ slice_offsets,
28
+ # Meta parameters
29
+ NUM_SLICES: tl.constexpr,
30
+ OUTPUT_DIM: tl.constexpr,
31
+ MAX_RANK: tl.constexpr, # K = R
32
+ BLOCK_M: tl.constexpr,
33
+ BLOCK_N: tl.constexpr,
34
+ BLOCK_K: tl.constexpr,
35
+ ):
36
+ """
37
+ Computes a chunked SGMV for LoRA expand operations.
38
+
39
+ When a sequence's rank is 0, the kernel is essentially a no-op, following
40
+ the convention in pytorch where the product of two matrices of shape (m, 0)
41
+ and (0, n) is an all-zero matrix of shape (m, n).
42
+
43
+ Args:
44
+ x (Tensor): The input tensor, which is the result of the LoRA A projection.
45
+ Shape: (s, num_slices * K), where s is the sum of all sequence lengths in the
46
+ batch and K is the maximum LoRA rank.
47
+ weights (Tensor): The LoRA B weights for all adapters.
48
+ Shape: (num_lora, output_dim, K).
49
+ output (Tensor): The output tensor where the result is stored.
50
+ Shape: (s, output_dim).
51
+ """
52
+ tl.static_assert(NUM_SLICES <= 3)
53
+
54
+ x_stride_0: tl.constexpr = NUM_SLICES * MAX_RANK
55
+ x_stride_1: tl.constexpr = 1
56
+
57
+ w_stride_0: tl.constexpr = OUTPUT_DIM * MAX_RANK
58
+ w_stride_1: tl.constexpr = MAX_RANK
59
+ w_stride_2: tl.constexpr = 1
60
+
61
+ output_stride_0: tl.constexpr = OUTPUT_DIM
62
+ output_stride_1: tl.constexpr = 1
63
+
64
+ pid_s = tl.program_id(axis=2)
65
+ if pid_s >= num_segs:
66
+ return
67
+
68
+ # Current block computes sequence with batch_id,
69
+ # which starts from row seg_start of x with length seg_len.
70
+ # qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v)
71
+ w_index = tl.load(weight_indices + pid_s)
72
+ cur_rank = tl.load(lora_ranks + w_index)
73
+
74
+ # If rank is 0, this kernel is a no-op.
75
+ if cur_rank == 0:
76
+ return
77
+
78
+ seg_start = tl.load(seg_indptr + pid_s)
79
+ seg_end = tl.load(seg_indptr + pid_s + 1)
80
+
81
+ slice_id = tl.program_id(axis=1)
82
+ slice_start = tl.load(slice_offsets + slice_id)
83
+ slice_end = tl.load(slice_offsets + slice_id + 1)
84
+
85
+ scaling = tl.load(scalings + w_index)
86
+ # Adjust K (rank) according to the specific LoRA adapter
87
+ cur_rank = tl.minimum(MAX_RANK, cur_rank)
88
+
89
+ # Map logical sequence index to physical index
90
+ s_offset_logical = tl.arange(0, BLOCK_M) + seg_start
91
+ s_offset_physical = tl.load(
92
+ permutation + s_offset_logical, mask=s_offset_logical < seg_end
93
+ )
94
+
95
+ # Create pointers for the first block of x and weights[batch_id][n_start: n_end][:]
96
+ # The pointers will be advanced as we move in the K direction
97
+ # and accumulate
98
+ pid_n = tl.program_id(axis=0)
99
+ n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_start
100
+ k_offset = tl.arange(0, BLOCK_K)
101
+
102
+ x_ptrs = (
103
+ x
104
+ + slice_id * cur_rank * x_stride_1
105
+ + (s_offset_physical[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1)
106
+ )
107
+ w_ptrs = (weights + w_index * w_stride_0) + (
108
+ k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
109
+ )
110
+
111
+ # Iterate to compute the block in output matrix
112
+ partial_sum = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
113
+ for k in range(0, tl.cdiv(cur_rank, BLOCK_K)):
114
+ x_tile = tl.load(
115
+ x_ptrs,
116
+ mask=(s_offset_logical[:, None] < seg_end)
117
+ & (k_offset[None, :] < cur_rank - k * BLOCK_K),
118
+ other=0.0,
119
+ )
120
+ w_tile = tl.load(
121
+ w_ptrs,
122
+ mask=(k_offset[:, None] < cur_rank - k * BLOCK_K)
123
+ & (n_offset[None, :] < slice_end),
124
+ other=0.0,
125
+ )
126
+ partial_sum += tl.dot(x_tile, w_tile)
127
+
128
+ x_ptrs += BLOCK_K * x_stride_1
129
+ w_ptrs += BLOCK_K * w_stride_2
130
+
131
+ # Store result to output matrix
132
+ partial_sum *= scaling
133
+ partial_sum = partial_sum.to(x.dtype.element_ty)
134
+ output_ptr = output + (
135
+ s_offset_physical[:, None] * output_stride_0
136
+ + n_offset[None, :] * output_stride_1
137
+ )
138
+ output_mask = (s_offset_logical[:, None] < seg_end) & (
139
+ n_offset[None, :] < slice_end
140
+ )
141
+ partial_sum += tl.load(output_ptr, mask=output_mask, other=0.0)
142
+ tl.store(output_ptr, partial_sum, mask=output_mask)
143
+
144
+
145
+ def chunked_sgmv_lora_expand_forward(
146
+ x: torch.Tensor,
147
+ weights: torch.Tensor,
148
+ batch_info: LoRABatchInfo,
149
+ slice_offsets: torch.Tensor,
150
+ max_slice_size: int,
151
+ base_output: Optional[torch.Tensor],
152
+ ) -> torch.Tensor:
153
+
154
+ # x: (s, slice_num * r)
155
+ # weights: (num_lora, output_dim, r)
156
+ # slice_offsets: boundaries for different slices in the output dimension
157
+ # output: (s, output_dim)
158
+
159
+ # Compute lora_output with shape (s, output_dim) as follows:
160
+ # For each slice i, accumulates:
161
+ # lora_output[:, slice_offsets[i]:slice_offsets[i+1]] += scaling * sgemm(x[:, i*cur_rank:(i+1)*cur_rank], weights[:, slice_offsets[i]:slice_offsets[i+1], :])
162
+
163
+ assert x.is_contiguous()
164
+ assert weights.is_contiguous()
165
+ assert len(x.shape) == 2
166
+ assert len(weights.shape) == 3
167
+
168
+ # Get dims
169
+ M = x.shape[0]
170
+ input_dim = x.shape[1]
171
+ OUTPUT_DIM = weights.shape[1]
172
+ MAX_RANK = weights.shape[2]
173
+ num_slices = len(slice_offsets) - 1
174
+ assert input_dim == num_slices * MAX_RANK
175
+
176
+ # TODO (lifuhuang): fine-tune per operation
177
+ BLOCK_M = batch_info.max_len
178
+ BLOCK_K = 16
179
+ BLOCK_N = 64
180
+
181
+ num_segments = batch_info.num_segments
182
+
183
+ grid = (
184
+ triton.cdiv(max_slice_size, BLOCK_N),
185
+ num_slices, # number of slices in the input/output
186
+ batch_info.bs if batch_info.use_cuda_graph else num_segments,
187
+ )
188
+
189
+ if base_output is None:
190
+ output = torch.zeros((M, OUTPUT_DIM), device=x.device, dtype=x.dtype)
191
+ else:
192
+ output = base_output
193
+
194
+ _chunked_lora_expand_kernel[grid](
195
+ x=x,
196
+ weights=weights,
197
+ output=output,
198
+ seg_indptr=batch_info.seg_indptr,
199
+ weight_indices=batch_info.weight_indices,
200
+ lora_ranks=batch_info.lora_ranks,
201
+ permutation=batch_info.permutation,
202
+ num_segs=num_segments,
203
+ scalings=batch_info.scalings,
204
+ slice_offsets=slice_offsets,
205
+ # constants
206
+ NUM_SLICES=num_slices,
207
+ OUTPUT_DIM=OUTPUT_DIM,
208
+ MAX_RANK=MAX_RANK,
209
+ BLOCK_M=BLOCK_M,
210
+ BLOCK_N=BLOCK_N,
211
+ BLOCK_K=BLOCK_K,
212
+ )
213
+
214
+ return output