sglang 0.4.3.post1__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +208 -295
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  64. sglang/srt/layers/moe/topk.py +13 -4
  65. sglang/srt/layers/quantization/__init__.py +111 -7
  66. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  69. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  71. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  72. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  73. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  75. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  80. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  82. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  83. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  86. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  87. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  89. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  90. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  91. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  93. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  94. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  95. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  96. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  97. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  98. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  99. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  100. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  101. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  102. sglang/srt/layers/quantization/fp8.py +69 -28
  103. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  104. sglang/srt/layers/quantization/gptq.py +416 -0
  105. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  106. sglang/srt/layers/quantization/int8_utils.py +73 -0
  107. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  108. sglang/srt/layers/radix_attention.py +1 -0
  109. sglang/srt/layers/rotary_embedding.py +0 -1
  110. sglang/srt/layers/sampler.py +76 -31
  111. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  112. sglang/srt/lora/lora.py +17 -1
  113. sglang/srt/lora/lora_config.py +5 -0
  114. sglang/srt/lora/lora_manager.py +1 -3
  115. sglang/srt/managers/cache_controller.py +193 -62
  116. sglang/srt/managers/configure_logging.py +2 -1
  117. sglang/srt/managers/data_parallel_controller.py +6 -2
  118. sglang/srt/managers/detokenizer_manager.py +124 -102
  119. sglang/srt/managers/image_processor.py +2 -1
  120. sglang/srt/managers/io_struct.py +143 -6
  121. sglang/srt/managers/schedule_batch.py +238 -197
  122. sglang/srt/managers/schedule_policy.py +29 -29
  123. sglang/srt/managers/scheduler.py +681 -259
  124. sglang/srt/managers/session_controller.py +6 -2
  125. sglang/srt/managers/tokenizer_manager.py +224 -68
  126. sglang/srt/managers/tp_worker.py +15 -4
  127. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  128. sglang/srt/mem_cache/chunk_cache.py +18 -11
  129. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  130. sglang/srt/mem_cache/memory_pool.py +44 -18
  131. sglang/srt/mem_cache/radix_cache.py +58 -47
  132. sglang/srt/metrics/collector.py +94 -36
  133. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  134. sglang/srt/model_executor/forward_batch_info.py +49 -16
  135. sglang/srt/model_executor/model_runner.py +209 -28
  136. sglang/srt/model_loader/loader.py +3 -3
  137. sglang/srt/model_loader/weight_utils.py +36 -14
  138. sglang/srt/models/baichuan.py +31 -6
  139. sglang/srt/models/chatglm.py +39 -7
  140. sglang/srt/models/commandr.py +29 -5
  141. sglang/srt/models/dbrx.py +31 -5
  142. sglang/srt/models/deepseek.py +43 -6
  143. sglang/srt/models/deepseek_nextn.py +32 -19
  144. sglang/srt/models/deepseek_v2.py +265 -29
  145. sglang/srt/models/exaone.py +19 -9
  146. sglang/srt/models/gemma.py +22 -8
  147. sglang/srt/models/gemma2.py +25 -12
  148. sglang/srt/models/gemma2_reward.py +5 -1
  149. sglang/srt/models/gpt2.py +28 -13
  150. sglang/srt/models/gpt_bigcode.py +27 -5
  151. sglang/srt/models/granite.py +21 -9
  152. sglang/srt/models/grok.py +21 -4
  153. sglang/srt/models/internlm2.py +36 -6
  154. sglang/srt/models/internlm2_reward.py +5 -1
  155. sglang/srt/models/llama.py +26 -9
  156. sglang/srt/models/llama_classification.py +5 -1
  157. sglang/srt/models/llama_eagle.py +17 -4
  158. sglang/srt/models/llama_embedding.py +5 -1
  159. sglang/srt/models/llama_reward.py +7 -2
  160. sglang/srt/models/llava.py +19 -3
  161. sglang/srt/models/llavavid.py +10 -1
  162. sglang/srt/models/minicpm.py +26 -2
  163. sglang/srt/models/minicpm3.py +39 -3
  164. sglang/srt/models/minicpmv.py +45 -14
  165. sglang/srt/models/mixtral.py +20 -9
  166. sglang/srt/models/mixtral_quant.py +50 -8
  167. sglang/srt/models/mllama.py +57 -11
  168. sglang/srt/models/olmo.py +34 -6
  169. sglang/srt/models/olmo2.py +34 -13
  170. sglang/srt/models/olmoe.py +26 -4
  171. sglang/srt/models/phi3_small.py +29 -10
  172. sglang/srt/models/qwen.py +26 -3
  173. sglang/srt/models/qwen2.py +26 -4
  174. sglang/srt/models/qwen2_5_vl.py +46 -8
  175. sglang/srt/models/qwen2_eagle.py +17 -5
  176. sglang/srt/models/qwen2_moe.py +44 -6
  177. sglang/srt/models/qwen2_rm.py +78 -0
  178. sglang/srt/models/qwen2_vl.py +39 -8
  179. sglang/srt/models/stablelm.py +32 -5
  180. sglang/srt/models/torch_native_llama.py +5 -2
  181. sglang/srt/models/xverse.py +21 -9
  182. sglang/srt/models/xverse_moe.py +45 -7
  183. sglang/srt/models/yivl.py +2 -1
  184. sglang/srt/openai_api/adapter.py +109 -24
  185. sglang/srt/openai_api/protocol.py +17 -1
  186. sglang/srt/reasoning_parser.py +154 -0
  187. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  188. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  189. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  190. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  191. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  192. sglang/srt/sampling/sampling_batch_info.py +79 -157
  193. sglang/srt/sampling/sampling_params.py +16 -13
  194. sglang/srt/server_args.py +136 -52
  195. sglang/srt/speculative/build_eagle_tree.py +2 -8
  196. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  197. sglang/srt/speculative/eagle_utils.py +92 -58
  198. sglang/srt/speculative/eagle_worker.py +186 -94
  199. sglang/srt/speculative/spec_info.py +1 -13
  200. sglang/srt/utils.py +43 -17
  201. sglang/srt/warmup.py +47 -0
  202. sglang/test/few_shot_gsm8k.py +4 -1
  203. sglang/test/runners.py +389 -126
  204. sglang/test/send_one.py +88 -0
  205. sglang/test/test_block_fp8_ep.py +361 -0
  206. sglang/test/test_programs.py +1 -1
  207. sglang/test/test_utils.py +138 -84
  208. sglang/utils.py +50 -60
  209. sglang/version.py +1 -1
  210. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  211. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +214 -166
  212. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  213. sglang/bench_latency.py +0 -1
  214. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  215. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  216. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  217. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  218. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  219. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -9,9 +9,6 @@ import torch
