sglang 0.4.3.post1__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +208 -295
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  64. sglang/srt/layers/moe/topk.py +13 -4
  65. sglang/srt/layers/quantization/__init__.py +111 -7
  66. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  69. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  71. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  72. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  73. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  75. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  80. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  82. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  83. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  86. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  87. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  89. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  90. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  91. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  93. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  94. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  95. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  96. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  97. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  98. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  99. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  100. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  101. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  102. sglang/srt/layers/quantization/fp8.py +69 -28
  103. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  104. sglang/srt/layers/quantization/gptq.py +416 -0
  105. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  106. sglang/srt/layers/quantization/int8_utils.py +73 -0
  107. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  108. sglang/srt/layers/radix_attention.py +1 -0
  109. sglang/srt/layers/rotary_embedding.py +0 -1
  110. sglang/srt/layers/sampler.py +76 -31
  111. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  112. sglang/srt/lora/lora.py +17 -1
  113. sglang/srt/lora/lora_config.py +5 -0
  114. sglang/srt/lora/lora_manager.py +1 -3
  115. sglang/srt/managers/cache_controller.py +193 -62
  116. sglang/srt/managers/configure_logging.py +2 -1
  117. sglang/srt/managers/data_parallel_controller.py +6 -2
  118. sglang/srt/managers/detokenizer_manager.py +124 -102
  119. sglang/srt/managers/image_processor.py +2 -1
  120. sglang/srt/managers/io_struct.py +143 -6
  121. sglang/srt/managers/schedule_batch.py +238 -197
  122. sglang/srt/managers/schedule_policy.py +29 -29
  123. sglang/srt/managers/scheduler.py +681 -259
  124. sglang/srt/managers/session_controller.py +6 -2
  125. sglang/srt/managers/tokenizer_manager.py +224 -68
  126. sglang/srt/managers/tp_worker.py +15 -4
  127. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  128. sglang/srt/mem_cache/chunk_cache.py +18 -11
  129. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  130. sglang/srt/mem_cache/memory_pool.py +44 -18
  131. sglang/srt/mem_cache/radix_cache.py +58 -47
  132. sglang/srt/metrics/collector.py +94 -36
  133. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  134. sglang/srt/model_executor/forward_batch_info.py +49 -16
  135. sglang/srt/model_executor/model_runner.py +209 -28
  136. sglang/srt/model_loader/loader.py +3 -3
  137. sglang/srt/model_loader/weight_utils.py +36 -14
  138. sglang/srt/models/baichuan.py +31 -6
  139. sglang/srt/models/chatglm.py +39 -7
  140. sglang/srt/models/commandr.py +29 -5
  141. sglang/srt/models/dbrx.py +31 -5
  142. sglang/srt/models/deepseek.py +43 -6
  143. sglang/srt/models/deepseek_nextn.py +32 -19
  144. sglang/srt/models/deepseek_v2.py +265 -29
  145. sglang/srt/models/exaone.py +19 -9
  146. sglang/srt/models/gemma.py +22 -8
  147. sglang/srt/models/gemma2.py +25 -12
  148. sglang/srt/models/gemma2_reward.py +5 -1
  149. sglang/srt/models/gpt2.py +28 -13
  150. sglang/srt/models/gpt_bigcode.py +27 -5
  151. sglang/srt/models/granite.py +21 -9
  152. sglang/srt/models/grok.py +21 -4
  153. sglang/srt/models/internlm2.py +36 -6
  154. sglang/srt/models/internlm2_reward.py +5 -1
  155. sglang/srt/models/llama.py +26 -9
  156. sglang/srt/models/llama_classification.py +5 -1
  157. sglang/srt/models/llama_eagle.py +17 -4
  158. sglang/srt/models/llama_embedding.py +5 -1
  159. sglang/srt/models/llama_reward.py +7 -2
  160. sglang/srt/models/llava.py +19 -3
  161. sglang/srt/models/llavavid.py +10 -1
  162. sglang/srt/models/minicpm.py +26 -2
  163. sglang/srt/models/minicpm3.py +39 -3
  164. sglang/srt/models/minicpmv.py +45 -14
  165. sglang/srt/models/mixtral.py +20 -9
  166. sglang/srt/models/mixtral_quant.py +50 -8
  167. sglang/srt/models/mllama.py +57 -11
  168. sglang/srt/models/olmo.py +34 -6
  169. sglang/srt/models/olmo2.py +34 -13
  170. sglang/srt/models/olmoe.py +26 -4
  171. sglang/srt/models/phi3_small.py +29 -10
  172. sglang/srt/models/qwen.py +26 -3
  173. sglang/srt/models/qwen2.py +26 -4
  174. sglang/srt/models/qwen2_5_vl.py +46 -8
  175. sglang/srt/models/qwen2_eagle.py +17 -5
  176. sglang/srt/models/qwen2_moe.py +44 -6
  177. sglang/srt/models/qwen2_rm.py +78 -0
  178. sglang/srt/models/qwen2_vl.py +39 -8
  179. sglang/srt/models/stablelm.py +32 -5
  180. sglang/srt/models/torch_native_llama.py +5 -2
  181. sglang/srt/models/xverse.py +21 -9
  182. sglang/srt/models/xverse_moe.py +45 -7
  183. sglang/srt/models/yivl.py +2 -1
  184. sglang/srt/openai_api/adapter.py +109 -24
  185. sglang/srt/openai_api/protocol.py +17 -1
  186. sglang/srt/reasoning_parser.py +154 -0
  187. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  188. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  189. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  190. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  191. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  192. sglang/srt/sampling/sampling_batch_info.py +79 -157
  193. sglang/srt/sampling/sampling_params.py +16 -13
  194. sglang/srt/server_args.py +136 -52
  195. sglang/srt/speculative/build_eagle_tree.py +2 -8
  196. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  197. sglang/srt/speculative/eagle_utils.py +92 -58
  198. sglang/srt/speculative/eagle_worker.py +186 -94
  199. sglang/srt/speculative/spec_info.py +1 -13
  200. sglang/srt/utils.py +43 -17
  201. sglang/srt/warmup.py +47 -0
  202. sglang/test/few_shot_gsm8k.py +4 -1
  203. sglang/test/runners.py +389 -126
  204. sglang/test/send_one.py +88 -0
  205. sglang/test/test_block_fp8_ep.py +361 -0
  206. sglang/test/test_programs.py +1 -1
  207. sglang/test/test_utils.py +138 -84
  208. sglang/utils.py +50 -60
  209. sglang/version.py +1 -1
  210. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  211. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +214 -166
  212. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  213. sglang/bench_latency.py +0 -1
  214. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  215. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  216. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  217. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  218. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  219. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -261,26 +261,27 @@ class VocabParallelEmbedding(torch.nn.Module):
