sglang 0.4.3.post1__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +208 -295
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +9 -6
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +143 -6
- sglang/srt/managers/schedule_batch.py +238 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +681 -259
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +224 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +44 -18
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +94 -36
- sglang/srt/model_executor/cuda_graph_runner.py +55 -24
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +209 -28
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -29
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +136 -52
- sglang/srt/speculative/build_eagle_tree.py +2 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
- sglang/srt/speculative/eagle_utils.py +92 -58
- sglang/srt/speculative/eagle_worker.py +186 -94
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +214 -166
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
sglang/bench_latency.py
DELETED
@@ -1 +0,0 @@
|
|
1
|
-
raise ValueError("bench_latency.py has been renamed to bench_one_batch.py")
|
@@ -1,75 +0,0 @@
|
|
1
|
-
from typing import List
|
2
|
-
|
3
|
-
import torch
|
4
|
-
|
5
|
-
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
6
|
-
|
7
|
-
|
8
|
-
class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
9
|
-
"""
|
10
|
-
Frequency penalizer penalizes tokens based on their frequency in the output.
|
11
|
-
"""
|
12
|
-
|
13
|
-
frequency_penalties: torch.Tensor = None
|
14
|
-
cumulated_frequency_penalties: torch.Tensor = None
|
15
|
-
|
16
|
-
def _is_required(self) -> bool:
|
17
|
-
return any(
|
18
|
-
req.sampling_params.frequency_penalty != 0.0
|
19
|
-
for req in self.orchestrator.reqs()
|
20
|
-
)
|
21
|
-
|
22
|
-
def _prepare(self):
|
23
|
-
self.cumulated_frequency_penalties = (
|
24
|
-
torch.tensor(
|
25
|
-
data=[0.0 for _ in self.orchestrator.reqs()],
|
26
|
-
dtype=torch.float32,
|
27
|
-
device=self.orchestrator.device,
|
28
|
-
)
|
29
|
-
.unsqueeze_(1)
|
30
|
-
.repeat(1, self.orchestrator.vocab_size)
|
31
|
-
)
|
32
|
-
|
33
|
-
self.frequency_penalties = (
|
34
|
-
torch.tensor(
|
35
|
-
data=[
|
36
|
-
req.sampling_params.frequency_penalty
|
37
|
-
for req in self.orchestrator.reqs()
|
38
|
-
],
|
39
|
-
dtype=torch.float32,
|
40
|
-
device=self.orchestrator.device,
|
41
|
-
)
|
42
|
-
.unsqueeze_(1)
|
43
|
-
.expand_as(self.cumulated_frequency_penalties)
|
44
|
-
)
|
45
|
-
|
46
|
-
def _teardown(self):
|
47
|
-
self.frequency_penalties = None
|
48
|
-
self.cumulated_frequency_penalties = None
|
49
|
-
|
50
|
-
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
51
|
-
pass
|
52
|
-
|
53
|
-
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
54
|
-
self.cumulated_frequency_penalties += (
|
55
|
-
self.frequency_penalties * output_ids.occurrence_count()
|
56
|
-
)
|
57
|
-
|
58
|
-
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
59
|
-
logits -= self.cumulated_frequency_penalties
|
60
|
-
return logits
|
61
|
-
|
62
|
-
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
63
|
-
self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep]
|
64
|
-
self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
|
65
|
-
indices_tensor_to_keep
|
66
|
-
]
|
67
|
-
|
68
|
-
def _merge(self, their: "BatchedFrequencyPenalizer"):
|
69
|
-
self.frequency_penalties = torch.cat(
|
70
|
-
[self.frequency_penalties, their.frequency_penalties], dim=0
|
71
|
-
)
|
72
|
-
self.cumulated_frequency_penalties = torch.cat(
|
73
|
-
[self.cumulated_frequency_penalties, their.cumulated_frequency_penalties],
|
74
|
-
dim=0,
|
75
|
-
)
|
@@ -1,74 +0,0 @@
|
|
1
|
-
from typing import List
|
2
|
-
|
3
|
-
import torch
|
4
|
-
|
5
|
-
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
6
|
-
|
7
|
-
|
8
|
-
class BatchedPresencePenalizer(_BatchedPenalizer):
|
9
|
-
"""
|
10
|
-
Presence penalizer penalizes tokens based on their presence in the output.
|
11
|
-
"""
|
12
|
-
|
13
|
-
presence_penalties: torch.Tensor = None
|
14
|
-
cumulated_presence_penalties: torch.Tensor = None
|
15
|
-
|
16
|
-
def _is_required(self) -> bool:
|
17
|
-
return any(
|
18
|
-
req.sampling_params.presence_penalty != 0.0
|
19
|
-
for req in self.orchestrator.reqs()
|
20
|
-
)
|
21
|
-
|
22
|
-
def _prepare(self):
|
23
|
-
self.cumulated_presence_penalties = (
|
24
|
-
torch.tensor(
|
25
|
-
data=[0.0 for _ in self.orchestrator.reqs()],
|
26
|
-
dtype=torch.float32,
|
27
|
-
device=self.orchestrator.device,
|
28
|
-
)
|
29
|
-
.unsqueeze_(1)
|
30
|
-
.repeat(1, self.orchestrator.vocab_size)
|
31
|
-
)
|
32
|
-
|
33
|
-
self.presence_penalties = (
|
34
|
-
torch.tensor(
|
35
|
-
data=[
|
36
|
-
req.sampling_params.presence_penalty
|
37
|
-
for req in self.orchestrator.reqs()
|
38
|
-
],
|
39
|
-
dtype=torch.float32,
|
40
|
-
device=self.orchestrator.device,
|
41
|
-
)
|
42
|
-
.unsqueeze_(1)
|
43
|
-
.expand_as(self.cumulated_presence_penalties)
|
44
|
-
)
|
45
|
-
|
46
|
-
def _teardown(self):
|
47
|
-
self.presence_penalties = None
|
48
|
-
self.cumulated_presence_penalties = None
|
49
|
-
|
50
|
-
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
51
|
-
pass
|
52
|
-
|
53
|
-
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
54
|
-
mask = output_ids.occurrence_count() > 0
|
55
|
-
self.cumulated_presence_penalties[mask] = self.presence_penalties[mask]
|
56
|
-
|
57
|
-
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
58
|
-
logits -= self.cumulated_presence_penalties
|
59
|
-
return logits
|
60
|
-
|
61
|
-
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
62
|
-
self.presence_penalties = self.presence_penalties[indices_tensor_to_keep]
|
63
|
-
self.cumulated_presence_penalties = self.cumulated_presence_penalties[
|
64
|
-
indices_tensor_to_keep
|
65
|
-
]
|
66
|
-
|
67
|
-
def _merge(self, their: "BatchedPresencePenalizer"):
|
68
|
-
self.presence_penalties = torch.cat(
|
69
|
-
[self.presence_penalties, their.presence_penalties], dim=0
|
70
|
-
)
|
71
|
-
self.cumulated_presence_penalties = torch.cat(
|
72
|
-
[self.cumulated_presence_penalties, their.cumulated_presence_penalties],
|
73
|
-
dim=0,
|
74
|
-
)
|
@@ -1,85 +0,0 @@
|
|
1
|
-
from typing import List
|
2
|
-
|
3
|
-
import torch
|
4
|
-
|
5
|
-
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
6
|
-
from sglang.srt.utils import get_compiler_backend
|
7
|
-
|
8
|
-
|
9
|
-
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
10
|
-
def apply_scaling_penalties(logits, scaling_penalties):
|
11
|
-
logits[:] = torch.where(
|
12
|
-
logits > 0,
|
13
|
-
logits / scaling_penalties,
|
14
|
-
logits * scaling_penalties,
|
15
|
-
)
|
16
|
-
|
17
|
-
|
18
|
-
class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
19
|
-
"""
|
20
|
-
Repetition penalizer penalizes tokens based on their repetition in the input and output.
|
21
|
-
"""
|
22
|
-
|
23
|
-
repetition_penalties: torch.Tensor = None
|
24
|
-
cumulated_repetition_penalties: torch.Tensor = None
|
25
|
-
|
26
|
-
def _is_required(self) -> bool:
|
27
|
-
return any(
|
28
|
-
req.sampling_params.repetition_penalty != 1.0
|
29
|
-
for req in self.orchestrator.reqs()
|
30
|
-
)
|
31
|
-
|
32
|
-
def _prepare(self):
|
33
|
-
self.cumulated_repetition_penalties = (
|
34
|
-
torch.tensor(
|
35
|
-
data=[1.0 for _ in self.orchestrator.reqs()],
|
36
|
-
dtype=torch.float32,
|
37
|
-
device=self.orchestrator.device,
|
38
|
-
)
|
39
|
-
.unsqueeze_(1)
|
40
|
-
.repeat(1, self.orchestrator.vocab_size)
|
41
|
-
)
|
42
|
-
|
43
|
-
self.repetition_penalties = (
|
44
|
-
torch.tensor(
|
45
|
-
data=[
|
46
|
-
req.sampling_params.repetition_penalty
|
47
|
-
for req in self.orchestrator.reqs()
|
48
|
-
],
|
49
|
-
dtype=torch.float32,
|
50
|
-
device=self.orchestrator.device,
|
51
|
-
)
|
52
|
-
.unsqueeze_(1)
|
53
|
-
.expand_as(self.cumulated_repetition_penalties)
|
54
|
-
)
|
55
|
-
|
56
|
-
def _teardown(self):
|
57
|
-
self.repetition_penalties = None
|
58
|
-
self.cumulated_repetition_penalties = None
|
59
|
-
|
60
|
-
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
61
|
-
mask = input_ids.occurrence_count() > 0
|
62
|
-
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
|
63
|
-
|
64
|
-
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
65
|
-
mask = output_ids.occurrence_count() > 0
|
66
|
-
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
|
67
|
-
|
68
|
-
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
69
|
-
apply_scaling_penalties(logits, self.cumulated_repetition_penalties)
|
70
|
-
return logits
|
71
|
-
|
72
|
-
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
73
|
-
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
|
74
|
-
self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[
|
75
|
-
indices_tensor_to_keep
|
76
|
-
]
|
77
|
-
|
78
|
-
def _merge(self, their: "BatchedRepetitionPenalizer"):
|
79
|
-
self.repetition_penalties = torch.cat(
|
80
|
-
[self.repetition_penalties, their.repetition_penalties], dim=0
|
81
|
-
)
|
82
|
-
self.cumulated_repetition_penalties = torch.cat(
|
83
|
-
[self.cumulated_repetition_penalties, their.cumulated_repetition_penalties],
|
84
|
-
dim=0,
|
85
|
-
)
|
@@ -1,344 +0,0 @@
|
|
1
|
-
import dataclasses
|
2
|
-
import enum
|
3
|
-
import unittest
|
4
|
-
from typing import Dict, List, Optional, Set, Tuple, Type
|
5
|
-
|
6
|
-
import torch
|
7
|
-
|
8
|
-
from sglang.srt.sampling.penaltylib.orchestrator import (
|
9
|
-
BatchedPenalizerOrchestrator,
|
10
|
-
_BatchedPenalizer,
|
11
|
-
_BatchLike,
|
12
|
-
)
|
13
|
-
|
14
|
-
|
15
|
-
@dataclasses.dataclass
|
16
|
-
class MockSamplingParams:
|
17
|
-
frequency_penalty: float = 0.0
|
18
|
-
min_new_tokens: int = 0
|
19
|
-
stop_token_ids: List[int] = None
|
20
|
-
presence_penalty: float = 0.0
|
21
|
-
repetition_penalty: float = 1.0
|
22
|
-
|
23
|
-
|
24
|
-
@dataclasses.dataclass
|
25
|
-
class MockTokenizer:
|
26
|
-
eos_token_id: int
|
27
|
-
additional_stop_token_ids: Optional[List[int]] = None
|
28
|
-
|
29
|
-
|
30
|
-
@dataclasses.dataclass
|
31
|
-
class MockReq:
|
32
|
-
origin_input_ids: List[int]
|
33
|
-
sampling_params: MockSamplingParams
|
34
|
-
tokenizer: MockTokenizer
|
35
|
-
|
36
|
-
|
37
|
-
class StepType(enum.Enum):
|
38
|
-
INPUT = "input"
|
39
|
-
OUTPUT = "output"
|
40
|
-
|
41
|
-
|
42
|
-
@dataclasses.dataclass
|
43
|
-
class Step:
|
44
|
-
type: StepType
|
45
|
-
token_ids: List[int]
|
46
|
-
expected_tensors: Dict[str, torch.Tensor]
|
47
|
-
# assume initial logits are all 1
|
48
|
-
expected_logits: torch.Tensor
|
49
|
-
|
50
|
-
|
51
|
-
@dataclasses.dataclass
|
52
|
-
class Subject:
|
53
|
-
sampling_params: MockSamplingParams
|
54
|
-
# first step must be input, which will be converted to Req
|
55
|
-
steps: List[Step]
|
56
|
-
eos_token_id: int = -1
|
57
|
-
|
58
|
-
def __post_init__(self):
|
59
|
-
if self.steps[0].type != StepType.INPUT:
|
60
|
-
raise ValueError("First step must be input")
|
61
|
-
|
62
|
-
# each steps should have the same expected_tensors.keys()
|
63
|
-
for i in range(1, len(self.steps)):
|
64
|
-
if self.tensor_keys(i) != self.tensor_keys():
|
65
|
-
raise ValueError(
|
66
|
-
f"Expected tensors keys must be the same for all steps. Got {self.steps[i].expected_tensors.keys()} for key={i} and {self.steps[0].expected_tensors.keys()}"
|
67
|
-
)
|
68
|
-
|
69
|
-
def tensor_keys(self, i: int = 0) -> Set[str]:
|
70
|
-
return set(self.steps[i].expected_tensors.keys())
|
71
|
-
|
72
|
-
def to_req(self) -> MockReq:
|
73
|
-
return MockReq(
|
74
|
-
origin_input_ids=self.steps[0].token_ids,
|
75
|
-
sampling_params=self.sampling_params,
|
76
|
-
tokenizer=MockTokenizer(eos_token_id=self.eos_token_id),
|
77
|
-
)
|
78
|
-
|
79
|
-
|
80
|
-
@dataclasses.dataclass
|
81
|
-
class Case:
|
82
|
-
enabled: bool
|
83
|
-
test_subjects: List[Subject]
|
84
|
-
|
85
|
-
def __post_init__(self):
|
86
|
-
# each test_subjects.steps should have the same expected_tensors.keys()
|
87
|
-
for i in range(1, len(self.test_subjects)):
|
88
|
-
if self.tensor_keys(i) != self.tensor_keys():
|
89
|
-
raise ValueError(
|
90
|
-
f"Expected tensors keys must be the same for all test_subjects. Got {self.test_subjects[i].tensor_keys()} for key={i} and {self.test_subjects[0].tensor_keys()}"
|
91
|
-
)
|
92
|
-
|
93
|
-
def tensor_keys(self, i: int = 0) -> List[str]:
|
94
|
-
return set(self.test_subjects[i].tensor_keys())
|
95
|
-
|
96
|
-
|
97
|
-
class BaseBatchedPenalizerTest(unittest.TestCase):
|
98
|
-
Penalizer: Type[_BatchedPenalizer]
|
99
|
-
device = "cuda"
|
100
|
-
vocab_size = 5
|
101
|
-
|
102
|
-
enabled: Subject = None
|
103
|
-
disabled: Subject = None
|
104
|
-
|
105
|
-
def setUp(self):
|
106
|
-
if self.__class__ == BaseBatchedPenalizerTest:
|
107
|
-
self.skipTest("Base class for penalizer tests")
|
108
|
-
|
109
|
-
self.create_test_subjects()
|
110
|
-
self.create_test_cases()
|
111
|
-
|
112
|
-
def tensor(self, data, **kwargs) -> torch.Tensor:
|
113
|
-
"""
|
114
|
-
Shortcut to create a tensor with device=self.device.
|
115
|
-
"""
|
116
|
-
return torch.tensor(data, **kwargs, device=self.device)
|
117
|
-
|
118
|
-
def create_test_subjects(self) -> List[Subject]:
|
119
|
-
raise NotImplementedError()
|
120
|
-
|
121
|
-
def create_test_cases(self):
|
122
|
-
self.test_cases = [
|
123
|
-
Case(enabled=True, test_subjects=[self.enabled]),
|
124
|
-
Case(enabled=False, test_subjects=[self.disabled]),
|
125
|
-
Case(enabled=True, test_subjects=[self.enabled, self.disabled]),
|
126
|
-
]
|
127
|
-
|
128
|
-
def _create_penalizer(
|
129
|
-
self, case: Case
|
130
|
-
) -> Tuple[BatchedPenalizerOrchestrator, _BatchedPenalizer]:
|
131
|
-
orchestrator = BatchedPenalizerOrchestrator(
|
132
|
-
vocab_size=self.vocab_size,
|
133
|
-
batch=_BatchLike(reqs=[subject.to_req() for subject in case.test_subjects]),
|
134
|
-
device=self.device,
|
135
|
-
Penalizers={self.Penalizer},
|
136
|
-
)
|
137
|
-
|
138
|
-
return orchestrator, orchestrator.penalizers[self.Penalizer]
|
139
|
-
|
140
|
-
def test_is_required(self):
|
141
|
-
for case in self.test_cases:
|
142
|
-
with self.subTest(case=case):
|
143
|
-
_, penalizer = self._create_penalizer(case)
|
144
|
-
self.assertEqual(case.enabled, penalizer.is_required())
|
145
|
-
|
146
|
-
def test_prepare(self):
|
147
|
-
for case in self.test_cases:
|
148
|
-
with self.subTest(case=case):
|
149
|
-
orchestrator, penalizer = self._create_penalizer(case)
|
150
|
-
self.assertEqual(case.enabled, penalizer.is_prepared())
|
151
|
-
|
152
|
-
if case.enabled:
|
153
|
-
for key, tensor in {
|
154
|
-
key: torch.cat(
|
155
|
-
tensors=[
|
156
|
-
subject.steps[0].expected_tensors[key]
|
157
|
-
for subject in case.test_subjects
|
158
|
-
],
|
159
|
-
)
|
160
|
-
for key in case.tensor_keys()
|
161
|
-
}.items():
|
162
|
-
torch.testing.assert_close(
|
163
|
-
actual=getattr(penalizer, key),
|
164
|
-
expected=tensor,
|
165
|
-
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
|
166
|
-
)
|
167
|
-
|
168
|
-
original = torch.ones(
|
169
|
-
size=(len(case.test_subjects), self.vocab_size),
|
170
|
-
dtype=torch.float32,
|
171
|
-
device=self.device,
|
172
|
-
)
|
173
|
-
actual = orchestrator.apply(original.clone())
|
174
|
-
expected = torch.cat(
|
175
|
-
tensors=[
|
176
|
-
subject.steps[0].expected_logits
|
177
|
-
for subject in case.test_subjects
|
178
|
-
],
|
179
|
-
)
|
180
|
-
if actual is None:
|
181
|
-
actual = original
|
182
|
-
torch.testing.assert_close(
|
183
|
-
actual=actual,
|
184
|
-
expected=expected,
|
185
|
-
msg=f"logits\nactual={actual}\nexpected={expected}",
|
186
|
-
)
|
187
|
-
|
188
|
-
def test_teardown(self):
|
189
|
-
for case in self.test_cases:
|
190
|
-
with self.subTest(case=case):
|
191
|
-
_, penalizer = self._create_penalizer(case)
|
192
|
-
penalizer.teardown()
|
193
|
-
|
194
|
-
for key in case.test_subjects[0].steps[0].expected_tensors.keys():
|
195
|
-
self.assertIsNone(getattr(penalizer, key, None))
|
196
|
-
|
197
|
-
def test_filter(self):
|
198
|
-
for case in self.test_cases:
|
199
|
-
with self.subTest(case=case):
|
200
|
-
orchestrator, penalizer = self._create_penalizer(case)
|
201
|
-
|
202
|
-
indices_to_keep = [0]
|
203
|
-
orchestrator.filter(indices_to_keep=indices_to_keep)
|
204
|
-
|
205
|
-
filtered_subjects = [case.test_subjects[i] for i in indices_to_keep]
|
206
|
-
|
207
|
-
if penalizer.is_required():
|
208
|
-
self.assertTrue(penalizer.is_prepared())
|
209
|
-
for key, tensor in {
|
210
|
-
key: torch.cat(
|
211
|
-
tensors=[
|
212
|
-
subject.steps[0].expected_tensors[key]
|
213
|
-
for subject in filtered_subjects
|
214
|
-
],
|
215
|
-
)
|
216
|
-
for key in case.tensor_keys()
|
217
|
-
}.items():
|
218
|
-
torch.testing.assert_close(
|
219
|
-
actual=getattr(penalizer, key),
|
220
|
-
expected=tensor,
|
221
|
-
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
|
222
|
-
)
|
223
|
-
|
224
|
-
actual_logits = orchestrator.apply(
|
225
|
-
torch.ones(
|
226
|
-
size=(len(filtered_subjects), self.vocab_size),
|
227
|
-
dtype=torch.float32,
|
228
|
-
device=self.device,
|
229
|
-
)
|
230
|
-
)
|
231
|
-
if actual_logits is None:
|
232
|
-
continue
|
233
|
-
filtered_expected_logits = torch.cat(
|
234
|
-
tensors=[
|
235
|
-
subject.steps[0].expected_logits
|
236
|
-
for subject in filtered_subjects
|
237
|
-
],
|
238
|
-
)
|
239
|
-
torch.testing.assert_close(
|
240
|
-
actual=actual_logits,
|
241
|
-
expected=filtered_expected_logits,
|
242
|
-
msg=f"logits\nactual={actual_logits}\nexpected={filtered_expected_logits}",
|
243
|
-
)
|
244
|
-
|
245
|
-
def test_merge_enabled_with_disabled(self):
|
246
|
-
enabled_test_case = self.test_cases[0]
|
247
|
-
disabled_test_case = self.test_cases[1]
|
248
|
-
|
249
|
-
orchestrator, penalizer = self._create_penalizer(enabled_test_case)
|
250
|
-
theirs, _ = self._create_penalizer(disabled_test_case)
|
251
|
-
|
252
|
-
orchestrator.merge(theirs)
|
253
|
-
|
254
|
-
for key, tensor in {
|
255
|
-
key: torch.cat(
|
256
|
-
tensors=[
|
257
|
-
enabled_test_case.test_subjects[0].steps[0].expected_tensors[key],
|
258
|
-
disabled_test_case.test_subjects[0].steps[0].expected_tensors[key],
|
259
|
-
],
|
260
|
-
)
|
261
|
-
for key in enabled_test_case.tensor_keys()
|
262
|
-
}.items():
|
263
|
-
torch.testing.assert_close(
|
264
|
-
actual=getattr(penalizer, key),
|
265
|
-
expected=tensor,
|
266
|
-
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
|
267
|
-
)
|
268
|
-
|
269
|
-
def test_cumulate_apply_repeat(self):
|
270
|
-
for case in self.test_cases:
|
271
|
-
with self.subTest(case=case):
|
272
|
-
orchestrator, penalizer = self._create_penalizer(case)
|
273
|
-
|
274
|
-
max_step = max(len(subject.steps) for subject in case.test_subjects)
|
275
|
-
for i in range(1, max_step):
|
276
|
-
orchestrator.filter(
|
277
|
-
indices_to_keep=[
|
278
|
-
j
|
279
|
-
for j, subject in enumerate(case.test_subjects)
|
280
|
-
if i < len(subject.steps)
|
281
|
-
]
|
282
|
-
)
|
283
|
-
|
284
|
-
filtered_subjects = [
|
285
|
-
subject
|
286
|
-
for subject in case.test_subjects
|
287
|
-
if i < len(subject.steps)
|
288
|
-
]
|
289
|
-
|
290
|
-
inputs: List[List[int]] = []
|
291
|
-
outputs: List[List[int]] = []
|
292
|
-
for subject in filtered_subjects:
|
293
|
-
step = subject.steps[i]
|
294
|
-
if step.type == StepType.INPUT:
|
295
|
-
raise NotImplementedError()
|
296
|
-
else:
|
297
|
-
inputs.append([])
|
298
|
-
outputs.append(step.token_ids)
|
299
|
-
|
300
|
-
if any(outputs):
|
301
|
-
for j in range(max(len(x) for x in outputs)):
|
302
|
-
tmp_outputs = torch.tensor(
|
303
|
-
[x[j] for x in outputs],
|
304
|
-
dtype=torch.int32,
|
305
|
-
device=orchestrator.device,
|
306
|
-
)
|
307
|
-
orchestrator.cumulate_output_tokens(tmp_outputs)
|
308
|
-
|
309
|
-
if penalizer.is_required():
|
310
|
-
self.assertTrue(penalizer.is_prepared())
|
311
|
-
for key, tensor in {
|
312
|
-
key: torch.cat(
|
313
|
-
tensors=[
|
314
|
-
subject.steps[i].expected_tensors[key]
|
315
|
-
for subject in filtered_subjects
|
316
|
-
],
|
317
|
-
)
|
318
|
-
for key in case.tensor_keys()
|
319
|
-
}.items():
|
320
|
-
torch.testing.assert_close(
|
321
|
-
actual=getattr(penalizer, key),
|
322
|
-
expected=tensor,
|
323
|
-
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
|
324
|
-
)
|
325
|
-
|
326
|
-
original = torch.ones(
|
327
|
-
size=(len(filtered_subjects), self.vocab_size),
|
328
|
-
dtype=torch.float32,
|
329
|
-
device=self.device,
|
330
|
-
)
|
331
|
-
actual_logits = orchestrator.apply(original.clone())
|
332
|
-
filtered_expected_logits = torch.cat(
|
333
|
-
tensors=[
|
334
|
-
subject.steps[i].expected_logits
|
335
|
-
for subject in filtered_subjects
|
336
|
-
],
|
337
|
-
)
|
338
|
-
if actual_logits is None:
|
339
|
-
actual_logits = original
|
340
|
-
torch.testing.assert_close(
|
341
|
-
actual=actual_logits,
|
342
|
-
expected=filtered_expected_logits,
|
343
|
-
msg=f"logits\nactual={actual_logits}\nexpected={filtered_expected_logits}",
|
344
|
-
)
|
File without changes
|
File without changes
|