vllm-npu 0.4.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. vllm/__init__.py +23 -0
  2. vllm/_custom_ops.py +251 -0
  3. vllm/attention/__init__.py +13 -0
  4. vllm/attention/backends/__init__.py +0 -0
  5. vllm/attention/backends/abstract.py +127 -0
  6. vllm/attention/backends/flash_attn.py +271 -0
  7. vllm/attention/backends/flashinfer.py +220 -0
  8. vllm/attention/backends/rocm_flash_attn.py +374 -0
  9. vllm/attention/backends/torch_sdpa.py +250 -0
  10. vllm/attention/backends/xformers.py +393 -0
  11. vllm/attention/layer.py +56 -0
  12. vllm/attention/ops/__init__.py +0 -0
  13. vllm/attention/ops/paged_attn.py +216 -0
  14. vllm/attention/ops/prefix_prefill.py +792 -0
  15. vllm/attention/ops/triton_flash_attention.py +810 -0
  16. vllm/attention/selector.py +91 -0
  17. vllm/block.py +84 -0
  18. vllm/config.py +1225 -0
  19. vllm/core/__init__.py +0 -0
  20. vllm/core/block/__init__.py +0 -0
  21. vllm/core/block/block_table.py +295 -0
  22. vllm/core/block/common.py +199 -0
  23. vllm/core/block/cpu_gpu_block_allocator.py +228 -0
  24. vllm/core/block/interfaces.py +205 -0
  25. vllm/core/block/naive_block.py +318 -0
  26. vllm/core/block/prefix_caching_block.py +606 -0
  27. vllm/core/block_manager_v1.py +625 -0
  28. vllm/core/block_manager_v2.py +258 -0
  29. vllm/core/evictor_v1.py +105 -0
  30. vllm/core/evictor_v2.py +127 -0
  31. vllm/core/interfaces.py +113 -0
  32. vllm/core/policy.py +45 -0
  33. vllm/core/scheduler.py +1163 -0
  34. vllm/distributed/__init__.py +3 -0
  35. vllm/distributed/communication_op.py +237 -0
  36. vllm/distributed/device_communicators/__init__.py +0 -0
  37. vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
  38. vllm/distributed/device_communicators/pynccl.py +287 -0
  39. vllm/distributed/device_communicators/pynccl_utils.py +66 -0
  40. vllm/distributed/parallel_state.py +339 -0
  41. vllm/distributed/utils.py +136 -0
  42. vllm/engine/__init__.py +0 -0
  43. vllm/engine/arg_utils.py +649 -0
  44. vllm/engine/async_llm_engine.py +737 -0
  45. vllm/engine/llm_engine.py +784 -0
  46. vllm/engine/metrics.py +368 -0
  47. vllm/engine/output_processor/__init__.py +0 -0
  48. vllm/engine/output_processor/interfaces.py +76 -0
  49. vllm/engine/output_processor/multi_step.py +142 -0
  50. vllm/engine/output_processor/single_step.py +284 -0
  51. vllm/engine/output_processor/stop_checker.py +101 -0
  52. vllm/engine/output_processor/util.py +19 -0
  53. vllm/entrypoints/__init__.py +0 -0
  54. vllm/entrypoints/api_server.py +119 -0
  55. vllm/entrypoints/llm.py +259 -0
  56. vllm/entrypoints/openai/__init__.py +0 -0
  57. vllm/entrypoints/openai/api_server.py +186 -0
  58. vllm/entrypoints/openai/cli_args.py +115 -0
  59. vllm/entrypoints/openai/protocol.py +460 -0
  60. vllm/entrypoints/openai/serving_chat.py +392 -0
  61. vllm/entrypoints/openai/serving_completion.py +347 -0
  62. vllm/entrypoints/openai/serving_engine.py +234 -0
  63. vllm/envs.py +217 -0
  64. vllm/executor/__init__.py +0 -0
  65. vllm/executor/cpu_executor.py +152 -0
  66. vllm/executor/distributed_gpu_executor.py +115 -0
  67. vllm/executor/executor_base.py +115 -0
  68. vllm/executor/gpu_executor.py +150 -0
  69. vllm/executor/multiproc_worker_utils.py +263 -0
  70. vllm/executor/neuron_executor.py +91 -0
  71. vllm/executor/ray_gpu_executor.py +327 -0
  72. vllm/executor/ray_utils.py +119 -0
  73. vllm/logger.py +153 -0
  74. vllm/logging/__init__.py +5 -0
  75. vllm/logging/formatter.py +15 -0
  76. vllm/lora/__init__.py +0 -0
  77. vllm/lora/fully_sharded_layers.py +262 -0
  78. vllm/lora/layers.py +1181 -0
  79. vllm/lora/lora.py +167 -0
  80. vllm/lora/models.py +645 -0
  81. vllm/lora/punica.py +213 -0
  82. vllm/lora/request.py +32 -0
  83. vllm/lora/utils.py +98 -0
  84. vllm/lora/worker_manager.py +251 -0
  85. vllm/model_executor/__init__.py +7 -0
  86. vllm/model_executor/guided_decoding/__init__.py +25 -0
  87. vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
  88. vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
  89. vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
  90. vllm/model_executor/layers/__init__.py +0 -0
  91. vllm/model_executor/layers/activation.py +173 -0
  92. vllm/model_executor/layers/fused_moe/__init__.py +7 -0
  93. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  94. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  97. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  98. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  99. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  100. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  101. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  102. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  103. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  104. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  105. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
  106. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  107. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  108. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  109. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  110. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
  111. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  112. vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
  113. vllm/model_executor/layers/layernorm.py +71 -0
  114. vllm/model_executor/layers/linear.py +709 -0
  115. vllm/model_executor/layers/logits_processor.py +115 -0
  116. vllm/model_executor/layers/ops/__init__.py +0 -0
  117. vllm/model_executor/layers/ops/rand.py +157 -0
  118. vllm/model_executor/layers/ops/sample.py +406 -0
  119. vllm/model_executor/layers/quantization/__init__.py +35 -0
  120. vllm/model_executor/layers/quantization/aqlm.py +376 -0
  121. vllm/model_executor/layers/quantization/awq.py +175 -0
  122. vllm/model_executor/layers/quantization/base_config.py +97 -0
  123. vllm/model_executor/layers/quantization/fp8.py +265 -0
  124. vllm/model_executor/layers/quantization/gptq.py +224 -0
  125. vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
  126. vllm/model_executor/layers/quantization/marlin.py +227 -0
  127. vllm/model_executor/layers/quantization/schema.py +84 -0
  128. vllm/model_executor/layers/quantization/squeezellm.py +137 -0
  129. vllm/model_executor/layers/rejection_sampler.py +405 -0
  130. vllm/model_executor/layers/rotary_embedding.py +525 -0
  131. vllm/model_executor/layers/sampler.py +1051 -0
  132. vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
  133. vllm/model_executor/model_loader/__init__.py +30 -0
  134. vllm/model_executor/model_loader/loader.py +362 -0
  135. vllm/model_executor/model_loader/neuron.py +136 -0
  136. vllm/model_executor/model_loader/tensorizer.py +368 -0
  137. vllm/model_executor/model_loader/utils.py +41 -0
  138. vllm/model_executor/model_loader/weight_utils.py +372 -0
  139. vllm/model_executor/models/__init__.py +119 -0
  140. vllm/model_executor/models/baichuan.py +410 -0
  141. vllm/model_executor/models/bloom.py +327 -0
  142. vllm/model_executor/models/chatglm.py +386 -0
  143. vllm/model_executor/models/commandr.py +373 -0
  144. vllm/model_executor/models/dbrx.py +413 -0
  145. vllm/model_executor/models/decilm.py +122 -0
  146. vllm/model_executor/models/deepseek.py +438 -0
  147. vllm/model_executor/models/falcon.py +444 -0
  148. vllm/model_executor/models/gemma.py +393 -0
  149. vllm/model_executor/models/gpt2.py +266 -0
  150. vllm/model_executor/models/gpt_bigcode.py +274 -0
  151. vllm/model_executor/models/gpt_j.py +281 -0
  152. vllm/model_executor/models/gpt_neox.py +295 -0
  153. vllm/model_executor/models/internlm2.py +323 -0
  154. vllm/model_executor/models/jais.py +333 -0
  155. vllm/model_executor/models/llama.py +442 -0
  156. vllm/model_executor/models/llava.py +239 -0
  157. vllm/model_executor/models/minicpm.py +531 -0
  158. vllm/model_executor/models/mixtral.py +583 -0
  159. vllm/model_executor/models/mixtral_quant.py +404 -0
  160. vllm/model_executor/models/mpt.py +295 -0
  161. vllm/model_executor/models/olmo.py +356 -0
  162. vllm/model_executor/models/opt.py +349 -0
  163. vllm/model_executor/models/orion.py +319 -0
  164. vllm/model_executor/models/phi.py +300 -0
  165. vllm/model_executor/models/qwen.py +284 -0
  166. vllm/model_executor/models/qwen2.py +367 -0
  167. vllm/model_executor/models/qwen2_moe.py +447 -0
  168. vllm/model_executor/models/stablelm.py +301 -0
  169. vllm/model_executor/models/starcoder2.py +302 -0
  170. vllm/model_executor/models/xverse.py +366 -0
  171. vllm/model_executor/sampling_metadata.py +588 -0
  172. vllm/model_executor/utils.py +35 -0
  173. vllm/outputs.py +150 -0
  174. vllm/py.typed +2 -0
  175. vllm/sampling_params.py +340 -0
  176. vllm/sequence.py +766 -0
  177. vllm/spec_decode/__init__.py +0 -0
  178. vllm/spec_decode/batch_expansion.py +397 -0
  179. vllm/spec_decode/interfaces.py +73 -0
  180. vllm/spec_decode/metrics.py +191 -0
  181. vllm/spec_decode/multi_step_worker.py +203 -0
  182. vllm/spec_decode/ngram_worker.py +176 -0
  183. vllm/spec_decode/spec_decode_worker.py +472 -0
  184. vllm/spec_decode/top1_proposer.py +200 -0
  185. vllm/spec_decode/util.py +228 -0
  186. vllm/test_utils.py +41 -0
  187. vllm/transformers_utils/__init__.py +0 -0
  188. vllm/transformers_utils/config.py +58 -0
  189. vllm/transformers_utils/configs/__init__.py +16 -0
  190. vllm/transformers_utils/configs/chatglm.py +68 -0
  191. vllm/transformers_utils/configs/dbrx.py +278 -0
  192. vllm/transformers_utils/configs/falcon.py +87 -0
  193. vllm/transformers_utils/configs/jais.py +236 -0
  194. vllm/transformers_utils/configs/mpt.py +178 -0
  195. vllm/transformers_utils/detokenizer.py +313 -0
  196. vllm/transformers_utils/tokenizer.py +149 -0
  197. vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
  198. vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
  199. vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
  200. vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
  201. vllm/transformers_utils/tokenizers/__init__.py +5 -0
  202. vllm/transformers_utils/tokenizers/baichuan.py +255 -0
  203. vllm/usage/__init__.py +0 -0
  204. vllm/usage/usage_lib.py +209 -0
  205. vllm/utils.py +677 -0
  206. vllm/worker/__init__.py +0 -0
  207. vllm/worker/cache_engine.py +105 -0
  208. vllm/worker/cpu_model_runner.py +346 -0
  209. vllm/worker/cpu_worker.py +321 -0
  210. vllm/worker/model_runner.py +1168 -0
  211. vllm/worker/neuron_model_runner.py +196 -0
  212. vllm/worker/neuron_worker.py +98 -0
  213. vllm/worker/worker.py +345 -0
  214. vllm/worker/worker_base.py +146 -0
  215. vllm_npu-0.4.2.dist-info/LICENSE +201 -0
  216. vllm_npu-0.4.2.dist-info/METADATA +173 -0
  217. vllm_npu-0.4.2.dist-info/RECORD +219 -0
  218. vllm_npu-0.4.2.dist-info/WHEEL +5 -0
  219. vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