261
261
  )
262
262
  self.embedding_dim = embedding_dim
263
263
 
264
- linear_method = None
264
+ quant_method = None
265
265
  if quant_config is not None:
266
- linear_method = quant_config.get_quant_method(self, prefix=prefix)
267
- if linear_method is None:
268
- linear_method = UnquantizedEmbeddingMethod()
266
+ quant_method = quant_config.get_quant_method(self, prefix=prefix)
267
+ print("quant_method", quant_method)
268
+ if quant_method is None:
269
+ quant_method = UnquantizedEmbeddingMethod()
269
270
 
270
271
  # If we are making an embedding layer, then our quantization linear
271
272
  # method must implement the embedding operation. If we are another
272
273
  # layer type like ParallelLMHead, this is not important.
273
274
  is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
274
- linear_method_implements_embedding = method_has_implemented_embedding(
275
- type(linear_method)
275
+ quant_method_implements_embedding = method_has_implemented_embedding(
276
+ type(quant_method)
276
277
  )
277
- if is_embedding_layer and not linear_method_implements_embedding:
278
+ if is_embedding_layer and not quant_method_implements_embedding:
278
279
  raise NotImplementedError(
279
- f"The class {type(linear_method).__name__} must implement "
280
+ f"The class {type(quant_method).__name__} must implement "
280
281
  "the 'embedding' method, see UnquantizedEmbeddingMethod."
281
282
  )
282
283
 
283
- self.linear_method: QuantizeMethodBase = linear_method
284
+ self.quant_method: QuantizeMethodBase = quant_method
284
285
 
285
286
  if params_dtype is None:
286
287
  params_dtype = torch.get_default_dtype()
@@ -301,7 +302,7 @@ class VocabParallelEmbedding(torch.nn.Module):
301
302
  - self.shard_indices.added_vocab_start_index
302
303
  )
