vllm-npu 0.4.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. vllm/__init__.py +23 -0
  2. vllm/_custom_ops.py +251 -0
  3. vllm/attention/__init__.py +13 -0
  4. vllm/attention/backends/__init__.py +0 -0
  5. vllm/attention/backends/abstract.py +127 -0
  6. vllm/attention/backends/flash_attn.py +271 -0
  7. vllm/attention/backends/flashinfer.py +220 -0
  8. vllm/attention/backends/rocm_flash_attn.py +374 -0
  9. vllm/attention/backends/torch_sdpa.py +250 -0
  10. vllm/attention/backends/xformers.py +393 -0
  11. vllm/attention/layer.py +56 -0
  12. vllm/attention/ops/__init__.py +0 -0
  13. vllm/attention/ops/paged_attn.py +216 -0
  14. vllm/attention/ops/prefix_prefill.py +792 -0
  15. vllm/attention/ops/triton_flash_attention.py +810 -0
  16. vllm/attention/selector.py +91 -0
  17. vllm/block.py +84 -0
  18. vllm/config.py +1225 -0
  19. vllm/core/__init__.py +0 -0
  20. vllm/core/block/__init__.py +0 -0
  21. vllm/core/block/block_table.py +295 -0
  22. vllm/core/block/common.py +199 -0
  23. vllm/core/block/cpu_gpu_block_allocator.py +228 -0
  24. vllm/core/block/interfaces.py +205 -0
  25. vllm/core/block/naive_block.py +318 -0
  26. vllm/core/block/prefix_caching_block.py +606 -0
  27. vllm/core/block_manager_v1.py +625 -0
  28. vllm/core/block_manager_v2.py +258 -0
  29. vllm/core/evictor_v1.py +105 -0
  30. vllm/core/evictor_v2.py +127 -0
  31. vllm/core/interfaces.py +113 -0
  32. vllm/core/policy.py +45 -0
  33. vllm/core/scheduler.py +1163 -0
  34. vllm/distributed/__init__.py +3 -0
  35. vllm/distributed/communication_op.py +237 -0
  36. vllm/distributed/device_communicators/__init__.py +0 -0
  37. vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
  38. vllm/distributed/device_communicators/pynccl.py +287 -0
  39. vllm/distributed/device_communicators/pynccl_utils.py +66 -0
  40. vllm/distributed/parallel_state.py +339 -0
  41. vllm/distributed/utils.py +136 -0
  42. vllm/engine/__init__.py +0 -0
  43. vllm/engine/arg_utils.py +649 -0
  44. vllm/engine/async_llm_engine.py +737 -0
  45. vllm/engine/llm_engine.py +784 -0
  46. vllm/engine/metrics.py +368 -0
  47. vllm/engine/output_processor/__init__.py +0 -0
  48. vllm/engine/output_processor/interfaces.py +76 -0
  49. vllm/engine/output_processor/multi_step.py +142 -0
  50. vllm/engine/output_processor/single_step.py +284 -0
  51. vllm/engine/output_processor/stop_checker.py +101 -0
  52. vllm/engine/output_processor/util.py +19 -0
  53. vllm/entrypoints/__init__.py +0 -0
  54. vllm/entrypoints/api_server.py +119 -0
  55. vllm/entrypoints/llm.py +259 -0
  56. vllm/entrypoints/openai/__init__.py +0 -0
  57. vllm/entrypoints/openai/api_server.py +186 -0
  58. vllm/entrypoints/openai/cli_args.py +115 -0
  59. vllm/entrypoints/openai/protocol.py +460 -0
  60. vllm/entrypoints/openai/serving_chat.py +392 -0
  61. vllm/entrypoints/openai/serving_completion.py +347 -0
  62. vllm/entrypoints/openai/serving_engine.py +234 -0
  63. vllm/envs.py +217 -0
  64. vllm/executor/__init__.py +0 -0
  65. vllm/executor/cpu_executor.py +152 -0
  66. vllm/executor/distributed_gpu_executor.py +115 -0
  67. vllm/executor/executor_base.py +115 -0
  68. vllm/executor/gpu_executor.py +150 -0
  69. vllm/executor/multiproc_worker_utils.py +263 -0
  70. vllm/executor/neuron_executor.py +91 -0
  71. vllm/executor/ray_gpu_executor.py +327 -0
  72. vllm/executor/ray_utils.py +119 -0
  73. vllm/logger.py +153 -0
  74. vllm/logging/__init__.py +5 -0
  75. vllm/logging/formatter.py +15 -0
  76. vllm/lora/__init__.py +0 -0
  77. vllm/lora/fully_sharded_layers.py +262 -0
  78. vllm/lora/layers.py +1181 -0
  79. vllm/lora/lora.py +167 -0
  80. vllm/lora/models.py +645 -0
  81. vllm/lora/punica.py +213 -0
  82. vllm/lora/request.py +32 -0
  83. vllm/lora/utils.py +98 -0
  84. vllm/lora/worker_manager.py +251 -0
  85. vllm/model_executor/__init__.py +7 -0
  86. vllm/model_executor/guided_decoding/__init__.py +25 -0
  87. vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
  88. vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
  89. vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
  90. vllm/model_executor/layers/__init__.py +0 -0
  91. vllm/model_executor/layers/activation.py +173 -0
  92. vllm/model_executor/layers/fused_moe/__init__.py +7 -0
  93. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  94. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  97. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  98. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  99. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  100. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  101. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  102. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  103. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  104. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  105. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
  106. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  107. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  108. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  109. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  110. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
  111. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  112. vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
  113. vllm/model_executor/layers/layernorm.py +71 -0
  114. vllm/model_executor/layers/linear.py +709 -0
  115. vllm/model_executor/layers/logits_processor.py +115 -0
  116. vllm/model_executor/layers/ops/__init__.py +0 -0
  117. vllm/model_executor/layers/ops/rand.py +157 -0
  118. vllm/model_executor/layers/ops/sample.py +406 -0
  119. vllm/model_executor/layers/quantization/__init__.py +35 -0
  120. vllm/model_executor/layers/quantization/aqlm.py +376 -0
  121. vllm/model_executor/layers/quantization/awq.py +175 -0
  122. vllm/model_executor/layers/quantization/base_config.py +97 -0
  123. vllm/model_executor/layers/quantization/fp8.py +265 -0
  124. vllm/model_executor/layers/quantization/gptq.py +224 -0
  125. vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
  126. vllm/model_executor/layers/quantization/marlin.py +227 -0
  127. vllm/model_executor/layers/quantization/schema.py +84 -0
  128. vllm/model_executor/layers/quantization/squeezellm.py +137 -0
  129. vllm/model_executor/layers/rejection_sampler.py +405 -0
  130. vllm/model_executor/layers/rotary_embedding.py +525 -0
  131. vllm/model_executor/layers/sampler.py +1051 -0
  132. vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
  133. vllm/model_executor/model_loader/__init__.py +30 -0
  134. vllm/model_executor/model_loader/loader.py +362 -0
  135. vllm/model_executor/model_loader/neuron.py +136 -0
  136. vllm/model_executor/model_loader/tensorizer.py +368 -0
  137. vllm/model_executor/model_loader/utils.py +41 -0
  138. vllm/model_executor/model_loader/weight_utils.py +372 -0
  139. vllm/model_executor/models/__init__.py +119 -0
  140. vllm/model_executor/models/baichuan.py +410 -0
  141. vllm/model_executor/models/bloom.py +327 -0
  142. vllm/model_executor/models/chatglm.py +386 -0
  143. vllm/model_executor/models/commandr.py +373 -0
  144. vllm/model_executor/models/dbrx.py +413 -0
  145. vllm/model_executor/models/decilm.py +122 -0
  146. vllm/model_executor/models/deepseek.py +438 -0
  147. vllm/model_executor/models/falcon.py +444 -0
  148. vllm/model_executor/models/gemma.py +393 -0
  149. vllm/model_executor/models/gpt2.py +266 -0
  150. vllm/model_executor/models/gpt_bigcode.py +274 -0
  151. vllm/model_executor/models/gpt_j.py +281 -0
  152. vllm/model_executor/models/gpt_neox.py +295 -0
  153. vllm/model_executor/models/internlm2.py +323 -0
  154. vllm/model_executor/models/jais.py +333 -0
  155. vllm/model_executor/models/llama.py +442 -0
  156. vllm/model_executor/models/llava.py +239 -0
  157. vllm/model_executor/models/minicpm.py +531 -0
  158. vllm/model_executor/models/mixtral.py +583 -0
  159. vllm/model_executor/models/mixtral_quant.py +404 -0
  160. vllm/model_executor/models/mpt.py +295 -0
  161. vllm/model_executor/models/olmo.py +356 -0
  162. vllm/model_executor/models/opt.py +349 -0
  163. vllm/model_executor/models/orion.py +319 -0
  164. vllm/model_executor/models/phi.py +300 -0
  165. vllm/model_executor/models/qwen.py +284 -0
  166. vllm/model_executor/models/qwen2.py +367 -0
  167. vllm/model_executor/models/qwen2_moe.py +447 -0
  168. vllm/model_executor/models/stablelm.py +301 -0
  169. vllm/model_executor/models/starcoder2.py +302 -0
  170. vllm/model_executor/models/xverse.py +366 -0
  171. vllm/model_executor/sampling_metadata.py +588 -0
  172. vllm/model_executor/utils.py +35 -0
  173. vllm/outputs.py +150 -0
  174. vllm/py.typed +2 -0
  175. vllm/sampling_params.py +340 -0
  176. vllm/sequence.py +766 -0
  177. vllm/spec_decode/__init__.py +0 -0
  178. vllm/spec_decode/batch_expansion.py +397 -0
  179. vllm/spec_decode/interfaces.py +73 -0
  180. vllm/spec_decode/metrics.py +191 -0
  181. vllm/spec_decode/multi_step_worker.py +203 -0
  182. vllm/spec_decode/ngram_worker.py +176 -0
  183. vllm/spec_decode/spec_decode_worker.py +472 -0
  184. vllm/spec_decode/top1_proposer.py +200 -0
  185. vllm/spec_decode/util.py +228 -0
  186. vllm/test_utils.py +41 -0
  187. vllm/transformers_utils/__init__.py +0 -0
  188. vllm/transformers_utils/config.py +58 -0
  189. vllm/transformers_utils/configs/__init__.py +16 -0
  190. vllm/transformers_utils/configs/chatglm.py +68 -0
  191. vllm/transformers_utils/configs/dbrx.py +278 -0
  192. vllm/transformers_utils/configs/falcon.py +87 -0
  193. vllm/transformers_utils/configs/jais.py +236 -0
  194. vllm/transformers_utils/configs/mpt.py +178 -0
  195. vllm/transformers_utils/detokenizer.py +313 -0
  196. vllm/transformers_utils/tokenizer.py +149 -0
  197. vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
  198. vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
  199. vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
  200. vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
  201. vllm/transformers_utils/tokenizers/__init__.py +5 -0
  202. vllm/transformers_utils/tokenizers/baichuan.py +255 -0
  203. vllm/usage/__init__.py +0 -0
  204. vllm/usage/usage_lib.py +209 -0
  205. vllm/utils.py +677 -0
  206. vllm/worker/__init__.py +0 -0
  207. vllm/worker/cache_engine.py +105 -0
  208. vllm/worker/cpu_model_runner.py +346 -0
  209. vllm/worker/cpu_worker.py +321 -0
  210. vllm/worker/model_runner.py +1168 -0
  211. vllm/worker/neuron_model_runner.py +196 -0
  212. vllm/worker/neuron_worker.py +98 -0
  213. vllm/worker/worker.py +345 -0
  214. vllm/worker/worker_base.py +146 -0
  215. vllm_npu-0.4.2.dist-info/LICENSE +201 -0
  216. vllm_npu-0.4.2.dist-info/METADATA +173 -0
  217. vllm_npu-0.4.2.dist-info/RECORD +219 -0
  218. vllm_npu-0.4.2.dist-info/WHEEL +5 -0
  219. vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,35 @@
