sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +220 -378
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  63. sglang/srt/layers/moe/topk.py +13 -4
  64. sglang/srt/layers/quantization/__init__.py +111 -7
  65. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  66. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  69. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  71. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  72. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  73. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  86. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/fp8.py +69 -28
  89. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  90. sglang/srt/layers/quantization/gptq.py +416 -0
  91. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  92. sglang/srt/layers/quantization/int8_utils.py +73 -0
  93. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  94. sglang/srt/layers/radix_attention.py +1 -0
  95. sglang/srt/layers/rotary_embedding.py +0 -1
  96. sglang/srt/layers/sampler.py +76 -31
  97. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  98. sglang/srt/lora/lora.py +17 -1
  99. sglang/srt/lora/lora_config.py +5 -0
  100. sglang/srt/lora/lora_manager.py +1 -3
  101. sglang/srt/managers/cache_controller.py +193 -62
  102. sglang/srt/managers/configure_logging.py +2 -1
  103. sglang/srt/managers/data_parallel_controller.py +6 -2
  104. sglang/srt/managers/detokenizer_manager.py +124 -102
  105. sglang/srt/managers/image_processor.py +2 -1
  106. sglang/srt/managers/io_struct.py +143 -6
  107. sglang/srt/managers/schedule_batch.py +237 -197
  108. sglang/srt/managers/schedule_policy.py +29 -29
  109. sglang/srt/managers/scheduler.py +681 -259
  110. sglang/srt/managers/session_controller.py +6 -2
  111. sglang/srt/managers/tokenizer_manager.py +224 -68
  112. sglang/srt/managers/tp_worker.py +15 -4
  113. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  114. sglang/srt/mem_cache/chunk_cache.py +18 -11
  115. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  116. sglang/srt/mem_cache/memory_pool.py +44 -18
  117. sglang/srt/mem_cache/radix_cache.py +58 -47
  118. sglang/srt/metrics/collector.py +94 -36
  119. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  120. sglang/srt/model_executor/forward_batch_info.py +49 -16
  121. sglang/srt/model_executor/model_runner.py +208 -28
  122. sglang/srt/model_loader/loader.py +3 -3
  123. sglang/srt/model_loader/weight_utils.py +36 -14
  124. sglang/srt/models/baichuan.py +31 -6
  125. sglang/srt/models/chatglm.py +39 -7
  126. sglang/srt/models/commandr.py +29 -5
  127. sglang/srt/models/dbrx.py +31 -5
  128. sglang/srt/models/deepseek.py +43 -6
  129. sglang/srt/models/deepseek_nextn.py +32 -19
  130. sglang/srt/models/deepseek_v2.py +265 -32
  131. sglang/srt/models/exaone.py +19 -9
  132. sglang/srt/models/gemma.py +22 -8
  133. sglang/srt/models/gemma2.py +25 -12
  134. sglang/srt/models/gemma2_reward.py +5 -1
  135. sglang/srt/models/gpt2.py +28 -13
  136. sglang/srt/models/gpt_bigcode.py +27 -5
  137. sglang/srt/models/granite.py +21 -9
  138. sglang/srt/models/grok.py +21 -4
  139. sglang/srt/models/internlm2.py +36 -6
  140. sglang/srt/models/internlm2_reward.py +5 -1
  141. sglang/srt/models/llama.py +26 -9
  142. sglang/srt/models/llama_classification.py +5 -1
  143. sglang/srt/models/llama_eagle.py +17 -4
  144. sglang/srt/models/llama_embedding.py +5 -1
  145. sglang/srt/models/llama_reward.py +7 -2
  146. sglang/srt/models/llava.py +19 -3
  147. sglang/srt/models/llavavid.py +10 -1
  148. sglang/srt/models/minicpm.py +26 -2
  149. sglang/srt/models/minicpm3.py +39 -3
  150. sglang/srt/models/minicpmv.py +45 -14
  151. sglang/srt/models/mixtral.py +20 -9
  152. sglang/srt/models/mixtral_quant.py +50 -8
  153. sglang/srt/models/mllama.py +57 -11
  154. sglang/srt/models/olmo.py +34 -6
  155. sglang/srt/models/olmo2.py +34 -13
  156. sglang/srt/models/olmoe.py +26 -4
  157. sglang/srt/models/phi3_small.py +29 -10
  158. sglang/srt/models/qwen.py +26 -3
  159. sglang/srt/models/qwen2.py +26 -4
  160. sglang/srt/models/qwen2_5_vl.py +46 -8
  161. sglang/srt/models/qwen2_eagle.py +17 -5
  162. sglang/srt/models/qwen2_moe.py +44 -6
  163. sglang/srt/models/qwen2_rm.py +78 -0
  164. sglang/srt/models/qwen2_vl.py +39 -8
  165. sglang/srt/models/stablelm.py +32 -5
  166. sglang/srt/models/torch_native_llama.py +5 -2
  167. sglang/srt/models/xverse.py +21 -9
  168. sglang/srt/models/xverse_moe.py +45 -7
  169. sglang/srt/models/yivl.py +2 -1
  170. sglang/srt/openai_api/adapter.py +109 -24
  171. sglang/srt/openai_api/protocol.py +17 -1
  172. sglang/srt/reasoning_parser.py +154 -0
  173. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  174. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  175. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  176. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  177. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  178. sglang/srt/sampling/sampling_batch_info.py +79 -157
  179. sglang/srt/sampling/sampling_params.py +16 -13
  180. sglang/srt/server_args.py +136 -52
  181. sglang/srt/speculative/build_eagle_tree.py +2 -8
  182. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  183. sglang/srt/speculative/eagle_utils.py +92 -58
  184. sglang/srt/speculative/eagle_worker.py +186 -94
  185. sglang/srt/speculative/spec_info.py +1 -13
  186. sglang/srt/utils.py +43 -17
  187. sglang/srt/warmup.py +47 -0
  188. sglang/test/few_shot_gsm8k.py +4 -1
  189. sglang/test/runners.py +389 -126
  190. sglang/test/send_one.py +88 -0
  191. sglang/test/test_block_fp8_ep.py +361 -0
  192. sglang/test/test_programs.py +1 -1
  193. sglang/test/test_utils.py +138 -84
  194. sglang/utils.py +50 -60
  195. sglang/version.py +1 -1
  196. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  197. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +200 -166
  198. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  199. sglang/bench_latency.py +0 -1
  200. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  201. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  202. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  203. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  204. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  205. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,66 @@
