sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +220 -378
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  63. sglang/srt/layers/moe/topk.py +13 -4
  64. sglang/srt/layers/quantization/__init__.py +111 -7
  65. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  66. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  69. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  71. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  72. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  73. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  86. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/fp8.py +69 -28
  89. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  90. sglang/srt/layers/quantization/gptq.py +416 -0
  91. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  92. sglang/srt/layers/quantization/int8_utils.py +73 -0
  93. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  94. sglang/srt/layers/radix_attention.py +1 -0
  95. sglang/srt/layers/rotary_embedding.py +0 -1
  96. sglang/srt/layers/sampler.py +76 -31
  97. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  98. sglang/srt/lora/lora.py +17 -1
  99. sglang/srt/lora/lora_config.py +5 -0
  100. sglang/srt/lora/lora_manager.py +1 -3
  101. sglang/srt/managers/cache_controller.py +193 -62
  102. sglang/srt/managers/configure_logging.py +2 -1
  103. sglang/srt/managers/data_parallel_controller.py +6 -2
  104. sglang/srt/managers/detokenizer_manager.py +124 -102
  105. sglang/srt/managers/image_processor.py +2 -1
  106. sglang/srt/managers/io_struct.py +143 -6
  107. sglang/srt/managers/schedule_batch.py +237 -197
  108. sglang/srt/managers/schedule_policy.py +29 -29
  109. sglang/srt/managers/scheduler.py +681 -259
  110. sglang/srt/managers/session_controller.py +6 -2
  111. sglang/srt/managers/tokenizer_manager.py +224 -68
  112. sglang/srt/managers/tp_worker.py +15 -4
  113. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  114. sglang/srt/mem_cache/chunk_cache.py +18 -11
  115. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  116. sglang/srt/mem_cache/memory_pool.py +44 -18
  117. sglang/srt/mem_cache/radix_cache.py +58 -47
  118. sglang/srt/metrics/collector.py +94 -36
  119. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  120. sglang/srt/model_executor/forward_batch_info.py +49 -16
  121. sglang/srt/model_executor/model_runner.py +208 -28
  122. sglang/srt/model_loader/loader.py +3 -3
  123. sglang/srt/model_loader/weight_utils.py +36 -14
  124. sglang/srt/models/baichuan.py +31 -6
  125. sglang/srt/models/chatglm.py +39 -7
  126. sglang/srt/models/commandr.py +29 -5
  127. sglang/srt/models/dbrx.py +31 -5
  128. sglang/srt/models/deepseek.py +43 -6
  129. sglang/srt/models/deepseek_nextn.py +32 -19
  130. sglang/srt/models/deepseek_v2.py +265 -32
  131. sglang/srt/models/exaone.py +19 -9
  132. sglang/srt/models/gemma.py +22 -8
  133. sglang/srt/models/gemma2.py +25 -12
  134. sglang/srt/models/gemma2_reward.py +5 -1
  135. sglang/srt/models/gpt2.py +28 -13
  136. sglang/srt/models/gpt_bigcode.py +27 -5
  137. sglang/srt/models/granite.py +21 -9
  138. sglang/srt/models/grok.py +21 -4
  139. sglang/srt/models/internlm2.py +36 -6
  140. sglang/srt/models/internlm2_reward.py +5 -1
  141. sglang/srt/models/llama.py +26 -9
  142. sglang/srt/models/llama_classification.py +5 -1
  143. sglang/srt/models/llama_eagle.py +17 -4
  144. sglang/srt/models/llama_embedding.py +5 -1
  145. sglang/srt/models/llama_reward.py +7 -2
  146. sglang/srt/models/llava.py +19 -3
  147. sglang/srt/models/llavavid.py +10 -1
  148. sglang/srt/models/minicpm.py +26 -2
  149. sglang/srt/models/minicpm3.py +39 -3
  150. sglang/srt/models/minicpmv.py +45 -14
  151. sglang/srt/models/mixtral.py +20 -9
  152. sglang/srt/models/mixtral_quant.py +50 -8
  153. sglang/srt/models/mllama.py +57 -11
  154. sglang/srt/models/olmo.py +34 -6
  155. sglang/srt/models/olmo2.py +34 -13
  156. sglang/srt/models/olmoe.py +26 -4
  157. sglang/srt/models/phi3_small.py +29 -10
  158. sglang/srt/models/qwen.py +26 -3
  159. sglang/srt/models/qwen2.py +26 -4
  160. sglang/srt/models/qwen2_5_vl.py +46 -8
  161. sglang/srt/models/qwen2_eagle.py +17 -5
  162. sglang/srt/models/qwen2_moe.py +44 -6
  163. sglang/srt/models/qwen2_rm.py +78 -0
  164. sglang/srt/models/qwen2_vl.py +39 -8
  165. sglang/srt/models/stablelm.py +32 -5
  166. sglang/srt/models/torch_native_llama.py +5 -2
  167. sglang/srt/models/xverse.py +21 -9
  168. sglang/srt/models/xverse_moe.py +45 -7
  169. sglang/srt/models/yivl.py +2 -1
  170. sglang/srt/openai_api/adapter.py +109 -24
  171. sglang/srt/openai_api/protocol.py +17 -1
  172. sglang/srt/reasoning_parser.py +154 -0
  173. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  174. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  175. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  176. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  177. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  178. sglang/srt/sampling/sampling_batch_info.py +79 -157
  179. sglang/srt/sampling/sampling_params.py +16 -13
  180. sglang/srt/server_args.py +136 -52
  181. sglang/srt/speculative/build_eagle_tree.py +2 -8
  182. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  183. sglang/srt/speculative/eagle_utils.py +92 -58
  184. sglang/srt/speculative/eagle_worker.py +186 -94
  185. sglang/srt/speculative/spec_info.py +1 -13
  186. sglang/srt/utils.py +43 -17
  187. sglang/srt/warmup.py +47 -0
  188. sglang/test/few_shot_gsm8k.py +4 -1
  189. sglang/test/runners.py +389 -126
  190. sglang/test/send_one.py +88 -0
  191. sglang/test/test_block_fp8_ep.py +361 -0
  192. sglang/test/test_programs.py +1 -1
  193. sglang/test/test_utils.py +138 -84
  194. sglang/utils.py +50 -60
  195. sglang/version.py +1 -1
  196. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  197. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +200 -166
  198. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  199. sglang/bench_latency.py +0 -1
  200. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  201. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  202. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  203. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  204. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  205. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.0)