9
9
 
10
10
  import sglang.srt.sampling.penaltylib as penaltylib
11
11
  from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
12
- from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import (
13
- apply_scaling_penalties,
14
- )
15
12
 
16
13
  logger = logging.getLogger(__name__)
17
14
 
@@ -22,49 +19,45 @@ if TYPE_CHECKING:
22
19
 
23
20
  @dataclasses.dataclass
24
21
  class SamplingBatchInfo:
25
- # Batched sampling params
22
+ # Basic batched sampling params
26
23
  temperatures: torch.Tensor
27
24
  top_ps: torch.Tensor
28
25
  top_ks: torch.Tensor
29
26
  min_ps: torch.Tensor
30
27
 
31
- # All requests use greedy sampling
28
+ # Whether all requests use greedy sampling
32
29
  is_all_greedy: bool
33
30
 
34
- # Dispatch in CUDA graph
31
+ # Whether any request needs min_p sampling
35
32
  need_min_p_sampling: bool
36
33
 
37
- # Whether any request has custom logit processor
38
- has_custom_logit_processor: bool
39
-
40
- # Bias Tensors
34
+ # Masking tensors for grammar-guided structured outputs
41
35
  vocab_size: int
42
36
  grammars: Optional[List] = None
43
- sampling_info_done: Optional[threading.Event] = None
44
- logit_bias: torch.Tensor = None
45
37
  vocab_mask: Optional[torch.Tensor] = None