1
+ import torch
2
+
3
+ from sglang.srt.sampling.penaltylib.orchestrator import (
4
+ BatchedPenalizerOrchestrator,
5
+ _BatchedPenalizer,
6
+ )
7
+
8
+
9
+ class BatchedFrequencyPenalizer(_BatchedPenalizer):
10
+ """
11
+ Frequency penalizer penalizes tokens based on their frequency in the output.
12
+ """
13
+
14
+ def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
15
+ self.orchestrator = orchestrator
16
+ self._is_prepared = False
17
+
18
+ def _is_required(self) -> bool:
19
+ return any(
20
+ req.sampling_params.frequency_penalty != 0.0
21
+ for req in self.orchestrator.reqs()
22
+ )
23
+
24
+ def _prepare(self):
25
+ self.cumulated_frequency_penalties = torch.zeros(
26
+ (len(self.orchestrator.reqs()), self.orchestrator.vocab_size),
27
+ dtype=torch.float32,
28
+ device=self.orchestrator.device,
29
+ )
30
+
31
+ self.frequency_penalties = (
32
+ torch.tensor(
33
+ data=[
34
+ req.sampling_params.frequency_penalty
35
+ for req in self.orchestrator.reqs()
36
+ ],
37
+ dtype=torch.float32,
38
+ device=self.orchestrator.device,
39
+ )
40
+ ).unsqueeze_(1)
41
+
42
+ def _cumulate_output_tokens(self, output_ids: torch.Tensor):
43
+ self.cumulated_frequency_penalties.scatter_add_(
44
+ dim=1,
45
+ index=output_ids.unsqueeze(1),
46
+ src=self.frequency_penalties,
47
+ )
48
+
49
+ def _apply(self, logits: torch.Tensor) -> torch.Tensor:
50
+ logits.sub_(self.cumulated_frequency_penalties)
51
+
52
+ def _filter(self, keep_indices: torch.Tensor):
53
+ self.frequency_penalties = self.frequency_penalties[keep_indices]
54
+ self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
55
+ keep_indices
56
+ ]
57
+
58
+ def _merge(self, their: "BatchedFrequencyPenalizer"):
59
+ print(f"{self.frequency_penalties.shape=}, {their.frequency_penalties.shape=}")
60
+ self.frequency_penalties = torch.cat(
61
+ [self.frequency_penalties, their.frequency_penalties], dim=0
62
+ )
63
+ self.cumulated_frequency_penalties = torch.cat(
64
+ [self.cumulated_frequency_penalties, their.cumulated_frequency_penalties],
65
+ dim=0,
66
+ )
@@ -1,8 +1,9 @@
1
- from typing import List
2
-
3
1
  import torch
