sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/device_config.py +3 -1
  5. sglang/srt/configs/dots_vlm.py +139 -0
  6. sglang/srt/configs/load_config.py +1 -0
  7. sglang/srt/configs/model_config.py +50 -6
  8. sglang/srt/configs/qwen3_next.py +326 -0
  9. sglang/srt/connector/__init__.py +8 -1
  10. sglang/srt/connector/remote_instance.py +82 -0
  11. sglang/srt/constrained/base_grammar_backend.py +48 -12
  12. sglang/srt/constrained/llguidance_backend.py +0 -1
  13. sglang/srt/constrained/outlines_backend.py +0 -1
  14. sglang/srt/constrained/xgrammar_backend.py +28 -9
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/base/conn.py +1 -1
  21. sglang/srt/disaggregation/common/conn.py +15 -12
  22. sglang/srt/disaggregation/decode.py +21 -10
  23. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -445
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +5 -3
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +24 -3
  31. sglang/srt/entrypoints/engine.py +38 -17
  32. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  33. sglang/srt/entrypoints/grpc_server.py +680 -0
  34. sglang/srt/entrypoints/http_server.py +85 -54
  35. sglang/srt/entrypoints/openai/protocol.py +4 -1
  36. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  37. sglang/srt/entrypoints/openai/serving_chat.py +36 -16
  38. sglang/srt/entrypoints/openai/serving_completions.py +12 -3
  39. sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
  40. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  41. sglang/srt/entrypoints/openai/serving_responses.py +6 -3
  42. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  43. sglang/srt/eplb/eplb_manager.py +2 -2
  44. sglang/srt/eplb/expert_distribution.py +26 -13
  45. sglang/srt/eplb/expert_location.py +8 -3
  46. sglang/srt/eplb/expert_location_updater.py +1 -1
  47. sglang/srt/function_call/base_format_detector.py +3 -6
  48. sglang/srt/function_call/ebnf_composer.py +11 -9
  49. sglang/srt/function_call/function_call_parser.py +6 -0
  50. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  51. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  52. sglang/srt/grpc/__init__.py +1 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  56. sglang/srt/hf_transformers_utils.py +4 -0
  57. sglang/srt/layers/activation.py +142 -9
  58. sglang/srt/layers/attention/ascend_backend.py +11 -4
  59. sglang/srt/layers/attention/fla/chunk.py +242 -0
  60. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  61. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  62. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  63. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  66. sglang/srt/layers/attention/fla/index.py +37 -0
  67. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  68. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  69. sglang/srt/layers/attention/fla/op.py +66 -0
  70. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  71. sglang/srt/layers/attention/fla/utils.py +331 -0
  72. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  73. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  74. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  75. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  76. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  77. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  78. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  79. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  80. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  81. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  82. sglang/srt/layers/attention/triton_backend.py +18 -1
  83. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  84. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  85. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  86. sglang/srt/layers/dp_attention.py +30 -1
  87. sglang/srt/layers/layernorm.py +32 -15
  88. sglang/srt/layers/linear.py +34 -3
  89. sglang/srt/layers/logits_processor.py +29 -10
  90. sglang/srt/layers/moe/__init__.py +2 -1
  91. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  92. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  93. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  94. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  95. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  96. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  107. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  108. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  109. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  110. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  111. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  112. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  113. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  114. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  115. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  116. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  117. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  118. sglang/srt/layers/moe/topk.py +30 -9
  119. sglang/srt/layers/moe/utils.py +12 -6
  120. sglang/srt/layers/quantization/awq.py +19 -7
  121. sglang/srt/layers/quantization/base_config.py +11 -6
  122. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  123. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  124. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  125. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  126. sglang/srt/layers/quantization/fp8.py +76 -47
  127. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  128. sglang/srt/layers/quantization/gptq.py +25 -17
  129. sglang/srt/layers/quantization/modelopt_quant.py +147 -47
  130. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  131. sglang/srt/layers/quantization/mxfp4.py +64 -40
  132. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  133. sglang/srt/layers/quantization/unquant.py +135 -47
  134. sglang/srt/layers/quantization/w4afp8.py +30 -17
  135. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  136. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  137. sglang/srt/layers/sampler.py +162 -18
  138. sglang/srt/lora/backend/base_backend.py +50 -8
  139. sglang/srt/lora/backend/triton_backend.py +90 -2
  140. sglang/srt/lora/layers.py +32 -0
  141. sglang/srt/lora/lora.py +4 -1
  142. sglang/srt/lora/lora_manager.py +35 -112
  143. sglang/srt/lora/mem_pool.py +24 -10
  144. sglang/srt/lora/utils.py +18 -9
  145. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  146. sglang/srt/managers/cache_controller.py +158 -160
  147. sglang/srt/managers/data_parallel_controller.py +105 -35
  148. sglang/srt/managers/detokenizer_manager.py +8 -4
  149. sglang/srt/managers/disagg_service.py +46 -0
  150. sglang/srt/managers/io_struct.py +199 -12
  151. sglang/srt/managers/mm_utils.py +1 -0
  152. sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
  153. sglang/srt/managers/schedule_batch.py +77 -56
  154. sglang/srt/managers/schedule_policy.py +1 -1
  155. sglang/srt/managers/scheduler.py +187 -39
  156. sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
  157. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  158. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  159. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  160. sglang/srt/managers/tokenizer_manager.py +259 -519
  161. sglang/srt/managers/tp_worker.py +53 -4
  162. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  163. sglang/srt/mem_cache/hicache_storage.py +3 -23
  164. sglang/srt/mem_cache/hiradix_cache.py +103 -43
  165. sglang/srt/mem_cache/memory_pool.py +347 -48
  166. sglang/srt/mem_cache/memory_pool_host.py +105 -46
  167. sglang/srt/mem_cache/radix_cache.py +0 -2
  168. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  169. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  170. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
  171. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  172. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  173. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
  174. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  175. sglang/srt/metrics/collector.py +493 -76
  176. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  177. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  178. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  179. sglang/srt/model_executor/forward_batch_info.py +59 -2
  180. sglang/srt/model_executor/model_runner.py +356 -29
  181. sglang/srt/model_loader/__init__.py +9 -3
  182. sglang/srt/model_loader/loader.py +128 -4
  183. sglang/srt/model_loader/weight_utils.py +2 -1
  184. sglang/srt/models/apertus.py +686 -0
  185. sglang/srt/models/bailing_moe.py +798 -218
  186. sglang/srt/models/bailing_moe_nextn.py +168 -0
  187. sglang/srt/models/deepseek_v2.py +109 -15
  188. sglang/srt/models/dots_vlm.py +174 -0
  189. sglang/srt/models/dots_vlm_vit.py +337 -0
  190. sglang/srt/models/ernie4.py +1 -1
  191. sglang/srt/models/gemma3n_mm.py +1 -1
  192. sglang/srt/models/glm4_moe.py +1 -1
  193. sglang/srt/models/glm4v.py +4 -2
  194. sglang/srt/models/glm4v_moe.py +3 -0
  195. sglang/srt/models/gpt_oss.py +1 -1
  196. sglang/srt/models/llama4.py +9 -0
  197. sglang/srt/models/llama_eagle3.py +13 -0
  198. sglang/srt/models/longcat_flash.py +2 -2
  199. sglang/srt/models/mllama4.py +25 -0
  200. sglang/srt/models/opt.py +637 -0
  201. sglang/srt/models/qwen2.py +7 -0
  202. sglang/srt/models/qwen2_5_vl.py +27 -3
  203. sglang/srt/models/qwen2_moe.py +56 -12
  204. sglang/srt/models/qwen3_moe.py +1 -1
  205. sglang/srt/models/qwen3_next.py +1042 -0
  206. sglang/srt/models/qwen3_next_mtp.py +112 -0
  207. sglang/srt/models/step3_vl.py +1 -1
  208. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  209. sglang/srt/multimodal/processors/glm4v.py +9 -9
  210. sglang/srt/multimodal/processors/internvl.py +141 -129
  211. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  212. sglang/srt/offloader.py +27 -3
  213. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  214. sglang/srt/sampling/sampling_batch_info.py +18 -15
  215. sglang/srt/server_args.py +276 -35
  216. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  217. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  218. sglang/srt/speculative/eagle_utils.py +0 -2
  219. sglang/srt/speculative/eagle_worker.py +43 -4
  220. sglang/srt/speculative/spec_info.py +5 -0
  221. sglang/srt/speculative/standalone_worker.py +109 -0
  222. sglang/srt/tracing/trace.py +552 -0
  223. sglang/srt/utils.py +34 -3
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  226. sglang/test/runners.py +4 -0
  227. sglang/test/test_cutlass_moe.py +24 -6
  228. sglang/test/test_disaggregation_utils.py +66 -0
  229. sglang/test/test_fp4_moe.py +370 -1
  230. sglang/test/test_utils.py +28 -1
  231. sglang/utils.py +11 -0
  232. sglang/version.py +1 -1
  233. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  234. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
  235. sglang/srt/disaggregation/launch_lb.py +0 -118
  236. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  237. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  238. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  import logging