46
- apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
38
+ apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
39
+
40
+ # An event used for overlap schedule
41
+ sampling_info_done: Optional[threading.Event] = None
47
42
 
48
43
  # Penalizer
49
44
  penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
50
- linear_penalties: Optional[torch.Tensor] = None
51
- scaling_penalties: Optional[torch.Tensor] = None
45
+ linear_penalty: torch.Tensor = None
52
46
 
53
- # Device
54
- device: str = "cuda"
55
-
56
- # Custom Parameters
47
+ # Whether any request has custom logit processor
48
+ has_custom_logit_processor: bool = False
49
+ # Custom parameters
57
50
  custom_params: Optional[List[Optional[Dict[str, Any]]]] = None
58
-
59
- # Custom Logit Processor
51
+ # Custom logit processor
60
52
  custom_logit_processor: Optional[
61
53
  Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]
62
54
  ] = None
63
55
 
56
+ # Device
57
+ device: str = "cuda"
58
+
64
59
  @classmethod
65
- def from_schedule_batch(
66
- cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool
67
- ):
60
+ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
68
61
  reqs = batch.reqs
69
62
  device = batch.device
70
63
  temperatures = (
@@ -118,106 +111,60 @@ class SamplingBatchInfo:
118
111
  merged_custom_logit_processor = None
119
112
  custom_params = None
120
113
 
121
- ret = cls(
122
- temperatures=temperatures,
123
- top_ps=top_ps,
124
- top_ks=top_ks,
125
- min_ps=min_ps,
126
- need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
127
- is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
128
- has_custom_logit_processor=has_custom_logit_processor,
129
- vocab_size=vocab_size,
130
- device=device,
131
- custom_params=custom_params,
132
- custom_logit_processor=merged_custom_logit_processor,
133
- )
134
- # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
135
-
136
- if enable_overlap_schedule:
137
- # TODO (lianmin): Some penalizers such as frequency and presence depend on model outputs,
138
- # so it is kind of tricky to make it work with overlap scheduler.
139
- # It requires correcly updating the penalty logits before the sampling and syncing the events.
140
- # We will support them later.
141
- penalizers = {
142
- penaltylib.BatchedMinNewTokensPenalizer,
143
- }
144
- if (
145
- any(req.sampling_params.frequency_penalty != 0.0 for req in reqs)
146
- or any(req.sampling_params.presence_penalty != 0.0 for req in reqs)
147
- or any(req.sampling_params.repetition_penalty != 1.0 for req in reqs)
148
- ):
149
- logger.warning(
150
- "frequency_penalty, presence_penalty, and repetition_penalty are not supported "
151
- "when using the default overlap scheduler. They will be ignored. "
152
- "Please add `--disable-overlap` when launching the server if you need these features. "
153
- "The speed will be slower in that case."
154
- )
155
- else:
156
- penalizers = {
157
- penaltylib.BatchedFrequencyPenalizer,
158
- penaltylib.BatchedMinNewTokensPenalizer,
159
- penaltylib.BatchedPresencePenalizer,
160
- penaltylib.BatchedRepetitionPenalizer,
161
- }
162
-
163
114
  # Each penalizers will do nothing if they evaluate themselves as not required by looking at
164
115
  # the sampling_params of the requests (See {_is_required()} of each penalizers). So this
165
116
  # should not add hefty computation overhead other than simple checks.
166
117
  #
167
- # While we choose not to even create the class instances if they are not required, this
118
+ # While we can choose not to even create the class instances if they are not required, this
168
119
  # could add additional complexity to the {ScheduleBatch} class, especially we need to
169
120
  # handle {filter_batch()} and {merge_batch()} cases as well.
170
- ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
121
+ penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
171
122
  vocab_size=vocab_size,
172
123
  batch=batch,
173
- device=batch.device,
174
- Penalizers=penalizers,
124
+ penalizers={
125
+ penaltylib.BatchedFrequencyPenalizer,
126
+ penaltylib.BatchedMinNewTokensPenalizer,
127
+ penaltylib.BatchedPresencePenalizer,
128
+ },
175
129
  )
