sglang 0.4.3.post1__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +208 -295
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +9 -6
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +143 -6
- sglang/srt/managers/schedule_batch.py +238 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +681 -259
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +224 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +44 -18
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +94 -36
- sglang/srt/model_executor/cuda_graph_runner.py +55 -24
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +209 -28
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -29
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +136 -52
- sglang/srt/speculative/build_eagle_tree.py +2 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
- sglang/srt/speculative/eagle_utils.py +92 -58
- sglang/srt/speculative/eagle_worker.py +186 -94
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +214 -166
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,66 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
from sglang.srt.sampling.penaltylib.orchestrator import (
|
4
|
+
BatchedPenalizerOrchestrator,
|
5
|
+
_BatchedPenalizer,
|
6
|
+
)
|
7
|
+
|
8
|
+
|
9
|
+
class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
10
|
+
"""
|
11
|
+
Frequency penalizer penalizes tokens based on their frequency in the output.
|
12
|
+
"""
|
13
|
+
|
14
|
+
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
|
15
|
+
self.orchestrator = orchestrator
|
16
|
+
self._is_prepared = False
|
17
|
+
|
18
|
+
def _is_required(self) -> bool:
|
19
|
+
return any(
|
20
|
+
req.sampling_params.frequency_penalty != 0.0
|
21
|
+
for req in self.orchestrator.reqs()
|
22
|
+
)
|
23
|
+
|
24
|
+
def _prepare(self):
|
25
|
+
self.cumulated_frequency_penalties = torch.zeros(
|
26
|
+
(len(self.orchestrator.reqs()), self.orchestrator.vocab_size),
|
27
|
+
dtype=torch.float32,
|
28
|
+
device=self.orchestrator.device,
|
29
|
+
)
|
30
|
+
|
31
|
+
self.frequency_penalties = (
|
32
|
+
torch.tensor(
|
33
|
+
data=[
|
34
|
+
req.sampling_params.frequency_penalty
|
35
|
+
for req in self.orchestrator.reqs()
|
36
|
+
],
|
37
|
+
dtype=torch.float32,
|
38
|
+
device=self.orchestrator.device,
|
39
|
+
)
|
40
|
+
).unsqueeze_(1)
|
41
|
+
|
42
|
+
def _cumulate_output_tokens(self, output_ids: torch.Tensor):
|
43
|
+
self.cumulated_frequency_penalties.scatter_add_(
|
44
|
+
dim=1,
|
45
|
+
index=output_ids.unsqueeze(1),
|
46
|
+
src=self.frequency_penalties,
|
47
|
+
)
|
48
|
+
|
49
|
+
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
50
|
+
logits.sub_(self.cumulated_frequency_penalties)
|
51
|
+
|
52
|
+
def _filter(self, keep_indices: torch.Tensor):
|
53
|
+
self.frequency_penalties = self.frequency_penalties[keep_indices]
|
54
|
+
self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
|
55
|
+
keep_indices
|
56
|
+
]
|
57
|
+
|
58
|
+
def _merge(self, their: "BatchedFrequencyPenalizer"):
|
59
|
+
print(f"{self.frequency_penalties.shape=}, {their.frequency_penalties.shape=}")
|
60
|
+
self.frequency_penalties = torch.cat(
|
61
|
+
[self.frequency_penalties, their.frequency_penalties], dim=0
|
62
|
+
)
|
63
|
+
self.cumulated_frequency_penalties = torch.cat(
|
64
|
+
[self.cumulated_frequency_penalties, their.cumulated_frequency_penalties],
|
65
|
+
dim=0,
|
66
|
+
)
|
@@ -1,8 +1,9 @@
|
|
1
|
-
from typing import List
|
2
|
-
|
3
1
|
import torch
|
4
2
|
|
5
|
-
from sglang.srt.sampling.penaltylib.orchestrator import
|
3
|
+
from sglang.srt.sampling.penaltylib.orchestrator import (
|
4
|
+
BatchedPenalizerOrchestrator,
|
5
|
+
_BatchedPenalizer,
|
6
|
+
)
|
6
7
|
|
7
8
|
|
8
9
|
class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
@@ -10,9 +11,9 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
|
10
11
|
Min new tokens penalizer penalizes tokens based on the length of the output.
|
11
12
|
"""
|
12
13
|
|
13
|
-
|
14
|
-
|
15
|
-
|
14
|
+
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
|
15
|
+
self.orchestrator = orchestrator
|
16
|
+
self._is_prepared = False
|
16
17
|
|
17
18
|
def _is_required(self) -> bool:
|
18
19
|
return any(
|
@@ -47,7 +48,7 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
|
47
48
|
padding_value=self.orchestrator.vocab_size,
|
48
49
|
)
|
49
50
|
self.stop_token_penalties = torch.zeros(
|
50
|
-
size=(self.orchestrator.
|
51
|
+
size=(len(self.orchestrator.reqs()), self.orchestrator.vocab_size + 1),
|
51
52
|
dtype=torch.float32,
|
52
53
|
device=self.orchestrator.device,
|
53
54
|
).scatter_add_(
|
@@ -64,31 +65,22 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
|
64
65
|
]
|
65
66
|
|
66
67
|
self.len_output_tokens = torch.zeros(
|
67
|
-
size=(self.orchestrator.
|
68
|
+
size=(len(self.orchestrator.reqs()), 1),
|
68
69
|
dtype=torch.int32,
|
69
70
|
device=self.orchestrator.device,
|
70
71
|
)
|
71
72
|
|
72
|
-
def
|
73
|
-
self.min_new_tokens = None
|
74
|
-
self.stop_token_penalties = None
|
75
|
-
self.len_output_tokens = None
|
76
|
-
|
77
|
-
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
78
|
-
pass
|
79
|
-
|
80
|
-
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
73
|
+
def _cumulate_output_tokens(self, output_ids: torch.Tensor):
|
81
74
|
self.len_output_tokens += 1
|
82
75
|
|
83
|
-
def _apply(self, logits: torch.Tensor)
|
76
|
+
def _apply(self, logits: torch.Tensor):
|
84
77
|
mask = (self.len_output_tokens < self.min_new_tokens).expand_as(logits)
|
85
78
|
logits[mask] += self.stop_token_penalties[mask]
|
86
|
-
return logits
|
87
79
|
|
88
|
-
def _filter(self,
|
89
|
-
self.min_new_tokens = self.min_new_tokens[
|
90
|
-
self.stop_token_penalties = self.stop_token_penalties[
|
91
|
-
self.len_output_tokens = self.len_output_tokens[
|
80
|
+
def _filter(self, keep_indices: torch.Tensor):
|
81
|
+
self.min_new_tokens = self.min_new_tokens[keep_indices]
|
82
|
+
self.stop_token_penalties = self.stop_token_penalties[keep_indices]
|
83
|
+
self.len_output_tokens = self.len_output_tokens[keep_indices]
|
92
84
|
|
93
85
|
def _merge(self, their: "BatchedMinNewTokensPenalizer"):
|
94
86
|
self.min_new_tokens = torch.cat(
|
@@ -1,35 +1,25 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import abc
|
2
|
-
import
|
3
|
-
from typing import List, Set, Type, Union
|
4
|
+
from typing import TYPE_CHECKING, Set, Type
|
4
5
|
|
5
6
|
import torch
|
6
7
|
|
7
|
-
|
8
|
-
|
9
|
-
class _ReqLike:
|
10
|
-
origin_input_ids: List[int]
|
11
|
-
|
12
|
-
|
13
|
-
@dataclasses.dataclass
|
14
|
-
class _BatchLike:
|
15
|
-
reqs: List[_ReqLike]
|
16
|
-
|
17
|
-
def batch_size(self):
|
18
|
-
return len(self.reqs)
|
8
|
+
if TYPE_CHECKING:
|
9
|
+
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
19
10
|
|
20
11
|
|
21
12
|
class BatchedPenalizerOrchestrator:
|
22
13
|
def __init__(
|
23
14
|
self,
|
24
15
|
vocab_size: int,
|
25
|
-
batch:
|
26
|
-
|
27
|
-
Penalizers: Set[Type["_BatchedPenalizer"]],
|
16
|
+
batch: ScheduleBatch,
|
17
|
+
penalizers: Set[Type["_BatchedPenalizer"]],
|
28
18
|
):
|
29
19
|
self.vocab_size = vocab_size
|
30
20
|
self.batch = batch
|
31
|
-
self.device = device
|
32
|
-
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in
|
21
|
+
self.device = batch.device
|
22
|
+
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in penalizers}
|
33
23
|
|
34
24
|
is_required = False
|
35
25
|
for penalizer in self.penalizers.values():
|
@@ -37,31 +27,9 @@ class BatchedPenalizerOrchestrator:
|
|
37
27
|
is_required |= pen_is_required
|
38
28
|
self.is_required = is_required
|
39
29
|
|
40
|
-
input_ids = [
|
41
|
-
torch.tensor(req.origin_input_ids, dtype=torch.int64, device=self.device)
|
42
|
-
for req in self.reqs()
|
43
|
-
]
|
44
|
-
if self.is_required:
|
45
|
-
self.cumulate_input_tokens(input_ids=input_ids)
|
46
|
-
|
47
30
|
def reqs(self):
|
48
31
|
return self.batch.reqs
|
49
32
|
|
50
|
-
def batch_size(self):
|
51
|
-
return self.batch.batch_size()
|
52
|
-
|
53
|
-
def cumulate_input_tokens(self, input_ids: List[torch.Tensor]):
|
54
|
-
"""
|
55
|
-
Feed the input tokens to the penalizers.
|
56
|
-
|
57
|
-
Args:
|
58
|
-
input_ids (List[torch.Tensor]): The input tokens.
|
59
|
-
"""
|
60
|
-
token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids)
|
61
|
-
|
62
|
-
for penalizer in self.penalizers.values():
|
63
|
-
penalizer.cumulate_input_tokens(input_ids=token_ids)
|
64
|
-
|
65
33
|
def cumulate_output_tokens(self, output_ids: torch.Tensor):
|
66
34
|
"""
|
67
35
|
Feed the output tokens to the penalizers.
|
@@ -69,13 +37,8 @@ class BatchedPenalizerOrchestrator:
|
|
69
37
|
Args:
|
70
38
|
output_ids (torch.Tensor): The output tokens.
|
71
39
|
"""
|
72
|
-
if not self.is_required:
|
73
|
-
return
|
74
|
-
|
75
|
-
token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids)
|
76
|
-
|
77
40
|
for penalizer in self.penalizers.values():
|
78
|
-
penalizer.cumulate_output_tokens(output_ids=
|
41
|
+
penalizer.cumulate_output_tokens(output_ids=output_ids)
|
79
42
|
|
80
43
|
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
81
44
|
"""
|
@@ -88,48 +51,33 @@ class BatchedPenalizerOrchestrator:
|
|
88
51
|
Returns:
|
89
52
|
torch.Tensor: The logits after applying the penalizers.
|
90
53
|
"""
|
91
|
-
if not self.is_required:
|
92
|
-
return
|
93
|
-
|
94
54
|
for penalizer in self.penalizers.values():
|
95
|
-
|
96
|
-
|
97
|
-
return logits
|
55
|
+
penalizer.apply(logits)
|
98
56
|
|
99
|
-
def filter(
|
100
|
-
self,
|
101
|
-
indices_to_keep: List[int],
|
102
|
-
indices_tensor_to_keep: torch.Tensor = None,
|
103
|
-
):
|
57
|
+
def filter(self, keep_indices: torch.Tensor):
|
104
58
|
"""
|
105
59
|
Filter the penalizers based on the indices to keep in the batch.
|
106
60
|
|
107
61
|
Args:
|
108
|
-
|
109
|
-
indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
|
62
|
+
keep_indices (torch.Tensor): Tensor of indices to keep in the batch.
|
110
63
|
"""
|
111
64
|
if not self.is_required:
|
112
65
|
return
|
113
66
|
|
114
|
-
|
67
|
+
if len(keep_indices) == 0:
|
68
|
+
self.is_required = False
|
69
|
+
for penalizer in self.penalizers.values():
|
70
|
+
penalizer.teardown()
|
71
|
+
return
|
115
72
|
|
116
73
|
is_required = False
|
117
74
|
for penalizer in self.penalizers.values():
|
118
75
|
tmp_is_required = penalizer.is_required()
|
119
|
-
is_required
|
120
|
-
if
|
121
|
-
penalizer.
|
76
|
+
is_required |= tmp_is_required
|
77
|
+
if tmp_is_required:
|
78
|
+
penalizer.filter(keep_indices=keep_indices)
|
122
79
|
else:
|
123
|
-
|
124
|
-
if indices_tensor_to_keep is None:
|
125
|
-
indices_tensor_to_keep = torch.tensor(
|
126
|
-
indices_to_keep, dtype=torch.int32, device=self.device
|
127
|
-
)
|
128
|
-
|
129
|
-
penalizer.filter(
|
130
|
-
indices_to_keep=indices_to_keep,
|
131
|
-
indices_tensor_to_keep=indices_tensor_to_keep,
|
132
|
-
)
|
80
|
+
penalizer.teardown()
|
133
81
|
self.is_required = is_required
|
134
82
|
|
135
83
|
def merge(self, their: "BatchedPenalizerOrchestrator"):
|
@@ -146,75 +94,9 @@ class BatchedPenalizerOrchestrator:
|
|
146
94
|
if not self.is_required and not their.is_required:
|
147
95
|
return
|
148
96
|
|
149
|
-
self.is_required
|
150
|
-
for
|
151
|
-
|
152
|
-
raise ValueError(f"Penalizer {Penalizer} not found in self.penalizers")
|
153
|
-
|
154
|
-
self.penalizers[Penalizer].merge(their_penalizer)
|
155
|
-
|
156
|
-
|
157
|
-
class _TokenIDs:
|
158
|
-
"""
|
159
|
-
A class that wraps token IDs to provide additional utility functions to penalizers.
|
160
|
-
|
161
|
-
Attributes:
|
162
|
-
orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to.
|
163
|
-
token_ids (Union[torch.Tensor, List[torch.Tensor]]): The token IDs.
|
164
|
-
cached_counts (torch.Tensor): The cached occurrence count tensor.
|
165
|
-
"""
|
166
|
-
|
167
|
-
def __init__(
|
168
|
-
self,
|
169
|
-
orchestrator: BatchedPenalizerOrchestrator,
|
170
|
-
token_ids: Union[torch.Tensor, List[torch.Tensor]],
|
171
|
-
):
|
172
|
-
self.orchestrator = orchestrator
|
173
|
-
self.token_ids = token_ids
|
174
|
-
self.cached_counts = None
|
175
|
-
|
176
|
-
def occurrence_count(self) -> torch.Tensor:
|
177
|
-
"""
|
178
|
-
Returns a tensor of shape (batch_size, vocab_size) where each element is the number of times the corresponding token appears in the batch.
|
179
|
-
|
180
|
-
Returns:
|
181
|
-
torch.Tensor: The occurrence count tensor.
|
182
|
-
"""
|
183
|
-
if self.cached_counts is not None:
|
184
|
-
return self.cached_counts
|
185
|
-
|
186
|
-
token_ids = self.token_ids
|
187
|
-
|
188
|
-
if isinstance(token_ids, list):
|
189
|
-
# TODO: optimize this part
|
190
|
-
padded_token_ids = torch.nn.utils.rnn.pad_sequence(
|
191
|
-
sequences=token_ids,
|
192
|
-
batch_first=True,
|
193
|
-
padding_value=self.orchestrator.vocab_size,
|
194
|
-
)
|
195
|
-
self.cached_counts = torch.zeros(
|
196
|
-
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
|
197
|
-
dtype=torch.int64,
|
198
|
-
device=self.orchestrator.device,
|
199
|
-
).scatter_add_(
|
200
|
-
dim=1,
|
201
|
-
index=padded_token_ids,
|
202
|
-
src=torch.ones_like(padded_token_ids),
|
203
|
-
)[
|
204
|
-
:, : self.orchestrator.vocab_size
|
205
|
-
]
|
206
|
-
else:
|
207
|
-
# TODO: optimize this part. We do not need to create this big tensor every time.
|
208
|
-
# We can directly apply the results on the logits.
|
209
|
-
self.cached_counts = torch.zeros(
|
210
|
-
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size),
|
211
|
-
device=self.orchestrator.device,
|
212
|
-
)
|
213
|
-
self.cached_counts[
|
214
|
-
torch.arange(len(token_ids), device=self.orchestrator.device), token_ids
|
215
|
-
] = 1
|
216
|
-
|
217
|
-
return self.cached_counts
|
97
|
+
self.is_required = True
|
98
|
+
for penalizer, their_penalizer in their.penalizers.items():
|
99
|
+
self.penalizers[penalizer].merge(their_penalizer)
|
218
100
|
|
219
101
|
|
220
102
|
class _BatchedPenalizer(abc.ABC):
|
@@ -222,10 +104,6 @@ class _BatchedPenalizer(abc.ABC):
|
|
222
104
|
An abstract class for a batched penalizer.
|
223
105
|
"""
|
224
106
|
|
225
|
-
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
|
226
|
-
self.orchestrator = orchestrator
|
227
|
-
self._is_prepared = False
|
228
|
-
|
229
107
|
def is_prepared(self) -> bool:
|
230
108
|
return self._is_prepared
|
231
109
|
|
@@ -233,51 +111,40 @@ class _BatchedPenalizer(abc.ABC):
|
|
233
111
|
return self._is_required()
|
234
112
|
|
235
113
|
def prepare(self):
|
236
|
-
if not self.
|
114
|
+
if not self._is_prepared:
|
237
115
|
self._prepare()
|
238
116
|
self._is_prepared = True
|
239
117
|
|
240
118
|
def prepare_if_required(self):
|
241
|
-
if self.
|
119
|
+
if self._is_required():
|
242
120
|
self.prepare()
|
243
121
|
return True
|
244
122
|
else:
|
245
123
|
return False
|
246
124
|
|
247
125
|
def teardown(self):
|
248
|
-
|
249
|
-
self._teardown()
|
250
|
-
self._is_prepared = False
|
251
|
-
|
252
|
-
def cumulate_input_tokens(self, input_ids: _TokenIDs):
|
253
|
-
if not self.is_prepared():
|
254
|
-
return
|
255
|
-
|
256
|
-
self._cumulate_input_tokens(input_ids=input_ids)
|
126
|
+
self._is_prepared = False
|
257
127
|
|
258
|
-
def cumulate_output_tokens(self, output_ids:
|
259
|
-
if not self.
|
128
|
+
def cumulate_output_tokens(self, output_ids: torch.Tensor):
|
129
|
+
if not self._is_prepared:
|
260
130
|
return
|
261
131
|
|
262
132
|
self._cumulate_output_tokens(output_ids=output_ids)
|
263
133
|
|
264
134
|
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
265
|
-
if not self.
|
266
|
-
return
|
135
|
+
if not self._is_prepared:
|
136
|
+
return
|
267
137
|
|
268
|
-
|
138
|
+
self._apply(logits=logits)
|
269
139
|
|
270
|
-
def filter(self,
|
271
|
-
if not self.
|
140
|
+
def filter(self, keep_indices: torch.Tensor):
|
141
|
+
if not self._is_prepared:
|
272
142
|
return
|
273
143
|
|
274
|
-
self._filter(
|
275
|
-
indices_to_keep=indices_to_keep,
|
276
|
-
indices_tensor_to_keep=indices_tensor_to_keep,
|
277
|
-
)
|
144
|
+
self._filter(keep_indices=keep_indices)
|
278
145
|
|
279
146
|
def merge(self, their: "_BatchedPenalizer"):
|
280
|
-
if not self.
|
147
|
+
if not self._is_prepared and not their._is_prepared:
|
281
148
|
return
|
282
149
|
|
283
150
|
self.prepare()
|
@@ -300,23 +167,7 @@ class _BatchedPenalizer(abc.ABC):
|
|
300
167
|
pass
|
301
168
|
|
302
169
|
@abc.abstractmethod
|
303
|
-
def
|
304
|
-
"""
|
305
|
-
Tear down the penalizer.
|
306
|
-
Usually, this is where the penalizer frees its tensors.
|
307
|
-
"""
|
308
|
-
pass
|
309
|
-
|
310
|
-
@abc.abstractmethod
|
311
|
-
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
312
|
-
"""
|
313
|
-
Cumulate the input tokens.
|
314
|
-
Orchestrator will call this function to feed the input tokens to the penalizer.
|
315
|
-
"""
|
316
|
-
pass
|
317
|
-
|
318
|
-
@abc.abstractmethod
|
319
|
-
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
170
|
+
def _cumulate_output_tokens(self, output_ids: torch.Tensor):
|
320
171
|
"""
|
321
172
|
Cumulate the output tokens.
|
322
173
|
Orchestrator will call this function to feed the output tokens to the penalizer.
|
@@ -332,7 +183,7 @@ class _BatchedPenalizer(abc.ABC):
|
|
332
183
|
pass
|
333
184
|
|
334
185
|
@abc.abstractmethod
|
335
|
-
def _filter(self,
|
186
|
+
def _filter(self, keep_indices: torch.Tensor):
|
336
187
|
"""
|
337
188
|
Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch.
|
338
189
|
"""
|
@@ -0,0 +1,66 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
from sglang.srt.sampling.penaltylib.orchestrator import (
|
4
|
+
BatchedPenalizerOrchestrator,
|
5
|
+
_BatchedPenalizer,
|
6
|
+
)
|
7
|
+
|
8
|
+
|
9
|
+
class BatchedPresencePenalizer(_BatchedPenalizer):
|
10
|
+
"""
|
11
|
+
Presence penalizer penalizes tokens based on their presence in the output.
|
12
|
+
"""
|
13
|
+
|
14
|
+
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
|
15
|
+
self.orchestrator = orchestrator
|
16
|
+
self._is_prepared = False
|
17
|
+
|
18
|
+
def _is_required(self) -> bool:
|
19
|
+
return any(
|
20
|
+
req.sampling_params.presence_penalty != 0.0
|
21
|
+
for req in self.orchestrator.reqs()
|
22
|
+
)
|
23
|
+
|
24
|
+
def _prepare(self):
|
25
|
+
self.cumulated_presence_penalties = torch.zeros(
|
26
|
+
(len(self.orchestrator.reqs()), self.orchestrator.vocab_size),
|
27
|
+
dtype=torch.float32,
|
28
|
+
device=self.orchestrator.device,
|
29
|
+
)
|
30
|
+
|
31
|
+
self.presence_penalties = (
|
32
|
+
torch.tensor(
|
33
|
+
data=[
|
34
|
+
req.sampling_params.presence_penalty
|
35
|
+
for req in self.orchestrator.reqs()
|
36
|
+
],
|
37
|
+
dtype=torch.float32,
|
38
|
+
device=self.orchestrator.device,
|
39
|
+
)
|
40
|
+
).unsqueeze_(1)
|
41
|
+
|
42
|
+
def _cumulate_output_tokens(self, output_ids: torch.Tensor):
|
43
|
+
self.cumulated_presence_penalties.scatter_(
|
44
|
+
dim=1,
|
45
|
+
index=output_ids.unsqueeze(1),
|
46
|
+
src=self.presence_penalties,
|
47
|
+
)
|
48
|
+
|
49
|
+
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
50
|
+
logits.sub_(self.cumulated_presence_penalties)
|
51
|
+
|
52
|
+
def _filter(self, keep_indices: torch.Tensor):
|
53
|
+
self.presence_penalties = self.presence_penalties[keep_indices]
|
54
|
+
self.cumulated_presence_penalties = self.cumulated_presence_penalties[
|
55
|
+
keep_indices
|
56
|
+
]
|
57
|
+
|
58
|
+
def _merge(self, their: "BatchedPresencePenalizer"):
|
59
|
+
print(f"{self.presence_penalties.shape=}, {their.presence_penalties.shape=}")
|
60
|
+
self.presence_penalties = torch.cat(
|
61
|
+
[self.presence_penalties, their.presence_penalties], dim=0
|
62
|
+
)
|
63
|
+
self.cumulated_presence_penalties = torch.cat(
|
64
|
+
[self.cumulated_presence_penalties, their.cumulated_presence_penalties],
|
65
|
+
dim=0,
|
66
|
+
)
|