sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +302 -414
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +13 -8
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +144 -6
- sglang/srt/managers/schedule_batch.py +237 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +773 -334
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +225 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +68 -37
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +102 -36
- sglang/srt/model_executor/cuda_graph_runner.py +56 -31
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +280 -81
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -32
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +135 -60
- sglang/srt/speculative/build_eagle_tree.py +8 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
- sglang/srt/speculative/eagle_utils.py +92 -57
- sglang/srt/speculative/eagle_worker.py +238 -111
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
sglang/bench_latency.py
DELETED
@@ -1 +0,0 @@
|
|
1
|
-
raise ValueError("bench_latency.py has been renamed to bench_one_batch.py")
|
@@ -1,75 +0,0 @@
|
|
1
|
-
from typing import List
|
2
|
-
|
3
|
-
import torch
|
4
|
-
|
5
|
-
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
6
|
-
|
7
|
-
|
8
|
-
class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
9
|
-
"""
|
10
|
-
Frequency penalizer penalizes tokens based on their frequency in the output.
|
11
|
-
"""
|
12
|
-
|
13
|
-
frequency_penalties: torch.Tensor = None
|
14
|
-
cumulated_frequency_penalties: torch.Tensor = None
|
15
|
-
|
16
|
-
def _is_required(self) -> bool:
|
17
|
-
return any(
|
18
|
-
req.sampling_params.frequency_penalty != 0.0
|
19
|
-
for req in self.orchestrator.reqs()
|
20
|
-
)
|
21
|
-
|
22
|
-
def _prepare(self):
|
23
|
-
self.cumulated_frequency_penalties = (
|
24
|
-
torch.tensor(
|
25
|
-
data=[0.0 for _ in self.orchestrator.reqs()],
|
26
|
-
dtype=torch.float32,
|
27
|
-
device=self.orchestrator.device,
|
28
|
-
)
|
29
|
-
.unsqueeze_(1)
|
30
|
-
.repeat(1, self.orchestrator.vocab_size)
|
31
|
-
)
|
32
|
-
|
33
|
-
self.frequency_penalties = (
|
34
|
-
torch.tensor(
|
35
|
-
data=[
|
36
|
-
req.sampling_params.frequency_penalty
|
37
|
-
for req in self.orchestrator.reqs()
|
38
|
-
],
|
39
|
-
dtype=torch.float32,
|
40
|
-
device=self.orchestrator.device,
|
41
|
-
)
|
42
|
-
.unsqueeze_(1)
|
43
|
-
.expand_as(self.cumulated_frequency_penalties)
|
44
|
-
)
|
45
|
-
|
46
|
-
def _teardown(self):
|
47
|
-
self.frequency_penalties = None
|
48
|
-
self.cumulated_frequency_penalties = None
|
49
|
-
|
50
|
-
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
51
|
-
pass
|
52
|
-
|
53
|
-
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
54
|
-
self.cumulated_frequency_penalties += (
|
55
|
-
self.frequency_penalties * output_ids.occurrence_count()
|
56
|
-
)
|
57
|
-
|
58
|
-
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
59
|
-
logits -= self.cumulated_frequency_penalties
|
60
|
-
return logits
|
61
|
-
|
62
|
-
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
63
|
-
self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep]
|
64
|
-
self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
|
65
|
-
indices_tensor_to_keep
|
66
|
-
]
|
67
|
-
|
68
|
-
def _merge(self, their: "BatchedFrequencyPenalizer"):
|
69
|
-
self.frequency_penalties = torch.cat(
|
70
|
-
[self.frequency_penalties, their.frequency_penalties], dim=0
|
71
|
-
)
|
72
|
-
self.cumulated_frequency_penalties = torch.cat(
|
73
|
-
[self.cumulated_frequency_penalties, their.cumulated_frequency_penalties],
|
74
|
-
dim=0,
|
75
|
-
)
|
@@ -1,74 +0,0 @@
|
|
1
|
-
from typing import List
|
2
|
-
|
3
|
-
import torch
|
4
|
-
|
5
|
-
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
6
|
-
|
7
|
-
|
8
|
-
class BatchedPresencePenalizer(_BatchedPenalizer):
|
9
|
-
"""
|
10
|
-
Presence penalizer penalizes tokens based on their presence in the output.
|
11
|
-
"""
|
12
|
-
|
13
|
-
presence_penalties: torch.Tensor = None
|
14
|
-
cumulated_presence_penalties: torch.Tensor = None
|
15
|
-
|
16
|
-
def _is_required(self) -> bool:
|
17
|
-
return any(
|
18
|
-
req.sampling_params.presence_penalty != 0.0
|
19
|
-
for req in self.orchestrator.reqs()
|
20
|
-
)
|
21
|
-
|
22
|
-
def _prepare(self):
|
23
|
-
self.cumulated_presence_penalties = (
|
24
|
-
torch.tensor(
|
25
|
-
data=[0.0 for _ in self.orchestrator.reqs()],
|
26
|
-
dtype=torch.float32,
|
27
|
-
device=self.orchestrator.device,
|
28
|
-
)
|
29
|
-
.unsqueeze_(1)
|
30
|
-
.repeat(1, self.orchestrator.vocab_size)
|
31
|
-
)
|
32
|
-
|
33
|
-
self.presence_penalties = (
|
34
|
-
torch.tensor(
|
35
|
-
data=[
|
36
|
-
req.sampling_params.presence_penalty
|
37
|
-
for req in self.orchestrator.reqs()
|
38
|
-
],
|
39
|
-
dtype=torch.float32,
|
40
|
-
device=self.orchestrator.device,
|
41
|
-
)
|
42
|
-
.unsqueeze_(1)
|
43
|
-
.expand_as(self.cumulated_presence_penalties)
|
44
|
-
)
|
45
|
-
|
46
|
-
def _teardown(self):
|
47
|
-
self.presence_penalties = None
|
48
|
-
self.cumulated_presence_penalties = None
|
49
|
-
|
50
|
-
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
51
|
-
pass
|
52
|
-
|
53
|
-
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
54
|
-
mask = output_ids.occurrence_count() > 0
|
55
|
-
self.cumulated_presence_penalties[mask] = self.presence_penalties[mask]
|
56
|
-
|
57
|
-
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
58
|
-
logits -= self.cumulated_presence_penalties
|
59
|
-
return logits
|
60
|
-
|
61
|
-
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
62
|
-
self.presence_penalties = self.presence_penalties[indices_tensor_to_keep]
|
63
|
-
self.cumulated_presence_penalties = self.cumulated_presence_penalties[
|
64
|
-
indices_tensor_to_keep
|
65
|
-
]
|
66
|
-
|
67
|
-
def _merge(self, their: "BatchedPresencePenalizer"):
|
68
|
-
self.presence_penalties = torch.cat(
|
69
|
-
[self.presence_penalties, their.presence_penalties], dim=0
|
70
|
-
)
|
71
|
-
self.cumulated_presence_penalties = torch.cat(
|
72
|
-
[self.cumulated_presence_penalties, their.cumulated_presence_penalties],
|
73
|
-
dim=0,
|
74
|
-
)
|
@@ -1,85 +0,0 @@
|
|
1
|
-
from typing import List
|
2
|
-
|
3
|
-
import torch
|
4
|
-
|
5
|
-
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
6
|
-
from sglang.srt.utils import get_compiler_backend
|
7
|
-
|
8
|
-
|
9
|
-
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
10
|
-
def apply_scaling_penalties(logits, scaling_penalties):
|
11
|
-
logits[:] = torch.where(
|
12
|
-
logits > 0,
|
13
|
-
logits / scaling_penalties,
|
14
|
-
logits * scaling_penalties,
|
15
|
-
)
|
16
|
-
|
17
|
-
|
18
|
-
class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
19
|
-
"""
|
20
|
-
Repetition penalizer penalizes tokens based on their repetition in the input and output.
|
21
|
-
"""
|
22
|
-
|
23
|
-
repetition_penalties: torch.Tensor = None
|
24
|
-
cumulated_repetition_penalties: torch.Tensor = None
|
25
|
-
|
26
|
-
def _is_required(self) -> bool:
|
27
|
-
return any(
|
28
|
-
req.sampling_params.repetition_penalty != 1.0
|
29
|
-
for req in self.orchestrator.reqs()
|
30
|
-
)
|
31
|
-
|
32
|
-
def _prepare(self):
|
33
|
-
self.cumulated_repetition_penalties = (
|
34
|
-
torch.tensor(
|
35
|
-
data=[1.0 for _ in self.orchestrator.reqs()],
|
36
|
-
dtype=torch.float32,
|
37
|
-
device=self.orchestrator.device,
|
38
|
-
)
|
39
|
-
.unsqueeze_(1)
|
40
|
-
.repeat(1, self.orchestrator.vocab_size)
|
41
|
-
)
|
42
|
-
|
43
|
-
self.repetition_penalties = (
|
44
|
-
torch.tensor(
|
45
|
-
data=[
|
46
|
-
req.sampling_params.repetition_penalty
|
47
|
-
for req in self.orchestrator.reqs()
|
48
|
-
],
|
49
|
-
dtype=torch.float32,
|
50
|
-
device=self.orchestrator.device,
|
51
|
-
)
|
52
|
-
.unsqueeze_(1)
|
53
|
-
.expand_as(self.cumulated_repetition_penalties)
|
54
|
-
)
|
55
|
-
|
56
|
-
def _teardown(self):
|
57
|
-
self.repetition_penalties = None
|
58
|
-
self.cumulated_repetition_penalties = None
|
59
|
-
|
60
|
-
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
61
|
-
mask = input_ids.occurrence_count() > 0
|
62
|
-
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
|
63
|
-
|
64
|
-
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
65
|
-
mask = output_ids.occurrence_count() > 0
|
66
|
-
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
|
67
|
-
|
68
|
-
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
69
|
-
apply_scaling_penalties(logits, self.cumulated_repetition_penalties)
|
70
|
-
return logits
|
71
|
-
|
72
|
-
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
73
|
-
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
|
74
|
-
self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[
|
75
|
-
indices_tensor_to_keep
|
76
|
-
]
|
77
|
-
|
78
|
-
def _merge(self, their: "BatchedRepetitionPenalizer"):
|
79
|
-
self.repetition_penalties = torch.cat(
|
80
|
-
[self.repetition_penalties, their.repetition_penalties], dim=0
|
81
|
-
)
|
82
|
-
self.cumulated_repetition_penalties = torch.cat(
|
83
|
-
[self.cumulated_repetition_penalties, their.cumulated_repetition_penalties],
|
84
|
-
dim=0,
|
85
|
-
)
|
@@ -1,344 +0,0 @@
|
|
1
|
-
import dataclasses
|
2
|
-
import enum
|
3
|
-
import unittest
|
4
|
-
from typing import Dict, List, Optional, Set, Tuple, Type
|
5
|
-
|
6
|
-
import torch
|
7
|
-
|
8
|
-
from sglang.srt.sampling.penaltylib.orchestrator import (
|
9
|
-
BatchedPenalizerOrchestrator,
|
10
|
-
_BatchedPenalizer,
|
11
|
-
_BatchLike,
|
12
|
-
)
|
13
|
-
|
14
|
-
|
15
|
-
@dataclasses.dataclass
|
16
|
-
class MockSamplingParams:
|
17
|
-
frequency_penalty: float = 0.0
|
18
|
-
min_new_tokens: int = 0
|
19
|
-
stop_token_ids: List[int] = None
|
20
|
-
presence_penalty: float = 0.0
|
21
|
-
repetition_penalty: float = 1.0
|
22
|
-
|
23
|
-
|
24
|
-
@dataclasses.dataclass
|
25
|
-
class MockTokenizer:
|
26
|
-
eos_token_id: int
|
27
|
-
additional_stop_token_ids: Optional[List[int]] = None
|
28
|
-
|
29
|
-
|
30
|
-
@dataclasses.dataclass
|
31
|
-
class MockReq:
|
32
|
-
origin_input_ids: List[int]
|
33
|
-
sampling_params: MockSamplingParams
|
34
|
-
tokenizer: MockTokenizer
|
35
|
-
|
36
|
-
|
37
|
-
class StepType(enum.Enum):
|
38
|
-
INPUT = "input"
|
39
|
-
OUTPUT = "output"
|
40
|
-
|
41
|
-
|
42
|
-
@dataclasses.dataclass
|
43
|
-
class Step:
|
44
|
-
type: StepType
|
45
|
-
token_ids: List[int]
|
46
|
-
expected_tensors: Dict[str, torch.Tensor]
|
47
|
-
# assume initial logits are all 1
|
48
|
-
expected_logits: torch.Tensor
|
49
|
-
|
50
|
-
|
51
|
-
@dataclasses.dataclass
|
52
|
-
class Subject:
|
53
|
-
sampling_params: MockSamplingParams
|
54
|
-
# first step must be input, which will be converted to Req
|
55
|
-
steps: List[Step]
|
56
|
-
eos_token_id: int = -1
|
57
|
-
|
58
|
-
def __post_init__(self):
|
59
|
-
if self.steps[0].type != StepType.INPUT:
|
60
|
-
raise ValueError("First step must be input")
|
61
|
-
|
62
|
-
# each steps should have the same expected_tensors.keys()
|
63
|
-
for i in range(1, len(self.steps)):
|
64
|
-
if self.tensor_keys(i) != self.tensor_keys():
|
65
|
-
raise ValueError(
|
66
|
-
f"Expected tensors keys must be the same for all steps. Got {self.steps[i].expected_tensors.keys()} for key={i} and {self.steps[0].expected_tensors.keys()}"
|
67
|
-
)
|
68
|
-
|
69
|
-
def tensor_keys(self, i: int = 0) -> Set[str]:
|
70
|
-
return set(self.steps[i].expected_tensors.keys())
|
71
|
-
|
72
|
-
def to_req(self) -> MockReq:
|
73
|
-
return MockReq(
|
74
|
-
origin_input_ids=self.steps[0].token_ids,
|
75
|
-
sampling_params=self.sampling_params,
|
76
|
-
tokenizer=MockTokenizer(eos_token_id=self.eos_token_id),
|
77
|
-
)
|
78
|
-
|
79
|
-
|
80
|
-
@dataclasses.dataclass
|
81
|
-
class Case:
|
82
|
-
enabled: bool
|
83
|
-
test_subjects: List[Subject]
|
84
|
-
|
85
|
-
def __post_init__(self):
|
86
|
-
# each test_subjects.steps should have the same expected_tensors.keys()
|
87
|
-
for i in range(1, len(self.test_subjects)):
|
88
|
-
if self.tensor_keys(i) != self.tensor_keys():
|
89
|
-
raise ValueError(
|
90
|
-
f"Expected tensors keys must be the same for all test_subjects. Got {self.test_subjects[i].tensor_keys()} for key={i} and {self.test_subjects[0].tensor_keys()}"
|
91
|
-
)
|
92
|
-
|
93
|
-
def tensor_keys(self, i: int = 0) -> List[str]:
|
94
|
-
return set(self.test_subjects[i].tensor_keys())
|
95
|
-
|
96
|
-
|
97
|
-
class BaseBatchedPenalizerTest(unittest.TestCase):
|
98
|
-
Penalizer: Type[_BatchedPenalizer]
|
99
|
-
device = "cuda"
|
100
|
-
vocab_size = 5
|
101
|
-
|
102
|
-
enabled: Subject = None
|
103
|
-
disabled: Subject = None
|
104
|
-
|
105
|
-
def setUp(self):
|
106
|
-
if self.__class__ == BaseBatchedPenalizerTest:
|
107
|
-
self.skipTest("Base class for penalizer tests")
|
108
|
-
|
109
|
-
self.create_test_subjects()
|
110
|
-
self.create_test_cases()
|
111
|
-
|
112
|
-
def tensor(self, data, **kwargs) -> torch.Tensor:
|
113
|
-
"""
|
114
|
-
Shortcut to create a tensor with device=self.device.
|
115
|
-
"""
|
116
|
-
return torch.tensor(data, **kwargs, device=self.device)
|
117
|
-
|
118
|
-
def create_test_subjects(self) -> List[Subject]:
|
119
|
-
raise NotImplementedError()
|
120
|
-
|
121
|
-
def create_test_cases(self):
|
122
|
-
self.test_cases = [
|
123
|
-
Case(enabled=True, test_subjects=[self.enabled]),
|
124
|
-
Case(enabled=False, test_subjects=[self.disabled]),
|
125
|
-
Case(enabled=True, test_subjects=[self.enabled, self.disabled]),
|
126
|
-
]
|
127
|
-
|
128
|
-
def _create_penalizer(
|
129
|
-
self, case: Case
|
130
|
-
) -> Tuple[BatchedPenalizerOrchestrator, _BatchedPenalizer]:
|
131
|
-
orchestrator = BatchedPenalizerOrchestrator(
|
132
|
-
vocab_size=self.vocab_size,
|
133
|
-
batch=_BatchLike(reqs=[subject.to_req() for subject in case.test_subjects]),
|
134
|
-
device=self.device,
|
135
|
-
Penalizers={self.Penalizer},
|
136
|
-
)
|
137
|
-
|
138
|
-
return orchestrator, orchestrator.penalizers[self.Penalizer]
|
139
|
-
|
140
|
-
def test_is_required(self):
|
141
|
-
for case in self.test_cases:
|
142
|
-
with self.subTest(case=case):
|
143
|
-
_, penalizer = self._create_penalizer(case)
|
144
|
-
self.assertEqual(case.enabled, penalizer.is_required())
|
145
|
-
|
146
|
-
def test_prepare(self):
|
147
|
-
for case in self.test_cases:
|
148
|
-
with self.subTest(case=case):
|
149
|
-
orchestrator, penalizer = self._create_penalizer(case)
|
150
|
-
self.assertEqual(case.enabled, penalizer.is_prepared())
|
151
|
-
|
152
|
-
if case.enabled:
|
153
|
-
for key, tensor in {
|
154
|
-
key: torch.cat(
|
155
|
-
tensors=[
|
156
|
-
subject.steps[0].expected_tensors[key]
|
157
|
-
for subject in case.test_subjects
|
158
|
-
],
|
159
|
-
)
|
160
|
-
for key in case.tensor_keys()
|
161
|
-
}.items():
|
162
|
-
torch.testing.assert_close(
|
163
|
-
actual=getattr(penalizer, key),
|
164
|
-
expected=tensor,
|
165
|
-
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
|
166
|
-
)
|
167
|
-
|
168
|
-
original = torch.ones(
|
169
|
-
size=(len(case.test_subjects), self.vocab_size),
|
170
|
-
dtype=torch.float32,
|
171
|
-
device=self.device,
|
172
|
-
)
|
173
|
-
actual = orchestrator.apply(original.clone())
|
174
|
-
expected = torch.cat(
|
175
|
-
tensors=[
|
176
|
-
subject.steps[0].expected_logits
|
177
|
-
for subject in case.test_subjects
|
178
|
-
],
|
179
|
-
)
|
180
|
-
if actual is None:
|
181
|
-
actual = original
|
182
|
-
torch.testing.assert_close(
|
183
|
-
actual=actual,
|
184
|
-
expected=expected,
|
185
|
-
msg=f"logits\nactual={actual}\nexpected={expected}",
|
186
|
-
)
|
187
|
-
|
188
|
-
def test_teardown(self):
|
189
|
-
for case in self.test_cases:
|
190
|
-
with self.subTest(case=case):
|
191
|
-
_, penalizer = self._create_penalizer(case)
|
192
|
-
penalizer.teardown()
|
193
|
-
|
194
|
-
for key in case.test_subjects[0].steps[0].expected_tensors.keys():
|
195
|
-
self.assertIsNone(getattr(penalizer, key, None))
|
196
|
-
|
197
|
-
def test_filter(self):
|
198
|
-
for case in self.test_cases:
|
199
|
-
with self.subTest(case=case):
|
200
|
-
orchestrator, penalizer = self._create_penalizer(case)
|
201
|
-
|
202
|
-
indices_to_keep = [0]
|
203
|
-
orchestrator.filter(indices_to_keep=indices_to_keep)
|
204
|
-
|
205
|
-
filtered_subjects = [case.test_subjects[i] for i in indices_to_keep]
|
206
|
-
|
207
|
-
if penalizer.is_required():
|
208
|
-
self.assertTrue(penalizer.is_prepared())
|
209
|
-
for key, tensor in {
|
210
|
-
key: torch.cat(
|
211
|
-
tensors=[
|
212
|
-
subject.steps[0].expected_tensors[key]
|
213
|
-
for subject in filtered_subjects
|
214
|
-
],
|
215
|
-
)
|
216
|
-
for key in case.tensor_keys()
|
217
|
-
}.items():
|
218
|
-
torch.testing.assert_close(
|
219
|
-
actual=getattr(penalizer, key),
|
220
|
-
expected=tensor,
|
221
|
-
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
|
222
|
-
)
|
223
|
-
|
224
|
-
actual_logits = orchestrator.apply(
|
225
|
-
torch.ones(
|
226
|
-
size=(len(filtered_subjects), self.vocab_size),
|
227
|
-
dtype=torch.float32,
|
228
|
-
device=self.device,
|
229
|
-
)
|
230
|
-
)
|
231
|
-
if actual_logits is None:
|
232
|
-
continue
|
233
|
-
filtered_expected_logits = torch.cat(
|
234
|
-
tensors=[
|
235
|
-
subject.steps[0].expected_logits
|
236
|
-
for subject in filtered_subjects
|
237
|
-
],
|
238
|
-
)
|
239
|
-
torch.testing.assert_close(
|
240
|
-
actual=actual_logits,
|
241
|
-
expected=filtered_expected_logits,
|
242
|
-
msg=f"logits\nactual={actual_logits}\nexpected={filtered_expected_logits}",
|
243
|
-
)
|
244
|
-
|
245
|
-
def test_merge_enabled_with_disabled(self):
|
246
|
-
enabled_test_case = self.test_cases[0]
|
247
|
-
disabled_test_case = self.test_cases[1]
|
248
|
-
|
249
|
-
orchestrator, penalizer = self._create_penalizer(enabled_test_case)
|
250
|
-
theirs, _ = self._create_penalizer(disabled_test_case)
|
251
|
-
|
252
|
-
orchestrator.merge(theirs)
|
253
|
-
|
254
|
-
for key, tensor in {
|
255
|
-
key: torch.cat(
|
256
|
-
tensors=[
|
257
|
-
enabled_test_case.test_subjects[0].steps[0].expected_tensors[key],
|
258
|
-
disabled_test_case.test_subjects[0].steps[0].expected_tensors[key],
|
259
|
-
],
|
260
|
-
)
|
261
|
-
for key in enabled_test_case.tensor_keys()
|
262
|
-
}.items():
|
263
|
-
torch.testing.assert_close(
|
264
|
-
actual=getattr(penalizer, key),
|
265
|
-
expected=tensor,
|
266
|
-
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
|
267
|
-
)
|
268
|
-
|
269
|
-
def test_cumulate_apply_repeat(self):
|
270
|
-
for case in self.test_cases:
|
271
|
-
with self.subTest(case=case):
|
272
|
-
orchestrator, penalizer = self._create_penalizer(case)
|
273
|
-
|
274
|
-
max_step = max(len(subject.steps) for subject in case.test_subjects)
|
275
|
-
for i in range(1, max_step):
|
276
|
-
orchestrator.filter(
|
277
|
-
indices_to_keep=[
|
278
|
-
j
|
279
|
-
for j, subject in enumerate(case.test_subjects)
|
280
|
-
if i < len(subject.steps)
|
281
|
-
]
|
282
|
-
)
|
283
|
-
|
284
|
-
filtered_subjects = [
|
285
|
-
subject
|
286
|
-
for subject in case.test_subjects
|
287
|
-
if i < len(subject.steps)
|
288
|
-
]
|
289
|
-
|
290
|
-
inputs: List[List[int]] = []
|
291
|
-
outputs: List[List[int]] = []
|
292
|
-
for subject in filtered_subjects:
|
293
|
-
step = subject.steps[i]
|
294
|
-
if step.type == StepType.INPUT:
|
295
|
-
raise NotImplementedError()
|
296
|
-
else:
|
297
|
-
inputs.append([])
|
298
|
-
outputs.append(step.token_ids)
|
299
|
-
|
300
|
-
if any(outputs):
|
301
|
-
for j in range(max(len(x) for x in outputs)):
|
302
|
-
tmp_outputs = torch.tensor(
|
303
|
-
[x[j] for x in outputs],
|
304
|
-
dtype=torch.int32,
|
305
|
-
device=orchestrator.device,
|
306
|
-
)
|
307
|
-
orchestrator.cumulate_output_tokens(tmp_outputs)
|
308
|
-
|
309
|
-
if penalizer.is_required():
|
310
|
-
self.assertTrue(penalizer.is_prepared())
|
311
|
-
for key, tensor in {
|
312
|
-
key: torch.cat(
|
313
|
-
tensors=[
|
314
|
-
subject.steps[i].expected_tensors[key]
|
315
|
-
for subject in filtered_subjects
|
316
|
-
],
|
317
|
-
)
|
318
|
-
for key in case.tensor_keys()
|
319
|
-
}.items():
|
320
|
-
torch.testing.assert_close(
|
321
|
-
actual=getattr(penalizer, key),
|
322
|
-
expected=tensor,
|
323
|
-
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
|
324
|
-
)
|
325
|
-
|
326
|
-
original = torch.ones(
|
327
|
-
size=(len(filtered_subjects), self.vocab_size),
|
328
|
-
dtype=torch.float32,
|
329
|
-
device=self.device,
|
330
|
-
)
|
331
|
-
actual_logits = orchestrator.apply(original.clone())
|
332
|
-
filtered_expected_logits = torch.cat(
|
333
|
-
tensors=[
|
334
|
-
subject.steps[i].expected_logits
|
335
|
-
for subject in filtered_subjects
|
336
|
-
],
|
337
|
-
)
|
338
|
-
if actual_logits is None:
|
339
|
-
actual_logits = original
|
340
|
-
torch.testing.assert_close(
|
341
|
-
actual=actual_logits,
|
342
|
-
expected=filtered_expected_logits,
|
343
|
-
msg=f"logits\nactual={actual_logits}\nexpected={filtered_expected_logits}",
|
344
|
-
)
|
File without changes
|
File without changes
|