1
+ from typing import Dict, Type
2
+
3
+ from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
4
+ from vllm.model_executor.layers.quantization.awq import AWQConfig
5
+ from vllm.model_executor.layers.quantization.base_config import (
6
+ QuantizationConfig)
7
+ from vllm.model_executor.layers.quantization.fp8 import Fp8Config
8
+ from vllm.model_executor.layers.quantization.gptq import GPTQConfig
9
+ from vllm.model_executor.layers.quantization.gptq_marlin import (
10
+ GPTQMarlinConfig)
11
+ from vllm.model_executor.layers.quantization.marlin import MarlinConfig
12
+ from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
13
+
14
+ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
15
+ "aqlm": AQLMConfig,
16
+ "awq": AWQConfig,
17
+ "fp8": Fp8Config,
18
+ "gptq": GPTQConfig,
19
+ "squeezellm": SqueezeLLMConfig,
20
+ "gptq_marlin": GPTQMarlinConfig,
21
+ "marlin": MarlinConfig,
22
+ }
23
+
24
+
25
+ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
26
+ if quantization not in QUANTIZATION_METHODS:
27
+ raise ValueError(f"Invalid quantization method: {quantization}")
28
+ return QUANTIZATION_METHODS[quantization]
29
+
30
+
31
+ __all__ = [
32
+ "QuantizationConfig",
33
+ "get_quantization_config",
34
+ "QUANTIZATION_METHODS",
35
+ ]
@@ -0,0 +1,376 @@
1
+ # Supports AQLM compression, see https://github.com/Vahe1994/AQLM
2
+ # and https://arxiv.org/pdf/2401.06118.pdf
3
+
4
+ import math
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch.nn.parameter import Parameter
10
+
11
+ from vllm import _custom_ops as ops
12
+ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
13
+ from vllm.model_executor.layers.quantization.base_config import (
14
+ QuantizationConfig)
15
+ from vllm.model_executor.utils import set_weight_attrs
16
+
17
+
18
+ def get_int_dtype(nbits: int) -> torch.dtype:
19
+ if nbits <= 8:
20
+ return torch.int8
21
+ if nbits <= 16:
22
+ return torch.int16
23
+ if nbits <= 32:
24
+ return torch.int32
25
+ if nbits <= 64:
26
+ return torch.int64
27
+ raise ValueError(f"No dtype available for {nbits}-bit codebooks")
28
+
29
+
30
+ @torch.inference_mode()
31
+ def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor:
32
+ return data.to(torch.int64) % (2**nbits)
33
+
34
+
35
+ def dequantize_weight(codes: torch.Tensor,
36
+ codebooks: torch.Tensor,
37
+ scales: Optional[torch.Tensor] = None) -> torch.Tensor:
38
+ """
39
+ Decode float weights from quantization codes. Differentiable.
40
+ :param codes: tensor of integer quantization codes, shape
41
+ [*dims, num_out_groups, num_in_groups, num_codebooks]
42
+ :param codebooks: tensor of vectors for each quantization code,
43
+ [num_codebooks, codebook_size, out_group_size, in_group_size]
44
+ :param scales: weight will be multiplied by this factor, must be
45
+ broadcastble with
46
+ [*dims, out_groups, num_in_groups, out_group_size, in_group_size]
47
+ :return: reconstructed weight tensor of shape
48
+ [*dims, num_in_groups*group_size]
49
+ """
50
+ num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:]
51
+ num_codebooks, codebook_size, out_group_size, in_group_size = \
52
+ codebooks.shape
53
+ out_features = num_out_groups * out_group_size
54
+ in_features = num_in_groups * in_group_size
55
+ codebook_offsets = torch.arange(
56
+ 0, num_codebooks * codebook_size, codebook_size,
57
+ device=codes.device) # shape: [num_codebooks]
58
+ reconstructed_weight_flat = F.embedding_bag(
59
+ codes.flatten(0, -2) + codebook_offsets,
60
+ codebooks.flatten(0, 1).flatten(-2, -1),
61
+ mode="sum"
62
+ ) # [prod(dims) * num_out_groups * num_in_groups, out_group_size
63
+ # * in_group_size]
64
+
65
+ reconstructed_weight_groupwise = reconstructed_weight_flat.view(
66
+ list(codes.shape[:-3]) +
67
+ [num_out_groups, num_in_groups, out_group_size, in_group_size])
68
+ if scales is not None:
69
+ reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(
70
+ scales)
71
+ return reconstructed_weight_groupwise.swapaxes(
72
+ -3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features])
73
+
74
+
75
+ def dequantize_gemm(
76
+ input: torch.Tensor, # [..., in_features]
77
+ codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
78
+ codebooks: torch.
79
+ Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
80
+ scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
81
+ bias: Optional[torch.Tensor],
82
+ ) -> torch.Tensor:
83
+ dequantized_weight = dequantize_weight(
84
+ unpack_int_data(codes, codebooks.shape[1].bit_length() - 1),
85
+ codebooks,
86
+ scales,
87
+ )
88
+ return F.linear(input, dequantized_weight, bias)
89
+
90
+
91
+ # Generic dequantization, slow but flexible.
92
+ def generic_dequantize_gemm(
93
+ input: torch.Tensor, # [..., in_features]
94
+ codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
95
+ codebooks: torch.
96
+ Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
97
+ scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
98
+ output_partition_sizes: torch.IntTensor,
99
+ bias: Optional[torch.Tensor],
100
+ ) -> torch.Tensor:
101
+ output_shape = input.shape[:-1] + (scales.shape[0], )
102
+ output = torch.empty(output_shape, dtype=input.dtype, device=input.device)
103
+ num_outputs = len(output_partition_sizes)
104
+
105
+ # break the inputs and codebooks apart then combine the outputs.
106
+ # Surprisingly (to me) this is faster than doing 3 de-quants and 1 big
107
+ # multiply at the end.
108
+ num_codebooks = codebooks.shape[0] // num_outputs
109
+ assert (scales.shape[0] == codes.shape[0])
110
+ assert (sum(output_partition_sizes) == scales.shape[0])
111
+ output_offset = 0
112
+ codebooks_offset = 0
113
+ for output_size in output_partition_sizes:
114
+ shard_output = dequantize_gemm(
115
+ input, codes.narrow(0, output_offset, output_size),
116
+ codebooks.narrow(0, codebooks_offset, num_codebooks),
117
+ scales.narrow(0, output_offset, output_size), None
118
+ if bias is None else bias.narrow(0, output_offset, output_size))
119
+
120
+ output_slice = output.narrow(-1, output_offset, output_size)
121
+ assert (output_slice.shape == shard_output.shape)
122
+ output_slice.copy_(shard_output)
123
+ output_offset += output_size
124
+ codebooks_offset += num_codebooks
125
+ return output
126
+
127
+
128
+ # Optimized dequnantize/decompression kernels, supports 1x16 and 2x8
129
+ # at 6 and 9 times faster than the generic version above, respectively.
130
+ def optimized_dequantize_gemm(
131
+ input: torch.Tensor, # [..., in_features]
132
+ codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
133
+ codebooks: torch.
134
+ Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
135
+ scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
136
+ output_partition_sizes: torch.IntTensor,
137
+ bias: Optional[torch.Tensor],
138
+ ) -> torch.Tensor:
139
+ weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
140
+
141
+ if bias is None:
142
+ # scaling the output is fastest, so we do that when possible.
143
+ output = F.linear(input, weights, bias)
144
+ orig_shape = output.shape
145
+ flattened_output = output.view(-1, output.size(-1))
146
+ f_scales = scales.view(-1, scales.shape[0])
147
+ b_scales = f_scales.expand(flattened_output.shape[0], -1)
148
+ flattened_output *= b_scales
149
+ return output.view(orig_shape)
150
+ else:
151
+ b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
152
+ -1, weights.shape[1])
153
+ weights *= b_scales
154
+ return F.linear(input, weights, bias)
155
+
156
+
157
+ class AQLMConfig(QuantizationConfig):
158
+ """Config class for AQLM.
159
+
160
+ Reference: https://github.com/Vahe1994/AQLM
161
+ """
162
+
163
+ def __init__(
164
+ self,
165
+ in_group_size: int,
166
+ nbits_per_codebook: int,
167
+ num_codebooks: int,
168
+ out_group_size: int,
169
+ ) -> None:
170
+ self.in_group_size = in_group_size
171
+ self.nbits_per_codebook = nbits_per_codebook
172
+ self.num_codebooks = num_codebooks
173
+ self.out_group_size = out_group_size
174
+
175
+ # out_group_size > 1 is untested, and probably won't work as-is.
176
+ assert (self.out_group_size == 1)
177
+ self.pack_factor = (self.in_group_size * self.out_group_size)
178
+
179
+ def __repr__(self) -> str:
180
+ return (f"AQLMConfig(in_group_size={self.in_group_size}, "
181
+ f"nbits_per_codebook={self.nbits_per_codebook}, "
182
+ f"num_codebooks={self.num_codebooks}, "
183
+ f"out_group_size={self.out_group_size})")
184
+
185
+ @classmethod
186
+ def get_name(cls) -> str:
187
+ return "aqlm"
188
+
189
+ @classmethod
190
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
191
+ return [torch.half]
192
+
193
+ @classmethod
194
+ def get_min_capability(cls) -> int:
195
+ return 70
196
+
197
+ @classmethod
198
+ def get_config_filenames(cls) -> List[str]:
199
+ return [] # no extra configs.
200
+
201
+ @classmethod
202
+ def from_config(cls, config: Dict[str, Any]) -> "AQLMConfig":
203
+ in_group_size = cls.get_from_keys(config, ["in_group_size"])
204
+ nbits_per_codebook = cls.get_from_keys(config, ["nbits_per_codebook"])
205
+ num_code_books = cls.get_from_keys(config, ["num_codebooks"])
206
+ out_group_size = cls.get_from_keys(config, ["out_group_size"])
207
+ return cls(in_group_size, nbits_per_codebook, num_code_books,
208
+ out_group_size)
209
+
210
+ def get_quant_method(
211
+ self, layer: torch.nn.Module) -> Optional["AQLMLinearMethod"]:
212
+ if isinstance(layer, LinearBase):
213
+ return AQLMLinearMethod(self)
214
+ return None
215
+
216
+ def get_scaled_act_names(self) -> List[str]:
217
+ return []
218
+
219
+
220
+ class AQLMLinearMethod(LinearMethodBase):
221
+ """Linear method for AQLM.
222
+
223
+ Args:
224
+ quant_config: The AQLM quantization config.
225
+ """
226
+
227
+ def __init__(self, quant_config: AQLMConfig):
228
+ self.quant_config = quant_config
229
+
230
+ def create_weights(self, layer: torch.nn.Module,
231
+ input_size_per_partition: int,
232
+ output_partition_sizes: List[int], input_size: int,
233
+ output_size: int, params_dtype: torch.dtype,
234
+ **extra_weight_attrs):
235
+ del output_size # Unused.
236
+ del input_size # Unused.
237
+
238
+ if params_dtype != torch.half:
239
+ raise ValueError("Only half is currently supported by aqlm")
240
+ if input_size_per_partition % self.quant_config.in_group_size != 0:
241
+ raise ValueError(
242
+ "The input size is not aligned with the quantized "
243
+ "weight shape. This can be caused by too large "
244
+ "tensor parallel size.")
245
+
246
+ output_size_per_partition = sum(output_partition_sizes)
247
+ if output_size_per_partition % self.quant_config.out_group_size != 0:
248
+ raise ValueError(
249
+ "The output size is not aligned with the quantized "
250
+ "weight shape. This can be caused by too large "
251
+ "tensor parallel size.")
252
+
253
+ codes = Parameter(
254
+ torch.empty(
255
+ # There could actually be two pack factors, one along input and
256
+ # one along output, but we don't currently support
257
+ # out_group_size, and only the one along output needs to be
258
+ # marked with "packed_dim" in order for QKVLinear to work.
259
+ output_size_per_partition,
260
+ input_size_per_partition // self.quant_config.pack_factor,
261
+ self.quant_config.num_codebooks,
262
+ dtype=get_int_dtype(self.quant_config.nbits_per_codebook),
263
+ ),
264
+ requires_grad=False,
265
+ )
266
+
267
+ set_weight_attrs(
268
+ codes,
269
+ {
270
+ "input_dim": 1,
271
+ "output_dim": 0,
272
+ "packed_dim": 1,
273
+ "pack_factor": self.quant_config.pack_factor,
274
+ },
275
+ )
276
+
277
+ codebooks = Parameter(
278
+ torch.empty(
279
+ self.quant_config.num_codebooks * len(output_partition_sizes),
280
+ 2**self.quant_config.nbits_per_codebook,
281
+ self.quant_config.out_group_size,
282
+ self.quant_config.in_group_size,
283
+ dtype=params_dtype,
284
+ ),
285
+ requires_grad=False,
286
+ )
287
+ set_weight_attrs(
288
+ codebooks,
289
+ {
290
+ # metadata indicates fixed size concatenated along dim 0
291
+ "is_metadata":
292
+ True,
293
+ "output_partition_sizes":
294
+ torch.tensor(output_partition_sizes, device='cpu'),
295
+ },
296
+ )
297
+
298
+ scales = Parameter(
299
+ torch.empty(
300
+ (
301
+ output_size_per_partition //
302
+ self.quant_config.out_group_size,
303
+ 1,
304
+ 1,
305
+ 1,
306
+ ),
307
+ dtype=params_dtype,
308
+ ),
309
+ requires_grad=False,
310
+ )
311
+ set_weight_attrs(
312
+ scales,
313
+ {
314
+ "output_dim": 0,
315
+ "packed_dim": 0,
316
+ "pack_factor": self.quant_config.out_group_size
317
+ },
318
+ )
319
+
320
+ layer.register_parameter("codes", codes)
321
+ set_weight_attrs(codes, extra_weight_attrs)
322
+ layer.register_parameter("codebooks", codebooks)
323
+ set_weight_attrs(codebooks, extra_weight_attrs)
324
+ layer.register_parameter("scales", scales)
325
+ set_weight_attrs(scales, extra_weight_attrs)
326
+
327
+ def apply(
328
+ self,
329
+ layer: torch.nn.Module,
330
+ x: torch.Tensor,
331
+ bias: Optional[torch.Tensor] = None,
332
+ ) -> torch.Tensor:
333
+ codebooks = layer.codebooks
334
+ codes = layer.codes
335
+ scales = layer.scales
336
+ output_partition_sizes = getattr(codebooks, "output_partition_sizes",
337
+ None)
338
+
339
+ nbooks = codes.shape[2]
340
+ ingroups = codebooks.shape[3]
341
+ outgroups = codebooks.shape[2]
342
+ bits = codebooks.shape[1]
343
+
344
+ # We support these formats with dedicated gemm and decompression
345
+ # kernels.
346
+ if ingroups == 8 and outgroups == 1 and (
347
+ (bits == 256 and nbooks == 2) or (bits == 65536 and nbooks == 1)):
348
+
349
+ # thresholds determined by timings on an A6000, one GPU
350
+ use_gemv = math.prod(x.shape[:-1]) <= 6
351
+
352
+ return ops.aqlm_gemm(
353
+ x,
354
+ codes,
355
+ codebooks,
356
+ scales,
357
+ output_partition_sizes,
358
+ bias,
359
+ ) if use_gemv else optimized_dequantize_gemm(
360
+ x,
361
+ codes,
362
+ codebooks,
363
+ scales,
364
+ output_partition_sizes,
365
+ bias,
366
+ )
367
+
368
+ # fall back all unoptimized formats
369
+ return generic_dequantize_gemm(
370
+ x,
371
+ codes,
372
+ codebooks,
373
+ scales,
374
+ output_partition_sizes,
375
+ bias,
376
+ )
@@ -0,0 +1,175 @@
1
+ from typing import Any, Dict, List, Optional
2
+
3
+ import torch
4
+ from torch.nn.parameter import Parameter
5
+
6
+ from vllm import _custom_ops as ops
7
+ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
8
+ from vllm.model_executor.layers.quantization.base_config import (
9
+ QuantizationConfig)
10
+ from vllm.model_executor.utils import set_weight_attrs
11
+
12
+
13
+ class AWQConfig(QuantizationConfig):
14
+ """Config class for AWQ.
15
+
16
+ Reference: https://arxiv.org/abs/2306.00978
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ weight_bits: int,
22
+ group_size: int,
23
+ zero_point: bool,
24
+ ) -> None:
25
+ self.weight_bits = weight_bits
26
+ self.group_size = group_size
27
+ self.zero_point = zero_point
28
+
29
+ if self.weight_bits != 4:
30
+ raise ValueError(
31
+ "Currently, only 4-bit weight quantization is supported for "
32
+ f"AWQ, but got {self.weight_bits} bits.")
33
+ self.pack_factor = 32 // self.weight_bits
34
+
35
+ def __repr__(self) -> str:
36
+ return (f"AWQConfig(weight_bits={self.weight_bits}, "
37
+ f"group_size={self.group_size}, "
38
+ f"zero_point={self.zero_point})")
39
+
40
+ def get_name(self) -> str:
41
+ return "awq"
42
+
43
+ def get_supported_act_dtypes(self) -> List[torch.dtype]:
44
+ return [torch.half]
45
+
46
+ def get_min_capability(self) -> int:
47
+ # The AWQ kernel only supports Turing or newer GPUs.
48
+ return 75
49
+
50
+ @staticmethod
51
+ def get_config_filenames() -> List[str]:
52
+ return [
53
+ "quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
54
+ # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
55
+ "quantize_config.json",
56
+ ]
57
+
58
+ @classmethod
59
+ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
60
+ weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
61
+ group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
62
+ zero_point = cls.get_from_keys(config, ["zero_point"])
63
+ return cls(weight_bits, group_size, zero_point)
64
+
65
+ def get_quant_method(
66
+ self, layer: torch.nn.Module) -> Optional["AWQLinearMethod"]:
67
+ if isinstance(layer, LinearBase):
68
+ return AWQLinearMethod(self)
69
+ return None
70
+
71
+ def get_scaled_act_names(self) -> List[str]:
72
+ return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
73
+
74
+
75
+ class AWQLinearMethod(LinearMethodBase):
76
+ """Linear method for AWQ.
77
+
78
+ Args:
79
+ quant_config: The AWQ quantization config.
80
+ """
81
+
82
+ def __init__(self, quant_config: AWQConfig):
83
+ self.quant_config = quant_config
84
+
85
+ def create_weights(self, layer: torch.nn.Module,
86
+ input_size_per_partition: int,
87
+ output_partition_sizes: List[int], input_size: int,
88
+ output_size: int, params_dtype: torch.dtype,
89
+ **extra_weight_attrs):
90
+ if input_size_per_partition % self.quant_config.group_size != 0:
91
+ raise ValueError(
92
+ "The input size is not aligned with the quantized "
93
+ "weight shape. This can be caused by too large "
94
+ "tensor parallel size.")
95
+
96
+ output_size_per_partition = sum(output_partition_sizes)
97
+ if output_size_per_partition % self.quant_config.pack_factor != 0:
98
+ raise ValueError(
99
+ "The output size is not aligned with the quantized "
100
+ "weight shape. This can be caused by too large "
101
+ "tensor parallel size.")
102
+
103
+ qweight = Parameter(
104
+ torch.empty(
105
+ input_size_per_partition,
106
+ output_size_per_partition // self.quant_config.pack_factor,
107
+ dtype=torch.int32,
108
+ ),
109
+ requires_grad=False,
110
+ )
111
+ set_weight_attrs(
112
+ qweight, {
113
+ "input_dim": 0,
114
+ "output_dim": 1,
115
+ "packed_dim": 1,
116
+ "pack_factor": self.quant_config.pack_factor,
117
+ })
118
+ qzeros = Parameter(
119
+ torch.empty(
120
+ input_size_per_partition // self.quant_config.group_size,
121
+ output_size_per_partition // self.quant_config.pack_factor,
122
+ dtype=torch.int32,
123
+ ),
124
+ requires_grad=False,
125
+ )
126
+ set_weight_attrs(
127
+ qzeros, {
128
+ "input_dim": 0,
129
+ "output_dim": 1,
130
+ "packed_dim": 1,
131
+ "pack_factor": self.quant_config.pack_factor,
132
+ })
133
+ scales = Parameter(
134
+ torch.empty(
135
+ input_size_per_partition // self.quant_config.group_size,
136
+ output_size_per_partition,
137
+ dtype=params_dtype,
138
+ ),
139
+ requires_grad=False,
140
+ )
141
+ set_weight_attrs(scales, {
142
+ "input_dim": 0,
143
+ "output_dim": 1,
144
+ })
145
+
146
+ layer.register_parameter("qweight", qweight)
147
+ set_weight_attrs(qweight, extra_weight_attrs)
148
+ layer.register_parameter("qzeros", qzeros)
149
+ set_weight_attrs(qzeros, extra_weight_attrs)
150
+ layer.register_parameter("scales", scales)
151
+ set_weight_attrs(scales, extra_weight_attrs)
152
+
153
+ def apply(self,
154
+ layer: torch.nn.Module,
155
+ x: torch.Tensor,
156
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
157
+ qweight = layer.qweight
158
+ scales = layer.scales
159
+ qzeros = layer.qzeros
160
+ pack_factor = self.quant_config.pack_factor
161
+ out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
162
+ reshaped_x = x.reshape(-1, x.shape[-1])
163
+
164
+ # num_tokens >= threshold
165
+ FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
166
+
167
+ if FP16_MATMUL_HEURISTIC_CONDITION:
168
+ out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
169
+ out = torch.matmul(reshaped_x, out)
170
+ else:
171
+ out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
172
+ pack_factor)
173
+ if bias is not None:
174
+ out.add_(bias)
175
+ return out.reshape(out_shape)
@@ -0,0 +1,97 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+
8
+ class QuantizeMethodBase(ABC):
9
+ """Base class for different quantized methods."""
10
+
11
+ @abstractmethod
12
+ def create_weights(self, layer: torch.nn.Module, *weight_args,
13
+ **extra_weight_attrs):
14
+ """Create weights for a layer.
15
+
16
+ The weights will be set as attributes of the layer."""
17
+ raise NotImplementedError
18
+
19
+ @abstractmethod
20
+ def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
21
+ """Apply the weights in layer to the input tensor.
22
+
23
+ Expects create_weights to have been called before on the layer."""
24
+ raise NotImplementedError
25
+
26
+ def process_weights_after_loading(self, layer: nn.Module) -> None:
27
+ """Process the weight after loading.
28
+
29
+ This can be used for example, to transpose weights for computation.
30
+ """
31
+ return
32
+
33
+
34
+ class QuantizationConfig(ABC):
35
+ """Base class for quantization configs."""
36
+
37
+ @abstractmethod
38
+ def get_name(self) -> str:
39
+ """Name of the quantization method."""
40
+ raise NotImplementedError
41
+
42
+ @abstractmethod
43
+ def get_supported_act_dtypes(self) -> List[torch.dtype]:
44
+ """List of supported activation dtypes."""
45
+ raise NotImplementedError
46
+
47
+ @abstractmethod
48
+ def get_min_capability(self) -> int:
49
+ """Minimum GPU capability to support the quantization method.
50
+
51
+ E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
52
+ This requirement is due to the custom CUDA kernels used by the
53
+ quantization method.
54
+ """
55
+ raise NotImplementedError
56
+
57
+ @staticmethod
58
+ @abstractmethod
59
+ def get_config_filenames() -> List[str]:
60
+ """List of filenames to search for in the model directory."""
61
+ raise NotImplementedError
62
+
63
+ @classmethod
64
+ @abstractmethod
65
+ def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
66
+ """Create a config class from the model's quantization config."""
67
+ raise NotImplementedError
68
+
69
+ @staticmethod
70
+ def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
71
+ """Get a value from the model's quantization config."""
72
+ for key in keys:
73
+ if key in config:
74
+ return config[key]
75
+ raise ValueError(f"Cannot find any of {keys} in the model's "
76
+ "quantization config.")
77
+
78
+ @abstractmethod
79
+ def get_quant_method(
80
+ self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]:
81
+ """Get the quantize method to use for the quantized layer.
82
+
83
+ Args:
84
+ layer: The layer for the quant method.
85
+ Returns:
86
+ The quantize method. None if the given layer doesn't support quant
87
+ method.
88
+ """
89
+ raise NotImplementedError
90
+
91
+ @abstractmethod
92
+ def get_scaled_act_names(self) -> List[str]:
93
+ """Returns the activation function names that should be post-scaled.
94
+
95
+ For now, this is only used by AWQ.
96
+ """
97
+ raise NotImplementedError