303
304
 
304
- self.linear_method.create_weights(
305
+ self.quant_method.create_weights(
305
306
  self,
306
307
  self.embedding_dim,
307
308
  [self.num_embeddings_per_partition],
@@ -446,7 +447,7 @@ class VocabParallelEmbedding(torch.nn.Module):
446
447
  packed_factor = (
447
448
  param.packed_factor
448
449
  if isinstance(param, BasevLLMParameter)
449
- else param.pack_factor
450
+ else param.packed_factor
450
451
  )
451
452
  assert loaded_weight.shape[output_dim] == (
452
453
  self.org_vocab_size // param.packed_factor
@@ -457,7 +458,7 @@ class VocabParallelEmbedding(torch.nn.Module):
457
458
  assert loaded_weight.shape[output_dim] == (
458
459
  self.org_vocab_size
459
460
  // (self.tp_size if self.use_presharded_weights else 1)
460
- )
461
+ ), f"{self.org_vocab_size=} {self.use_presharded_weights=} {loaded_weight.shape[output_dim]=}"
461
462
 
462
463
  # Copy the data.
463
464
  if not self.use_presharded_weights:
@@ -479,7 +480,7 @@ class VocabParallelEmbedding(torch.nn.Module):
479
480
  else:
480
481
  masked_input = input_
481
482
  # Get the embeddings.
482
- output_parallel = self.linear_method.embedding(self, masked_input.long())
483
+ output_parallel = self.quant_method.embedding(self, masked_input.long())
483
484
  # Mask the output embedding.
484
485
  if self.tp_size > 1:
485
486
  output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
sglang/srt/lora/lora.py CHANGED
@@ -18,6 +18,7 @@
18
18
  # LoRA layers class inheritance adapted from:
19
19
  # https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
20
20
 
21
+ import logging
21
22
  import re
22
23
  from typing import Dict, List
23
24
 
@@ -30,6 +31,8 @@ from sglang.srt.lora.backend import BaseLoRABackend
30
31
  from sglang.srt.lora.lora_config import LoRAConfig
31
32
  from sglang.srt.model_loader.loader import DefaultModelLoader
32
33
 
34
+ logger = logging.getLogger(__name__)
35
+
33
36
 
34
37
  class LoRALayer(nn.Module):
35
38
  def __init__(self, config: LoRAConfig, base_hf_config: AutoConfig):
@@ -173,6 +176,18 @@ class LoRAAdapter(nn.Module):
173
176
  if "gate_proj" in weight_name:
174
177
  up_name = weight_name.replace("gate_proj", "up_proj")
175
178
  gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
179
+ if up_name not in weights:
180
+ logger.warning(
181
+ f"Gate projection {weight_name} does not have a corresponding up projection {up_name}. "
182
+ f"Initializing up projection to zero."
183
+ )
184
+ weights[up_name] = torch.zeros_like(weights[weight_name])
185
+ # FIXME: Add gate-only support for flashinfer in future implementations
186
+ assert self.lora_backend.name == "triton", (
187
+ f"LoRA weight initialization currently only supported for 'triton' backend. "
188
+ f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
189
+ f"or consider implementing custom initialization logic for other backends."
190
+ )
176
191
  if "lora_A" in weight_name:
177
192
  weights[gate_up_name] = torch.cat(
178
193
  (weights[weight_name], weights[up_name]), 0
@@ -182,4 +197,5 @@ class LoRAAdapter(nn.Module):
182
197
  [weights[weight_name], weights[up_name]], dim=0
183
198
  )
184
199
  weights.pop(weight_name)
185
- weights.pop(up_name)
200
+ if up_name in weights:
201
+ weights.pop(up_name)
@@ -26,6 +26,11 @@ class LoRAConfig:
26
26
  self.path = path
27
27
  self.hf_config = self.get_lora_config()
28
28
  self.target_modules = self.hf_config["target_modules"]
29
+
30
+ # TODO: Support more modules
31
+ if any(module in self.target_modules for module in ["embed_tokens", "lm_head"]):
32
+ raise ValueError("Not supported yet")
33
+
29
34
  self.r = self.hf_config["r"]
30
35
  self.lora_alpha = self.hf_config["lora_alpha"]
31
36
 
@@ -76,9 +76,7 @@ class LoRAManager:
76
76
  self.hf_target_names: Set[str] = set()
77
77
  for name, path in self.lora_paths.items():
78
78
  self.configs[name] = LoRAConfig(path)
79
- self.hf_target_names = set(self.hf_target_names) | set(
80
- self.configs[name].target_modules
81
- )
79
+ self.hf_target_names.update(self.configs[name].target_modules)
82
80
 
83
81
  # Target lora weight names for lora_a and lora_b modules repectively.
84
82
  # e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")}
@@ -5,9 +5,7 @@ Copyright 2023-2025 SGLang Team
5
5
  Licensed under the Apache License, Version 2.0 (the "License");
6
6
  you may not use this file except in compliance with the License.
7
7
  You may obtain a copy of the License at
8
-
9
8
  http://www.apache.org/licenses/LICENSE-2.0
10
-
11
9
  Unless required by applicable law or agreed to in writing, software
12
10
  distributed under the License is distributed on an "AS IS" BASIS,
13
11
  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -15,14 +13,16 @@ See the License for the specific language governing permissions and
15
13
  limitations under the License.
16
14
  """
17
15
 
16
+ import concurrent.futures
18
17
  import logging
18
+ import math
19
19
  import threading
20
- from queue import PriorityQueue, Queue
21
- from typing import Optional
20
+ from queue import Empty, Full, PriorityQueue, Queue
21
+ from typing import List, Optional
22
22
 
23
23
  import torch
24
24
 
25
- from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPoolHost
25
+ from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MHATokenToKVPoolHost
26
26
 
27
27
  logger = logging.getLogger(__name__)
28
28
 
@@ -55,6 +55,27 @@ class CacheOperation:
55
55
  self.priority = min(self.priority, other.priority)
56
56
  self.node_ids.extend(other.node_ids)
57
57
 
58
+ def split(self, factor) -> List["CacheOperation"]:
59
+ # split an operation into smaller operations to reduce the size of intermediate buffers
60
+ if factor <= 1:
61
+ return [self]
62
+
63
+ chunk_size = math.ceil(len(self.host_indices) / factor)
64
+ split_ops = []
65
+ for i in range(0, len(self.host_indices), chunk_size):
66
+ split_ops.append(
67
+ CacheOperation(
68
+ host_indices=self.host_indices[i : i + chunk_size],
69
+ device_indices=self.device_indices[i : i + chunk_size],
70
+ node_id=0,
71
+ )
72
+ )
73
+ # Inherit the node_ids on the final chunk
74
+ if split_ops:
75
+ split_ops[-1].node_ids = self.node_ids
76
+
77
+ return split_ops
78
+
58
79
  def __lt__(self, other: "CacheOperation"):
59
80
  return self.priority < other.priority
60
81
 
@@ -64,7 +85,10 @@ class TransferBuffer:
64
85
  Overlapping buffer preparation and transfer operations to improve throughput.
65
86
  """
66
87
 
67
- def __init__(self, buffer_count: int = 3, max_buffer_size: int = 1000) -> None:
88
+ def __init__(
89
+ self, stop_event, buffer_count: int = 3, max_buffer_size: int = 1000
90
+ ) -> None:
91
+ self.stop_event = stop_event
68
92
  self.buffers = Queue(maxsize=buffer_count)
69
93
  # todo: adjust the buffer size based on throughput profile of the system
70
94
  self.max_buffer_size = max_buffer_size
@@ -75,22 +99,36 @@ class TransferBuffer:
75
99
  def empty(self) -> bool:
76
100
  return self.buffers.empty()
77
101
 
78
- def put(self, item, block=True) -> None:
79
- self.buffers.put(item, block=block)
102
+ def put(self, item, block=True, timeout=1) -> None:
103
+ while not self.stop_event.is_set():
104
+ try:
105
+ self.buffers.put(item, block=block, timeout=timeout)
106
+ break
107
+ except Full:
108
+ if not block:
109
+ break
110
+ continue
111
+ except Exception as e:
112
+ logger.error(e)
80
113
 
81
- def get(self, block=True) -> Optional[CacheOperation]:
114
+ def get(self, block=True, timeout=1) -> Optional[CacheOperation]:
82
115
  try:
83
- return self.buffers.get(block=block)
116
+ return self.buffers.get(block=block, timeout=timeout)
117
+ except Empty:
118
+ return None
84
119
  except Exception as e:
85
120
  logger.error(e)
86
121
 
122
+ def clear(self):
123
+ self.buffers.queue.clear()
124
+
87
125
 
88
126
  class HiCacheController:
89
127
 
90
128
  def __init__(
91
129
  self,
92
130
  mem_pool_device: MHATokenToKVPool,
93
- mem_pool_host: MLATokenToKVPoolHost,
131
+ mem_pool_host: MHATokenToKVPoolHost,
94
132
  write_policy: str = "write_through_selective",
95
133
  ):
96
134
 
@@ -111,8 +149,11 @@ class HiCacheController:
111
149
  self.ack_write_queue = Queue()
112
150
  self.ack_load_queue = Queue()
113
151
 
114
- self.write_buffer = TransferBuffer()
115
- self.load_buffer = TransferBuffer()
152
+ self.stop_event = threading.Event()
153
+ self.write_buffer = TransferBuffer(self.stop_event)
154
+ self.load_buffer = TransferBuffer(
155
+ self.stop_event, buffer_count=10, max_buffer_size=100
156
+ )
116
157
 
117
158
  self.write_stream = torch.cuda.Stream()
118
159
  self.load_stream = torch.cuda.Stream()
@@ -126,6 +167,28 @@ class HiCacheController:
126
167
  self.write_thread.start()
127
168
  self.load_thread.start()
128
169
 
170
+ def reset(self):
171
+ self.stop_event.set()
172
+ self.write_thread.join()
173
+ self.load_thread.join()
174
+
175
+ self.write_queue.queue.clear()
176
+ self.load_queue.queue.clear()
177
+ self.write_buffer.clear()
178
+ self.load_buffer.clear()
179
+ self.ack_write_queue.queue.clear()
180
+ self.ack_load_queue.queue.clear()
181
+
182
+ self.write_thread = threading.Thread(
183
+ target=self.write_thread_func_buffer, daemon=True
184
+ )
185
+ self.load_thread = threading.Thread(
186
+ target=self.load_thread_func_buffer, daemon=True
187
+ )
188
+ self.stop_event.clear()
189
+ self.write_thread.start()
190
+ self.load_thread.start()
191
+
129
192
  def write(
130
193
  self,
131
194
  device_indices: torch.Tensor,
@@ -138,10 +201,10 @@ class HiCacheController:
138
201
  host_indices = self.mem_pool_host.alloc(len(device_indices))
139
202
  if host_indices is None:
140
203
  return None
204
+ self.mem_pool_host.protect_write(host_indices)
141
205
  self.write_queue.put(
142
206
  CacheOperation(host_indices, device_indices, node_id, priority)
143
207
  )
144
- self.mem_pool_host.protect_write(host_indices)
145
208
  return host_indices
146
209
 
147
210
  def load(
@@ -156,10 +219,10 @@ class HiCacheController:
156
219
  device_indices = self.mem_pool_device.alloc(len(host_indices))
157
220
  if device_indices is None:
158
221
  return None
222
+ self.mem_pool_host.protect_load(host_indices)
159
223
  self.load_queue.put(
160
224
  CacheOperation(host_indices, device_indices, node_id, priority)
161
225
  )
162
- self.mem_pool_host.protect_load(host_indices)
163
226
  return device_indices
164
227
 
165
228
  def write_thread_func_direct(self):
@@ -167,16 +230,19 @@ class HiCacheController:
167
230
  Directly write through KV caches to host memory without buffering.
168
231
  """
169
232
  with torch.cuda.stream(self.write_stream):
170
- while True:
233
+ while not self.stop_event.is_set():
171
234
  try:
172
- operation = self.write_queue.get(block=True)
235
+ operation = self.write_queue.get(block=True, timeout=1)
173
236
  operation.data = self.mem_pool_device.get_flat_data(
174
237
  operation.device_indices
175
238
  )
176
239
  self.mem_pool_host.transfer(operation.host_indices, operation.data)
177
240
  self.mem_pool_host.complete_io(operation.host_indices)
178
241
  for node_id in operation.node_ids:
179
- self.ack_write_queue.put(node_id)
242
+ if node_id != 0:
243
+ self.ack_write_queue.put(node_id)
244
+ except Empty:
245
+ continue
180
246
  except Exception as e:
181
247
  logger.error(e)
182
248
 
@@ -185,9 +251,10 @@ class HiCacheController:
185
251
  Directly load KV caches from host memory to device memory without buffering.
186
252
  """
187
253
  with torch.cuda.stream(self.load_stream):
188
- while True:
254
+ while not self.stop_event.is_set():
189
255
  try:
190
- operation = self.load_queue.get(block=True)
256
+ operation = self.load_queue.get(block=True, timeout=1)
257
+ # time.sleep(18e-6 * len(operation.host_indices))
191
258
  operation.data = self.mem_pool_host.get_flat_data(
192
259
  operation.host_indices
193
260
  )
@@ -196,7 +263,10 @@ class HiCacheController:
196
263
  )
197
264
  self.mem_pool_host.complete_io(operation.host_indices)
198
265
  for node_id in operation.node_ids:
199
- self.ack_load_queue.put(node_id)
266
+ if node_id != 0:
267
+ self.ack_load_queue.put(node_id)
268
+ except Empty:
269
+ continue
200
270
  except Exception as e:
201
271
  logger.error(e)
202
272
 
@@ -204,39 +274,98 @@ class HiCacheController:
204
274
  """
205
275
  Auxiliary function to prepare the buffer for write operations.
206
276
  """
277
+
278
+ def _to_op(op_):
279
+ assert op_.device_indices.is_cuda, "Device indices should be on GPU"
280
+ op_.data = self.mem_pool_device.get_flat_data(op_.device_indices).to(
281
+ self.mem_pool_host.device
282
+ )
283
+ self.write_buffer.put(op_)
284
+ return op_
285
+
207
286
  buffer = None
208
- while True:
209
- try:
210
- operation = self.write_queue.get(block=True)
211
- if buffer is None:
212
- buffer = operation
213
- else:
214
- buffer.merge(operation)
215
- if (
216
- no_wait
217
- or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
218
- or self.write_queue.empty()
219
- or self.write_buffer.empty()
220
- ):
221
- assert (
222
- buffer.device_indices.is_cuda
223
- ), "Device indices should be on GPU"
224
- buffer.data = self.mem_pool_device.get_flat_data(
225
- buffer.device_indices
226
- ).contiguous()
227
- self.write_buffer.put(buffer, block=True)
228
- buffer = None
229
- except Exception as e:
230
- logger.error(e)
287
+ with torch.cuda.stream(self.write_stream):
288
+ while not self.stop_event.is_set():
289
+ try:
290
+ operation = self.write_queue.get(block=True, timeout=1)
291
+ factor = (
292
+ len(operation.device_indices)
293
+ // self.write_buffer.max_buffer_size
294
+ )
295
+
296
+ if factor >= 1:
297
+ if buffer is not None:
298
+ _to_op(buffer)
299
+ buffer = None
300
+
301
+ if factor < 2:
302
+ _to_op(operation)
303
+ else:
304
+ split_ops = operation.split(factor)
305
+ for op_ in split_ops:
306
+ _to_op(op_)
307
+ continue
308
+
309
+ if buffer is None:
310
+ buffer = operation
311
+ else:
312
+ buffer.merge(operation)
313
+ if (
314
+ no_wait
315
+ or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
316
+ or self.write_queue.empty()
317
+ or self.write_buffer.empty()
318
+ ):
319
+ _to_op(buffer)
320
+ buffer = None
321
+ except Empty:
322
+ continue
323
+ except Exception as e:
324
+ logger.error(e)
231
325
 
232
326
  def load_aux_func(self):
233
327
  """
234
328
  Auxiliary function to prepare the buffer for load operations.
235
329
  """
330
+
331
+ def _pin_op(op_, put=True):
332
+ op_.data = (
333
+ self.mem_pool_host.get_flat_data(op_.host_indices)
334
+ .contiguous()
335
+ .pin_memory()
336
+ )
337
+ if put:
338
+ self.load_buffer.put(op_)
339
+ return op_
340
+
236
341
  buffer = None
237
- while True:
342
+ while not self.stop_event.is_set():
238
343
  try:
239
- operation = self.load_queue.get(block=True)
344
+ operation = self.load_queue.get(block=True, timeout=1)
345
+ factor = len(operation.host_indices) // self.load_buffer.max_buffer_size
346
+
347
+ if factor >= 1:
348
+ if buffer is not None:
349
+ _pin_op(buffer)
350
+ buffer = None
351
+
352
+ if factor < 2:
353
+ _pin_op(operation)
354
+ else:
355
+ split_ops = operation.split(factor)
356
+ split_args = [(op_, True) for op_ in split_ops[:-1]]
357
+ split_args.append((split_ops[-1], False))
358
+ # Spawn threads to pin each op concurrently
359
+ with concurrent.futures.ThreadPoolExecutor() as executor:
360
+ pinned_ops = list(
361
+ executor.map(
362
+ lambda x: _pin_op(x[0], put=x[1]), split_args
363
+ )
364
+ )
365
+ # preserve the order of last op to ensure correct ack
366
+ self.load_buffer.put(pinned_ops[-1])
367
+ continue
368
+
240
369
  if buffer is None:
241
370
  buffer = operation
242
371
  else:
@@ -246,41 +375,43 @@ class HiCacheController:
246
375
  or self.load_queue.empty()
247
376
  or self.load_buffer.empty()
248
377
  ):
249
- buffer.data = (
250
- self.mem_pool_host.get_flat_data(buffer.host_indices)
251
- .contiguous()
252
- .pin_memory()
253
- )
254
- self.load_buffer.put(buffer, block=True)
378
+ _pin_op(buffer)
255
379
  buffer = None
380
+ except Empty:
381
+ continue
256
382
  except Exception as e:
257
383
  logger.error(e)
258
384
 
259
385
  def write_thread_func_buffer(self):
260
386
  aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
261
387
  aux_thread.start()
262
- with torch.cuda.stream(self.write_stream):
263
- while True:
264
- operation = self.write_buffer.get()
265
- if operation is None:
266
- continue
267
- self.mem_pool_host.transfer(operation.host_indices, operation.data)
268
- self.mem_pool_host.complete_io(operation.host_indices)
269
- for node_id in operation.node_ids:
388
+
389
+ while not self.stop_event.is_set():
390
+ operation = self.write_buffer.get()
391
+ if operation is None:
392
+ continue
393
+ self.mem_pool_host.assign_flat_data(operation.host_indices, operation.data)
394
+ self.mem_pool_host.complete_io(operation.host_indices)
395
+ for node_id in operation.node_ids:
396
+ if node_id != 0:
270
397
  self.ack_write_queue.put(node_id)
398
+ aux_thread.join()
271
399
 
272
400
  def load_thread_func_buffer(self):
273
401
  aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
274
402
  aux_thread.start()
403
+
275
404
  with torch.cuda.stream(self.load_stream):
276
- while True:
405
+ while not self.stop_event.is_set():
277
406
  operation = self.load_buffer.get()
278
407
  if operation is None:
279
408
  continue
280
409
  self.mem_pool_device.transfer(operation.device_indices, operation.data)
281
410
  self.mem_pool_host.complete_io(operation.host_indices)
282
411
  for node_id in operation.node_ids:
283
- self.ack_load_queue.put(node_id)
412
+ if node_id != 0:
413
+ self.ack_load_queue.put(node_id)
414
+ aux_thread.join()
284
415
 
285
416
  def evict_device(
286
417
  self, device_indices: torch.Tensor, host_indices: torch.Tensor
@@ -28,6 +28,7 @@ if __name__ == "__main__":
28
28
  parser = argparse.ArgumentParser()
29
29
  parser.add_argument("--url", type=str, default="http://localhost:30000")
30
30
  parser.add_argument("--log-requests", action="store_true")
31
+ parser.add_argument("--log-requests-level", type=int, default=2)
31
32
  parser.add_argument(
32
33
  "--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
33
34
  )
@@ -38,7 +39,7 @@ if __name__ == "__main__":
38
39
  args.url + "/configure_logging",
39
40
  json={
40
41
  "log_requests": args.log_requests,
41
- "log_requests_level": 1, # Log full requests
42
+ "log_requests_level": args.log_requests_level, # Log full requests
42
43
  "dump_requests_folder": args.dump_requests_folder,
43
44
  "dump_requests_threshold": args.dump_requests_threshold,
44
45
  },
@@ -121,7 +121,7 @@ class DataParallelController:
121
121
  args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
122
122
  )
123
123
  threads.append(thread)
124
- base_gpu_id += server_args.tp_size
124
+ base_gpu_id += server_args.tp_size * server_args.gpu_id_step
125
125
 
126
126
  # Free all sockets before starting the threads to launch TP workers
127
127
  for sock in sockets:
@@ -177,7 +177,11 @@ class DataParallelController:
177
177
  rank_port_args.nccl_port = port_args.nccl_port
178
178
 
179
179
  reader, writer = mp.Pipe(duplex=False)
180
- gpu_id = server_args.base_gpu_id + base_gpu_id + tp_rank % tp_size_per_node
180
+ gpu_id = (
181
+ server_args.base_gpu_id
182
+ + base_gpu_id
183
+ + (tp_rank % tp_size_per_node) * server_args.gpu_id_step
184
+ )
181
185
  proc = mp.Process(
182
186
  target=run_scheduler_process,
183
187
  args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),