4
2
 
5
- from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
3
+ from sglang.srt.sampling.penaltylib.orchestrator import (
4
+ BatchedPenalizerOrchestrator,
5
+ _BatchedPenalizer,
6
+ )
6
7
 
7
8
 
8
9
  class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
@@ -10,9 +11,9 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
10
11
  Min new tokens penalizer penalizes tokens based on the length of the output.
11
12
  """
12
13
 
13
- min_new_tokens: torch.Tensor = None
14
- stop_token_penalties: torch.Tensor = None
15
- len_output_tokens: torch.Tensor = None
14
+ def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
15
+ self.orchestrator = orchestrator
16
+ self._is_prepared = False
16
17
 
17
18
  def _is_required(self) -> bool:
18
19
  return any(
@@ -47,7 +48,7 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
47
48
  padding_value=self.orchestrator.vocab_size,
48
49
  )
49
50
  self.stop_token_penalties = torch.zeros(
50
- size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
51
+ size=(len(self.orchestrator.reqs()), self.orchestrator.vocab_size + 1),
51
52
  dtype=torch.float32,
52
53
  device=self.orchestrator.device,
53
54
  ).scatter_add_(
@@ -64,31 +65,22 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
64
65
  ]
65
66
 
66
67
  self.len_output_tokens = torch.zeros(
67
- size=(self.orchestrator.batch_size(), 1),
68
+ size=(len(self.orchestrator.reqs()), 1),
68
69
  dtype=torch.int32,
69
70
  device=self.orchestrator.device,
70
71
  )
71
72
 
72
- def _teardown(self):
73
- self.min_new_tokens = None
74
- self.stop_token_penalties = None
75
- self.len_output_tokens = None
76
-
77
- def _cumulate_input_tokens(self, input_ids: _TokenIDs):
78
- pass
79
-
80
- def _cumulate_output_tokens(self, output_ids: _TokenIDs):
73
+ def _cumulate_output_tokens(self, output_ids: torch.Tensor):
81
74
  self.len_output_tokens += 1
82
75
 
83
- def _apply(self, logits: torch.Tensor) -> torch.Tensor:
76
+ def _apply(self, logits: torch.Tensor):
84
77
  mask = (self.len_output_tokens < self.min_new_tokens).expand_as(logits)
85
78
  logits[mask] += self.stop_token_penalties[mask]
86
- return logits
87
79
 
88
- def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
89
- self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep]
90
- self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep]
91
- self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep]
80
+ def _filter(self, keep_indices: torch.Tensor):
81
+ self.min_new_tokens = self.min_new_tokens[keep_indices]
82
+ self.stop_token_penalties = self.stop_token_penalties[keep_indices]
83
+ self.len_output_tokens = self.len_output_tokens[keep_indices]
92
84
 
93
85
  def _merge(self, their: "BatchedMinNewTokensPenalizer"):
94
86
  self.min_new_tokens = torch.cat(
@@ -1,35 +1,25 @@
1
+ from __future__ import annotations
2
+
1
3
  import abc
2
- import dataclasses
3
- from typing import List, Set, Type, Union
4
+ from typing import TYPE_CHECKING, Set, Type
4
5
 
5
6
  import torch
6
7
 
7
-
8
- @dataclasses.dataclass
9
- class _ReqLike:
10
- origin_input_ids: List[int]
11
-
12
-
13
- @dataclasses.dataclass
14
- class _BatchLike:
15
- reqs: List[_ReqLike]
16
-
17
- def batch_size(self):
18
- return len(self.reqs)
8
+ if TYPE_CHECKING:
9
+ from sglang.srt.managers.schedule_batch import ScheduleBatch
19
10
 
20
11
 
21
12
  class BatchedPenalizerOrchestrator:
22
13
  def __init__(
23
14
  self,
24
15
  vocab_size: int,
25
- batch: _BatchLike,
26
- device: str,
27
- Penalizers: Set[Type["_BatchedPenalizer"]],
16
+ batch: ScheduleBatch,
17
+ penalizers: Set[Type["_BatchedPenalizer"]],
28
18
  ):
29
19
  self.vocab_size = vocab_size
30
20
  self.batch = batch
31
- self.device = device
32
- self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}
21
+ self.device = batch.device
22
+ self.penalizers = {Penalizer: Penalizer(self) for Penalizer in penalizers}
33
23
 
34
24
  is_required = False
35
25
  for penalizer in self.penalizers.values():
@@ -37,31 +27,9 @@ class BatchedPenalizerOrchestrator:
37
27
  is_required |= pen_is_required
38
28
  self.is_required = is_required
39
29
 
40
- input_ids = [
41
- torch.tensor(req.origin_input_ids, dtype=torch.int64, device=self.device)
42
- for req in self.reqs()
43
- ]
44
- if self.is_required:
45
- self.cumulate_input_tokens(input_ids=input_ids)
46
-
47
30
  def reqs(self):
48
31
  return self.batch.reqs
49
32
 
50
- def batch_size(self):
51
- return self.batch.batch_size()
52
-
53
- def cumulate_input_tokens(self, input_ids: List[torch.Tensor]):
54
- """
55
- Feed the input tokens to the penalizers.
56
-
57
- Args:
58
- input_ids (List[torch.Tensor]): The input tokens.
59
- """
60
- token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids)
61
-
62
- for penalizer in self.penalizers.values():
63
- penalizer.cumulate_input_tokens(input_ids=token_ids)
64
-
65
33
  def cumulate_output_tokens(self, output_ids: torch.Tensor):
66
34
  """
