vllm-npu 0.4.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. vllm/__init__.py +23 -0
  2. vllm/_custom_ops.py +251 -0
  3. vllm/attention/__init__.py +13 -0
  4. vllm/attention/backends/__init__.py +0 -0
  5. vllm/attention/backends/abstract.py +127 -0
  6. vllm/attention/backends/flash_attn.py +271 -0
  7. vllm/attention/backends/flashinfer.py +220 -0
  8. vllm/attention/backends/rocm_flash_attn.py +374 -0
  9. vllm/attention/backends/torch_sdpa.py +250 -0
  10. vllm/attention/backends/xformers.py +393 -0
  11. vllm/attention/layer.py +56 -0
  12. vllm/attention/ops/__init__.py +0 -0
  13. vllm/attention/ops/paged_attn.py +216 -0
  14. vllm/attention/ops/prefix_prefill.py +792 -0
  15. vllm/attention/ops/triton_flash_attention.py +810 -0
  16. vllm/attention/selector.py +91 -0
  17. vllm/block.py +84 -0
  18. vllm/config.py +1225 -0
  19. vllm/core/__init__.py +0 -0
  20. vllm/core/block/__init__.py +0 -0
  21. vllm/core/block/block_table.py +295 -0
  22. vllm/core/block/common.py +199 -0
  23. vllm/core/block/cpu_gpu_block_allocator.py +228 -0
  24. vllm/core/block/interfaces.py +205 -0
  25. vllm/core/block/naive_block.py +318 -0
  26. vllm/core/block/prefix_caching_block.py +606 -0
  27. vllm/core/block_manager_v1.py +625 -0
  28. vllm/core/block_manager_v2.py +258 -0
  29. vllm/core/evictor_v1.py +105 -0
  30. vllm/core/evictor_v2.py +127 -0
  31. vllm/core/interfaces.py +113 -0
  32. vllm/core/policy.py +45 -0
  33. vllm/core/scheduler.py +1163 -0
  34. vllm/distributed/__init__.py +3 -0
  35. vllm/distributed/communication_op.py +237 -0
  36. vllm/distributed/device_communicators/__init__.py +0 -0
  37. vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
  38. vllm/distributed/device_communicators/pynccl.py +287 -0
  39. vllm/distributed/device_communicators/pynccl_utils.py +66 -0
  40. vllm/distributed/parallel_state.py +339 -0
  41. vllm/distributed/utils.py +136 -0
  42. vllm/engine/__init__.py +0 -0
  43. vllm/engine/arg_utils.py +649 -0
  44. vllm/engine/async_llm_engine.py +737 -0
  45. vllm/engine/llm_engine.py +784 -0
  46. vllm/engine/metrics.py +368 -0
  47. vllm/engine/output_processor/__init__.py +0 -0
  48. vllm/engine/output_processor/interfaces.py +76 -0
  49. vllm/engine/output_processor/multi_step.py +142 -0
  50. vllm/engine/output_processor/single_step.py +284 -0
  51. vllm/engine/output_processor/stop_checker.py +101 -0
  52. vllm/engine/output_processor/util.py +19 -0
  53. vllm/entrypoints/__init__.py +0 -0
  54. vllm/entrypoints/api_server.py +119 -0
  55. vllm/entrypoints/llm.py +259 -0
  56. vllm/entrypoints/openai/__init__.py +0 -0
  57. vllm/entrypoints/openai/api_server.py +186 -0
  58. vllm/entrypoints/openai/cli_args.py +115 -0
  59. vllm/entrypoints/openai/protocol.py +460 -0
  60. vllm/entrypoints/openai/serving_chat.py +392 -0
  61. vllm/entrypoints/openai/serving_completion.py +347 -0
  62. vllm/entrypoints/openai/serving_engine.py +234 -0
  63. vllm/envs.py +217 -0
  64. vllm/executor/__init__.py +0 -0
  65. vllm/executor/cpu_executor.py +152 -0
  66. vllm/executor/distributed_gpu_executor.py +115 -0
  67. vllm/executor/executor_base.py +115 -0
  68. vllm/executor/gpu_executor.py +150 -0
  69. vllm/executor/multiproc_worker_utils.py +263 -0
  70. vllm/executor/neuron_executor.py +91 -0
  71. vllm/executor/ray_gpu_executor.py +327 -0
  72. vllm/executor/ray_utils.py +119 -0
  73. vllm/logger.py +153 -0
  74. vllm/logging/__init__.py +5 -0
  75. vllm/logging/formatter.py +15 -0
  76. vllm/lora/__init__.py +0 -0
  77. vllm/lora/fully_sharded_layers.py +262 -0
  78. vllm/lora/layers.py +1181 -0
  79. vllm/lora/lora.py +167 -0
  80. vllm/lora/models.py +645 -0
  81. vllm/lora/punica.py +213 -0
  82. vllm/lora/request.py +32 -0
  83. vllm/lora/utils.py +98 -0
  84. vllm/lora/worker_manager.py +251 -0
  85. vllm/model_executor/__init__.py +7 -0
  86. vllm/model_executor/guided_decoding/__init__.py +25 -0
  87. vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
  88. vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
  89. vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
  90. vllm/model_executor/layers/__init__.py +0 -0
  91. vllm/model_executor/layers/activation.py +173 -0
  92. vllm/model_executor/layers/fused_moe/__init__.py +7 -0
  93. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  94. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  97. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  98. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  99. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  100. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  101. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  102. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  103. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  104. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  105. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
  106. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  107. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  108. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  109. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  110. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
  111. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  112. vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
  113. vllm/model_executor/layers/layernorm.py +71 -0
  114. vllm/model_executor/layers/linear.py +709 -0
  115. vllm/model_executor/layers/logits_processor.py +115 -0
  116. vllm/model_executor/layers/ops/__init__.py +0 -0
  117. vllm/model_executor/layers/ops/rand.py +157 -0
  118. vllm/model_executor/layers/ops/sample.py +406 -0
  119. vllm/model_executor/layers/quantization/__init__.py +35 -0
  120. vllm/model_executor/layers/quantization/aqlm.py +376 -0
  121. vllm/model_executor/layers/quantization/awq.py +175 -0
  122. vllm/model_executor/layers/quantization/base_config.py +97 -0
  123. vllm/model_executor/layers/quantization/fp8.py +265 -0
  124. vllm/model_executor/layers/quantization/gptq.py +224 -0
  125. vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
  126. vllm/model_executor/layers/quantization/marlin.py +227 -0
  127. vllm/model_executor/layers/quantization/schema.py +84 -0
  128. vllm/model_executor/layers/quantization/squeezellm.py +137 -0
  129. vllm/model_executor/layers/rejection_sampler.py +405 -0
  130. vllm/model_executor/layers/rotary_embedding.py +525 -0
  131. vllm/model_executor/layers/sampler.py +1051 -0
  132. vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
  133. vllm/model_executor/model_loader/__init__.py +30 -0
  134. vllm/model_executor/model_loader/loader.py +362 -0
  135. vllm/model_executor/model_loader/neuron.py +136 -0
  136. vllm/model_executor/model_loader/tensorizer.py +368 -0
  137. vllm/model_executor/model_loader/utils.py +41 -0
  138. vllm/model_executor/model_loader/weight_utils.py +372 -0
  139. vllm/model_executor/models/__init__.py +119 -0
  140. vllm/model_executor/models/baichuan.py +410 -0
  141. vllm/model_executor/models/bloom.py +327 -0
  142. vllm/model_executor/models/chatglm.py +386 -0
  143. vllm/model_executor/models/commandr.py +373 -0
  144. vllm/model_executor/models/dbrx.py +413 -0
  145. vllm/model_executor/models/decilm.py +122 -0
  146. vllm/model_executor/models/deepseek.py +438 -0
  147. vllm/model_executor/models/falcon.py +444 -0
  148. vllm/model_executor/models/gemma.py +393 -0
  149. vllm/model_executor/models/gpt2.py +266 -0
  150. vllm/model_executor/models/gpt_bigcode.py +274 -0
  151. vllm/model_executor/models/gpt_j.py +281 -0
  152. vllm/model_executor/models/gpt_neox.py +295 -0
  153. vllm/model_executor/models/internlm2.py +323 -0
  154. vllm/model_executor/models/jais.py +333 -0
  155. vllm/model_executor/models/llama.py +442 -0
  156. vllm/model_executor/models/llava.py +239 -0
  157. vllm/model_executor/models/minicpm.py +531 -0
  158. vllm/model_executor/models/mixtral.py +583 -0
  159. vllm/model_executor/models/mixtral_quant.py +404 -0
  160. vllm/model_executor/models/mpt.py +295 -0
  161. vllm/model_executor/models/olmo.py +356 -0
  162. vllm/model_executor/models/opt.py +349 -0
  163. vllm/model_executor/models/orion.py +319 -0
  164. vllm/model_executor/models/phi.py +300 -0
  165. vllm/model_executor/models/qwen.py +284 -0
  166. vllm/model_executor/models/qwen2.py +367 -0
  167. vllm/model_executor/models/qwen2_moe.py +447 -0
  168. vllm/model_executor/models/stablelm.py +301 -0
  169. vllm/model_executor/models/starcoder2.py +302 -0
  170. vllm/model_executor/models/xverse.py +366 -0
  171. vllm/model_executor/sampling_metadata.py +588 -0
  172. vllm/model_executor/utils.py +35 -0
  173. vllm/outputs.py +150 -0
  174. vllm/py.typed +2 -0
  175. vllm/sampling_params.py +340 -0
  176. vllm/sequence.py +766 -0
  177. vllm/spec_decode/__init__.py +0 -0
  178. vllm/spec_decode/batch_expansion.py +397 -0
  179. vllm/spec_decode/interfaces.py +73 -0
  180. vllm/spec_decode/metrics.py +191 -0
  181. vllm/spec_decode/multi_step_worker.py +203 -0
  182. vllm/spec_decode/ngram_worker.py +176 -0
  183. vllm/spec_decode/spec_decode_worker.py +472 -0
  184. vllm/spec_decode/top1_proposer.py +200 -0
  185. vllm/spec_decode/util.py +228 -0
  186. vllm/test_utils.py +41 -0
  187. vllm/transformers_utils/__init__.py +0 -0
  188. vllm/transformers_utils/config.py +58 -0
  189. vllm/transformers_utils/configs/__init__.py +16 -0
  190. vllm/transformers_utils/configs/chatglm.py +68 -0
  191. vllm/transformers_utils/configs/dbrx.py +278 -0
  192. vllm/transformers_utils/configs/falcon.py +87 -0
  193. vllm/transformers_utils/configs/jais.py +236 -0
  194. vllm/transformers_utils/configs/mpt.py +178 -0
  195. vllm/transformers_utils/detokenizer.py +313 -0
  196. vllm/transformers_utils/tokenizer.py +149 -0
  197. vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
  198. vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
  199. vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
  200. vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
  201. vllm/transformers_utils/tokenizers/__init__.py +5 -0
  202. vllm/transformers_utils/tokenizers/baichuan.py +255 -0
  203. vllm/usage/__init__.py +0 -0
  204. vllm/usage/usage_lib.py +209 -0
  205. vllm/utils.py +677 -0
  206. vllm/worker/__init__.py +0 -0
  207. vllm/worker/cache_engine.py +105 -0
  208. vllm/worker/cpu_model_runner.py +346 -0
  209. vllm/worker/cpu_worker.py +321 -0
  210. vllm/worker/model_runner.py +1168 -0
  211. vllm/worker/neuron_model_runner.py +196 -0
  212. vllm/worker/neuron_worker.py +98 -0
  213. vllm/worker/worker.py +345 -0
  214. vllm/worker/worker_base.py +146 -0
  215. vllm_npu-0.4.2.dist-info/LICENSE +201 -0
  216. vllm_npu-0.4.2.dist-info/METADATA +173 -0
  217. vllm_npu-0.4.2.dist-info/RECORD +219 -0
  218. vllm_npu-0.4.2.dist-info/WHEEL +5 -0
  219. vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,259 @@
