vllm-npu 0.4.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. vllm/__init__.py +23 -0
  2. vllm/_custom_ops.py +251 -0
  3. vllm/attention/__init__.py +13 -0
  4. vllm/attention/backends/__init__.py +0 -0
  5. vllm/attention/backends/abstract.py +127 -0
  6. vllm/attention/backends/flash_attn.py +271 -0
  7. vllm/attention/backends/flashinfer.py +220 -0
  8. vllm/attention/backends/rocm_flash_attn.py +374 -0
  9. vllm/attention/backends/torch_sdpa.py +250 -0
  10. vllm/attention/backends/xformers.py +393 -0
  11. vllm/attention/layer.py +56 -0
  12. vllm/attention/ops/__init__.py +0 -0
  13. vllm/attention/ops/paged_attn.py +216 -0
  14. vllm/attention/ops/prefix_prefill.py +792 -0
  15. vllm/attention/ops/triton_flash_attention.py +810 -0
  16. vllm/attention/selector.py +91 -0
  17. vllm/block.py +84 -0
  18. vllm/config.py +1225 -0
  19. vllm/core/__init__.py +0 -0
  20. vllm/core/block/__init__.py +0 -0
  21. vllm/core/block/block_table.py +295 -0
  22. vllm/core/block/common.py +199 -0
  23. vllm/core/block/cpu_gpu_block_allocator.py +228 -0
  24. vllm/core/block/interfaces.py +205 -0
  25. vllm/core/block/naive_block.py +318 -0
  26. vllm/core/block/prefix_caching_block.py +606 -0
  27. vllm/core/block_manager_v1.py +625 -0
  28. vllm/core/block_manager_v2.py +258 -0
  29. vllm/core/evictor_v1.py +105 -0
  30. vllm/core/evictor_v2.py +127 -0
  31. vllm/core/interfaces.py +113 -0
  32. vllm/core/policy.py +45 -0
  33. vllm/core/scheduler.py +1163 -0
  34. vllm/distributed/__init__.py +3 -0
  35. vllm/distributed/communication_op.py +237 -0
  36. vllm/distributed/device_communicators/__init__.py +0 -0
  37. vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
  38. vllm/distributed/device_communicators/pynccl.py +287 -0
  39. vllm/distributed/device_communicators/pynccl_utils.py +66 -0
  40. vllm/distributed/parallel_state.py +339 -0
  41. vllm/distributed/utils.py +136 -0
  42. vllm/engine/__init__.py +0 -0
  43. vllm/engine/arg_utils.py +649 -0
  44. vllm/engine/async_llm_engine.py +737 -0
  45. vllm/engine/llm_engine.py +784 -0
  46. vllm/engine/metrics.py +368 -0
  47. vllm/engine/output_processor/__init__.py +0 -0
  48. vllm/engine/output_processor/interfaces.py +76 -0
  49. vllm/engine/output_processor/multi_step.py +142 -0
  50. vllm/engine/output_processor/single_step.py +284 -0
  51. vllm/engine/output_processor/stop_checker.py +101 -0
  52. vllm/engine/output_processor/util.py +19 -0
  53. vllm/entrypoints/__init__.py +0 -0
  54. vllm/entrypoints/api_server.py +119 -0
  55. vllm/entrypoints/llm.py +259 -0
  56. vllm/entrypoints/openai/__init__.py +0 -0
  57. vllm/entrypoints/openai/api_server.py +186 -0
  58. vllm/entrypoints/openai/cli_args.py +115 -0
  59. vllm/entrypoints/openai/protocol.py +460 -0
  60. vllm/entrypoints/openai/serving_chat.py +392 -0
  61. vllm/entrypoints/openai/serving_completion.py +347 -0
  62. vllm/entrypoints/openai/serving_engine.py +234 -0
  63. vllm/envs.py +217 -0
  64. vllm/executor/__init__.py +0 -0
  65. vllm/executor/cpu_executor.py +152 -0
  66. vllm/executor/distributed_gpu_executor.py +115 -0
  67. vllm/executor/executor_base.py +115 -0
  68. vllm/executor/gpu_executor.py +150 -0
  69. vllm/executor/multiproc_worker_utils.py +263 -0
  70. vllm/executor/neuron_executor.py +91 -0
  71. vllm/executor/ray_gpu_executor.py +327 -0
  72. vllm/executor/ray_utils.py +119 -0
  73. vllm/logger.py +153 -0
  74. vllm/logging/__init__.py +5 -0
  75. vllm/logging/formatter.py +15 -0
  76. vllm/lora/__init__.py +0 -0
  77. vllm/lora/fully_sharded_layers.py +262 -0
  78. vllm/lora/layers.py +1181 -0
  79. vllm/lora/lora.py +167 -0
  80. vllm/lora/models.py +645 -0
  81. vllm/lora/punica.py +213 -0
  82. vllm/lora/request.py +32 -0
  83. vllm/lora/utils.py +98 -0
  84. vllm/lora/worker_manager.py +251 -0
  85. vllm/model_executor/__init__.py +7 -0
  86. vllm/model_executor/guided_decoding/__init__.py +25 -0
  87. vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
  88. vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
  89. vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
  90. vllm/model_executor/layers/__init__.py +0 -0
  91. vllm/model_executor/layers/activation.py +173 -0
  92. vllm/model_executor/layers/fused_moe/__init__.py +7 -0
  93. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  94. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  97. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  98. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  99. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  100. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  101. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  102. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  103. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  104. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  105. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
  106. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  107. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  108. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  109. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  110. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
  111. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  112. vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
  113. vllm/model_executor/layers/layernorm.py +71 -0
  114. vllm/model_executor/layers/linear.py +709 -0
  115. vllm/model_executor/layers/logits_processor.py +115 -0
  116. vllm/model_executor/layers/ops/__init__.py +0 -0
  117. vllm/model_executor/layers/ops/rand.py +157 -0
  118. vllm/model_executor/layers/ops/sample.py +406 -0
  119. vllm/model_executor/layers/quantization/__init__.py +35 -0
  120. vllm/model_executor/layers/quantization/aqlm.py +376 -0
  121. vllm/model_executor/layers/quantization/awq.py +175 -0
  122. vllm/model_executor/layers/quantization/base_config.py +97 -0
  123. vllm/model_executor/layers/quantization/fp8.py +265 -0
  124. vllm/model_executor/layers/quantization/gptq.py +224 -0
  125. vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
  126. vllm/model_executor/layers/quantization/marlin.py +227 -0
  127. vllm/model_executor/layers/quantization/schema.py +84 -0
  128. vllm/model_executor/layers/quantization/squeezellm.py +137 -0
  129. vllm/model_executor/layers/rejection_sampler.py +405 -0
  130. vllm/model_executor/layers/rotary_embedding.py +525 -0
  131. vllm/model_executor/layers/sampler.py +1051 -0
  132. vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
  133. vllm/model_executor/model_loader/__init__.py +30 -0
  134. vllm/model_executor/model_loader/loader.py +362 -0
  135. vllm/model_executor/model_loader/neuron.py +136 -0
  136. vllm/model_executor/model_loader/tensorizer.py +368 -0
  137. vllm/model_executor/model_loader/utils.py +41 -0
  138. vllm/model_executor/model_loader/weight_utils.py +372 -0
  139. vllm/model_executor/models/__init__.py +119 -0
  140. vllm/model_executor/models/baichuan.py +410 -0
  141. vllm/model_executor/models/bloom.py +327 -0
  142. vllm/model_executor/models/chatglm.py +386 -0
  143. vllm/model_executor/models/commandr.py +373 -0
  144. vllm/model_executor/models/dbrx.py +413 -0
  145. vllm/model_executor/models/decilm.py +122 -0
  146. vllm/model_executor/models/deepseek.py +438 -0
  147. vllm/model_executor/models/falcon.py +444 -0
  148. vllm/model_executor/models/gemma.py +393 -0
  149. vllm/model_executor/models/gpt2.py +266 -0
  150. vllm/model_executor/models/gpt_bigcode.py +274 -0
  151. vllm/model_executor/models/gpt_j.py +281 -0
  152. vllm/model_executor/models/gpt_neox.py +295 -0
  153. vllm/model_executor/models/internlm2.py +323 -0
  154. vllm/model_executor/models/jais.py +333 -0
  155. vllm/model_executor/models/llama.py +442 -0
  156. vllm/model_executor/models/llava.py +239 -0
  157. vllm/model_executor/models/minicpm.py +531 -0
  158. vllm/model_executor/models/mixtral.py +583 -0
  159. vllm/model_executor/models/mixtral_quant.py +404 -0
  160. vllm/model_executor/models/mpt.py +295 -0
  161. vllm/model_executor/models/olmo.py +356 -0
  162. vllm/model_executor/models/opt.py +349 -0
  163. vllm/model_executor/models/orion.py +319 -0
  164. vllm/model_executor/models/phi.py +300 -0
  165. vllm/model_executor/models/qwen.py +284 -0
  166. vllm/model_executor/models/qwen2.py +367 -0
  167. vllm/model_executor/models/qwen2_moe.py +447 -0
  168. vllm/model_executor/models/stablelm.py +301 -0
  169. vllm/model_executor/models/starcoder2.py +302 -0
  170. vllm/model_executor/models/xverse.py +366 -0
  171. vllm/model_executor/sampling_metadata.py +588 -0
  172. vllm/model_executor/utils.py +35 -0
  173. vllm/outputs.py +150 -0
  174. vllm/py.typed +2 -0
  175. vllm/sampling_params.py +340 -0
  176. vllm/sequence.py +766 -0
  177. vllm/spec_decode/__init__.py +0 -0
  178. vllm/spec_decode/batch_expansion.py +397 -0
  179. vllm/spec_decode/interfaces.py +73 -0
  180. vllm/spec_decode/metrics.py +191 -0
  181. vllm/spec_decode/multi_step_worker.py +203 -0
  182. vllm/spec_decode/ngram_worker.py +176 -0
  183. vllm/spec_decode/spec_decode_worker.py +472 -0
  184. vllm/spec_decode/top1_proposer.py +200 -0
  185. vllm/spec_decode/util.py +228 -0
  186. vllm/test_utils.py +41 -0
  187. vllm/transformers_utils/__init__.py +0 -0
  188. vllm/transformers_utils/config.py +58 -0
  189. vllm/transformers_utils/configs/__init__.py +16 -0
  190. vllm/transformers_utils/configs/chatglm.py +68 -0
  191. vllm/transformers_utils/configs/dbrx.py +278 -0
  192. vllm/transformers_utils/configs/falcon.py +87 -0
  193. vllm/transformers_utils/configs/jais.py +236 -0
  194. vllm/transformers_utils/configs/mpt.py +178 -0
  195. vllm/transformers_utils/detokenizer.py +313 -0
  196. vllm/transformers_utils/tokenizer.py +149 -0
  197. vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
  198. vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
  199. vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
  200. vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
  201. vllm/transformers_utils/tokenizers/__init__.py +5 -0
  202. vllm/transformers_utils/tokenizers/baichuan.py +255 -0
  203. vllm/usage/__init__.py +0 -0
  204. vllm/usage/usage_lib.py +209 -0
  205. vllm/utils.py +677 -0
  206. vllm/worker/__init__.py +0 -0
  207. vllm/worker/cache_engine.py +105 -0
  208. vllm/worker/cpu_model_runner.py +346 -0
  209. vllm/worker/cpu_worker.py +321 -0
  210. vllm/worker/model_runner.py +1168 -0
  211. vllm/worker/neuron_model_runner.py +196 -0
  212. vllm/worker/neuron_worker.py +98 -0
  213. vllm/worker/worker.py +345 -0
  214. vllm/worker/worker_base.py +146 -0
  215. vllm_npu-0.4.2.dist-info/LICENSE +201 -0
  216. vllm_npu-0.4.2.dist-info/METADATA +173 -0
  217. vllm_npu-0.4.2.dist-info/RECORD +219 -0
  218. vllm_npu-0.4.2.dist-info/WHEEL +5 -0
  219. vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,368 @@
