vllm-npu 0.4.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. vllm/__init__.py +23 -0
  2. vllm/_custom_ops.py +251 -0
  3. vllm/attention/__init__.py +13 -0
  4. vllm/attention/backends/__init__.py +0 -0
  5. vllm/attention/backends/abstract.py +127 -0
  6. vllm/attention/backends/flash_attn.py +271 -0
  7. vllm/attention/backends/flashinfer.py +220 -0
  8. vllm/attention/backends/rocm_flash_attn.py +374 -0
  9. vllm/attention/backends/torch_sdpa.py +250 -0
  10. vllm/attention/backends/xformers.py +393 -0
  11. vllm/attention/layer.py +56 -0
  12. vllm/attention/ops/__init__.py +0 -0
  13. vllm/attention/ops/paged_attn.py +216 -0
  14. vllm/attention/ops/prefix_prefill.py +792 -0
  15. vllm/attention/ops/triton_flash_attention.py +810 -0
  16. vllm/attention/selector.py +91 -0
  17. vllm/block.py +84 -0
  18. vllm/config.py +1225 -0
  19. vllm/core/__init__.py +0 -0
  20. vllm/core/block/__init__.py +0 -0
  21. vllm/core/block/block_table.py +295 -0
  22. vllm/core/block/common.py +199 -0
  23. vllm/core/block/cpu_gpu_block_allocator.py +228 -0
  24. vllm/core/block/interfaces.py +205 -0
  25. vllm/core/block/naive_block.py +318 -0
  26. vllm/core/block/prefix_caching_block.py +606 -0
  27. vllm/core/block_manager_v1.py +625 -0
  28. vllm/core/block_manager_v2.py +258 -0
  29. vllm/core/evictor_v1.py +105 -0
  30. vllm/core/evictor_v2.py +127 -0
  31. vllm/core/interfaces.py +113 -0
  32. vllm/core/policy.py +45 -0
  33. vllm/core/scheduler.py +1163 -0
  34. vllm/distributed/__init__.py +3 -0
  35. vllm/distributed/communication_op.py +237 -0
  36. vllm/distributed/device_communicators/__init__.py +0 -0
  37. vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
  38. vllm/distributed/device_communicators/pynccl.py +287 -0
  39. vllm/distributed/device_communicators/pynccl_utils.py +66 -0
  40. vllm/distributed/parallel_state.py +339 -0
  41. vllm/distributed/utils.py +136 -0
  42. vllm/engine/__init__.py +0 -0
  43. vllm/engine/arg_utils.py +649 -0
  44. vllm/engine/async_llm_engine.py +737 -0
  45. vllm/engine/llm_engine.py +784 -0
  46. vllm/engine/metrics.py +368 -0
  47. vllm/engine/output_processor/__init__.py +0 -0
  48. vllm/engine/output_processor/interfaces.py +76 -0
  49. vllm/engine/output_processor/multi_step.py +142 -0
  50. vllm/engine/output_processor/single_step.py +284 -0
  51. vllm/engine/output_processor/stop_checker.py +101 -0
  52. vllm/engine/output_processor/util.py +19 -0
  53. vllm/entrypoints/__init__.py +0 -0
  54. vllm/entrypoints/api_server.py +119 -0
  55. vllm/entrypoints/llm.py +259 -0
  56. vllm/entrypoints/openai/__init__.py +0 -0
  57. vllm/entrypoints/openai/api_server.py +186 -0
  58. vllm/entrypoints/openai/cli_args.py +115 -0
  59. vllm/entrypoints/openai/protocol.py +460 -0
  60. vllm/entrypoints/openai/serving_chat.py +392 -0
  61. vllm/entrypoints/openai/serving_completion.py +347 -0
  62. vllm/entrypoints/openai/serving_engine.py +234 -0
  63. vllm/envs.py +217 -0
  64. vllm/executor/__init__.py +0 -0
  65. vllm/executor/cpu_executor.py +152 -0
  66. vllm/executor/distributed_gpu_executor.py +115 -0
  67. vllm/executor/executor_base.py +115 -0
  68. vllm/executor/gpu_executor.py +150 -0
  69. vllm/executor/multiproc_worker_utils.py +263 -0
  70. vllm/executor/neuron_executor.py +91 -0
  71. vllm/executor/ray_gpu_executor.py +327 -0
  72. vllm/executor/ray_utils.py +119 -0
  73. vllm/logger.py +153 -0
  74. vllm/logging/__init__.py +5 -0
  75. vllm/logging/formatter.py +15 -0
  76. vllm/lora/__init__.py +0 -0
  77. vllm/lora/fully_sharded_layers.py +262 -0
  78. vllm/lora/layers.py +1181 -0
  79. vllm/lora/lora.py +167 -0
  80. vllm/lora/models.py +645 -0
  81. vllm/lora/punica.py +213 -0
  82. vllm/lora/request.py +32 -0
  83. vllm/lora/utils.py +98 -0
  84. vllm/lora/worker_manager.py +251 -0
  85. vllm/model_executor/__init__.py +7 -0
  86. vllm/model_executor/guided_decoding/__init__.py +25 -0
  87. vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
  88. vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
  89. vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
  90. vllm/model_executor/layers/__init__.py +0 -0
  91. vllm/model_executor/layers/activation.py +173 -0
  92. vllm/model_executor/layers/fused_moe/__init__.py +7 -0
  93. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  94. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  97. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  98. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  99. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  100. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  101. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  102. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  103. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  104. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  105. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
  106. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  107. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  108. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  109. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  110. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
  111. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  112. vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
  113. vllm/model_executor/layers/layernorm.py +71 -0
  114. vllm/model_executor/layers/linear.py +709 -0
  115. vllm/model_executor/layers/logits_processor.py +115 -0
  116. vllm/model_executor/layers/ops/__init__.py +0 -0
  117. vllm/model_executor/layers/ops/rand.py +157 -0
  118. vllm/model_executor/layers/ops/sample.py +406 -0
  119. vllm/model_executor/layers/quantization/__init__.py +35 -0
  120. vllm/model_executor/layers/quantization/aqlm.py +376 -0
  121. vllm/model_executor/layers/quantization/awq.py +175 -0
  122. vllm/model_executor/layers/quantization/base_config.py +97 -0
  123. vllm/model_executor/layers/quantization/fp8.py +265 -0
  124. vllm/model_executor/layers/quantization/gptq.py +224 -0
  125. vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
  126. vllm/model_executor/layers/quantization/marlin.py +227 -0
  127. vllm/model_executor/layers/quantization/schema.py +84 -0
  128. vllm/model_executor/layers/quantization/squeezellm.py +137 -0
  129. vllm/model_executor/layers/rejection_sampler.py +405 -0
  130. vllm/model_executor/layers/rotary_embedding.py +525 -0
  131. vllm/model_executor/layers/sampler.py +1051 -0
  132. vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
  133. vllm/model_executor/model_loader/__init__.py +30 -0
  134. vllm/model_executor/model_loader/loader.py +362 -0
  135. vllm/model_executor/model_loader/neuron.py +136 -0
  136. vllm/model_executor/model_loader/tensorizer.py +368 -0
  137. vllm/model_executor/model_loader/utils.py +41 -0
  138. vllm/model_executor/model_loader/weight_utils.py +372 -0
  139. vllm/model_executor/models/__init__.py +119 -0
  140. vllm/model_executor/models/baichuan.py +410 -0
  141. vllm/model_executor/models/bloom.py +327 -0
  142. vllm/model_executor/models/chatglm.py +386 -0
  143. vllm/model_executor/models/commandr.py +373 -0
  144. vllm/model_executor/models/dbrx.py +413 -0
  145. vllm/model_executor/models/decilm.py +122 -0
  146. vllm/model_executor/models/deepseek.py +438 -0
  147. vllm/model_executor/models/falcon.py +444 -0
  148. vllm/model_executor/models/gemma.py +393 -0
  149. vllm/model_executor/models/gpt2.py +266 -0
  150. vllm/model_executor/models/gpt_bigcode.py +274 -0
  151. vllm/model_executor/models/gpt_j.py +281 -0
  152. vllm/model_executor/models/gpt_neox.py +295 -0
  153. vllm/model_executor/models/internlm2.py +323 -0
  154. vllm/model_executor/models/jais.py +333 -0
  155. vllm/model_executor/models/llama.py +442 -0
  156. vllm/model_executor/models/llava.py +239 -0
  157. vllm/model_executor/models/minicpm.py +531 -0
  158. vllm/model_executor/models/mixtral.py +583 -0
  159. vllm/model_executor/models/mixtral_quant.py +404 -0
  160. vllm/model_executor/models/mpt.py +295 -0
  161. vllm/model_executor/models/olmo.py +356 -0
  162. vllm/model_executor/models/opt.py +349 -0
  163. vllm/model_executor/models/orion.py +319 -0
  164. vllm/model_executor/models/phi.py +300 -0
  165. vllm/model_executor/models/qwen.py +284 -0
  166. vllm/model_executor/models/qwen2.py +367 -0
  167. vllm/model_executor/models/qwen2_moe.py +447 -0
  168. vllm/model_executor/models/stablelm.py +301 -0
  169. vllm/model_executor/models/starcoder2.py +302 -0
  170. vllm/model_executor/models/xverse.py +366 -0
  171. vllm/model_executor/sampling_metadata.py +588 -0
  172. vllm/model_executor/utils.py +35 -0
  173. vllm/outputs.py +150 -0
  174. vllm/py.typed +2 -0
  175. vllm/sampling_params.py +340 -0
  176. vllm/sequence.py +766 -0
  177. vllm/spec_decode/__init__.py +0 -0
  178. vllm/spec_decode/batch_expansion.py +397 -0
  179. vllm/spec_decode/interfaces.py +73 -0
  180. vllm/spec_decode/metrics.py +191 -0
  181. vllm/spec_decode/multi_step_worker.py +203 -0
  182. vllm/spec_decode/ngram_worker.py +176 -0
  183. vllm/spec_decode/spec_decode_worker.py +472 -0
  184. vllm/spec_decode/top1_proposer.py +200 -0
  185. vllm/spec_decode/util.py +228 -0
  186. vllm/test_utils.py +41 -0
  187. vllm/transformers_utils/__init__.py +0 -0
  188. vllm/transformers_utils/config.py +58 -0
  189. vllm/transformers_utils/configs/__init__.py +16 -0
  190. vllm/transformers_utils/configs/chatglm.py +68 -0
  191. vllm/transformers_utils/configs/dbrx.py +278 -0
  192. vllm/transformers_utils/configs/falcon.py +87 -0
  193. vllm/transformers_utils/configs/jais.py +236 -0
  194. vllm/transformers_utils/configs/mpt.py +178 -0
  195. vllm/transformers_utils/detokenizer.py +313 -0
  196. vllm/transformers_utils/tokenizer.py +149 -0
  197. vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
  198. vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
  199. vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
  200. vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
  201. vllm/transformers_utils/tokenizers/__init__.py +5 -0
  202. vllm/transformers_utils/tokenizers/baichuan.py +255 -0
  203. vllm/usage/__init__.py +0 -0
  204. vllm/usage/usage_lib.py +209 -0
  205. vllm/utils.py +677 -0
  206. vllm/worker/__init__.py +0 -0
  207. vllm/worker/cache_engine.py +105 -0
  208. vllm/worker/cpu_model_runner.py +346 -0
  209. vllm/worker/cpu_worker.py +321 -0
  210. vllm/worker/model_runner.py +1168 -0
  211. vllm/worker/neuron_model_runner.py +196 -0
  212. vllm/worker/neuron_worker.py +98 -0
  213. vllm/worker/worker.py +345 -0
  214. vllm/worker/worker_base.py +146 -0
  215. vllm_npu-0.4.2.dist-info/LICENSE +201 -0
  216. vllm_npu-0.4.2.dist-info/METADATA +173 -0
  217. vllm_npu-0.4.2.dist-info/RECORD +219 -0
  218. vllm_npu-0.4.2.dist-info/WHEEL +5 -0
  219. vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,356 @@