2
+ Generator: setuptools (75.8.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
sglang/bench_latency.py DELETED
@@ -1 +0,0 @@
1
- raise ValueError("bench_latency.py has been renamed to bench_one_batch.py")
@@ -1,75 +0,0 @@
1
- from typing import List
2
-
3
- import torch
4
-
5
- from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
6
-
7
-
8
- class BatchedFrequencyPenalizer(_BatchedPenalizer):
9
- """
10
- Frequency penalizer penalizes tokens based on their frequency in the output.
11
- """
12
-
13
- frequency_penalties: torch.Tensor = None
14
- cumulated_frequency_penalties: torch.Tensor = None
15
-
16
- def _is_required(self) -> bool:
17
- return any(
18
- req.sampling_params.frequency_penalty != 0.0
19
- for req in self.orchestrator.reqs()
20
- )
21
-
22
- def _prepare(self):
23
- self.cumulated_frequency_penalties = (
24
- torch.tensor(
25
- data=[0.0 for _ in self.orchestrator.reqs()],
26
- dtype=torch.float32,
27
- device=self.orchestrator.device,
28
- )
29
- .unsqueeze_(1)
30
- .repeat(1, self.orchestrator.vocab_size)
31
- )
32
-
33
- self.frequency_penalties = (
34
- torch.tensor(
35
- data=[
36
- req.sampling_params.frequency_penalty
37
- for req in self.orchestrator.reqs()
38
- ],
39
- dtype=torch.float32,
40
- device=self.orchestrator.device,
41
- )
42
- .unsqueeze_(1)
43
- .expand_as(self.cumulated_frequency_penalties)
44
- )
45
-
46
- def _teardown(self):
47
- self.frequency_penalties = None
48
- self.cumulated_frequency_penalties = None
49
-
50
- def _cumulate_input_tokens(self, input_ids: _TokenIDs):
51
- pass
52
-
53
- def _cumulate_output_tokens(self, output_ids: _TokenIDs):
54
- self.cumulated_frequency_penalties += (
55
- self.frequency_penalties * output_ids.occurrence_count()
56
- )
57
-
58
- def _apply(self, logits: torch.Tensor) -> torch.Tensor:
59
- logits -= self.cumulated_frequency_penalties
60
- return logits
61
-
62
- def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
63
- self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep]
64
- self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
65
- indices_tensor_to_keep
66
- ]
67
-
68
- def _merge(self, their: "BatchedFrequencyPenalizer"):
69
- self.frequency_penalties = torch.cat(
70
- [self.frequency_penalties, their.frequency_penalties], dim=0
71
- )
72
- self.cumulated_frequency_penalties = torch.cat(
73
- [self.cumulated_frequency_penalties, their.cumulated_frequency_penalties],
74
- dim=0,
75
- )
@@ -1,74 +0,0 @@
1
- from typing import List
2
-
3
- import torch
4
-
5
- from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
6
-
7
-
8
- class BatchedPresencePenalizer(_BatchedPenalizer):
9
- """
10
- Presence penalizer penalizes tokens based on their presence in the output.
11
- """
12
-
13
- presence_penalties: torch.Tensor = None
14
- cumulated_presence_penalties: torch.Tensor = None
15
-
16
- def _is_required(self) -> bool:
17
- return any(
18
- req.sampling_params.presence_penalty != 0.0
19
- for req in self.orchestrator.reqs()
20
- )
21
-
22
- def _prepare(self):
23
- self.cumulated_presence_penalties = (
24
- torch.tensor(
25
- data=[0.0 for _ in self.orchestrator.reqs()],
26
- dtype=torch.float32,
27
- device=self.orchestrator.device,
28
- )
29
- .unsqueeze_(1)
30
- .repeat(1, self.orchestrator.vocab_size)
31
- )
32
-
33
- self.presence_penalties = (
34
- torch.tensor(
35
- data=[
36
- req.sampling_params.presence_penalty
37
- for req in self.orchestrator.reqs()
38
- ],
39
- dtype=torch.float32,
40
- device=self.orchestrator.device,
41
- )
42
- .unsqueeze_(1)
43
- .expand_as(self.cumulated_presence_penalties)
44
- )
45
-
46
- def _teardown(self):
47
- self.presence_penalties = None
48
- self.cumulated_presence_penalties = None
49
-
50
- def _cumulate_input_tokens(self, input_ids: _TokenIDs):
51
- pass
52
-
53
- def _cumulate_output_tokens(self, output_ids: _TokenIDs):
54
- mask = output_ids.occurrence_count() > 0
55
- self.cumulated_presence_penalties[mask] = self.presence_penalties[mask]
56
-
57
- def _apply(self, logits: torch.Tensor) -> torch.Tensor:
58
- logits -= self.cumulated_presence_penalties
59
- return logits
60
-
61
- def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
62
- self.presence_penalties = self.presence_penalties[indices_tensor_to_keep]
63
- self.cumulated_presence_penalties = self.cumulated_presence_penalties[
64
- indices_tensor_to_keep
65
- ]
66
-
67
- def _merge(self, their: "BatchedPresencePenalizer"):
68
- self.presence_penalties = torch.cat(
69
- [self.presence_penalties, their.presence_penalties], dim=0
70
- )
71
- self.cumulated_presence_penalties = torch.cat(
72
- [self.cumulated_presence_penalties, their.cumulated_presence_penalties],
73
- dim=0,
74
- )
@@ -1,85 +0,0 @@
1
- from typing import List
2
-
3
- import torch
4
-
5
- from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
6
- from sglang.srt.utils import get_compiler_backend
7
-
8
-
9
- @torch.compile(dynamic=True, backend=get_compiler_backend())
10
- def apply_scaling_penalties(logits, scaling_penalties):
11
- logits[:] = torch.where(
12
- logits > 0,
13
- logits / scaling_penalties,
14
- logits * scaling_penalties,
15
- )
16
-
17
-
18
- class BatchedRepetitionPenalizer(_BatchedPenalizer):
19
- """
20
- Repetition penalizer penalizes tokens based on their repetition in the input and output.
21
- """
22
-
23
- repetition_penalties: torch.Tensor = None
24
- cumulated_repetition_penalties: torch.Tensor = None
25
-
26
- def _is_required(self) -> bool:
27
- return any(
28
- req.sampling_params.repetition_penalty != 1.0
29
- for req in self.orchestrator.reqs()
30
- )
31
-
32
- def _prepare(self):
33
- self.cumulated_repetition_penalties = (
34
- torch.tensor(
35
- data=[1.0 for _ in self.orchestrator.reqs()],
36
- dtype=torch.float32,
37
- device=self.orchestrator.device,
38
- )
39
- .unsqueeze_(1)
40
- .repeat(1, self.orchestrator.vocab_size)
41
- )
42
-
43
- self.repetition_penalties = (
44
- torch.tensor(
45
- data=[
46
- req.sampling_params.repetition_penalty
47
- for req in self.orchestrator.reqs()
48
- ],
49
- dtype=torch.float32,
50
- device=self.orchestrator.device,
51
- )
52
- .unsqueeze_(1)
53
- .expand_as(self.cumulated_repetition_penalties)
54
- )
55
-
56
- def _teardown(self):
57
- self.repetition_penalties = None
58
- self.cumulated_repetition_penalties = None
59
-
60
- def _cumulate_input_tokens(self, input_ids: _TokenIDs):
61
- mask = input_ids.occurrence_count() > 0
62
- self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
63
-
64
- def _cumulate_output_tokens(self, output_ids: _TokenIDs):
65
- mask = output_ids.occurrence_count() > 0
66
- self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
67
-
68
- def _apply(self, logits: torch.Tensor) -> torch.Tensor:
69
- apply_scaling_penalties(logits, self.cumulated_repetition_penalties)
70
- return logits
71
-
72
- def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
73
- self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
74
- self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[
75
- indices_tensor_to_keep
76
- ]
77
-
78
- def _merge(self, their: "BatchedRepetitionPenalizer"):
79
- self.repetition_penalties = torch.cat(
80
- [self.repetition_penalties, their.repetition_penalties], dim=0
81
- )
82
- self.cumulated_repetition_penalties = torch.cat(
83
- [self.cumulated_repetition_penalties, their.cumulated_repetition_penalties],
84
- dim=0,
85
- )
@@ -1,344 +0,0 @@
1
- import dataclasses
2
- import enum
3
- import unittest
4
- from typing import Dict, List, Optional, Set, Tuple, Type
5
-
6
- import torch
7
-
8
- from sglang.srt.sampling.penaltylib.orchestrator import (
9
- BatchedPenalizerOrchestrator,
10
- _BatchedPenalizer,
11
- _BatchLike,
12
- )
13
-
14
-
15
- @dataclasses.dataclass
16
- class MockSamplingParams:
17
- frequency_penalty: float = 0.0
18
- min_new_tokens: int = 0
19
- stop_token_ids: List[int] = None
20
- presence_penalty: float = 0.0
21
- repetition_penalty: float = 1.0
22
-
23
-
24
- @dataclasses.dataclass
25
- class MockTokenizer:
26
- eos_token_id: int
27
- additional_stop_token_ids: Optional[List[int]] = None
28
-
29
-
30
- @dataclasses.dataclass
31
- class MockReq:
32
- origin_input_ids: List[int]
33
- sampling_params: MockSamplingParams
34
- tokenizer: MockTokenizer
35
-
36
-
37
- class StepType(enum.Enum):
38
- INPUT = "input"
39
- OUTPUT = "output"
40
-
41
-
42
- @dataclasses.dataclass
43
- class Step:
44
- type: StepType
45
- token_ids: List[int]
46
- expected_tensors: Dict[str, torch.Tensor]
47
- # assume initial logits are all 1
48
- expected_logits: torch.Tensor
49
-
50
-
51
- @dataclasses.dataclass
52
- class Subject:
53
- sampling_params: MockSamplingParams
54
- # first step must be input, which will be converted to Req
55
- steps: List[Step]
56
- eos_token_id: int = -1
57
-
58
- def __post_init__(self):
59
- if self.steps[0].type != StepType.INPUT:
60
- raise ValueError("First step must be input")
61
-
62
- # each steps should have the same expected_tensors.keys()
63
- for i in range(1, len(self.steps)):
64
- if self.tensor_keys(i) != self.tensor_keys():
65
- raise ValueError(
66
- f"Expected tensors keys must be the same for all steps. Got {self.steps[i].expected_tensors.keys()} for key={i} and {self.steps[0].expected_tensors.keys()}"
67
- )
68
-
69
- def tensor_keys(self, i: int = 0) -> Set[str]:
70
- return set(self.steps[i].expected_tensors.keys())
71
-
72
- def to_req(self) -> MockReq:
73
- return MockReq(
74
- origin_input_ids=self.steps[0].token_ids,
75
- sampling_params=self.sampling_params,
76
- tokenizer=MockTokenizer(eos_token_id=self.eos_token_id),
77
- )
78
-
79
-
80
- @dataclasses.dataclass
81
- class Case:
82
- enabled: bool
83
- test_subjects: List[Subject]
84
-
85
- def __post_init__(self):
86
- # each test_subjects.steps should have the same expected_tensors.keys()
87
- for i in range(1, len(self.test_subjects)):
88
- if self.tensor_keys(i) != self.tensor_keys():
89
- raise ValueError(
90
- f"Expected tensors keys must be the same for all test_subjects. Got {self.test_subjects[i].tensor_keys()} for key={i} and {self.test_subjects[0].tensor_keys()}"
91
- )
92
-
93
- def tensor_keys(self, i: int = 0) -> List[str]:
94
- return set(self.test_subjects[i].tensor_keys())
95
-
96
-
97
- class BaseBatchedPenalizerTest(unittest.TestCase):
98
- Penalizer: Type[_BatchedPenalizer]
99
- device = "cuda"
100
- vocab_size = 5
101
-
102
- enabled: Subject = None
103
- disabled: Subject = None
104
-
105
- def setUp(self):
106
- if self.__class__ == BaseBatchedPenalizerTest:
107
- self.skipTest("Base class for penalizer tests")
108
-
109
- self.create_test_subjects()
110
- self.create_test_cases()
111
-
112
- def tensor(self, data, **kwargs) -> torch.Tensor:
113
- """
114
- Shortcut to create a tensor with device=self.device.
115
- """
116
- return torch.tensor(data, **kwargs, device=self.device)
117
-
118
- def create_test_subjects(self) -> List[Subject]:
119
- raise NotImplementedError()
120
-
121
- def create_test_cases(self):
122
- self.test_cases = [
123
- Case(enabled=True, test_subjects=[self.enabled]),
124
- Case(enabled=False, test_subjects=[self.disabled]),
125
- Case(enabled=True, test_subjects=[self.enabled, self.disabled]),
126
- ]
127
-
128
- def _create_penalizer(
129
- self, case: Case
130
- ) -> Tuple[BatchedPenalizerOrchestrator, _BatchedPenalizer]:
131
- orchestrator = BatchedPenalizerOrchestrator(
132
- vocab_size=self.vocab_size,
133
- batch=_BatchLike(reqs=[subject.to_req() for subject in case.test_subjects]),
134
- device=self.device,
135
- Penalizers={self.Penalizer},
136
- )
137
-
138
- return orchestrator, orchestrator.penalizers[self.Penalizer]
139
-
140
- def test_is_required(self):
141
- for case in self.test_cases:
142
- with self.subTest(case=case):
143
- _, penalizer = self._create_penalizer(case)
144
- self.assertEqual(case.enabled, penalizer.is_required())
145
-
146
- def test_prepare(self):
147
- for case in self.test_cases:
148
- with self.subTest(case=case):
149
- orchestrator, penalizer = self._create_penalizer(case)
150
- self.assertEqual(case.enabled, penalizer.is_prepared())
151
-
152
- if case.enabled:
153
- for key, tensor in {
154
- key: torch.cat(
155
- tensors=[
156
- subject.steps[0].expected_tensors[key]
157
- for subject in case.test_subjects
158
- ],
159
- )
160
- for key in case.tensor_keys()
161
- }.items():
162
- torch.testing.assert_close(
163
- actual=getattr(penalizer, key),
164
- expected=tensor,
165
- msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
166
- )
167
-
168
- original = torch.ones(
169
- size=(len(case.test_subjects), self.vocab_size),
170
- dtype=torch.float32,
171
- device=self.device,
172
- )
173
- actual = orchestrator.apply(original.clone())
174
- expected = torch.cat(
175
- tensors=[
176
- subject.steps[0].expected_logits
177
- for subject in case.test_subjects
178
- ],
179
- )
180
- if actual is None:
181
- actual = original
182
- torch.testing.assert_close(
183
- actual=actual,
184
- expected=expected,
185
- msg=f"logits\nactual={actual}\nexpected={expected}",
186
- )
187
-
188
- def test_teardown(self):
189
- for case in self.test_cases:
190
- with self.subTest(case=case):
191
- _, penalizer = self._create_penalizer(case)
192
- penalizer.teardown()
193
-
194
- for key in case.test_subjects[0].steps[0].expected_tensors.keys():
195
- self.assertIsNone(getattr(penalizer, key, None))
196
-
197
- def test_filter(self):
198
- for case in self.test_cases:
199
- with self.subTest(case=case):
200
- orchestrator, penalizer = self._create_penalizer(case)
201
-
202
- indices_to_keep = [0]
203
- orchestrator.filter(indices_to_keep=indices_to_keep)
204
-
205
- filtered_subjects = [case.test_subjects[i] for i in indices_to_keep]
206
-
207
- if penalizer.is_required():
208
- self.assertTrue(penalizer.is_prepared())
209
- for key, tensor in {
210
- key: torch.cat(
211
- tensors=[
212
- subject.steps[0].expected_tensors[key]
213
- for subject in filtered_subjects
214
- ],
215
- )
216
- for key in case.tensor_keys()
217
- }.items():
218
- torch.testing.assert_close(
219
- actual=getattr(penalizer, key),
220
- expected=tensor,
221
- msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
222
- )
223
-
224
- actual_logits = orchestrator.apply(
225
- torch.ones(
226
- size=(len(filtered_subjects), self.vocab_size),
227
- dtype=torch.float32,
228
- device=self.device,
229
- )
230
- )
231
- if actual_logits is None:
232
- continue
233
- filtered_expected_logits = torch.cat(
234
- tensors=[
235
- subject.steps[0].expected_logits
236
- for subject in filtered_subjects
237
- ],
238
- )
239
- torch.testing.assert_close(
240
- actual=actual_logits,
241
- expected=filtered_expected_logits,
242
- msg=f"logits\nactual={actual_logits}\nexpected={filtered_expected_logits}",
243
- )
244
-
245
- def test_merge_enabled_with_disabled(self):
246
- enabled_test_case = self.test_cases[0]
247
- disabled_test_case = self.test_cases[1]
248
-
249
- orchestrator, penalizer = self._create_penalizer(enabled_test_case)
250
- theirs, _ = self._create_penalizer(disabled_test_case)
251
-
252
- orchestrator.merge(theirs)
253
-
254
- for key, tensor in {
255
- key: torch.cat(
256
- tensors=[
257
- enabled_test_case.test_subjects[0].steps[0].expected_tensors[key],
258
- disabled_test_case.test_subjects[0].steps[0].expected_tensors[key],
259
- ],
260
- )
261
- for key in enabled_test_case.tensor_keys()
262
- }.items():
263
- torch.testing.assert_close(
264
- actual=getattr(penalizer, key),
265
- expected=tensor,
266
- msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
267
- )
268
-
269
- def test_cumulate_apply_repeat(self):
270
- for case in self.test_cases:
271
- with self.subTest(case=case):
272
- orchestrator, penalizer = self._create_penalizer(case)
273
-
274
- max_step = max(len(subject.steps) for subject in case.test_subjects)
275
- for i in range(1, max_step):
276
- orchestrator.filter(
277
- indices_to_keep=[
278
- j
279
- for j, subject in enumerate(case.test_subjects)
280
- if i < len(subject.steps)
281
- ]
282
- )
283
-
284
- filtered_subjects = [
285
- subject
286
- for subject in case.test_subjects
287
- if i < len(subject.steps)
288
- ]
289
-
290
- inputs: List[List[int]] = []
291
- outputs: List[List[int]] = []
292
- for subject in filtered_subjects:
293
- step = subject.steps[i]
294
- if step.type == StepType.INPUT:
295
- raise NotImplementedError()
296
- else:
297
- inputs.append([])
298
- outputs.append(step.token_ids)
299
-
300
- if any(outputs):
301
- for j in range(max(len(x) for x in outputs)):
302
- tmp_outputs = torch.tensor(
303
- [x[j] for x in outputs],
304
- dtype=torch.int32,
305
- device=orchestrator.device,
306
- )
307
- orchestrator.cumulate_output_tokens(tmp_outputs)
308
-
309
- if penalizer.is_required():
310
- self.assertTrue(penalizer.is_prepared())
311
- for key, tensor in {
312
- key: torch.cat(
313
- tensors=[
314
- subject.steps[i].expected_tensors[key]
315
- for subject in filtered_subjects
316
- ],
317
- )
318
- for key in case.tensor_keys()
319
- }.items():
320
- torch.testing.assert_close(
321
- actual=getattr(penalizer, key),
322
- expected=tensor,
323
- msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
324
- )
325
-
326
- original = torch.ones(
327
- size=(len(filtered_subjects), self.vocab_size),
328
- dtype=torch.float32,
329
- device=self.device,
330
- )
331
- actual_logits = orchestrator.apply(original.clone())
332
- filtered_expected_logits = torch.cat(
333
- tensors=[
334
- subject.steps[i].expected_logits
335
- for subject in filtered_subjects
336
- ],
337
- )
338
- if actual_logits is None:
339
- actual_logits = original
340
- torch.testing.assert_close(
341
- actual=actual_logits,
342
- expected=filtered_expected_logits,
343
- msg=f"logits\nactual={actual_logits}\nexpected={filtered_expected_logits}",
344
- )