176
130
 
177
- # Handle logit bias but only allocate when needed
178
- ret.logit_bias = None
179
-
131
+ ret = cls(
132
+ temperatures=temperatures,
133
+ top_ps=top_ps,
134
+ top_ks=top_ks,
135
+ min_ps=min_ps,
136
+ is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
137
+ need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
138
+ vocab_size=vocab_size,
139
+ penalizer_orchestrator=penalizer_orchestrator,
140
+ has_custom_logit_processor=has_custom_logit_processor,
141
+ custom_params=custom_params,
142
+ custom_logit_processor=merged_custom_logit_processor,
143
+ device=device,
144
+ )
180
145
  return ret
181
146
 
182
147
  def __len__(self):
183
148
  return len(self.temperatures)
184
149
 
185
- def update_penalties(self):
186
- self.scaling_penalties = None
187
- self.linear_penalties = None
188
-
189
- for penalizer in self.penalizer_orchestrator.penalizers.values():
190
- if not penalizer.is_prepared():
191
- continue
192
-
193
- if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer):
194
- self.scaling_penalties = penalizer.cumulated_repetition_penalties
195
- else:
196
- if self.linear_penalties is None:
197
- bs = self.penalizer_orchestrator.batch.batch_size()
198
- self.linear_penalties = torch.zeros(
199
- (bs, self.vocab_size),
200
- dtype=torch.float32,
201
- device=self.device,
202
- )
203
- self.linear_penalties = penalizer.apply(self.linear_penalties)
204
-
205
150
  def update_regex_vocab_mask(self):
206
151
  if not self.grammars:
207
152
  self.vocab_mask = None
208
- self.apply_mask = None
153
+ self.apply_mask_func = None
209
154
  return
210
155
 
211
- # find a grammar from the list
156
+ # Find a grammar from the list
212
157
  first_grammar = next(grammar for grammar in self.grammars if grammar)
213
158
 
214
- # maybe we can reuse the existing mask?
159
+ # TODO(lianmin): Maybe we can reuse the existing mask?
215
160
  self.vocab_mask = first_grammar.allocate_vocab_mask(
216
161
  vocab_size=self.vocab_size,
217
162
  batch_size=len(self.temperatures),
218
163
  device=self.device,
219
164
  )
220
- self.apply_mask = first_grammar.apply_vocab_mask # force to use static method
165
+ self.apply_mask_func = (
166
+ first_grammar.apply_vocab_mask
167
+ ) # force to use static method
221
168
 
222
169
  # Apply the mask
223
170
  for i, grammar in enumerate(self.grammars):
@@ -227,35 +174,56 @@ class SamplingBatchInfo:
227
174
  # Move the mask to the device if needed
228
175
  self.vocab_mask = first_grammar.move_vocab_mask(self.vocab_mask, self.device)
229
176
 
230
- def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
231
- self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
177
+ def update_penalties(self):
178
+ if self.penalizer_orchestrator.is_required:
179
+ self.linear_penalty = torch.zeros(
180
+ (len(self.temperatures), self.vocab_size),
181
+ dtype=torch.float32,
182
+ device=self.temperatures.device,
183
+ )
184
+ self.penalizer_orchestrator.apply(self.linear_penalty)
185
+ else:
186
+ self.linear_penalty = None
187
+
188
+ def apply_logits_bias(self, logits: torch.Tensor):
189
+ if self.linear_penalty is not None:
190
+ # Used in the overlap mode
191
+ logits.add_(self.linear_penalty)
192
+
193
+ if self.penalizer_orchestrator and self.penalizer_orchestrator.is_required:
194
+ # Used in the non-overlap mode
195
+ self.penalizer_orchestrator.apply(logits)
196
+
197
+ if self.vocab_mask is not None:
198
+ self.apply_mask_func(logits=logits, vocab_mask=self.vocab_mask)
199
+
200
+ def filter_batch(self, keep_indices: List[int], keep_indices_device: torch.Tensor):
201
+ self.penalizer_orchestrator.filter(keep_indices_device)
202
+
232
203
  if self.has_custom_logit_processor:
233
- self._filter_batch_custom_logit_processor(unfinished_indices, new_indices)
204
+ self._filter_batch_custom_logit_processor(keep_indices, keep_indices_device)
234
205
 
235
206
  for item in [
236
207
  "temperatures",
237
208
  "top_ps",
238
209
  "top_ks",
239
210
  "min_ps",
240
- "logit_bias",
241
211
  ]:
242
212
  value = getattr(self, item, None)
243
- if value is not None: # logit_bias can be None
244
- setattr(self, item, value[new_indices])
213
+ setattr(self, item, value[keep_indices_device])
245
214
 
246
215
  def _filter_batch_custom_logit_processor(
247
- self, unfinished_indices: List[int], new_indices: torch.Tensor
216
+ self, keep_indices: List[int], keep_indices_device: torch.Tensor
248
217
  ):
249
218
  """Filter the custom logit processor and custom params"""
250
-
251
219
  self.custom_logit_processor = {
252
- k: (p, mask[new_indices])
220
+ k: (p, mask[keep_indices_device])
253
221
  for k, (p, mask) in self.custom_logit_processor.items()
254
- if any(
255
- mask[new_indices]
222
+ if torch.any(
223
+ mask[keep_indices_device]
256
224
  ) # ignore the custom logit processor whose mask is all False
257
225
  }
258
- self.custom_params = [self.custom_params[i] for i in unfinished_indices]
226
+ self.custom_params = [self.custom_params[i] for i in keep_indices]
259
227
 
260
228
  # If the custom logit processor is an empty dict, set the flag to False,
261
229
  # and set the custom logit processor and custom params to None.
@@ -264,31 +232,6 @@ class SamplingBatchInfo:
264
232
  self.custom_params = None
265
233
  self.has_custom_logit_processor = False
266
234
 
267
- @staticmethod
268
- def merge_bias_tensor(
269
- lhs: torch.Tensor,
270
- rhs: torch.Tensor,
271
- bs1: int,
272
- bs2: int,
273
- device: str,
274
- default: int = 0,
275
- ):
276
- # bias tensor can be None
277
- if lhs is not None or rhs is not None:
278
- shape, dtype = None, None
279
- if lhs is not None:
280
- shape, dtype = lhs.shape[1:], lhs.dtype
281
- else:
282
- shape, dtype = rhs.shape[1:], rhs.dtype
283
- with torch.dtype(dtype):
284
- if lhs is None:
285
- lhs = torch.empty((bs1, *shape), device=device).fill_(default)
286
- if rhs is None:
287
- rhs = torch.empty((bs2, *shape), device=device).fill_(default)
288
- return torch.cat([lhs, rhs])
289
-
290
- return None
291
-
292
235
  @staticmethod
293
236
  def merge_custom_logit_processor(
294
237
  lhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
@@ -332,10 +275,6 @@ class SamplingBatchInfo:
332
275
  def merge_batch(self, other: "SamplingBatchInfo"):
333
276
  self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
334
277
 
335
- # Merge the logit bias tensor
336
- self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
337
- self.logit_bias, other.logit_bias, len(self), len(other), self.device
338
- )
339
278
  # Merge the custom logit processors and custom params lists
340
279
  if self.has_custom_logit_processor or other.has_custom_logit_processor:
341
280
  # Merge the custom logit processors
@@ -369,22 +308,5 @@ class SamplingBatchInfo:
369
308
  other_val = getattr(other, item, None)
370
309
  setattr(self, item, torch.concat([self_val, other_val]))
371
310
 
372
- self.is_all_greedy = self.is_all_greedy and other.is_all_greedy
373
- self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
374
-
375
- def apply_logits_bias(self, logits: torch.Tensor):
376
- # Apply logit_bias
377
- if self.logit_bias is not None:
378
- logits.add_(self.logit_bias)
379
-
380
- # min-token, presence, frequency
381
- if self.linear_penalties is not None:
382
- logits.add_(self.linear_penalties)
383
-
384
- # repetition
385
- if self.scaling_penalties is not None:
386
- apply_scaling_penalties(logits, self.scaling_penalties)
387
-
388
- # Apply regex vocab_mask
389
- if self.vocab_mask is not None:
390
- self.apply_mask(logits=logits, vocab_mask=self.vocab_mask)
311
+ self.is_all_greedy |= other.is_all_greedy
312
+ self.need_min_p_sampling |= other.need_min_p_sampling
@@ -22,8 +22,8 @@ class SamplingParams:
22
22
  """
23
23
  The sampling parameters.
24
24
 
25
- See docs/references/sampling_params.md or
26
- https://docs.sglang.ai/references/sampling_params.html
25
+ See docs/backend/sampling_params.md or
26
+ https://docs.sglang.ai/backend/sampling_params.html
27
27
  for the documentation.
28
28
  """
29
29
 
@@ -40,16 +40,23 @@ class SamplingParams:
40
40
  presence_penalty: float = 0.0,
41
41
  repetition_penalty: float = 1.0,
42
42
  min_new_tokens: int = 0,
43
- spaces_between_special_tokens: bool = True,
44
43
  n: int = 1,
45
44
  json_schema: Optional[str] = None,
46
45
  regex: Optional[str] = None,
47
46
  ebnf: Optional[str] = None,
48
- no_stop_trim: bool = False,
47
+ structural_tag: Optional[str] = None,
49
48
  ignore_eos: bool = False,
50
49
  skip_special_tokens: bool = True,
50
+ spaces_between_special_tokens: bool = True,
51
+ no_stop_trim: bool = False,
51
52
  custom_params: Optional[Dict[str, Any]] = None,
52
53
  ) -> None:
54
+ self.max_new_tokens = max_new_tokens
55
+ self.stop_strs = stop
56
+ if stop_token_ids:
57
+ self.stop_token_ids = set(stop_token_ids)
58
+ else:
59
+ self.stop_token_ids = None
53
60
  self.temperature = temperature
54
61
  self.top_p = top_p
55
62
  self.top_k = top_k
@@ -57,25 +64,21 @@ class SamplingParams:
57
64
  self.frequency_penalty = frequency_penalty
58
65
  self.presence_penalty = presence_penalty
59
66
  self.repetition_penalty = repetition_penalty
60
- self.stop_strs = stop
61
- if stop_token_ids:
62
- self.stop_token_ids = set(stop_token_ids)
63
- else:
64
- self.stop_token_ids = None
65
- self.max_new_tokens = max_new_tokens
66
67
  self.min_new_tokens = min_new_tokens
67
- self.ignore_eos = ignore_eos
68
- self.skip_special_tokens = skip_special_tokens
69
- self.spaces_between_special_tokens = spaces_between_special_tokens
70
68
  self.regex = regex
71
69
  self.n = n
72
70
  self.json_schema = json_schema
73
71
  self.ebnf = ebnf
72
+ self.structural_tag = structural_tag
73
+ self.ignore_eos = ignore_eos
74
+ self.skip_special_tokens = skip_special_tokens
75
+ self.spaces_between_special_tokens = spaces_between_special_tokens
74
76
  self.no_stop_trim = no_stop_trim
75
77
  self.custom_params = custom_params
76
78
 
77
79
  # Process some special cases
78
80
  if self.temperature < _SAMPLING_EPS:
81
+ # top_k = 1 means greedy sampling
79
82
  self.temperature = 1.0
80
83
  self.top_k = 1
81
84
  if self.top_k == -1: