sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +220 -378
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +9 -6
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +143 -6
- sglang/srt/managers/schedule_batch.py +237 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +681 -259
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +224 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +44 -18
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +94 -36
- sglang/srt/model_executor/cuda_graph_runner.py +55 -24
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +208 -28
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -32
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +136 -52
- sglang/srt/speculative/build_eagle_tree.py +2 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
- sglang/srt/speculative/eagle_utils.py +92 -58
- sglang/srt/speculative/eagle_worker.py +186 -94
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +200 -166
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -9,9 +9,6 @@ import torch
|
|
9
9
|
|
10
10
|
import sglang.srt.sampling.penaltylib as penaltylib
|
11
11
|
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
12
|
-
from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import (
|
13
|
-
apply_scaling_penalties,
|
14
|
-
)
|
15
12
|
|
16
13
|
logger = logging.getLogger(__name__)
|
17
14
|
|
@@ -22,49 +19,45 @@ if TYPE_CHECKING:
|
|
22
19
|
|
23
20
|
@dataclasses.dataclass
|
24
21
|
class SamplingBatchInfo:
|
25
|
-
#
|
22
|
+
# Basic batched sampling params
|
26
23
|
temperatures: torch.Tensor
|
27
24
|
top_ps: torch.Tensor
|
28
25
|
top_ks: torch.Tensor
|
29
26
|
min_ps: torch.Tensor
|
30
27
|
|
31
|
-
#
|
28
|
+
# Whether all requests use greedy sampling
|
32
29
|
is_all_greedy: bool
|
33
30
|
|
34
|
-
#
|
31
|
+
# Whether any request needs min_p sampling
|
35
32
|
need_min_p_sampling: bool
|
36
33
|
|
37
|
-
#
|
38
|
-
has_custom_logit_processor: bool
|
39
|
-
|
40
|
-
# Bias Tensors
|
34
|
+
# Masking tensors for grammar-guided structured outputs
|
41
35
|
vocab_size: int
|
42
36
|
grammars: Optional[List] = None
|
43
|
-
sampling_info_done: Optional[threading.Event] = None
|
44
|
-
logit_bias: torch.Tensor = None
|
45
37
|
vocab_mask: Optional[torch.Tensor] = None
|
46
|
-
|
38
|
+
apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
|
39
|
+
|
40
|
+
# An event used for overlap schedule
|
41
|
+
sampling_info_done: Optional[threading.Event] = None
|
47
42
|
|
48
43
|
# Penalizer
|
49
44
|
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
|
50
|
-
|
51
|
-
scaling_penalties: Optional[torch.Tensor] = None
|
45
|
+
linear_penalty: torch.Tensor = None
|
52
46
|
|
53
|
-
#
|
54
|
-
|
55
|
-
|
56
|
-
# Custom Parameters
|
47
|
+
# Whether any request has custom logit processor
|
48
|
+
has_custom_logit_processor: bool = False
|
49
|
+
# Custom parameters
|
57
50
|
custom_params: Optional[List[Optional[Dict[str, Any]]]] = None
|
58
|
-
|
59
|
-
# Custom Logit Processor
|
51
|
+
# Custom logit processor
|
60
52
|
custom_logit_processor: Optional[
|
61
53
|
Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]
|
62
54
|
] = None
|
63
55
|
|
56
|
+
# Device
|
57
|
+
device: str = "cuda"
|
58
|
+
|
64
59
|
@classmethod
|
65
|
-
def from_schedule_batch(
|
66
|
-
cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool
|
67
|
-
):
|
60
|
+
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
68
61
|
reqs = batch.reqs
|
69
62
|
device = batch.device
|
70
63
|
temperatures = (
|
@@ -118,106 +111,60 @@ class SamplingBatchInfo:
|
|
118
111
|
merged_custom_logit_processor = None
|
119
112
|
custom_params = None
|
120
113
|
|
121
|
-
ret = cls(
|
122
|
-
temperatures=temperatures,
|
123
|
-
top_ps=top_ps,
|
124
|
-
top_ks=top_ks,
|
125
|
-
min_ps=min_ps,
|
126
|
-
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
127
|
-
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
|
128
|
-
has_custom_logit_processor=has_custom_logit_processor,
|
129
|
-
vocab_size=vocab_size,
|
130
|
-
device=device,
|
131
|
-
custom_params=custom_params,
|
132
|
-
custom_logit_processor=merged_custom_logit_processor,
|
133
|
-
)
|
134
|
-
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
135
|
-
|
136
|
-
if enable_overlap_schedule:
|
137
|
-
# TODO (lianmin): Some penalizers such as frequency and presence depend on model outputs,
|
138
|
-
# so it is kind of tricky to make it work with overlap scheduler.
|
139
|
-
# It requires correcly updating the penalty logits before the sampling and syncing the events.
|
140
|
-
# We will support them later.
|
141
|
-
penalizers = {
|
142
|
-
penaltylib.BatchedMinNewTokensPenalizer,
|
143
|
-
}
|
144
|
-
if (
|
145
|
-
any(req.sampling_params.frequency_penalty != 0.0 for req in reqs)
|
146
|
-
or any(req.sampling_params.presence_penalty != 0.0 for req in reqs)
|
147
|
-
or any(req.sampling_params.repetition_penalty != 1.0 for req in reqs)
|
148
|
-
):
|
149
|
-
logger.warning(
|
150
|
-
"frequency_penalty, presence_penalty, and repetition_penalty are not supported "
|
151
|
-
"when using the default overlap scheduler. They will be ignored. "
|
152
|
-
"Please add `--disable-overlap` when launching the server if you need these features. "
|
153
|
-
"The speed will be slower in that case."
|
154
|
-
)
|
155
|
-
else:
|
156
|
-
penalizers = {
|
157
|
-
penaltylib.BatchedFrequencyPenalizer,
|
158
|
-
penaltylib.BatchedMinNewTokensPenalizer,
|
159
|
-
penaltylib.BatchedPresencePenalizer,
|
160
|
-
penaltylib.BatchedRepetitionPenalizer,
|
161
|
-
}
|
162
|
-
|
163
114
|
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
164
115
|
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
165
116
|
# should not add hefty computation overhead other than simple checks.
|
166
117
|
#
|
167
|
-
# While we choose not to even create the class instances if they are not required, this
|
118
|
+
# While we can choose not to even create the class instances if they are not required, this
|
168
119
|
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
169
120
|
# handle {filter_batch()} and {merge_batch()} cases as well.
|
170
|
-
|
121
|
+
penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
171
122
|
vocab_size=vocab_size,
|
172
123
|
batch=batch,
|
173
|
-
|
174
|
-
|
124
|
+
penalizers={
|
125
|
+
penaltylib.BatchedFrequencyPenalizer,
|
126
|
+
penaltylib.BatchedMinNewTokensPenalizer,
|
127
|
+
penaltylib.BatchedPresencePenalizer,
|
128
|
+
},
|
175
129
|
)
|
176
130
|
|
177
|
-
|
178
|
-
|
179
|
-
|
131
|
+
ret = cls(
|
132
|
+
temperatures=temperatures,
|
133
|
+
top_ps=top_ps,
|
134
|
+
top_ks=top_ks,
|
135
|
+
min_ps=min_ps,
|
136
|
+
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
|
137
|
+
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
138
|
+
vocab_size=vocab_size,
|
139
|
+
penalizer_orchestrator=penalizer_orchestrator,
|
140
|
+
has_custom_logit_processor=has_custom_logit_processor,
|
141
|
+
custom_params=custom_params,
|
142
|
+
custom_logit_processor=merged_custom_logit_processor,
|
143
|
+
device=device,
|
144
|
+
)
|
180
145
|
return ret
|
181
146
|
|
182
147
|
def __len__(self):
|
183
148
|
return len(self.temperatures)
|
184
149
|
|
185
|
-
def update_penalties(self):
|
186
|
-
self.scaling_penalties = None
|
187
|
-
self.linear_penalties = None
|
188
|
-
|
189
|
-
for penalizer in self.penalizer_orchestrator.penalizers.values():
|
190
|
-
if not penalizer.is_prepared():
|
191
|
-
continue
|
192
|
-
|
193
|
-
if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer):
|
194
|
-
self.scaling_penalties = penalizer.cumulated_repetition_penalties
|
195
|
-
else:
|
196
|
-
if self.linear_penalties is None:
|
197
|
-
bs = self.penalizer_orchestrator.batch.batch_size()
|
198
|
-
self.linear_penalties = torch.zeros(
|
199
|
-
(bs, self.vocab_size),
|
200
|
-
dtype=torch.float32,
|
201
|
-
device=self.device,
|
202
|
-
)
|
203
|
-
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
204
|
-
|
205
150
|
def update_regex_vocab_mask(self):
|
206
151
|
if not self.grammars:
|
207
152
|
self.vocab_mask = None
|
208
|
-
self.
|
153
|
+
self.apply_mask_func = None
|
209
154
|
return
|
210
155
|
|
211
|
-
#
|
156
|
+
# Find a grammar from the list
|
212
157
|
first_grammar = next(grammar for grammar in self.grammars if grammar)
|
213
158
|
|
214
|
-
#
|
159
|
+
# TODO(lianmin): Maybe we can reuse the existing mask?
|
215
160
|
self.vocab_mask = first_grammar.allocate_vocab_mask(
|
216
161
|
vocab_size=self.vocab_size,
|
217
162
|
batch_size=len(self.temperatures),
|
218
163
|
device=self.device,
|
219
164
|
)
|
220
|
-
self.
|
165
|
+
self.apply_mask_func = (
|
166
|
+
first_grammar.apply_vocab_mask
|
167
|
+
) # force to use static method
|
221
168
|
|
222
169
|
# Apply the mask
|
223
170
|
for i, grammar in enumerate(self.grammars):
|
@@ -227,35 +174,56 @@ class SamplingBatchInfo:
|
|
227
174
|
# Move the mask to the device if needed
|
228
175
|
self.vocab_mask = first_grammar.move_vocab_mask(self.vocab_mask, self.device)
|
229
176
|
|
230
|
-
def
|
231
|
-
self.penalizer_orchestrator.
|
177
|
+
def update_penalties(self):
|
178
|
+
if self.penalizer_orchestrator.is_required:
|
179
|
+
self.linear_penalty = torch.zeros(
|
180
|
+
(len(self.temperatures), self.vocab_size),
|
181
|
+
dtype=torch.float32,
|
182
|
+
device=self.temperatures.device,
|
183
|
+
)
|
184
|
+
self.penalizer_orchestrator.apply(self.linear_penalty)
|
185
|
+
else:
|
186
|
+
self.linear_penalty = None
|
187
|
+
|
188
|
+
def apply_logits_bias(self, logits: torch.Tensor):
|
189
|
+
if self.linear_penalty is not None:
|
190
|
+
# Used in the overlap mode
|
191
|
+
logits.add_(self.linear_penalty)
|
192
|
+
|
193
|
+
if self.penalizer_orchestrator and self.penalizer_orchestrator.is_required:
|
194
|
+
# Used in the non-overlap mode
|
195
|
+
self.penalizer_orchestrator.apply(logits)
|
196
|
+
|
197
|
+
if self.vocab_mask is not None:
|
198
|
+
self.apply_mask_func(logits=logits, vocab_mask=self.vocab_mask)
|
199
|
+
|
200
|
+
def filter_batch(self, keep_indices: List[int], keep_indices_device: torch.Tensor):
|
201
|
+
self.penalizer_orchestrator.filter(keep_indices_device)
|
202
|
+
|
232
203
|
if self.has_custom_logit_processor:
|
233
|
-
self._filter_batch_custom_logit_processor(
|
204
|
+
self._filter_batch_custom_logit_processor(keep_indices, keep_indices_device)
|
234
205
|
|
235
206
|
for item in [
|
236
207
|
"temperatures",
|
237
208
|
"top_ps",
|
238
209
|
"top_ks",
|
239
210
|
"min_ps",
|
240
|
-
"logit_bias",
|
241
211
|
]:
|
242
212
|
value = getattr(self, item, None)
|
243
|
-
|
244
|
-
setattr(self, item, value[new_indices])
|
213
|
+
setattr(self, item, value[keep_indices_device])
|
245
214
|
|
246
215
|
def _filter_batch_custom_logit_processor(
|
247
|
-
self,
|
216
|
+
self, keep_indices: List[int], keep_indices_device: torch.Tensor
|
248
217
|
):
|
249
218
|
"""Filter the custom logit processor and custom params"""
|
250
|
-
|
251
219
|
self.custom_logit_processor = {
|
252
|
-
k: (p, mask[
|
220
|
+
k: (p, mask[keep_indices_device])
|
253
221
|
for k, (p, mask) in self.custom_logit_processor.items()
|
254
|
-
if any(
|
255
|
-
mask[
|
222
|
+
if torch.any(
|
223
|
+
mask[keep_indices_device]
|
256
224
|
) # ignore the custom logit processor whose mask is all False
|
257
225
|
}
|
258
|
-
self.custom_params = [self.custom_params[i] for i in
|
226
|
+
self.custom_params = [self.custom_params[i] for i in keep_indices]
|
259
227
|
|
260
228
|
# If the custom logit processor is an empty dict, set the flag to False,
|
261
229
|
# and set the custom logit processor and custom params to None.
|
@@ -264,31 +232,6 @@ class SamplingBatchInfo:
|
|
264
232
|
self.custom_params = None
|
265
233
|
self.has_custom_logit_processor = False
|
266
234
|
|
267
|
-
@staticmethod
|
268
|
-
def merge_bias_tensor(
|
269
|
-
lhs: torch.Tensor,
|
270
|
-
rhs: torch.Tensor,
|
271
|
-
bs1: int,
|
272
|
-
bs2: int,
|
273
|
-
device: str,
|
274
|
-
default: int = 0,
|
275
|
-
):
|
276
|
-
# bias tensor can be None
|
277
|
-
if lhs is not None or rhs is not None:
|
278
|
-
shape, dtype = None, None
|
279
|
-
if lhs is not None:
|
280
|
-
shape, dtype = lhs.shape[1:], lhs.dtype
|
281
|
-
else:
|
282
|
-
shape, dtype = rhs.shape[1:], rhs.dtype
|
283
|
-
with torch.dtype(dtype):
|
284
|
-
if lhs is None:
|
285
|
-
lhs = torch.empty((bs1, *shape), device=device).fill_(default)
|
286
|
-
if rhs is None:
|
287
|
-
rhs = torch.empty((bs2, *shape), device=device).fill_(default)
|
288
|
-
return torch.cat([lhs, rhs])
|
289
|
-
|
290
|
-
return None
|
291
|
-
|
292
235
|
@staticmethod
|
293
236
|
def merge_custom_logit_processor(
|
294
237
|
lhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
|
@@ -332,10 +275,6 @@ class SamplingBatchInfo:
|
|
332
275
|
def merge_batch(self, other: "SamplingBatchInfo"):
|
333
276
|
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
334
277
|
|
335
|
-
# Merge the logit bias tensor
|
336
|
-
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
337
|
-
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
338
|
-
)
|
339
278
|
# Merge the custom logit processors and custom params lists
|
340
279
|
if self.has_custom_logit_processor or other.has_custom_logit_processor:
|
341
280
|
# Merge the custom logit processors
|
@@ -369,22 +308,5 @@ class SamplingBatchInfo:
|
|
369
308
|
other_val = getattr(other, item, None)
|
370
309
|
setattr(self, item, torch.concat([self_val, other_val]))
|
371
310
|
|
372
|
-
self.is_all_greedy
|
373
|
-
self.need_min_p_sampling
|
374
|
-
|
375
|
-
def apply_logits_bias(self, logits: torch.Tensor):
|
376
|
-
# Apply logit_bias
|
377
|
-
if self.logit_bias is not None:
|
378
|
-
logits.add_(self.logit_bias)
|
379
|
-
|
380
|
-
# min-token, presence, frequency
|
381
|
-
if self.linear_penalties is not None:
|
382
|
-
logits.add_(self.linear_penalties)
|
383
|
-
|
384
|
-
# repetition
|
385
|
-
if self.scaling_penalties is not None:
|
386
|
-
apply_scaling_penalties(logits, self.scaling_penalties)
|
387
|
-
|
388
|
-
# Apply regex vocab_mask
|
389
|
-
if self.vocab_mask is not None:
|
390
|
-
self.apply_mask(logits=logits, vocab_mask=self.vocab_mask)
|
311
|
+
self.is_all_greedy |= other.is_all_greedy
|
312
|
+
self.need_min_p_sampling |= other.need_min_p_sampling
|
@@ -22,8 +22,8 @@ class SamplingParams:
|
|
22
22
|
"""
|
23
23
|
The sampling parameters.
|
24
24
|
|
25
|
-
See docs/
|
26
|
-
https://docs.sglang.ai/
|
25
|
+
See docs/backend/sampling_params.md or
|
26
|
+
https://docs.sglang.ai/backend/sampling_params.html
|
27
27
|
for the documentation.
|
28
28
|
"""
|
29
29
|
|
@@ -40,16 +40,23 @@ class SamplingParams:
|
|
40
40
|
presence_penalty: float = 0.0,
|
41
41
|
repetition_penalty: float = 1.0,
|
42
42
|
min_new_tokens: int = 0,
|
43
|
-
spaces_between_special_tokens: bool = True,
|
44
43
|
n: int = 1,
|
45
44
|
json_schema: Optional[str] = None,
|
46
45
|
regex: Optional[str] = None,
|
47
46
|
ebnf: Optional[str] = None,
|
48
|
-
|
47
|
+
structural_tag: Optional[str] = None,
|
49
48
|
ignore_eos: bool = False,
|
50
49
|
skip_special_tokens: bool = True,
|
50
|
+
spaces_between_special_tokens: bool = True,
|
51
|
+
no_stop_trim: bool = False,
|
51
52
|
custom_params: Optional[Dict[str, Any]] = None,
|
52
53
|
) -> None:
|
54
|
+
self.max_new_tokens = max_new_tokens
|
55
|
+
self.stop_strs = stop
|
56
|
+
if stop_token_ids:
|
57
|
+
self.stop_token_ids = set(stop_token_ids)
|
58
|
+
else:
|
59
|
+
self.stop_token_ids = None
|
53
60
|
self.temperature = temperature
|
54
61
|
self.top_p = top_p
|
55
62
|
self.top_k = top_k
|
@@ -57,25 +64,21 @@ class SamplingParams:
|
|
57
64
|
self.frequency_penalty = frequency_penalty
|
58
65
|
self.presence_penalty = presence_penalty
|
59
66
|
self.repetition_penalty = repetition_penalty
|
60
|
-
self.stop_strs = stop
|
61
|
-
if stop_token_ids:
|
62
|
-
self.stop_token_ids = set(stop_token_ids)
|
63
|
-
else:
|
64
|
-
self.stop_token_ids = None
|
65
|
-
self.max_new_tokens = max_new_tokens
|
66
67
|
self.min_new_tokens = min_new_tokens
|
67
|
-
self.ignore_eos = ignore_eos
|
68
|
-
self.skip_special_tokens = skip_special_tokens
|
69
|
-
self.spaces_between_special_tokens = spaces_between_special_tokens
|
70
68
|
self.regex = regex
|
71
69
|
self.n = n
|
72
70
|
self.json_schema = json_schema
|
73
71
|
self.ebnf = ebnf
|
72
|
+
self.structural_tag = structural_tag
|
73
|
+
self.ignore_eos = ignore_eos
|
74
|
+
self.skip_special_tokens = skip_special_tokens
|
75
|
+
self.spaces_between_special_tokens = spaces_between_special_tokens
|
74
76
|
self.no_stop_trim = no_stop_trim
|
75
77
|
self.custom_params = custom_params
|
76
78
|
|
77
79
|
# Process some special cases
|
78
80
|
if self.temperature < _SAMPLING_EPS:
|
81
|
+
# top_k = 1 means greedy sampling
|
79
82
|
self.temperature = 1.0
|
80
83
|
self.top_k = 1
|
81
84
|
if self.top_k == -1:
|