sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/device_config.py +3 -1
  5. sglang/srt/configs/dots_vlm.py +139 -0
  6. sglang/srt/configs/load_config.py +1 -0
  7. sglang/srt/configs/model_config.py +50 -6
  8. sglang/srt/configs/qwen3_next.py +326 -0
  9. sglang/srt/connector/__init__.py +8 -1
  10. sglang/srt/connector/remote_instance.py +82 -0
  11. sglang/srt/constrained/base_grammar_backend.py +48 -12
  12. sglang/srt/constrained/llguidance_backend.py +0 -1
  13. sglang/srt/constrained/outlines_backend.py +0 -1
  14. sglang/srt/constrained/xgrammar_backend.py +28 -9
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/base/conn.py +1 -1
  21. sglang/srt/disaggregation/common/conn.py +15 -12
  22. sglang/srt/disaggregation/decode.py +21 -10
  23. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -445
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +5 -3
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +24 -3
  31. sglang/srt/entrypoints/engine.py +38 -17
  32. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  33. sglang/srt/entrypoints/grpc_server.py +680 -0
  34. sglang/srt/entrypoints/http_server.py +85 -54
  35. sglang/srt/entrypoints/openai/protocol.py +4 -1
  36. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  37. sglang/srt/entrypoints/openai/serving_chat.py +36 -16
  38. sglang/srt/entrypoints/openai/serving_completions.py +12 -3
  39. sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
  40. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  41. sglang/srt/entrypoints/openai/serving_responses.py +6 -3
  42. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  43. sglang/srt/eplb/eplb_manager.py +2 -2
  44. sglang/srt/eplb/expert_distribution.py +26 -13
  45. sglang/srt/eplb/expert_location.py +8 -3
  46. sglang/srt/eplb/expert_location_updater.py +1 -1
  47. sglang/srt/function_call/base_format_detector.py +3 -6
  48. sglang/srt/function_call/ebnf_composer.py +11 -9
  49. sglang/srt/function_call/function_call_parser.py +6 -0
  50. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  51. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  52. sglang/srt/grpc/__init__.py +1 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  56. sglang/srt/hf_transformers_utils.py +4 -0
  57. sglang/srt/layers/activation.py +142 -9
  58. sglang/srt/layers/attention/ascend_backend.py +11 -4
  59. sglang/srt/layers/attention/fla/chunk.py +242 -0
  60. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  61. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  62. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  63. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  66. sglang/srt/layers/attention/fla/index.py +37 -0
  67. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  68. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  69. sglang/srt/layers/attention/fla/op.py +66 -0
  70. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  71. sglang/srt/layers/attention/fla/utils.py +331 -0
  72. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  73. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  74. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  75. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  76. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  77. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  78. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  79. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  80. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  81. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  82. sglang/srt/layers/attention/triton_backend.py +18 -1
  83. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  84. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  85. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  86. sglang/srt/layers/dp_attention.py +30 -1
  87. sglang/srt/layers/layernorm.py +32 -15
  88. sglang/srt/layers/linear.py +34 -3
  89. sglang/srt/layers/logits_processor.py +29 -10
  90. sglang/srt/layers/moe/__init__.py +2 -1
  91. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  92. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  93. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  94. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  95. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  96. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  107. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  108. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  109. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  110. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  111. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  112. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  113. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  114. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  115. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  116. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  117. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  118. sglang/srt/layers/moe/topk.py +30 -9
  119. sglang/srt/layers/moe/utils.py +12 -6
  120. sglang/srt/layers/quantization/awq.py +19 -7
  121. sglang/srt/layers/quantization/base_config.py +11 -6
  122. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  123. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  124. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  125. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  126. sglang/srt/layers/quantization/fp8.py +76 -47
  127. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  128. sglang/srt/layers/quantization/gptq.py +25 -17
  129. sglang/srt/layers/quantization/modelopt_quant.py +147 -47
  130. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  131. sglang/srt/layers/quantization/mxfp4.py +64 -40
  132. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  133. sglang/srt/layers/quantization/unquant.py +135 -47
  134. sglang/srt/layers/quantization/w4afp8.py +30 -17
  135. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  136. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  137. sglang/srt/layers/sampler.py +162 -18
  138. sglang/srt/lora/backend/base_backend.py +50 -8
  139. sglang/srt/lora/backend/triton_backend.py +90 -2
  140. sglang/srt/lora/layers.py +32 -0
  141. sglang/srt/lora/lora.py +4 -1
  142. sglang/srt/lora/lora_manager.py +35 -112
  143. sglang/srt/lora/mem_pool.py +24 -10
  144. sglang/srt/lora/utils.py +18 -9
  145. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  146. sglang/srt/managers/cache_controller.py +158 -160
  147. sglang/srt/managers/data_parallel_controller.py +105 -35
  148. sglang/srt/managers/detokenizer_manager.py +8 -4
  149. sglang/srt/managers/disagg_service.py +46 -0
  150. sglang/srt/managers/io_struct.py +199 -12
  151. sglang/srt/managers/mm_utils.py +1 -0
  152. sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
  153. sglang/srt/managers/schedule_batch.py +77 -56
  154. sglang/srt/managers/schedule_policy.py +1 -1
  155. sglang/srt/managers/scheduler.py +187 -39
  156. sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
  157. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  158. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  159. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  160. sglang/srt/managers/tokenizer_manager.py +259 -519
  161. sglang/srt/managers/tp_worker.py +53 -4
  162. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  163. sglang/srt/mem_cache/hicache_storage.py +3 -23
  164. sglang/srt/mem_cache/hiradix_cache.py +103 -43
  165. sglang/srt/mem_cache/memory_pool.py +347 -48
  166. sglang/srt/mem_cache/memory_pool_host.py +105 -46
  167. sglang/srt/mem_cache/radix_cache.py +0 -2
  168. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  169. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  170. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
  171. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  172. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  173. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
  174. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  175. sglang/srt/metrics/collector.py +493 -76
  176. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  177. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  178. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  179. sglang/srt/model_executor/forward_batch_info.py +59 -2
  180. sglang/srt/model_executor/model_runner.py +356 -29
  181. sglang/srt/model_loader/__init__.py +9 -3
  182. sglang/srt/model_loader/loader.py +128 -4
  183. sglang/srt/model_loader/weight_utils.py +2 -1
  184. sglang/srt/models/apertus.py +686 -0
  185. sglang/srt/models/bailing_moe.py +798 -218
  186. sglang/srt/models/bailing_moe_nextn.py +168 -0
  187. sglang/srt/models/deepseek_v2.py +109 -15
  188. sglang/srt/models/dots_vlm.py +174 -0
  189. sglang/srt/models/dots_vlm_vit.py +337 -0
  190. sglang/srt/models/ernie4.py +1 -1
  191. sglang/srt/models/gemma3n_mm.py +1 -1
  192. sglang/srt/models/glm4_moe.py +1 -1
  193. sglang/srt/models/glm4v.py +4 -2
  194. sglang/srt/models/glm4v_moe.py +3 -0
  195. sglang/srt/models/gpt_oss.py +1 -1
  196. sglang/srt/models/llama4.py +9 -0
  197. sglang/srt/models/llama_eagle3.py +13 -0
  198. sglang/srt/models/longcat_flash.py +2 -2
  199. sglang/srt/models/mllama4.py +25 -0
  200. sglang/srt/models/opt.py +637 -0
  201. sglang/srt/models/qwen2.py +7 -0
  202. sglang/srt/models/qwen2_5_vl.py +27 -3
  203. sglang/srt/models/qwen2_moe.py +56 -12
  204. sglang/srt/models/qwen3_moe.py +1 -1
  205. sglang/srt/models/qwen3_next.py +1042 -0
  206. sglang/srt/models/qwen3_next_mtp.py +112 -0
  207. sglang/srt/models/step3_vl.py +1 -1
  208. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  209. sglang/srt/multimodal/processors/glm4v.py +9 -9
  210. sglang/srt/multimodal/processors/internvl.py +141 -129
  211. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  212. sglang/srt/offloader.py +27 -3
  213. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  214. sglang/srt/sampling/sampling_batch_info.py +18 -15
  215. sglang/srt/server_args.py +276 -35
  216. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  217. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  218. sglang/srt/speculative/eagle_utils.py +0 -2
  219. sglang/srt/speculative/eagle_worker.py +43 -4
  220. sglang/srt/speculative/spec_info.py +5 -0
  221. sglang/srt/speculative/standalone_worker.py +109 -0
  222. sglang/srt/tracing/trace.py +552 -0
  223. sglang/srt/utils.py +34 -3
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  226. sglang/test/runners.py +4 -0
  227. sglang/test/test_cutlass_moe.py +24 -6
  228. sglang/test/test_disaggregation_utils.py +66 -0
  229. sglang/test/test_fp4_moe.py +370 -1
  230. sglang/test/test_utils.py +28 -1
  231. sglang/utils.py +11 -0
  232. sglang/version.py +1 -1
  233. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  234. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
  235. sglang/srt/disaggregation/launch_lb.py +0 -118
  236. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  237. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  238. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -69,7 +69,10 @@ class LoRAManager:
