vllm-npu 0.4.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. vllm/__init__.py +23 -0
  2. vllm/_custom_ops.py +251 -0
  3. vllm/attention/__init__.py +13 -0
  4. vllm/attention/backends/__init__.py +0 -0
  5. vllm/attention/backends/abstract.py +127 -0
  6. vllm/attention/backends/flash_attn.py +271 -0
  7. vllm/attention/backends/flashinfer.py +220 -0
  8. vllm/attention/backends/rocm_flash_attn.py +374 -0
  9. vllm/attention/backends/torch_sdpa.py +250 -0
  10. vllm/attention/backends/xformers.py +393 -0
  11. vllm/attention/layer.py +56 -0
  12. vllm/attention/ops/__init__.py +0 -0
  13. vllm/attention/ops/paged_attn.py +216 -0
  14. vllm/attention/ops/prefix_prefill.py +792 -0
  15. vllm/attention/ops/triton_flash_attention.py +810 -0
  16. vllm/attention/selector.py +91 -0
  17. vllm/block.py +84 -0
  18. vllm/config.py +1225 -0
  19. vllm/core/__init__.py +0 -0
  20. vllm/core/block/__init__.py +0 -0
  21. vllm/core/block/block_table.py +295 -0
  22. vllm/core/block/common.py +199 -0
  23. vllm/core/block/cpu_gpu_block_allocator.py +228 -0
  24. vllm/core/block/interfaces.py +205 -0
  25. vllm/core/block/naive_block.py +318 -0
  26. vllm/core/block/prefix_caching_block.py +606 -0
  27. vllm/core/block_manager_v1.py +625 -0
  28. vllm/core/block_manager_v2.py +258 -0
  29. vllm/core/evictor_v1.py +105 -0
  30. vllm/core/evictor_v2.py +127 -0
  31. vllm/core/interfaces.py +113 -0
  32. vllm/core/policy.py +45 -0
  33. vllm/core/scheduler.py +1163 -0
  34. vllm/distributed/__init__.py +3 -0
  35. vllm/distributed/communication_op.py +237 -0
  36. vllm/distributed/device_communicators/__init__.py +0 -0
  37. vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
  38. vllm/distributed/device_communicators/pynccl.py +287 -0
  39. vllm/distributed/device_communicators/pynccl_utils.py +66 -0
  40. vllm/distributed/parallel_state.py +339 -0
  41. vllm/distributed/utils.py +136 -0
  42. vllm/engine/__init__.py +0 -0
  43. vllm/engine/arg_utils.py +649 -0
  44. vllm/engine/async_llm_engine.py +737 -0
  45. vllm/engine/llm_engine.py +784 -0
  46. vllm/engine/metrics.py +368 -0
  47. vllm/engine/output_processor/__init__.py +0 -0
  48. vllm/engine/output_processor/interfaces.py +76 -0
  49. vllm/engine/output_processor/multi_step.py +142 -0
  50. vllm/engine/output_processor/single_step.py +284 -0
  51. vllm/engine/output_processor/stop_checker.py +101 -0
  52. vllm/engine/output_processor/util.py +19 -0
  53. vllm/entrypoints/__init__.py +0 -0
  54. vllm/entrypoints/api_server.py +119 -0
  55. vllm/entrypoints/llm.py +259 -0
  56. vllm/entrypoints/openai/__init__.py +0 -0
  57. vllm/entrypoints/openai/api_server.py +186 -0
  58. vllm/entrypoints/openai/cli_args.py +115 -0
  59. vllm/entrypoints/openai/protocol.py +460 -0
  60. vllm/entrypoints/openai/serving_chat.py +392 -0
  61. vllm/entrypoints/openai/serving_completion.py +347 -0
  62. vllm/entrypoints/openai/serving_engine.py +234 -0
  63. vllm/envs.py +217 -0
  64. vllm/executor/__init__.py +0 -0
  65. vllm/executor/cpu_executor.py +152 -0
  66. vllm/executor/distributed_gpu_executor.py +115 -0
  67. vllm/executor/executor_base.py +115 -0
  68. vllm/executor/gpu_executor.py +150 -0
  69. vllm/executor/multiproc_worker_utils.py +263 -0
  70. vllm/executor/neuron_executor.py +91 -0
  71. vllm/executor/ray_gpu_executor.py +327 -0
  72. vllm/executor/ray_utils.py +119 -0
  73. vllm/logger.py +153 -0
  74. vllm/logging/__init__.py +5 -0
  75. vllm/logging/formatter.py +15 -0
  76. vllm/lora/__init__.py +0 -0
  77. vllm/lora/fully_sharded_layers.py +262 -0
  78. vllm/lora/layers.py +1181 -0
  79. vllm/lora/lora.py +167 -0
  80. vllm/lora/models.py +645 -0
  81. vllm/lora/punica.py +213 -0
  82. vllm/lora/request.py +32 -0
  83. vllm/lora/utils.py +98 -0
  84. vllm/lora/worker_manager.py +251 -0
  85. vllm/model_executor/__init__.py +7 -0
  86. vllm/model_executor/guided_decoding/__init__.py +25 -0
  87. vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
  88. vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
  89. vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
  90. vllm/model_executor/layers/__init__.py +0 -0
  91. vllm/model_executor/layers/activation.py +173 -0
  92. vllm/model_executor/layers/fused_moe/__init__.py +7 -0
  93. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  94. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  97. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  98. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  99. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  100. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  101. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  102. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  103. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  104. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  105. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
  106. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  107. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  108. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  109. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  110. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
  111. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  112. vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
  113. vllm/model_executor/layers/layernorm.py +71 -0
  114. vllm/model_executor/layers/linear.py +709 -0
  115. vllm/model_executor/layers/logits_processor.py +115 -0
  116. vllm/model_executor/layers/ops/__init__.py +0 -0
  117. vllm/model_executor/layers/ops/rand.py +157 -0
  118. vllm/model_executor/layers/ops/sample.py +406 -0
  119. vllm/model_executor/layers/quantization/__init__.py +35 -0
  120. vllm/model_executor/layers/quantization/aqlm.py +376 -0
  121. vllm/model_executor/layers/quantization/awq.py +175 -0
  122. vllm/model_executor/layers/quantization/base_config.py +97 -0
  123. vllm/model_executor/layers/quantization/fp8.py +265 -0
  124. vllm/model_executor/layers/quantization/gptq.py +224 -0
  125. vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
  126. vllm/model_executor/layers/quantization/marlin.py +227 -0
  127. vllm/model_executor/layers/quantization/schema.py +84 -0
  128. vllm/model_executor/layers/quantization/squeezellm.py +137 -0
  129. vllm/model_executor/layers/rejection_sampler.py +405 -0
  130. vllm/model_executor/layers/rotary_embedding.py +525 -0
  131. vllm/model_executor/layers/sampler.py +1051 -0
  132. vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
  133. vllm/model_executor/model_loader/__init__.py +30 -0
  134. vllm/model_executor/model_loader/loader.py +362 -0
  135. vllm/model_executor/model_loader/neuron.py +136 -0
  136. vllm/model_executor/model_loader/tensorizer.py +368 -0
  137. vllm/model_executor/model_loader/utils.py +41 -0
  138. vllm/model_executor/model_loader/weight_utils.py +372 -0
  139. vllm/model_executor/models/__init__.py +119 -0
  140. vllm/model_executor/models/baichuan.py +410 -0
  141. vllm/model_executor/models/bloom.py +327 -0
  142. vllm/model_executor/models/chatglm.py +386 -0
  143. vllm/model_executor/models/commandr.py +373 -0
  144. vllm/model_executor/models/dbrx.py +413 -0
  145. vllm/model_executor/models/decilm.py +122 -0
  146. vllm/model_executor/models/deepseek.py +438 -0
  147. vllm/model_executor/models/falcon.py +444 -0
  148. vllm/model_executor/models/gemma.py +393 -0
  149. vllm/model_executor/models/gpt2.py +266 -0
  150. vllm/model_executor/models/gpt_bigcode.py +274 -0
  151. vllm/model_executor/models/gpt_j.py +281 -0
  152. vllm/model_executor/models/gpt_neox.py +295 -0
  153. vllm/model_executor/models/internlm2.py +323 -0
  154. vllm/model_executor/models/jais.py +333 -0
  155. vllm/model_executor/models/llama.py +442 -0
  156. vllm/model_executor/models/llava.py +239 -0
  157. vllm/model_executor/models/minicpm.py +531 -0
  158. vllm/model_executor/models/mixtral.py +583 -0
  159. vllm/model_executor/models/mixtral_quant.py +404 -0
  160. vllm/model_executor/models/mpt.py +295 -0
  161. vllm/model_executor/models/olmo.py +356 -0
  162. vllm/model_executor/models/opt.py +349 -0
  163. vllm/model_executor/models/orion.py +319 -0
  164. vllm/model_executor/models/phi.py +300 -0
  165. vllm/model_executor/models/qwen.py +284 -0
  166. vllm/model_executor/models/qwen2.py +367 -0
  167. vllm/model_executor/models/qwen2_moe.py +447 -0
  168. vllm/model_executor/models/stablelm.py +301 -0
  169. vllm/model_executor/models/starcoder2.py +302 -0
  170. vllm/model_executor/models/xverse.py +366 -0
  171. vllm/model_executor/sampling_metadata.py +588 -0
  172. vllm/model_executor/utils.py +35 -0
  173. vllm/outputs.py +150 -0
  174. vllm/py.typed +2 -0
  175. vllm/sampling_params.py +340 -0
  176. vllm/sequence.py +766 -0
  177. vllm/spec_decode/__init__.py +0 -0
  178. vllm/spec_decode/batch_expansion.py +397 -0
  179. vllm/spec_decode/interfaces.py +73 -0
  180. vllm/spec_decode/metrics.py +191 -0
  181. vllm/spec_decode/multi_step_worker.py +203 -0
  182. vllm/spec_decode/ngram_worker.py +176 -0
  183. vllm/spec_decode/spec_decode_worker.py +472 -0
  184. vllm/spec_decode/top1_proposer.py +200 -0
  185. vllm/spec_decode/util.py +228 -0
  186. vllm/test_utils.py +41 -0
  187. vllm/transformers_utils/__init__.py +0 -0
  188. vllm/transformers_utils/config.py +58 -0
  189. vllm/transformers_utils/configs/__init__.py +16 -0
  190. vllm/transformers_utils/configs/chatglm.py +68 -0
  191. vllm/transformers_utils/configs/dbrx.py +278 -0
  192. vllm/transformers_utils/configs/falcon.py +87 -0
  193. vllm/transformers_utils/configs/jais.py +236 -0
  194. vllm/transformers_utils/configs/mpt.py +178 -0
  195. vllm/transformers_utils/detokenizer.py +313 -0
  196. vllm/transformers_utils/tokenizer.py +149 -0
  197. vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
  198. vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
  199. vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
  200. vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
  201. vllm/transformers_utils/tokenizers/__init__.py +5 -0
  202. vllm/transformers_utils/tokenizers/baichuan.py +255 -0
  203. vllm/usage/__init__.py +0 -0
  204. vllm/usage/usage_lib.py +209 -0
  205. vllm/utils.py +677 -0
  206. vllm/worker/__init__.py +0 -0
  207. vllm/worker/cache_engine.py +105 -0
  208. vllm/worker/cpu_model_runner.py +346 -0
  209. vllm/worker/cpu_worker.py +321 -0
  210. vllm/worker/model_runner.py +1168 -0
  211. vllm/worker/neuron_model_runner.py +196 -0
  212. vllm/worker/neuron_worker.py +98 -0
  213. vllm/worker/worker.py +345 -0
  214. vllm/worker/worker_base.py +146 -0
  215. vllm_npu-0.4.2.dist-info/LICENSE +201 -0
  216. vllm_npu-0.4.2.dist-info/METADATA +173 -0
  217. vllm_npu-0.4.2.dist-info/RECORD +219 -0
  218. vllm_npu-0.4.2.dist-info/WHEEL +5 -0
  219. vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,438 @@
1
+ import enum
2
+ from enum import Enum
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ import torch
6
+ from torch.nn.parameter import Parameter
7
+
8
+ from vllm import _custom_ops as ops
9
+ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
10
+ set_weight_attrs)
11
+ from vllm.model_executor.layers.quantization.base_config import (
12
+ QuantizationConfig)
13
+
14
+ GPTQ_MARLIN_TILE = 16
15
+ GPTQ_MARLIN_MIN_THREAD_N = 64
16
+ GPTQ_MARLIN_MIN_THREAD_K = 128
17
+ GPTQ_MARLIN_MAX_PARALLEL = 16
18
+
19
+ GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]
20
+ GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
21
+ GPTQ_MARLIN_SUPPORTED_SYM = [True]
22
+
23
+
24
+ # Permutations for Marlin scale shuffling
25
+ def get_scale_perms(num_bits):
26
+ scale_perm = []
27
+ for i in range(8):
28
+ scale_perm.extend([i + 8 * j for j in range(8)])
29
+ scale_perm_single = []
30
+ for i in range(4):
31
+ scale_perm_single.extend(
32
+ [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
33
+ return scale_perm, scale_perm_single
34
+
35
+
36
+ def get_pack_factor(num_bits):
37
+ assert (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
38
+ ), f"Unsupported num_bits = {num_bits}"
39
+ return 32 // num_bits
40
+
41
+
42
+ def marlin_permute_scales(s, size_k, size_n, group_size, num_bits):
43
+ scale_perm, scale_perm_single = get_scale_perms(num_bits)
44
+ if group_size < size_k and group_size != -1:
45
+ s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
46
+ else:
47
+ s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
48
+ s = s.reshape((-1, size_n)).contiguous()
49
+
50
+ return s
51
+
52
+
53
+ class GPTQMarlinConfig(QuantizationConfig):
54
+ """Config class for GPTQ Marlin"""
55
+
56
+ def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
57
+ is_sym: bool) -> None:
58
+ if desc_act and group_size == -1:
59
+ # In this case, act_order == True is the same as act_order == False
60
+ # (since we have only one group per output channel)
61
+ desc_act = False
62
+
63
+ self.weight_bits = weight_bits
64
+ self.group_size = group_size
65
+ self.desc_act = desc_act
66
+ self.is_sym = is_sym
67
+
68
+ # Verify
69
+ if self.weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
70
+ raise ValueError(
71
+ f"Marlin does not support weight_bits = {self.weight_bits}. "
72
+ f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} "
73
+ "are supported.")
74
+ if self.group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
75
+ raise ValueError(
76
+ f"Marlin does not support group_size = {self.group_size}. "
77
+ f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} "
78
+ "are supported.")
79
+ if self.is_sym not in GPTQ_MARLIN_SUPPORTED_SYM:
80
+ raise ValueError(
81
+ f"Marlin does not support is_sym = {self.is_sym}. "
82
+ f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.")
83
+
84
+ # Init
85
+ self.pack_factor = get_pack_factor(weight_bits)
86
+ self.tile_size = GPTQ_MARLIN_TILE
87
+ self.min_thread_n = GPTQ_MARLIN_MIN_THREAD_N
88
+ self.min_thread_k = GPTQ_MARLIN_MIN_THREAD_K
89
+ self.max_parallel = GPTQ_MARLIN_MAX_PARALLEL
90
+
91
+ def __repr__(self) -> str:
92
+ return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
93
+ f"group_size={self.group_size}, "
94
+ f"desc_act={self.desc_act})")
95
+
96
+ @classmethod
97
+ def get_name(cls) -> str:
98
+ return "gptq_marlin"
99
+
100
+ @classmethod
101
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
102
+ return [torch.half]
103
+
104
+ @classmethod
105
+ def get_min_capability(cls) -> int:
106
+ return 80
107
+
108
+ @classmethod
109
+ def get_config_filenames(cls) -> List[str]:
110
+ return ["quantize_config.json"]
111
+
112
+ @classmethod
113
+ def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
114
+ weight_bits = cls.get_from_keys(config, ["bits"])
115
+ group_size = cls.get_from_keys(config, ["group_size"])
116
+ desc_act = cls.get_from_keys(config, ["desc_act"])
117
+ is_sym = cls.get_from_keys(config, ["sym"])
118
+ return cls(weight_bits, group_size, desc_act, is_sym)
119
+
120
+ def get_quant_method(
121
+ self,
122
+ layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]:
123
+ if isinstance(layer, LinearBase):
124
+ return GPTQMarlinLinearMethod(self)
125
+ return None
126
+
127
+ def get_scaled_act_names(self) -> List[str]:
128
+ return []
129
+
130
+ @classmethod
131
+ def is_marlin_compatible(cls, quant_config: Dict[str, Any]):
132
+ # Extract data from quant config.
133
+ num_bits = quant_config.get("bits", None)
134
+ group_size = quant_config.get("group_size", None)
135
+ sym = quant_config.get("sym", None)
136
+ desc_act = quant_config.get("desc_act", None)
137
+
138
+ # If we cannot find the info needed in the config, cannot convert.
139
+ if (num_bits is None or group_size is None or sym is None
140
+ or desc_act is None):
141
+ return False
142
+
143
+ # If the capability of the device is too low, cannot convert.
144
+ major, minor = torch.cuda.get_device_capability()
145
+ device_capability = major * 10 + minor
146
+ if device_capability < cls.get_min_capability():
147
+ return False
148
+
149
+ # Otherwise, can convert if model satisfies marlin constraints.
150
+ return (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
151
+ and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
152
+ and sym in GPTQ_MARLIN_SUPPORTED_SYM)
153
+
154
+
155
+ class GPTQMarlinState(Enum):
156
+ REPACK = enum.auto()
157
+ READY = enum.auto()
158
+
159
+
160
+ class GPTQMarlinLinearMethod(LinearMethodBase):
161
+ """Linear method for GPTQ Marlin.
162
+
163
+ Args:
164
+ quant_config: The GPTQ Marlin quantization config.
165
+ """
166
+
167
+ def __init__(self, quant_config: GPTQMarlinConfig) -> None:
168
+ self.quant_config = quant_config
169
+
170
+ def create_weights(
171
+ self,
172
+ layer: torch.nn.Module,
173
+ input_size_per_partition: int,
174
+ output_partition_sizes: List[int],
175
+ input_size: int,
176
+ output_size: int,
177
+ params_dtype: torch.dtype,
178
+ **extra_weight_attrs,
179
+ ) -> None:
180
+ del output_size
181
+
182
+ # Normalize group_size
183
+ if self.quant_config.group_size != -1:
184
+ group_size = self.quant_config.group_size
185
+ else:
186
+ group_size = input_size
187
+
188
+ # Validate dtype
189
+ if params_dtype != torch.float16:
190
+ raise ValueError(
191
+ f"The params dtype must be float16, but got {params_dtype}")
192
+
193
+ # Validate output_size_per_partition
194
+ output_size_per_partition = sum(output_partition_sizes)
195
+ if output_size_per_partition % self.quant_config.min_thread_n != 0:
196
+ raise ValueError(
197
+ f"Weight output_size_per_partition = "
198
+ f"{output_size_per_partition} is not divisible by "
199
+ f" min_thread_n = {self.quant_config.min_thread_n}.")
200
+
201
+ # Validate input_size_per_partition
202
+ if input_size_per_partition % self.quant_config.min_thread_k != 0:
203
+ raise ValueError(
204
+ f"Weight input_size_per_partition = "
205
+ f"{input_size_per_partition} is not divisible "
206
+ f"by min_thread_k = {self.quant_config.min_thread_k}.")
207
+
208
+ if (group_size < input_size
209
+ and input_size_per_partition % group_size != 0):
210
+ raise ValueError(
211
+ f"Weight input_size_per_partition = {input_size_per_partition}"
212
+ f" is not divisible by group_size = {group_size}.")
213
+
214
+ # Detect sharding of scales/zp
215
+
216
+ # By default, no sharding over "input dim"
217
+ scales_and_zp_size = input_size // group_size
218
+ scales_and_zp_input_dim = None
219
+
220
+ if self.quant_config.desc_act:
221
+ # Act-order case
222
+ assert self.quant_config.group_size != -1
223
+
224
+ is_k_full = input_size_per_partition == input_size
225
+
226
+ else:
227
+ # No act-order case
228
+
229
+ # K is always full due to full alignment with
230
+ # group-size and shard of scales/zp
231
+ is_k_full = True
232
+
233
+ # If this is a row-parallel case, then shard scales/zp
234
+ if (input_size != input_size_per_partition
235
+ and self.quant_config.group_size != -1):
236
+ scales_and_zp_size = input_size_per_partition // group_size
237
+ scales_and_zp_input_dim = 0
238
+
239
+ # Init buffers
240
+
241
+ # Quantized weights
242
+ qweight = Parameter(
243
+ torch.empty(
244
+ input_size_per_partition // self.quant_config.pack_factor,
245
+ output_size_per_partition,
246
+ dtype=torch.int32,
247
+ ),
248
+ requires_grad=False,
249
+ )
250
+ set_weight_attrs(
251
+ qweight,
252
+ {
253
+ **extra_weight_attrs,
254
+ "input_dim": 0,
255
+ "output_dim": 1,
256
+ "packed_dim": 0,
257
+ "pack_factor": self.quant_config.pack_factor,
258
+ },
259
+ )
260
+
261
+ # Activation order
262
+ g_idx = Parameter(
263
+ torch.empty(
264
+ input_size_per_partition,
265
+ dtype=torch.int32,
266
+ ),
267
+ requires_grad=False,
268
+ )
269
+ # Ignore warning from fused linear layers such as QKVParallelLinear.
270
+ set_weight_attrs(
271
+ g_idx,
272
+ {
273
+ **extra_weight_attrs, "input_dim": 0,
274
+ "ignore_warning": True
275
+ },
276
+ )
277
+
278
+ g_idx_sort_indices = Parameter(
279
+ torch.empty(
280
+ g_idx.shape,
281
+ dtype=torch.int32,
282
+ ),
283
+ requires_grad=False,
284
+ )
285
+ set_weight_attrs(g_idx_sort_indices, extra_weight_attrs)
286
+
287
+ # Scales
288
+ scales = Parameter(
289
+ torch.empty(
290
+ scales_and_zp_size,
291
+ output_size_per_partition,
292
+ dtype=params_dtype,
293
+ ),
294
+ requires_grad=False,
295
+ )
296
+ set_weight_attrs(
297
+ scales,
298
+ {
299
+ **extra_weight_attrs,
300
+ "input_dim": scales_and_zp_input_dim,
301
+ "output_dim": 1,
302
+ },
303
+ )
304
+
305
+ # Quantized zero-points
306
+ qzeros = Parameter(
307
+ torch.empty(
308
+ scales_and_zp_size,
309
+ output_size_per_partition // self.quant_config.pack_factor,
310
+ dtype=torch.int32,
311
+ device="meta",
312
+ ),
313
+ requires_grad=False,
314
+ )
315
+ set_weight_attrs(
316
+ qzeros,
317
+ {
318
+ **extra_weight_attrs,
319
+ "input_dim": scales_and_zp_input_dim,
320
+ "output_dim": 1,
321
+ "packed_dim": 1,
322
+ "pack_factor": self.quant_config.pack_factor,
323
+ },
324
+ )
325
+
326
+ # Allocate marlin workspace
327
+ max_workspace_size = (
328
+ output_size_per_partition //
329
+ self.quant_config.min_thread_n) * self.quant_config.max_parallel
330
+ workspace = torch.zeros(max_workspace_size,
331
+ dtype=torch.int,
332
+ requires_grad=False)
333
+
334
+ layer.register_parameter("qweight", qweight)
335
+ layer.register_parameter("g_idx", g_idx)
336
+ layer.register_parameter("g_idx_sort_indices", g_idx_sort_indices)
337
+ layer.register_parameter("scales", scales)
338
+ layer.register_parameter("qzeros", qzeros)
339
+ layer.workspace = workspace
340
+ layer.input_size_per_partition = input_size_per_partition
341
+ layer.output_size_per_partition = output_size_per_partition
342
+ layer.input_size = input_size
343
+ layer.is_k_full = is_k_full
344
+ layer.marlin_state = GPTQMarlinState.REPACK
345
+
346
+ def apply(
347
+ self,
348
+ layer: torch.nn.Module,
349
+ x: torch.Tensor,
350
+ bias: Optional[torch.Tensor] = None,
351
+ ) -> torch.Tensor:
352
+ reshaped_x = x.reshape(-1, x.shape[-1])
353
+
354
+ size_m = reshaped_x.shape[0]
355
+ part_size_n = layer.output_size_per_partition
356
+ part_size_k = layer.input_size_per_partition
357
+ full_size_k = layer.input_size
358
+
359
+ out_shape = x.shape[:-1] + (part_size_n, )
360
+
361
+ if layer.marlin_state == GPTQMarlinState.REPACK:
362
+ layer.marlin_state = GPTQMarlinState.READY
363
+
364
+ # Newly generated tensors need to replace existing tensors that are
365
+ # already registered as parameters by vLLM (and won't be freed)
366
+ def replace_tensor(name, new_t):
367
+ # It is important to use resize_() here since it ensures
368
+ # the same buffer is reused
369
+ getattr(layer, name).resize_(new_t.shape)
370
+ getattr(layer, name).copy_(new_t)
371
+ del new_t
372
+
373
+ cur_device = layer.qweight.device
374
+
375
+ # Process act_order
376
+ if self.quant_config.desc_act:
377
+ # Get sorting based on g_idx
378
+ g_idx_sort_indices = torch.argsort(layer.g_idx).to(torch.int)
379
+
380
+ sorted_g_idx = layer.g_idx[g_idx_sort_indices]
381
+
382
+ replace_tensor("g_idx", sorted_g_idx)
383
+ replace_tensor("g_idx_sort_indices", g_idx_sort_indices)
384
+
385
+ else:
386
+ # Reset g_idx related tensors
387
+ layer.g_idx = Parameter(
388
+ torch.empty(0, dtype=torch.int, device=cur_device),
389
+ requires_grad=False,
390
+ )
391
+ layer.g_idx_sort_indices = Parameter(
392
+ torch.empty(0, dtype=torch.int, device=cur_device),
393
+ requires_grad=False,
394
+ )
395
+
396
+ # Repack weights
397
+ marlin_qweight = ops.gptq_marlin_repack(
398
+ layer.qweight,
399
+ layer.g_idx_sort_indices,
400
+ part_size_k,
401
+ part_size_n,
402
+ self.quant_config.weight_bits,
403
+ )
404
+ replace_tensor("qweight", marlin_qweight)
405
+
406
+ # Permute scales
407
+ scales_size_k = part_size_k
408
+ scales_size_n = part_size_n
409
+ if self.quant_config.desc_act:
410
+ scales_size_k = full_size_k
411
+
412
+ marlin_scales = marlin_permute_scales(
413
+ layer.scales,
414
+ scales_size_k,
415
+ scales_size_n,
416
+ self.quant_config.group_size,
417
+ self.quant_config.weight_bits,
418
+ )
419
+ replace_tensor("scales", marlin_scales)
420
+
421
+ output = ops.gptq_marlin_gemm(
422
+ reshaped_x,
423
+ layer.qweight,
424
+ layer.scales,
425
+ layer.g_idx,
426
+ layer.g_idx_sort_indices,
427
+ layer.workspace,
428
+ self.quant_config.weight_bits,
429
+ size_m,
430
+ part_size_n,
431
+ part_size_k,
432
+ layer.is_k_full,
433
+ )
434
+
435
+ if bias is not None:
436
+ output.add_(bias) # In-place add
437
+
438
+ return output.reshape(out_shape)
@@ -0,0 +1,227 @@
1
+ from typing import Any, Dict, List, Optional
2
+
3
+ import torch
4
+ from torch.nn.parameter import Parameter
5
+
6
+ from vllm import _custom_ops as ops
7
+ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
8
+ from vllm.model_executor.layers.quantization.base_config import (
9
+ QuantizationConfig)
10
+ from vllm.model_executor.utils import set_weight_attrs
11
+
12
+
13
+ class MarlinConfig(QuantizationConfig):
14
+ """Config class for Marlin.
15
+
16
+ Reference: https://github.com/IST-DASLab/marlin/tree/master
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ group_size: int,
22
+ ) -> None:
23
+ # Group size for the quantization.
24
+ self.group_size = group_size
25
+ if self.group_size != 128 and self.group_size != -1:
26
+ raise ValueError(
27
+ "Currently, only group size 128 and -1 (channelwise) "
28
+ "is supported for Marlin, but got group_size of "
29
+ f"{self.group_size}")
30
+
31
+ # 4 Bits packed into 32 bit datatype.
32
+ self.pack_factor = 32 // 4
33
+
34
+ # Tile size used by marlin kernels.
35
+ self.tile_size = 16
36
+
37
+ # Min out_features dim
38
+ self.min_n_threads = 64
39
+
40
+ # Min in_features dim
41
+ self.min_k_threads = 128
42
+
43
+ # Max parallel problems to solve at once (improves large
44
+ # batch performance)
45
+ self.max_parallel = 16
46
+
47
+ # Permutation length used by the marlin kernels.
48
+ self.perm_len = 1024
49
+
50
+ def __repr__(self) -> str:
51
+ return f"MarlinConfig(group_size={self.group_size})"
52
+
53
+ @classmethod
54
+ def get_name(cls) -> str:
55
+ return "marlin"
56
+
57
+ @classmethod
58
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
59
+ return [torch.half]
60
+
61
+ @classmethod
62
+ # Need to figure it out
63
+ def get_min_capability(cls) -> int:
64
+ return 80
65
+
66
+ @classmethod
67
+ def get_config_filenames(cls) -> List[str]:
68
+ return ["quantize_config.json"]
69
+
70
+ @classmethod
71
+ def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
72
+ group_size = cls.get_from_keys(config, ["group_size"])
73
+ return cls(group_size)
74
+
75
+ def get_quant_method(
76
+ self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]:
77
+ if isinstance(layer, LinearBase):
78
+ return MarlinLinearMethod(self)
79
+ return None
80
+
81
+ def get_scaled_act_names(self) -> List[str]:
82
+ return []
83
+
84
+
85
+ class MarlinLinearMethod(LinearMethodBase):
86
+ """Linear method for Marlin.
87
+
88
+ Args:
89
+ quant_config: The Marlin quantization config.
90
+ """
91
+
92
+ def __init__(self, quant_config: MarlinConfig):
93
+ self.quant_config = quant_config
94
+
95
+ def create_weights(
96
+ self,
97
+ layer: torch.nn.Module,
98
+ input_size_per_partition: int,
99
+ output_partition_sizes: List[int],
100
+ input_size: int,
101
+ output_size: int,
102
+ params_dtype: torch.dtype,
103
+ **extra_weight_attrs,
104
+ ):
105
+ del output_size # Unused.
106
+
107
+ if params_dtype != torch.float16:
108
+ raise ValueError(
109
+ f"The params dtype must be float16, but got {params_dtype}")
110
+
111
+ # Validate output_size_per_partition
112
+ output_size_per_partition = sum(output_partition_sizes)
113
+ if output_size_per_partition % self.quant_config.min_n_threads != 0:
114
+ raise ValueError(
115
+ f"Weight output_size_per_partition = "
116
+ f"{output_size_per_partition} is not divisible by "
117
+ f"min_n_threads = {self.quant_config.min_n_threads}.")
118
+ if output_size_per_partition % self.quant_config.pack_factor != 0:
119
+ raise ValueError(
120
+ f"Weight output_size_per_partition = "
121
+ f"{output_size_per_partition} is not divisible by "
122
+ f"pack_factor = {self.quant_config.pack_factor}.")
123
+
124
+ # Validate input_size_per_partition
125
+ if input_size_per_partition % self.quant_config.min_k_threads != 0:
126
+ raise ValueError(
127
+ f"Weight input_size_per_partition = "
128
+ f"{input_size_per_partition} is not divisible by "
129
+ f"min_k_threads = {self.quant_config.min_k_threads}.")
130
+ if (self.quant_config.group_size != -1 and
131
+ input_size_per_partition % self.quant_config.group_size != 0):
132
+ raise ValueError(f"Weight input_size_per_partition = "
133
+ f"{input_size_per_partition} is not divisible by "
134
+ f"group_size = {self.quant_config.group_size}.")
135
+
136
+ # Check that we have at least 4 tiles horizontally in the shard
137
+ num_tiles_per_perm = self.quant_config.perm_len // (
138
+ self.quant_config.tile_size**2)
139
+ if output_size_per_partition % num_tiles_per_perm != 0:
140
+ raise ValueError(
141
+ "Each permutation group must reside on the same gpu")
142
+
143
+ # Quantized 4Bit weights packed into Int32.
144
+ qweight = Parameter(
145
+ torch.empty(
146
+ input_size_per_partition // self.quant_config.tile_size,
147
+ output_size_per_partition * self.quant_config.tile_size //
148
+ self.quant_config.pack_factor,
149
+ device="cuda",
150
+ dtype=torch.int32,
151
+ ),
152
+ requires_grad=False,
153
+ )
154
+ set_weight_attrs(
155
+ qweight,
156
+ {
157
+ "input_dim": 0,
158
+ "output_dim": 1,
159
+ "packed_dim": 1,
160
+ "pack_factor": self.quant_config.pack_factor,
161
+ "marlin_tile_size": self.quant_config.tile_size,
162
+ },
163
+ )
164
+
165
+ # Determine if channelwise or not
166
+ input_groups = (1 if self.quant_config.group_size == -1 else
167
+ input_size_per_partition //
168
+ self.quant_config.group_size)
169
+
170
+ scales = Parameter(
171
+ torch.empty(
172
+ input_groups,
173
+ output_size_per_partition,
174
+ device="cuda",
175
+ dtype=params_dtype,
176
+ ),
177
+ requires_grad=False,
178
+ )
179
+ set_weight_attrs(
180
+ scales,
181
+ {
182
+ "input_dim": None if input_groups == 1 else 0,
183
+ "output_dim": 1,
184
+ },
185
+ )
186
+
187
+ # Allocate workspace (Used for internal locking mechanism)
188
+ max_workspace_size = (
189
+ output_size_per_partition //
190
+ self.quant_config.min_n_threads) * self.quant_config.max_parallel
191
+ workspace = Parameter(torch.zeros(max_workspace_size,
192
+ device="cuda",
193
+ dtype=torch.int),
194
+ requires_grad=False)
195
+
196
+ layer.register_parameter("B", qweight)
197
+ set_weight_attrs(qweight, extra_weight_attrs)
198
+ layer.register_parameter("s", scales)
199
+ set_weight_attrs(scales, extra_weight_attrs)
200
+ layer.register_parameter("workspace", workspace)
201
+ set_weight_attrs(workspace, extra_weight_attrs)
202
+
203
+ def apply(
204
+ self,
205
+ layer: torch.nn.Module,
206
+ x: torch.Tensor,
207
+ bias: Optional[torch.Tensor] = None,
208
+ ) -> torch.Tensor:
209
+ qweight = layer.B
210
+ scales = layer.s
211
+ workspace = layer.workspace
212
+
213
+ x_2d = x.view(-1, x.shape[-1])
214
+
215
+ size_m = x_2d.shape[0]
216
+ size_k = x_2d.shape[1]
217
+ size_n = scales.shape[1]
218
+
219
+ output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace, size_m,
220
+ size_n, size_k)
221
+
222
+ output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
223
+
224
+ if bias is not None:
225
+ output.add_(bias) # In-place add
226
+
227
+ return output