1
+ import argparse
2
+ import dataclasses
3
+ import io
4
+ import os
5
+ import time
6
+ import typing
7
+ from dataclasses import dataclass
8
+ from typing import Generator, Optional, Tuple, Type, Union
9
+
10
+ import torch
11
+ from torch import nn
12
+ from transformers import PretrainedConfig
13
+
14
+ import vllm.envs as envs
15
+ from vllm.config import ModelConfig, ParallelConfig
16
+ from vllm.logger import init_logger
17
+ from vllm.model_executor.layers.quantization.base_config import (
18
+ QuantizationConfig)
19
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
20
+ VocabParallelEmbedding)
21
+
22
+ tensorizer_load_fail = None
23
+
24
+ try:
25
+ from tensorizer import (DecryptionParams, EncryptionParams,
26
+ TensorDeserializer, TensorSerializer)
27
+ from tensorizer.stream_io import open_stream
28
+ from tensorizer.utils import (convert_bytes, get_mem_usage,
29
+ no_init_or_tensor)
30
+ except ImportError as e:
31
+ tensorizer_load_fail = e
32
+
33
+ __all__ = [
34
+ 'EncryptionParams', 'DecryptionParams', 'TensorDeserializer',
35
+ 'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage',
36
+ 'no_init_or_tensor', 'TensorizerConfig'
37
+ ]
38
+
39
+ logger = init_logger(__name__)
40
+
41
+
42
+ @dataclass
43
+ class TensorizerConfig:
44
+ tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
45
+ str, bytes, os.PathLike, int]
46
+ vllm_tensorized: bool
47
+ verify_hash: Optional[bool] = False
48
+ num_readers: Optional[int] = None
49
+ encryption_keyfile: Optional[str] = None
50
+ s3_access_key_id: Optional[str] = None
51
+ s3_secret_access_key: Optional[str] = None
52
+ s3_endpoint: Optional[str] = None
53
+ model_class: Optional[Type[torch.nn.Module]] = None
54
+ hf_config: Optional[PretrainedConfig] = None
55
+ dtype: Optional[Union[str, torch.dtype]] = None
56
+
57
+ def _construct_tensorizer_args(self) -> "TensorizerArgs":
58
+ tensorizer_args = {
59
+ "tensorizer_uri": self.tensorizer_uri,
60
+ "vllm_tensorized": self.vllm_tensorized,
61
+ "verify_hash": self.verify_hash,
62
+ "num_readers": self.num_readers,
63
+ "encryption_keyfile": self.encryption_keyfile,
64
+ "s3_access_key_id": self.s3_access_key_id,
65
+ "s3_secret_access_key": self.s3_secret_access_key,
66
+ "s3_endpoint": self.s3_endpoint,
67
+ }
68
+ return TensorizerArgs(**tensorizer_args) # type: ignore
69
+
70
+ def verify_with_parallel_config(
71
+ self,
72
+ parallel_config: "ParallelConfig",
73
+ ) -> None:
74
+ if (parallel_config.tensor_parallel_size > 1
75
+ and self.tensorizer_uri is not None):
76
+ raise ValueError(
77
+ "Loading to multiple GPUs is not currently supported with "
78
+ "vLLM-serialized models. Please set tensor_parallel_size=1."
79
+ " or use a non-vLLM-serialized model, such as a "
80
+ "serialized Hugging Face `PretrainedModel`.")
81
+
82
+ def verify_with_model_config(self, model_config: "ModelConfig") -> None:
83
+ if (model_config.quantization is not None
84
+ and self.tensorizer_uri is not None):
85
+ logger.warning(
86
+ "Loading a model using Tensorizer with quantization on vLLM"
87
+ " is unstable and may lead to errors.")
88
+
89
+
90
+ def load_with_tensorizer(tensorizer_config: TensorizerConfig,
91
+ **extra_kwargs) -> nn.Module:
92
+ tensorizer = TensorizerAgent(tensorizer_config, **extra_kwargs)
93
+ return tensorizer.deserialize()
94
+
95
+
96
+ def is_vllm_serialized_tensorizer(tensorizer_config: TensorizerConfig) -> bool:
97
+ if tensorizer_config is None:
98
+ return False
99
+ return tensorizer_config.vllm_tensorized
100
+
101
+
102
+ @dataclass
103
+ class TensorizerArgs:
104
+ tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
105
+ str, bytes, os.PathLike, int]
106
+ vllm_tensorized: bool
107
+ verify_hash: Optional[bool] = False
108
+ num_readers: Optional[int] = None
109
+ encryption_keyfile: Optional[str] = None
110
+ s3_access_key_id: Optional[str] = None
111
+ s3_secret_access_key: Optional[str] = None
112
+ s3_endpoint: Optional[str] = None
113
+ """
114
+ Args for the TensorizerAgent class. These are used to configure the behavior
115
+ of the TensorDeserializer when loading tensors from a serialized model.
116
+
117
+ Args:
118
+ tensorizer_uri: Path to serialized model tensors. Can be a local file
119
+ path or a S3 URI.
120
+ vllm_tensorized: If True, indicates that the serialized model is a
121
+ vLLM model. This is used to determine the behavior of the
122
+ TensorDeserializer when loading tensors from a serialized model.
123
+ It is far faster to deserialize a vLLM model as it utilizes
124
+ tensorizer's optimized GPU loading.
125
+ verify_hash: If True, the hashes of each tensor will be verified against
126
+ the hashes stored in the metadata. A `HashMismatchError` will be
127
+ raised if any of the hashes do not match.
128
+ num_readers: Controls how many threads are allowed to read concurrently
129
+ from the source file. Default is `None`, which will dynamically set
130
+ the number of readers based on the number of available
131
+ resources and model size. This greatly increases performance.
132
+ encryption_keyfile: File path to a binary file containing a
133
+ binary key to use for decryption. `None` (the default) means
134
+ no decryption. See the example script in
135
+ examples/tensorize_vllm_model.py.
136
+ s3_access_key_id: The access key for the S3 bucket. Can also be set via
137
+ the S3_ACCESS_KEY_ID environment variable.
138
+ s3_secret_access_key: The secret access key for the S3 bucket. Can also
139
+ be set via the S3_SECRET_ACCESS_KEY environment variable.
140
+ s3_endpoint: The endpoint for the S3 bucket. Can also be set via the
141
+ S3_ENDPOINT_URL environment variable.
142
+ """
143
+
144
+ def __post_init__(self):
145
+ self.file_obj = self.tensorizer_uri
146
+ self.s3_access_key_id = self.s3_access_key_id or envs.S3_ACCESS_KEY_ID
147
+ self.s3_secret_access_key = (self.s3_secret_access_key
148
+ or envs.S3_SECRET_ACCESS_KEY)
149
+ self.s3_endpoint = self.s3_endpoint or envs.S3_ENDPOINT_URL
150
+ self.stream_params = {
151
+ "s3_access_key_id": self.s3_access_key_id,
152
+ "s3_secret_access_key": self.s3_secret_access_key,
153
+ "s3_endpoint": self.s3_endpoint,
154
+ }
155
+
156
+ self.deserializer_params = {
157
+ "verify_hash": self.verify_hash,
158
+ "encryption": self.encryption_keyfile,
159
+ "num_readers": self.num_readers
160
+ }
161
+ if self.encryption_keyfile:
162
+ with open_stream(
163
+ self.encryption_keyfile,
164
+ **self.stream_params,
165
+ ) as stream:
166
+ key = stream.read()
167
+ decryption_params = DecryptionParams.from_key(key)
168
+ self.deserializer_params['encryption'] = decryption_params
169
+
170
+ @staticmethod
171
+ def add_cli_args(
172
+ parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
173
+ """Tensorizer CLI arguments"""
174
+
175
+ # Tensorizer options arg group
176
+ group = parser.add_argument_group(
177
+ 'tensorizer options',
178
+ description=('Options for configuring the behavior of the'
179
+ ' tensorizer deserializer when '
180
+ '--load-format=tensorizer'))
181
+
182
+ group.add_argument(
183
+ "--tensorizer-uri",
184
+ help="Path to serialized model tensors. Can be a local file path,"
185
+ " or an HTTP(S) or S3 URI.",
186
+ )
187
+ group.add_argument(
188
+ "--verify-hash",
189
+ action="store_true",
190
+ help="If enabled, the hashes of each tensor will be verified"
191
+ " against the hashes stored in the file metadata. An exception"
192
+ " will be raised if any of the hashes do not match.",
193
+ )
194
+ group.add_argument(
195
+ "--encryption-keyfile",
196
+ default=None,
197
+ help="The file path to a binary file containing a binary key to "
198
+ "use for decryption. Can be a file path or S3 network URI.")
199
+ group.add_argument(
200
+ "--num-readers",
201
+ default=None,
202
+ type=int,
203
+ help="Controls how many threads are allowed to read concurrently "
204
+ "from the source file. Default is `None`, which will dynamically "
205
+ "set the number of readers based on the available resources "
206
+ "and model size. This greatly increases performance.")
207
+ group.add_argument(
208
+ "--s3-access-key-id",
209
+ default=None,
210
+ help="The access key for the S3 bucket. Can also be set via the "
211
+ "S3_ACCESS_KEY_ID environment variable.",
212
+ )
213
+ group.add_argument(
214
+ "--s3-secret-access-key",
215
+ default=None,
216
+ help="The secret access key for the S3 bucket. Can also be set via "
217
+ "the S3_SECRET_ACCESS_KEY environment variable.",
218
+ )
219
+ group.add_argument(
220
+ "--s3-endpoint",
221
+ default=None,
222
+ help="The endpoint for the S3 bucket. Can also be set via the "
223
+ "S3_ENDPOINT_URL environment variable.",
224
+ )
225
+ group.add_argument(
226
+ "--vllm-tensorized",
227
+ action="store_true",
228
+ help="If enabled, indicates that the serialized model is a vLLM "
229
+ "model. This is used to determine the behavior of the "
230
+ "TensorDeserializer when loading tensors from a "
231
+ "serialized model.")
232
+
233
+ return parser
234
+
235
+ @classmethod
236
+ def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs":
237
+ attrs = [attr.name for attr in dataclasses.fields(cls)]
238
+ tensorizer_args = cls(**{
239
+ attr: getattr(args, attr)
240
+ for attr in attrs if hasattr(args, attr)
241
+ })
242
+ return tensorizer_args
243
+
244
+
245
+ class TensorizerAgent:
246
+ """
247
+ A class for performing tensorizer deserializations specifically for
248
+ vLLM models using plaid_mode. Uses TensorizerArgs to configure the
249
+ behavior of the TensorDeserializer when loading tensors from a serialized
250
+ model. For deserializations of HuggingFace models, TensorDeserializer is
251
+ instead used as an iterator directly in the func hf_model_weights_iterator
252
+ in vllm/model_executor/model_loader/weight_utils.py
253
+ """
254
+
255
+ def __init__(self, tensorizer_config: TensorizerConfig,
256
+ quant_config: QuantizationConfig, **extra_kwargs):
257
+ if tensorizer_load_fail is not None:
258
+ raise ImportError(
259
+ "Tensorizer is not installed. Please install tensorizer "
260
+ "to use this feature with `pip install vllm[tensorizer]`."
261
+ ) from tensorizer_load_fail
262
+
263
+ self.tensorizer_config = tensorizer_config
264
+ self.tensorizer_args = (
265
+ self.tensorizer_config._construct_tensorizer_args())
266
+ self.extra_kwargs = extra_kwargs
267
+ if extra_kwargs.get("quant_config", None) is not None:
268
+ self.quant_config = extra_kwargs["quant_config"]
269
+ else:
270
+ self.quant_config = quant_config
271
+ self.model = self._init_model()
272
+
273
+ def _init_model(self):
274
+ assert self.tensorizer_config.hf_config is not None
275
+ model_args = self.tensorizer_config.hf_config
276
+ model_args.torch_dtype = self.tensorizer_config.dtype
277
+ assert self.tensorizer_config.model_class is not None
278
+ with no_init_or_tensor():
279
+ return self.tensorizer_config.model_class(
280
+ config=model_args,
281
+ quant_config=self.quant_config,
282
+ **self.extra_kwargs)
283
+
284
+ def _resize_lora_embeddings(self):
285
+ """Modify LoRA embedding layers to use bigger tensors
286
+ to allow for adapter added tokens."""
287
+ for child in self.model.modules():
288
+ if (isinstance(child, VocabParallelEmbedding)
289
+ and child.weight.shape[0] <
290
+ child.num_embeddings_per_partition):
291
+ new_weight = torch.empty(child.num_embeddings_per_partition,
292
+ child.embedding_dim,
293
+ dtype=child.weight.dtype,
294
+ device=child.weight.device)
295
+ new_weight[:child.weight.shape[0]].copy_(child.weight.data)
296
+ new_weight[child.weight.shape[0]:].fill_(0)
297
+ child.weight.data = new_weight
298
+
299
+ def _check_tensors_on_meta_device(self):
300
+ for tensor in self.model.state_dict().values():
301
+ if tensor.device.type == 'meta':
302
+ raise ValueError(
303
+ "The serialized model contains tensors on the meta device,"
304
+ " indicating that some tensors were not loaded properly."
305
+ " Please check that the parameters of the model being"
306
+ " specified match that of the serialized model, such as"
307
+ " its quantization.")
308
+
309
+ def deserialize(self):
310
+ """
311
+ Deserialize the model using the TensorDeserializer. This method is
312
+ specifically for vLLM models using tensorizer's plaid_mode.
313
+
314
+ The deserializer makes use of tensorizer_args.stream_params
315
+ to configure the behavior of the stream when loading tensors from a
316
+ serialized model. The deserializer_params are used to configure the
317
+ behavior of the TensorDeserializer when loading tensors themselves.
318
+ Documentation on these params can be found in TensorizerArgs
319
+
320
+ Returns:
321
+ nn.Module: The deserialized model.
322
+ """
323
+ before_mem = get_mem_usage()
324
+ start = time.perf_counter()
325
+ with open_stream(
326
+ self.tensorizer_args.tensorizer_uri,
327
+ mode="rb",
328
+ **self.tensorizer_args.stream_params,
329
+ ) as stream, TensorDeserializer(
330
+ stream,
331
+ dtype=self.tensorizer_config.dtype,
332
+ **self.tensorizer_args.deserializer_params) as deserializer:
333
+ deserializer.load_into_module(self.model)
334
+ end = time.perf_counter()
335
+
336
+ total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
337
+ duration = end - start
338
+ per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
339
+ after_mem = get_mem_usage()
340
+ deserializer.close()
341
+ logger.info("Deserialized %s in %0.2fs, %s/s", total_bytes_str,
342
+ end - start, per_second)
343
+ logger.info("Memory usage before: %s", before_mem)
344
+ logger.info("Memory usage after: %s", after_mem)
345
+
346
+ self._check_tensors_on_meta_device()
347
+ self._resize_lora_embeddings()
348
+ return self.model.eval()
349
+
350
+
351
+ def tensorizer_weights_iterator(
352
+ tensorizer_args: "TensorizerArgs"
353
+ ) -> Generator[Tuple[str, torch.Tensor], None, None]:
354
+ logger.warning(
355
+ "Deserializing HuggingFace models is not optimized for "
356
+ "loading on vLLM, as tensorizer is forced to load to CPU. "
357
+ "Consider deserializing a vLLM model instead for faster "
358
+ "load times. See the examples/tensorize_vllm_model.py example "
359
+ "script for serializing vLLM models.")
360
+
361
+ deserializer_args = tensorizer_args.deserializer_params
362
+ stream_params = tensorizer_args.stream_params
363
+ stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params)
364
+ with TensorDeserializer(stream, **deserializer_args,
365
+ device="cpu") as state:
366
+ for name, param in state.items():
367
+ yield name, param
368
+ del state
@@ -0,0 +1,41 @@
1
+ """Utilities for selecting and loading models."""
2
+ import contextlib
3
+ from typing import Tuple, Type
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from vllm.config import ModelConfig
9
+ from vllm.model_executor.models import ModelRegistry
10
+
11
+
12
+ @contextlib.contextmanager
13
+ def set_default_torch_dtype(dtype: torch.dtype):
14
+ """Sets the default torch dtype to the given dtype."""
15
+ old_dtype = torch.get_default_dtype()
16
+ torch.set_default_dtype(dtype)
17
+ yield
18
+ torch.set_default_dtype(old_dtype)
19
+
20
+
21
+ def get_model_architecture(
22
+ model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
23
+ architectures = getattr(model_config.hf_config, "architectures", [])
24
+ # Special handling for quantized Mixtral.
25
+ # FIXME(woosuk): This is a temporary hack.
26
+ if (model_config.quantization is not None
27
+ and model_config.quantization != "fp8"
28
+ and "MixtralForCausalLM" in architectures):
29
+ architectures = ["QuantMixtralForCausalLM"]
30
+
31
+ for arch in architectures:
32
+ model_cls = ModelRegistry.load_model_cls(arch)
33
+ if model_cls is not None:
34
+ return (model_cls, arch)
35
+ raise ValueError(
36
+ f"Model architectures {architectures} are not supported for now. "
37
+ f"Supported architectures: {ModelRegistry.get_supported_archs()}")
38
+
39
+
40
+ def get_architecture_class_name(model_config: ModelConfig) -> str:
41
+ return get_model_architecture(model_config)[1]