sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,214 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from sglang.srt.lora.utils import LoRABatchInfo
8
+ from sglang.srt.utils import cached_triton_kernel
9
+
10
+
11
+ @cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
12
+ @triton.jit
13
+ def _chunked_lora_expand_kernel(
14
+ # Pointers to matrices
15
+ x,
16
+ weights,
17
+ output,
18
+ # Information on sequence lengths and weight id
19
+ seg_indptr,
20
+ weight_indices,
21
+ lora_ranks,
22
+ permutation,
23
+ num_segs,
24
+ # For fused output scaling
25
+ scalings,
26
+ # Offsets of q/k/v slice on output dimension
27
+ slice_offsets,
28
+ # Meta parameters
29
+ NUM_SLICES: tl.constexpr,
30
+ OUTPUT_DIM: tl.constexpr,
31
+ MAX_RANK: tl.constexpr, # K = R
32
+ BLOCK_M: tl.constexpr,
33
+ BLOCK_N: tl.constexpr,
34
+ BLOCK_K: tl.constexpr,
35
+ ):
36
+ """
37
+ Computes a chunked SGMV for LoRA expand operations.
38
+
39
+ When a sequence's rank is 0, the kernel is essentially a no-op, following
40
+ the convention in pytorch where the product of two matrices of shape (m, 0)
41
+ and (0, n) is an all-zero matrix of shape (m, n).
42
+
43
+ Args:
44
+ x (Tensor): The input tensor, which is the result of the LoRA A projection.
45
+ Shape: (s, num_slices * K), where s is the sum of all sequence lengths in the
46
+ batch and K is the maximum LoRA rank.
47
+ weights (Tensor): The LoRA B weights for all adapters.
48
+ Shape: (num_lora, output_dim, K).
49
+ output (Tensor): The output tensor where the result is stored.
50
+ Shape: (s, output_dim).
51
+ """
52
+ tl.static_assert(NUM_SLICES <= 3)
53
+
54
+ x_stride_0: tl.constexpr = NUM_SLICES * MAX_RANK
55
+ x_stride_1: tl.constexpr = 1
56
+
57
+ w_stride_0: tl.constexpr = OUTPUT_DIM * MAX_RANK
58
+ w_stride_1: tl.constexpr = MAX_RANK
59
+ w_stride_2: tl.constexpr = 1
60
+
61
+ output_stride_0: tl.constexpr = OUTPUT_DIM
62
+ output_stride_1: tl.constexpr = 1
63
+
64
+ pid_s = tl.program_id(axis=2)
65
+ if pid_s >= num_segs:
66
+ return
67
+
68
+ # Current block computes sequence with batch_id,
69
+ # which starts from row seg_start of x with length seg_len.
70
+ # qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v)
71
+ w_index = tl.load(weight_indices + pid_s)
72
+ cur_rank = tl.load(lora_ranks + w_index)
73
+
74
+ # If rank is 0, this kernel is a no-op.
75
+ if cur_rank == 0:
76
+ return
77
+
78
+ seg_start = tl.load(seg_indptr + pid_s)
79
+ seg_end = tl.load(seg_indptr + pid_s + 1)
80
+
81
+ slice_id = tl.program_id(axis=1)
82
+ slice_start = tl.load(slice_offsets + slice_id)
83
+ slice_end = tl.load(slice_offsets + slice_id + 1)
84
+
85
+ scaling = tl.load(scalings + w_index)
86
+ # Adjust K (rank) according to the specific LoRA adapter
87
+ cur_rank = tl.minimum(MAX_RANK, cur_rank)
88
+
89
+ # Map logical sequence index to physical index
90
+ s_offset_logical = tl.arange(0, BLOCK_M) + seg_start
91
+ s_offset_physical = tl.load(
92
+ permutation + s_offset_logical, mask=s_offset_logical < seg_end
93
+ )
94
+
95
+ # Create pointers for the first block of x and weights[batch_id][n_start: n_end][:]
96
+ # The pointers will be advanced as we move in the K direction
97
+ # and accumulate
98
+ pid_n = tl.program_id(axis=0)
99
+ n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_start
100
+ k_offset = tl.arange(0, BLOCK_K)
101
+
102
+ x_ptrs = (
103
+ x
104
+ + slice_id * cur_rank * x_stride_1
105
+ + (s_offset_physical[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1)
106
+ )
107
+ w_ptrs = (weights + w_index * w_stride_0) + (
108
+ k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
109
+ )
110
+
111
+ # Iterate to compute the block in output matrix
112
+ partial_sum = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
113
+ for k in range(0, tl.cdiv(cur_rank, BLOCK_K)):
114
+ x_tile = tl.load(
115
+ x_ptrs,
116
+ mask=(s_offset_logical[:, None] < seg_end)
117
+ & (k_offset[None, :] < cur_rank - k * BLOCK_K),
118
+ other=0.0,
119
+ )
120
+ w_tile = tl.load(
121
+ w_ptrs,
122
+ mask=(k_offset[:, None] < cur_rank - k * BLOCK_K)
123
+ & (n_offset[None, :] < slice_end),
124
+ other=0.0,
125
+ )
126
+ partial_sum += tl.dot(x_tile, w_tile)
127
+
128
+ x_ptrs += BLOCK_K * x_stride_1
129
+ w_ptrs += BLOCK_K * w_stride_2
130
+
131
+ # Store result to output matrix
132
+ partial_sum *= scaling
133
+ partial_sum = partial_sum.to(x.dtype.element_ty)
134
+ output_ptr = output + (
135
+ s_offset_physical[:, None] * output_stride_0
136
+ + n_offset[None, :] * output_stride_1
137
+ )
138
+ output_mask = (s_offset_logical[:, None] < seg_end) & (
139
+ n_offset[None, :] < slice_end
140
+ )
141
+ partial_sum += tl.load(output_ptr, mask=output_mask, other=0.0)
142
+ tl.store(output_ptr, partial_sum, mask=output_mask)
143
+
144
+
145
+ def chunked_sgmv_lora_expand_forward(
146
+ x: torch.Tensor,
147
+ weights: torch.Tensor,
148
+ batch_info: LoRABatchInfo,
149
+ slice_offsets: torch.Tensor,
150
+ max_slice_size: int,
151
+ base_output: Optional[torch.Tensor],
152
+ ) -> torch.Tensor:
153
+
154
+ # x: (s, slice_num * r)
155
+ # weights: (num_lora, output_dim, r)
156
+ # slice_offsets: boundaries for different slices in the output dimension
157
+ # output: (s, output_dim)
158
+
159
+ # Compute lora_output with shape (s, output_dim) as follows:
160
+ # For each slice i, accumulates:
161
+ # lora_output[:, slice_offsets[i]:slice_offsets[i+1]] += scaling * sgemm(x[:, i*cur_rank:(i+1)*cur_rank], weights[:, slice_offsets[i]:slice_offsets[i+1], :])
162
+
163
+ assert x.is_contiguous()
164
+ assert weights.is_contiguous()
165
+ assert len(x.shape) == 2
166
+ assert len(weights.shape) == 3
167
+
168
+ # Get dims
169
+ M = x.shape[0]
170
+ input_dim = x.shape[1]
171
+ OUTPUT_DIM = weights.shape[1]
172
+ MAX_RANK = weights.shape[2]
173
+ num_slices = len(slice_offsets) - 1
174
+ assert input_dim == num_slices * MAX_RANK
175
+
176
+ # TODO (lifuhuang): fine-tune per operation
177
+ BLOCK_M = batch_info.max_len
178
+ BLOCK_K = 16
179
+ BLOCK_N = 64
180
+
181
+ num_segments = batch_info.num_segments
182
+
183
+ grid = (
184
+ triton.cdiv(max_slice_size, BLOCK_N),
185
+ num_slices, # number of slices in the input/output
186
+ batch_info.bs if batch_info.use_cuda_graph else num_segments,
187
+ )
188
+
189
+ if base_output is None:
190
+ output = torch.zeros((M, OUTPUT_DIM), device=x.device, dtype=x.dtype)
191
+ else:
192
+ output = base_output
193
+
194
+ _chunked_lora_expand_kernel[grid](
195
+ x=x,
196
+ weights=weights,
197
+ output=output,
198
+ seg_indptr=batch_info.seg_indptr,
199
+ weight_indices=batch_info.weight_indices,
200
+ lora_ranks=batch_info.lora_ranks,
201
+ permutation=batch_info.permutation,
202
+ num_segs=num_segments,
203
+ scalings=batch_info.scalings,
204
+ slice_offsets=slice_offsets,
205
+ # constants
206
+ NUM_SLICES=num_slices,
207
+ OUTPUT_DIM=OUTPUT_DIM,
208
+ MAX_RANK=MAX_RANK,
209
+ BLOCK_M=BLOCK_M,
210
+ BLOCK_N=BLOCK_N,
211
+ BLOCK_K=BLOCK_K,
212
+ )
213
+
214
+ return output
@@ -0,0 +1,174 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from sglang.srt.lora.utils import LoRABatchInfo
6
+ from sglang.srt.utils import cached_triton_kernel
7
+
8
+
9
+ @cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
10
+ @triton.jit
11
+ def _chunked_lora_shrink_kernel(
12
+ # Pointers to matrices
13
+ x,
14
+ weights,
15
+ output,
16
+ # Information on sequence lengths,ranks and weight id
17
+ seg_indptr,
18
+ weight_indices,
19
+ lora_ranks,
20
+ permutation,
21
+ num_segs,
22
+ # Meta parameters
23
+ N: tl.constexpr, # num_slices * r
24
+ K: tl.constexpr, # input_dim
25
+ NUM_SLICES: tl.constexpr,
26
+ BLOCK_M: tl.constexpr,
27
+ BLOCK_N: tl.constexpr,
28
+ BLOCK_K: tl.constexpr,
29
+ ):
30
+ """
31
+ Computes a chunked SGMV for LoRA shrink operations.
32
+
33
+ The kernel ensures that output[seg_start:seg_start + seg_len, :rank * num_slices]
34
+ stores the product of the input `x` and the LoRA weights for the corresponding
35
+ sequence. This implies that when rank is 0, the kernel is essentially a no-op,
36
+ as output[seg_start:seg_start + seg_len, :0] is trivially correct (empty).
37
+
38
+ Args:
39
+ x (torch.Tensor): The input activations tensor of shape `(s, K)`, where `s`
40
+ is the sum of all sequence lengths in the batch.
41
+ weights (torch.Tensor): The LoRA A weights for all available adapters,
42
+ with shape `(num_lora, N, K)` where N = num_slices * r.
43
+ output (torch.Tensor): The output tensor of shape `(s, N)`.
44
+ """
45
+ x_stride_1: tl.constexpr = 1
46
+ x_stride_0: tl.constexpr = K
47
+
48
+ w_stride_0: tl.constexpr = N * K
49
+ w_stride_1: tl.constexpr = K
50
+ w_stride_2: tl.constexpr = 1
51
+
52
+ output_stride_0: tl.constexpr = N
53
+ output_stride_1: tl.constexpr = 1
54
+
55
+ pid_s = tl.program_id(1)
56
+ if pid_s >= num_segs:
57
+ return
58
+
59
+ pid_n = tl.program_id(0)
60
+
61
+ # Current block computes sequence with batch_id,
62
+ # which starts from row seg_start of x with length seg_len
63
+ w_index = tl.load(weight_indices + pid_s)
64
+ rank = tl.load(lora_ranks + w_index)
65
+
66
+ # If rank is 0, this kernel becomes a no-op as the output is always trivially correct.
67
+ if rank == 0:
68
+ return
69
+
70
+ seg_start = tl.load(seg_indptr + pid_s)
71
+ seg_end = tl.load(seg_indptr + pid_s + 1)
72
+
73
+ # Adjust N dim according to the specific LoRA adapter
74
+ cur_n = tl.minimum(N, rank * NUM_SLICES)
75
+
76
+ # Map logical sequence index to physical index
77
+ s_offset_logical = tl.arange(0, BLOCK_M) + seg_start
78
+ s_offset_physical = tl.load(
79
+ permutation + s_offset_logical, mask=s_offset_logical < seg_end
80
+ )
81
+
82
+ n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
83
+ k_offset = tl.arange(0, BLOCK_K)
84
+ x_ptrs = x + (
85
+ s_offset_physical[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
86
+ )
87
+ w_ptrs = (weights + w_index * w_stride_0) + (
88
+ k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
89
+ )
90
+
91
+ # Iterate to compute the block in output matrix
92
+ partial_sum = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
93
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
94
+ x_tile = tl.load(
95
+ x_ptrs,
96
+ mask=(s_offset_logical[:, None] < seg_end)
97
+ & (k_offset[None, :] < K - k * BLOCK_K),
98
+ other=0.0,
99
+ )
100
+ w_tile = tl.load(
101
+ w_ptrs,
102
+ mask=(k_offset[:, None] < K - k * BLOCK_K) & (n_offset[None, :] < cur_n),
103
+ other=0.0,
104
+ )
105
+ partial_sum += tl.dot(x_tile, w_tile)
106
+
107
+ x_ptrs += BLOCK_K * x_stride_1
108
+ w_ptrs += BLOCK_K * w_stride_2
109
+
110
+ # Store result to output matrix
111
+ partial_sum = partial_sum.to(x.dtype.element_ty)
112
+ output_ptr = output + (
113
+ s_offset_physical[:, None] * output_stride_0
114
+ + n_offset[None, :] * output_stride_1
115
+ )
116
+ output_mask = (s_offset_logical[:, None] < seg_end) & (n_offset[None, :] < cur_n)
117
+ tl.store(output_ptr, partial_sum, mask=output_mask)
118
+
119
+
120
+ def chunked_sgmv_lora_shrink_forward(
121
+ x: torch.Tensor,
122
+ weights: torch.Tensor,
123
+ batch_info: LoRABatchInfo,
124
+ num_slices: int,
125
+ ) -> torch.Tensor:
126
+ # x: (s, input_dim)
127
+ # weights: (num_lora, num_slices * r, input_dim)
128
+ # output: (s, num_slices * r)
129
+ # num_slices: qkv=3, gate_up=2, others=1
130
+ # when called with multiple slices, the weights.shape[-2] will be num_slices * r
131
+ # input_dim is much larger than r
132
+
133
+ assert x.is_contiguous()
134
+ assert weights.is_contiguous()
135
+ assert len(x.shape) == 2
136
+ assert len(weights.shape) == 3
137
+
138
+ # Block shapes
139
+ # TODO (lifuhuang): experiment with split-k
140
+ BLOCK_M = batch_info.max_len
141
+ BLOCK_N = 16
142
+ BLOCK_K = 256
143
+
144
+ S = x.shape[0]
145
+ N = weights.shape[1]
146
+ K = weights.shape[2]
147
+ assert x.shape[-1] == K
148
+
149
+ num_segments = batch_info.num_segments
150
+ grid = (
151
+ triton.cdiv(N, BLOCK_N),
152
+ batch_info.bs if batch_info.use_cuda_graph else num_segments,
153
+ )
154
+
155
+ output = torch.empty((S, N), device=x.device, dtype=x.dtype)
156
+ _chunked_lora_shrink_kernel[grid](
157
+ x=x,
158
+ weights=weights,
159
+ output=output,
160
+ seg_indptr=batch_info.seg_indptr,
161
+ weight_indices=batch_info.weight_indices,
162
+ lora_ranks=batch_info.lora_ranks,
163
+ permutation=batch_info.permutation,
164
+ num_segs=num_segments,
165
+ # constants
166
+ N=N,
167
+ K=K,
168
+ NUM_SLICES=num_slices,
169
+ BLOCK_M=BLOCK_M,
170
+ BLOCK_N=BLOCK_N,
171
+ BLOCK_K=BLOCK_K,
172
+ )
173
+
174
+ return output
sglang/srt/lora/utils.py CHANGED
@@ -5,7 +5,7 @@ from typing import Iterable, Optional, Set, Tuple
5
5
 
6
6
  import torch
7
7
 
8
- from sglang.srt.hf_transformers_utils import AutoConfig
8
+ from sglang.srt.utils.hf_transformers_utils import AutoConfig
9
9
 
10
10
 
11
11
  @dataclass
@@ -19,6 +19,9 @@ class LoRABatchInfo:
19
19
  # Number of segments. For triton backend, it is equal to batch size.
20
20
  num_segments: int
21
21
 
22
+ # Maximum segment length of current batch
23
+ max_len: int
24
+
22
25
  # Indice pointers of each segment in shape (num_segments + 1, )
23
26
  seg_indptr: torch.Tensor
24
27
 
@@ -34,9 +37,6 @@ class LoRABatchInfo:
34
37
  # Lengths of each segments in shape (num_segments,)
35
38
  seg_lens: Optional[torch.Tensor]
36
39
 
37
- # Maximum segment length of current batch
38
- max_len: Optional[int]
39
-
40
40
  # The logical (re)ordering of input rows (tokens), in shape (num_tokens,)
41
41
  permutation: Optional[torch.Tensor]
42
42
 
@@ -98,6 +98,7 @@ def get_normalized_target_modules(
98
98
  ) -> set[str]:
99
99
  """
100
100
  Mapping a list of target module name to names of the normalized LoRA weights.
101
+ Handles both base module names (e.g., "gate_proj") and prefixed module names (e.g., "feed_forward.gate_proj").
101
102
  """
102
103
  params_mapping = {
103
104
  "q_proj": "qkv_proj",
@@ -109,7 +110,8 @@ def get_normalized_target_modules(
109
110
 
110
111
  result = set()
111
112
  for name in target_modules:
112
- normalized_name = params_mapping.get(name, name)
113
+ base_name = name.split(".")[-1]
114
+ normalized_name = params_mapping.get(base_name, base_name)
113
115
  result.add(normalized_name)
114
116
  return result
115
117