vllm-npu 0.4.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. vllm/__init__.py +23 -0
  2. vllm/_custom_ops.py +251 -0
  3. vllm/attention/__init__.py +13 -0
  4. vllm/attention/backends/__init__.py +0 -0
  5. vllm/attention/backends/abstract.py +127 -0
  6. vllm/attention/backends/flash_attn.py +271 -0
  7. vllm/attention/backends/flashinfer.py +220 -0
  8. vllm/attention/backends/rocm_flash_attn.py +374 -0
  9. vllm/attention/backends/torch_sdpa.py +250 -0
  10. vllm/attention/backends/xformers.py +393 -0
  11. vllm/attention/layer.py +56 -0
  12. vllm/attention/ops/__init__.py +0 -0
  13. vllm/attention/ops/paged_attn.py +216 -0
  14. vllm/attention/ops/prefix_prefill.py +792 -0
  15. vllm/attention/ops/triton_flash_attention.py +810 -0
  16. vllm/attention/selector.py +91 -0
  17. vllm/block.py +84 -0
  18. vllm/config.py +1225 -0
  19. vllm/core/__init__.py +0 -0
  20. vllm/core/block/__init__.py +0 -0
  21. vllm/core/block/block_table.py +295 -0
  22. vllm/core/block/common.py +199 -0
  23. vllm/core/block/cpu_gpu_block_allocator.py +228 -0
  24. vllm/core/block/interfaces.py +205 -0
  25. vllm/core/block/naive_block.py +318 -0
  26. vllm/core/block/prefix_caching_block.py +606 -0
  27. vllm/core/block_manager_v1.py +625 -0
  28. vllm/core/block_manager_v2.py +258 -0
  29. vllm/core/evictor_v1.py +105 -0
  30. vllm/core/evictor_v2.py +127 -0
  31. vllm/core/interfaces.py +113 -0
  32. vllm/core/policy.py +45 -0
  33. vllm/core/scheduler.py +1163 -0
  34. vllm/distributed/__init__.py +3 -0
  35. vllm/distributed/communication_op.py +237 -0
  36. vllm/distributed/device_communicators/__init__.py +0 -0
  37. vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
  38. vllm/distributed/device_communicators/pynccl.py +287 -0
  39. vllm/distributed/device_communicators/pynccl_utils.py +66 -0
  40. vllm/distributed/parallel_state.py +339 -0
  41. vllm/distributed/utils.py +136 -0
  42. vllm/engine/__init__.py +0 -0
  43. vllm/engine/arg_utils.py +649 -0
  44. vllm/engine/async_llm_engine.py +737 -0
  45. vllm/engine/llm_engine.py +784 -0
  46. vllm/engine/metrics.py +368 -0
  47. vllm/engine/output_processor/__init__.py +0 -0
  48. vllm/engine/output_processor/interfaces.py +76 -0
  49. vllm/engine/output_processor/multi_step.py +142 -0
  50. vllm/engine/output_processor/single_step.py +284 -0
  51. vllm/engine/output_processor/stop_checker.py +101 -0
  52. vllm/engine/output_processor/util.py +19 -0
  53. vllm/entrypoints/__init__.py +0 -0
  54. vllm/entrypoints/api_server.py +119 -0
  55. vllm/entrypoints/llm.py +259 -0
  56. vllm/entrypoints/openai/__init__.py +0 -0
  57. vllm/entrypoints/openai/api_server.py +186 -0
  58. vllm/entrypoints/openai/cli_args.py +115 -0
  59. vllm/entrypoints/openai/protocol.py +460 -0
  60. vllm/entrypoints/openai/serving_chat.py +392 -0
  61. vllm/entrypoints/openai/serving_completion.py +347 -0
  62. vllm/entrypoints/openai/serving_engine.py +234 -0
  63. vllm/envs.py +217 -0
  64. vllm/executor/__init__.py +0 -0
  65. vllm/executor/cpu_executor.py +152 -0
  66. vllm/executor/distributed_gpu_executor.py +115 -0
  67. vllm/executor/executor_base.py +115 -0
  68. vllm/executor/gpu_executor.py +150 -0
  69. vllm/executor/multiproc_worker_utils.py +263 -0
  70. vllm/executor/neuron_executor.py +91 -0
  71. vllm/executor/ray_gpu_executor.py +327 -0
  72. vllm/executor/ray_utils.py +119 -0
  73. vllm/logger.py +153 -0
  74. vllm/logging/__init__.py +5 -0
  75. vllm/logging/formatter.py +15 -0
  76. vllm/lora/__init__.py +0 -0
  77. vllm/lora/fully_sharded_layers.py +262 -0
  78. vllm/lora/layers.py +1181 -0
  79. vllm/lora/lora.py +167 -0
  80. vllm/lora/models.py +645 -0
  81. vllm/lora/punica.py +213 -0
  82. vllm/lora/request.py +32 -0
  83. vllm/lora/utils.py +98 -0
  84. vllm/lora/worker_manager.py +251 -0
  85. vllm/model_executor/__init__.py +7 -0
  86. vllm/model_executor/guided_decoding/__init__.py +25 -0
  87. vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
  88. vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
  89. vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
  90. vllm/model_executor/layers/__init__.py +0 -0
  91. vllm/model_executor/layers/activation.py +173 -0
  92. vllm/model_executor/layers/fused_moe/__init__.py +7 -0
  93. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  94. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  97. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  98. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  99. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  100. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  101. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  102. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  103. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  104. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  105. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
  106. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  107. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  108. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  109. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  110. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
  111. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  112. vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
  113. vllm/model_executor/layers/layernorm.py +71 -0
  114. vllm/model_executor/layers/linear.py +709 -0
  115. vllm/model_executor/layers/logits_processor.py +115 -0
  116. vllm/model_executor/layers/ops/__init__.py +0 -0
  117. vllm/model_executor/layers/ops/rand.py +157 -0
  118. vllm/model_executor/layers/ops/sample.py +406 -0
  119. vllm/model_executor/layers/quantization/__init__.py +35 -0
  120. vllm/model_executor/layers/quantization/aqlm.py +376 -0
  121. vllm/model_executor/layers/quantization/awq.py +175 -0
  122. vllm/model_executor/layers/quantization/base_config.py +97 -0
  123. vllm/model_executor/layers/quantization/fp8.py +265 -0
  124. vllm/model_executor/layers/quantization/gptq.py +224 -0
  125. vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
  126. vllm/model_executor/layers/quantization/marlin.py +227 -0
  127. vllm/model_executor/layers/quantization/schema.py +84 -0
  128. vllm/model_executor/layers/quantization/squeezellm.py +137 -0
  129. vllm/model_executor/layers/rejection_sampler.py +405 -0
  130. vllm/model_executor/layers/rotary_embedding.py +525 -0
  131. vllm/model_executor/layers/sampler.py +1051 -0
  132. vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
  133. vllm/model_executor/model_loader/__init__.py +30 -0
  134. vllm/model_executor/model_loader/loader.py +362 -0
  135. vllm/model_executor/model_loader/neuron.py +136 -0
  136. vllm/model_executor/model_loader/tensorizer.py +368 -0
  137. vllm/model_executor/model_loader/utils.py +41 -0
  138. vllm/model_executor/model_loader/weight_utils.py +372 -0
  139. vllm/model_executor/models/__init__.py +119 -0
  140. vllm/model_executor/models/baichuan.py +410 -0
  141. vllm/model_executor/models/bloom.py +327 -0
  142. vllm/model_executor/models/chatglm.py +386 -0
  143. vllm/model_executor/models/commandr.py +373 -0
  144. vllm/model_executor/models/dbrx.py +413 -0
  145. vllm/model_executor/models/decilm.py +122 -0
  146. vllm/model_executor/models/deepseek.py +438 -0
  147. vllm/model_executor/models/falcon.py +444 -0
  148. vllm/model_executor/models/gemma.py +393 -0
  149. vllm/model_executor/models/gpt2.py +266 -0
  150. vllm/model_executor/models/gpt_bigcode.py +274 -0
  151. vllm/model_executor/models/gpt_j.py +281 -0
  152. vllm/model_executor/models/gpt_neox.py +295 -0
  153. vllm/model_executor/models/internlm2.py +323 -0
  154. vllm/model_executor/models/jais.py +333 -0
  155. vllm/model_executor/models/llama.py +442 -0
  156. vllm/model_executor/models/llava.py +239 -0
  157. vllm/model_executor/models/minicpm.py +531 -0
  158. vllm/model_executor/models/mixtral.py +583 -0
  159. vllm/model_executor/models/mixtral_quant.py +404 -0
  160. vllm/model_executor/models/mpt.py +295 -0
  161. vllm/model_executor/models/olmo.py +356 -0
  162. vllm/model_executor/models/opt.py +349 -0
  163. vllm/model_executor/models/orion.py +319 -0
  164. vllm/model_executor/models/phi.py +300 -0
  165. vllm/model_executor/models/qwen.py +284 -0
  166. vllm/model_executor/models/qwen2.py +367 -0
  167. vllm/model_executor/models/qwen2_moe.py +447 -0
  168. vllm/model_executor/models/stablelm.py +301 -0
  169. vllm/model_executor/models/starcoder2.py +302 -0
  170. vllm/model_executor/models/xverse.py +366 -0
  171. vllm/model_executor/sampling_metadata.py +588 -0
  172. vllm/model_executor/utils.py +35 -0
  173. vllm/outputs.py +150 -0
  174. vllm/py.typed +2 -0
  175. vllm/sampling_params.py +340 -0
  176. vllm/sequence.py +766 -0
  177. vllm/spec_decode/__init__.py +0 -0
  178. vllm/spec_decode/batch_expansion.py +397 -0
  179. vllm/spec_decode/interfaces.py +73 -0
  180. vllm/spec_decode/metrics.py +191 -0
  181. vllm/spec_decode/multi_step_worker.py +203 -0
  182. vllm/spec_decode/ngram_worker.py +176 -0
  183. vllm/spec_decode/spec_decode_worker.py +472 -0
  184. vllm/spec_decode/top1_proposer.py +200 -0
  185. vllm/spec_decode/util.py +228 -0
  186. vllm/test_utils.py +41 -0
  187. vllm/transformers_utils/__init__.py +0 -0
  188. vllm/transformers_utils/config.py +58 -0
  189. vllm/transformers_utils/configs/__init__.py +16 -0
  190. vllm/transformers_utils/configs/chatglm.py +68 -0
  191. vllm/transformers_utils/configs/dbrx.py +278 -0
  192. vllm/transformers_utils/configs/falcon.py +87 -0
  193. vllm/transformers_utils/configs/jais.py +236 -0
  194. vllm/transformers_utils/configs/mpt.py +178 -0
  195. vllm/transformers_utils/detokenizer.py +313 -0
  196. vllm/transformers_utils/tokenizer.py +149 -0
  197. vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
  198. vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
  199. vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
  200. vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
  201. vllm/transformers_utils/tokenizers/__init__.py +5 -0
  202. vllm/transformers_utils/tokenizers/baichuan.py +255 -0
  203. vllm/usage/__init__.py +0 -0
  204. vllm/usage/usage_lib.py +209 -0
  205. vllm/utils.py +677 -0
  206. vllm/worker/__init__.py +0 -0
  207. vllm/worker/cache_engine.py +105 -0
  208. vllm/worker/cpu_model_runner.py +346 -0
  209. vllm/worker/cpu_worker.py +321 -0
  210. vllm/worker/model_runner.py +1168 -0
  211. vllm/worker/neuron_model_runner.py +196 -0
  212. vllm/worker/neuron_worker.py +98 -0
  213. vllm/worker/worker.py +345 -0
  214. vllm/worker/worker_base.py +146 -0
  215. vllm_npu-0.4.2.dist-info/LICENSE +201 -0
  216. vllm_npu-0.4.2.dist-info/METADATA +173 -0
  217. vllm_npu-0.4.2.dist-info/RECORD +219 -0
  218. vllm_npu-0.4.2.dist-info/WHEEL +5 -0
  219. vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
