vllm-npu 0.4.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- vllm/__init__.py +23 -0
- vllm/_custom_ops.py +251 -0
- vllm/attention/__init__.py +13 -0
- vllm/attention/backends/__init__.py +0 -0
- vllm/attention/backends/abstract.py +127 -0
- vllm/attention/backends/flash_attn.py +271 -0
- vllm/attention/backends/flashinfer.py +220 -0
- vllm/attention/backends/rocm_flash_attn.py +374 -0
- vllm/attention/backends/torch_sdpa.py +250 -0
- vllm/attention/backends/xformers.py +393 -0
- vllm/attention/layer.py +56 -0
- vllm/attention/ops/__init__.py +0 -0
- vllm/attention/ops/paged_attn.py +216 -0
- vllm/attention/ops/prefix_prefill.py +792 -0
- vllm/attention/ops/triton_flash_attention.py +810 -0
- vllm/attention/selector.py +91 -0
- vllm/block.py +84 -0
- vllm/config.py +1225 -0
- vllm/core/__init__.py +0 -0
- vllm/core/block/__init__.py +0 -0
- vllm/core/block/block_table.py +295 -0
- vllm/core/block/common.py +199 -0
- vllm/core/block/cpu_gpu_block_allocator.py +228 -0
- vllm/core/block/interfaces.py +205 -0
- vllm/core/block/naive_block.py +318 -0
- vllm/core/block/prefix_caching_block.py +606 -0
- vllm/core/block_manager_v1.py +625 -0
- vllm/core/block_manager_v2.py +258 -0
- vllm/core/evictor_v1.py +105 -0
- vllm/core/evictor_v2.py +127 -0
- vllm/core/interfaces.py +113 -0
- vllm/core/policy.py +45 -0
- vllm/core/scheduler.py +1163 -0
- vllm/distributed/__init__.py +3 -0
- vllm/distributed/communication_op.py +237 -0
- vllm/distributed/device_communicators/__init__.py +0 -0
- vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
- vllm/distributed/device_communicators/pynccl.py +287 -0
- vllm/distributed/device_communicators/pynccl_utils.py +66 -0
- vllm/distributed/parallel_state.py +339 -0
- vllm/distributed/utils.py +136 -0
- vllm/engine/__init__.py +0 -0
- vllm/engine/arg_utils.py +649 -0
- vllm/engine/async_llm_engine.py +737 -0
- vllm/engine/llm_engine.py +784 -0
- vllm/engine/metrics.py +368 -0
- vllm/engine/output_processor/__init__.py +0 -0
- vllm/engine/output_processor/interfaces.py +76 -0
- vllm/engine/output_processor/multi_step.py +142 -0
- vllm/engine/output_processor/single_step.py +284 -0
- vllm/engine/output_processor/stop_checker.py +101 -0
- vllm/engine/output_processor/util.py +19 -0
- vllm/entrypoints/__init__.py +0 -0
- vllm/entrypoints/api_server.py +119 -0
- vllm/entrypoints/llm.py +259 -0
- vllm/entrypoints/openai/__init__.py +0 -0
- vllm/entrypoints/openai/api_server.py +186 -0
- vllm/entrypoints/openai/cli_args.py +115 -0
- vllm/entrypoints/openai/protocol.py +460 -0
- vllm/entrypoints/openai/serving_chat.py +392 -0
- vllm/entrypoints/openai/serving_completion.py +347 -0
- vllm/entrypoints/openai/serving_engine.py +234 -0
- vllm/envs.py +217 -0
- vllm/executor/__init__.py +0 -0
- vllm/executor/cpu_executor.py +152 -0
- vllm/executor/distributed_gpu_executor.py +115 -0
- vllm/executor/executor_base.py +115 -0
- vllm/executor/gpu_executor.py +150 -0
- vllm/executor/multiproc_worker_utils.py +263 -0
- vllm/executor/neuron_executor.py +91 -0
- vllm/executor/ray_gpu_executor.py +327 -0
- vllm/executor/ray_utils.py +119 -0
- vllm/logger.py +153 -0
- vllm/logging/__init__.py +5 -0
- vllm/logging/formatter.py +15 -0
- vllm/lora/__init__.py +0 -0
- vllm/lora/fully_sharded_layers.py +262 -0
- vllm/lora/layers.py +1181 -0
- vllm/lora/lora.py +167 -0
- vllm/lora/models.py +645 -0
- vllm/lora/punica.py +213 -0
- vllm/lora/request.py +32 -0
- vllm/lora/utils.py +98 -0
- vllm/lora/worker_manager.py +251 -0
- vllm/model_executor/__init__.py +7 -0
- vllm/model_executor/guided_decoding/__init__.py +25 -0
- vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
- vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
- vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
- vllm/model_executor/layers/__init__.py +0 -0
- vllm/model_executor/layers/activation.py +173 -0
- vllm/model_executor/layers/fused_moe/__init__.py +7 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
- vllm/model_executor/layers/layernorm.py +71 -0
- vllm/model_executor/layers/linear.py +709 -0
- vllm/model_executor/layers/logits_processor.py +115 -0
- vllm/model_executor/layers/ops/__init__.py +0 -0
- vllm/model_executor/layers/ops/rand.py +157 -0
- vllm/model_executor/layers/ops/sample.py +406 -0
- vllm/model_executor/layers/quantization/__init__.py +35 -0
- vllm/model_executor/layers/quantization/aqlm.py +376 -0
- vllm/model_executor/layers/quantization/awq.py +175 -0
- vllm/model_executor/layers/quantization/base_config.py +97 -0
- vllm/model_executor/layers/quantization/fp8.py +265 -0
- vllm/model_executor/layers/quantization/gptq.py +224 -0
- vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
- vllm/model_executor/layers/quantization/marlin.py +227 -0
- vllm/model_executor/layers/quantization/schema.py +84 -0
- vllm/model_executor/layers/quantization/squeezellm.py +137 -0
- vllm/model_executor/layers/rejection_sampler.py +405 -0
- vllm/model_executor/layers/rotary_embedding.py +525 -0
- vllm/model_executor/layers/sampler.py +1051 -0
- vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
- vllm/model_executor/model_loader/__init__.py +30 -0
- vllm/model_executor/model_loader/loader.py +362 -0
- vllm/model_executor/model_loader/neuron.py +136 -0
- vllm/model_executor/model_loader/tensorizer.py +368 -0
- vllm/model_executor/model_loader/utils.py +41 -0
- vllm/model_executor/model_loader/weight_utils.py +372 -0
- vllm/model_executor/models/__init__.py +119 -0
- vllm/model_executor/models/baichuan.py +410 -0
- vllm/model_executor/models/bloom.py +327 -0
- vllm/model_executor/models/chatglm.py +386 -0
- vllm/model_executor/models/commandr.py +373 -0
- vllm/model_executor/models/dbrx.py +413 -0
- vllm/model_executor/models/decilm.py +122 -0
- vllm/model_executor/models/deepseek.py +438 -0
- vllm/model_executor/models/falcon.py +444 -0
- vllm/model_executor/models/gemma.py +393 -0
- vllm/model_executor/models/gpt2.py +266 -0
- vllm/model_executor/models/gpt_bigcode.py +274 -0
- vllm/model_executor/models/gpt_j.py +281 -0
- vllm/model_executor/models/gpt_neox.py +295 -0
- vllm/model_executor/models/internlm2.py +323 -0
- vllm/model_executor/models/jais.py +333 -0
- vllm/model_executor/models/llama.py +442 -0
- vllm/model_executor/models/llava.py +239 -0
- vllm/model_executor/models/minicpm.py +531 -0
- vllm/model_executor/models/mixtral.py +583 -0
- vllm/model_executor/models/mixtral_quant.py +404 -0
- vllm/model_executor/models/mpt.py +295 -0
- vllm/model_executor/models/olmo.py +356 -0
- vllm/model_executor/models/opt.py +349 -0
- vllm/model_executor/models/orion.py +319 -0
- vllm/model_executor/models/phi.py +300 -0
- vllm/model_executor/models/qwen.py +284 -0
- vllm/model_executor/models/qwen2.py +367 -0
- vllm/model_executor/models/qwen2_moe.py +447 -0
- vllm/model_executor/models/stablelm.py +301 -0
- vllm/model_executor/models/starcoder2.py +302 -0
- vllm/model_executor/models/xverse.py +366 -0
- vllm/model_executor/sampling_metadata.py +588 -0
- vllm/model_executor/utils.py +35 -0
- vllm/outputs.py +150 -0
- vllm/py.typed +2 -0
- vllm/sampling_params.py +340 -0
- vllm/sequence.py +766 -0
- vllm/spec_decode/__init__.py +0 -0
- vllm/spec_decode/batch_expansion.py +397 -0
- vllm/spec_decode/interfaces.py +73 -0
- vllm/spec_decode/metrics.py +191 -0
- vllm/spec_decode/multi_step_worker.py +203 -0
- vllm/spec_decode/ngram_worker.py +176 -0
- vllm/spec_decode/spec_decode_worker.py +472 -0
- vllm/spec_decode/top1_proposer.py +200 -0
- vllm/spec_decode/util.py +228 -0
- vllm/test_utils.py +41 -0
- vllm/transformers_utils/__init__.py +0 -0
- vllm/transformers_utils/config.py +58 -0
- vllm/transformers_utils/configs/__init__.py +16 -0
- vllm/transformers_utils/configs/chatglm.py +68 -0
- vllm/transformers_utils/configs/dbrx.py +278 -0
- vllm/transformers_utils/configs/falcon.py +87 -0
- vllm/transformers_utils/configs/jais.py +236 -0
- vllm/transformers_utils/configs/mpt.py +178 -0
- vllm/transformers_utils/detokenizer.py +313 -0
- vllm/transformers_utils/tokenizer.py +149 -0
- vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
- vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
- vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
- vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
- vllm/transformers_utils/tokenizers/__init__.py +5 -0
- vllm/transformers_utils/tokenizers/baichuan.py +255 -0
- vllm/usage/__init__.py +0 -0
- vllm/usage/usage_lib.py +209 -0
- vllm/utils.py +677 -0
- vllm/worker/__init__.py +0 -0
- vllm/worker/cache_engine.py +105 -0
- vllm/worker/cpu_model_runner.py +346 -0
- vllm/worker/cpu_worker.py +321 -0
- vllm/worker/model_runner.py +1168 -0
- vllm/worker/neuron_model_runner.py +196 -0
- vllm/worker/neuron_worker.py +98 -0
- vllm/worker/worker.py +345 -0
- vllm/worker/worker_base.py +146 -0
- vllm_npu-0.4.2.dist-info/LICENSE +201 -0
- vllm_npu-0.4.2.dist-info/METADATA +173 -0
- vllm_npu-0.4.2.dist-info/RECORD +219 -0
- vllm_npu-0.4.2.dist-info/WHEEL +5 -0
- vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
vllm/outputs.py
ADDED
@@ -0,0 +1,150 @@
|
|
1
|
+
import time
|
2
|
+
from typing import List, Optional, Union
|
3
|
+
|
4
|
+
from vllm.lora.request import LoRARequest
|
5
|
+
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
|
6
|
+
SequenceGroup, SequenceStatus)
|
7
|
+
|
8
|
+
|
9
|
+
class CompletionOutput:
|
10
|
+
"""The output data of one completion output of a request.
|
11
|
+
|
12
|
+
Args:
|
13
|
+
index: The index of the output in the request.
|
14
|
+
text: The generated output text.
|
15
|
+
token_ids: The token IDs of the generated output text.
|
16
|
+
cumulative_logprob: The cumulative log probability of the generated
|
17
|
+
output text.
|
18
|
+
logprobs: The log probabilities of the top probability words at each
|
19
|
+
position if the logprobs are requested.
|
20
|
+
finish_reason: The reason why the sequence is finished.
|
21
|
+
stop_reason: The stop string or token id that caused the completion
|
22
|
+
to stop, None if the completion finished for some other reason
|
23
|
+
including encountering the EOS token.
|
24
|
+
lora_request: The LoRA request that was used to generate the output.
|
25
|
+
"""
|
26
|
+
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
index: int,
|
30
|
+
text: str,
|
31
|
+
token_ids: List[int],
|
32
|
+
cumulative_logprob: float,
|
33
|
+
logprobs: Optional[SampleLogprobs],
|
34
|
+
finish_reason: Optional[str] = None,
|
35
|
+
stop_reason: Union[int, str, None] = None,
|
36
|
+
lora_request: Optional[LoRARequest] = None,
|
37
|
+
) -> None:
|
38
|
+
self.index = index
|
39
|
+
self.text = text
|
40
|
+
self.token_ids = token_ids
|
41
|
+
self.cumulative_logprob = cumulative_logprob
|
42
|
+
self.logprobs = logprobs
|
43
|
+
self.finish_reason = finish_reason
|
44
|
+
self.stop_reason = stop_reason
|
45
|
+
self.lora_request = lora_request
|
46
|
+
|
47
|
+
def finished(self) -> bool:
|
48
|
+
return self.finish_reason is not None
|
49
|
+
|
50
|
+
def __repr__(self) -> str:
|
51
|
+
return (f"CompletionOutput(index={self.index}, "
|
52
|
+
f"text={self.text!r}, "
|
53
|
+
f"token_ids={self.token_ids}, "
|
54
|
+
f"cumulative_logprob={self.cumulative_logprob}, "
|
55
|
+
f"logprobs={self.logprobs}, "
|
56
|
+
f"finish_reason={self.finish_reason}, "
|
57
|
+
f"stop_reason={self.stop_reason})")
|
58
|
+
|
59
|
+
|
60
|
+
class RequestOutput:
|
61
|
+
"""The output data of a request to the LLM.
|
62
|
+
|
63
|
+
Args:
|
64
|
+
request_id: The unique ID of the request.
|
65
|
+
prompt: The prompt string of the request.
|
66
|
+
prompt_token_ids: The token IDs of the prompt.
|
67
|
+
prompt_logprobs: The log probabilities to return per prompt token.
|
68
|
+
outputs: The output sequences of the request.
|
69
|
+
finished: Whether the whole request is finished.
|
70
|
+
metrics: Metrics associated with the request.
|
71
|
+
lora_request: The LoRA request that was used to generate the output.
|
72
|
+
"""
|
73
|
+
|
74
|
+
def __init__(
|
75
|
+
self,
|
76
|
+
request_id: str,
|
77
|
+
prompt: str,
|
78
|
+
prompt_token_ids: List[int],
|
79
|
+
prompt_logprobs: Optional[PromptLogprobs],
|
80
|
+
outputs: List[CompletionOutput],
|
81
|
+
finished: bool,
|
82
|
+
metrics: Optional[RequestMetrics] = None,
|
83
|
+
lora_request: Optional[LoRARequest] = None,
|
84
|
+
) -> None:
|
85
|
+
self.request_id = request_id
|
86
|
+
self.prompt = prompt
|
87
|
+
self.prompt_token_ids = prompt_token_ids
|
88
|
+
self.prompt_logprobs = prompt_logprobs
|
89
|
+
self.outputs = outputs
|
90
|
+
self.finished = finished
|
91
|
+
self.metrics = metrics
|
92
|
+
self.lora_request = lora_request
|
93
|
+
|
94
|
+
@classmethod
|
95
|
+
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
|
96
|
+
seqs = seq_group.get_seqs()
|
97
|
+
if len(seqs) == 1:
|
98
|
+
top_n_seqs = seqs
|
99
|
+
else:
|
100
|
+
# Get the top-n sequences.
|
101
|
+
n = seq_group.sampling_params.n
|
102
|
+
if seq_group.sampling_params.use_beam_search:
|
103
|
+
sorting_key = lambda seq: seq.get_beam_search_score(
|
104
|
+
seq_group.sampling_params.length_penalty)
|
105
|
+
else:
|
106
|
+
sorting_key = lambda seq: seq.get_cumulative_logprob()
|
107
|
+
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
|
108
|
+
top_n_seqs = sorted_seqs[:n]
|
109
|
+
|
110
|
+
# Create the outputs.
|
111
|
+
# NOTE: We need omit logprobs here explicitly because the sequence
|
112
|
+
# always has the logprobs of the sampled tokens even if the
|
113
|
+
# logprobs are not requested.
|
114
|
+
include_logprobs = seq_group.sampling_params.logprobs is not None
|
115
|
+
text_buffer_length = seq_group.sampling_params.output_text_buffer_length
|
116
|
+
outputs = [
|
117
|
+
CompletionOutput(seqs.index(seq),
|
118
|
+
seq.get_output_text_to_return(text_buffer_length),
|
119
|
+
seq.get_output_token_ids(),
|
120
|
+
seq.get_cumulative_logprob(),
|
121
|
+
seq.output_logprobs if include_logprobs else None,
|
122
|
+
SequenceStatus.get_finished_reason(seq.status),
|
123
|
+
seq.stop_reason) for seq in top_n_seqs
|
124
|
+
]
|
125
|
+
|
126
|
+
# Every sequence in the sequence group should have the same prompt.
|
127
|
+
prompt = seq_group.prompt
|
128
|
+
prompt_token_ids = seq_group.prompt_token_ids
|
129
|
+
prompt_logprobs = seq_group.prompt_logprobs
|
130
|
+
finished = seq_group.is_finished()
|
131
|
+
finished_time = time.time() if finished else None
|
132
|
+
seq_group.set_finished_time(finished_time)
|
133
|
+
return cls(seq_group.request_id,
|
134
|
+
prompt,
|
135
|
+
prompt_token_ids,
|
136
|
+
prompt_logprobs,
|
137
|
+
outputs,
|
138
|
+
finished,
|
139
|
+
seq_group.metrics,
|
140
|
+
lora_request=seq_group.lora_request)
|
141
|
+
|
142
|
+
def __repr__(self) -> str:
|
143
|
+
return (f"RequestOutput(request_id={self.request_id}, "
|
144
|
+
f"prompt={self.prompt!r}, "
|
145
|
+
f"prompt_token_ids={self.prompt_token_ids}, "
|
146
|
+
f"prompt_logprobs={self.prompt_logprobs}, "
|
147
|
+
f"outputs={self.outputs}, "
|
148
|
+
f"finished={self.finished}, "
|
149
|
+
f"metrics={self.metrics}, "
|
150
|
+
f"lora_request={self.lora_request})")
|
vllm/py.typed
ADDED
vllm/sampling_params.py
ADDED
@@ -0,0 +1,340 @@
|
|
1
|
+
"""Sampling parameters for text generation."""
|
2
|
+
import copy
|
3
|
+
from enum import IntEnum
|
4
|
+
from functools import cached_property
|
5
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
6
|
+
|
7
|
+
import torch
|
8
|
+
from pydantic import Field
|
9
|
+
from typing_extensions import Annotated
|
10
|
+
|
11
|
+
_SAMPLING_EPS = 1e-5
|
12
|
+
|
13
|
+
|
14
|
+
class SamplingType(IntEnum):
|
15
|
+
GREEDY = 0
|
16
|
+
RANDOM = 1
|
17
|
+
RANDOM_SEED = 2
|
18
|
+
BEAM = 3
|
19
|
+
|
20
|
+
|
21
|
+
LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor]
|
22
|
+
"""LogitsProcessor is a function that takes a list of previously generated
|
23
|
+
tokens and a tensor of the logits for the next token, and returns a modified
|
24
|
+
tensor of logits to sample from."""
|
25
|
+
|
26
|
+
|
27
|
+
class SamplingParams:
|
28
|
+
"""Sampling parameters for text generation.
|
29
|
+
|
30
|
+
Overall, we follow the sampling parameters from the OpenAI text completion
|
31
|
+
API (https://platform.openai.com/docs/api-reference/completions/create).
|
32
|
+
In addition, we support beam search, which is not supported by OpenAI.
|
33
|
+
|
34
|
+
Args:
|
35
|
+
n: Number of output sequences to return for the given prompt.
|
36
|
+
best_of: Number of output sequences that are generated from the prompt.
|
37
|
+
From these `best_of` sequences, the top `n` sequences are returned.
|
38
|
+
`best_of` must be greater than or equal to `n`. This is treated as
|
39
|
+
the beam width when `use_beam_search` is True. By default, `best_of`
|
40
|
+
is set to `n`.
|
41
|
+
presence_penalty: Float that penalizes new tokens based on whether they
|
42
|
+
appear in the generated text so far. Values > 0 encourage the model
|
43
|
+
to use new tokens, while values < 0 encourage the model to repeat
|
44
|
+
tokens.
|
45
|
+
frequency_penalty: Float that penalizes new tokens based on their
|
46
|
+
frequency in the generated text so far. Values > 0 encourage the
|
47
|
+
model to use new tokens, while values < 0 encourage the model to
|
48
|
+
repeat tokens.
|
49
|
+
repetition_penalty: Float that penalizes new tokens based on whether
|
50
|
+
they appear in the prompt and the generated text so far. Values > 1
|
51
|
+
encourage the model to use new tokens, while values < 1 encourage
|
52
|
+
the model to repeat tokens.
|
53
|
+
temperature: Float that controls the randomness of the sampling. Lower
|
54
|
+
values make the model more deterministic, while higher values make
|
55
|
+
the model more random. Zero means greedy sampling.
|
56
|
+
top_p: Float that controls the cumulative probability of the top tokens
|
57
|
+
to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
|
58
|
+
top_k: Integer that controls the number of top tokens to consider. Set
|
59
|
+
to -1 to consider all tokens.
|
60
|
+
min_p: Float that represents the minimum probability for a token to be
|
61
|
+
considered, relative to the probability of the most likely token.
|
62
|
+
Must be in [0, 1]. Set to 0 to disable this.
|
63
|
+
seed: Random seed to use for the generation.
|
64
|
+
use_beam_search: Whether to use beam search instead of sampling.
|
65
|
+
length_penalty: Float that penalizes sequences based on their length.
|
66
|
+
Used in beam search.
|
67
|
+
early_stopping: Controls the stopping condition for beam search. It
|
68
|
+
accepts the following values: `True`, where the generation stops as
|
69
|
+
soon as there are `best_of` complete candidates; `False`, where an
|
70
|
+
heuristic is applied and the generation stops when is it very
|
71
|
+
unlikely to find better candidates; `"never"`, where the beam search
|
72
|
+
procedure only stops when there cannot be better candidates
|
73
|
+
(canonical beam search algorithm).
|
74
|
+
stop: List of strings that stop the generation when they are generated.
|
75
|
+
The returned output will not contain the stop strings.
|
76
|
+
stop_token_ids: List of tokens that stop the generation when they are
|
77
|
+
generated. The returned output will contain the stop tokens unless
|
78
|
+
the stop tokens are special tokens.
|
79
|
+
include_stop_str_in_output: Whether to include the stop strings in
|
80
|
+
output text. Defaults to False.
|
81
|
+
ignore_eos: Whether to ignore the EOS token and continue generating
|
82
|
+
tokens after the EOS token is generated.
|
83
|
+
max_tokens: Maximum number of tokens to generate per output sequence.
|
84
|
+
min_tokens: Minimum number of tokens to generate per output sequence
|
85
|
+
before EOS or stop_token_ids can be generated
|
86
|
+
logprobs: Number of log probabilities to return per output token.
|
87
|
+
Note that the implementation follows the OpenAI API: The return
|
88
|
+
result includes the log probabilities on the `logprobs` most likely
|
89
|
+
tokens, as well the chosen tokens. The API will always return the
|
90
|
+
log probability of the sampled token, so there may be up to
|
91
|
+
`logprobs+1` elements in the response.
|
92
|
+
prompt_logprobs: Number of log probabilities to return per prompt token.
|
93
|
+
detokenize: Whether to detokenize the output. Defaults to True.
|
94
|
+
skip_special_tokens: Whether to skip special tokens in the output.
|
95
|
+
spaces_between_special_tokens: Whether to add spaces between special
|
96
|
+
tokens in the output. Defaults to True.
|
97
|
+
logits_processors: List of functions that modify logits based on
|
98
|
+
previously generated tokens.
|
99
|
+
truncate_prompt_tokens: If set to an integer k, will use only the last k
|
100
|
+
tokens from the prompt (i.e., left truncation). Defaults to None
|
101
|
+
(i.e., no truncation).
|
102
|
+
"""
|
103
|
+
|
104
|
+
def __init__(
|
105
|
+
self,
|
106
|
+
n: int = 1,
|
107
|
+
best_of: Optional[int] = None,
|
108
|
+
presence_penalty: float = 0.0,
|
109
|
+
frequency_penalty: float = 0.0,
|
110
|
+
repetition_penalty: float = 1.0,
|
111
|
+
temperature: float = 1.0,
|
112
|
+
top_p: float = 1.0,
|
113
|
+
top_k: int = -1,
|
114
|
+
min_p: float = 0.0,
|
115
|
+
seed: Optional[int] = None,
|
116
|
+
use_beam_search: bool = False,
|
117
|
+
length_penalty: float = 1.0,
|
118
|
+
early_stopping: Union[bool, str] = False,
|
119
|
+
stop: Optional[Union[str, List[str]]] = None,
|
120
|
+
stop_token_ids: Optional[List[int]] = None,
|
121
|
+
include_stop_str_in_output: bool = False,
|
122
|
+
ignore_eos: bool = False,
|
123
|
+
max_tokens: Optional[int] = 16,
|
124
|
+
min_tokens: int = 0,
|
125
|
+
logprobs: Optional[int] = None,
|
126
|
+
prompt_logprobs: Optional[int] = None,
|
127
|
+
detokenize: bool = True,
|
128
|
+
skip_special_tokens: bool = True,
|
129
|
+
spaces_between_special_tokens: bool = True,
|
130
|
+
logits_processors: Optional[List[LogitsProcessor]] = None,
|
131
|
+
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
132
|
+
) -> None:
|
133
|
+
self.n = n
|
134
|
+
self.best_of = best_of if best_of is not None else n
|
135
|
+
self.presence_penalty = presence_penalty
|
136
|
+
self.frequency_penalty = frequency_penalty
|
137
|
+
self.repetition_penalty = repetition_penalty
|
138
|
+
self.temperature = temperature
|
139
|
+
self.top_p = top_p
|
140
|
+
self.top_k = top_k
|
141
|
+
self.min_p = min_p
|
142
|
+
if seed == -1:
|
143
|
+
self.seed = None
|
144
|
+
else:
|
145
|
+
self.seed = seed
|
146
|
+
self.use_beam_search = use_beam_search
|
147
|
+
self.length_penalty = length_penalty
|
148
|
+
self.early_stopping = early_stopping
|
149
|
+
if stop is None:
|
150
|
+
self.stop = []
|
151
|
+
elif isinstance(stop, str):
|
152
|
+
self.stop = [stop]
|
153
|
+
else:
|
154
|
+
self.stop = list(stop)
|
155
|
+
if stop_token_ids is None:
|
156
|
+
self.stop_token_ids = []
|
157
|
+
else:
|
158
|
+
self.stop_token_ids = list(stop_token_ids)
|
159
|
+
self.ignore_eos = ignore_eos
|
160
|
+
self.max_tokens = max_tokens
|
161
|
+
self.min_tokens = min_tokens
|
162
|
+
self.logprobs = logprobs
|
163
|
+
self.prompt_logprobs = prompt_logprobs
|
164
|
+
# NOTE: This parameter is only exposed at the engine level for now.
|
165
|
+
# It is not exposed in the OpenAI API server, as the OpenAI API does
|
166
|
+
# not support returning only a list of token IDs.
|
167
|
+
self.detokenize = detokenize
|
168
|
+
self.skip_special_tokens = skip_special_tokens
|
169
|
+
self.spaces_between_special_tokens = spaces_between_special_tokens
|
170
|
+
self.logits_processors = logits_processors
|
171
|
+
self.include_stop_str_in_output = include_stop_str_in_output
|
172
|
+
self.truncate_prompt_tokens = truncate_prompt_tokens
|
173
|
+
# Number of characters to hold back for stop string evaluation
|
174
|
+
# until sequence is finished.
|
175
|
+
if self.stop and not include_stop_str_in_output:
|
176
|
+
self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
|
177
|
+
else:
|
178
|
+
self.output_text_buffer_length = 0
|
179
|
+
|
180
|
+
self._verify_args()
|
181
|
+
if self.use_beam_search:
|
182
|
+
self._verify_beam_search()
|
183
|
+
else:
|
184
|
+
self._verify_non_beam_search()
|
185
|
+
if self.temperature < _SAMPLING_EPS:
|
186
|
+
# Zero temperature means greedy sampling.
|
187
|
+
self.top_p = 1.0
|
188
|
+
self.top_k = -1
|
189
|
+
self.min_p = 0.0
|
190
|
+
self._verify_greedy_sampling()
|
191
|
+
# eos_token_id is added to this by the engine
|
192
|
+
self.all_stop_token_ids = set(self.stop_token_ids)
|
193
|
+
|
194
|
+
def _verify_args(self) -> None:
|
195
|
+
if self.n < 1:
|
196
|
+
raise ValueError(f"n must be at least 1, got {self.n}.")
|
197
|
+
if self.best_of < self.n:
|
198
|
+
raise ValueError(f"best_of must be greater than or equal to n, "
|
199
|
+
f"got n={self.n} and best_of={self.best_of}.")
|
200
|
+
if not -2.0 <= self.presence_penalty <= 2.0:
|
201
|
+
raise ValueError("presence_penalty must be in [-2, 2], got "
|
202
|
+
f"{self.presence_penalty}.")
|
203
|
+
if not -2.0 <= self.frequency_penalty <= 2.0:
|
204
|
+
raise ValueError("frequency_penalty must be in [-2, 2], got "
|
205
|
+
f"{self.frequency_penalty}.")
|
206
|
+
if not 0.0 < self.repetition_penalty <= 2.0:
|
207
|
+
raise ValueError("repetition_penalty must be in (0, 2], got "
|
208
|
+
f"{self.repetition_penalty}.")
|
209
|
+
if self.temperature < 0.0:
|
210
|
+
raise ValueError(
|
211
|
+
f"temperature must be non-negative, got {self.temperature}.")
|
212
|
+
if not 0.0 < self.top_p <= 1.0:
|
213
|
+
raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
|
214
|
+
if self.top_k < -1 or self.top_k == 0:
|
215
|
+
raise ValueError(f"top_k must be -1 (disable), or at least 1, "
|
216
|
+
f"got {self.top_k}.")
|
217
|
+
if not 0.0 <= self.min_p <= 1.0:
|
218
|
+
raise ValueError("min_p must be in [0, 1], got "
|
219
|
+
f"{self.min_p}.")
|
220
|
+
if self.max_tokens is not None and self.max_tokens < 1:
|
221
|
+
raise ValueError(
|
222
|
+
f"max_tokens must be at least 1, got {self.max_tokens}.")
|
223
|
+
if self.min_tokens < 0:
|
224
|
+
raise ValueError(f"min_tokens must be greater than or equal to 0, "
|
225
|
+
f"got {self.min_tokens}.")
|
226
|
+
if self.max_tokens is not None and self.min_tokens > self.max_tokens:
|
227
|
+
raise ValueError(
|
228
|
+
f"min_tokens must be less than or equal to "
|
229
|
+
f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
|
230
|
+
if self.logprobs is not None and self.logprobs < 0:
|
231
|
+
raise ValueError(
|
232
|
+
f"logprobs must be non-negative, got {self.logprobs}.")
|
233
|
+
if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
|
234
|
+
raise ValueError(f"prompt_logprobs must be non-negative, got "
|
235
|
+
f"{self.prompt_logprobs}.")
|
236
|
+
if (self.truncate_prompt_tokens is not None
|
237
|
+
and self.truncate_prompt_tokens < 1):
|
238
|
+
raise ValueError(f"truncate_prompt_tokens must be >= 1, "
|
239
|
+
f"got {self.truncate_prompt_tokens}")
|
240
|
+
if any(not stop_str for stop_str in self.stop):
|
241
|
+
raise ValueError("stop cannot contain an empty string.")
|
242
|
+
if self.stop and not self.detokenize:
|
243
|
+
raise ValueError(
|
244
|
+
"stop strings are only supported when detokenize is True. "
|
245
|
+
"Set detokenize=True to use stop.")
|
246
|
+
|
247
|
+
def _verify_beam_search(self) -> None:
|
248
|
+
if self.best_of == 1:
|
249
|
+
raise ValueError("best_of must be greater than 1 when using beam "
|
250
|
+
f"search. Got {self.best_of}.")
|
251
|
+
if self.temperature > _SAMPLING_EPS:
|
252
|
+
raise ValueError("temperature must be 0 when using beam search.")
|
253
|
+
if self.top_p < 1.0 - _SAMPLING_EPS:
|
254
|
+
raise ValueError("top_p must be 1 when using beam search.")
|
255
|
+
if self.top_k != -1:
|
256
|
+
raise ValueError("top_k must be -1 when using beam search.")
|
257
|
+
if self.early_stopping not in [True, False, "never"]:
|
258
|
+
raise ValueError(
|
259
|
+
f"early_stopping must be True, False, or 'never', "
|
260
|
+
f"got {self.early_stopping}.")
|
261
|
+
|
262
|
+
def _verify_non_beam_search(self) -> None:
|
263
|
+
if self.early_stopping is not False:
|
264
|
+
raise ValueError("early_stopping is not effective and must be "
|
265
|
+
"False when not using beam search.")
|
266
|
+
if (self.length_penalty < 1.0 - _SAMPLING_EPS
|
267
|
+
or self.length_penalty > 1.0 + _SAMPLING_EPS):
|
268
|
+
raise ValueError(
|
269
|
+
"length_penalty is not effective and must be the "
|
270
|
+
"default value of 1.0 when not using beam search.")
|
271
|
+
|
272
|
+
def _verify_greedy_sampling(self) -> None:
|
273
|
+
if self.best_of > 1:
|
274
|
+
raise ValueError("best_of must be 1 when using greedy sampling."
|
275
|
+
f"Got {self.best_of}.")
|
276
|
+
|
277
|
+
def update_from_generation_config(
|
278
|
+
self, generation_config: Dict[str, Any]) -> None:
|
279
|
+
"""Update if there are non-default values from generation_config"""
|
280
|
+
# Update eos_token_id for generation
|
281
|
+
if (not self.ignore_eos) and (eos_ids :=
|
282
|
+
generation_config.get("eos_token_id")):
|
283
|
+
# it can be either int or list of int
|
284
|
+
if isinstance(eos_ids, int):
|
285
|
+
eos_ids = [eos_ids]
|
286
|
+
original_stop_token_ids = set(self.stop_token_ids)
|
287
|
+
original_stop_token_ids.update(eos_ids)
|
288
|
+
self.stop_token_ids = list(original_stop_token_ids)
|
289
|
+
|
290
|
+
@cached_property
|
291
|
+
def sampling_type(self) -> SamplingType:
|
292
|
+
if self.use_beam_search:
|
293
|
+
return SamplingType.BEAM
|
294
|
+
if self.temperature < _SAMPLING_EPS:
|
295
|
+
return SamplingType.GREEDY
|
296
|
+
if self.seed is not None:
|
297
|
+
return SamplingType.RANDOM_SEED
|
298
|
+
return SamplingType.RANDOM
|
299
|
+
|
300
|
+
def clone(self) -> "SamplingParams":
|
301
|
+
"""Deep copy excluding LogitsProcessor objects.
|
302
|
+
|
303
|
+
LogitsProcessor objects are excluded because they may contain an
|
304
|
+
arbitrary, nontrivial amount of data.
|
305
|
+
See https://github.com/vllm-project/vllm/issues/3087
|
306
|
+
"""
|
307
|
+
|
308
|
+
logit_processor_refs = None if self.logits_processors is None else {
|
309
|
+
id(lp): lp
|
310
|
+
for lp in self.logits_processors
|
311
|
+
}
|
312
|
+
return copy.deepcopy(self, memo=logit_processor_refs)
|
313
|
+
|
314
|
+
def __repr__(self) -> str:
|
315
|
+
return (
|
316
|
+
f"SamplingParams(n={self.n}, "
|
317
|
+
f"best_of={self.best_of}, "
|
318
|
+
f"presence_penalty={self.presence_penalty}, "
|
319
|
+
f"frequency_penalty={self.frequency_penalty}, "
|
320
|
+
f"repetition_penalty={self.repetition_penalty}, "
|
321
|
+
f"temperature={self.temperature}, "
|
322
|
+
f"top_p={self.top_p}, "
|
323
|
+
f"top_k={self.top_k}, "
|
324
|
+
f"min_p={self.min_p}, "
|
325
|
+
f"seed={self.seed}, "
|
326
|
+
f"use_beam_search={self.use_beam_search}, "
|
327
|
+
f"length_penalty={self.length_penalty}, "
|
328
|
+
f"early_stopping={self.early_stopping}, "
|
329
|
+
f"stop={self.stop}, "
|
330
|
+
f"stop_token_ids={self.stop_token_ids}, "
|
331
|
+
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
|
332
|
+
f"ignore_eos={self.ignore_eos}, "
|
333
|
+
f"max_tokens={self.max_tokens}, "
|
334
|
+
f"min_tokens={self.min_tokens}, "
|
335
|
+
f"logprobs={self.logprobs}, "
|
336
|
+
f"prompt_logprobs={self.prompt_logprobs}, "
|
337
|
+
f"skip_special_tokens={self.skip_special_tokens}, "
|
338
|
+
"spaces_between_special_tokens="
|
339
|
+
f"{self.spaces_between_special_tokens}, "
|
340
|
+
f"truncate_prompt_tokens={self.truncate_prompt_tokens})")
|