67
35
  Feed the output tokens to the penalizers.
@@ -69,13 +37,8 @@ class BatchedPenalizerOrchestrator:
69
37
  Args:
70
38
  output_ids (torch.Tensor): The output tokens.
71
39
  """
72
- if not self.is_required:
73
- return
74
-
75
- token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids)
76
-
77
40
  for penalizer in self.penalizers.values():
78
- penalizer.cumulate_output_tokens(output_ids=token_ids)
41
+ penalizer.cumulate_output_tokens(output_ids=output_ids)
79
42
 
80
43
  def apply(self, logits: torch.Tensor) -> torch.Tensor:
81
44
  """
@@ -88,48 +51,33 @@ class BatchedPenalizerOrchestrator:
88
51
  Returns:
89
52
  torch.Tensor: The logits after applying the penalizers.
90
53
  """
91
- if not self.is_required:
92
- return
93
-
94
54
  for penalizer in self.penalizers.values():
95
- logits = penalizer.apply(logits)
96
-
97
- return logits
55
+ penalizer.apply(logits)
98
56
 
99
- def filter(
100
- self,
101
- indices_to_keep: List[int],
102
- indices_tensor_to_keep: torch.Tensor = None,
103
- ):
57
+ def filter(self, keep_indices: torch.Tensor):
104
58
  """