vllm/config.py ADDED
@@ -0,0 +1,1225 @@
1
+ import enum
2
+ import json
3
+ from dataclasses import dataclass, field, fields
4
+ from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
5
+
6
+ import torch
7
+ from packaging.version import Version
8
+ from transformers import PretrainedConfig
9
+
10
+ from vllm.logger import init_logger
11
+ from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
12
+ get_quantization_config)
13
+ from vllm.transformers_utils.config import get_config, get_hf_text_config
14
+ from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip,
15
+ is_neuron)
16
+
17
+ GPTQMarlinConfig = get_quantization_config("gptq_marlin")
18
+
19
+ if TYPE_CHECKING:
20
+ from ray.util.placement_group import PlacementGroup
21
+
22
+ from vllm.model_executor.model_loader.loader import BaseModelLoader
23
+
24
+ logger = init_logger(__name__)
25
+
26
+ _GB = 1 << 30
27
+
28
+
29
+ class ModelConfig:
30
+ """Configuration for the model.
31
+
32
+ Args:
33
+ model: Name or path of the huggingface model to use.
34
+ It is also used as the content for `model_name` tag in metrics
35
+ output when `served_model_name` is not specified.
36
+ tokenizer: Name or path of the huggingface tokenizer to use.
37
+ tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
38
+ available, and "slow" will always use the slow tokenizer.
39
+ trust_remote_code: Trust remote code (e.g., from HuggingFace) when
40
+ downloading the model and tokenizer.
41
+ dtype: Data type for model weights and activations. The "auto" option
42
+ will use FP16 precision for FP32 and FP16 models, and BF16 precision
43
+ for BF16 models.
44
+ seed: Random seed for reproducibility.
45
+ revision: The specific model version to use. It can be a branch name,
46
+ a tag name, or a commit id. If unspecified, will use the default
47
+ version.
48
+ code_revision: The specific revision to use for the model code on
49
+ Hugging Face Hub. It can be a branch name, a tag name, or a
50
+ commit id. If unspecified, will use the default version.
51
+ tokenizer_revision: The specific tokenizer version to use. It can be a
52
+ branch name, a tag name, or a commit id. If unspecified, will use
53
+ the default version.
54
+ max_model_len: Maximum length of a sequence (including prompt and
55
+ output). If None, will be derived from the model.
56
+ quantization: Quantization method that was used to quantize the model
57
+ weights. If None, we assume the model weights are not quantized.
58
+ quantization_param_path: Path to JSON file containing scaling factors.
59
+ Used to load KV cache scaling factors into the model when KV cache
60
+ type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
61
+ be used to load activation and weight scaling factors when the
62
+ model dtype is FP8_E4M3 on ROCm.
63
+ enforce_eager: Whether to enforce eager execution. If True, we will
64
+ disable CUDA graph and always execute the model in eager mode.
65
+ If False, we will use CUDA graph and eager execution in hybrid.
66
+ max_context_len_to_capture: Maximum context len covered by CUDA graphs.
67
+ When a sequence has context length larger than this, we fall back
68
+ to eager mode (DEPRECATED. Use max_seq_len_to_capture instead).
69
+ max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
70
+ When a sequence has context length larger than this, we fall back
71
+ to eager mode
72
+ skip_tokenizer_init: If true, skip initialization of tokenizer and
73
+ detokenizer.
74
+ served_model_name: The model name used in metrics tag `model_name`,
75
+ matches the model name exposed via the APIs. If multiple model
76
+ names provided, the first name will be used. If not specified,
77
+ the model name will be the same as `model`.
78
+ """
79
+
80
+ def __init__(
81
+ self,
82
+ model: str,
83
+ tokenizer: str,
84
+ tokenizer_mode: str,
85
+ trust_remote_code: bool,
86
+ dtype: Union[str, torch.dtype],
87
+ seed: int,
88
+ revision: Optional[str] = None,
89
+ code_revision: Optional[str] = None,
90
+ tokenizer_revision: Optional[str] = None,
91
+ max_model_len: Optional[int] = None,
92
+ quantization: Optional[str] = None,
93
+ quantization_param_path: Optional[str] = None,
94
+ enforce_eager: bool = False,
95
+ max_context_len_to_capture: Optional[int] = None,
96
+ max_seq_len_to_capture: Optional[int] = None,
97
+ max_logprobs: int = 5,
98
+ skip_tokenizer_init: bool = False,
99
+ served_model_name: Optional[Union[str, List[str]]] = None,
100
+ ) -> None:
101
+ self.model = model
102
+ self.tokenizer = tokenizer
103
+ self.tokenizer_mode = tokenizer_mode
104
+ self.trust_remote_code = trust_remote_code
105
+ self.seed = seed
106
+ self.revision = revision
107
+ self.code_revision = code_revision
108
+ self.tokenizer_revision = tokenizer_revision
109
+ self.quantization = quantization
110
+ self.quantization_param_path = quantization_param_path
111
+ self.enforce_eager = enforce_eager
112
+ self.max_context_len_to_capture = max_context_len_to_capture
113
+ if self.max_context_len_to_capture is not None:
114
+ raise ValueError("`max_context_len_to_capture` is deprecated. "
115
+ "Use `max_seq_len_to_capture` instead.")
116
+ self.max_seq_len_to_capture = (max_seq_len_to_capture
117
+ or max_context_len_to_capture)
118
+ self.max_logprobs = max_logprobs
119
+ self.skip_tokenizer_init = skip_tokenizer_init
120
+
121
+ self.hf_config = get_config(self.model, trust_remote_code, revision,
122
+ code_revision)
123
+ self.hf_text_config = get_hf_text_config(self.hf_config)
124
+ self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
125
+ self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
126
+ max_model_len)
127
+ self.served_model_name = get_served_model_name(model,
128
+ served_model_name)
129
+ if not self.skip_tokenizer_init:
130
+ self._verify_tokenizer_mode()
131
+ self._verify_quantization()
132
+ self._verify_cuda_graph()
133
+
134
+ def _verify_tokenizer_mode(self) -> None:
135
+ tokenizer_mode = self.tokenizer_mode.lower()
136
+ if tokenizer_mode not in ["auto", "slow"]:
137
+ raise ValueError(
138
+ f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
139
+ "either 'auto' or 'slow'.")
140
+ self.tokenizer_mode = tokenizer_mode
141
+
142
+ def _verify_quantization(self) -> None:
143
+ supported_quantization = [*QUANTIZATION_METHODS]
144
+ rocm_supported_quantization = ["gptq", "squeezellm"]
145
+ if self.quantization is not None:
146
+ self.quantization = self.quantization.lower()
147
+
148
+ # Parse quantization method from the HF model config, if available.
149
+ quant_cfg = getattr(self.hf_config, "quantization_config", None)
150
+ if quant_cfg is not None:
151
+ quant_method = quant_cfg.get("quant_method", "").lower()
152
+ # compat: autogptq >=0.8.0 use checkpoint_format: str
153
+ # compat: autogptq <=0.7.1 is_marlin_format: bool
154
+ is_format_marlin = (quant_cfg.get("checkpoint_format") == "marlin"
155
+ or quant_cfg.get("is_marlin_format", False))
156
+
157
+ # Check which LinearMethod the GPTQ model should use.
158
+ if quant_method == "gptq":
159
+ # If serialized in Marlin format, use MarlinLinearMethod.
160
+ # TODO (@robertgshaw): migrate under GPTQMarlinLinearMethod.
161
+ if is_format_marlin:
162
+ logger.info("The model is serialized in Marlin format. "
163
+ "Using Marlin kernel.")
164
+ quant_method = "marlin"
165
+ if self.quantization == "gptq":
166
+ self.quantization = quant_method
167
+
168
+ # If convertible to Marlin format, use GPTQMarlinLinearMethod
169
+ # unless the user explicitly specified GPTQLinearMethod.
170
+ elif GPTQMarlinConfig.is_marlin_compatible(quant_cfg):
171
+ if self.quantization == "gptq":
172
+ logger.warning(
173
+ "The model is convertible to Marlin format, but "
174
+ "you specified quantization=gptq. Use "
175
+ "quantization=marlin for faster inference.")
176
+ else:
177
+ logger.info(
178
+ "The model is convertible to Marlin format. "
179
+ "Using Marlin kernel.")
180
+ quant_method = "gptq_marlin"
181
+ if self.quantization == "marlin":
182
+ self.quantization = quant_method
183
+
184
+ # Verify quantization configurations.
185
+ if self.quantization is None:
186
+ self.quantization = quant_method
187
+ elif self.quantization != quant_method:
188
+ raise ValueError(
189
+ "Quantization method specified in the model config "
190
+ f"({quant_method}) does not match the quantization "
191
+ f"method specified in the `quantization` argument "
192
+ f"({self.quantization}).")
193
+
194
+ if self.quantization is not None:
195
+ if self.quantization not in supported_quantization:
196
+ raise ValueError(
197
+ f"Unknown quantization method: {self.quantization}. Must "
198
+ f"be one of {supported_quantization}.")
199
+ if is_hip(
200
+ ) and self.quantization not in rocm_supported_quantization:
201
+ raise ValueError(
202
+ f"{self.quantization} quantization is currently not "
203
+ f"supported in ROCm.")
204
+ if (self.quantization not in ["marlin", "gptq_marlin"]):
205
+ logger.warning(
206
+ "%s quantization is not fully "
207
+ "optimized yet. The speed can be slower than "
208
+ "non-quantized models.", self.quantization)
209
+
210
+ def _verify_cuda_graph(self) -> None:
211
+ if self.max_seq_len_to_capture is None:
212
+ self.max_seq_len_to_capture = self.max_model_len
213
+ self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
214
+ self.max_model_len)
215
+
216
+ def verify_with_parallel_config(
217
+ self,
218
+ parallel_config: "ParallelConfig",
219
+ ) -> None:
220
+ total_num_attention_heads = self.hf_text_config.num_attention_heads
221
+ tensor_parallel_size = parallel_config.tensor_parallel_size
222
+ if total_num_attention_heads % tensor_parallel_size != 0:
223
+ raise ValueError(
224
+ f"Total number of attention heads ({total_num_attention_heads})"
225
+ " must be divisible by tensor parallel size "
226
+ f"({tensor_parallel_size}).")
227
+
228
+ total_num_hidden_layers = self.hf_text_config.num_hidden_layers
229
+ pipeline_parallel_size = parallel_config.pipeline_parallel_size
230
+ if total_num_hidden_layers % pipeline_parallel_size != 0:
231
+ raise ValueError(
232
+ f"Total number of hidden layers ({total_num_hidden_layers}) "
233
+ "must be divisible by pipeline parallel size "
234
+ f"({pipeline_parallel_size}).")
235
+
236
+ def get_sliding_window(self) -> Optional[int]:
237
+ """Get the sliding window size, or None if disabled.
238
+ """
239
+
240
+ # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
241
+ # addition to sliding window size. We check if that field is present
242
+ # and if it's False, return None.
243
+ if (hasattr(self.hf_text_config, "use_sliding_window")
244
+ and not self.hf_text_config.use_sliding_window):
245
+ return None
246
+ return getattr(self.hf_text_config, "sliding_window", None)
247
+
248
+ def get_vocab_size(self) -> int:
249
+ return self.hf_text_config.vocab_size
250
+
251
+ def get_hidden_size(self) -> int:
252
+ return self.hf_text_config.hidden_size
253
+
254
+ def get_head_size(self) -> int:
255
+ if hasattr(self.hf_text_config, "head_dim"):
256
+ return self.hf_text_config.head_dim
257
+ # FIXME(woosuk): This may not be true for all models.
258
+ return (self.hf_text_config.hidden_size //
259
+ self.hf_text_config.num_attention_heads)
260
+
261
+ def get_total_num_kv_heads(self) -> int:
262
+ """Returns the total number of KV heads."""
263
+ # For GPTBigCode & Falcon:
264
+ # NOTE: for falcon, when new_decoder_architecture is True, the
265
+ # multi_query flag is ignored and we use n_head_kv for the number of
266
+ # KV heads.
267
+ falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
268
+ new_decoder_arch_falcon = (
269
+ self.hf_config.model_type in falcon_model_types
270
+ and getattr(self.hf_config, "new_decoder_architecture", False))
271
+ if not new_decoder_arch_falcon and getattr(self.hf_text_config,
272
+ "multi_query", False):
273
+ # Multi-query attention, only one KV head.
274
+ # Currently, tensor parallelism is not supported in this case.
275
+ return 1
276
+
277
+ # For DBRX and MPT
278
+ if self.hf_config.model_type in ["dbrx", "mpt"]:
279
+ return getattr(self.hf_config.attn_config, "kv_n_heads",
280
+ self.hf_config.num_attention_heads)
281
+
282
+ attributes = [
283
+ # For Falcon:
284
+ "n_head_kv",
285
+ "num_kv_heads",
286
+ # For LLaMA-2:
287
+ "num_key_value_heads",
288
+ # For ChatGLM:
289
+ "multi_query_group_num",
290
+ ]
291
+ for attr in attributes:
292
+ num_kv_heads = getattr(self.hf_text_config, attr, None)
293
+ if num_kv_heads is not None:
294
+ return num_kv_heads
295
+
296
+ # For non-grouped-query attention models, the number of KV heads is
297
+ # equal to the number of attention heads.
298
+ return self.hf_text_config.num_attention_heads
299
+
300
+ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
301
+ """Returns the number of KV heads per GPU."""
302
+ total_num_kv_heads = self.get_total_num_kv_heads()
303
+ # If tensor parallelism is used, we divide the number of KV heads by
304
+ # the tensor parallel size. We will replicate the KV heads in the
305
+ # case where the number of KV heads is smaller than the tensor
306
+ # parallel size so each GPU has at least one KV head.
307
+ return max(1,
308
+ total_num_kv_heads // parallel_config.tensor_parallel_size)
309
+
310
+ def get_num_attention_heads(self,
311
+ parallel_config: "ParallelConfig") -> int:
312
+ return self.hf_text_config.num_attention_heads // \
313
+ parallel_config.tensor_parallel_size
314
+
315
+ def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
316
+ total_num_hidden_layers = self.hf_text_config.num_hidden_layers
317
+ return total_num_hidden_layers // parallel_config.pipeline_parallel_size
318
+
319
+
320
+ class CacheConfig:
321
+ """Configuration for the KV cache.
322
+
323
+ Args:
324
+ block_size: Size of a cache block in number of tokens.
325
+ gpu_memory_utilization: Fraction of GPU memory to use for the
326
+ vLLM execution.
327
+ swap_space: Size of the CPU swap space per GPU (in GiB).
328
+ cache_dtype: Data type for kv cache storage.
329
+ num_gpu_blocks_override: Number of GPU blocks to use. This overrides the
330
+ profiled num_gpu_blocks if specified. Does nothing if None.
331
+ """
332
+
333
+ def __init__(
334
+ self,
335
+ block_size: int,
336
+ gpu_memory_utilization: float,
337
+ swap_space: int,
338
+ cache_dtype: str,
339
+ num_gpu_blocks_override: Optional[int] = None,
340
+ sliding_window: Optional[int] = None,
341
+ enable_prefix_caching: bool = False,
342
+ ) -> None:
343
+ self.block_size = block_size
344
+ self.gpu_memory_utilization = gpu_memory_utilization
345
+ self.swap_space_bytes = swap_space * _GB
346
+ self.num_gpu_blocks_override = num_gpu_blocks_override
347
+ self.cache_dtype = cache_dtype
348
+ self.sliding_window = sliding_window
349
+ self.enable_prefix_caching = enable_prefix_caching
350
+ self._verify_args()
351
+ self._verify_cache_dtype()
352
+
353
+ # Will be set after profiling.
354
+ self.num_gpu_blocks = None
355
+ self.num_cpu_blocks = None
356
+
357
+ def metrics_info(self):
358
+ # convert cache_config to dict(key: str, value: str) for prometheus
359
+ # metrics info
360
+ return {key: str(value) for key, value in self.__dict__.items()}
361
+
362
+ def _verify_args(self) -> None:
363
+ if self.gpu_memory_utilization > 1.0:
364
+ raise ValueError(
365
+ "GPU memory utilization must be less than 1.0. Got "
366
+ f"{self.gpu_memory_utilization}.")
367
+
368
+ def _verify_cache_dtype(self) -> None:
369
+ if self.cache_dtype == "auto":
370
+ pass
371
+ elif self.cache_dtype == "fp8":
372
+ if not is_hip():
373
+ nvcc_cuda_version = get_nvcc_cuda_version()
374
+ if nvcc_cuda_version is not None \
375
+ and nvcc_cuda_version < Version("11.8"):
376
+ raise ValueError(
377
+ "FP8 is not supported when cuda version is"
378
+ "lower than 11.8.")
379
+ logger.info(
380
+ "Using fp8 data type to store kv cache. It reduces the GPU "
381
+ "memory footprint and boosts the performance. "
382
+ "But it may cause slight accuracy drop without scaling "
383
+ "factors. FP8_E5M2 (without scaling) is only supported on "
384
+ "cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 "
385
+ "is instead supported for common inference criteria.")
386
+ else:
387
+ raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
388
+
389
+ def verify_with_parallel_config(
390
+ self,
391
+ parallel_config: "ParallelConfig",
392
+ ) -> None:
393
+ total_cpu_memory = get_cpu_memory()
394
+ # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
395
+ # group are in the same node. However, the GPUs may span multiple nodes.
396
+ num_gpus_per_node = parallel_config.tensor_parallel_size
397
+ cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
398
+
399
+ msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of "
400
+ f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
401
+ "allocated for the swap space.")
402
+ if cpu_memory_usage > 0.7 * total_cpu_memory:
403
+ raise ValueError("Too large swap space. " + msg)
404
+ elif cpu_memory_usage > 0.4 * total_cpu_memory:
405
+ logger.warning("Possibly too large swap space. %s", msg)
406
+
407
+
408
+ @dataclass
409
+ class TokenizerPoolConfig:
410
+ """Configuration for the tokenizer pool.
411
+
412
+ Args:
413
+ pool_size: Number of tokenizer workers in the pool.
414
+ pool_type: Type of the pool.
415
+ extra_config: Additional config for the pool.
416
+ The way the config will be used depends on the
417
+ pool type.
418
+ """
419
+ pool_size: int
420
+ pool_type: str
421
+ extra_config: dict
422
+
423
+ def __post_init__(self):
424
+ if self.pool_type not in ("ray", ):
425
+ raise ValueError(f"Unknown pool type: {self.pool_type}")
426
+ if not isinstance(self.extra_config, dict):
427
+ raise ValueError("extra_config must be a dictionary.")
428
+
429
+ @classmethod
430
+ def create_config(
431
+ cls, tokenizer_pool_size: int, tokenizer_pool_type: str,
432
+ tokenizer_pool_extra_config: Optional[Union[str, dict]]
433
+ ) -> Optional["TokenizerPoolConfig"]:
434
+ """Create a TokenizerPoolConfig from the given parameters.
435
+
436
+ If tokenizer_pool_size is 0, return None.
437
+
438
+ Args:
439
+ tokenizer_pool_size: Number of tokenizer workers in the pool.
440
+ tokenizer_pool_type: Type of the pool.
441
+ tokenizer_pool_extra_config: Additional config for the pool.
442
+ The way the config will be used depends on the
443
+ pool type. This can be a JSON string (will be parsed).
444
+ """
445
+ if tokenizer_pool_size:
446
+ if isinstance(tokenizer_pool_extra_config, str):
447
+ tokenizer_pool_extra_config_parsed = json.loads(
448
+ tokenizer_pool_extra_config)
449
+ else:
450
+ tokenizer_pool_extra_config_parsed = (
451
+ tokenizer_pool_extra_config or {})
452
+ tokenizer_pool_config = cls(tokenizer_pool_size,
453
+ tokenizer_pool_type,
454
+ tokenizer_pool_extra_config_parsed)
455
+ else:
456
+ tokenizer_pool_config = None
457
+ return tokenizer_pool_config
458
+
459
+
460
+ class LoadFormat(str, enum.Enum):
461
+ AUTO = "auto"
462
+ PT = "pt"
463
+ SAFETENSORS = "safetensors"
464
+ NPCACHE = "npcache"
465
+ DUMMY = "dummy"
466
+ TENSORIZER = "tensorizer"
467
+
468
+
469
+ @dataclass
470
+ class LoadConfig:
471
+ """
472
+ download_dir: Directory to download and load the weights, default to the
473
+ default cache directory of huggingface.
474
+ load_format: The format of the model weights to load:
475
+ "auto" will try to load the weights in the safetensors format and
476
+ fall back to the pytorch bin format if safetensors format is
477
+ not available.
478
+ "pt" will load the weights in the pytorch bin format.
479
+ "safetensors" will load the weights in the safetensors format.
480
+ "npcache" will load the weights in pytorch format and store
481
+ a numpy cache to speed up the loading.
482
+ "dummy" will initialize the weights with random values, which is
483
+ mainly for profiling.
484
+ "tensorizer" will use CoreWeave's tensorizer library for
485
+ fast weight loading.
486
+ """
487
+
488
+ load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
489
+ download_dir: Optional[str] = None
490
+ model_loader_extra_config: Optional[Union[str, dict]] = field(
491
+ default_factory=dict)
492
+
493
+ def __post_init__(self):
494
+ model_loader_extra_config = self.model_loader_extra_config or {}
495
+ if isinstance(model_loader_extra_config, str):
496
+ self.model_loader_extra_config = json.loads(
497
+ model_loader_extra_config)
498
+ self._verify_load_format()
499
+
500
+ def _verify_load_format(self) -> None:
501
+ if not isinstance(self.load_format, str):
502
+ return
503
+
504
+ load_format = self.load_format.lower()
505
+ self.load_format = LoadFormat(load_format)
506
+
507
+ rocm_not_supported_load_format: List[str] = []
508
+ if is_hip() and load_format in rocm_not_supported_load_format:
509
+ rocm_supported_load_format = [
510
+ f for f in LoadFormat.__members__
511
+ if (f not in rocm_not_supported_load_format)
512
+ ]
513
+ raise ValueError(
514
+ f"load format '{load_format}' is not supported in ROCm. "
515
+ f"Supported load formats are "
516
+ f"{rocm_supported_load_format}")
517
+
518
+
519
+ class ParallelConfig:
520
+ """Configuration for the distributed execution.
521
+
522
+ Args:
523
+ pipeline_parallel_size: Number of pipeline parallel groups.
524
+ tensor_parallel_size: Number of tensor parallel groups.
525
+ worker_use_ray: Whether to use Ray for model workers. Will be set to
526
+ True if either pipeline_parallel_size or tensor_parallel_size is
527
+ greater than 1.
528
+ max_parallel_loading_workers: Maximum number of multiple batches
529
+ when load model sequentially. To avoid RAM OOM when using tensor
530
+ parallel and large models.
531
+ disable_custom_all_reduce: Disable the custom all-reduce kernel and
532
+ fall back to NCCL.
533
+ tokenizer_pool_config: Config for the tokenizer pool.
534
+ If None, will use synchronous tokenization.
535
+ ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
536
+ https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
537
+ """
538
+
539
+ def __init__(
540
+ self,
541
+ pipeline_parallel_size: int,
542
+ tensor_parallel_size: int,
543
+ worker_use_ray: bool,
544
+ max_parallel_loading_workers: Optional[int] = None,
545
+ disable_custom_all_reduce: bool = False,
546
+ tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
547
+ ray_workers_use_nsight: bool = False,
548
+ placement_group: Optional["PlacementGroup"] = None,
549
+ ) -> None:
550
+ self.pipeline_parallel_size = pipeline_parallel_size
551
+ self.tensor_parallel_size = tensor_parallel_size
552
+ self.worker_use_ray = worker_use_ray
553
+ self.max_parallel_loading_workers = max_parallel_loading_workers
554
+ self.disable_custom_all_reduce = disable_custom_all_reduce
555
+ self.tokenizer_pool_config = tokenizer_pool_config
556
+ self.ray_workers_use_nsight = ray_workers_use_nsight
557
+ self.placement_group = placement_group
558
+
559
+ self.world_size = pipeline_parallel_size * self.tensor_parallel_size
560
+ if self.world_size > 1:
561
+ self.worker_use_ray = True
562
+ self._verify_args()
563
+
564
+ def _verify_args(self) -> None:
565
+ if self.pipeline_parallel_size > 1:
566
+ raise NotImplementedError(
567
+ "Pipeline parallelism is not supported yet.")
568
+ if not self.disable_custom_all_reduce and self.world_size > 1:
569
+ if is_hip():
570
+ self.disable_custom_all_reduce = True
571
+ logger.info(
572
+ "Disabled the custom all-reduce kernel because it is not "
573
+ "supported on AMD GPUs.")
574
+ elif self.pipeline_parallel_size > 1:
575
+ self.disable_custom_all_reduce = True
576
+ logger.info(
577
+ "Disabled the custom all-reduce kernel because it is not "
578
+ "supported with pipeline parallelism.")
579
+ if self.ray_workers_use_nsight and not self.worker_use_ray:
580
+ raise ValueError("Unable to use nsight profiling unless workers "
581
+ "run with Ray.")
582
+
583
+
584
+ class SchedulerConfig:
585
+ """Scheduler configuration.
586
+
587
+ Args:
588
+ max_num_batched_tokens: Maximum number of tokens to be processed in
589
+ a single iteration.
590
+ max_num_seqs: Maximum number of sequences to be processed in a single
591
+ iteration.
592
+ max_model_len: Maximum length of a sequence (including prompt
593
+ and generated text).
594
+ use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
595
+ num_lookahead_slots: The number of slots to allocate per sequence per
596
+ step, beyond the known token ids. This is used in speculative
597
+ decoding to store KV activations of tokens which may or may not be
598
+ accepted.
599
+ delay_factor: Apply a delay (of delay factor multiplied by previous
600
+ prompt latency) before scheduling next prompt.
601
+ enable_chunked_prefill: If True, prefill requests can be chunked based
602
+ on the remaining max_num_batched_tokens.
603
+ """
604
+
605
+ def __init__(
606
+ self,
607
+ max_num_batched_tokens: Optional[int],
608
+ max_num_seqs: int,
609
+ max_model_len: int,
610
+ use_v2_block_manager: bool = False,
611
+ num_lookahead_slots: int = 0,
612
+ delay_factor: float = 0.0,
613
+ enable_chunked_prefill: bool = False,
614
+ ) -> None:
615
+ if max_num_batched_tokens is not None:
616
+ self.max_num_batched_tokens = max_num_batched_tokens
617
+ else:
618
+ if enable_chunked_prefill:
619
+ # It is the values that have the best balance between ITL
620
+ # and TTFT on A100. Note it is not optimized for throughput.
621
+ self.max_num_batched_tokens = 512
622
+ else:
623
+ # If max_model_len is too short, use 2048 as the default value
624
+ # for higher throughput.
625
+ self.max_num_batched_tokens = max(max_model_len, 2048)
626
+ if enable_chunked_prefill:
627
+ logger.info("Chunked prefill is enabled (EXPERIMENTAL).")
628
+
629
+ self.max_num_seqs = max_num_seqs
630
+ self.max_model_len = max_model_len
631
+ self.use_v2_block_manager = use_v2_block_manager
632
+ self.num_lookahead_slots = num_lookahead_slots
633
+ self.delay_factor = delay_factor
634
+ self.chunked_prefill_enabled = enable_chunked_prefill
635
+
636
+ self._verify_args()
637
+
638
+ def _verify_args(self) -> None:
639
+ if (self.max_num_batched_tokens < self.max_model_len
640
+ and not self.chunked_prefill_enabled):
641
+ raise ValueError(
642
+ f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
643
+ f"smaller than max_model_len ({self.max_model_len}). "
644
+ "This effectively limits the maximum sequence length to "
645
+ "max_num_batched_tokens and makes vLLM reject longer "
646
+ "sequences. Please increase max_num_batched_tokens or "
647
+ "decrease max_model_len.")
648
+
649
+ if self.max_num_batched_tokens < self.max_num_seqs:
650
+ raise ValueError(
651
+ f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
652
+ "be greater than or equal to max_num_seqs "
653
+ f"({self.max_num_seqs}).")
654
+
655
+ if self.num_lookahead_slots < 0:
656
+ raise ValueError(
657
+ "num_lookahead_slots "
658
+ f"({self.num_lookahead_slots}) must be greater than or "
659
+ "equal to 0.")
660
+
661
+
662
+ class DeviceConfig:
663
+
664
+ def __init__(self, device: str = "auto") -> None:
665
+ if device == "auto":
666
+ # Automated device type detection
667
+ if is_neuron():
668
+ self.device_type = "neuron"
669
+ elif is_cpu():
670
+ self.device_type = "cpu"
671
+ else:
672
+ # We don't call torch.cuda.is_available() here to
673
+ # avoid initializing CUDA before workers are forked
674
+ self.device_type = "cuda"
675
+ else:
676
+ # Device type is assigned explicitly
677
+ self.device_type = device
678
+
679
+ # Some device types require processing inputs on CPU
680
+ if self.device_type in ["neuron"]:
681
+ self.device = torch.device("cpu")
682
+ else:
683
+ # Set device with device type
684
+ self.device = torch.device(self.device_type)
685
+
686
+
687
+ class SpeculativeConfig:
688
+ """Configuration for speculative decoding.
689
+
690
+ The configuration is currently specialized to draft-model speculative
691
+ decoding with top-1 proposals.
692
+ """
693
+
694
+ @staticmethod
695
+ def maybe_create_spec_config(
696
+ target_model_config: ModelConfig,
697
+ target_parallel_config: ParallelConfig,
698
+ target_dtype: str,
699
+ speculative_model: Optional[str],
700
+ num_speculative_tokens: Optional[int],
701
+ speculative_max_model_len: Optional[int],
702
+ enable_chunked_prefill: bool,
703
+ use_v2_block_manager: bool,
704
+ ngram_prompt_lookup_max: Optional[int],
705
+ ngram_prompt_lookup_min: Optional[int],
706
+ ) -> Optional["SpeculativeConfig"]:
707
+ """Create a SpeculativeConfig if possible, else return None.
708
+
709
+ This function attempts to create a SpeculativeConfig object based on the
710
+ provided parameters. If the necessary conditions are met, it returns an
711
+ instance of SpeculativeConfig. Otherwise, it returns None.
712
+
713
+ Args:
714
+ target_model_config (ModelConfig): The configuration of the target
715
+ model.
716
+ target_parallel_config (ParallelConfig): The parallel configuration
717
+ for the target model.
718
+ target_dtype (str): The data type used for the target model.
719
+ speculative_model (Optional[str]): The name of the speculative
720
+ model, if provided.
721
+ num_speculative_tokens (Optional[int]): The number of speculative
722
+ tokens, if provided.
723
+ speculative_max_model_len (Optional[int]): The maximum model len of
724
+ the speculative model. Used when testing the ability to skip
725
+ speculation for some sequences.
726
+ enable_chunked_prefill (bool): Whether vLLM is configured to use
727
+ chunked prefill or not. Used for raising an error since its not
728
+ yet compatible with spec decode.
729
+ use_v2_block_manager (bool): Whether vLLM is configured to use the
730
+ v2 block manager or not. Used for raising an error since the v2
731
+ block manager is required with spec decode.
732
+ ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
733
+ window, if provided.
734
+ ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
735
+ window, if provided.
736
+
737
+ Returns:
738
+ Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
739
+ the necessary conditions are met, else None.
740
+ """
741
+
742
+ if (speculative_model is None and num_speculative_tokens is None):
743
+ return None
744
+
745
+ if speculative_model is not None and num_speculative_tokens is None:
746
+ raise ValueError(
747
+ "Expected both speculative_model and "
748
+ "num_speculative_tokens to be provided, but found "
749
+ f"{speculative_model=} and {num_speculative_tokens=}.")
750
+
751
+ assert (speculative_model is not None
752
+ and num_speculative_tokens is not None)
753
+
754
+ if enable_chunked_prefill:
755
+ raise ValueError(
756
+ "Speculative decoding and chunked prefill are "
757
+ f"currently mutually exclusive ({enable_chunked_prefill=}).")
758
+
759
+ if not use_v2_block_manager:
760
+ raise ValueError(
761
+ "Speculative decoding requires usage of the V2 "
762
+ "block manager. Enable it with --use-v2-block-manager.")
763
+
764
+ # TODO: The user should be able to specify revision/quantization/max
765
+ # model len for the draft model. It is not currently supported.
766
+ draft_revision = None
767
+ draft_code_revision = None
768
+ draft_quantization = None
769
+
770
+ if speculative_model == "[ngram]":
771
+ assert (ngram_prompt_lookup_max is not None
772
+ and ngram_prompt_lookup_max > 0)
773
+ if ngram_prompt_lookup_min is None:
774
+ ngram_prompt_lookup_min = 0
775
+ else:
776
+ assert ngram_prompt_lookup_max > ngram_prompt_lookup_min
777
+
778
+ # TODO: current we still need extract vocab_size from target model
779
+ # config, in future, we may try refactor it out, and set
780
+ # draft related config as None here.
781
+ draft_model_config = target_model_config
782
+ draft_parallel_config = target_parallel_config
783
+ else:
784
+ ngram_prompt_lookup_max = 0
785
+ ngram_prompt_lookup_min = 0
786
+ draft_model_config = ModelConfig(
787
+ model=speculative_model,
788
+ tokenizer=target_model_config.tokenizer,
789
+ tokenizer_mode=target_model_config.tokenizer_mode,
790
+ trust_remote_code=target_model_config.trust_remote_code,
791
+ dtype=target_model_config.dtype,
792
+ seed=target_model_config.seed,
793
+ revision=draft_revision,
794
+ code_revision=draft_code_revision,
795
+ tokenizer_revision=target_model_config.tokenizer_revision,
796
+ max_model_len=None,
797
+ quantization=draft_quantization,
798
+ enforce_eager=target_model_config.enforce_eager,
799
+ max_seq_len_to_capture=target_model_config.
800
+ max_seq_len_to_capture,
801
+ max_logprobs=target_model_config.max_logprobs,
802
+ )
803
+
804
+ draft_model_config.max_model_len = (
805
+ SpeculativeConfig._maybe_override_draft_max_model_len(
806
+ speculative_max_model_len,
807
+ draft_model_config.max_model_len,
808
+ target_model_config.max_model_len,
809
+ ))
810
+
811
+ draft_parallel_config = (
812
+ SpeculativeConfig.create_draft_parallel_config(
813
+ target_parallel_config))
814
+
815
+ return SpeculativeConfig(
816
+ draft_model_config,
817
+ draft_parallel_config,
818
+ num_speculative_tokens,
819
+ ngram_prompt_lookup_max,
820
+ ngram_prompt_lookup_min,
821
+ )
822
+
823
+ @staticmethod
824
+ def _maybe_override_draft_max_model_len(
825
+ speculative_max_model_len: Optional[int],
826
+ draft_max_model_len: int,
827
+ target_max_model_len: int,
828
+ ) -> int:
829
+ """Determine the max sequence len for the draft model. This is usually
830
+ the draft_max_model_len, but may be the target_max_model_len if it is
831
+ less than the draft_max_model_len, or may be speculative_max_model_len
832
+ if it is specified.
833
+
834
+ This is necessary so that sequences do not exceed the capacity of the
835
+ draft model or the target model.
836
+
837
+ speculative_max_model_len is mainly used for testing that sequences can
838
+ skip speculation.
839
+ """
840
+
841
+ if speculative_max_model_len is not None:
842
+
843
+ if speculative_max_model_len > draft_max_model_len:
844
+ raise ValueError(f"{speculative_max_model_len=} cannot be "
845
+ f"larger than {draft_max_model_len=}")
846
+
847
+ if speculative_max_model_len > target_max_model_len:
848
+ raise ValueError(f"{speculative_max_model_len=} cannot be "
849
+ f"larger than {target_max_model_len=}")
850
+
851
+ return speculative_max_model_len
852
+
853
+ return min(
854
+ draft_max_model_len,
855
+ target_max_model_len,
856
+ )
857
+
858
+ @staticmethod
859
+ def create_draft_parallel_config(
860
+ target_parallel_config: ParallelConfig) -> ParallelConfig:
861
+ """Create a parallel config for use by the draft worker.
862
+
863
+ This is mostly a copy of the target parallel config. In the future the
864
+ draft worker can have a different parallel strategy, e.g. TP=1.
865
+ """
866
+ draft_parallel_config = ParallelConfig(
867
+ pipeline_parallel_size=target_parallel_config.
868
+ pipeline_parallel_size,
869
+ tensor_parallel_size=target_parallel_config.tensor_parallel_size,
870
+ worker_use_ray=target_parallel_config.worker_use_ray,
871
+ max_parallel_loading_workers=target_parallel_config.
872
+ max_parallel_loading_workers,
873
+ disable_custom_all_reduce=target_parallel_config.
874
+ disable_custom_all_reduce,
875
+ tokenizer_pool_config=target_parallel_config.tokenizer_pool_config,
876
+ ray_workers_use_nsight=target_parallel_config.
877
+ ray_workers_use_nsight,
878
+ placement_group=target_parallel_config.placement_group,
879
+ )
880
+
881
+ return draft_parallel_config
882
+
883
+ def __init__(
884
+ self,
885
+ draft_model_config: ModelConfig,
886
+ draft_parallel_config: ParallelConfig,
887
+ num_speculative_tokens: int,
888
+ ngram_prompt_lookup_max: int,
889
+ ngram_prompt_lookup_min: int,
890
+ ):
891
+ """Create a SpeculativeConfig object.
892
+
893
+ Args:
894
+ draft_model_config: ModelConfig for the draft model.
895
+ draft_parallel_config: ParallelConfig for the draft model.
896
+ num_speculative_tokens: The number of tokens to sample from the
897
+ draft model before scoring with the target model.
898
+ """
899
+ self.draft_model_config = draft_model_config
900
+ self.draft_parallel_config = draft_parallel_config
901
+ self.num_speculative_tokens = num_speculative_tokens
902
+ self.ngram_prompt_lookup_max = ngram_prompt_lookup_max
903
+ self.ngram_prompt_lookup_min = ngram_prompt_lookup_min
904
+
905
+ self._verify_args()
906
+
907
+ def _verify_args(self) -> None:
908
+ if self.num_speculative_tokens <= 0:
909
+ raise ValueError("Expected num_speculative_tokens to be greater "
910
+ f"than zero ({self.num_speculative_tokens}).")
911
+
912
+ if self.draft_model_config:
913
+ self.draft_model_config.verify_with_parallel_config(
914
+ self.draft_parallel_config)
915
+
916
+ @property
917
+ def num_lookahead_slots(self) -> int:
918
+ """The number of additional slots the scheduler should allocate per
919
+ step, in addition to the slots allocated for each known token.
920
+
921
+ This is equal to the number of speculative tokens, as each speculative
922
+ token must be scored.
923
+ """
924
+ return self.num_speculative_tokens
925
+
926
+ def __repr__(self) -> str:
927
+ if self.ngram_prompt_lookup_max > 0:
928
+ draft_model = "[ngram]"
929
+ else:
930
+ draft_model = self.draft_model_config.model
931
+ num_spec_tokens = self.num_speculative_tokens
932
+ return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})"
933
+
934
+
935
+ @dataclass
936
+ class LoRAConfig:
937
+ max_lora_rank: int
938
+ max_loras: int
939
+ fully_sharded_loras: bool = False
940
+ max_cpu_loras: Optional[int] = None
941
+ lora_dtype: Optional[torch.dtype] = None
942
+ lora_extra_vocab_size: int = 256
943
+ # This is a constant.
944
+ lora_vocab_padding_size: ClassVar[int] = 256
945
+
946
+ def __post_init__(self):
947
+ # Keep this in sync with csrc/punica/bgmv/bgmv_config.h
948
+ possible_max_ranks = (8, 16, 32, 64)
949
+ possible_lora_extra_vocab_size = (0, 256, 512)
950
+ if self.max_lora_rank not in possible_max_ranks:
951
+ raise ValueError(
952
+ f"max_lora_rank ({self.max_lora_rank}) must be one of "
953
+ f"{possible_max_ranks}.")
954
+ if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
955
+ raise ValueError(
956
+ f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
957
+ f"must be one of {possible_lora_extra_vocab_size}.")
958
+ if self.max_loras < 1:
959
+ raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.")
960
+ if self.max_cpu_loras is None:
961
+ self.max_cpu_loras = self.max_loras
962
+ elif self.max_cpu_loras < self.max_loras:
963
+ raise ValueError(
964
+ f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
965
+ f"max_loras ({self.max_loras})")
966
+
967
+ def verify_with_model_config(self, model_config: ModelConfig):
968
+ if self.lora_dtype in (None, "auto"):
969
+ self.lora_dtype = model_config.dtype
970
+ elif isinstance(self.lora_dtype, str):
971
+ self.lora_dtype = getattr(torch, self.lora_dtype)
972
+ if model_config.quantization and model_config.quantization not in [
973
+ "awq", "gptq"
974
+ ]:
975
+ # TODO support marlin and squeezellm
976
+ logger.warning("%s quantization is not tested with LoRA yet.",
977
+ model_config.quantization)
978
+
979
+ def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
980
+ if scheduler_config.max_num_batched_tokens > 65528:
981
+ raise ValueError(
982
+ "Due to limitations of the custom LoRA CUDA kernel, "
983
+ "max_num_batched_tokens must be <= 65528 when "
984
+ "LoRA is enabled.")
985
+
986
+
987
+ @dataclass
988
+ class VisionLanguageConfig:
989
+ """Configs the input data format and how models should run for
990
+ vision language models."""
991
+
992
+ class ImageInputType(enum.Enum):
993
+ """Image input type into the vision language model.
994
+
995
+ An image roughly goes through the following transformation:
996
+ Raw image --> pixel values --> image features --> image embeddings.
997
+
998
+ The difference between different image input types is where the
999
+ image encoder (pixel values --> image features) is run.
1000
+ Different image input types also correspond to different tensor shapes.
1001
+
1002
+ For example, for Llava, PIXEL_VALUES: (1, 3, 336, 336).
1003
+ IMAGE_FEATURES: (1, 576, 1024).
1004
+ """
1005
+ PIXEL_VALUES = enum.auto()
1006
+ IMAGE_FEATURES = enum.auto()
1007
+
1008
+ image_input_type: ImageInputType
1009
+ # The input id corresponding to image token.
1010
+ image_token_id: int
1011
+ # Used for running `run_prefill_max_token`.
1012
+ # For models that support varying resolution, this corresponds to
1013
+ # worst case scenario (biggest supported resolution).
1014
+ image_input_shape: tuple
1015
+ image_feature_size: int
1016
+
1017
+ @classmethod
1018
+ def get_image_input_enum_type(
1019
+ cls, value: str) -> "VisionLanguageConfig.ImageInputType":
1020
+ """Get the image input type from a string."""
1021
+ try:
1022
+ return cls.ImageInputType[value.upper()]
1023
+ except KeyError as e:
1024
+ raise ValueError(f"{value} is not a valid choice. "
1025
+ f"Expecting to choose from "
1026
+ f"{[x.name for x in cls.ImageInputType]}.") from e
1027
+
1028
+
1029
+ _STR_DTYPE_TO_TORCH_DTYPE = {
1030
+ "half": torch.float16,
1031
+ "float16": torch.float16,
1032
+ "float": torch.float32,
1033
+ "float32": torch.float32,
1034
+ "bfloat16": torch.bfloat16,
1035
+ }
1036
+
1037
+ _ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"]
1038
+
1039
+
1040
+ def _get_and_verify_dtype(
1041
+ config: PretrainedConfig,
1042
+ dtype: Union[str, torch.dtype],
1043
+ ) -> torch.dtype:
1044
+ # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
1045
+ # because config.torch_dtype can be None.
1046
+ config_dtype = getattr(config, "torch_dtype", None)
1047
+ if config_dtype is None:
1048
+ config_dtype = torch.float32
1049
+
1050
+ if isinstance(dtype, str):
1051
+ dtype = dtype.lower()
1052
+ if dtype == "auto":
1053
+ if config_dtype == torch.float32:
1054
+ # Following the common practice, we use float16 for float32
1055
+ # models.
1056
+ torch_dtype = torch.float16
1057
+ else:
1058
+ torch_dtype = config_dtype
1059
+ else:
1060
+ if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
1061
+ raise ValueError(f"Unknown dtype: {dtype}")
1062
+ torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
1063
+ elif isinstance(dtype, torch.dtype):
1064
+ torch_dtype = dtype
1065
+ else:
1066
+ raise ValueError(f"Unknown dtype: {dtype}")
1067
+
1068
+ if is_hip() and torch_dtype == torch.float32:
1069
+ rocm_supported_dtypes = [
1070
+ k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
1071
+ if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
1072
+ ]
1073
+ raise ValueError(f"dtype '{dtype}' is not supported in ROCm. "
1074
+ f"Supported dtypes are {rocm_supported_dtypes}")
1075
+
1076
+ # Verify the dtype.
1077
+ if torch_dtype != config_dtype:
1078
+ if torch_dtype == torch.float32:
1079
+ # Upcasting to float32 is allowed.
1080
+ pass
1081
+ elif config_dtype == torch.float32:
1082
+ # Downcasting from float32 to float16 or bfloat16 is allowed.
1083
+ pass
1084
+ else:
1085
+ # Casting between float16 and bfloat16 is allowed with a warning.
1086
+ logger.warning("Casting %s to %s.", config_dtype, torch_dtype)
1087
+
1088
+ return torch_dtype
1089
+
1090
+
1091
+ def _get_and_verify_max_len(
1092
+ hf_config: PretrainedConfig,
1093
+ max_model_len: Optional[int],
1094
+ ) -> int:
1095
+ """Get and verify the model's maximum length."""
1096
+ derived_max_model_len = float("inf")
1097
+ possible_keys = [
1098
+ # OPT
1099
+ "max_position_embeddings",
1100
+ # GPT-2
1101
+ "n_positions",
1102
+ # MPT
1103
+ "max_seq_len",
1104
+ # ChatGLM2
1105
+ "seq_length",
1106
+ # Command-R
1107
+ "model_max_length",
1108
+ # Others
1109
+ "max_sequence_length",
1110
+ "max_seq_length",
1111
+ "seq_len",
1112
+ ]
1113
+ max_len_key = None
1114
+ for key in possible_keys:
1115
+ max_len = getattr(hf_config, key, None)
1116
+ if max_len is not None:
1117
+ max_len_key = key if max_len < derived_max_model_len \
1118
+ else max_len_key
1119
+ derived_max_model_len = min(derived_max_model_len, max_len)
1120
+ if derived_max_model_len == float("inf"):
1121
+ if max_model_len is not None:
1122
+ # If max_model_len is specified, we use it.
1123
+ return max_model_len
1124
+
1125
+ default_max_len = 2048
1126
+ logger.warning(
1127
+ "The model's config.json does not contain any of the following "
1128
+ "keys to determine the original maximum length of the model: "
1129
+ "%d. Assuming the model's maximum length is %d.", possible_keys,
1130
+ default_max_len)
1131
+ derived_max_model_len = default_max_len
1132
+
1133
+ rope_scaling = getattr(hf_config, "rope_scaling", None)
1134
+ if rope_scaling is not None and rope_scaling["type"] != "su":
1135
+ assert "factor" in rope_scaling
1136
+ scaling_factor = rope_scaling["factor"]
1137
+ if rope_scaling["type"] == "yarn":
1138
+ derived_max_model_len = rope_scaling[
1139
+ "original_max_position_embeddings"]
1140
+ derived_max_model_len *= scaling_factor
1141
+
1142
+ if max_model_len is None:
1143
+ max_model_len = int(derived_max_model_len)
1144
+ elif max_model_len > derived_max_model_len:
1145
+ # Some models might have a separate key for specifying model_max_length
1146
+ # that will be bigger than derived_max_model_len. We compare user input
1147
+ # with model_max_length and allow this override when it's smaller.
1148
+ model_max_length = getattr(hf_config, "model_max_length", None)
1149
+ if model_max_length is not None and max_model_len <= model_max_length:
1150
+ pass
1151
+ else:
1152
+ raise ValueError(
1153
+ f"User-specified max_model_len ({max_model_len}) is greater "
1154
+ "than the derived max_model_len "
1155
+ f"({max_len_key}={derived_max_model_len} or model_max_length="
1156
+ f"{model_max_length} in model's config.json). This may lead "
1157
+ "to incorrect model outputs or CUDA errors. Make sure the "
1158
+ "value is correct and within the model context size.")
1159
+ return int(max_model_len)
1160
+
1161
+
1162
+ def get_served_model_name(model: str,
1163
+ served_model_name: Optional[Union[str, List[str]]]):
1164
+ """
1165
+ If the input is a non-empty list, the first model_name in
1166
+ `served_model_name` is taken.
1167
+ If the input is a non-empty string, it is used directly.
1168
+ For cases where the input is either an empty string or an
1169
+ empty list, the fallback is to use `self.model`.
1170
+ """
1171
+ if not served_model_name:
1172
+ return model
1173
+ if isinstance(served_model_name, list):
1174
+ return served_model_name[0]
1175
+ return served_model_name
1176
+
1177
+
1178
+ @dataclass
1179
+ class DecodingConfig:
1180
+ """Dataclass which contains the decoding strategy of the engine"""
1181
+
1182
+ # Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer'
1183
+ guided_decoding_backend: str = 'outlines'
1184
+
1185
+ def __post_init__(self):
1186
+ valid_guided_backends = ['outlines', 'lm-format-enforcer']
1187
+ backend = self.guided_decoding_backend
1188
+ if backend not in valid_guided_backends:
1189
+ raise ValueError(f"Invalid guided_decoding_backend '{backend},"
1190
+ f"must be one of {valid_guided_backends}")
1191
+
1192
+
1193
+ @dataclass(frozen=True)
1194
+ class EngineConfig:
1195
+ """Dataclass which contains all engine-related configuration. This
1196
+ simplifies passing around the distinct configurations in the codebase.
1197
+ """
1198
+
1199
+ model_config: ModelConfig
1200
+ cache_config: CacheConfig
1201
+ parallel_config: ParallelConfig
1202
+ scheduler_config: SchedulerConfig
1203
+ device_config: DeviceConfig
1204
+ load_config: LoadConfig
1205
+ lora_config: Optional[LoRAConfig]
1206
+ vision_language_config: Optional[VisionLanguageConfig]
1207
+ speculative_config: Optional[SpeculativeConfig]
1208
+ decoding_config: Optional[DecodingConfig]
1209
+
1210
+ def __post_init__(self):
1211
+ """Verify configs are valid & consistent with each other.
1212
+ """
1213
+ self.model_config.verify_with_parallel_config(self.parallel_config)
1214
+ self.cache_config.verify_with_parallel_config(self.parallel_config)
1215
+
1216
+ if self.lora_config:
1217
+ self.lora_config.verify_with_model_config(self.model_config)
1218
+ self.lora_config.verify_with_scheduler_config(
1219
+ self.scheduler_config)
1220
+
1221
+ def to_dict(self):
1222
+ """Return the configs as a dictionary, for use in **kwargs.
1223
+ """
1224
+ return dict(
1225
+ (field.name, getattr(self, field.name)) for field in fields(self))