vllm-npu 0.4.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. vllm/__init__.py +23 -0
  2. vllm/_custom_ops.py +251 -0
  3. vllm/attention/__init__.py +13 -0
  4. vllm/attention/backends/__init__.py +0 -0
  5. vllm/attention/backends/abstract.py +127 -0
  6. vllm/attention/backends/flash_attn.py +271 -0
  7. vllm/attention/backends/flashinfer.py +220 -0
  8. vllm/attention/backends/rocm_flash_attn.py +374 -0
  9. vllm/attention/backends/torch_sdpa.py +250 -0
  10. vllm/attention/backends/xformers.py +393 -0
  11. vllm/attention/layer.py +56 -0
  12. vllm/attention/ops/__init__.py +0 -0
  13. vllm/attention/ops/paged_attn.py +216 -0
  14. vllm/attention/ops/prefix_prefill.py +792 -0
  15. vllm/attention/ops/triton_flash_attention.py +810 -0
  16. vllm/attention/selector.py +91 -0
  17. vllm/block.py +84 -0
  18. vllm/config.py +1225 -0
  19. vllm/core/__init__.py +0 -0
  20. vllm/core/block/__init__.py +0 -0
  21. vllm/core/block/block_table.py +295 -0
  22. vllm/core/block/common.py +199 -0
  23. vllm/core/block/cpu_gpu_block_allocator.py +228 -0
  24. vllm/core/block/interfaces.py +205 -0
  25. vllm/core/block/naive_block.py +318 -0
  26. vllm/core/block/prefix_caching_block.py +606 -0
  27. vllm/core/block_manager_v1.py +625 -0
  28. vllm/core/block_manager_v2.py +258 -0
  29. vllm/core/evictor_v1.py +105 -0
  30. vllm/core/evictor_v2.py +127 -0
  31. vllm/core/interfaces.py +113 -0
  32. vllm/core/policy.py +45 -0
  33. vllm/core/scheduler.py +1163 -0
  34. vllm/distributed/__init__.py +3 -0
  35. vllm/distributed/communication_op.py +237 -0
  36. vllm/distributed/device_communicators/__init__.py +0 -0
  37. vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
  38. vllm/distributed/device_communicators/pynccl.py +287 -0
  39. vllm/distributed/device_communicators/pynccl_utils.py +66 -0
  40. vllm/distributed/parallel_state.py +339 -0
  41. vllm/distributed/utils.py +136 -0
  42. vllm/engine/__init__.py +0 -0
  43. vllm/engine/arg_utils.py +649 -0
  44. vllm/engine/async_llm_engine.py +737 -0
  45. vllm/engine/llm_engine.py +784 -0
  46. vllm/engine/metrics.py +368 -0
  47. vllm/engine/output_processor/__init__.py +0 -0
  48. vllm/engine/output_processor/interfaces.py +76 -0
  49. vllm/engine/output_processor/multi_step.py +142 -0
  50. vllm/engine/output_processor/single_step.py +284 -0
  51. vllm/engine/output_processor/stop_checker.py +101 -0
  52. vllm/engine/output_processor/util.py +19 -0
  53. vllm/entrypoints/__init__.py +0 -0
  54. vllm/entrypoints/api_server.py +119 -0
  55. vllm/entrypoints/llm.py +259 -0
  56. vllm/entrypoints/openai/__init__.py +0 -0
  57. vllm/entrypoints/openai/api_server.py +186 -0
  58. vllm/entrypoints/openai/cli_args.py +115 -0
  59. vllm/entrypoints/openai/protocol.py +460 -0
  60. vllm/entrypoints/openai/serving_chat.py +392 -0
  61. vllm/entrypoints/openai/serving_completion.py +347 -0
  62. vllm/entrypoints/openai/serving_engine.py +234 -0
  63. vllm/envs.py +217 -0
  64. vllm/executor/__init__.py +0 -0
  65. vllm/executor/cpu_executor.py +152 -0
  66. vllm/executor/distributed_gpu_executor.py +115 -0
  67. vllm/executor/executor_base.py +115 -0
  68. vllm/executor/gpu_executor.py +150 -0
  69. vllm/executor/multiproc_worker_utils.py +263 -0
  70. vllm/executor/neuron_executor.py +91 -0
  71. vllm/executor/ray_gpu_executor.py +327 -0
  72. vllm/executor/ray_utils.py +119 -0
  73. vllm/logger.py +153 -0
  74. vllm/logging/__init__.py +5 -0
  75. vllm/logging/formatter.py +15 -0
  76. vllm/lora/__init__.py +0 -0
  77. vllm/lora/fully_sharded_layers.py +262 -0
  78. vllm/lora/layers.py +1181 -0
  79. vllm/lora/lora.py +167 -0
  80. vllm/lora/models.py +645 -0
  81. vllm/lora/punica.py +213 -0
  82. vllm/lora/request.py +32 -0
  83. vllm/lora/utils.py +98 -0
  84. vllm/lora/worker_manager.py +251 -0
  85. vllm/model_executor/__init__.py +7 -0
  86. vllm/model_executor/guided_decoding/__init__.py +25 -0
  87. vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
  88. vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
  89. vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
  90. vllm/model_executor/layers/__init__.py +0 -0
  91. vllm/model_executor/layers/activation.py +173 -0
  92. vllm/model_executor/layers/fused_moe/__init__.py +7 -0
  93. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  94. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  97. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  98. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  99. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  100. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  101. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  102. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  103. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  104. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  105. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
  106. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  107. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  108. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  109. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  110. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
  111. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  112. vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
  113. vllm/model_executor/layers/layernorm.py +71 -0
  114. vllm/model_executor/layers/linear.py +709 -0
  115. vllm/model_executor/layers/logits_processor.py +115 -0
  116. vllm/model_executor/layers/ops/__init__.py +0 -0
  117. vllm/model_executor/layers/ops/rand.py +157 -0
  118. vllm/model_executor/layers/ops/sample.py +406 -0
  119. vllm/model_executor/layers/quantization/__init__.py +35 -0
  120. vllm/model_executor/layers/quantization/aqlm.py +376 -0
  121. vllm/model_executor/layers/quantization/awq.py +175 -0
  122. vllm/model_executor/layers/quantization/base_config.py +97 -0
  123. vllm/model_executor/layers/quantization/fp8.py +265 -0
  124. vllm/model_executor/layers/quantization/gptq.py +224 -0
  125. vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
  126. vllm/model_executor/layers/quantization/marlin.py +227 -0
  127. vllm/model_executor/layers/quantization/schema.py +84 -0
  128. vllm/model_executor/layers/quantization/squeezellm.py +137 -0
  129. vllm/model_executor/layers/rejection_sampler.py +405 -0
  130. vllm/model_executor/layers/rotary_embedding.py +525 -0
  131. vllm/model_executor/layers/sampler.py +1051 -0
  132. vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
  133. vllm/model_executor/model_loader/__init__.py +30 -0
  134. vllm/model_executor/model_loader/loader.py +362 -0
  135. vllm/model_executor/model_loader/neuron.py +136 -0
  136. vllm/model_executor/model_loader/tensorizer.py +368 -0
  137. vllm/model_executor/model_loader/utils.py +41 -0
  138. vllm/model_executor/model_loader/weight_utils.py +372 -0
  139. vllm/model_executor/models/__init__.py +119 -0
  140. vllm/model_executor/models/baichuan.py +410 -0
  141. vllm/model_executor/models/bloom.py +327 -0
  142. vllm/model_executor/models/chatglm.py +386 -0
  143. vllm/model_executor/models/commandr.py +373 -0
  144. vllm/model_executor/models/dbrx.py +413 -0
  145. vllm/model_executor/models/decilm.py +122 -0
  146. vllm/model_executor/models/deepseek.py +438 -0
  147. vllm/model_executor/models/falcon.py +444 -0
  148. vllm/model_executor/models/gemma.py +393 -0
  149. vllm/model_executor/models/gpt2.py +266 -0
  150. vllm/model_executor/models/gpt_bigcode.py +274 -0
  151. vllm/model_executor/models/gpt_j.py +281 -0
  152. vllm/model_executor/models/gpt_neox.py +295 -0
  153. vllm/model_executor/models/internlm2.py +323 -0
  154. vllm/model_executor/models/jais.py +333 -0
  155. vllm/model_executor/models/llama.py +442 -0
  156. vllm/model_executor/models/llava.py +239 -0
  157. vllm/model_executor/models/minicpm.py +531 -0
  158. vllm/model_executor/models/mixtral.py +583 -0
  159. vllm/model_executor/models/mixtral_quant.py +404 -0
  160. vllm/model_executor/models/mpt.py +295 -0
  161. vllm/model_executor/models/olmo.py +356 -0
  162. vllm/model_executor/models/opt.py +349 -0
  163. vllm/model_executor/models/orion.py +319 -0
  164. vllm/model_executor/models/phi.py +300 -0
  165. vllm/model_executor/models/qwen.py +284 -0
  166. vllm/model_executor/models/qwen2.py +367 -0
  167. vllm/model_executor/models/qwen2_moe.py +447 -0
  168. vllm/model_executor/models/stablelm.py +301 -0
  169. vllm/model_executor/models/starcoder2.py +302 -0
  170. vllm/model_executor/models/xverse.py +366 -0
  171. vllm/model_executor/sampling_metadata.py +588 -0
  172. vllm/model_executor/utils.py +35 -0
  173. vllm/outputs.py +150 -0
  174. vllm/py.typed +2 -0
  175. vllm/sampling_params.py +340 -0
  176. vllm/sequence.py +766 -0
  177. vllm/spec_decode/__init__.py +0 -0
  178. vllm/spec_decode/batch_expansion.py +397 -0
  179. vllm/spec_decode/interfaces.py +73 -0
  180. vllm/spec_decode/metrics.py +191 -0
  181. vllm/spec_decode/multi_step_worker.py +203 -0
  182. vllm/spec_decode/ngram_worker.py +176 -0
  183. vllm/spec_decode/spec_decode_worker.py +472 -0
  184. vllm/spec_decode/top1_proposer.py +200 -0
  185. vllm/spec_decode/util.py +228 -0
  186. vllm/test_utils.py +41 -0
  187. vllm/transformers_utils/__init__.py +0 -0
  188. vllm/transformers_utils/config.py +58 -0
  189. vllm/transformers_utils/configs/__init__.py +16 -0
  190. vllm/transformers_utils/configs/chatglm.py +68 -0
  191. vllm/transformers_utils/configs/dbrx.py +278 -0
  192. vllm/transformers_utils/configs/falcon.py +87 -0
  193. vllm/transformers_utils/configs/jais.py +236 -0
  194. vllm/transformers_utils/configs/mpt.py +178 -0
  195. vllm/transformers_utils/detokenizer.py +313 -0
  196. vllm/transformers_utils/tokenizer.py +149 -0
  197. vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
  198. vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
  199. vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
  200. vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
  201. vllm/transformers_utils/tokenizers/__init__.py +5 -0
  202. vllm/transformers_utils/tokenizers/baichuan.py +255 -0
  203. vllm/usage/__init__.py +0 -0
  204. vllm/usage/usage_lib.py +209 -0
  205. vllm/utils.py +677 -0
  206. vllm/worker/__init__.py +0 -0
  207. vllm/worker/cache_engine.py +105 -0
  208. vllm/worker/cpu_model_runner.py +346 -0
  209. vllm/worker/cpu_worker.py +321 -0
  210. vllm/worker/model_runner.py +1168 -0
  211. vllm/worker/neuron_model_runner.py +196 -0
  212. vllm/worker/neuron_worker.py +98 -0
  213. vllm/worker/worker.py +345 -0
  214. vllm/worker/worker_base.py +146 -0
  215. vllm_npu-0.4.2.dist-info/LICENSE +201 -0
  216. vllm_npu-0.4.2.dist-info/METADATA +173 -0
  217. vllm_npu-0.4.2.dist-info/RECORD +219 -0
  218. vllm_npu-0.4.2.dist-info/WHEEL +5 -0
  219. vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,442 @@