vllm/outputs.py ADDED
@@ -0,0 +1,150 @@
1
+ import time
2
+ from typing import List, Optional, Union
3
+
4
+ from vllm.lora.request import LoRARequest
5
+ from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
6
+ SequenceGroup, SequenceStatus)
7
+
8
+
9
+ class CompletionOutput:
10
+ """The output data of one completion output of a request.
11
+
12
+ Args:
13
+ index: The index of the output in the request.
14
+ text: The generated output text.
15
+ token_ids: The token IDs of the generated output text.
16
+ cumulative_logprob: The cumulative log probability of the generated
17
+ output text.
18
+ logprobs: The log probabilities of the top probability words at each
19
+ position if the logprobs are requested.
20
+ finish_reason: The reason why the sequence is finished.
21
+ stop_reason: The stop string or token id that caused the completion
22
+ to stop, None if the completion finished for some other reason
23
+ including encountering the EOS token.
24
+ lora_request: The LoRA request that was used to generate the output.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ index: int,
30
+ text: str,
31
+ token_ids: List[int],
32
+ cumulative_logprob: float,
33
+ logprobs: Optional[SampleLogprobs],
34
+ finish_reason: Optional[str] = None,
35
+ stop_reason: Union[int, str, None] = None,
36
+ lora_request: Optional[LoRARequest] = None,
37
+ ) -> None:
38
+ self.index = index
39
+ self.text = text
40
+ self.token_ids = token_ids
41
+ self.cumulative_logprob = cumulative_logprob
42
+ self.logprobs = logprobs
43
+ self.finish_reason = finish_reason
44
+ self.stop_reason = stop_reason
45
+ self.lora_request = lora_request
46
+
47
+ def finished(self) -> bool:
48
+ return self.finish_reason is not None
49
+
50
+ def __repr__(self) -> str:
51
+ return (f"CompletionOutput(index={self.index}, "
52
+ f"text={self.text!r}, "
53
+ f"token_ids={self.token_ids}, "
54
+ f"cumulative_logprob={self.cumulative_logprob}, "
55
+ f"logprobs={self.logprobs}, "
56
+ f"finish_reason={self.finish_reason}, "
57
+ f"stop_reason={self.stop_reason})")
58
+
59
+
60
+ class RequestOutput:
61
+ """The output data of a request to the LLM.
62
+
63
+ Args:
64
+ request_id: The unique ID of the request.
65
+ prompt: The prompt string of the request.
66
+ prompt_token_ids: The token IDs of the prompt.
67
+ prompt_logprobs: The log probabilities to return per prompt token.
68
+ outputs: The output sequences of the request.
69
+ finished: Whether the whole request is finished.
70
+ metrics: Metrics associated with the request.
71
+ lora_request: The LoRA request that was used to generate the output.
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ request_id: str,
77
+ prompt: str,
78
+ prompt_token_ids: List[int],
79
+ prompt_logprobs: Optional[PromptLogprobs],
80
+ outputs: List[CompletionOutput],
81
+ finished: bool,
82
+ metrics: Optional[RequestMetrics] = None,
83
+ lora_request: Optional[LoRARequest] = None,
84
+ ) -> None:
85
+ self.request_id = request_id
86
+ self.prompt = prompt
87
+ self.prompt_token_ids = prompt_token_ids
88
+ self.prompt_logprobs = prompt_logprobs
89
+ self.outputs = outputs
90
+ self.finished = finished
91
+ self.metrics = metrics
92
+ self.lora_request = lora_request
93
+
94
+ @classmethod
95
+ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
96
+ seqs = seq_group.get_seqs()
97
+ if len(seqs) == 1:
98
+ top_n_seqs = seqs
99
+ else:
100
+ # Get the top-n sequences.
101
+ n = seq_group.sampling_params.n
102
+ if seq_group.sampling_params.use_beam_search:
103
+ sorting_key = lambda seq: seq.get_beam_search_score(
104
+ seq_group.sampling_params.length_penalty)
105
+ else:
106
+ sorting_key = lambda seq: seq.get_cumulative_logprob()
107
+ sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
108
+ top_n_seqs = sorted_seqs[:n]
109
+
110
+ # Create the outputs.
111
+ # NOTE: We need omit logprobs here explicitly because the sequence
112
+ # always has the logprobs of the sampled tokens even if the
113
+ # logprobs are not requested.
114
+ include_logprobs = seq_group.sampling_params.logprobs is not None
115
+ text_buffer_length = seq_group.sampling_params.output_text_buffer_length
116
+ outputs = [
117
+ CompletionOutput(seqs.index(seq),
118
+ seq.get_output_text_to_return(text_buffer_length),
119
+ seq.get_output_token_ids(),
120
+ seq.get_cumulative_logprob(),
121
+ seq.output_logprobs if include_logprobs else None,
122
+ SequenceStatus.get_finished_reason(seq.status),
123
+ seq.stop_reason) for seq in top_n_seqs
124
+ ]
125
+
126
+ # Every sequence in the sequence group should have the same prompt.
127
+ prompt = seq_group.prompt
128
+ prompt_token_ids = seq_group.prompt_token_ids
129
+ prompt_logprobs = seq_group.prompt_logprobs
130
+ finished = seq_group.is_finished()
131
+ finished_time = time.time() if finished else None
132
+ seq_group.set_finished_time(finished_time)
133
+ return cls(seq_group.request_id,
134
+ prompt,
135
+ prompt_token_ids,
136
+ prompt_logprobs,
137
+ outputs,
138
+ finished,
139
+ seq_group.metrics,
140
+ lora_request=seq_group.lora_request)
141
+
142
+ def __repr__(self) -> str:
143
+ return (f"RequestOutput(request_id={self.request_id}, "
144
+ f"prompt={self.prompt!r}, "
145
+ f"prompt_token_ids={self.prompt_token_ids}, "
146
+ f"prompt_logprobs={self.prompt_logprobs}, "
147
+ f"outputs={self.outputs}, "
148
+ f"finished={self.finished}, "
149
+ f"metrics={self.metrics}, "
150
+ f"lora_request={self.lora_request})")
vllm/py.typed ADDED
@@ -0,0 +1,2 @@
1
+ # Marker file for PEP 561.
2
+ # The vllm package uses inline types.
@@ -0,0 +1,340 @@
1
+ """Sampling parameters for text generation."""
2
+ import copy
3
+ from enum import IntEnum
4
+ from functools import cached_property
5
+ from typing import Any, Callable, Dict, List, Optional, Union
6
+
7
+ import torch
8
+ from pydantic import Field
9
+ from typing_extensions import Annotated
10
+
11
+ _SAMPLING_EPS = 1e-5
12
+
13
+
14
+ class SamplingType(IntEnum):
15
+ GREEDY = 0
16
+ RANDOM = 1
17
+ RANDOM_SEED = 2
18
+ BEAM = 3
19
+
20
+
21
+ LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor]
22
+ """LogitsProcessor is a function that takes a list of previously generated
23
+ tokens and a tensor of the logits for the next token, and returns a modified
24
+ tensor of logits to sample from."""
25
+
26
+
27
+ class SamplingParams:
28
+ """Sampling parameters for text generation.
29
+
30
+ Overall, we follow the sampling parameters from the OpenAI text completion
31
+ API (https://platform.openai.com/docs/api-reference/completions/create).
32
+ In addition, we support beam search, which is not supported by OpenAI.
33
+
34
+ Args:
35
+ n: Number of output sequences to return for the given prompt.
36
+ best_of: Number of output sequences that are generated from the prompt.
37
+ From these `best_of` sequences, the top `n` sequences are returned.
38
+ `best_of` must be greater than or equal to `n`. This is treated as
39
+ the beam width when `use_beam_search` is True. By default, `best_of`
40
+ is set to `n`.
41
+ presence_penalty: Float that penalizes new tokens based on whether they
42
+ appear in the generated text so far. Values > 0 encourage the model
43
+ to use new tokens, while values < 0 encourage the model to repeat
44
+ tokens.
45
+ frequency_penalty: Float that penalizes new tokens based on their
46
+ frequency in the generated text so far. Values > 0 encourage the
47
+ model to use new tokens, while values < 0 encourage the model to
48
+ repeat tokens.
49
+ repetition_penalty: Float that penalizes new tokens based on whether
50
+ they appear in the prompt and the generated text so far. Values > 1
51
+ encourage the model to use new tokens, while values < 1 encourage
52
+ the model to repeat tokens.
53
+ temperature: Float that controls the randomness of the sampling. Lower
54
+ values make the model more deterministic, while higher values make
55
+ the model more random. Zero means greedy sampling.
56
+ top_p: Float that controls the cumulative probability of the top tokens
57
+ to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
58
+ top_k: Integer that controls the number of top tokens to consider. Set
59
+ to -1 to consider all tokens.
60
+ min_p: Float that represents the minimum probability for a token to be
61
+ considered, relative to the probability of the most likely token.
62
+ Must be in [0, 1]. Set to 0 to disable this.
63
+ seed: Random seed to use for the generation.
64
+ use_beam_search: Whether to use beam search instead of sampling.
65
+ length_penalty: Float that penalizes sequences based on their length.
66
+ Used in beam search.
67
+ early_stopping: Controls the stopping condition for beam search. It
68
+ accepts the following values: `True`, where the generation stops as
69
+ soon as there are `best_of` complete candidates; `False`, where an
70
+ heuristic is applied and the generation stops when is it very
71
+ unlikely to find better candidates; `"never"`, where the beam search
72
+ procedure only stops when there cannot be better candidates
73
+ (canonical beam search algorithm).
74
+ stop: List of strings that stop the generation when they are generated.
75
+ The returned output will not contain the stop strings.
76
+ stop_token_ids: List of tokens that stop the generation when they are
77
+ generated. The returned output will contain the stop tokens unless
78
+ the stop tokens are special tokens.
79
+ include_stop_str_in_output: Whether to include the stop strings in
80
+ output text. Defaults to False.
81
+ ignore_eos: Whether to ignore the EOS token and continue generating
82
+ tokens after the EOS token is generated.
83
+ max_tokens: Maximum number of tokens to generate per output sequence.
84
+ min_tokens: Minimum number of tokens to generate per output sequence
85
+ before EOS or stop_token_ids can be generated
86
+ logprobs: Number of log probabilities to return per output token.
87
+ Note that the implementation follows the OpenAI API: The return
88
+ result includes the log probabilities on the `logprobs` most likely
89
+ tokens, as well the chosen tokens. The API will always return the
90
+ log probability of the sampled token, so there may be up to
91
+ `logprobs+1` elements in the response.
92
+ prompt_logprobs: Number of log probabilities to return per prompt token.
93
+ detokenize: Whether to detokenize the output. Defaults to True.
94
+ skip_special_tokens: Whether to skip special tokens in the output.
95
+ spaces_between_special_tokens: Whether to add spaces between special
96
+ tokens in the output. Defaults to True.
97
+ logits_processors: List of functions that modify logits based on
98
+ previously generated tokens.
99
+ truncate_prompt_tokens: If set to an integer k, will use only the last k
100
+ tokens from the prompt (i.e., left truncation). Defaults to None
101
+ (i.e., no truncation).
102
+ """
103
+
104
+ def __init__(
105
+ self,
106
+ n: int = 1,
107
+ best_of: Optional[int] = None,
108
+ presence_penalty: float = 0.0,
109
+ frequency_penalty: float = 0.0,
110
+ repetition_penalty: float = 1.0,
111
+ temperature: float = 1.0,
112
+ top_p: float = 1.0,
113
+ top_k: int = -1,
114
+ min_p: float = 0.0,
115
+ seed: Optional[int] = None,
116
+ use_beam_search: bool = False,
117
+ length_penalty: float = 1.0,
118
+ early_stopping: Union[bool, str] = False,
119
+ stop: Optional[Union[str, List[str]]] = None,
120
+ stop_token_ids: Optional[List[int]] = None,
121
+ include_stop_str_in_output: bool = False,
122
+ ignore_eos: bool = False,
123
+ max_tokens: Optional[int] = 16,
124
+ min_tokens: int = 0,
125
+ logprobs: Optional[int] = None,
126
+ prompt_logprobs: Optional[int] = None,
127
+ detokenize: bool = True,
128
+ skip_special_tokens: bool = True,
129
+ spaces_between_special_tokens: bool = True,
130
+ logits_processors: Optional[List[LogitsProcessor]] = None,
131
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
132
+ ) -> None:
133
+ self.n = n
134
+ self.best_of = best_of if best_of is not None else n
135
+ self.presence_penalty = presence_penalty
136
+ self.frequency_penalty = frequency_penalty
137
+ self.repetition_penalty = repetition_penalty
138
+ self.temperature = temperature
139
+ self.top_p = top_p
140
+ self.top_k = top_k
141
+ self.min_p = min_p
142
+ if seed == -1:
143
+ self.seed = None
144
+ else:
145
+ self.seed = seed
146
+ self.use_beam_search = use_beam_search
147
+ self.length_penalty = length_penalty
148
+ self.early_stopping = early_stopping
149
+ if stop is None:
150
+ self.stop = []
151
+ elif isinstance(stop, str):
152
+ self.stop = [stop]
153
+ else:
154
+ self.stop = list(stop)
155
+ if stop_token_ids is None:
156
+ self.stop_token_ids = []
157
+ else:
158
+ self.stop_token_ids = list(stop_token_ids)
159
+ self.ignore_eos = ignore_eos
160
+ self.max_tokens = max_tokens
161
+ self.min_tokens = min_tokens
162
+ self.logprobs = logprobs
163
+ self.prompt_logprobs = prompt_logprobs
164
+ # NOTE: This parameter is only exposed at the engine level for now.
165
+ # It is not exposed in the OpenAI API server, as the OpenAI API does
166
+ # not support returning only a list of token IDs.
167
+ self.detokenize = detokenize
168
+ self.skip_special_tokens = skip_special_tokens
169
+ self.spaces_between_special_tokens = spaces_between_special_tokens
170
+ self.logits_processors = logits_processors
171
+ self.include_stop_str_in_output = include_stop_str_in_output
172
+ self.truncate_prompt_tokens = truncate_prompt_tokens
173
+ # Number of characters to hold back for stop string evaluation
174
+ # until sequence is finished.
175
+ if self.stop and not include_stop_str_in_output:
176
+ self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
177
+ else:
178
+ self.output_text_buffer_length = 0
179
+
180
+ self._verify_args()
181
+ if self.use_beam_search:
182
+ self._verify_beam_search()
183
+ else:
184
+ self._verify_non_beam_search()
185
+ if self.temperature < _SAMPLING_EPS:
186
+ # Zero temperature means greedy sampling.
187
+ self.top_p = 1.0
188
+ self.top_k = -1
189
+ self.min_p = 0.0
190
+ self._verify_greedy_sampling()
191
+ # eos_token_id is added to this by the engine
192
+ self.all_stop_token_ids = set(self.stop_token_ids)
193
+
194
+ def _verify_args(self) -> None:
195
+ if self.n < 1:
196
+ raise ValueError(f"n must be at least 1, got {self.n}.")
197
+ if self.best_of < self.n:
198
+ raise ValueError(f"best_of must be greater than or equal to n, "
199
+ f"got n={self.n} and best_of={self.best_of}.")
200
+ if not -2.0 <= self.presence_penalty <= 2.0:
201
+ raise ValueError("presence_penalty must be in [-2, 2], got "
202
+ f"{self.presence_penalty}.")
203
+ if not -2.0 <= self.frequency_penalty <= 2.0:
204
+ raise ValueError("frequency_penalty must be in [-2, 2], got "
205
+ f"{self.frequency_penalty}.")
206
+ if not 0.0 < self.repetition_penalty <= 2.0:
207
+ raise ValueError("repetition_penalty must be in (0, 2], got "
208
+ f"{self.repetition_penalty}.")
209
+ if self.temperature < 0.0:
210
+ raise ValueError(
211
+ f"temperature must be non-negative, got {self.temperature}.")
212
+ if not 0.0 < self.top_p <= 1.0:
213
+ raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
214
+ if self.top_k < -1 or self.top_k == 0:
215
+ raise ValueError(f"top_k must be -1 (disable), or at least 1, "
216
+ f"got {self.top_k}.")
217
+ if not 0.0 <= self.min_p <= 1.0:
218
+ raise ValueError("min_p must be in [0, 1], got "
219
+ f"{self.min_p}.")
220
+ if self.max_tokens is not None and self.max_tokens < 1:
221
+ raise ValueError(
222
+ f"max_tokens must be at least 1, got {self.max_tokens}.")
223
+ if self.min_tokens < 0:
224
+ raise ValueError(f"min_tokens must be greater than or equal to 0, "
225
+ f"got {self.min_tokens}.")
226
+ if self.max_tokens is not None and self.min_tokens > self.max_tokens:
227
+ raise ValueError(
228
+ f"min_tokens must be less than or equal to "
229
+ f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
230
+ if self.logprobs is not None and self.logprobs < 0:
231
+ raise ValueError(
232
+ f"logprobs must be non-negative, got {self.logprobs}.")
233
+ if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
234
+ raise ValueError(f"prompt_logprobs must be non-negative, got "
235
+ f"{self.prompt_logprobs}.")
236
+ if (self.truncate_prompt_tokens is not None
237
+ and self.truncate_prompt_tokens < 1):
238
+ raise ValueError(f"truncate_prompt_tokens must be >= 1, "
239
+ f"got {self.truncate_prompt_tokens}")
240
+ if any(not stop_str for stop_str in self.stop):
241
+ raise ValueError("stop cannot contain an empty string.")
242
+ if self.stop and not self.detokenize:
243
+ raise ValueError(
244
+ "stop strings are only supported when detokenize is True. "
245
+ "Set detokenize=True to use stop.")
246
+
247
+ def _verify_beam_search(self) -> None:
248
+ if self.best_of == 1:
249
+ raise ValueError("best_of must be greater than 1 when using beam "
250
+ f"search. Got {self.best_of}.")
251
+ if self.temperature > _SAMPLING_EPS:
252
+ raise ValueError("temperature must be 0 when using beam search.")
253
+ if self.top_p < 1.0 - _SAMPLING_EPS:
254
+ raise ValueError("top_p must be 1 when using beam search.")
255
+ if self.top_k != -1:
256
+ raise ValueError("top_k must be -1 when using beam search.")
257
+ if self.early_stopping not in [True, False, "never"]:
258
+ raise ValueError(
259
+ f"early_stopping must be True, False, or 'never', "
260
+ f"got {self.early_stopping}.")
261
+
262
+ def _verify_non_beam_search(self) -> None:
263
+ if self.early_stopping is not False:
264
+ raise ValueError("early_stopping is not effective and must be "
265
+ "False when not using beam search.")
266
+ if (self.length_penalty < 1.0 - _SAMPLING_EPS
267
+ or self.length_penalty > 1.0 + _SAMPLING_EPS):
268
+ raise ValueError(
269
+ "length_penalty is not effective and must be the "
270
+ "default value of 1.0 when not using beam search.")
271
+
272
+ def _verify_greedy_sampling(self) -> None:
273
+ if self.best_of > 1:
274
+ raise ValueError("best_of must be 1 when using greedy sampling."
275
+ f"Got {self.best_of}.")
276
+
277
+ def update_from_generation_config(
278
+ self, generation_config: Dict[str, Any]) -> None:
279
+ """Update if there are non-default values from generation_config"""
280
+ # Update eos_token_id for generation
281
+ if (not self.ignore_eos) and (eos_ids :=
282
+ generation_config.get("eos_token_id")):
283
+ # it can be either int or list of int
284
+ if isinstance(eos_ids, int):
285
+ eos_ids = [eos_ids]
286
+ original_stop_token_ids = set(self.stop_token_ids)
287
+ original_stop_token_ids.update(eos_ids)
288
+ self.stop_token_ids = list(original_stop_token_ids)
289
+
290
+ @cached_property
291
+ def sampling_type(self) -> SamplingType:
292
+ if self.use_beam_search:
293
+ return SamplingType.BEAM
294
+ if self.temperature < _SAMPLING_EPS:
295
+ return SamplingType.GREEDY
296
+ if self.seed is not None:
297
+ return SamplingType.RANDOM_SEED
298
+ return SamplingType.RANDOM
299
+
300
+ def clone(self) -> "SamplingParams":
301
+ """Deep copy excluding LogitsProcessor objects.
302
+
303
+ LogitsProcessor objects are excluded because they may contain an
304
+ arbitrary, nontrivial amount of data.
305
+ See https://github.com/vllm-project/vllm/issues/3087
306
+ """
307
+
308
+ logit_processor_refs = None if self.logits_processors is None else {
309
+ id(lp): lp
310
+ for lp in self.logits_processors
311
+ }
312
+ return copy.deepcopy(self, memo=logit_processor_refs)
313
+
314
+ def __repr__(self) -> str:
315
+ return (
316
+ f"SamplingParams(n={self.n}, "
317
+ f"best_of={self.best_of}, "
318
+ f"presence_penalty={self.presence_penalty}, "
319
+ f"frequency_penalty={self.frequency_penalty}, "
320
+ f"repetition_penalty={self.repetition_penalty}, "
321
+ f"temperature={self.temperature}, "
322
+ f"top_p={self.top_p}, "
323
+ f"top_k={self.top_k}, "
324
+ f"min_p={self.min_p}, "
325
+ f"seed={self.seed}, "
326
+ f"use_beam_search={self.use_beam_search}, "
327
+ f"length_penalty={self.length_penalty}, "
328
+ f"early_stopping={self.early_stopping}, "
329
+ f"stop={self.stop}, "
330
+ f"stop_token_ids={self.stop_token_ids}, "
331
+ f"include_stop_str_in_output={self.include_stop_str_in_output}, "
332
+ f"ignore_eos={self.ignore_eos}, "
333
+ f"max_tokens={self.max_tokens}, "
334
+ f"min_tokens={self.min_tokens}, "
335
+ f"logprobs={self.logprobs}, "
336
+ f"prompt_logprobs={self.prompt_logprobs}, "
337
+ f"skip_special_tokens={self.skip_special_tokens}, "
338
+ "spaces_between_special_tokens="
339
+ f"{self.spaces_between_special_tokens}, "
340
+ f"truncate_prompt_tokens={self.truncate_prompt_tokens})")