vllm-npu 0.4.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. vllm/__init__.py +23 -0
  2. vllm/_custom_ops.py +251 -0
  3. vllm/attention/__init__.py +13 -0
  4. vllm/attention/backends/__init__.py +0 -0
  5. vllm/attention/backends/abstract.py +127 -0
  6. vllm/attention/backends/flash_attn.py +271 -0
  7. vllm/attention/backends/flashinfer.py +220 -0
  8. vllm/attention/backends/rocm_flash_attn.py +374 -0
  9. vllm/attention/backends/torch_sdpa.py +250 -0
  10. vllm/attention/backends/xformers.py +393 -0
  11. vllm/attention/layer.py +56 -0
  12. vllm/attention/ops/__init__.py +0 -0
  13. vllm/attention/ops/paged_attn.py +216 -0
  14. vllm/attention/ops/prefix_prefill.py +792 -0
  15. vllm/attention/ops/triton_flash_attention.py +810 -0
  16. vllm/attention/selector.py +91 -0
  17. vllm/block.py +84 -0
  18. vllm/config.py +1225 -0
  19. vllm/core/__init__.py +0 -0
  20. vllm/core/block/__init__.py +0 -0
  21. vllm/core/block/block_table.py +295 -0
  22. vllm/core/block/common.py +199 -0
  23. vllm/core/block/cpu_gpu_block_allocator.py +228 -0
  24. vllm/core/block/interfaces.py +205 -0
  25. vllm/core/block/naive_block.py +318 -0
  26. vllm/core/block/prefix_caching_block.py +606 -0
  27. vllm/core/block_manager_v1.py +625 -0
  28. vllm/core/block_manager_v2.py +258 -0
  29. vllm/core/evictor_v1.py +105 -0
  30. vllm/core/evictor_v2.py +127 -0
  31. vllm/core/interfaces.py +113 -0
  32. vllm/core/policy.py +45 -0
  33. vllm/core/scheduler.py +1163 -0
  34. vllm/distributed/__init__.py +3 -0
  35. vllm/distributed/communication_op.py +237 -0
  36. vllm/distributed/device_communicators/__init__.py +0 -0
  37. vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
  38. vllm/distributed/device_communicators/pynccl.py +287 -0
  39. vllm/distributed/device_communicators/pynccl_utils.py +66 -0
  40. vllm/distributed/parallel_state.py +339 -0
  41. vllm/distributed/utils.py +136 -0
  42. vllm/engine/__init__.py +0 -0
  43. vllm/engine/arg_utils.py +649 -0
  44. vllm/engine/async_llm_engine.py +737 -0
  45. vllm/engine/llm_engine.py +784 -0
  46. vllm/engine/metrics.py +368 -0
  47. vllm/engine/output_processor/__init__.py +0 -0
  48. vllm/engine/output_processor/interfaces.py +76 -0
  49. vllm/engine/output_processor/multi_step.py +142 -0
  50. vllm/engine/output_processor/single_step.py +284 -0
  51. vllm/engine/output_processor/stop_checker.py +101 -0
  52. vllm/engine/output_processor/util.py +19 -0
  53. vllm/entrypoints/__init__.py +0 -0
  54. vllm/entrypoints/api_server.py +119 -0
  55. vllm/entrypoints/llm.py +259 -0
  56. vllm/entrypoints/openai/__init__.py +0 -0
  57. vllm/entrypoints/openai/api_server.py +186 -0
  58. vllm/entrypoints/openai/cli_args.py +115 -0
  59. vllm/entrypoints/openai/protocol.py +460 -0
  60. vllm/entrypoints/openai/serving_chat.py +392 -0
  61. vllm/entrypoints/openai/serving_completion.py +347 -0
  62. vllm/entrypoints/openai/serving_engine.py +234 -0
  63. vllm/envs.py +217 -0
  64. vllm/executor/__init__.py +0 -0
  65. vllm/executor/cpu_executor.py +152 -0
  66. vllm/executor/distributed_gpu_executor.py +115 -0
  67. vllm/executor/executor_base.py +115 -0
  68. vllm/executor/gpu_executor.py +150 -0
  69. vllm/executor/multiproc_worker_utils.py +263 -0
  70. vllm/executor/neuron_executor.py +91 -0
  71. vllm/executor/ray_gpu_executor.py +327 -0
  72. vllm/executor/ray_utils.py +119 -0
  73. vllm/logger.py +153 -0
  74. vllm/logging/__init__.py +5 -0
  75. vllm/logging/formatter.py +15 -0
  76. vllm/lora/__init__.py +0 -0
  77. vllm/lora/fully_sharded_layers.py +262 -0
  78. vllm/lora/layers.py +1181 -0
  79. vllm/lora/lora.py +167 -0
  80. vllm/lora/models.py +645 -0
  81. vllm/lora/punica.py +213 -0
  82. vllm/lora/request.py +32 -0
  83. vllm/lora/utils.py +98 -0
  84. vllm/lora/worker_manager.py +251 -0
  85. vllm/model_executor/__init__.py +7 -0
  86. vllm/model_executor/guided_decoding/__init__.py +25 -0
  87. vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
  88. vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
  89. vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
  90. vllm/model_executor/layers/__init__.py +0 -0
  91. vllm/model_executor/layers/activation.py +173 -0
  92. vllm/model_executor/layers/fused_moe/__init__.py +7 -0
  93. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  94. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  97. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  98. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  99. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  100. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  101. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  102. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  103. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  104. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  105. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
  106. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  107. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  108. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  109. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  110. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
  111. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  112. vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
  113. vllm/model_executor/layers/layernorm.py +71 -0
  114. vllm/model_executor/layers/linear.py +709 -0
  115. vllm/model_executor/layers/logits_processor.py +115 -0
  116. vllm/model_executor/layers/ops/__init__.py +0 -0
  117. vllm/model_executor/layers/ops/rand.py +157 -0
  118. vllm/model_executor/layers/ops/sample.py +406 -0
  119. vllm/model_executor/layers/quantization/__init__.py +35 -0
  120. vllm/model_executor/layers/quantization/aqlm.py +376 -0
  121. vllm/model_executor/layers/quantization/awq.py +175 -0
  122. vllm/model_executor/layers/quantization/base_config.py +97 -0
  123. vllm/model_executor/layers/quantization/fp8.py +265 -0
  124. vllm/model_executor/layers/quantization/gptq.py +224 -0
  125. vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
  126. vllm/model_executor/layers/quantization/marlin.py +227 -0
  127. vllm/model_executor/layers/quantization/schema.py +84 -0
  128. vllm/model_executor/layers/quantization/squeezellm.py +137 -0
  129. vllm/model_executor/layers/rejection_sampler.py +405 -0
  130. vllm/model_executor/layers/rotary_embedding.py +525 -0
  131. vllm/model_executor/layers/sampler.py +1051 -0
  132. vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
  133. vllm/model_executor/model_loader/__init__.py +30 -0
  134. vllm/model_executor/model_loader/loader.py +362 -0
  135. vllm/model_executor/model_loader/neuron.py +136 -0
  136. vllm/model_executor/model_loader/tensorizer.py +368 -0
  137. vllm/model_executor/model_loader/utils.py +41 -0
  138. vllm/model_executor/model_loader/weight_utils.py +372 -0
  139. vllm/model_executor/models/__init__.py +119 -0
  140. vllm/model_executor/models/baichuan.py +410 -0
  141. vllm/model_executor/models/bloom.py +327 -0
  142. vllm/model_executor/models/chatglm.py +386 -0
  143. vllm/model_executor/models/commandr.py +373 -0
  144. vllm/model_executor/models/dbrx.py +413 -0
  145. vllm/model_executor/models/decilm.py +122 -0
  146. vllm/model_executor/models/deepseek.py +438 -0
  147. vllm/model_executor/models/falcon.py +444 -0
  148. vllm/model_executor/models/gemma.py +393 -0
  149. vllm/model_executor/models/gpt2.py +266 -0
  150. vllm/model_executor/models/gpt_bigcode.py +274 -0
  151. vllm/model_executor/models/gpt_j.py +281 -0
  152. vllm/model_executor/models/gpt_neox.py +295 -0
  153. vllm/model_executor/models/internlm2.py +323 -0
  154. vllm/model_executor/models/jais.py +333 -0
  155. vllm/model_executor/models/llama.py +442 -0
  156. vllm/model_executor/models/llava.py +239 -0
  157. vllm/model_executor/models/minicpm.py +531 -0
  158. vllm/model_executor/models/mixtral.py +583 -0
  159. vllm/model_executor/models/mixtral_quant.py +404 -0
  160. vllm/model_executor/models/mpt.py +295 -0
  161. vllm/model_executor/models/olmo.py +356 -0
  162. vllm/model_executor/models/opt.py +349 -0
  163. vllm/model_executor/models/orion.py +319 -0
  164. vllm/model_executor/models/phi.py +300 -0
  165. vllm/model_executor/models/qwen.py +284 -0
  166. vllm/model_executor/models/qwen2.py +367 -0
  167. vllm/model_executor/models/qwen2_moe.py +447 -0
  168. vllm/model_executor/models/stablelm.py +301 -0
  169. vllm/model_executor/models/starcoder2.py +302 -0
  170. vllm/model_executor/models/xverse.py +366 -0
  171. vllm/model_executor/sampling_metadata.py +588 -0
  172. vllm/model_executor/utils.py +35 -0
  173. vllm/outputs.py +150 -0
  174. vllm/py.typed +2 -0
  175. vllm/sampling_params.py +340 -0
  176. vllm/sequence.py +766 -0
  177. vllm/spec_decode/__init__.py +0 -0
  178. vllm/spec_decode/batch_expansion.py +397 -0
  179. vllm/spec_decode/interfaces.py +73 -0
  180. vllm/spec_decode/metrics.py +191 -0
  181. vllm/spec_decode/multi_step_worker.py +203 -0
  182. vllm/spec_decode/ngram_worker.py +176 -0
  183. vllm/spec_decode/spec_decode_worker.py +472 -0
  184. vllm/spec_decode/top1_proposer.py +200 -0
  185. vllm/spec_decode/util.py +228 -0
  186. vllm/test_utils.py +41 -0
  187. vllm/transformers_utils/__init__.py +0 -0
  188. vllm/transformers_utils/config.py +58 -0
  189. vllm/transformers_utils/configs/__init__.py +16 -0
  190. vllm/transformers_utils/configs/chatglm.py +68 -0
  191. vllm/transformers_utils/configs/dbrx.py +278 -0
  192. vllm/transformers_utils/configs/falcon.py +87 -0
  193. vllm/transformers_utils/configs/jais.py +236 -0
  194. vllm/transformers_utils/configs/mpt.py +178 -0
  195. vllm/transformers_utils/detokenizer.py +313 -0
  196. vllm/transformers_utils/tokenizer.py +149 -0
  197. vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
  198. vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
  199. vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
  200. vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
  201. vllm/transformers_utils/tokenizers/__init__.py +5 -0
  202. vllm/transformers_utils/tokenizers/baichuan.py +255 -0
  203. vllm/usage/__init__.py +0 -0
  204. vllm/usage/usage_lib.py +209 -0
  205. vllm/utils.py +677 -0
  206. vllm/worker/__init__.py +0 -0
  207. vllm/worker/cache_engine.py +105 -0
  208. vllm/worker/cpu_model_runner.py +346 -0
  209. vllm/worker/cpu_worker.py +321 -0
  210. vllm/worker/model_runner.py +1168 -0
  211. vllm/worker/neuron_model_runner.py +196 -0
  212. vllm/worker/neuron_worker.py +98 -0
  213. vllm/worker/worker.py +345 -0
  214. vllm/worker/worker_base.py +146 -0
  215. vllm_npu-0.4.2.dist-info/LICENSE +201 -0
  216. vllm_npu-0.4.2.dist-info/METADATA +173 -0
  217. vllm_npu-0.4.2.dist-info/RECORD +219 -0
  218. vllm_npu-0.4.2.dist-info/WHEEL +5 -0
  219. vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,709 @@
1
+ from abc import abstractmethod
2
+ from typing import List, Optional
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch.nn.parameter import Parameter
7
+
8
+ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
9
+ get_tensor_model_parallel_world_size,
10
+ split_tensor_along_last_dim,
11
+ tensor_model_parallel_all_gather,
12
+ tensor_model_parallel_all_reduce)
13
+ from vllm.logger import init_logger
14
+ from vllm.model_executor.layers.quantization.base_config import (
15
+ QuantizationConfig, QuantizeMethodBase)
16
+ from vllm.model_executor.utils import set_weight_attrs
17
+
18
+ logger = init_logger(__name__)
19
+
20
+
21
+ def adjust_marlin_shard(param, shard_size, shard_offset):
22
+ marlin_tile_size = getattr(param, "marlin_tile_size", None)
23
+ if marlin_tile_size is None:
24
+ return shard_size, shard_offset
25
+
26
+ return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
27
+
28
+
29
+ class LinearMethodBase(QuantizeMethodBase):
30
+ """Base class for different (maybe quantized) linear methods."""
31
+
32
+ @abstractmethod
33
+ def create_weights(self, layer: torch.nn.Module,
34
+ input_size_per_partition: int,
35
+ output_partition_sizes: List[int], input_size: int,
36
+ output_size: int, params_dtype: torch.dtype,
37
+ **extra_weight_attrs):
38
+ """Create weights for a linear layer.
39
+ The weights will be set as attributes of the layer.
40
+
41
+ Args:
42
+ layer: The layer that is using the LinearMethodBase factory.
43
+ input_size_per_partition: Size of the weight input dim on rank X.
44
+ output_partition_sizes: Sizes of the output dim of each logical
45
+ weight on rank X. E.g., output_partition_sizes for QKVLinear
46
+ is a list contains the width of Wq, Wk, Wv on rank X.
47
+ input_size: Size of the input dim of the weight across all ranks.
48
+ output_size: Size of the output dim of the weight across all ranks.
49
+ params_dtype: Datatype of the parameters.
50
+ """
51
+ raise NotImplementedError
52
+
53
+ @abstractmethod
54
+ def apply(self,
55
+ layer: torch.nn.Module,
56
+ x: torch.Tensor,
57
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
58
+ """Apply the weights in layer to the input tensor.
59
+
60
+ Expects create_weights to have been called before on the layer."""
61
+ raise NotImplementedError
62
+
63
+
64
+ class UnquantizedLinearMethod(LinearMethodBase):
65
+ """Linear method without quantization.
66
+
67
+ Args:
68
+ separate_bias_add: If true, add bias separately after matrix
69
+ multiplication.
70
+ """
71
+
72
+ def __init__(self, separate_bias_add: bool = False):
73
+ self.separate_bias_add = separate_bias_add
74
+
75
+ def create_weights(self, layer: torch.nn.Module,
76
+ input_size_per_partition: int,
77
+ output_partition_sizes: List[int], input_size: int,
78
+ output_size: int, params_dtype: torch.dtype,
79
+ **extra_weight_attrs):
80
+ output_size_per_partition = sum(output_partition_sizes)
81
+ weight = Parameter(torch.empty(output_size_per_partition,
82
+ input_size_per_partition,
83
+ dtype=params_dtype),
84
+ requires_grad=False)
85
+ set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
86
+ layer.register_parameter("weight", weight)
87
+ set_weight_attrs(weight, extra_weight_attrs)
88
+
89
+ def apply(self,
90
+ layer: torch.nn.Module,
91
+ x: torch.Tensor,
92
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
93
+ weight = layer.weight
94
+ if self.separate_bias_add:
95
+ if bias is not None:
96
+ return F.linear(x, weight) + bias
97
+ return F.linear(x, weight)
98
+ return F.linear(x, weight, bias)
99
+
100
+
101
+ class LinearBase(torch.nn.Module):
102
+ """Base linear layer.
103
+
104
+ Args:
105
+ input_size: input dimension of the linear layer.
106
+ output_size: output dimension of the linear layer.
107
+ bias: If true, add bias.
108
+ skip_bias_add: If true, skip adding bias but instead return it.
109
+ params_dtype: Data type for the parameters.
110
+ quant_config: Quantization configure.
111
+ """
112
+
113
+ def __init__(
114
+ self,
115
+ input_size: int,
116
+ output_size: int,
117
+ skip_bias_add: bool = False,
118
+ params_dtype: Optional[torch.dtype] = None,
119
+ quant_config: Optional[QuantizationConfig] = None,
120
+ ):
121
+ super().__init__()
122
+
123
+ # Keep input parameters
124
+ self.input_size = input_size
125
+ self.output_size = output_size
126
+ self.skip_bias_add = skip_bias_add
127
+ if params_dtype is None:
128
+ params_dtype = torch.get_default_dtype()
129
+ self.params_dtype = params_dtype
130
+ if quant_config is None:
131
+ self.quant_method: Optional[
132
+ QuantizeMethodBase] = UnquantizedLinearMethod()
133
+ else:
134
+ self.quant_method = quant_config.get_quant_method(self)
135
+
136
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
137
+ raise NotImplementedError
138
+
139
+
140
+ class ReplicatedLinear(LinearBase):
141
+ """Replicated linear layer.
142
+
143
+ Args:
144
+ input_size: input dimension of the linear layer.
145
+ output_size: output dimension of the linear layer.
146
+ bias: If true, add bias.
147
+ skip_bias_add: If true, skip adding bias but instead return it.
148
+ params_dtype: Data type for the parameters.
149
+ quant_config: Quantization configure.
150
+ """
151
+
152
+ def __init__(
153
+ self,
154
+ input_size: int,
155
+ output_size: int,
156
+ bias: bool = True,
157
+ skip_bias_add: bool = False,
158
+ params_dtype: Optional[torch.dtype] = None,
159
+ quant_config: Optional[QuantizationConfig] = None,
160
+ ):
161
+ super().__init__(input_size, output_size, skip_bias_add, params_dtype,
162
+ quant_config)
163
+
164
+ # All the linear layer supports quant method.
165
+ assert self.quant_method is not None
166
+ self.quant_method.create_weights(self, self.input_size,
167
+ [self.output_size], self.input_size,
168
+ self.output_size, self.params_dtype)
169
+
170
+ if bias:
171
+ self.bias = Parameter(
172
+ torch.empty(self.output_size, dtype=self.params_dtype))
173
+ set_weight_attrs(self.bias, {"output_dim": 0})
174
+ else:
175
+ self.register_parameter("bias", None)
176
+
177
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
178
+ bias = self.bias if not self.skip_bias_add else None
179
+ assert self.quant_method is not None
180
+ output = self.quant_method.apply(self, x, bias)
181
+ output_bias = self.bias if self.skip_bias_add else None
182
+ return output, output_bias
183
+
184
+ def extra_repr(self) -> str:
185
+ s = f"in_features={self.input_size}"
186
+ s += f", output_features={self.output_size}"
187
+ s += f", bias={self.bias is not None}"
188
+ return s
189
+
190
+
191
+ class ColumnParallelLinear(LinearBase):
192
+ """Linear layer with column parallelism.
193
+
194
+ The linear layer is defined as Y = XA + b. A is parallelized along
195
+ its second dimension as A = [A_1, ..., A_p].
196
+
197
+ Args:
198
+ input_size: first dimension of matrix A.
199
+ output_size: second dimension of matrix A.
200
+ bias: If true, add bias.
201
+ gather_output: If true, call all-gather on output and make Y available
202
+ to all GPUs, otherwise, every GPU will have its output
203
+ which is Y_i = XA_i
204
+ skip_bias_add: This was added to enable performance optimizations where
205
+ bias can be fused with other element-wise operations. we
206
+ skip adding bias but instead return it.
207
+ params_dtype: Data type for the parameters.
208
+ quant_config: Quantization configure.
209
+ output_sizes: list of output sizes packed into one output, like for QKV
210
+ the list would be size 3.
211
+ """
212
+
213
+ def __init__(
214
+ self,
215
+ input_size: int,
216
+ output_size: int,
217
+ bias: bool = True,
218
+ gather_output: bool = False,
219
+ skip_bias_add: bool = False,
220
+ params_dtype: Optional[torch.dtype] = None,
221
+ quant_config: Optional[QuantizationConfig] = None,
222
+ output_sizes: Optional[List[int]] = None,
223
+ ):
224
+ super().__init__(input_size, output_size, skip_bias_add, params_dtype,
225
+ quant_config)
226
+
227
+ self.gather_output = gather_output
228
+
229
+ # Divide the weight matrix along the last dimension.
230
+ tp_size = get_tensor_model_parallel_world_size()
231
+ self.output_size_per_partition = divide(output_size, tp_size)
232
+ if output_sizes is None:
233
+ output_sizes = [output_size]
234
+ # All the linear layer supports quant method.
235
+ assert self.quant_method is not None
236
+ self.quant_method.create_weights(self,
237
+ self.input_size,
238
+ [x // tp_size for x in output_sizes],
239
+ self.input_size,
240
+ self.output_size,
241
+ self.params_dtype,
242
+ weight_loader=self.weight_loader)
243
+ if bias:
244
+ self.bias = Parameter(
245
+ torch.empty(self.output_size_per_partition,
246
+ dtype=params_dtype))
247
+ set_weight_attrs(self.bias, {
248
+ "output_dim": 0,
249
+ "weight_loader": self.weight_loader,
250
+ })
251
+ else:
252
+ self.register_parameter("bias", None)
253
+
254
+ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
255
+ # Special case for Fp8 scales.
256
+ fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
257
+ None)
258
+
259
+ tp_rank = get_tensor_model_parallel_rank()
260
+ output_dim = getattr(param, "output_dim", None)
261
+ param_data = param.data
262
+ if output_dim is not None:
263
+ shard_size = param_data.shape[output_dim]
264
+ start_idx = tp_rank * shard_size
265
+ loaded_weight = loaded_weight.narrow(output_dim, start_idx,
266
+ shard_size)
267
+ # Special case for Fp8 scales.
268
+ elif fp8_scales_shard_indexer is not None:
269
+ param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
270
+ loaded_weight,
271
+ shard_id=0)
272
+
273
+ assert param_data.shape == loaded_weight.shape
274
+ param_data.copy_(loaded_weight)
275
+
276
+ def forward(self, input_):
277
+ bias = self.bias if not self.skip_bias_add else None
278
+
279
+ # Matrix multiply.
280
+ assert self.quant_method is not None
281
+ output_parallel = self.quant_method.apply(self, input_, bias)
282
+ if self.gather_output:
283
+ # All-gather across the partitions.
284
+ output = tensor_model_parallel_all_gather(output_parallel)
285
+ else:
286
+ output = output_parallel
287
+ output_bias = self.bias if self.skip_bias_add else None
288
+ return output, output_bias
289
+
290
+ def extra_repr(self) -> str:
291
+ s = f"in_features={self.input_size}"
292
+ s += f", output_features={self.output_size_per_partition}"
293
+ s += f", bias={self.bias is not None}"
294
+ s += f", tp_size={get_tensor_model_parallel_world_size()}"
295
+ s += f", gather_output={self.gather_output}"
296
+ return s
297
+
298
+
299
+ class MergedColumnParallelLinear(ColumnParallelLinear):
300
+ """Packed linear layers with column parallelism.
301
+
302
+ Similar to ColumnParallelLinear, but the weight matrix is concatenated
303
+ along the output dimension. When the weight matrix is loaded, the
304
+ different partitions are sharded separately.
305
+
306
+ Args:
307
+ input_size: input dimension of the linear layer.
308
+ output_sizes: list of output dimensions of the linear layer.
309
+ bias: If true, add bias.
310
+ gather_output: If true, call all-gather on output and make the output
311
+ available to all GPUs, otherwise, every GPU will have
312
+ its own output.
313
+ skip_bias_add: This was added to enable performance optimizations where
314
+ bias can be fused with other element-wise operations. we
315
+ skip adding bias but instead return it.
316
+ params_dtype: Data type for the parameters.
317
+ quant_config: Quantization configure.
318
+ """
319
+
320
+ def __init__(
321
+ self,
322
+ input_size: int,
323
+ output_sizes: List[int],
324
+ bias: bool = True,
325
+ gather_output: bool = False,
326
+ skip_bias_add: bool = False,
327
+ params_dtype: Optional[torch.dtype] = None,
328
+ quant_config: Optional[QuantizationConfig] = None,
329
+ ):
330
+ self.output_sizes = output_sizes
331
+ tp_size = get_tensor_model_parallel_world_size()
332
+ assert all(output_size % tp_size == 0 for output_size in output_sizes)
333
+ super().__init__(input_size, sum(output_sizes), bias, gather_output,
334
+ skip_bias_add, params_dtype, quant_config,
335
+ self.output_sizes)
336
+
337
+ def weight_loader(self,
338
+ param: Parameter,
339
+ loaded_weight: torch.Tensor,
340
+ loaded_shard_id: Optional[int] = None):
341
+
342
+ param_data = param.data
343
+ output_dim = getattr(param, "output_dim", None)
344
+ # Special case for AQLM codebooks.
345
+ is_metadata = getattr(param, "is_metadata", False)
346
+ # Special case for Fp8 scales.
347
+ fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
348
+ None)
349
+
350
+ if loaded_shard_id is None:
351
+ # Loaded weight is already packed.
352
+ if output_dim is None:
353
+ assert param_data.shape == loaded_weight.shape
354
+ param_data.copy_(loaded_weight)
355
+ return
356
+ current_shard_offset = 0
357
+ shard_offsets = []
358
+ for i, output_size in enumerate(self.output_sizes):
359
+ shard_offsets.append((i, current_shard_offset, output_size))
360
+ current_shard_offset += output_size
361
+ packed_dim = getattr(param, "packed_dim", None)
362
+ for shard_id, shard_offset, shard_size in shard_offsets:
363
+ # Special case for Quantization.
364
+ # If quantized, we need to adjust the offset and size to account
365
+ # for the packing.
366
+ if packed_dim == output_dim:
367
+ shard_size = shard_size // param.pack_factor
368
+ shard_offset = shard_offset // param.pack_factor
369
+ # Special case for Marlin.
370
+ shard_size, shard_offset = adjust_marlin_shard(
371
+ param, shard_size, shard_offset)
372
+
373
+ loaded_weight_shard = loaded_weight.narrow(
374
+ output_dim, shard_offset, shard_size)
375
+ self.weight_loader(param, loaded_weight_shard, shard_id)
376
+ return
377
+
378
+ assert loaded_shard_id < len(self.output_sizes)
379
+ tp_rank = get_tensor_model_parallel_rank()
380
+ tp_size = get_tensor_model_parallel_world_size()
381
+ if output_dim is not None:
382
+ shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
383
+ shard_size = self.output_sizes[loaded_shard_id] // tp_size
384
+ # Special case for quantization.
385
+ # If quantized, we need to adjust the offset and size to account
386
+ # for the packing.
387
+ packed_dim = getattr(param, "packed_dim", None)
388
+ if packed_dim == output_dim:
389
+ shard_size = shard_size // param.pack_factor
390
+ shard_offset = shard_offset // param.pack_factor
391
+ # Special case for Marlin.
392
+ shard_size, shard_offset = adjust_marlin_shard(
393
+ param, shard_size, shard_offset)
394
+
395
+ param_data = param_data.narrow(output_dim, shard_offset,
396
+ shard_size)
397
+ start_idx = tp_rank * shard_size
398
+ loaded_weight = loaded_weight.narrow(output_dim, start_idx,
399
+ shard_size)
400
+ # Special case for AQLM codebooks.
401
+ elif is_metadata:
402
+ # metadata indicates fixed size concatenated along dim 0
403
+ shard_size = loaded_weight.shape[0]
404
+ shard_offset = loaded_shard_id * shard_size
405
+ param_data = param_data.narrow(0, shard_offset, shard_size)
406
+ # Special case for Fp8 scales.
407
+ elif fp8_scales_shard_indexer is not None:
408
+ param_data, loaded_weight = fp8_scales_shard_indexer(
409
+ param_data, loaded_weight, loaded_shard_id)
410
+
411
+ else:
412
+ ignore_warning = getattr(param, "ignore_warning", False)
413
+ if not ignore_warning:
414
+ logger.warning(
415
+ "Loading a weight without `output_dim` attribute in "
416
+ "MergedColumnParallelLinear, assume the weight is "
417
+ "the same for all partitions.")
418
+ assert param_data.shape == loaded_weight.shape
419
+ param_data.copy_(loaded_weight)
420
+
421
+
422
+ class QKVParallelLinear(ColumnParallelLinear):
423
+ """Linear layers for the attention's QKV transformation.
424
+
425
+ Linear layers for the linear transformation of the query, key, and value
426
+ vectors in the attention layer. The weight matrix is concatenated along
427
+ the output dimension. The layer is parallelized along the head dimension.
428
+ When the number of key/value heads is smaller than the number of query
429
+ heads (e.g., multi-query/grouped-query attention), the key/value head may
430
+ be replicated while the query heads are partitioned.
431
+
432
+ Args:
433
+ hidden_size: input hidden state size of the transformer.
434
+ head_size: size of each attention head.
435
+ total_num_heads: total number of attention query heads.
436
+ total_num_kv_heads: total number of attention key/value heads. If
437
+ None, assume total_num_kv_heads = total_num_heads.
438
+ bias: If true, add bias.
439
+ skip_bias_add: This was added to enable performance optimizations where
440
+ bias can be fused with other element-wise operations. we
441
+ skip adding bias but instead return it.
442
+ params_dtype: Data type for the parameters.
443
+ quant_config: Quantization configure.
444
+ """
445
+
446
+ def __init__(
447
+ self,
448
+ hidden_size: int,
449
+ head_size: int,
450
+ total_num_heads: int,
451
+ total_num_kv_heads: Optional[int] = None,
452
+ bias: bool = True,
453
+ skip_bias_add: bool = False,
454
+ params_dtype: Optional[torch.dtype] = None,
455
+ quant_config: Optional[QuantizationConfig] = None,
456
+ ):
457
+ self.hidden_size = hidden_size
458
+ self.head_size = head_size
459
+ self.total_num_heads = total_num_heads
460
+ if total_num_kv_heads is None:
461
+ total_num_kv_heads = total_num_heads
462
+ self.total_num_kv_heads = total_num_kv_heads
463
+ # Divide the weight matrix along the last dimension.
464
+ tp_size = get_tensor_model_parallel_world_size()
465
+ self.num_heads = divide(self.total_num_heads, tp_size)
466
+ if tp_size >= self.total_num_kv_heads:
467
+ self.num_kv_heads = 1
468
+ self.num_kv_head_replicas = divide(tp_size,
469
+ self.total_num_kv_heads)
470
+ else:
471
+ self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
472
+ self.num_kv_head_replicas = 1
473
+ input_size = self.hidden_size
474
+ output_size = (self.num_heads +
475
+ 2 * self.num_kv_heads) * tp_size * self.head_size
476
+ output_sizes = [
477
+ self.num_heads * tp_size * self.head_size,
478
+ self.num_kv_heads * tp_size * self.head_size,
479
+ self.num_kv_heads * tp_size * self.head_size
480
+ ]
481
+
482
+ super().__init__(input_size, output_size, bias, False, skip_bias_add,
483
+ params_dtype, quant_config, output_sizes)
484
+
485
+ def weight_loader(self,
486
+ param: Parameter,
487
+ loaded_weight: torch.Tensor,
488
+ loaded_shard_id: Optional[str] = None):
489
+ param_data = param.data
490
+ output_dim = getattr(param, "output_dim", None)
491
+ # Special case for AQLM codebooks.
492
+ is_metadata = getattr(param, "is_metadata", False)
493
+ # Special case for Fp8 scales.
494
+ fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
495
+ None)
496
+
497
+ if loaded_shard_id is None:
498
+ # Loaded weight is already packed.
499
+ if output_dim is None:
500
+ assert param_data.shape == loaded_weight.shape
501
+ param_data.copy_(loaded_weight)
502
+ return
503
+ shard_offsets = [
504
+ # (shard_id, shard_offset, shard_size)
505
+ ("q", 0, self.total_num_heads * self.head_size),
506
+ ("k", self.total_num_heads * self.head_size,
507
+ self.total_num_kv_heads * self.head_size),
508
+ ("v", (self.total_num_heads + self.total_num_kv_heads) *
509
+ self.head_size, self.total_num_kv_heads * self.head_size),
510
+ ]
511
+ packed_dim = getattr(param, "packed_dim", None)
512
+ for shard_id, shard_offset, shard_size in shard_offsets:
513
+ # Special case for Quantized Weights.
514
+ # If quantized, we need to adjust the offset and size to account
515
+ # for the packing.
516
+ if packed_dim == output_dim:
517
+ shard_size = shard_size // param.pack_factor
518
+ shard_offset = shard_offset // param.pack_factor
519
+
520
+ # Special case for Marlin.
521
+ shard_size, shard_offset = adjust_marlin_shard(
522
+ param, shard_size, shard_offset)
523
+
524
+ loaded_weight_shard = loaded_weight.narrow(
525
+ output_dim, shard_offset, shard_size)
526
+ self.weight_loader(param, loaded_weight_shard, shard_id)
527
+ return
528
+
529
+ tp_rank = get_tensor_model_parallel_rank()
530
+ assert loaded_shard_id in ["q", "k", "v"]
531
+ if output_dim is not None:
532
+ if loaded_shard_id == "q":
533
+ shard_offset = 0
534
+ shard_size = self.num_heads * self.head_size
535
+ elif loaded_shard_id == "k":
536
+ shard_offset = self.num_heads * self.head_size
537
+ shard_size = self.num_kv_heads * self.head_size
538
+ elif loaded_shard_id == "v":
539
+ shard_offset = (self.num_heads +
540
+ self.num_kv_heads) * self.head_size
541
+ shard_size = self.num_kv_heads * self.head_size
542
+ # Special case for Quantized Weights.
543
+ # If quantized, we need to adjust the offset and size to account
544
+ # for the packing.
545
+ packed_dim = getattr(param, "packed_dim", None)
546
+ if packed_dim == output_dim:
547
+ shard_size = shard_size // param.pack_factor
548
+ shard_offset = shard_offset // param.pack_factor
549
+
550
+ # Special case for Marlin.
551
+ shard_size, shard_offset = adjust_marlin_shard(
552
+ param, shard_size, shard_offset)
553
+
554
+ param_data = param_data.narrow(output_dim, shard_offset,
555
+ shard_size)
556
+ if loaded_shard_id == "q":
557
+ shard_id = tp_rank
558
+ else:
559
+ shard_id = tp_rank // self.num_kv_head_replicas
560
+ start_idx = shard_id * shard_size
561
+ loaded_weight = loaded_weight.narrow(output_dim, start_idx,
562
+ shard_size)
563
+ # Special case for for AQLM codebooks.
564
+ elif is_metadata:
565
+ # metadata indicates fixed size concatenated along dim 0
566
+ shard_size = loaded_weight.shape[0]
567
+ shard_index = ["q", "k", "v"].index(loaded_shard_id)
568
+ param_data = param_data.narrow(0, shard_index * shard_size,
569
+ shard_size)
570
+ # Special case for Fp8 scales.
571
+ elif fp8_scales_shard_indexer is not None:
572
+ param_data, loaded_weight = fp8_scales_shard_indexer(
573
+ param_data, loaded_weight, loaded_shard_id)
574
+ else:
575
+ ignore_warning = getattr(param, "ignore_warning", False)
576
+ if not ignore_warning:
577
+ logger.warning(
578
+ "Loading a weight without `output_dim` attribute in "
579
+ "QKVParallelLinear, assume the weight is the same "
580
+ "for all partitions.")
581
+ assert param_data.shape == loaded_weight.shape
582
+ param_data.copy_(loaded_weight)
583
+
584
+
585
+ class RowParallelLinear(LinearBase):
586
+ """Linear layer with row parallelism.
587
+
588
+ The linear layer is defined as Y = XA + b. A is parallelized along
589
+ its first dimension and X along its second dimension as:
590
+ - -
591
+ | A_1 |
592
+ | . |
593
+ A = | . | X = [X_1, ..., X_p]
594
+ | . |
595
+ | A_p |
596
+ - -
597
+ Arguments:
598
+ input_size: first dimension of matrix A.
599
+ output_size: second dimension of matrix A.
600
+ bias: If true, add bias. Note that bias is not parallelized.
601
+ input_is_parallel: If true, we assume that the input is already
602
+ split across the GPUs and we do not split
603
+ again.
604
+ skip_bias_add: This was added to enable performance optimization where
605
+ bias can be fused with other element-wise operations.
606
+ We skip adding bias but instead return it.
607
+ params_dtype: Data type for the parameters.
608
+ quant_config: Quantization configure.
609
+ """
610
+
611
+ def __init__(
612
+ self,
613
+ input_size: int,
614
+ output_size: int,
615
+ bias: bool = True,
616
+ input_is_parallel: bool = True,
617
+ skip_bias_add: bool = False,
618
+ params_dtype: Optional[torch.dtype] = None,
619
+ reduce_results: bool = True,
620
+ quant_config: Optional[QuantizationConfig] = None,
621
+ ):
622
+ super().__init__(input_size, output_size, skip_bias_add, params_dtype,
623
+ quant_config)
624
+
625
+ self.input_is_parallel = input_is_parallel
626
+ self.reduce_results = reduce_results
627
+
628
+ # Divide the weight matrix along the last dimension.
629
+ self.tp_size = get_tensor_model_parallel_world_size()
630
+ self.input_size_per_partition = divide(input_size, self.tp_size)
631
+ # All the linear layer supports quant method.
632
+ assert self.quant_method is not None
633
+ self.quant_method.create_weights(self,
634
+ self.input_size_per_partition,
635
+ [self.output_size],
636
+ self.input_size,
637
+ self.output_size,
638
+ self.params_dtype,
639
+ weight_loader=self.weight_loader)
640
+
641
+ if not reduce_results and (bias and not skip_bias_add):
642
+ raise ValueError("When not reduce the results, adding bias to the "
643
+ "results can lead to incorrect results")
644
+
645
+ if bias:
646
+ self.bias = Parameter(
647
+ torch.empty(self.output_size, dtype=params_dtype))
648
+ set_weight_attrs(self.bias, {
649
+ "output_dim": 0,
650
+ "weight_loader": self.weight_loader,
651
+ })
652
+ else:
653
+ self.register_parameter("bias", None)
654
+
655
+ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
656
+ # Special case for Fp8 scales.
657
+ fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
658
+ None)
659
+
660
+ tp_rank = get_tensor_model_parallel_rank()
661
+ input_dim = getattr(param, "input_dim", None)
662
+ param_data = param.data
663
+ if input_dim is not None:
664
+ shard_size = param_data.shape[input_dim]
665
+ start_idx = tp_rank * shard_size
666
+ loaded_weight = loaded_weight.narrow(input_dim, start_idx,
667
+ shard_size)
668
+ # Special case for Fp8 scales.
669
+ elif fp8_scales_shard_indexer is not None:
670
+ param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
671
+ loaded_weight,
672
+ shard_id=0)
673
+
674
+ assert param_data.shape == loaded_weight.shape
675
+ param_data.copy_(loaded_weight)
676
+
677
+ def forward(self, input_):
678
+ # Set up backprop all-reduce.
679
+ if self.input_is_parallel:
680
+ input_parallel = input_
681
+ else:
682
+ tp_rank = get_tensor_model_parallel_rank()
683
+ splitted_input = split_tensor_along_last_dim(
684
+ input_, num_partitions=self.tp_size)
685
+ input_parallel = splitted_input[tp_rank].contiguous()
686
+
687
+ # Matrix multiply.
688
+ assert self.quant_method is not None
689
+ output_parallel = self.quant_method.apply(self, input_parallel)
690
+ if self.reduce_results and self.tp_size > 1:
691
+ output_ = tensor_model_parallel_all_reduce(output_parallel)
692
+ else:
693
+ output_ = output_parallel
694
+
695
+ if not self.skip_bias_add:
696
+ output = output_ + self.bias if self.bias is not None else output_
697
+ output_bias = None
698
+ else:
699
+ output = output_
700
+ output_bias = self.bias
701
+ return output, output_bias
702
+
703
+ def extra_repr(self) -> str:
704
+ s = f"input_features={self.input_size_per_partition}"
705
+ s += f", output_features={self.output_size}"
706
+ s += f", bias={self.bias is not None}"
707
+ s += f", tp_size={self.tp_size}"
708
+ s += f", reduce_results={self.reduce_results}"
709
+ return s