105
59
  Filter the penalizers based on the indices to keep in the batch.
106
60
 
107
61
  Args:
108
- indices_to_keep (List[int]): List of indices to keep in the batch.
109
- indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
62
+ keep_indices (torch.Tensor): Tensor of indices to keep in the batch.
110
63
  """
111
64
  if not self.is_required:
112
65
  return
113
66
 
114
- empty_indices = len(indices_to_keep) == 0
67
+ if len(keep_indices) == 0:
68
+ self.is_required = False
69
+ for penalizer in self.penalizers.values():
70
+ penalizer.teardown()
71
+ return
115
72
 
116
73
  is_required = False
117
74
  for penalizer in self.penalizers.values():
118
75
  tmp_is_required = penalizer.is_required()
119
- is_required = is_required or tmp_is_required
120
- if not tmp_is_required or empty_indices:
121
- penalizer.teardown()
76
+ is_required |= tmp_is_required
77
+ if tmp_is_required:
78
+ penalizer.filter(keep_indices=keep_indices)
122
79
  else:
123
- # create tensor index only when it's needed
124
- if indices_tensor_to_keep is None:
125
- indices_tensor_to_keep = torch.tensor(
126
- indices_to_keep, dtype=torch.int32, device=self.device
127
- )
128
-
129
- penalizer.filter(
130
- indices_to_keep=indices_to_keep,
131
- indices_tensor_to_keep=indices_tensor_to_keep,
132
- )
80
+ penalizer.teardown()
133
81
  self.is_required = is_required
134
82
 
135
83
  def merge(self, their: "BatchedPenalizerOrchestrator"):
@@ -146,75 +94,9 @@ class BatchedPenalizerOrchestrator:
146
94
  if not self.is_required and not their.is_required:
147
95
  return
148
96
 
149
- self.is_required |= their.is_required
150
- for Penalizer, their_penalizer in their.penalizers.items():
151
- if Penalizer not in self.penalizers:
152
- raise ValueError(f"Penalizer {Penalizer} not found in self.penalizers")
153
-
154
- self.penalizers[Penalizer].merge(their_penalizer)
155
-
156
-
157
- class _TokenIDs:
158
- """
159
- A class that wraps token IDs to provide additional utility functions to penalizers.
160
-
161
- Attributes:
162
- orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to.
163
- token_ids (Union[torch.Tensor, List[torch.Tensor]]): The token IDs.
164
- cached_counts (torch.Tensor): The cached occurrence count tensor.
165
- """
166
-
167
- def __init__(
168
- self,
169
- orchestrator: BatchedPenalizerOrchestrator,
170
- token_ids: Union[torch.Tensor, List[torch.Tensor]],
171
- ):
172
- self.orchestrator = orchestrator
173
- self.token_ids = token_ids
174
- self.cached_counts = None
175
-
176
- def occurrence_count(self) -> torch.Tensor:
177
- """
178
- Returns a tensor of shape (batch_size, vocab_size) where each element is the number of times the corresponding token appears in the batch.
179
-
180
- Returns:
181
- torch.Tensor: The occurrence count tensor.
182
- """
183
- if self.cached_counts is not None:
184
- return self.cached_counts
185
-
186
- token_ids = self.token_ids
187
-
188
- if isinstance(token_ids, list):
189
- # TODO: optimize this part
190
- padded_token_ids = torch.nn.utils.rnn.pad_sequence(
191
- sequences=token_ids,
192
- batch_first=True,
193
- padding_value=self.orchestrator.vocab_size,
194
- )
195
- self.cached_counts = torch.zeros(
196
- size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
197
- dtype=torch.int64,
198
- device=self.orchestrator.device,
199
- ).scatter_add_(
200
- dim=1,
201
- index=padded_token_ids,
202
- src=torch.ones_like(padded_token_ids),
203
- )[
204
- :, : self.orchestrator.vocab_size
205
- ]
206
- else:
207
- # TODO: optimize this part. We do not need to create this big tensor every time.
208
- # We can directly apply the results on the logits.
209
- self.cached_counts = torch.zeros(
210
- size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size),
211
- device=self.orchestrator.device,
212
- )
213
- self.cached_counts[
214
- torch.arange(len(token_ids), device=self.orchestrator.device), token_ids
215
- ] = 1
216
-
217
- return self.cached_counts
97
+ self.is_required = True
98
+ for penalizer, their_penalizer in their.penalizers.items():
99
+ self.penalizers[penalizer].merge(their_penalizer)
218
100
 
219
101
 
220
102
  class _BatchedPenalizer(abc.ABC):
@@ -222,10 +104,6 @@ class _BatchedPenalizer(abc.ABC):
222
104
  An abstract class for a batched penalizer.
223
105
  """
