vllm-npu 0.4.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. vllm/__init__.py +23 -0
  2. vllm/_custom_ops.py +251 -0
  3. vllm/attention/__init__.py +13 -0
  4. vllm/attention/backends/__init__.py +0 -0
  5. vllm/attention/backends/abstract.py +127 -0
  6. vllm/attention/backends/flash_attn.py +271 -0
  7. vllm/attention/backends/flashinfer.py +220 -0
  8. vllm/attention/backends/rocm_flash_attn.py +374 -0
  9. vllm/attention/backends/torch_sdpa.py +250 -0
  10. vllm/attention/backends/xformers.py +393 -0
  11. vllm/attention/layer.py +56 -0
  12. vllm/attention/ops/__init__.py +0 -0
  13. vllm/attention/ops/paged_attn.py +216 -0
  14. vllm/attention/ops/prefix_prefill.py +792 -0
  15. vllm/attention/ops/triton_flash_attention.py +810 -0
  16. vllm/attention/selector.py +91 -0
  17. vllm/block.py +84 -0
  18. vllm/config.py +1225 -0
  19. vllm/core/__init__.py +0 -0
  20. vllm/core/block/__init__.py +0 -0
  21. vllm/core/block/block_table.py +295 -0
  22. vllm/core/block/common.py +199 -0
  23. vllm/core/block/cpu_gpu_block_allocator.py +228 -0
  24. vllm/core/block/interfaces.py +205 -0
  25. vllm/core/block/naive_block.py +318 -0
  26. vllm/core/block/prefix_caching_block.py +606 -0
  27. vllm/core/block_manager_v1.py +625 -0
  28. vllm/core/block_manager_v2.py +258 -0
  29. vllm/core/evictor_v1.py +105 -0
  30. vllm/core/evictor_v2.py +127 -0
  31. vllm/core/interfaces.py +113 -0
  32. vllm/core/policy.py +45 -0
  33. vllm/core/scheduler.py +1163 -0
  34. vllm/distributed/__init__.py +3 -0
  35. vllm/distributed/communication_op.py +237 -0
  36. vllm/distributed/device_communicators/__init__.py +0 -0
  37. vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
  38. vllm/distributed/device_communicators/pynccl.py +287 -0
  39. vllm/distributed/device_communicators/pynccl_utils.py +66 -0
  40. vllm/distributed/parallel_state.py +339 -0
  41. vllm/distributed/utils.py +136 -0
  42. vllm/engine/__init__.py +0 -0
  43. vllm/engine/arg_utils.py +649 -0
  44. vllm/engine/async_llm_engine.py +737 -0
  45. vllm/engine/llm_engine.py +784 -0
  46. vllm/engine/metrics.py +368 -0
  47. vllm/engine/output_processor/__init__.py +0 -0
  48. vllm/engine/output_processor/interfaces.py +76 -0
  49. vllm/engine/output_processor/multi_step.py +142 -0
  50. vllm/engine/output_processor/single_step.py +284 -0
  51. vllm/engine/output_processor/stop_checker.py +101 -0
  52. vllm/engine/output_processor/util.py +19 -0
  53. vllm/entrypoints/__init__.py +0 -0
  54. vllm/entrypoints/api_server.py +119 -0
  55. vllm/entrypoints/llm.py +259 -0
  56. vllm/entrypoints/openai/__init__.py +0 -0
  57. vllm/entrypoints/openai/api_server.py +186 -0
  58. vllm/entrypoints/openai/cli_args.py +115 -0
  59. vllm/entrypoints/openai/protocol.py +460 -0
  60. vllm/entrypoints/openai/serving_chat.py +392 -0
  61. vllm/entrypoints/openai/serving_completion.py +347 -0
  62. vllm/entrypoints/openai/serving_engine.py +234 -0
  63. vllm/envs.py +217 -0
  64. vllm/executor/__init__.py +0 -0
  65. vllm/executor/cpu_executor.py +152 -0
  66. vllm/executor/distributed_gpu_executor.py +115 -0
  67. vllm/executor/executor_base.py +115 -0
  68. vllm/executor/gpu_executor.py +150 -0
  69. vllm/executor/multiproc_worker_utils.py +263 -0
  70. vllm/executor/neuron_executor.py +91 -0
  71. vllm/executor/ray_gpu_executor.py +327 -0
  72. vllm/executor/ray_utils.py +119 -0
  73. vllm/logger.py +153 -0
  74. vllm/logging/__init__.py +5 -0
  75. vllm/logging/formatter.py +15 -0
  76. vllm/lora/__init__.py +0 -0
  77. vllm/lora/fully_sharded_layers.py +262 -0
  78. vllm/lora/layers.py +1181 -0
  79. vllm/lora/lora.py +167 -0
  80. vllm/lora/models.py +645 -0
  81. vllm/lora/punica.py +213 -0
  82. vllm/lora/request.py +32 -0
  83. vllm/lora/utils.py +98 -0
  84. vllm/lora/worker_manager.py +251 -0
  85. vllm/model_executor/__init__.py +7 -0
  86. vllm/model_executor/guided_decoding/__init__.py +25 -0
  87. vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
  88. vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
  89. vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
  90. vllm/model_executor/layers/__init__.py +0 -0
  91. vllm/model_executor/layers/activation.py +173 -0
  92. vllm/model_executor/layers/fused_moe/__init__.py +7 -0
  93. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  94. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  97. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  98. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  99. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  100. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  101. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  102. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  103. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  104. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  105. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
  106. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  107. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  108. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  109. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  110. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
  111. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  112. vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
  113. vllm/model_executor/layers/layernorm.py +71 -0
  114. vllm/model_executor/layers/linear.py +709 -0
  115. vllm/model_executor/layers/logits_processor.py +115 -0
  116. vllm/model_executor/layers/ops/__init__.py +0 -0
  117. vllm/model_executor/layers/ops/rand.py +157 -0
  118. vllm/model_executor/layers/ops/sample.py +406 -0
  119. vllm/model_executor/layers/quantization/__init__.py +35 -0
  120. vllm/model_executor/layers/quantization/aqlm.py +376 -0
  121. vllm/model_executor/layers/quantization/awq.py +175 -0
  122. vllm/model_executor/layers/quantization/base_config.py +97 -0
  123. vllm/model_executor/layers/quantization/fp8.py +265 -0
  124. vllm/model_executor/layers/quantization/gptq.py +224 -0
  125. vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
  126. vllm/model_executor/layers/quantization/marlin.py +227 -0
  127. vllm/model_executor/layers/quantization/schema.py +84 -0
  128. vllm/model_executor/layers/quantization/squeezellm.py +137 -0
  129. vllm/model_executor/layers/rejection_sampler.py +405 -0
  130. vllm/model_executor/layers/rotary_embedding.py +525 -0
  131. vllm/model_executor/layers/sampler.py +1051 -0
  132. vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
  133. vllm/model_executor/model_loader/__init__.py +30 -0
  134. vllm/model_executor/model_loader/loader.py +362 -0
  135. vllm/model_executor/model_loader/neuron.py +136 -0
  136. vllm/model_executor/model_loader/tensorizer.py +368 -0
  137. vllm/model_executor/model_loader/utils.py +41 -0
  138. vllm/model_executor/model_loader/weight_utils.py +372 -0
  139. vllm/model_executor/models/__init__.py +119 -0
  140. vllm/model_executor/models/baichuan.py +410 -0
  141. vllm/model_executor/models/bloom.py +327 -0
  142. vllm/model_executor/models/chatglm.py +386 -0
  143. vllm/model_executor/models/commandr.py +373 -0
  144. vllm/model_executor/models/dbrx.py +413 -0
  145. vllm/model_executor/models/decilm.py +122 -0
  146. vllm/model_executor/models/deepseek.py +438 -0
  147. vllm/model_executor/models/falcon.py +444 -0
  148. vllm/model_executor/models/gemma.py +393 -0
  149. vllm/model_executor/models/gpt2.py +266 -0
  150. vllm/model_executor/models/gpt_bigcode.py +274 -0
  151. vllm/model_executor/models/gpt_j.py +281 -0
  152. vllm/model_executor/models/gpt_neox.py +295 -0
  153. vllm/model_executor/models/internlm2.py +323 -0
  154. vllm/model_executor/models/jais.py +333 -0
  155. vllm/model_executor/models/llama.py +442 -0
  156. vllm/model_executor/models/llava.py +239 -0
  157. vllm/model_executor/models/minicpm.py +531 -0
  158. vllm/model_executor/models/mixtral.py +583 -0
  159. vllm/model_executor/models/mixtral_quant.py +404 -0
  160. vllm/model_executor/models/mpt.py +295 -0
  161. vllm/model_executor/models/olmo.py +356 -0
  162. vllm/model_executor/models/opt.py +349 -0
  163. vllm/model_executor/models/orion.py +319 -0
  164. vllm/model_executor/models/phi.py +300 -0
  165. vllm/model_executor/models/qwen.py +284 -0
  166. vllm/model_executor/models/qwen2.py +367 -0
  167. vllm/model_executor/models/qwen2_moe.py +447 -0
  168. vllm/model_executor/models/stablelm.py +301 -0
  169. vllm/model_executor/models/starcoder2.py +302 -0
  170. vllm/model_executor/models/xverse.py +366 -0
  171. vllm/model_executor/sampling_metadata.py +588 -0
  172. vllm/model_executor/utils.py +35 -0
  173. vllm/outputs.py +150 -0
  174. vllm/py.typed +2 -0
  175. vllm/sampling_params.py +340 -0
  176. vllm/sequence.py +766 -0
  177. vllm/spec_decode/__init__.py +0 -0
  178. vllm/spec_decode/batch_expansion.py +397 -0
  179. vllm/spec_decode/interfaces.py +73 -0
  180. vllm/spec_decode/metrics.py +191 -0
  181. vllm/spec_decode/multi_step_worker.py +203 -0
  182. vllm/spec_decode/ngram_worker.py +176 -0
  183. vllm/spec_decode/spec_decode_worker.py +472 -0
  184. vllm/spec_decode/top1_proposer.py +200 -0
  185. vllm/spec_decode/util.py +228 -0
  186. vllm/test_utils.py +41 -0
  187. vllm/transformers_utils/__init__.py +0 -0
  188. vllm/transformers_utils/config.py +58 -0
  189. vllm/transformers_utils/configs/__init__.py +16 -0
  190. vllm/transformers_utils/configs/chatglm.py +68 -0
  191. vllm/transformers_utils/configs/dbrx.py +278 -0
  192. vllm/transformers_utils/configs/falcon.py +87 -0
  193. vllm/transformers_utils/configs/jais.py +236 -0
  194. vllm/transformers_utils/configs/mpt.py +178 -0
  195. vllm/transformers_utils/detokenizer.py +313 -0
  196. vllm/transformers_utils/tokenizer.py +149 -0
  197. vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
  198. vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
  199. vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
  200. vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
  201. vllm/transformers_utils/tokenizers/__init__.py +5 -0
  202. vllm/transformers_utils/tokenizers/baichuan.py +255 -0
  203. vllm/usage/__init__.py +0 -0
  204. vllm/usage/usage_lib.py +209 -0
  205. vllm/utils.py +677 -0
  206. vllm/worker/__init__.py +0 -0
  207. vllm/worker/cache_engine.py +105 -0
  208. vllm/worker/cpu_model_runner.py +346 -0
  209. vllm/worker/cpu_worker.py +321 -0
  210. vllm/worker/model_runner.py +1168 -0
  211. vllm/worker/neuron_model_runner.py +196 -0
  212. vllm/worker/neuron_worker.py +98 -0
  213. vllm/worker/worker.py +345 -0
  214. vllm/worker/worker_base.py +146 -0
  215. vllm_npu-0.4.2.dist-info/LICENSE +201 -0
  216. vllm_npu-0.4.2.dist-info/METADATA +173 -0
  217. vllm_npu-0.4.2.dist-info/RECORD +219 -0
  218. vllm_npu-0.4.2.dist-info/WHEEL +5 -0
  219. vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,460 @@