1
+ from typing import List, Optional, Union
2
+
3
+ import torch
4
+ from tqdm import tqdm
5
+ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
6
+
7
+ from vllm.engine.arg_utils import EngineArgs
8
+ from vllm.engine.llm_engine import LLMEngine
9
+ from vllm.lora.request import LoRARequest
10
+ from vllm.outputs import RequestOutput
11
+ from vllm.sampling_params import SamplingParams
12
+ from vllm.sequence import MultiModalData
13
+ from vllm.usage.usage_lib import UsageContext
14
+ from vllm.utils import Counter
15
+
16
+
17
+ class LLM:
18
+ """An LLM for generating texts from given prompts and sampling parameters.
19
+
20
+ This class includes a tokenizer, a language model (possibly distributed
21
+ across multiple GPUs), and GPU memory space allocated for intermediate
22
+ states (aka KV cache). Given a batch of prompts and sampling parameters,
23
+ this class generates texts from the model, using an intelligent batching
24
+ mechanism and efficient memory management.
25
+
26
+ NOTE: This class is intended to be used for offline inference. For online
27
+ serving, use the `AsyncLLMEngine` class instead.
28
+ NOTE: For the comprehensive list of arguments, see `EngineArgs`.
29
+
30
+ Args:
31
+ model: The name or path of a HuggingFace Transformers model.
32
+ tokenizer: The name or path of a HuggingFace Transformers tokenizer.
33
+ tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
34
+ if available, and "slow" will always use the slow tokenizer.
35
+ skip_tokenizer_init: If true, skip initialization of tokenizer and
36
+ detokenizer. Expect valid prompt_token_ids and None for prompt
37
+ from the input.
38
+ trust_remote_code: Trust remote code (e.g., from HuggingFace) when
39
+ downloading the model and tokenizer.
40
+ tensor_parallel_size: The number of GPUs to use for distributed
41
+ execution with tensor parallelism.
42
+ dtype: The data type for the model weights and activations. Currently,
43
+ we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
44
+ the `torch_dtype` attribute specified in the model config file.
45
+ However, if the `torch_dtype` in the config is `float32`, we will
46
+ use `float16` instead.
47
+ quantization: The method used to quantize the model weights. Currently,
48
+ we support "awq", "gptq", "squeezellm", and "fp8" (experimental).
49
+ If None, we first check the `quantization_config` attribute in the
50
+ model config file. If that is None, we assume the model weights are
51
+ not quantized and use `dtype` to determine the data type of
52
+ the weights.
53
+ revision: The specific model version to use. It can be a branch name,
54
+ a tag name, or a commit id.
55
+ tokenizer_revision: The specific tokenizer version to use. It can be a
56
+ branch name, a tag name, or a commit id.
57
+ seed: The seed to initialize the random number generator for sampling.
58
+ gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
59
+ reserve for the model weights, activations, and KV cache. Higher
60
+ values will increase the KV cache size and thus improve the model's
61
+ throughput. However, if the value is too high, it may cause out-of-
62
+ memory (OOM) errors.
63
+ swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
64
+ This can be used for temporarily storing the states of the requests
65
+ when their `best_of` sampling parameters are larger than 1. If all
66
+ requests will have `best_of=1`, you can safely set this to 0.
67
+ Otherwise, too small values may cause out-of-memory (OOM) errors.
68
+ enforce_eager: Whether to enforce eager execution. If True, we will
69
+ disable CUDA graph and always execute the model in eager mode.
70
+ If False, we will use CUDA graph and eager execution in hybrid.
71
+ max_context_len_to_capture: Maximum context len covered by CUDA graphs.
72
+ When a sequence has context length larger than this, we fall back
73
+ to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
74
+ max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
75
+ When a sequence has context length larger than this, we fall back
76
+ to eager mode.
77
+ disable_custom_all_reduce: See ParallelConfig
78
+ """
79
+
80
+ def __init__(
81
+ self,
82
+ model: str,
83
+ tokenizer: Optional[str] = None,
84
+ tokenizer_mode: str = "auto",
85
+ skip_tokenizer_init: bool = False,
86
+ trust_remote_code: bool = False,
87
+ tensor_parallel_size: int = 1,
88
+ dtype: str = "auto",
89
+ quantization: Optional[str] = None,
90
+ revision: Optional[str] = None,
91
+ tokenizer_revision: Optional[str] = None,
92
+ seed: int = 0,
93
+ gpu_memory_utilization: float = 0.9,
94
+ swap_space: int = 4,
95
+ enforce_eager: bool = False,
96
+ max_context_len_to_capture: Optional[int] = None,
97
+ max_seq_len_to_capture: int = 8192,
98
+ disable_custom_all_reduce: bool = False,
99
+ **kwargs,
100
+ ) -> None:
101
+ if "disable_log_stats" not in kwargs:
102
+ kwargs["disable_log_stats"] = True
103
+ engine_args = EngineArgs(
104
+ model=model,
105
+ tokenizer=tokenizer,
106
+ tokenizer_mode=tokenizer_mode,
107
+ skip_tokenizer_init=skip_tokenizer_init,
108
+ trust_remote_code=trust_remote_code,
109
+ tensor_parallel_size=tensor_parallel_size,
110
+ dtype=dtype,
111
+ quantization=quantization,
112
+ revision=revision,
113
+ tokenizer_revision=tokenizer_revision,
114
+ seed=seed,
115
+ gpu_memory_utilization=gpu_memory_utilization,
116
+ swap_space=swap_space,
117
+ enforce_eager=enforce_eager,
118
+ max_context_len_to_capture=max_context_len_to_capture,
119
+ max_seq_len_to_capture=max_seq_len_to_capture,
120
+ disable_custom_all_reduce=disable_custom_all_reduce,
121
+ **kwargs,
122
+ )
123
+ self.llm_engine = LLMEngine.from_engine_args(
124
+ engine_args, usage_context=UsageContext.LLM_CLASS)
125
+ self.request_counter = Counter()
126
+
127
+ def get_tokenizer(
128
+ self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
129
+ return self.llm_engine.tokenizer.tokenizer
130
+
131
+ def set_tokenizer(
132
+ self,
133
+ tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
134
+ ) -> None:
135
+ self.llm_engine.tokenizer.tokenizer = tokenizer
136
+
137
+ def generate(
138
+ self,
139
+ prompts: Optional[Union[str, List[str]]] = None,
140
+ sampling_params: Optional[Union[SamplingParams,
141
+ List[SamplingParams]]] = None,
142
+ prompt_token_ids: Optional[List[List[int]]] = None,
143
+ use_tqdm: bool = True,
144
+ lora_request: Optional[LoRARequest] = None,
145
+ multi_modal_data: Optional[MultiModalData] = None,
146
+ ) -> List[RequestOutput]:
147
+ """Generates the completions for the input prompts.
148
+
149
+ NOTE: This class automatically batches the given prompts, considering
150
+ the memory constraint. For the best performance, put all of your prompts
151
+ into a single list and pass it to this method.
152
+
153
+ Args:
154
+ prompts: A list of prompts to generate completions for.
155
+ sampling_params: The sampling parameters for text generation. If
156
+ None, we use the default sampling parameters.
157
+ When it is a single value, it is applied to every prompt.
158
+ When it is a list, the list must have the same length as the
159
+ prompts and it is paired one by one with the prompt.
160
+ prompt_token_ids: A list of token IDs for the prompts. If None, we
161
+ use the tokenizer to convert the prompts to token IDs.
162
+ use_tqdm: Whether to use tqdm to display the progress bar.
163
+ lora_request: LoRA request to use for generation, if any.
164
+ multi_modal_data: Multi modal data.
165
+
166
+ Returns:
167
+ A list of `RequestOutput` objects containing the generated
168
+ completions in the same order as the input prompts.
169
+ """
170
+ if prompts is None and prompt_token_ids is None:
171
+ raise ValueError("Either prompts or prompt_token_ids must be "
172
+ "provided.")
173
+ if self.llm_engine.model_config.skip_tokenizer_init \
174
+ and prompts is not None:
175
+ raise ValueError("prompts must be None if skip_tokenizer_init "
176
+ "is True")
177
+ if isinstance(prompts, str):
178
+ # Convert a single prompt to a list.
179
+ prompts = [prompts]
180
+ if (prompts is not None and prompt_token_ids is not None
181
+ and len(prompts) != len(prompt_token_ids)):
182
+ raise ValueError("The lengths of prompts and prompt_token_ids "
183
+ "must be the same.")
184
+
185
+ if prompts is not None:
186
+ num_requests = len(prompts)
187
+ else:
188
+ assert prompt_token_ids is not None
189
+ num_requests = len(prompt_token_ids)
190
+
191
+ if sampling_params is None:
192
+ # Use default sampling params.
193
+ sampling_params = SamplingParams()
194
+
195
+ elif isinstance(sampling_params,
196
+ list) and len(sampling_params) != num_requests:
197
+ raise ValueError("The lengths of prompts and sampling_params "
198
+ "must be the same.")
199
+ if multi_modal_data:
200
+ multi_modal_data.data = multi_modal_data.data.to(torch.float16)
201
+
202
+ # Add requests to the engine.
203
+ for i in range(num_requests):
204
+ prompt = prompts[i] if prompts is not None else None
205
+ token_ids = None if prompt_token_ids is None else prompt_token_ids[
206
+ i]
207
+ self._add_request(
208
+ prompt,
209
+ sampling_params[i]
210
+ if isinstance(sampling_params, list) else sampling_params,
211
+ token_ids,
212
+ lora_request=lora_request,
213
+ # Get ith image while maintaining the batch dim.
214
+ multi_modal_data=MultiModalData(
215
+ type=multi_modal_data.type,
216
+ data=multi_modal_data.data[i].unsqueeze(0))
217
+ if multi_modal_data else None,
218
+ )
219
+ return self._run_engine(use_tqdm)
220
+
221
+ def _add_request(
222
+ self,
223
+ prompt: Optional[str],
224
+ sampling_params: SamplingParams,
225
+ prompt_token_ids: Optional[List[int]],
226
+ lora_request: Optional[LoRARequest] = None,
227
+ multi_modal_data: Optional[MultiModalData] = None,
228
+ ) -> None:
229
+ request_id = str(next(self.request_counter))
230
+ self.llm_engine.add_request(request_id,
231
+ prompt,
232
+ sampling_params,
233
+ prompt_token_ids,
234
+ lora_request=lora_request,
235
+ multi_modal_data=multi_modal_data)
236
+
237
+ def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
238
+ # Initialize tqdm.
239
+ if use_tqdm:
240
+ num_requests = self.llm_engine.get_num_unfinished_requests()
241
+ pbar = tqdm(total=num_requests,
242
+ desc="Processed prompts",
243
+ dynamic_ncols=True)
244
+ # Run the engine.
245
+ outputs: List[RequestOutput] = []
246
+ while self.llm_engine.has_unfinished_requests():
247
+ step_outputs = self.llm_engine.step()
248
+ for output in step_outputs:
249
+ if output.finished:
250
+ outputs.append(output)
251
+ if use_tqdm:
252
+ pbar.update(1)
253
+ if use_tqdm:
254
+ pbar.close()
255
+ # Sort the outputs by request ID.
256
+ # This is necessary because some requests may be finished earlier than
257
+ # its previous requests.
258
+ outputs = sorted(outputs, key=lambda x: int(x.request_id))
259
+ return outputs
File without changes
@@ -0,0 +1,186 @@
1
+ import asyncio
2
+ import importlib
3
+ import inspect
4
+ import re
5
+ from contextlib import asynccontextmanager
6
+ from http import HTTPStatus
7
+ from typing import Any, Set
8
+
9
+ import fastapi
10
+ import uvicorn
11
+ from fastapi import Request
12
+ from fastapi.exceptions import RequestValidationError
13
+ from fastapi.middleware.cors import CORSMiddleware
14
+ from fastapi.responses import JSONResponse, Response, StreamingResponse
15
+ from prometheus_client import make_asgi_app
16
+ from starlette.routing import Mount
17
+
18
+ import vllm
19
+ import vllm.envs as envs
20
+ from vllm.engine.arg_utils import AsyncEngineArgs
21
+ from vllm.engine.async_llm_engine import AsyncLLMEngine
22
+ from vllm.entrypoints.openai.cli_args import make_arg_parser
23
+ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
24
+ ChatCompletionResponse,
25
+ CompletionRequest, ErrorResponse)
26
+ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
27
+ from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
28
+ from vllm.logger import init_logger
29
+ from vllm.usage.usage_lib import UsageContext
30
+
31
+ TIMEOUT_KEEP_ALIVE = 5 # seconds
32
+
33
+ openai_serving_chat: OpenAIServingChat
34
+ openai_serving_completion: OpenAIServingCompletion
35
+ logger = init_logger(__name__)
36
+
37
+ _running_tasks: Set[asyncio.Task[Any]] = set()
38
+
39
+
40
+ @asynccontextmanager
41
+ async def lifespan(app: fastapi.FastAPI):
42
+
43
+ async def _force_log():
44
+ while True:
45
+ await asyncio.sleep(10)
46
+ await engine.do_log_stats()
47
+
48
+ if not engine_args.disable_log_stats:
49
+ task = asyncio.create_task(_force_log())
50
+ _running_tasks.add(task)
51
+ task.add_done_callback(_running_tasks.remove)
52
+
53
+ yield
54
+
55
+
56
+ app = fastapi.FastAPI(lifespan=lifespan)
57
+
58
+
59
+ def parse_args():
60
+ parser = make_arg_parser()
61
+ return parser.parse_args()
62
+
63
+
64
+ # Add prometheus asgi middleware to route /metrics requests
65
+ route = Mount("/metrics", make_asgi_app())
66
+ # Workaround for 307 Redirect for /metrics
67
+ route.path_regex = re.compile('^/metrics(?P<path>.*)$')
68
+ app.routes.append(route)
69
+
70
+
71
+ @app.exception_handler(RequestValidationError)
72
+ async def validation_exception_handler(_, exc):
73
+ err = openai_serving_chat.create_error_response(message=str(exc))
74
+ return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
75
+
76
+
77
+ @app.get("/health")
78
+ async def health() -> Response:
79
+ """Health check."""
80
+ await openai_serving_chat.engine.check_health()
81
+ return Response(status_code=200)
82
+
83
+
84
+ @app.get("/v1/models")
85
+ async def show_available_models():
86
+ models = await openai_serving_chat.show_available_models()
87
+ return JSONResponse(content=models.model_dump())
88
+
89
+
90
+ @app.get("/version")
91
+ async def show_version():
92
+ ver = {"version": vllm.__version__}
93
+ return JSONResponse(content=ver)
94
+
95
+
96
+ @app.post("/v1/chat/completions")
97
+ async def create_chat_completion(request: ChatCompletionRequest,
98
+ raw_request: Request):
99
+ generator = await openai_serving_chat.create_chat_completion(
100
+ request, raw_request)
101
+ if isinstance(generator, ErrorResponse):
102
+ return JSONResponse(content=generator.model_dump(),
103
+ status_code=generator.code)
104
+ if request.stream:
105
+ return StreamingResponse(content=generator,
106
+ media_type="text/event-stream")
107
+ else:
108
+ assert isinstance(generator, ChatCompletionResponse)
109
+ return JSONResponse(content=generator.model_dump())
110
+
111
+
112
+ @app.post("/v1/completions")
113
+ async def create_completion(request: CompletionRequest, raw_request: Request):
114
+ generator = await openai_serving_completion.create_completion(
115
+ request, raw_request)
116
+ if isinstance(generator, ErrorResponse):
117
+ return JSONResponse(content=generator.model_dump(),
118
+ status_code=generator.code)
119
+ if request.stream:
120
+ return StreamingResponse(content=generator,
121
+ media_type="text/event-stream")
122
+ else:
123
+ return JSONResponse(content=generator.model_dump())
124
+
125
+
126
+ if __name__ == "__main__":
127
+ args = parse_args()
128
+
129
+ app.add_middleware(
130
+ CORSMiddleware,
131
+ allow_origins=args.allowed_origins,
132
+ allow_credentials=args.allow_credentials,
133
+ allow_methods=args.allowed_methods,
134
+ allow_headers=args.allowed_headers,
135
+ )
136
+
137
+ if token := envs.VLLM_API_KEY or args.api_key:
138
+
139
+ @app.middleware("http")
140
+ async def authentication(request: Request, call_next):
141
+ root_path = "" if args.root_path is None else args.root_path
142
+ if not request.url.path.startswith(f"{root_path}/v1"):
143
+ return await call_next(request)
144
+ if request.headers.get("Authorization") != "Bearer " + token:
145
+ return JSONResponse(content={"error": "Unauthorized"},
146
+ status_code=401)
147
+ return await call_next(request)
148
+
149
+ for middleware in args.middleware:
150
+ module_path, object_name = middleware.rsplit(".", 1)
151
+ imported = getattr(importlib.import_module(module_path), object_name)
152
+ if inspect.isclass(imported):
153
+ app.add_middleware(imported)
154
+ elif inspect.iscoroutinefunction(imported):
155
+ app.middleware("http")(imported)
156
+ else:
157
+ raise ValueError(f"Invalid middleware {middleware}. "
158
+ f"Must be a function or a class.")
159
+
160
+ logger.info("vLLM API server version %s", vllm.__version__)
161
+ logger.info("args: %s", args)
162
+
163
+ if args.served_model_name is not None:
164
+ served_model_names = args.served_model_name
165
+ else:
166
+ served_model_names = [args.model]
167
+ engine_args = AsyncEngineArgs.from_cli_args(args)
168
+ engine = AsyncLLMEngine.from_engine_args(
169
+ engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
170
+ openai_serving_chat = OpenAIServingChat(engine, served_model_names,
171
+ args.response_role,
172
+ args.lora_modules,
173
+ args.chat_template)
174
+ openai_serving_completion = OpenAIServingCompletion(
175
+ engine, served_model_names, args.lora_modules)
176
+
177
+ app.root_path = args.root_path
178
+ uvicorn.run(app,
179
+ host=args.host,
180
+ port=args.port,
181
+ log_level=args.uvicorn_log_level,
182
+ timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
183
+ ssl_keyfile=args.ssl_keyfile,
184
+ ssl_certfile=args.ssl_certfile,
185
+ ssl_ca_certs=args.ssl_ca_certs,
186
+ ssl_cert_reqs=args.ssl_cert_reqs)
@@ -0,0 +1,115 @@
1
+ """
2
+ This file contains the command line arguments for the vLLM's
3
+ OpenAI-compatible server. It is kept in a separate file for documentation
4
+ purposes.
5
+ """
6
+
7
+ import argparse
8
+ import json
9
+ import ssl
10
+
11
+ from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
12
+ from vllm.entrypoints.openai.serving_engine import LoRAModulePath
13
+
14
+
15
+ class LoRAParserAction(argparse.Action):
16
+
17
+ def __call__(self, parser, namespace, values, option_string=None):
18
+ lora_list = []
19
+ for item in values:
20
+ name, path = item.split('=')
21
+ lora_list.append(LoRAModulePath(name, path))
22
+ setattr(namespace, self.dest, lora_list)
23
+
24
+
25
+ def make_arg_parser():
26
+ parser = argparse.ArgumentParser(
27
+ description="vLLM OpenAI-Compatible RESTful API server.")
28
+ parser.add_argument("--host",
29
+ type=nullable_str,
30
+ default=None,
31
+ help="host name")
32
+ parser.add_argument("--port", type=int, default=8000, help="port number")
33
+ parser.add_argument(
34
+ "--uvicorn-log-level",
35
+ type=str,
36
+ default="info",
37
+ choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'],
38
+ help="log level for uvicorn")
39
+ parser.add_argument("--allow-credentials",
40
+ action="store_true",
41
+ help="allow credentials")
42
+ parser.add_argument("--allowed-origins",
43
+ type=json.loads,
44
+ default=["*"],
45
+ help="allowed origins")
46
+ parser.add_argument("--allowed-methods",
47
+ type=json.loads,
48
+ default=["*"],
49
+ help="allowed methods")
50
+ parser.add_argument("--allowed-headers",
51
+ type=json.loads,
52
+ default=["*"],
53
+ help="allowed headers")
54
+ parser.add_argument("--api-key",
55
+ type=nullable_str,
56
+ default=None,
57
+ help="If provided, the server will require this key "
58
+ "to be presented in the header.")
59
+ parser.add_argument(
60
+ "--lora-modules",
61
+ type=nullable_str,
62
+ default=None,
63
+ nargs='+',
64
+ action=LoRAParserAction,
65
+ help="LoRA module configurations in the format name=path. "
66
+ "Multiple modules can be specified.")
67
+ parser.add_argument("--chat-template",
68
+ type=nullable_str,
69
+ default=None,
70
+ help="The file path to the chat template, "
71
+ "or the template in single-line form "
72
+ "for the specified model")
73
+ parser.add_argument("--response-role",
74
+ type=nullable_str,
75
+ default="assistant",
76
+ help="The role name to return if "
77
+ "`request.add_generation_prompt=true`.")
78
+ parser.add_argument("--ssl-keyfile",
79
+ type=nullable_str,
80
+ default=None,
81
+ help="The file path to the SSL key file")
82
+ parser.add_argument("--ssl-certfile",
83
+ type=nullable_str,
84
+ default=None,
85
+ help="The file path to the SSL cert file")
86
+ parser.add_argument("--ssl-ca-certs",
87
+ type=nullable_str,
88
+ default=None,
89
+ help="The CA certificates file")
90
+ parser.add_argument(
91
+ "--ssl-cert-reqs",
92
+ type=int,
93
+ default=int(ssl.CERT_NONE),
94
+ help="Whether client certificate is required (see stdlib ssl module's)"
95
+ )
96
+ parser.add_argument(
97
+ "--root-path",
98
+ type=nullable_str,
99
+ default=None,
100
+ help="FastAPI root_path when app is behind a path based routing proxy")
101
+ parser.add_argument(
102
+ "--middleware",
103
+ type=nullable_str,
104
+ action="append",
105
+ default=[],
106
+ help="Additional ASGI middleware to apply to the app. "
107
+ "We accept multiple --middleware arguments. "
108
+ "The value should be an import path. "
109
+ "If a function is provided, vLLM will add it to the server "
110
+ "using @app.middleware('http'). "
111
+ "If a class is provided, vLLM will add it to the server "
112
+ "using app.add_middleware(). ")
113
+
114
+ parser = AsyncEngineArgs.add_cli_args(parser)
115
+ return parser