1
+ # coding=utf-8
2
+ # Adapted from
3
+ # https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/olmo/modeling_olmo.py
4
+ # Copyright 2024 The vLLM team.
5
+ # Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
6
+ #
7
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
8
+ # and OPT implementations in this library. It has been modified from its
9
+ # original forms to accommodate minor architectural differences compared
10
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
11
+ #
12
+ # Licensed under the Apache License, Version 2.0 (the "License");
13
+ # you may not use this file except in compliance with the License.
14
+ # You may obtain a copy of the License at
15
+ #
16
+ # http://www.apache.org/licenses/LICENSE-2.0
17
+ #
18
+ # Unless required by applicable law or agreed to in writing, software
19
+ # distributed under the License is distributed on an "AS IS" BASIS,
20
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21
+ # See the License for the specific language governing permissions and
22
+ # limitations under the License.
23
+ """Inference-only OLMo model compatible with HuggingFace weights."""
24
+ from typing import Iterable, List, Optional, Tuple
25
+
26
+ import torch
27
+ from torch import nn
28
+ from transformers import OlmoConfig
29
+
30
+ from vllm.attention import Attention, AttentionMetadata
31
+ from vllm.distributed import get_tensor_model_parallel_world_size
32
+ from vllm.model_executor.layers.activation import SiluAndMul
33
+ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
34
+ QKVParallelLinear,
35
+ RowParallelLinear)
36
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
37
+ from vllm.model_executor.layers.quantization.base_config import (
38
+ QuantizationConfig)
39
+ from vllm.model_executor.layers.rotary_embedding import get_rope
40
+ from vllm.model_executor.layers.sampler import Sampler
41
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
42
+ ParallelLMHead, VocabParallelEmbedding)
43
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
44
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
45
+ from vllm.sequence import SamplerOutput
46
+
47
+
48
+ class OlmoAttention(nn.Module):
49
+ """
50
+ This is the attention block where the output is computed as
51
+ ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
52
+ (plus another skip connection).
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ config: OlmoConfig,
58
+ quant_config: Optional[QuantizationConfig] = None,
59
+ ):
60
+ super().__init__()
61
+ self.config = config
62
+ self.hidden_size = config.hidden_size
63
+ tensor_model_parallel_world_size = (
64
+ get_tensor_model_parallel_world_size())
65
+ self.total_num_heads = config.num_attention_heads
66
+
67
+ assert self.hidden_size % self.total_num_heads == 0
68
+ assert self.total_num_heads % tensor_model_parallel_world_size == 0
69
+
70
+ self.num_heads = (self.total_num_heads //
71
+ tensor_model_parallel_world_size)
72
+ self.head_dim = self.hidden_size // self.total_num_heads
73
+ self.max_position_embeddings = config.max_position_embeddings
74
+ self.rope_theta = config.rope_theta
75
+ self.clip_qkv = config.clip_qkv
76
+
77
+ # Attention input projection. Projects x -> (q, k, v)
78
+ self.qkv_proj = QKVParallelLinear(
79
+ self.hidden_size,
80
+ self.head_dim,
81
+ self.total_num_heads,
82
+ bias=config.attention_bias,
83
+ quant_config=quant_config,
84
+ )
85
+
86
+ # Rotary embeddings.
87
+ self.rotary_emb = get_rope(
88
+ self.head_dim,
89
+ rotary_dim=self.head_dim,
90
+ max_position=self.max_position_embeddings,
91
+ base=self.rope_theta,
92
+ )
93
+ self.scaling = self.head_dim**-0.5
94
+ self.attn = Attention(self.num_heads,
95
+ self.head_dim,
96
+ scale=self.scaling)
97
+
98
+ # Attention output projection.
99
+ self.o_proj = RowParallelLinear(
100
+ self.hidden_size,
101
+ self.hidden_size,
102
+ bias=config.attention_bias,
103
+ quant_config=quant_config,
104
+ )
105
+
106
+ def forward(
107
+ self,
108
+ positions: torch.Tensor,
109
+ hidden_states: torch.Tensor,
110
+ kv_cache: torch.Tensor,
111
+ attn_metadata: AttentionMetadata,
112
+ ) -> torch.Tensor:
113
+ qkv, _ = self.qkv_proj(hidden_states)
114
+ if self.clip_qkv is not None:
115
+ qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
116
+ q, k, v = qkv.chunk(chunks=3, dim=-1)
117
+ q, k = self.rotary_emb(positions, q, k)
118
+ attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
119
+ output, _ = self.o_proj(attn_output)
120
+ return output
121
+
122
+
123
+ class OlmoMLP(nn.Module):
124
+ """
125
+ This is the MLP block where the output is computed as
126
+ ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
127
+ (plus another skip connection).
128
+ """
129
+
130
+ def __init__(
131
+ self,
132
+ config: OlmoConfig,
133
+ quant_config: Optional[QuantizationConfig] = None,
134
+ ):
135
+ super().__init__()
136
+ self.config = config
137
+ self.hidden_size = config.hidden_size
138
+ self.intermediate_size = config.intermediate_size
139
+
140
+ # Feed-forward input projection.
141
+ self.gate_up_proj = MergedColumnParallelLinear(
142
+ self.hidden_size,
143
+ [self.intermediate_size] * 2,
144
+ bias=False,
145
+ quant_config=quant_config,
146
+ )
147
+
148
+ # Activation function.
149
+ self.act_fn = SiluAndMul()
150
+
151
+ # Feed-forward output projection.
152
+ self.down_proj = RowParallelLinear(
153
+ self.intermediate_size,
154
+ self.hidden_size,
155
+ bias=False,
156
+ quant_config=quant_config,
157
+ )
158
+
159
+ def forward(
160
+ self,
161
+ x: torch.Tensor,
162
+ ) -> torch.Tensor:
163
+ gate_up, _ = self.gate_up_proj(x)
164
+ x = self.act_fn(gate_up)
165
+ x, _ = self.down_proj(x)
166
+ return x
167
+
168
+
169
+ class OlmoDecoderLayer(nn.Module):
170
+ """
171
+ This is a typical transformer block where the output is
172
+ computed as ``MLP(LN(x + Attention(LN(x))))``
173
+ (plus another skip connection).
174
+ """
175
+
176
+ def __init__(self,
177
+ config: OlmoConfig,
178
+ quant_config: Optional[QuantizationConfig] = None):
179
+ super().__init__()
180
+ # Attention block.
181
+ self.self_attn = OlmoAttention(config, quant_config)
182
+
183
+ # MLP block.
184
+ self.mlp = OlmoMLP(config, quant_config)
185
+
186
+ # LayerNorm
187
+ self.input_layernorm = nn.LayerNorm(config.hidden_size,
188
+ elementwise_affine=False,
189
+ bias=False)
190
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
191
+ elementwise_affine=False,
192
+ bias=False)
193
+
194
+ def forward(
195
+ self,
196
+ positions: torch.Tensor,
197
+ hidden_states: torch.Tensor,
198
+ kv_cache: torch.Tensor,
199
+ attn_metadata: AttentionMetadata,
200
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
201
+ # Attention block.
202
+ residual = hidden_states
203
+ hidden_states = self.input_layernorm(hidden_states)
204
+ hidden_states = self.self_attn(positions, hidden_states, kv_cache,
205
+ attn_metadata)
206
+ hidden_states = hidden_states + residual
207
+
208
+ # MLP block.
209
+ residual = hidden_states
210
+ hidden_states = self.post_attention_layernorm(hidden_states)
211
+ hidden_states = self.mlp(hidden_states)
212
+ hidden_states = residual + hidden_states
213
+ return hidden_states
214
+
215
+
216
+ class OlmoModel(nn.Module):
217
+
218
+ def __init__(self,
219
+ config: OlmoConfig,
220
+ quant_config: Optional[QuantizationConfig] = None):
221
+ super().__init__()
222
+ self.config = config
223
+
224
+ self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
225
+ config.hidden_size)
226
+ self.layers = nn.ModuleList([
227
+ OlmoDecoderLayer(config, quant_config)
228
+ for layer_idx in range(config.num_hidden_layers)
229
+ ])
230
+ self.norm = nn.LayerNorm(config.hidden_size,
231
+ elementwise_affine=False,
232
+ bias=False)
233
+
234
+ def forward(
235
+ self,
236
+ input_ids: torch.Tensor,
237
+ positions: torch.Tensor,
238
+ kv_caches: List[torch.Tensor],
239
+ attn_metadata: AttentionMetadata,
240
+ ) -> torch.Tensor:
241
+ """
242
+ :param input_ids: A tensor of shape `(batch_size, seq_len)`.
243
+ """
244
+ # Get embeddings of input.
245
+ # shape: (batch_size, seq_len, d_model)
246
+ inputs_embeds = self.embed_tokens(input_ids)
247
+
248
+ # embed positions
249
+ hidden_states = inputs_embeds
250
+
251
+ # Apply blocks one-by-one.
252
+ for layer_idx, decoder_layer in enumerate(self.layers):
253
+ # shape: (batch_size, seq_len, d_model)
254
+ hidden_states = decoder_layer(
255
+ positions,
256
+ hidden_states,
257
+ kv_caches[layer_idx],
258
+ attn_metadata,
259
+ )
260
+
261
+ # Apply final layer norm.
262
+ # shape: (batch_size, seq_len or 1, d_model)
263
+ hidden_states = self.norm(hidden_states)
264
+ return hidden_states
265
+
266
+
267
+ class OlmoForCausalLM(nn.Module):
268
+ """
269
+ Extremely barebones HF model wrapper.
270
+ """
271
+
272
+ def __init__(self,
273
+ config: OlmoConfig,
274
+ quant_config: Optional[QuantizationConfig] = None):
275
+ super().__init__()
276
+ self.config = config
277
+ self.model = OlmoModel(config, quant_config)
278
+ if config.tie_word_embeddings:
279
+ self.lm_head_weight = self.model.embed_tokens.weight
280
+ else:
281
+ self.unpadded_vocab_size = config.vocab_size
282
+ self.lm_head = ParallelLMHead(
283
+ self.unpadded_vocab_size,
284
+ config.hidden_size,
285
+ org_num_embeddings=config.vocab_size,
286
+ )
287
+ self.lm_head_weight = self.lm_head.weight
288
+ self.logits_processor = LogitsProcessor(config.vocab_size)
289
+ self.sampler = Sampler()
290
+
291
+ def forward(
292
+ self,
293
+ input_ids: torch.Tensor,
294
+ positions: torch.Tensor,
295
+ kv_caches: List[torch.Tensor],
296
+ attn_metadata: AttentionMetadata,
297
+ ) -> torch.Tensor:
298
+ hidden_states = self.model(
299
+ input_ids=input_ids,
300
+ positions=positions,
301
+ kv_caches=kv_caches,
302
+ attn_metadata=attn_metadata,
303
+ )
304
+ return hidden_states
305
+
306
+ def compute_logits(self, hidden_states: torch.Tensor,
307
+ sampling_metadata: SamplingMetadata) -> torch.Tensor:
308
+ logits = self.logits_processor(self.lm_head_weight, hidden_states,
309
+ sampling_metadata)
310
+ return logits
311
+
312
+ def sample(
313
+ self,
314
+ logits: torch.Tensor,
315
+ sampling_metadata: SamplingMetadata,
316
+ ) -> Optional[SamplerOutput]:
317
+ next_tokens = self.sampler(logits, sampling_metadata)
318
+ return next_tokens
319
+
320
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
321
+ stacked_params_mapping = [
322
+ # (param_name, shard_name, shard_id)
323
+ ("qkv_proj", "q_proj", "q"),
324
+ ("qkv_proj", "k_proj", "k"),
325
+ ("qkv_proj", "v_proj", "v"),
326
+ ("gate_up_proj", "gate_proj", 0),
327
+ ("gate_up_proj", "up_proj", 1),
328
+ ]
329
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
330
+ for name, loaded_weight in weights:
331
+ if "rotary_emb.inv_freq" in name:
332
+ continue
333
+ if ("rotary_emb.cos_cached" in name
334
+ or "rotary_emb.sin_cached" in name):
335
+ # Models trained using ColossalAI may include these tensors in
336
+ # the checkpoint. Skip them.
337
+ continue
338
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
339
+ if weight_name not in name:
340
+ continue
341
+ name = name.replace(weight_name, param_name)
342
+ # Skip loading extra bias for GPTQ models.
343
+ if name.endswith(".bias") and name not in params_dict:
344
+ continue
345
+ param = params_dict[name]
346
+ weight_loader = param.weight_loader
347
+ weight_loader(param, loaded_weight, shard_id)
348
+ break
349
+ else:
350
+ # Skip loading extra bias for GPTQ models.
351
+ if name.endswith(".bias") and name not in params_dict:
352
+ continue
353
+ param = params_dict[name]
354
+ weight_loader = getattr(param, "weight_loader",
355
+ default_weight_loader)
356
+ weight_loader(param, loaded_weight)
@@ -0,0 +1,349 @@
1
+ # coding=utf-8
2
+ # Adapted from
3
+ # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
4
+ # Copyright 2023 The vLLM team.
5
+ # Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
6
+ # reserved.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ """Inference-only OPT model compatible with HuggingFace weights."""
20
+ from typing import Iterable, List, Optional, Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+ from transformers import OPTConfig
25
+
26
+ from vllm.attention import Attention, AttentionMetadata
27
+ from vllm.distributed import get_tensor_model_parallel_world_size
28
+ from vllm.model_executor.layers.activation import get_act_fn
29
+ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
30
+ QKVParallelLinear,
31
+ ReplicatedLinear,
32
+ RowParallelLinear)
33
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
34
+ from vllm.model_executor.layers.quantization.base_config import (
35
+ QuantizationConfig)
36
+ from vllm.model_executor.layers.sampler import Sampler
37
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
38
+ VocabParallelEmbedding)
39
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
41
+ from vllm.sequence import SamplerOutput
42
+
43
+
44
+ class OPTLearnedPositionalEmbedding(nn.Embedding):
45
+
46
+ def __init__(self, num_embeddings: int, embedding_dim: int):
47
+ # OPT is set up so that if padding_idx is specified then offset the
48
+ # embedding ids by 2 and adjust num_embeddings appropriately. Other
49
+ # models don't have this hack
50
+ self.offset = 2
51
+ super().__init__(num_embeddings + self.offset, embedding_dim)
52
+
53
+ def forward(self, positions: torch.Tensor):
54
+ return super().forward(positions + self.offset)
55
+
56
+
57
+ class OPTAttention(nn.Module):
58
+
59
+ def __init__(
60
+ self,
61
+ embed_dim: int,
62
+ num_heads: int,
63
+ bias: bool = True,
64
+ quant_config: Optional[QuantizationConfig] = None,
65
+ ) -> None:
66
+ super().__init__()
67
+ self.embed_dim = embed_dim
68
+ tensor_model_parallel_world_size = (
69
+ get_tensor_model_parallel_world_size())
70
+ total_num_heads = num_heads
71
+ assert num_heads % tensor_model_parallel_world_size == 0
72
+ self.num_heads = total_num_heads // tensor_model_parallel_world_size
73
+ self.head_dim = embed_dim // total_num_heads
74
+ self.scaling = self.head_dim**-0.5
75
+
76
+ self.qkv_proj = QKVParallelLinear(
77
+ embed_dim,
78
+ self.head_dim,
79
+ total_num_heads,
80
+ bias=bias,
81
+ quant_config=quant_config,
82
+ )
83
+ self.out_proj = RowParallelLinear(
84
+ embed_dim,
85
+ embed_dim,
86
+ bias=bias,
87
+ quant_config=quant_config,
88
+ )
89
+ self.attn = Attention(self.num_heads,
90
+ self.head_dim,
91
+ scale=self.scaling)
92
+
93
+ def forward(
94
+ self,
95
+ hidden_states: torch.Tensor,
96
+ kv_cache: torch.Tensor,
97
+ attn_metadata: AttentionMetadata,
98
+ ) -> torch.Tensor:
99
+ qkv, _ = self.qkv_proj(hidden_states)
100
+ q, k, v = qkv.chunk(chunks=3, dim=-1)
101
+ attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
102
+ output, _ = self.out_proj(attn_output)
103
+ return output
104
+
105
+
106
+ class OPTDecoderLayer(nn.Module):
107
+
108
+ def __init__(
109
+ self,
110
+ config: OPTConfig,
111
+ quant_config: Optional[QuantizationConfig] = None,
112
+ ):
113
+ super().__init__()
114
+ self.config = config
115
+ self.embed_dim = config.hidden_size
116
+ self.self_attn = OPTAttention(
117
+ embed_dim=self.embed_dim,
118
+ num_heads=config.num_attention_heads,
119
+ bias=config.enable_bias,
120
+ quant_config=quant_config,
121
+ )
122
+ self.do_layer_norm_before = config.do_layer_norm_before
123
+
124
+ self.self_attn_layer_norm = nn.LayerNorm(
125
+ self.embed_dim,
126
+ elementwise_affine=config.layer_norm_elementwise_affine)
127
+ self.fc1 = ColumnParallelLinear(
128
+ self.embed_dim,
129
+ config.ffn_dim,
130
+ bias=config.enable_bias,
131
+ quant_config=quant_config,
132
+ )
133
+ self.activation_fn = get_act_fn(config.activation_function,
134
+ quant_config, config.ffn_dim)
135
+ self.fc2 = RowParallelLinear(
136
+ config.ffn_dim,
137
+ self.embed_dim,
138
+ bias=config.enable_bias,
139
+ quant_config=quant_config,
140
+ )
141
+ self.final_layer_norm = nn.LayerNorm(
142
+ self.embed_dim,
143
+ elementwise_affine=config.layer_norm_elementwise_affine)
144
+
145
+ def forward(
146
+ self,
147
+ hidden_states: torch.Tensor,
148
+ kv_cache: torch.Tensor,
149
+ attn_metadata: AttentionMetadata,
150
+ ) -> torch.Tensor:
151
+ # Self Attention
152
+ residual = hidden_states
153
+ # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
154
+ if self.do_layer_norm_before:
155
+ hidden_states = self.self_attn_layer_norm(hidden_states)
156
+ hidden_states = self.self_attn(hidden_states=hidden_states,
157
+ kv_cache=kv_cache,
158
+ attn_metadata=attn_metadata)
159
+ hidden_states = residual + hidden_states
160
+ # 350m applies layer norm AFTER attention
161
+ if not self.do_layer_norm_before:
162
+ hidden_states = self.self_attn_layer_norm(hidden_states)
163
+
164
+ # Fully Connected
165
+ residual = hidden_states
166
+ # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
167
+ if self.do_layer_norm_before:
168
+ hidden_states = self.final_layer_norm(hidden_states)
169
+ hidden_states, _ = self.fc1(hidden_states)
170
+ hidden_states = self.activation_fn(hidden_states)
171
+ hidden_states, _ = self.fc2(hidden_states)
172
+ hidden_states = residual + hidden_states
173
+ # 350m applies layer norm AFTER attention
174
+ if not self.do_layer_norm_before:
175
+ hidden_states = self.final_layer_norm(hidden_states)
176
+ return hidden_states
177
+
178
+
179
+ class OPTDecoder(nn.Module):
180
+
181
+ def __init__(
182
+ self,
183
+ config: OPTConfig,
184
+ quant_config: Optional[QuantizationConfig] = None,
185
+ ):
186
+ super().__init__()
187
+ self.config = config
188
+ self.padding_idx = config.pad_token_id
189
+ self.max_target_positions = config.max_position_embeddings
190
+ self.vocab_size = config.vocab_size
191
+
192
+ self.embed_tokens = VocabParallelEmbedding(
193
+ config.vocab_size,
194
+ config.word_embed_proj_dim,
195
+ )
196
+ # Positional embeddings are replicated (not sharded).
197
+ self.embed_positions = OPTLearnedPositionalEmbedding(
198
+ config.max_position_embeddings, config.hidden_size)
199
+
200
+ # Project out & in will be replicated if they exist.
201
+ if config.word_embed_proj_dim != config.hidden_size:
202
+ self.project_out = ReplicatedLinear(config.hidden_size,
203
+ config.word_embed_proj_dim,
204
+ bias=False,
205
+ quant_config=quant_config)
206
+ else:
207
+ self.project_out = None
208
+
209
+ if config.word_embed_proj_dim != config.hidden_size:
210
+ self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
211
+ config.hidden_size,
212
+ bias=False,
213
+ quant_config=quant_config)
214
+ else:
215
+ self.project_in = None
216
+
217
+ # Note that the only purpose of `config._remove_final_layer_norm` is to
218
+ # keep backward compatibility with checkpoints that have been fine-tuned
219
+ # before transformers v4.20.1
220
+ # see https://github.com/facebookresearch/metaseq/pull/164
221
+ if config.do_layer_norm_before and not config._remove_final_layer_norm:
222
+ self.final_layer_norm = nn.LayerNorm(
223
+ config.hidden_size,
224
+ elementwise_affine=config.layer_norm_elementwise_affine)
225
+ else:
226
+ self.final_layer_norm = None
227
+
228
+ self.layers = nn.ModuleList([
229
+ OPTDecoderLayer(config, quant_config)
230
+ for _ in range(config.num_hidden_layers)
231
+ ])
232
+
233
+ def forward(
234
+ self,
235
+ input_ids: torch.Tensor,
236
+ positions: torch.Tensor,
237
+ kv_caches: List[torch.Tensor],
238
+ attn_metadata: AttentionMetadata,
239
+ ) -> torch.Tensor:
240
+ inputs_embeds = self.embed_tokens(input_ids)
241
+ pos_embeds = self.embed_positions(positions)
242
+ if self.project_in is not None:
243
+ inputs_embeds, _ = self.project_in(inputs_embeds)
244
+ hidden_states = inputs_embeds + pos_embeds
245
+
246
+ for i in range(len(self.layers)):
247
+ layer = self.layers[i]
248
+ hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
249
+
250
+ if self.final_layer_norm is not None:
251
+ hidden_states = self.final_layer_norm(hidden_states)
252
+ if self.project_out is not None:
253
+ hidden_states, _ = self.project_out(hidden_states)
254
+ return hidden_states
255
+
256
+
257
+ class OPTModel(nn.Module):
258
+
259
+ def __init__(
260
+ self,
261
+ config: OPTConfig,
262
+ quant_config: Optional[QuantizationConfig] = None,
263
+ ):
264
+ super().__init__()
265
+ self.decoder = OPTDecoder(config, quant_config)
266
+
267
+ def forward(
268
+ self,
269
+ input_ids: torch.Tensor,
270
+ positions: torch.Tensor,
271
+ kv_caches: List[torch.Tensor],
272
+ attn_metadata: AttentionMetadata,
273
+ ) -> torch.Tensor:
274
+ return self.decoder(input_ids, positions, kv_caches, attn_metadata)
275
+
276
+
277
+ class OPTForCausalLM(nn.Module):
278
+
279
+ def __init__(
280
+ self,
281
+ config,
282
+ quant_config: Optional[QuantizationConfig] = None,
283
+ ):
284
+ super().__init__()
285
+ self.config = config
286
+ self.quant_config = quant_config
287
+ self.model = OPTModel(config, quant_config)
288
+ self.lm_head_weight = self.model.decoder.embed_tokens.weight
289
+ self.logits_processor = LogitsProcessor(config.vocab_size)
290
+ self.sampler = Sampler()
291
+
292
+ def forward(
293
+ self,
294
+ input_ids: torch.Tensor,
295
+ positions: torch.Tensor,
296
+ kv_caches: List[torch.Tensor],
297
+ attn_metadata: AttentionMetadata,
298
+ ) -> torch.Tensor:
299
+ hidden_states = self.model(input_ids, positions, kv_caches,
300
+ attn_metadata)
301
+ return hidden_states
302
+
303
+ def compute_logits(self, hidden_states: torch.Tensor,
304
+ sampling_metadata: SamplingMetadata) -> torch.Tensor:
305
+ logits = self.logits_processor(self.lm_head_weight, hidden_states,
306
+ sampling_metadata)
307
+ return logits
308
+
309
+ def sample(
310
+ self,
311
+ logits: torch.Tensor,
312
+ sampling_metadata: SamplingMetadata,
313
+ ) -> Optional[SamplerOutput]:
314
+ next_tokens = self.sampler(logits, sampling_metadata)
315
+ return next_tokens
316
+
317
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
318
+ stacked_params_mapping = [
319
+ # (param_name, shard_name, shard_id)
320
+ ("qkv_proj", "q_proj", "q"),
321
+ ("qkv_proj", "k_proj", "k"),
322
+ ("qkv_proj", "v_proj", "v"),
323
+ ]
324
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
325
+ for name, loaded_weight in weights:
326
+ if "lm_head.weight" in name:
327
+ continue
328
+ if name.startswith("decoder."):
329
+ name = "model." + name
330
+
331
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
332
+ if weight_name not in name:
333
+ continue
334
+ name = name.replace(weight_name, param_name)
335
+ # Skip loading extra bias for GPTQ models.
336
+ if name.endswith(".bias") and name not in params_dict:
337
+ continue
338
+ param = params_dict[name]
339
+ weight_loader = param.weight_loader
340
+ weight_loader(param, loaded_weight, shard_id)
341
+ break
342
+ else:
343
+ # Skip loading extra bias for GPTQ models.
344
+ if name.endswith(".bias") and name not in params_dict:
345
+ continue
346
+ param = params_dict[name]
347
+ weight_loader = getattr(param, "weight_loader",
348
+ default_weight_loader)
349
+ weight_loader(param, loaded_weight)