1
+ # Adapted from
2
+ # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
3
+ import time
4
+ from typing import Dict, List, Literal, Optional, Union
5
+
6
+ import torch
7
+ from openai.types.chat import ChatCompletionMessageParam
8
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
9
+ from typing_extensions import Annotated
10
+
11
+ from vllm.sampling_params import SamplingParams
12
+ from vllm.utils import random_uuid
13
+
14
+
15
+ class OpenAIBaseModel(BaseModel):
16
+ # OpenAI API does not allow extra fields
17
+ model_config = ConfigDict(extra="forbid")
18
+
19
+
20
+ class ErrorResponse(OpenAIBaseModel):
21
+ object: str = "error"
22
+ message: str
23
+ type: str
24
+ param: Optional[str] = None
25
+ code: int
26
+
27
+
28
+ class ModelPermission(OpenAIBaseModel):
29
+ id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
30
+ object: str = "model_permission"
31
+ created: int = Field(default_factory=lambda: int(time.time()))
32
+ allow_create_engine: bool = False
33
+ allow_sampling: bool = True
34
+ allow_logprobs: bool = True
35
+ allow_search_indices: bool = False
36
+ allow_view: bool = True
37
+ allow_fine_tuning: bool = False
38
+ organization: str = "*"
39
+ group: Optional[str] = None
40
+ is_blocking: bool = False
41
+
42
+
43
+ class ModelCard(OpenAIBaseModel):
44
+ id: str
45
+ object: str = "model"
46
+ created: int = Field(default_factory=lambda: int(time.time()))
47
+ owned_by: str = "vllm"
48
+ root: Optional[str] = None
49
+ parent: Optional[str] = None
50
+ permission: List[ModelPermission] = Field(default_factory=list)
51
+
52
+
53
+ class ModelList(OpenAIBaseModel):
54
+ object: str = "list"
55
+ data: List[ModelCard] = Field(default_factory=list)
56
+
57
+
58
+ class UsageInfo(OpenAIBaseModel):
59
+ prompt_tokens: int = 0
60
+ total_tokens: int = 0
61
+ completion_tokens: Optional[int] = 0
62
+
63
+
64
+ class ResponseFormat(OpenAIBaseModel):
65
+ # type must be "json_object" or "text"
66
+ type: Literal["text", "json_object"]
67
+
68
+
69
+ class ChatCompletionRequest(OpenAIBaseModel):
70
+ # Ordered by official OpenAI API documentation
71
+ # https://platform.openai.com/docs/api-reference/chat/create
72
+ messages: List[ChatCompletionMessageParam]
73
+ model: str
74
+ frequency_penalty: Optional[float] = 0.0
75
+ logit_bias: Optional[Dict[str, float]] = None
76
+ logprobs: Optional[bool] = False
77
+ top_logprobs: Optional[int] = None
78
+ max_tokens: Optional[int] = None
79
+ n: Optional[int] = 1
80
+ presence_penalty: Optional[float] = 0.0
81
+ response_format: Optional[ResponseFormat] = None
82
+ seed: Optional[int] = Field(None,
83
+ ge=torch.iinfo(torch.long).min,
84
+ le=torch.iinfo(torch.long).max)
85
+ stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
86
+ stream: Optional[bool] = False
87
+ temperature: Optional[float] = 0.7
88
+ top_p: Optional[float] = 1.0
89
+ user: Optional[str] = None
90
+
91
+ # doc: begin-chat-completion-sampling-params
92
+ best_of: Optional[int] = None
93
+ use_beam_search: Optional[bool] = False
94
+ top_k: Optional[int] = -1
95
+ min_p: Optional[float] = 0.0
96
+ repetition_penalty: Optional[float] = 1.0
97
+ length_penalty: Optional[float] = 1.0
98
+ early_stopping: Optional[bool] = False
99
+ ignore_eos: Optional[bool] = False
100
+ min_tokens: Optional[int] = 0
101
+ stop_token_ids: Optional[List[int]] = Field(default_factory=list)
102
+ skip_special_tokens: Optional[bool] = True
103
+ spaces_between_special_tokens: Optional[bool] = True
104
+ # doc: end-chat-completion-sampling-params
105
+
106
+ # doc: begin-chat-completion-extra-params
107
+ echo: Optional[bool] = Field(
108
+ default=False,
109
+ description=(
110
+ "If true, the new message will be prepended with the last message "
111
+ "if they belong to the same role."),
112
+ )
113
+ add_generation_prompt: Optional[bool] = Field(
114
+ default=True,
115
+ description=
116
+ ("If true, the generation prompt will be added to the chat template. "
117
+ "This is a parameter used by chat template in tokenizer config of the "
118
+ "model."),
119
+ )
120
+ include_stop_str_in_output: Optional[bool] = Field(
121
+ default=False,
122
+ description=(
123
+ "Whether to include the stop string in the output. "
124
+ "This is only applied when the stop or stop_token_ids is set."),
125
+ )
126
+ guided_json: Optional[Union[str, dict, BaseModel]] = Field(
127
+ default=None,
128
+ description=("If specified, the output will follow the JSON schema."),
129
+ )
130
+ guided_regex: Optional[str] = Field(
131
+ default=None,
132
+ description=(
133
+ "If specified, the output will follow the regex pattern."),
134
+ )
135
+ guided_choice: Optional[List[str]] = Field(
136
+ default=None,
137
+ description=(
138
+ "If specified, the output will be exactly one of the choices."),
139
+ )
140
+ guided_grammar: Optional[str] = Field(
141
+ default=None,
142
+ description=(
143
+ "If specified, the output will follow the context free grammar."),
144
+ )
145
+ guided_decoding_backend: Optional[str] = Field(
146
+ default=None,
147
+ description=(
148
+ "If specified, will override the default guided decoding backend "
149
+ "of the server for this specific request. If set, must be either "
150
+ "'outlines' / 'lm-format-enforcer'"))
151
+ guided_whitespace_pattern: Optional[str] = Field(
152
+ default=None,
153
+ description=(
154
+ "If specified, will override the default whitespace pattern "
155
+ "for guided json decoding."))
156
+
157
+ # doc: end-chat-completion-extra-params
158
+
159
+ def to_sampling_params(self) -> SamplingParams:
160
+ if self.logprobs and not self.top_logprobs:
161
+ raise ValueError("Top logprobs must be set when logprobs is.")
162
+
163
+ logits_processors = None
164
+ if self.logit_bias:
165
+
166
+ def logit_bias_logits_processor(
167
+ token_ids: List[int],
168
+ logits: torch.Tensor) -> torch.Tensor:
169
+ assert self.logit_bias is not None
170
+ for token_id, bias in self.logit_bias.items():
171
+ # Clamp the bias between -100 and 100 per OpenAI API spec
172
+ bias = min(100, max(-100, bias))
173
+ logits[int(token_id)] += bias
174
+ return logits
175
+
176
+ logits_processors = [logit_bias_logits_processor]
177
+
178
+ return SamplingParams(
179
+ n=self.n,
180
+ presence_penalty=self.presence_penalty,
181
+ frequency_penalty=self.frequency_penalty,
182
+ repetition_penalty=self.repetition_penalty,
183
+ temperature=self.temperature,
184
+ top_p=self.top_p,
185
+ min_p=self.min_p,
186
+ seed=self.seed,
187
+ stop=self.stop,
188
+ stop_token_ids=self.stop_token_ids,
189
+ max_tokens=self.max_tokens,
190
+ min_tokens=self.min_tokens,
191
+ logprobs=self.top_logprobs if self.logprobs else None,
192
+ prompt_logprobs=self.top_logprobs if self.echo else None,
193
+ best_of=self.best_of,
194
+ top_k=self.top_k,
195
+ ignore_eos=self.ignore_eos,
196
+ use_beam_search=self.use_beam_search,
197
+ early_stopping=self.early_stopping,
198
+ skip_special_tokens=self.skip_special_tokens,
199
+ spaces_between_special_tokens=self.spaces_between_special_tokens,
200
+ include_stop_str_in_output=self.include_stop_str_in_output,
201
+ length_penalty=self.length_penalty,
202
+ logits_processors=logits_processors,
203
+ )
204
+
205
+ @model_validator(mode="before")
206
+ @classmethod
207
+ def check_guided_decoding_count(cls, data):
208
+ guide_count = sum([
209
+ "guided_json" in data and data["guided_json"] is not None,
210
+ "guided_regex" in data and data["guided_regex"] is not None,
211
+ "guided_choice" in data and data["guided_choice"] is not None
212
+ ])
213
+ if guide_count > 1:
214
+ raise ValueError(
215
+ "You can only use one kind of guided decoding "
216
+ "('guided_json', 'guided_regex' or 'guided_choice').")
217
+ return data
218
+
219
+
220
+ class CompletionRequest(OpenAIBaseModel):
221
+ # Ordered by official OpenAI API documentation
222
+ # https://platform.openai.com/docs/api-reference/completions/create
223
+ model: str
224
+ prompt: Union[List[int], List[List[int]], str, List[str]]
225
+ best_of: Optional[int] = None
226
+ echo: Optional[bool] = False
227
+ frequency_penalty: Optional[float] = 0.0
228
+ logit_bias: Optional[Dict[str, float]] = None
229
+ logprobs: Optional[int] = None
230
+ max_tokens: Optional[int] = 16
231
+ n: int = 1
232
+ presence_penalty: Optional[float] = 0.0
233
+ seed: Optional[int] = Field(None,
234
+ ge=torch.iinfo(torch.long).min,
235
+ le=torch.iinfo(torch.long).max)
236
+ stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
237
+ stream: Optional[bool] = False
238
+ suffix: Optional[str] = None
239
+ temperature: Optional[float] = 1.0
240
+ top_p: Optional[float] = 1.0
241
+ user: Optional[str] = None
242
+
243
+ # doc: begin-completion-sampling-params
244
+ use_beam_search: Optional[bool] = False
245
+ top_k: Optional[int] = -1
246
+ min_p: Optional[float] = 0.0
247
+ repetition_penalty: Optional[float] = 1.0
248
+ length_penalty: Optional[float] = 1.0
249
+ early_stopping: Optional[bool] = False
250
+ stop_token_ids: Optional[List[int]] = Field(default_factory=list)
251
+ ignore_eos: Optional[bool] = False
252
+ min_tokens: Optional[int] = 0
253
+ skip_special_tokens: Optional[bool] = True
254
+ spaces_between_special_tokens: Optional[bool] = True
255
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
256
+ # doc: end-completion-sampling-params
257
+
258
+ # doc: begin-completion-extra-params
259
+ include_stop_str_in_output: Optional[bool] = Field(
260
+ default=False,
261
+ description=(
262
+ "Whether to include the stop string in the output. "
263
+ "This is only applied when the stop or stop_token_ids is set."),
264
+ )
265
+ response_format: Optional[ResponseFormat] = Field(
266
+ default=None,
267
+ description=
268
+ ("Similar to chat completion, this parameter specifies the format of "
269
+ "output. Only {'type': 'json_object'} or {'type': 'text' } is "
270
+ "supported."),
271
+ )
272
+ guided_json: Optional[Union[str, dict, BaseModel]] = Field(
273
+ default=None,
274
+ description=("If specified, the output will follow the JSON schema."),
275
+ )
276
+ guided_regex: Optional[str] = Field(
277
+ default=None,
278
+ description=(
279
+ "If specified, the output will follow the regex pattern."),
280
+ )
281
+ guided_choice: Optional[List[str]] = Field(
282
+ default=None,
283
+ description=(
284
+ "If specified, the output will be exactly one of the choices."),
285
+ )
286
+ guided_grammar: Optional[str] = Field(
287
+ default=None,
288
+ description=(
289
+ "If specified, the output will follow the context free grammar."),
290
+ )
291
+ guided_decoding_backend: Optional[str] = Field(
292
+ default=None,
293
+ description=(
294
+ "If specified, will override the default guided decoding backend "
295
+ "of the server for this specific request. If set, must be one of "
296
+ "'outlines' / 'lm-format-enforcer'"))
297
+ guided_whitespace_pattern: Optional[str] = Field(
298
+ default=None,
299
+ description=(
300
+ "If specified, will override the default whitespace pattern "
301
+ "for guided json decoding."))
302
+
303
+ # doc: end-completion-extra-params
304
+
305
+ def to_sampling_params(self):
306
+ echo_without_generation = self.echo and self.max_tokens == 0
307
+
308
+ logits_processors = None
309
+ if self.logit_bias:
310
+
311
+ def logit_bias_logits_processor(
312
+ token_ids: List[int],
313
+ logits: torch.Tensor) -> torch.Tensor:
314
+ assert self.logit_bias is not None
315
+ for token_id, bias in self.logit_bias.items():
316
+ # Clamp the bias between -100 and 100 per OpenAI API spec
317
+ bias = min(100, max(-100, bias))
318
+ logits[int(token_id)] += bias
319
+ return logits
320
+
321
+ logits_processors = [logit_bias_logits_processor]
322
+
323
+ return SamplingParams(
324
+ n=self.n,
325
+ best_of=self.best_of,
326
+ presence_penalty=self.presence_penalty,
327
+ frequency_penalty=self.frequency_penalty,
328
+ repetition_penalty=self.repetition_penalty,
329
+ temperature=self.temperature,
330
+ top_p=self.top_p,
331
+ top_k=self.top_k,
332
+ min_p=self.min_p,
333
+ seed=self.seed,
334
+ stop=self.stop,
335
+ stop_token_ids=self.stop_token_ids,
336
+ ignore_eos=self.ignore_eos,
337
+ max_tokens=self.max_tokens if not echo_without_generation else 1,
338
+ min_tokens=self.min_tokens,
339
+ logprobs=self.logprobs,
340
+ use_beam_search=self.use_beam_search,
341
+ early_stopping=self.early_stopping,
342
+ prompt_logprobs=self.logprobs if self.echo else None,
343
+ skip_special_tokens=self.skip_special_tokens,
344
+ spaces_between_special_tokens=(self.spaces_between_special_tokens),
345
+ include_stop_str_in_output=self.include_stop_str_in_output,
346
+ length_penalty=self.length_penalty,
347
+ logits_processors=logits_processors,
348
+ truncate_prompt_tokens=self.truncate_prompt_tokens,
349
+ )
350
+
351
+ @model_validator(mode="before")
352
+ @classmethod
353
+ def check_guided_decoding_count(cls, data):
354
+ guide_count = sum([
355
+ "guided_json" in data and data["guided_json"] is not None,
356
+ "guided_regex" in data and data["guided_regex"] is not None,
357
+ "guided_choice" in data and data["guided_choice"] is not None
358
+ ])
359
+ if guide_count > 1:
360
+ raise ValueError(
361
+ "You can only use one kind of guided decoding "
362
+ "('guided_json', 'guided_regex' or 'guided_choice').")
363
+ return data
364
+
365
+
366
+ class LogProbs(OpenAIBaseModel):
367
+ text_offset: List[int] = Field(default_factory=list)
368
+ token_logprobs: List[Optional[float]] = Field(default_factory=list)
369
+ tokens: List[str] = Field(default_factory=list)
370
+ top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None
371
+
372
+
373
+ class CompletionResponseChoice(OpenAIBaseModel):
374
+ index: int
375
+ text: str
376
+ logprobs: Optional[LogProbs] = None
377
+ finish_reason: Optional[str] = None
378
+ stop_reason: Optional[Union[int, str]] = Field(
379
+ default=None,
380
+ description=(
381
+ "The stop string or token id that caused the completion "
382
+ "to stop, None if the completion finished for some other reason "
383
+ "including encountering the EOS token"),
384
+ )
385
+
386
+
387
+ class CompletionResponse(OpenAIBaseModel):
388
+ id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
389
+ object: str = "text_completion"
390
+ created: int = Field(default_factory=lambda: int(time.time()))
391
+ model: str
392
+ choices: List[CompletionResponseChoice]
393
+ usage: UsageInfo
394
+
395
+
396
+ class CompletionResponseStreamChoice(OpenAIBaseModel):
397
+ index: int
398
+ text: str
399
+ logprobs: Optional[LogProbs] = None
400
+ finish_reason: Optional[str] = None
401
+ stop_reason: Optional[Union[int, str]] = Field(
402
+ default=None,
403
+ description=(
404
+ "The stop string or token id that caused the completion "
405
+ "to stop, None if the completion finished for some other reason "
406
+ "including encountering the EOS token"),
407
+ )
408
+
409
+
410
+ class CompletionStreamResponse(OpenAIBaseModel):
411
+ id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
412
+ object: str = "text_completion"
413
+ created: int = Field(default_factory=lambda: int(time.time()))
414
+ model: str
415
+ choices: List[CompletionResponseStreamChoice]
416
+ usage: Optional[UsageInfo] = Field(default=None)
417
+
418
+
419
+ class ChatMessage(OpenAIBaseModel):
420
+ role: str
421
+ content: str
422
+
423
+
424
+ class ChatCompletionResponseChoice(OpenAIBaseModel):
425
+ index: int
426
+ message: ChatMessage
427
+ logprobs: Optional[LogProbs] = None
428
+ finish_reason: Optional[str] = None
429
+ stop_reason: Optional[Union[int, str]] = None
430
+
431
+
432
+ class ChatCompletionResponse(OpenAIBaseModel):
433
+ id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
434
+ object: str = "chat.completion"
435
+ created: int = Field(default_factory=lambda: int(time.time()))
436
+ model: str
437
+ choices: List[ChatCompletionResponseChoice]
438
+ usage: UsageInfo
439
+
440
+
441
+ class DeltaMessage(OpenAIBaseModel):
442
+ role: Optional[str] = None
443
+ content: Optional[str] = None
444
+
445
+
446
+ class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
447
+ index: int
448
+ delta: DeltaMessage
449
+ logprobs: Optional[LogProbs] = None
450
+ finish_reason: Optional[str] = None
451
+ stop_reason: Optional[Union[int, str]] = None
452
+
453
+
454
+ class ChatCompletionStreamResponse(OpenAIBaseModel):
455
+ id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
456
+ object: str = "chat.completion.chunk"
457
+ created: int = Field(default_factory=lambda: int(time.time()))
458
+ model: str
459
+ choices: List[ChatCompletionResponseStreamChoice]
460
+ usage: Optional[UsageInfo] = Field(default=None)