2
- from typing import List
2
+ from typing import List, Tuple
3
3
 
4
4
  import torch
5
5
  import torch.distributed as dist
@@ -39,6 +39,25 @@ class Sampler(nn.Module):
39
39
  if is_dp_attention_enabled():
40
40
  self.tp_sync_group = get_attention_tp_group().device_group
41
41
 
42
+ def _preprocess_logits(
43
+ self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
44
+ ) -> torch.Tensor:
45
+ """Apply custom logit processors and handle NaN detection."""
46
+ # Apply the custom logit processors if registered in the sampling info
47
+ if sampling_info.has_custom_logit_processor:
48
+ apply_custom_logit_processor(logits, sampling_info)
49
+
50
+ # Detect and handle NaN values in logits
51
+ if self.use_nan_detection and torch.any(torch.isnan(logits)):
52
+ logger.warning("Detected errors during sampling! NaN in the logits.")
53
+ logits = torch.where(
54
+ torch.isnan(logits), torch.full_like(logits, -1e5), logits
55
+ )
56
+ if crash_on_warnings():
57
+ raise ValueError("Detected errors during sampling! NaN in the logits.")
58
+
59
+ return logits
60
+
42
61
  def forward(
43
62
  self,
44
63
  logits_output: LogitsProcessorOutput,
@@ -61,17 +80,8 @@ class Sampler(nn.Module):
61
80
  """
62
81
  logits = logits_output.next_token_logits
63
82
 
64
- # Apply the custom logit processors if registered in the sampling info.
65
- if sampling_info.has_custom_logit_processor:
66
- apply_custom_logit_processor(logits, sampling_info)
67
-
68
- if self.use_nan_detection and torch.any(torch.isnan(logits)):
69
- logger.warning("Detected errors during sampling! NaN in the logits.")
70
- logits = torch.where(
71
- torch.isnan(logits), torch.full_like(logits, -1e5), logits
72
- )
73
- if crash_on_warnings():
74
- raise ValueError("Detected errors during sampling! NaN in the logits.")
83
+ # Preprocess logits (custom processors and NaN handling)
84
+ logits = self._preprocess_logits(logits, sampling_info)
75
85
 
76
86
  if sampling_info.is_all_greedy:
77
87
  # Use torch.argmax if all requests use greedy sampling
@@ -80,9 +90,9 @@ class Sampler(nn.Module):
80
90
  logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
81
91
 
82
92
  else:
83
- # Post process original logits. if temperatures are all 1.0, no need to rescale
93
+ # If requested, cache probabilities from original logits before temperature scaling.
84
94
  if return_logprob and RETURN_ORIGINAL_LOGPROB:
85
- logprobs = torch.softmax(logits, dim=-1)
95
+ probs_without_temp_scaling = torch.softmax(logits, dim=-1)
86
96
 
87
97
  # Post process logits
88
98
  logits.div_(sampling_info.temperatures)
@@ -123,9 +133,10 @@ class Sampler(nn.Module):
123
133
  if return_logprob:
124
134
  # clamp to avoid -inf
125
135
  if RETURN_ORIGINAL_LOGPROB:
126
- logprobs = torch.log(logprobs).clamp(
127
- min=torch.finfo(logprobs.dtype).min
136
+ logprobs = torch.log(probs_without_temp_scaling).clamp(
137
+ min=torch.finfo(probs_without_temp_scaling.dtype).min
128
138
  )
139
+ del probs_without_temp_scaling
129
140
  else:
130
141
  logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
131
142
 
@@ -164,6 +175,54 @@ class Sampler(nn.Module):
164
175
 
165
176
  return batch_next_token_ids
166
177
 
178
+ def compute_logprobs_only(
179
+ self,
180
+ logits_output: LogitsProcessorOutput,
181
+ sampling_info: SamplingBatchInfo,
182
+ return_logprob: bool,
183
+ top_logprobs_nums: List[int],
184
+ token_ids_logprobs: List[List[int]],
185
+ ) -> None:
186
+ """
187
+ Compute logprobs for requested token IDs without performing sampling.
188
+
189
+ Optimized for prefill-only scoring requests that need token probabilities
190
+ but don't require next token generation.
191
+ """
192
+ if logits_output.next_token_logits is None:
193
+ logger.warning("No logits available for logprob computation")
194
+ return
195
+
196
+ # Check if any requests actually need logprobs computation
197
+ needs_token_ids_logprobs = any(
198
+ token_ids is not None and len(token_ids) > 0
199
+ for token_ids in token_ids_logprobs
200
+ )
201
+ needs_top_logprobs = any(x > 0 for x in top_logprobs_nums)
202
+
203
+ if not (needs_token_ids_logprobs or needs_top_logprobs):
204
+ return
205
+
206
+ # Preprocess logits (custom processors and NaN handling)
207
+ logits = self._preprocess_logits(logits_output.next_token_logits, sampling_info)
208
+
209
+ # Compute logprobs
210
+ logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
211
+
212
+ # Handle top logprobs if requested
213
+ if needs_top_logprobs:
214
+ (
215
+ logits_output.next_token_top_logprobs_val,
216
+ logits_output.next_token_top_logprobs_idx,
217
+ ) = get_top_logprobs(logprobs, top_logprobs_nums)
218
+
219
+ # Handle token_ids logprobs if requested
220
+ if needs_token_ids_logprobs:
221
+ (
222
+ logits_output.next_token_token_ids_logprobs_val,
223
+ logits_output.next_token_token_ids_logprobs_idx,
224
+ ) = get_token_ids_logprobs_batch_optimized(logprobs, token_ids_logprobs)
225
+
167
226
 
168
227
  def top_k_top_p_min_p_sampling_from_probs_torch(
169
228
  probs: torch.Tensor,
@@ -233,10 +292,95 @@ def get_top_logprobs(
233
292
  )
234
293
 
235
294
 
236
- def get_token_ids_logprobs(
295
+ def get_token_ids_logprobs_batch_optimized(
237
296
  logprobs: torch.Tensor,
238
297
  token_ids_logprobs: List[List[int]],
239
- ):
298
+ ) -> Tuple[List, List]:
299
+ """
300
+ Vectorized batch processing for token ID logprobs extraction.
301
+
302
+ Uses a single GPU kernel call for the entire batch instead of multiple
303
+ separate calls, significantly improving performance for large batches.
304
+
305
+ Args:
306
+ logprobs: Log probabilities tensor [batch_size, vocab_size]
307
+ token_ids_logprobs: List of token IDs to extract logprobs for
308
+
309
+ Example:
310
+ # Input: batch_size=3, vocab_size=5
311
+ logprobs = torch.tensor([
312
+ [-1.2, -2.1, -0.8, -3.0, -1.5], # batch 0
313
+ [-0.5, -1.8, -2.2, -1.1, -2.7], # batch 1
314
+ [-2.0, -0.9, -1.4, -2.8, -1.6], # batch 2
315
+ ])
316
+ token_ids_logprobs = [[1, 3], [2], [0, 2, 4]]
317
+
318
+ # Output:
319
+ # values = [tensor([-2.1, -3.0]), tensor([-2.2]), tensor([-2.0, -1.4, -1.6])]
320
+ # indices = [[1, 3], [2], [0, 2, 4]]
321
+ """
322
+ batch_size = len(token_ids_logprobs)
323
+ device = logprobs.device
324
+
325
+ # Step 1: Calculate lengths for each request, treating None as empty list
326
+ # Example: [[1, 3], [2], [0, 2, 4]] -> token_lengths = tensor([2, 1, 3])
327
+ token_lengths = torch.tensor(
328
+ [len(token_ids or []) for token_ids in token_ids_logprobs], device=device
329
+ )
330
+ total_tokens = int(token_lengths.sum().item()) # 2 + 1 + 3 = 6
331
+
332
+ # Handle edge case where no tokens are requested
333
+ if total_tokens == 0:
334
+ return [logprobs.new_empty(0) for _ in token_ids_logprobs], [
335
+ [] for _ in token_ids_logprobs
336
+ ]
337
+
338
+ # Step 2: Build flattened indices using torch operations
339
+ # Example: row_indices = [0, 0, 1, 2, 2, 2] (batch indices repeated by their lengths)
340
+ row_indices = torch.repeat_interleave(
341
+ torch.arange(batch_size, device=device), token_lengths
342
+ )
343
+ # Example: col_indices = [1, 3, 2, 0, 2, 4] (flattened token IDs from all requests)
344
+ col_indices = torch.tensor(
345
+ [
346
+ token_id
347
+ for token_ids in token_ids_logprobs
348
+ for token_id in (token_ids or [])
349
+ ],
350
+ device=device,
351
+ dtype=torch.long,
352
+ )
353
+
354
+ # Step 3: Single vectorized gather operation
355
+ # Example: logprobs[row_indices, col_indices] -> [-2.1, -3.0, -2.2, -2.0, -1.4, -1.6]
356
+ gathered_logprobs = logprobs[row_indices, col_indices]
357
+
358
+ # Step 4: Split results back per request using torch operations
359
+ # Example: split tensor [6] into chunks of sizes [2, 1, 3] -> [tensor(2), tensor(1), tensor(3)]
360
+ split_logprobs = torch.split_with_sizes(
361
+ gathered_logprobs, token_lengths.tolist(), dim=0
362
+ )
363
+
364
+ # Step 5: Format output to match expected return structure
365
+ # Example: Convert split tensors back to list format with proper empty handling
366
+ # i=0: [1,3] -> append split_logprobs[0] and [1,3]
367
+ # i=1: [2] -> append split_logprobs[1] and [2]
368
+ # i=2: [0,2,4] -> append split_logprobs[2] and [0,2,4]
369
+ output_token_ids_logprobs_val = []
370
+ output_token_ids_logprobs_idx = []
371
+
372
+ for i, token_ids in enumerate(token_ids_logprobs):
373
+ if token_ids is not None and len(token_ids) > 0:
374
+ output_token_ids_logprobs_val.append(split_logprobs[i])
375
+ output_token_ids_logprobs_idx.append(token_ids)
376
+ else:
377
+ output_token_ids_logprobs_val.append(logprobs.new_empty(0))
378
+ output_token_ids_logprobs_idx.append([])
379
+
380
+ return output_token_ids_logprobs_val, output_token_ids_logprobs_idx
381
+
382
+
383
+ def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List[int]]):
240
384
  output_token_ids_logprobs_val = []
241
385
  output_token_ids_logprobs_idx = []
242
386
  for i, token_ids in enumerate(token_ids_logprobs):
@@ -1,8 +1,9 @@
1
- from typing import Tuple, Union
1
+ from typing import Optional, Tuple, Union
2
2
 
3
3
  import torch
4
4
 
5
5
  from sglang.srt.lora.utils import LoRABatchInfo
6
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
6
7
 
7
8
 
8
9
  class BaseLoRABackend:
@@ -10,13 +11,14 @@ class BaseLoRABackend:
10
11
  Each backend has its own implementation of Lora kernels.
11
12
 
12
13
  Args:
13
- name: name of backend
14
- batch_info: information of current batch for use
14
+ max_loras_per_batch: maximum number of different lora weights
15
+ that can be applied in a single forward batch.
16
+ device: the device where the backend runs.
15
17
  """
16
18
 
17
- def __init__(self, name: str, batch_info: LoRABatchInfo = None):
18
- self.name = name
19
- self.batch_info = batch_info
19
+ def __init__(self, max_loras_per_batch: int, device: torch.device):
20
+ self.max_loras_per_batch = max_loras_per_batch
21
+ self.device = device
20
22
 
21
23
  def run_lora_a_sgemm(
22
24
  self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
@@ -93,8 +95,44 @@ class BaseLoRABackend:
93
95
  """
94
96
  pass
95
97
 
96
- def set_batch_info(self, batch_info: LoRABatchInfo):
97
- self.batch_info = batch_info
98
+ def init_cuda_graph_batch_info(
99
+ self,
100
+ cuda_graph_batch_info: LoRABatchInfo,
101
+ max_bs_in_cuda_graph: int,
102
+ ):
103
+ """Initialize the batch info for CUDA Graph mode.
104
+
105
+ This method provides a hook for each backend to conduct its own initialization
106
+ logic for CUDA Graph mode.
107
+
108
+ Args:
109
+ cuda_graph_batch_info: the LoRABatchInfo object created in LoraManager
110
+ max_bs_in_cuda_graph: maximum batch size for CUDA Graph mode
111
+ """
112
+ pass
113
+
114
+ def prepare_lora_batch(
115
+ self,
116
+ forward_batch: ForwardBatch,
117
+ weight_indices: list[int],
118
+ lora_ranks: list[int],
119
+ scalings: list[float],
120
+ batch_info: Optional[LoRABatchInfo] = None,
121
+ ):
122
+ """Prepare the lora weights and batch info for current forward batch.
123
+
124
+ This method provides a hook for each backend to conduct its own preparation
125
+ logic for each forward batch.
126
+
127
+ Args:
128
+ forward_batch: the ForwardBatch object for current forward pass
129
+ weight_indices: list of indices of lora weights to be applied for current batch
130
+ lora_ranks: list of lora ranks corresponding to weight_indices
131
+ scalings: list of scaling factors corresponding to weight_indices
132
+ batch_info: optional LoRABatchInfo object, if not provided, the backend should use its own
133
+ internal batch info (e.g., self.cuda_graph_batch_info for CUDA Graph mode)
134
+ """
135
+ pass
98
136
 
99
137
 
100
138
  def get_backend_from_name(name: str) -> BaseLoRABackend:
@@ -105,6 +143,10 @@ def get_backend_from_name(name: str) -> BaseLoRABackend:
105
143
  from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
106
144
 
107
145
  return TritonLoRABackend
146
+ # elif name == "csgmv":
147
+ # from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
148
+
149
+ # return ChunkedSgmvLoRABackend
108
150
  elif name == "flashinfer":
109
151
  raise ValueError(
110
152
  "FlashInfer LoRA backend has been deprecated, please use `triton` instead."
@@ -1,3 +1,5 @@
1
+ from typing import Optional
2
+
1
3
  import torch
2
4
 
3
5
  from sglang.srt.lora.backend.base_backend import BaseLoRABackend
@@ -8,12 +10,14 @@ from sglang.srt.lora.triton_ops import (
8
10
  sgemm_lora_b_fwd,
9
11
  )
10
12
  from sglang.srt.lora.utils import LoRABatchInfo
13
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
11
14
 
12
15
 
13
16
  class TritonLoRABackend(BaseLoRABackend):
17
+ name = "triton"
14
18
 
15
- def __init__(self, name: str, batch_info: LoRABatchInfo = None):
16
- super().__init__(name, batch_info)
19
+ def __init__(self, max_loras_per_batch: int, device: torch.device):
20
+ super().__init__(max_loras_per_batch, device)
17
21
 
18
22
  def run_lora_a_sgemm(
19
23
  self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
@@ -86,3 +90,87 @@ class TritonLoRABackend(BaseLoRABackend):
86
90
  base_output,
87
91
  )
88
92
  return lora_output
93
+
94
+ def init_cuda_graph_batch_info(
95
+ self, cuda_graph_batch_info: LoRABatchInfo, max_bs_in_cuda_graph: int
96
+ ):
97
+ # Initialize seg_lens and seg_indptr for CUDA graph as they remain constant
98
+ # across batches.
99
+ cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph].fill_(1)
100
+ torch.cumsum(
101
+ cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph],
102
+ dim=0,
103
+ out=cuda_graph_batch_info.seg_indptr[1 : max_bs_in_cuda_graph + 1],
104
+ )
105
+
106
+ def prepare_lora_batch(
107
+ self,
108
+ forward_batch: ForwardBatch,
109
+ weight_indices: list[int],
110
+ lora_ranks: list[int],
111
+ scalings: list[float],
112
+ batch_info: Optional[LoRABatchInfo] = None,
113
+ ):
114
+ # Use pinned memory to avoid synchronizations during host-to-device transfer
115
+ weight_indices_tensor = torch.tensor(
116
+ weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
117
+ )
118
+ lora_ranks_tensor = torch.tensor(
119
+ lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
120
+ )
121
+ scalings_tensor = torch.tensor(
122
+ scalings, dtype=torch.float, pin_memory=True, device="cpu"
123
+ )
124
+
125
+ bs = forward_batch.batch_size
126
+
127
+ if batch_info is not None:
128
+ assert (
129
+ batch_info.use_cuda_graph
130
+ ), "batch_info.use_cuda_graph must be True when batch_info is provided"
131
+ batch_info.bs = forward_batch.batch_size
132
+ batch_info.num_segments = forward_batch.batch_size
133
+ else:
134
+ max_len = (
135
+ # Calculate max_len from the CPU copy to avoid D2H transfer.
136
+ max(forward_batch.extend_seq_lens_cpu)
137
+ if forward_batch.forward_mode.is_extend()
138
+ else 1
139
+ )
140
+ seg_lens = (
141
+ forward_batch.extend_seq_lens
142
+ if forward_batch.forward_mode.is_extend()
143
+ else torch.ones(bs, device=self.device)
144
+ )
145
+ seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
146
+ seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
147
+
148
+ batch_info = LoRABatchInfo(
149
+ bs=forward_batch.batch_size,
150
+ num_segments=forward_batch.batch_size,
151
+ max_len=max_len,
152
+ use_cuda_graph=False,
153
+ seg_lens=seg_lens,
154
+ seg_indptr=seg_indptr,
155
+ weight_indices=torch.empty(
156
+ (bs,), dtype=torch.int32, device=self.device
157
+ ),
158
+ lora_ranks=torch.empty(
159
+ (self.max_loras_per_batch,), dtype=torch.int64, device=self.device
160
+ ),
161
+ scalings=torch.empty(
162
+ (self.max_loras_per_batch,), dtype=torch.float, device=self.device
163
+ ),
164
+ permutation=None,
165
+ )
166
+
167
+ # Copy to device asynchronously
168
+ batch_info.lora_ranks[: self.max_loras_per_batch].copy_(
169
+ lora_ranks_tensor, non_blocking=True
170
+ )
171
+ batch_info.scalings[: self.max_loras_per_batch].copy_(
172
+ scalings_tensor, non_blocking=True
173
+ )
174
+ batch_info.weight_indices[:bs].copy_(weight_indices_tensor, non_blocking=True)
175
+
176
+ self.batch_info = batch_info
sglang/srt/lora/layers.py CHANGED
@@ -66,6 +66,15 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
66
66
  lora_backend: BaseLoRABackend,
67
67
  ) -> None:
68
68
  super().__init__(base_layer, lora_backend)
69
+ shard_size = self.base_layer.output_partition_sizes[0]
70
+ self.output_offset = torch.tensor(
71
+ [
72
+ 0,
73
+ shard_size,
74
+ ],
75
+ dtype=torch.int32,
76
+ device=next(self.base_layer.parameters()).device,
77
+ )
69
78
 
70
79
  def set_lora_info(
71
80
  self,
@@ -81,6 +90,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
81
90
  lora_output = self.lora_backend.run_lora_b_sgemm(
82
91
  x=lora_a_output,
83
92
  weights=self.B_buffer,
93
+ output_offset=self.output_offset,
84
94
  base_output=base_output,
85
95
  )
86
96
  return lora_output
@@ -130,11 +140,23 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
130
140
  self.A_buffer_gate_up = A_buffer
131
141
  self.B_buffer_gate_up = B_buffer
132
142
 
143
+ shard_size = self.base_layer.output_partition_sizes[0]
144
+ self.output_offset = torch.tensor(
145
+ [
146
+ 0,
147
+ shard_size,
148
+ 2 * shard_size,
149
+ ],
150
+ dtype=torch.int32,
151
+ device=next(self.base_layer.parameters()).device,
152
+ )
153
+
133
154
  def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
134
155
  lora_output = self.lora_backend.run_gate_up_lora(
135
156
  x=x,
136
157
  gate_up_lora_a=self.A_buffer_gate_up,
137
158
  gate_up_lora_b=self.B_buffer_gate_up,
159
+ output_offset=self.output_offset,
138
160
  base_output=base_output,
139
161
  )
140
162
  return lora_output
@@ -243,12 +265,22 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
243
265
  self.set_lora = True
244
266
  self.A_buffer = A_buffer
245
267
  self.B_buffer = B_buffer
268
+ output_size = self.base_layer.output_size
269
+ self.output_offset = torch.tensor(
270
+ [
271
+ 0,
272
+ output_size,
273
+ ],
274
+ dtype=torch.int32,
275
+ device=next(self.base_layer.parameters()).device,
276
+ )
246
277
 
247
278
  def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
248
279
  lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
249
280
  lora_output = self.lora_backend.run_lora_b_sgemm(
250
281
  x=lora_a_output,
251
282
  weights=self.B_buffer,
283
+ output_offset=self.output_offset,
252
284
  base_output=base_output,
253
285
  )
254
286
  return lora_output
sglang/srt/lora/lora.py CHANGED
@@ -28,6 +28,9 @@ from torch import nn
28
28
  from sglang.srt.configs.load_config import LoadConfig
29
29
  from sglang.srt.hf_transformers_utils import AutoConfig
30
30
  from sglang.srt.lora.backend.base_backend import BaseLoRABackend
31
+
32
+ # from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
33
+ from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
31
34
  from sglang.srt.lora.lora_config import LoRAConfig
32
35
  from sglang.srt.model_loader.loader import DefaultModelLoader
33
36
 
@@ -156,7 +159,7 @@ class LoRAAdapter(nn.Module):
156
159
  gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
157
160
  if up_name not in weights:
158
161
  weights[up_name] = torch.zeros_like(weights[weight_name])
159
- assert self.lora_backend.name == "triton", (
162
+ assert isinstance(self.lora_backend, TritonLoRABackend), (
160
163
  f"LoRA weight initialization currently only supported for 'triton' backend. "
161
164
  f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
162
165
  f"or consider implementing custom initialization logic for other backends."