224
106
 
225
- def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
226
- self.orchestrator = orchestrator
227
- self._is_prepared = False
228
-
229
107
  def is_prepared(self) -> bool:
230
108
  return self._is_prepared
231
109
 
@@ -233,51 +111,40 @@ class _BatchedPenalizer(abc.ABC):
233
111
  return self._is_required()
234
112
 
235
113
  def prepare(self):
236
- if not self.is_prepared():
114
+ if not self._is_prepared:
237
115
  self._prepare()
238
116
  self._is_prepared = True
239
117
 
240
118
  def prepare_if_required(self):
241
- if self.is_required():
119
+ if self._is_required():
242
120
  self.prepare()
243
121
  return True
244
122
  else:
245
123
  return False
246
124
 
247
125
  def teardown(self):
248
- if self.is_prepared():
249
- self._teardown()
250
- self._is_prepared = False
251
-
252
- def cumulate_input_tokens(self, input_ids: _TokenIDs):
253
- if not self.is_prepared():
254
- return
255
-
256
- self._cumulate_input_tokens(input_ids=input_ids)
126
+ self._is_prepared = False
257
127
 
258
- def cumulate_output_tokens(self, output_ids: _TokenIDs):
259
- if not self.is_prepared():
128
+ def cumulate_output_tokens(self, output_ids: torch.Tensor):
129
+ if not self._is_prepared:
260
130
  return
261
131
 
262
132
  self._cumulate_output_tokens(output_ids=output_ids)
263
133
 
264
134
  def apply(self, logits: torch.Tensor) -> torch.Tensor:
265
- if not self.is_prepared():
266
- return logits
135
+ if not self._is_prepared:
136
+ return
267
137
 
268
- return self._apply(logits=logits)
138
+ self._apply(logits=logits)
269
139
 
270
- def filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
271
- if not self.is_prepared():
140
+ def filter(self, keep_indices: torch.Tensor):
141
+ if not self._is_prepared:
272
142
  return
273
143
 
274
- self._filter(
275
- indices_to_keep=indices_to_keep,
276
- indices_tensor_to_keep=indices_tensor_to_keep,
277
- )
144
+ self._filter(keep_indices=keep_indices)
278
145
 
279
146
  def merge(self, their: "_BatchedPenalizer"):
280
- if not self.is_prepared() and not their.is_prepared():
147
+ if not self._is_prepared and not their._is_prepared:
281
148
  return
282
149
 
283
150
  self.prepare()