69
69
  # LoRA backend for running sgemm kernels
70
70
  logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
71
71
  backend_type = get_backend_from_name(lora_backend)
72
- self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
72
+ self.lora_backend: BaseLoRABackend = backend_type(
73
+ max_loras_per_batch=max_loras_per_batch,
74
+ device=self.device,
75
+ )
73
76
 
74
77
  # Initialize mutable internal state of the LoRAManager.
75
78
  self.init_state(
@@ -82,29 +85,22 @@ class LoRAManager:
82
85
  self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
83
86
  with torch.device("cuda"):
84
87
  self.cuda_graph_batch_info = LoRABatchInfo(
85
- bs=self.max_bs_in_cuda_graph,
86
- seg_lens=torch.zeros(self.max_bs_in_cuda_graph, dtype=torch.int32),
87
- seg_indptr=torch.zeros(
88
- self.max_bs_in_cuda_graph + 1, dtype=torch.int32
89
- ),
88
+ bs=max_bs_in_cuda_graph,
89
+ use_cuda_graph=True,
90
+ num_segments=None,
91
+ seg_lens=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
92
+ seg_indptr=torch.zeros(max_bs_in_cuda_graph + 1, dtype=torch.int32),
90
93
  max_len=1,
91
- weight_indices=torch.zeros(
92
- self.max_bs_in_cuda_graph, dtype=torch.int32
93
- ),
94
+ weight_indices=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
95
+ permutation=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
94
96
  lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
95
97
  scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
96
98
  )
97
99
 
98
- # Initialize seg_lens and seg_indptr for CUDA graph as they remain constant
99
- # across batches.
100
- self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph].fill_(1)
101
- torch.cumsum(
102
- self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph],
103
- dim=0,
104
- out=self.cuda_graph_batch_info.seg_indptr[
105
- 1 : self.max_bs_in_cuda_graph + 1
106
- ],
107
- )
100
+ self.lora_backend.init_cuda_graph_batch_info(
101
+ cuda_graph_batch_info=self.cuda_graph_batch_info,
102
+ max_bs_in_cuda_graph=max_bs_in_cuda_graph,
103
+ )
108
104
 
109
105
  def create_lora_update_result(
110
106
  self, success: bool, error_message: str = ""
@@ -232,7 +228,6 @@ class LoRAManager:
232
228
  return required_slots <= mem_pool_vacancy
233
229
 
234
230
  def prepare_lora_batch(self, forward_batch: ForwardBatch):
235
-
236
231
  # Load active loras into lora memory pool
237
232
  cur_uids = set(forward_batch.lora_ids)
238
233
 
@@ -247,102 +242,30 @@ class LoRAManager:
247
242
  # set up batch info shared by all lora modules
248
243
  bs = forward_batch.batch_size
249
244
 
250
- def transfer_adapter_info(
251
- weight_indices_out: torch.Tensor,
252
- lora_ranks_out: torch.Tensor,
253
- scalings_out: torch.Tensor,
254
- ):
255
- """
256
- Transfer adapter metadata (weight indices, LoRA rank, scalings) from host
257
- to device (CUDA) asynchronously.
258
- """
259
- weight_indices = [0] * len(forward_batch.lora_ids)
260
- lora_ranks = [0] * self.max_loras_per_batch
261
- scalings = [0] * self.max_loras_per_batch
262
- for i, uid in enumerate(forward_batch.lora_ids):
263
- weight_indices[i] = self.memory_pool.get_buffer_id(uid)
264
- if uid is not None:
265
- lora = self.loras[uid]
266
- lora_ranks[weight_indices[i]] = lora.config.r
267
- scalings[weight_indices[i]] = lora.scaling
268
-
269
- # Use pinned memory to avoid synchronizations during host-to-device transfer
270
- weight_indices_tensor = torch.tensor(
271
- weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
272
- )
273
- lora_ranks_tensor = torch.tensor(
274
- lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
275
- )
276
- scalings_tensor = torch.tensor(
277
- scalings, dtype=torch.float, pin_memory=True, device="cpu"
278
- )
279
-
280
- # Copy to device tensors asynchronously
281
- weight_indices_out[:bs].copy_(weight_indices_tensor, non_blocking=True)
282
- lora_ranks_out[: self.max_loras_per_batch].copy_(
283
- lora_ranks_tensor, non_blocking=True
284
- )
285
- scalings_out[: self.max_loras_per_batch].copy_(
286
- scalings_tensor, non_blocking=True
287
- )
288
-
289
- if (
245
+ use_cuda_graph = (
290
246
  hasattr(self, "max_bs_in_cuda_graph")
291
247
  and bs <= self.max_bs_in_cuda_graph
292
248
  and forward_batch.forward_mode.is_cuda_graph()
293
- ):
294
- # Do in-place updates when CUDA graph is enabled and the batch forward mode
295
- # could use CUDA graph.
296
-
297
- transfer_adapter_info(
298
- self.cuda_graph_batch_info.weight_indices,
299
- self.cuda_graph_batch_info.lora_ranks,
300
- self.cuda_graph_batch_info.scalings,
301
- )
302
-
303
- self.cuda_graph_batch_info.bs = bs
304
- self.cuda_graph_batch_info.max_len = 1
305
- batch_info = self.cuda_graph_batch_info
306
- else:
307
- weight_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
308
- lora_ranks = torch.zeros(
309
- (self.max_loras_per_batch,), dtype=torch.int64, device=self.device
310
- )
311
- scalings = torch.zeros(
312
- (self.max_loras_per_batch,), dtype=torch.float, device=self.device
313
- )
314
- transfer_adapter_info(
315
- weight_indices,
316
- lora_ranks,
317
- scalings,
318
- )
319
-
320
- seg_lens = (
321
- forward_batch.extend_seq_lens
322
- if forward_batch.forward_mode.is_extend()
323
- else torch.ones(bs, device=self.device)
324
- )
325
-
326
- max_len = (
327
- # Calculate max_len from the CPU copy to avoid D2H transfer.
328
- max(forward_batch.extend_seq_lens_cpu)
329
- if forward_batch.forward_mode.is_extend()
330
- else 1
331
- )
249
+ )
332
250
 
333
- seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
334
- seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
335
-
336
- batch_info = LoRABatchInfo(
337
- bs=bs,
338
- seg_lens=seg_lens,
339
- seg_indptr=seg_indptr,
340
- max_len=max_len,
341
- weight_indices=weight_indices,
342
- lora_ranks=lora_ranks,
343
- scalings=scalings,
344
- )
345
- self.lora_backend.set_batch_info(batch_info)
251
+ weight_indices = [0] * len(forward_batch.lora_ids)
252
+ lora_ranks = [0] * self.max_loras_per_batch
253
+ scalings = [0] * self.max_loras_per_batch
254
+ for i, uid in enumerate(forward_batch.lora_ids):
255
+ weight_indices[i] = self.memory_pool.get_buffer_id(uid)
256
+ if uid is not None:
257
+ lora = self.loras[uid]
258
+ lora_ranks[weight_indices[i]] = lora.config.r
259
+ scalings[weight_indices[i]] = lora.scaling
260
+ # Do in-place updates when CUDA graph is enabled and the batch forward mode
261
+ # could use CUDA graph.
262
+ self.lora_backend.prepare_lora_batch(
263
+ forward_batch=forward_batch,
264
+ weight_indices=weight_indices,
265
+ lora_ranks=lora_ranks,
266
+ scalings=scalings,
267
+ batch_info=self.cuda_graph_batch_info if use_cuda_graph else None,
268
+ )
346
269
 
347
270
  def update_lora_info(self):
348
271
  """
@@ -104,12 +104,18 @@ class LoRAMemoryPool:
104
104
  return all(_can_support(x) for x in config)
105
105
 
106
106
  def get_lora_A_shape(
107
- self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
107
+ self,
108
+ module_name: str,
109
+ base_model: torch.nn.Module,
110
+ max_lora_dim: int,
111
+ layer_idx: int,
108
112
  ) -> Tuple[int]:
109
113
  """
110
114
  Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
111
115
  """
112
- input_dim, _ = get_hidden_dim(module_name, self.base_hf_config, base_model)
116
+ input_dim, _ = get_hidden_dim(
117
+ module_name, self.base_hf_config, base_model, layer_idx
118
+ )
113
119
  c = get_stacked_multiply(module_name)
114
120
  if self.tp_size > 1 and module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES:
115
121
  input_dim = divide(input_dim, self.tp_size)
@@ -120,12 +126,18 @@ class LoRAMemoryPool:
120
126
  )
121
127
 
122
128
  def get_lora_B_shape(
123
- self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
129
+ self,
130
+ module_name: str,
131
+ base_model: torch.nn.Module,
132
+ max_lora_dim: int,
133
+ layer_idx: int,
124
134
  ) -> Tuple[int]:
125
135
  """
126
136
  Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
127
137
  """
128
- _, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
138
+ _, output_dim = get_hidden_dim(
139
+ module_name, self.base_hf_config, base_model, layer_idx
140
+ )
129
141
  if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
130
142
  output_dim = divide(output_dim, self.tp_size)
131
143
  return (
@@ -140,19 +152,21 @@ class LoRAMemoryPool:
140
152
  def init_buffer(
141
153
  buffer: Dict[str, List[torch.Tensor]],
142
154
  target_modules: Set[str],
143
- get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
155
+ get_lora_shape_fn: Callable[[str, torch.nn.Module, int, int], Tuple[int]],
144
156
  ):
145
157
  for module_name in target_modules:
146
- lora_shape = get_lora_shape_fn(
147
- module_name, base_model, self.max_lora_rank
148
- )
149
158
  buffer[module_name] = [
150
159
  torch.empty(
151
- lora_shape,
160
+ get_lora_shape_fn(
161
+ module_name,
162
+ base_model,
163
+ self.max_lora_rank,
164
+ idx,
165
+ ),
152
166
  dtype=self.dtype,
153
167
  device=device,
154
168
  )
155
- for _ in range(self.num_layer)
169
+ for idx in range(self.num_layer)
156
170
  ]
157
171
 
158
172
  init_buffer(
sglang/srt/lora/utils.py CHANGED
@@ -10,19 +10,19 @@ from sglang.srt.hf_transformers_utils import AutoConfig
10
10
 
11
11
  @dataclass
12
12
  class LoRABatchInfo:
13
+ # The forward mode is using CUDA Graph.
14
+ use_cuda_graph: bool
15
+
13
16
  # Batch size
14
17
  bs: int
15
18
 
16
- # Lengths of each sequence in shape (bs,)
17
- seg_lens: torch.Tensor
19
+ # Number of segments. For triton backend, it is equal to batch size.
20
+ num_segments: int
18
21
 
19
- # Indice pointers of each sequence in shape (bs + 1, )
22
+ # Indice pointers of each segment in shape (num_segments + 1, )
20
23
  seg_indptr: torch.Tensor
21
24
 
22
- # Maximum sequence length of current batch
23
- max_len: int
24
-
25
- # The index of lora adapter used by each sequence, in shape (bs,)
25
+ # The index of lora adapter used by each segment, in shape (num_segments,)
26
26
  weight_indices: torch.Tensor
27
27
 
28
28
  # ranks of each lora adapter, in shape (lora_num,)
@@ -31,6 +31,15 @@ class LoRABatchInfo:
31
31
  # scaling of each lora adapter, in shape (lora_num,)
32
32
  scalings: torch.Tensor
33
33
 
34
+ # Lengths of each segments in shape (num_segments,)
35
+ seg_lens: Optional[torch.Tensor]
36
+
37
+ # Maximum segment length of current batch
38
+ max_len: Optional[int]
39
+
40
+ # The logical (re)ordering of input rows (tokens), in shape (num_tokens,)
41
+ permutation: Optional[torch.Tensor]
42
+
34
43
 
35
44
  class LoRAType(Enum):
36
45
  LORA_A = 0
@@ -48,14 +57,14 @@ def get_layer_id(name: str) -> int:
48
57
 
49
58
 
50
59
  def get_hidden_dim(
51
- module_name: str, config: AutoConfig, base_model: torch.nn.Module
60
+ module_name: str, config: AutoConfig, base_model: torch.nn.Module, layer_idx: int
52
61
  ) -> Tuple[int]:
53
62
  """
54
63
  Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
55
64
  """
56
65
 
57
66
  if hasattr(base_model, "get_hidden_dim"):
58
- return base_model.get_hidden_dim(module_name)
67
+ return base_model.get_hidden_dim(module_name, layer_idx)
59
68
  else:
60
69
  """
61
70
  WARNING: get_hidden_dim() is not defined,
@@ -0,0 +1,170 @@
1
+ """
2
+ Asynchronous dynamic batch tokenizer for SGLang.
3
+
4
+ This module provides an async tokenizer with dynamic batching capabilities
5
+ to reduce tokenization overhead when multiple requests arrive concurrently.
6
+ """
7
+
8
+ import asyncio
9
+ import logging
10
+ from concurrent.futures import ThreadPoolExecutor
11
+ from functools import partial
12
+ from typing import Any, Dict, List, Optional
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class AsyncDynamicbatchTokenizer:
18
+ """Asynchronous tokenizer with dynamic batching for single string prompts.
19
+
20
+ Dynamically batches pending encode requests from a queue to reduce overhead.
21
+ Only handles single string prompts - regular batch processing of multiple
22
+ strings per request should be handled at a higher level.
23
+ A single-thread ThreadPoolExecutor is used so the event loop stays responsive.
24
+
25
+ Note: Uses lazy initialization for asyncio components because this class
26
+ is instantiated in TokenizerManager.__init__() before the event loop starts.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ tokenizer,
32
+ max_batch_size: int = 32,
33
+ batch_wait_timeout_s: float = 0.002,
34
+ ) -> None:
35
+ self.tokenizer = tokenizer
36
+ self.max_batch_size = max_batch_size
37
+ self.batch_wait_timeout_s = batch_wait_timeout_s
38
+
39
+ # Single queue for all encode requests - initialized lazily
40
+ self._queue: Optional[asyncio.Queue] = None
41
+ self._batcher_task: Optional[asyncio.Task] = None
42
+
43
+ # Single-thread executor for blocking tokenizer calls
44
+ self._executor = ThreadPoolExecutor(max_workers=1)
45
+ self._initialized = False
46
+
47
+ def _ensure_initialized(self):
48
+ """Lazy initialization of event loop dependent components."""
49
+ if not self._initialized:
50
+ self._queue = asyncio.Queue()
51
+ self._batcher_task = asyncio.create_task(self._dynamic_batch_loop())
52
+ self._initialized = True
53
+
54
+ async def __call__(self, prompt: str, **kwargs) -> Any:
55
+ """Encode a single prompt."""
56
+ return await self.encode(prompt, **kwargs)
57
+
58
+ async def encode(self, prompt: str, **kwargs) -> Any:
59
+ """Encode a single prompt."""
60
+ self._ensure_initialized()
61
+ result_future: asyncio.Future = asyncio.get_running_loop().create_future()
62
+ await self._queue.put((prompt, kwargs, result_future))
63
+ return await result_future
64
+
65
+ async def _dynamic_batch_loop(self):
66
+ """Dynamically batch incoming encode requests for efficiency."""
67
+ while True:
68
+ try:
69
+ # Get the first request
70
+ prompt, kwargs, result_future = await self._queue.get()
71
+
72
+ # Collect requests into dynamic batch
73
+ prompts = [prompt]
74
+ kwargs_list = [kwargs]
75
+ result_futures = [result_future]
76
+
77
+ # Check if there are more items immediately available in the queue
78
+ # If queue is empty, process single item immediately without timeout
79
+ if self._queue.empty():
80
+ # No other requests waiting, process immediately
81
+ pass
82
+ else:
83
+ # There might be more requests, wait for dynamic batching opportunity
84
+ start_time = asyncio.get_running_loop().time()
85
+
86
+ # Collect more requests up to max_batch_size or batch_wait_timeout_s
87
+ while len(prompts) < self.max_batch_size:
88
+ elapsed = asyncio.get_running_loop().time() - start_time
89
+ if elapsed >= self.batch_wait_timeout_s:
90
+ break
91
+
92
+ remaining_time = self.batch_wait_timeout_s - elapsed
93
+ try:
94
+ prompt, kwargs, result_future = await asyncio.wait_for(
95
+ self._queue.get(), remaining_time
96
+ )
97
+ prompts.append(prompt)
98
+ kwargs_list.append(kwargs)
99
+ result_futures.append(result_future)
100
+ except asyncio.TimeoutError:
101
+ break
102
+
103
+ # Log dynamic batch information
104
+ logger.debug(
105
+ f"AsyncDynamicbatchTokenizer: Processing dynamic batch of size {len(prompts)}"
106
+ )
107
+
108
+ # Process the dynamic batch
109
+ await self._process_dynamic_batch(prompts, kwargs_list, result_futures)
110
+
111
+ except Exception as e:
112
+ logger.error(f"Error in dynamic batch loop: {e}")
113
+ # Continue the loop to handle other requests
114
+
115
+ async def _process_dynamic_batch(
116
+ self,
117
+ prompts: List[str],
118
+ kwargs_list: List[Dict],
119
+ result_futures: List[asyncio.Future],
120
+ ) -> None:
121
+ """Process a dynamic batch of encode requests for single string prompts."""
122
+ # Check if all kwargs are identical for efficient batch processing
123
+ can_batch = len(set(str(sorted(kw.items())) for kw in kwargs_list)) == 1
124
+ kwargs = kwargs_list[0] if can_batch else None
125
+
126
+ try:
127
+ # If every request uses identical kwargs we can run a single
128
+ # batch tokenizer call for a big speed-up.
129
+ if can_batch and len(prompts) > 1:
130
+ encode_fn = partial(self.tokenizer, prompts, **kwargs)
131
+ results = await asyncio.get_running_loop().run_in_executor(
132
+ self._executor, encode_fn
133
+ )
134
+
135
+ for i, fut in enumerate(result_futures):
136
+ if not fut.done():
137
+ data = {k: v[i] for k, v in results.items()}
138
+ fut.set_result(data)
139
+ else:
140
+ # Process each request individually due to different kwargs
141
+ if len(prompts) > 1 and not can_batch:
142
+ logger.warning(
143
+ f"AsyncDynamicbatchTokenizer: Dynamic batching disabled for batch of {len(prompts)} "
144
+ f"requests due to differing kwargs. This reduces performance benefits. "
145
+ f"Consider using consistent tokenization parameters across requests."
146
+ )
147
+
148
+ encode_fn = lambda prompts=prompts, kwargs=kwargs_list: [
149
+ self.tokenizer(p, **kw) for p, kw in zip(prompts, kwargs_list)
150
+ ]
151
+ results = await asyncio.get_running_loop().run_in_executor(
152
+ self._executor, encode_fn
153
+ )
154
+
155
+ for fut, res in zip(result_futures, results):
156
+ if not fut.done():
157
+ fut.set_result(res)
158
+ except Exception as e:
159
+ logger.error(f"Error in dynamic batch processing: {e}")
160
+ for fut in result_futures:
161
+ if not fut.done():
162
+ fut.set_exception(e)
163
+
164
+ def __del__(self):
165
+ """Clean up background tasks."""
166
+ if hasattr(self, "_batcher_task") and self._batcher_task:
167
+ if not self._batcher_task.done():
168
+ self._batcher_task.cancel()
169
+ if hasattr(self, "_executor"):
170
+ self._executor.shutdown(wait=False)