1
+ # coding=utf-8
2
+ # Adapted from
3
+ # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
4
+ # Copyright 2023 The vLLM team.
5
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
6
+ #
7
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
8
+ # and OPT implementations in this library. It has been modified from its
9
+ # original forms to accommodate minor architectural differences compared
10
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
11
+ #
12
+ # Licensed under the Apache License, Version 2.0 (the "License");
13
+ # you may not use this file except in compliance with the License.
14
+ # You may obtain a copy of the License at
15
+ #
16
+ # http://www.apache.org/licenses/LICENSE-2.0
17
+ #
18
+ # Unless required by applicable law or agreed to in writing, software
19
+ # distributed under the License is distributed on an "AS IS" BASIS,
20
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21
+ # See the License for the specific language governing permissions and
22
+ # limitations under the License.
23
+ """Inference-only LLaMA model compatible with HuggingFace weights."""
24
+ from typing import Any, Dict, Iterable, List, Optional, Tuple
25
+
26
+ import torch
27
+ from torch import nn
28
+ from transformers import LlamaConfig
29
+
30
+ from vllm.attention import Attention, AttentionMetadata
31
+ from vllm.config import LoRAConfig
32
+ from vllm.distributed import (get_tensor_model_parallel_rank,
33
+ get_tensor_model_parallel_world_size)
34
+ from vllm.model_executor.layers.activation import SiluAndMul
35
+ from vllm.model_executor.layers.layernorm import RMSNorm
36
+ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
37
+ QKVParallelLinear,
38
+ RowParallelLinear)
39
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
40
+ from vllm.model_executor.layers.quantization.base_config import (
41
+ QuantizationConfig)
42
+ from vllm.model_executor.layers.rotary_embedding import get_rope
43
+ from vllm.model_executor.layers.sampler import Sampler
44
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
45
+ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
46
+ from vllm.model_executor.model_loader.weight_utils import (
47
+ default_weight_loader, kv_cache_scales_loader)
48
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
49
+ from vllm.sequence import SamplerOutput
50
+ from vllm.utils import is_hip
51
+
52
+
53
+ class LlamaMLP(nn.Module):
54
+
55
+ def __init__(
56
+ self,
57
+ hidden_size: int,
58
+ intermediate_size: int,
59
+ hidden_act: str,
60
+ quant_config: Optional[QKVParallelLinear] = None,
61
+ ) -> None:
62
+ super().__init__()
63
+ self.gate_up_proj = MergedColumnParallelLinear(
64
+ hidden_size, [intermediate_size] * 2,
65
+ bias=False,
66
+ quant_config=quant_config)
67
+ self.down_proj = RowParallelLinear(intermediate_size,
68
+ hidden_size,
69
+ bias=False,
70
+ quant_config=quant_config)
71
+ if hidden_act != "silu":
72
+ raise ValueError(f"Unsupported activation: {hidden_act}. "
73
+ "Only silu is supported for now.")
74
+ self.act_fn = SiluAndMul()
75
+
76
+ def forward(self, x):
77
+ gate_up, _ = self.gate_up_proj(x)
78
+ x = self.act_fn(gate_up)
79
+ x, _ = self.down_proj(x)
80
+ return x
81
+
82
+
83
+ class LlamaAttention(nn.Module):
84
+
85
+ def __init__(
86
+ self,
87
+ hidden_size: int,
88
+ num_heads: int,
89
+ num_kv_heads: int,
90
+ rope_theta: float = 10000,
91
+ rope_scaling: Optional[Dict[str, Any]] = None,
92
+ max_position_embeddings: int = 8192,
93
+ quant_config: Optional[QuantizationConfig] = None,
94
+ bias: bool = False,
95
+ sliding_window: Optional[int] = None,
96
+ ) -> None:
97
+ super().__init__()
98
+ self.hidden_size = hidden_size
99
+ tp_size = get_tensor_model_parallel_world_size()
100
+ self.total_num_heads = num_heads
101
+ assert self.total_num_heads % tp_size == 0
102
+ self.num_heads = self.total_num_heads // tp_size
103
+ self.total_num_kv_heads = num_kv_heads
104
+ if self.total_num_kv_heads >= tp_size:
105
+ # Number of KV heads is greater than TP size, so we partition
106
+ # the KV heads across multiple tensor parallel GPUs.
107
+ assert self.total_num_kv_heads % tp_size == 0
108
+ else:
109
+ # Number of KV heads is less than TP size, so we replicate
110
+ # the KV heads across multiple tensor parallel GPUs.
111
+ assert tp_size % self.total_num_kv_heads == 0
112
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
113
+ self.head_dim = hidden_size // self.total_num_heads
114
+ self.q_size = self.num_heads * self.head_dim
115
+ self.kv_size = self.num_kv_heads * self.head_dim
116
+ self.scaling = self.head_dim**-0.5
117
+ self.rope_theta = rope_theta
118
+ self.max_position_embeddings = max_position_embeddings
119
+
120
+ # This will be overwritten by model initialization if we are using it.
121
+ # N.B. currently we only support per tensor scalar scaling factors
122
+ # & only applicable to ROCm (AMD GPU).
123
+ # The scaling factor convention we are assuming is
124
+ # quantized_value * scaling_factor ~= true_value
125
+ # which is consistent with the practice of setting
126
+ # scaling_factor = tensor_amax / FPtype_max
127
+ self.kv_scale = 1.0
128
+
129
+ self.qkv_proj = QKVParallelLinear(
130
+ hidden_size,
131
+ self.head_dim,
132
+ self.total_num_heads,
133
+ self.total_num_kv_heads,
134
+ bias=bias,
135
+ quant_config=quant_config,
136
+ )
137
+ self.o_proj = RowParallelLinear(
138
+ self.total_num_heads * self.head_dim,
139
+ hidden_size,
140
+ bias=bias,
141
+ quant_config=quant_config,
142
+ )
143
+
144
+ self.rotary_emb = get_rope(
145
+ self.head_dim,
146
+ rotary_dim=self.head_dim,
147
+ max_position=max_position_embeddings,
148
+ base=rope_theta,
149
+ rope_scaling=rope_scaling,
150
+ )
151
+ self.attn = Attention(self.num_heads,
152
+ self.head_dim,
153
+ self.scaling,
154
+ num_kv_heads=self.num_kv_heads,
155
+ sliding_window=sliding_window)
156
+
157
+ def forward(
158
+ self,
159
+ positions: torch.Tensor,
160
+ hidden_states: torch.Tensor,
161
+ kv_cache: torch.Tensor,
162
+ attn_metadata: AttentionMetadata,
163
+ ) -> torch.Tensor:
164
+ qkv, _ = self.qkv_proj(hidden_states)
165
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
166
+ q, k = self.rotary_emb(positions, q, k)
167
+ attn_output = self.attn(q, k, v, kv_cache, attn_metadata,
168
+ self.kv_scale)
169
+ output, _ = self.o_proj(attn_output)
170
+ return output
171
+
172
+
173
+ class LlamaDecoderLayer(nn.Module):
174
+
175
+ def __init__(
176
+ self,
177
+ config: LlamaConfig,
178
+ quant_config: Optional[QuantizationConfig] = None,
179
+ ) -> None:
180
+ super().__init__()
181
+ self.hidden_size = config.hidden_size
182
+ rope_theta = getattr(config, "rope_theta", 10000)
183
+ rope_scaling = getattr(config, "rope_scaling", None)
184
+ if rope_scaling is not None and getattr(
185
+ config, "original_max_position_embeddings", None):
186
+ rope_scaling["original_max_position_embeddings"] = (
187
+ config.original_max_position_embeddings)
188
+ max_position_embeddings = getattr(config, "max_position_embeddings",
189
+ 8192)
190
+ sliding_window = getattr(config, "sliding_window", None)
191
+ # Support abacusai/Smaug-72B-v0.1 with attention_bias
192
+ # Support internlm/internlm-7b with bias
193
+ attention_bias = getattr(config, "attention_bias", False) or getattr(
194
+ config, "bias", False)
195
+ self.self_attn = LlamaAttention(
196
+ hidden_size=self.hidden_size,
197
+ num_heads=config.num_attention_heads,
198
+ num_kv_heads=getattr(config, "num_key_value_heads",
199
+ config.num_attention_heads),
200
+ rope_theta=rope_theta,
201
+ rope_scaling=rope_scaling,
202
+ max_position_embeddings=max_position_embeddings,
203
+ quant_config=quant_config,
204
+ bias=attention_bias,
205
+ sliding_window=sliding_window,
206
+ )
207
+ self.mlp = LlamaMLP(
208
+ hidden_size=self.hidden_size,
209
+ intermediate_size=config.intermediate_size,
210
+ hidden_act=config.hidden_act,
211
+ quant_config=quant_config,
212
+ )
213
+ self.input_layernorm = RMSNorm(config.hidden_size,
214
+ eps=config.rms_norm_eps)
215
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
216
+ eps=config.rms_norm_eps)
217
+
218
+ def forward(
219
+ self,
220
+ positions: torch.Tensor,
221
+ hidden_states: torch.Tensor,
222
+ kv_cache: torch.Tensor,
223
+ attn_metadata: AttentionMetadata,
224
+ residual: Optional[torch.Tensor],
225
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
226
+ # Self Attention
227
+ if residual is None:
228
+ residual = hidden_states
229
+ hidden_states = self.input_layernorm(hidden_states)
230
+ else:
231
+ hidden_states, residual = self.input_layernorm(
232
+ hidden_states, residual)
233
+ hidden_states = self.self_attn(
234
+ positions=positions,
235
+ hidden_states=hidden_states,
236
+ kv_cache=kv_cache,
237
+ attn_metadata=attn_metadata,
238
+ )
239
+
240
+ # Fully Connected
241
+ hidden_states, residual = self.post_attention_layernorm(
242
+ hidden_states, residual)
243
+ hidden_states = self.mlp(hidden_states)
244
+ return hidden_states, residual
245
+
246
+
247
+ class LlamaModel(nn.Module):
248
+
249
+ def __init__(
250
+ self,
251
+ config: LlamaConfig,
252
+ quant_config: Optional[QuantizationConfig] = None,
253
+ lora_config: Optional[LoRAConfig] = None,
254
+ ) -> None:
255
+ super().__init__()
256
+ self.config = config
257
+ self.padding_idx = config.pad_token_id
258
+ lora_vocab = (lora_config.lora_extra_vocab_size *
259
+ (lora_config.max_loras or 1)) if lora_config else 0
260
+ self.vocab_size = config.vocab_size + lora_vocab
261
+ self.org_vocab_size = config.vocab_size
262
+ self.embed_tokens = VocabParallelEmbedding(
263
+ self.vocab_size,
264
+ config.hidden_size,
265
+ org_num_embeddings=config.vocab_size,
266
+ )
267
+ self.layers = nn.ModuleList([
268
+ LlamaDecoderLayer(config, quant_config)
269
+ for _ in range(config.num_hidden_layers)
270
+ ])
271
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
272
+
273
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
274
+ return self.embed_tokens(input_ids)
275
+
276
+ def forward(
277
+ self,
278
+ input_ids: Optional[torch.Tensor],
279
+ positions: torch.Tensor,
280
+ kv_caches: List[torch.Tensor],
281
+ attn_metadata: AttentionMetadata,
282
+ inputs_embeds: Optional[torch.Tensor] = None,
283
+ ) -> torch.Tensor:
284
+ if inputs_embeds is not None:
285
+ hidden_states = inputs_embeds
286
+ else:
287
+ hidden_states = self.get_input_embeddings(input_ids)
288
+ residual = None
289
+ for i in range(len(self.layers)):
290
+ layer = self.layers[i]
291
+ hidden_states, residual = layer(
292
+ positions,
293
+ hidden_states,
294
+ kv_caches[i],
295
+ attn_metadata,
296
+ residual,
297
+ )
298
+ hidden_states, _ = self.norm(hidden_states, residual)
299
+ return hidden_states
300
+
301
+
302
+ class LlamaForCausalLM(nn.Module):
303
+ packed_modules_mapping = {
304
+ "qkv_proj": [
305
+ "q_proj",
306
+ "k_proj",
307
+ "v_proj",
308
+ ],
309
+ "gate_up_proj": [
310
+ "gate_proj",
311
+ "up_proj",
312
+ ],
313
+ }
314
+
315
+ # LoRA specific attributes
316
+ supported_lora_modules = [
317
+ "qkv_proj",
318
+ "o_proj",
319
+ "gate_up_proj",
320
+ "down_proj",
321
+ "embed_tokens",
322
+ "lm_head",
323
+ ]
324
+ embedding_modules = {
325
+ "embed_tokens": "input_embeddings",
326
+ "lm_head": "output_embeddings",
327
+ }
328
+ embedding_padding_modules = ["lm_head"]
329
+
330
+ def __init__(
331
+ self,
332
+ config: LlamaConfig,
333
+ quant_config: Optional[QuantizationConfig] = None,
334
+ lora_config: Optional[LoRAConfig] = None,
335
+ ) -> None:
336
+ super().__init__()
337
+ self.config = config
338
+ self.model = LlamaModel(config, quant_config, lora_config=lora_config)
339
+ self.unpadded_vocab_size = config.vocab_size
340
+ if lora_config:
341
+ self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
342
+ self.lm_head = ParallelLMHead(
343
+ self.unpadded_vocab_size,
344
+ config.hidden_size,
345
+ org_num_embeddings=config.vocab_size,
346
+ padding_size=DEFAULT_VOCAB_PADDING_SIZE
347
+ # We need bigger padding if using lora for kernel
348
+ # compatibility
349
+ if not lora_config else lora_config.lora_vocab_padding_size,
350
+ )
351
+
352
+ logit_scale = getattr(config, "logit_scale", 1.0)
353
+ self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
354
+ config.vocab_size, logit_scale)
355
+ self.sampler = Sampler()
356
+
357
+ def forward(
358
+ self,
359
+ input_ids: torch.Tensor,
360
+ positions: torch.Tensor,
361
+ kv_caches: List[torch.Tensor],
362
+ attn_metadata: AttentionMetadata,
363
+ ) -> torch.Tensor:
364
+ hidden_states = self.model(input_ids, positions, kv_caches,
365
+ attn_metadata)
366
+ return hidden_states
367
+
368
+ def compute_logits(self, hidden_states: torch.Tensor,
369
+ sampling_metadata: SamplingMetadata) -> torch.Tensor:
370
+ logits = self.logits_processor(self.lm_head.weight, hidden_states,
371
+ sampling_metadata)
372
+ return logits
373
+
374
+ def sample(
375
+ self,
376
+ logits: torch.Tensor,
377
+ sampling_metadata: SamplingMetadata,
378
+ ) -> Optional[SamplerOutput]:
379
+ next_tokens = self.sampler(logits, sampling_metadata)
380
+ return next_tokens
381
+
382
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
383
+ stacked_params_mapping = [
384
+ # (param_name, shard_name, shard_id)
385
+ (".qkv_proj", ".q_proj", "q"),
386
+ (".qkv_proj", ".k_proj", "k"),
387
+ (".qkv_proj", ".v_proj", "v"),
388
+ (".gate_up_proj", ".gate_proj", 0),
389
+ (".gate_up_proj", ".up_proj", 1),
390
+ ]
391
+ params_dict = dict(self.named_parameters())
392
+ for name, loaded_weight in weights:
393
+ if "rotary_emb.inv_freq" in name:
394
+ continue
395
+ if ("rotary_emb.cos_cached" in name
396
+ or "rotary_emb.sin_cached" in name):
397
+ # Models trained using ColossalAI may include these tensors in
398
+ # the checkpoint. Skip them.
399
+ continue
400
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
401
+ if weight_name not in name:
402
+ continue
403
+ name = name.replace(weight_name, param_name)
404
+ # Skip loading extra bias for GPTQ models.
405
+ if name.endswith(".bias") and name not in params_dict:
406
+ continue
407
+ param = params_dict[name]
408
+ weight_loader = param.weight_loader
409
+ weight_loader(param, loaded_weight, shard_id)
410
+ break
411
+ else:
412
+ # Skip loading extra bias for GPTQ models.
413
+ if name.endswith(".bias") and name not in params_dict:
414
+ continue
415
+ param = params_dict[name]
416
+ weight_loader = getattr(param, "weight_loader",
417
+ default_weight_loader)
418
+ weight_loader(param, loaded_weight)
419
+
420
+ # If this function is called, it should always initialize KV cache scale
421
+ # factors (or else raise an exception). Thus, handled exceptions should
422
+ # make sure to leave KV cache scale factors in a known good (dummy) state
423
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
424
+ tp_size = get_tensor_model_parallel_world_size()
425
+ tp_rank = get_tensor_model_parallel_rank()
426
+ for layer_idx, scaling_factor in kv_cache_scales_loader(
427
+ quantization_param_path, tp_rank, tp_size,
428
+ self.config.num_hidden_layers,
429
+ self.config.__class__.model_type):
430
+ layer_self_attn = self.model.layers[layer_idx].self_attn
431
+
432
+ if is_hip():
433
+ # The scaling factor convention we are assuming is
434
+ # quantized_value * scaling_factor ~= true_value
435
+ # which is consistent with the practice of setting
436
+ # scaling_factor = tensor_amax / FPtype_max
437
+ scaling_factor *= 2
438
+ if hasattr(layer_self_attn, "kv_scale"):
439
+ layer_self_attn.kv_scale = scaling_factor
440
+ else:
441
+ raise RuntimeError("Self attention has no KV cache scaling "
442
+ "factor attribute!")
@@ -0,0 +1,239 @@
1
+ from typing import Iterable, List, Optional, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ # TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on
6
+ # transformers' impl.
7
+ from transformers import CLIPVisionModel, LlavaConfig
8
+
9
+ from vllm.attention import AttentionMetadata
10
+ from vllm.config import VisionLanguageConfig
11
+ from vllm.model_executor.layers.activation import get_act_fn
12
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
13
+ from vllm.model_executor.layers.quantization.base_config import (
14
+ QuantizationConfig)
15
+ from vllm.model_executor.layers.sampler import Sampler
16
+ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
17
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
18
+ from vllm.model_executor.models.llama import LlamaModel
19
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
20
+ from vllm.sequence import SamplerOutput
21
+
22
+ _KEYS_TO_MODIFY_MAPPING = {
23
+ "language_model.lm_head": "lm_head",
24
+ "language_model.model": "language_model",
25
+ }
26
+
27
+
28
+ # TODO(xwjiang): Run benchmark and decide if TP.
29
+ class LlavaMultiModalProjector(nn.Module):
30
+
31
+ def __init__(self, vision_hidden_size: int, text_hidden_size: int,
32
+ projector_hidden_act: str):
33
+ super().__init__()
34
+
35
+ self.linear_1 = nn.Linear(vision_hidden_size,
36
+ text_hidden_size,
37
+ bias=True)
38
+ self.act = get_act_fn(projector_hidden_act)
39
+ self.linear_2 = nn.Linear(text_hidden_size,
40
+ text_hidden_size,
41
+ bias=True)
42
+
43
+ def forward(self, image_features):
44
+ hidden_states = self.linear_1(image_features)
45
+ hidden_states = self.act(hidden_states)
46
+ hidden_states = self.linear_2(hidden_states)
47
+ return hidden_states
48
+
49
+
50
+ def _merge_vision_embeddings(input_ids: torch.Tensor,
51
+ inputs_embeds: torch.Tensor,
52
+ vision_embeddings: torch.Tensor,
53
+ image_token_id: int):
54
+ """In place merges in vision_embeddings with inputs_embeds."""
55
+ mask = (input_ids == image_token_id)
56
+ inputs_embeds[mask] = vision_embeddings.view(-1,
57
+ vision_embeddings.shape[-1])
58
+
59
+
60
+ class LlavaForConditionalGeneration(nn.Module):
61
+
62
+ def __init__(self,
63
+ config: "LlavaConfig",
64
+ vision_language_config: VisionLanguageConfig,
65
+ quant_config: Optional["QuantizationConfig"] = None) -> None:
66
+ super().__init__()
67
+ self.config = config
68
+
69
+ self.vision_language_config = vision_language_config
70
+
71
+ assert self.vision_language_config, (
72
+ "Provide `image_input_type` and other vision "
73
+ "related configurations through LLM entrypoint "
74
+ "or engine arguments.")
75
+
76
+ if self.vision_language_config.image_input_type == (
77
+ VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
78
+ self.vision_tower = CLIPVisionModel(config.vision_config)
79
+ else:
80
+ self.vision_tower = None
81
+
82
+ self.multi_modal_projector = LlavaMultiModalProjector(
83
+ vision_hidden_size=config.vision_config.hidden_size,
84
+ text_hidden_size=config.text_config.hidden_size,
85
+ projector_hidden_act=config.projector_hidden_act)
86
+
87
+ self.quant_config = quant_config
88
+ self.language_model = LlamaModel(config.text_config, quant_config)
89
+ self.unpadded_vocab_size = config.text_config.vocab_size
90
+ self.lm_head = ParallelLMHead(
91
+ self.unpadded_vocab_size,
92
+ config.text_config.hidden_size,
93
+ org_num_embeddings=self.language_model.org_vocab_size)
94
+ logit_scale = getattr(config, "logit_scale", 1.0)
95
+ self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
96
+ config.vocab_size, logit_scale)
97
+ self.sampler = Sampler()
98
+
99
+ def forward(
100
+ self,
101
+ input_ids: torch.Tensor,
102
+ positions: torch.Tensor,
103
+ kv_caches: List[torch.Tensor],
104
+ attn_metadata: AttentionMetadata,
105
+ image_input: Optional[torch.Tensor] = None
106
+ ) -> SamplerOutput: # noqa: E501
107
+ """Run forward pass for Llava 1.5.
108
+
109
+ One key thing to understand is the `input_ids` already accounts for the
110
+ positions of the to-be-inserted image embeddings.
111
+ Concretely, consider a text prompt:
112
+ "<image>\nUSER: What's the content of the image?\nASSISTANT:".
113
+ Tokenizer outputs:
114
+ [1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278,
115
+ 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901].
116
+ The to-be-inserted image has a size of 576 (24 * 24) along the context
117
+ length dimension.
118
+ `input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901,
119
+ 1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933,
120
+ 9047, 13566, 29901].
121
+ There will be 576 `32000` in the `input_ids`.
122
+ (32000 is the token id for `<image>`.)
123
+
124
+ This way, the `positions` and `attn_metadata` are consistent
125
+ with the `input_ids`.
126
+
127
+ The model takes two types of image inputs:
128
+ PIXEL_VALUES and IMAGE_FEATURES.
129
+ The following shows how each maps to huggingface implementation.
130
+ PIXEL_VALUES:
131
+ - https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L353
132
+ IMAGE_FEATURES:
133
+ - https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L430
134
+ before going through the multi modal projector.
135
+
136
+ Args:
137
+ input_ids: Flattened (concatenated) input_ids corresponding to a
138
+ batch.
139
+ image_input: A batch of image inputs.
140
+ For PIXEL_VALUES, expecting [1, 3, 336, 336].
141
+ For IMAGE_FEATURES, expecting [1, 576, 1024].
142
+ """
143
+ if image_input is not None:
144
+ if list(image_input.shape[1:]) != list(
145
+ self.vision_language_config.image_input_shape[1:]):
146
+ raise ValueError(
147
+ f"The expected image tensor shape is batch dimension "
148
+ f"plus "
149
+ f"{self.vision_language_config.image_input_shape[1:]}."
150
+ f" You supplied {image_input.shape}. "
151
+ f"If you are using vLLM's entrypoint, make sure your "
152
+ f"supplied image input is consistent with "
153
+ f"image_input_shape in engine args.")
154
+ if self.vision_tower is not None:
155
+ # TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
156
+ image_outputs = self.vision_tower(image_input,
157
+ output_hidden_states=True)
158
+ image_features = image_outputs.hidden_states[
159
+ self.config.vision_feature_layer]
160
+ # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
161
+ if self.config.vision_feature_select_strategy == "default":
162
+ image_features = image_features[:, 1:]
163
+ elif self.config.vision_feature_select_strategy == "full":
164
+ image_features = image_features
165
+ else:
166
+ raise ValueError(
167
+ f"Unexpected select feature strategy: "
168
+ f"{self.config.vision_feature_select_strategy}")
169
+ else:
170
+ image_features = image_input
171
+ vision_embeddings = self.multi_modal_projector(image_features)
172
+ inputs_embeds = self.language_model.get_input_embeddings(input_ids)
173
+ _merge_vision_embeddings(
174
+ input_ids, inputs_embeds, vision_embeddings,
175
+ self.vision_language_config.image_token_id)
176
+ input_ids = None
177
+ else:
178
+ inputs_embeds = None
179
+ hidden_states = self.language_model(input_ids,
180
+ positions,
181
+ kv_caches,
182
+ attn_metadata,
183
+ inputs_embeds=inputs_embeds)
184
+
185
+ return hidden_states
186
+
187
+ def compute_logits(self, hidden_states: torch.Tensor,
188
+ sampling_metadata: SamplingMetadata) -> torch.Tensor:
189
+ logits = self.logits_processor(self.lm_head.weight, hidden_states,
190
+ sampling_metadata)
191
+ return logits
192
+
193
+ def sample(
194
+ self,
195
+ logits: torch.Tensor,
196
+ sampling_metadata: SamplingMetadata,
197
+ ) -> Optional[SamplerOutput]:
198
+ next_tokens = self.sampler(logits, sampling_metadata)
199
+ return next_tokens
200
+
201
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
202
+ # only doing this for language model part for now.
203
+ stacked_params_mapping = [
204
+ # (param_name, shard_name, shard_id)
205
+ ("qkv_proj", "q_proj", "q"),
206
+ ("qkv_proj", "k_proj", "k"),
207
+ ("qkv_proj", "v_proj", "v"),
208
+ ("gate_up_proj", "gate_proj", 0),
209
+ ("gate_up_proj", "up_proj", 1),
210
+ ]
211
+ params_dict = dict(self.named_parameters())
212
+ for name, loaded_weight in weights:
213
+ if "rotary_emb.inv_freq" in name:
214
+ continue
215
+ for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
216
+ if key_to_modify in name:
217
+ name = name.replace(key_to_modify, new_key)
218
+ use_default_weight_loading = False
219
+ if "vision" in name:
220
+ if self.vision_tower is not None:
221
+ # We only do sharding for language model and
222
+ # not vision model for now.
223
+ use_default_weight_loading = True
224
+ else:
225
+ for (param_name, weight_name,
226
+ shard_id) in stacked_params_mapping:
227
+ if weight_name not in name:
228
+ continue
229
+ param = params_dict[name.replace(weight_name, param_name)]
230
+ weight_loader = param.weight_loader
231
+ weight_loader(param, loaded_weight, shard_id)
232
+ break
233
+ else:
234
+ use_default_weight_loading = True
235
+ if use_default_weight_loading:
236
+ param = params_dict[name]
237
+ weight_loader = getattr(param, "weight_loader",
238
+ default_weight_loader)
239
+ weight_loader(param, loaded_weight)