@@ -300,23 +167,7 @@ class _BatchedPenalizer(abc.ABC):
300
167
  pass
301
168
 
302
169
  @abc.abstractmethod
303
- def _teardown(self):
304
- """
305
- Tear down the penalizer.
306
- Usually, this is where the penalizer frees its tensors.
307
- """
308
- pass
309
-
310
- @abc.abstractmethod
311
- def _cumulate_input_tokens(self, input_ids: _TokenIDs):
312
- """
313
- Cumulate the input tokens.
314
- Orchestrator will call this function to feed the input tokens to the penalizer.
315
- """
316
- pass
317
-
318
- @abc.abstractmethod
319
- def _cumulate_output_tokens(self, output_ids: _TokenIDs):
170
+ def _cumulate_output_tokens(self, output_ids: torch.Tensor):
320
171
  """
321
172
  Cumulate the output tokens.
322
173
  Orchestrator will call this function to feed the output tokens to the penalizer.
@@ -332,7 +183,7 @@ class _BatchedPenalizer(abc.ABC):
332
183
  pass
333
184
 
334
185
  @abc.abstractmethod
335
- def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
186
+ def _filter(self, keep_indices: torch.Tensor):
336
187
  """
337
188
  Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch.
338
189
  """
@@ -0,0 +1,66 @@
1
+ import torch
2
+
3
+ from sglang.srt.sampling.penaltylib.orchestrator import (
4
+ BatchedPenalizerOrchestrator,
5
+ _BatchedPenalizer,
6
+ )
7
+
8
+
9
+ class BatchedPresencePenalizer(_BatchedPenalizer):
10
+ """
11
+ Presence penalizer penalizes tokens based on their presence in the output.
12
+ """
13
+
14
+ def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
15
+ self.orchestrator = orchestrator
16
+ self._is_prepared = False
17
+
18
+ def _is_required(self) -> bool:
19
+ return any(
20
+ req.sampling_params.presence_penalty != 0.0
21
+ for req in self.orchestrator.reqs()
22
+ )
23
+
24
+ def _prepare(self):
25
+ self.cumulated_presence_penalties = torch.zeros(
26
+ (len(self.orchestrator.reqs()), self.orchestrator.vocab_size),
27
+ dtype=torch.float32,
28
+ device=self.orchestrator.device,
29
+ )
30
+
31
+ self.presence_penalties = (
32
+ torch.tensor(
33
+ data=[
34
+ req.sampling_params.presence_penalty
35
+ for req in self.orchestrator.reqs()
36
+ ],
37
+ dtype=torch.float32,
38
+ device=self.orchestrator.device,
39
+ )
40
+ ).unsqueeze_(1)
41
+
42
+ def _cumulate_output_tokens(self, output_ids: torch.Tensor):
43
+ self.cumulated_presence_penalties.scatter_(
44
+ dim=1,
45
+ index=output_ids.unsqueeze(1),
46
+ src=self.presence_penalties,
47
+ )
48
+
49
+ def _apply(self, logits: torch.Tensor) -> torch.Tensor:
50
+ logits.sub_(self.cumulated_presence_penalties)
51
+
52
+ def _filter(self, keep_indices: torch.Tensor):
53
+ self.presence_penalties = self.presence_penalties[keep_indices]
54
+ self.cumulated_presence_penalties = self.cumulated_presence_penalties[
55
+ keep_indices
56
+ ]
57
+
58
+ def _merge(self, their: "BatchedPresencePenalizer"):
59
+ print(f"{self.presence_penalties.shape=}, {their.presence_penalties.shape=}")
60
+ self.presence_penalties = torch.cat(
61
+ [self.presence_penalties, their.presence_penalties], dim=0
62
+ )
63
+ self.cumulated_presence_penalties = torch.cat(
64
+ [self.cumulated_presence_penalties, their.cumulated_presence_penalties],
65
+ dim=0,
66
+ )