vllm-npu 0.4.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. vllm/__init__.py +23 -0
  2. vllm/_custom_ops.py +251 -0
  3. vllm/attention/__init__.py +13 -0
  4. vllm/attention/backends/__init__.py +0 -0
  5. vllm/attention/backends/abstract.py +127 -0
  6. vllm/attention/backends/flash_attn.py +271 -0
  7. vllm/attention/backends/flashinfer.py +220 -0
  8. vllm/attention/backends/rocm_flash_attn.py +374 -0
  9. vllm/attention/backends/torch_sdpa.py +250 -0
  10. vllm/attention/backends/xformers.py +393 -0
  11. vllm/attention/layer.py +56 -0
  12. vllm/attention/ops/__init__.py +0 -0
  13. vllm/attention/ops/paged_attn.py +216 -0
  14. vllm/attention/ops/prefix_prefill.py +792 -0
  15. vllm/attention/ops/triton_flash_attention.py +810 -0
  16. vllm/attention/selector.py +91 -0
  17. vllm/block.py +84 -0
  18. vllm/config.py +1225 -0
  19. vllm/core/__init__.py +0 -0
  20. vllm/core/block/__init__.py +0 -0
  21. vllm/core/block/block_table.py +295 -0
  22. vllm/core/block/common.py +199 -0
  23. vllm/core/block/cpu_gpu_block_allocator.py +228 -0
  24. vllm/core/block/interfaces.py +205 -0
  25. vllm/core/block/naive_block.py +318 -0
  26. vllm/core/block/prefix_caching_block.py +606 -0
  27. vllm/core/block_manager_v1.py +625 -0
  28. vllm/core/block_manager_v2.py +258 -0
  29. vllm/core/evictor_v1.py +105 -0
  30. vllm/core/evictor_v2.py +127 -0
  31. vllm/core/interfaces.py +113 -0
  32. vllm/core/policy.py +45 -0
  33. vllm/core/scheduler.py +1163 -0
  34. vllm/distributed/__init__.py +3 -0
  35. vllm/distributed/communication_op.py +237 -0
  36. vllm/distributed/device_communicators/__init__.py +0 -0
  37. vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
  38. vllm/distributed/device_communicators/pynccl.py +287 -0
  39. vllm/distributed/device_communicators/pynccl_utils.py +66 -0
  40. vllm/distributed/parallel_state.py +339 -0
  41. vllm/distributed/utils.py +136 -0
  42. vllm/engine/__init__.py +0 -0
  43. vllm/engine/arg_utils.py +649 -0
  44. vllm/engine/async_llm_engine.py +737 -0
  45. vllm/engine/llm_engine.py +784 -0
  46. vllm/engine/metrics.py +368 -0
  47. vllm/engine/output_processor/__init__.py +0 -0
  48. vllm/engine/output_processor/interfaces.py +76 -0
  49. vllm/engine/output_processor/multi_step.py +142 -0
  50. vllm/engine/output_processor/single_step.py +284 -0
  51. vllm/engine/output_processor/stop_checker.py +101 -0
  52. vllm/engine/output_processor/util.py +19 -0
  53. vllm/entrypoints/__init__.py +0 -0
  54. vllm/entrypoints/api_server.py +119 -0
  55. vllm/entrypoints/llm.py +259 -0
  56. vllm/entrypoints/openai/__init__.py +0 -0
  57. vllm/entrypoints/openai/api_server.py +186 -0
  58. vllm/entrypoints/openai/cli_args.py +115 -0
  59. vllm/entrypoints/openai/protocol.py +460 -0
  60. vllm/entrypoints/openai/serving_chat.py +392 -0
  61. vllm/entrypoints/openai/serving_completion.py +347 -0
  62. vllm/entrypoints/openai/serving_engine.py +234 -0
  63. vllm/envs.py +217 -0
  64. vllm/executor/__init__.py +0 -0
  65. vllm/executor/cpu_executor.py +152 -0
  66. vllm/executor/distributed_gpu_executor.py +115 -0
  67. vllm/executor/executor_base.py +115 -0
  68. vllm/executor/gpu_executor.py +150 -0
  69. vllm/executor/multiproc_worker_utils.py +263 -0
  70. vllm/executor/neuron_executor.py +91 -0
  71. vllm/executor/ray_gpu_executor.py +327 -0
  72. vllm/executor/ray_utils.py +119 -0
  73. vllm/logger.py +153 -0
  74. vllm/logging/__init__.py +5 -0
  75. vllm/logging/formatter.py +15 -0
  76. vllm/lora/__init__.py +0 -0
  77. vllm/lora/fully_sharded_layers.py +262 -0
  78. vllm/lora/layers.py +1181 -0
  79. vllm/lora/lora.py +167 -0
  80. vllm/lora/models.py +645 -0
  81. vllm/lora/punica.py +213 -0
  82. vllm/lora/request.py +32 -0
  83. vllm/lora/utils.py +98 -0
  84. vllm/lora/worker_manager.py +251 -0
  85. vllm/model_executor/__init__.py +7 -0
  86. vllm/model_executor/guided_decoding/__init__.py +25 -0
  87. vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
  88. vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
  89. vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
  90. vllm/model_executor/layers/__init__.py +0 -0
  91. vllm/model_executor/layers/activation.py +173 -0
  92. vllm/model_executor/layers/fused_moe/__init__.py +7 -0
  93. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  94. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  97. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  98. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  99. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  100. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  101. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  102. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  103. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  104. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  105. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
  106. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  107. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  108. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  109. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  110. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
  111. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  112. vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
  113. vllm/model_executor/layers/layernorm.py +71 -0
  114. vllm/model_executor/layers/linear.py +709 -0
  115. vllm/model_executor/layers/logits_processor.py +115 -0
  116. vllm/model_executor/layers/ops/__init__.py +0 -0
  117. vllm/model_executor/layers/ops/rand.py +157 -0
  118. vllm/model_executor/layers/ops/sample.py +406 -0
  119. vllm/model_executor/layers/quantization/__init__.py +35 -0
  120. vllm/model_executor/layers/quantization/aqlm.py +376 -0
  121. vllm/model_executor/layers/quantization/awq.py +175 -0
  122. vllm/model_executor/layers/quantization/base_config.py +97 -0
  123. vllm/model_executor/layers/quantization/fp8.py +265 -0
  124. vllm/model_executor/layers/quantization/gptq.py +224 -0
  125. vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
  126. vllm/model_executor/layers/quantization/marlin.py +227 -0
  127. vllm/model_executor/layers/quantization/schema.py +84 -0
  128. vllm/model_executor/layers/quantization/squeezellm.py +137 -0
  129. vllm/model_executor/layers/rejection_sampler.py +405 -0
  130. vllm/model_executor/layers/rotary_embedding.py +525 -0
  131. vllm/model_executor/layers/sampler.py +1051 -0
  132. vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
  133. vllm/model_executor/model_loader/__init__.py +30 -0
  134. vllm/model_executor/model_loader/loader.py +362 -0
  135. vllm/model_executor/model_loader/neuron.py +136 -0
  136. vllm/model_executor/model_loader/tensorizer.py +368 -0
  137. vllm/model_executor/model_loader/utils.py +41 -0
  138. vllm/model_executor/model_loader/weight_utils.py +372 -0
  139. vllm/model_executor/models/__init__.py +119 -0
  140. vllm/model_executor/models/baichuan.py +410 -0
  141. vllm/model_executor/models/bloom.py +327 -0
  142. vllm/model_executor/models/chatglm.py +386 -0
  143. vllm/model_executor/models/commandr.py +373 -0
  144. vllm/model_executor/models/dbrx.py +413 -0
  145. vllm/model_executor/models/decilm.py +122 -0
  146. vllm/model_executor/models/deepseek.py +438 -0
  147. vllm/model_executor/models/falcon.py +444 -0
  148. vllm/model_executor/models/gemma.py +393 -0
  149. vllm/model_executor/models/gpt2.py +266 -0
  150. vllm/model_executor/models/gpt_bigcode.py +274 -0
  151. vllm/model_executor/models/gpt_j.py +281 -0
  152. vllm/model_executor/models/gpt_neox.py +295 -0
  153. vllm/model_executor/models/internlm2.py +323 -0
  154. vllm/model_executor/models/jais.py +333 -0
  155. vllm/model_executor/models/llama.py +442 -0
  156. vllm/model_executor/models/llava.py +239 -0
  157. vllm/model_executor/models/minicpm.py +531 -0
  158. vllm/model_executor/models/mixtral.py +583 -0
  159. vllm/model_executor/models/mixtral_quant.py +404 -0
  160. vllm/model_executor/models/mpt.py +295 -0
  161. vllm/model_executor/models/olmo.py +356 -0
  162. vllm/model_executor/models/opt.py +349 -0
  163. vllm/model_executor/models/orion.py +319 -0
  164. vllm/model_executor/models/phi.py +300 -0
  165. vllm/model_executor/models/qwen.py +284 -0
  166. vllm/model_executor/models/qwen2.py +367 -0
  167. vllm/model_executor/models/qwen2_moe.py +447 -0
  168. vllm/model_executor/models/stablelm.py +301 -0
  169. vllm/model_executor/models/starcoder2.py +302 -0
  170. vllm/model_executor/models/xverse.py +366 -0
  171. vllm/model_executor/sampling_metadata.py +588 -0
  172. vllm/model_executor/utils.py +35 -0
  173. vllm/outputs.py +150 -0
  174. vllm/py.typed +2 -0
  175. vllm/sampling_params.py +340 -0
  176. vllm/sequence.py +766 -0
  177. vllm/spec_decode/__init__.py +0 -0
  178. vllm/spec_decode/batch_expansion.py +397 -0
  179. vllm/spec_decode/interfaces.py +73 -0
  180. vllm/spec_decode/metrics.py +191 -0
  181. vllm/spec_decode/multi_step_worker.py +203 -0
  182. vllm/spec_decode/ngram_worker.py +176 -0
  183. vllm/spec_decode/spec_decode_worker.py +472 -0
  184. vllm/spec_decode/top1_proposer.py +200 -0
  185. vllm/spec_decode/util.py +228 -0
  186. vllm/test_utils.py +41 -0
  187. vllm/transformers_utils/__init__.py +0 -0
  188. vllm/transformers_utils/config.py +58 -0
  189. vllm/transformers_utils/configs/__init__.py +16 -0
  190. vllm/transformers_utils/configs/chatglm.py +68 -0
  191. vllm/transformers_utils/configs/dbrx.py +278 -0
  192. vllm/transformers_utils/configs/falcon.py +87 -0
  193. vllm/transformers_utils/configs/jais.py +236 -0
  194. vllm/transformers_utils/configs/mpt.py +178 -0
  195. vllm/transformers_utils/detokenizer.py +313 -0
  196. vllm/transformers_utils/tokenizer.py +149 -0
  197. vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
  198. vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
  199. vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
  200. vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
  201. vllm/transformers_utils/tokenizers/__init__.py +5 -0
  202. vllm/transformers_utils/tokenizers/baichuan.py +255 -0
  203. vllm/usage/__init__.py +0 -0
  204. vllm/usage/usage_lib.py +209 -0
  205. vllm/utils.py +677 -0
  206. vllm/worker/__init__.py +0 -0
  207. vllm/worker/cache_engine.py +105 -0
  208. vllm/worker/cpu_model_runner.py +346 -0
  209. vllm/worker/cpu_worker.py +321 -0
  210. vllm/worker/model_runner.py +1168 -0
  211. vllm/worker/neuron_model_runner.py +196 -0
  212. vllm/worker/neuron_worker.py +98 -0
  213. vllm/worker/worker.py +345 -0
  214. vllm/worker/worker_base.py +146 -0
  215. vllm_npu-0.4.2.dist-info/LICENSE +201 -0
  216. vllm_npu-0.4.2.dist-info/METADATA +173 -0
  217. vllm_npu-0.4.2.dist-info/RECORD +219 -0
  218. vllm_npu-0.4.2.dist-info/WHEEL +5 -0
  219. vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,525 @@
1
+ # coding=utf-8
2
+ # Adapted from
3
+ # https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
4
+ # Copyright 2023 The vLLM team.
5
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
6
+ #
7
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
8
+ # and OPT implementations in this library. It has been modified from its
9
+ # original forms to accommodate minor architectural differences compared
10
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
11
+ #
12
+ # Licensed under the Apache License, Version 2.0 (the "License");
13
+ # you may not use this file except in compliance with the License.
14
+ # You may obtain a copy of the License at
15
+ #
16
+ # http://www.apache.org/licenses/LICENSE-2.0
17
+ #
18
+ # Unless required by applicable law or agreed to in writing, software
19
+ # distributed under the License is distributed on an "AS IS" BASIS,
20
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21
+ # See the License for the specific language governing permissions and
22
+ # limitations under the License.
23
+ """Rotary Positional Embeddings."""
24
+ import math
25
+ from typing import Any, Dict, List, Optional, Tuple, Union
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+
30
+ from vllm import _custom_ops as ops
31
+
32
+
33
+ def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
34
+ x1 = x[..., :x.shape[-1] // 2]
35
+ x2 = x[..., x.shape[-1] // 2:]
36
+ return torch.cat((-x2, x1), dim=-1)
37
+
38
+
39
+ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
40
+ x1 = x[..., ::2]
41
+ x2 = x[..., 1::2]
42
+ x = torch.stack((-x2, x1), dim=-1)
43
+ return x.flatten(-2)
44
+
45
+
46
+ class RotaryEmbedding(nn.Module):
47
+ """Original rotary positional embedding."""
48
+
49
+ def __init__(
50
+ self,
51
+ head_size: int,
52
+ rotary_dim: int,
53
+ max_position_embeddings: int,
54
+ base: int,
55
+ is_neox_style: bool,
56
+ ) -> None:
57
+ super().__init__()
58
+ self.head_size = head_size
59
+ self.rotary_dim = rotary_dim
60
+ self.max_position_embeddings = max_position_embeddings
61
+ self.base = base
62
+ self.is_neox_style = is_neox_style
63
+
64
+ cache = self._compute_cos_sin_cache()
65
+ cache = cache.to(torch.get_default_dtype())
66
+ self.register_buffer("cos_sin_cache", cache, persistent=False)
67
+
68
+ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
69
+ """Compute the inverse frequency."""
70
+ # NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
71
+ # However, we use `torch.arange(..., dtype=torch.float)` instead to
72
+ # avoid numerical issues with large base values (e.g., 10000000).
73
+ # This may cause a slight numerical difference between the HF
74
+ # implementation and ours.
75
+ # NOTE(woosuk): To exactly match the HF implementation, we need to
76
+ # use CPU to compute the cache and then move it to GPU. However, we
77
+ # create the cache on GPU for faster initialization. This may cause
78
+ # a slight numerical difference between the HF implementation and ours.
79
+ inv_freq = 1.0 / (base**(torch.arange(
80
+ 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
81
+ return inv_freq
82
+
83
+ def _compute_cos_sin_cache(self) -> torch.Tensor:
84
+ """Compute the cos and sin cache."""
85
+ inv_freq = self._compute_inv_freq(self.base)
86
+ t = torch.arange(self.max_position_embeddings, dtype=torch.float)
87
+
88
+ freqs = torch.einsum("i,j -> ij", t, inv_freq)
89
+ cos = freqs.cos()
90
+ sin = freqs.sin()
91
+ cache = torch.cat((cos, sin), dim=-1)
92
+ return cache
93
+
94
+ def _forward(
95
+ self,
96
+ positions: torch.Tensor,
97
+ query: torch.Tensor,
98
+ key: torch.Tensor,
99
+ offsets: Optional[torch.Tensor] = None,
100
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
101
+ """PyTorch-native implementation equivalent to forward()."""
102
+ query = query.view(*query.shape[:-1], -1, self.head_size)
103
+ key = key.view(*key.shape[:-1], -1, self.head_size)
104
+
105
+ query_rot = query[..., :self.rotary_dim]
106
+ key_rot = key[..., :self.rotary_dim]
107
+ if self.rotary_dim < self.head_size:
108
+ query_pass = query[..., self.rotary_dim:]
109
+ key_pass = key[..., self.rotary_dim:]
110
+
111
+ self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
112
+ positions.device)
113
+ cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
114
+ if offsets is not None else positions]
115
+ cos, sin = cos_sin.chunk(2, dim=-1)
116
+ if self.is_neox_style:
117
+ # NOTE(woosuk): Here we assume that the positions tensor has the
118
+ # shape [batch_size, seq_len].
119
+ cos = cos.repeat(1, 1, 2).unsqueeze(-2)
120
+ sin = sin.repeat(1, 1, 2).unsqueeze(-2)
121
+ else:
122
+ cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
123
+ sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
124
+
125
+ rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
126
+ query_rot = query_rot * cos + rotate_fn(query_rot) * sin
127
+ key_rot = key_rot * cos + rotate_fn(key_rot) * sin
128
+
129
+ if self.rotary_dim < self.head_size:
130
+ query = torch.cat((query_rot, query_pass), dim=-1)
131
+ key = torch.cat((key_rot, key_pass), dim=-1)
132
+ else:
133
+ query = query_rot
134
+ key = key_rot
135
+ query = query.flatten(-2)
136
+ key = key.flatten(-2)
137
+ return query, key
138
+
139
+ def forward(
140
+ self,
141
+ positions: torch.Tensor,
142
+ query: torch.Tensor,
143
+ key: torch.Tensor,
144
+ offsets: Optional[torch.Tensor] = None,
145
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
146
+ self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
147
+ # ops.rotary_embedding()/batched_rotary_embedding()
148
+ # are in-place operations that update the query and key tensors.
149
+ if offsets is not None:
150
+ ops.batched_rotary_embedding(positions, query, key, self.head_size,
151
+ self.cos_sin_cache,
152
+ self.is_neox_style, self.rotary_dim,
153
+ offsets)
154
+ else:
155
+ ops.rotary_embedding(positions, query, key, self.head_size,
156
+ self.cos_sin_cache, self.is_neox_style)
157
+ return query, key
158
+
159
+ def extra_repr(self) -> str:
160
+ s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
161
+ s += f", max_position_embeddings={self.max_position_embeddings}"
162
+ s += f", base={self.base}, is_neox_style={self.is_neox_style}"
163
+ return s
164
+
165
+
166
+ class LinearScalingRotaryEmbedding(RotaryEmbedding):
167
+ """RotaryEmbedding extended with linear scaling.
168
+
169
+ Credits to the Reddit user /u/kaiokendev
170
+ """
171
+
172
+ def __init__(
173
+ self,
174
+ head_size: int,
175
+ rotary_dim: int,
176
+ max_position_embeddings: int,
177
+ base: int,
178
+ is_neox_style: bool,
179
+ scaling_factors: Union[List[float], float],
180
+ ) -> None:
181
+ if isinstance(scaling_factors, float):
182
+ scaling_factors = [scaling_factors]
183
+ self.scaling_factors = scaling_factors
184
+ super().__init__(head_size, rotary_dim, max_position_embeddings, base,
185
+ is_neox_style)
186
+
187
+ def _compute_cos_sin_cache(self) -> torch.Tensor:
188
+ inv_freq = self._compute_inv_freq(self.base)
189
+ cache_list = []
190
+ for scaling_factor in self.scaling_factors:
191
+ # NOTE(woosuk): self.max_position_embeddings is the original
192
+ # maximum length before applying the rope scaling.
193
+ # Thus, the maximum length after applying the rope scaling is
194
+ # self.max_position_embeddings * self.scaling_factor.
195
+ max_len = self.max_position_embeddings * scaling_factor
196
+ t = torch.arange(max_len, dtype=torch.float)
197
+ t = t / scaling_factor
198
+
199
+ freqs = torch.einsum("i,j -> ij", t, inv_freq)
200
+ cos = freqs.cos()
201
+ sin = freqs.sin()
202
+ cache = torch.cat((cos, sin), dim=-1)
203
+ cache_list.append(cache)
204
+ return torch.cat(cache_list, dim=0)
205
+
206
+
207
+ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
208
+ """RotaryEmbedding extended with Dynamic NTK scaling.
209
+
210
+ Credits to the Reddit users /u/bloc97 and /u/emozilla
211
+ """
212
+
213
+ def __init__(
214
+ self,
215
+ head_size: int,
216
+ rotary_dim: int,
217
+ max_position_embeddings: int,
218
+ base: int,
219
+ is_neox_style: bool,
220
+ scaling_factor: float,
221
+ ) -> None:
222
+ self.scaling_factor = scaling_factor
223
+ super().__init__(head_size, rotary_dim, max_position_embeddings, base,
224
+ is_neox_style)
225
+
226
+ def _compute_cos_sin_cache(self) -> torch.Tensor:
227
+ # NOTE(woosuk): self.max_position_embeddings is the original
228
+ # maximum length before applying the rope scaling.
229
+ # Thus, the maximum length after applying the rope scaling is
230
+ # self.max_position_embeddings * self.scaling_factor.
231
+ max_len = self.max_position_embeddings * self.scaling_factor
232
+ base = self.base * (
233
+ (self.scaling_factor * max_len / self.max_position_embeddings) -
234
+ (self.scaling_factor - 1))**(self.rotary_dim /
235
+ (self.rotary_dim - 2))
236
+ inv_freq = self._compute_inv_freq(base)
237
+ t = torch.arange(max_len, dtype=torch.float)
238
+
239
+ freqs = torch.einsum("i,j -> ij", t, inv_freq)
240
+ cos = freqs.cos()
241
+ sin = freqs.sin()
242
+ cache = torch.cat((cos, sin), dim=-1)
243
+ return cache
244
+
245
+
246
+ # Inverse dim formula to find dim based on number of rotations
247
+ def _yarn_find_correction_dim(num_rotations: int,
248
+ dim: int,
249
+ base: float = 10000,
250
+ max_position_embeddings: int = 2048) -> float:
251
+ return (dim * math.log(max_position_embeddings /
252
+ (num_rotations * 2 * math.pi))) / (2 *
253
+ math.log(base))
254
+
255
+
256
+ # Find dim range bounds based on rotations
257
+ def _yarn_find_correction_range(
258
+ low_rot: int,
259
+ high_rot: int,
260
+ dim: int,
261
+ base: float = 10000,
262
+ max_position_embeddings: int = 2048) -> Tuple[int, int]:
263
+ low = math.floor(
264
+ _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
265
+ high = math.ceil(
266
+ _yarn_find_correction_dim(high_rot, dim, base,
267
+ max_position_embeddings))
268
+ return max(low, 0), min(high, dim - 1) # Clamp values just in case
269
+
270
+
271
+ def _yarn_linear_ramp_mask(low: float, high: float, dim: int,
272
+ dtype: torch.dtype) -> torch.Tensor:
273
+ if low == high:
274
+ high += 0.001 # Prevent singularity
275
+
276
+ linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
277
+ ramp_func = torch.clamp(linear_func, 0, 1)
278
+ return ramp_func
279
+
280
+
281
+ def _yarn_get_mscale(scale: float = 1) -> float:
282
+ if scale <= 1:
283
+ return 1.0
284
+ return 0.1 * math.log(scale) + 1.0
285
+
286
+
287
+ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
288
+ """RotaryEmbedding extended with YaRN method.
289
+
290
+ Credits to Peng et al. github.com/jquesnelle/yarn
291
+ """
292
+
293
+ def __init__(
294
+ self,
295
+ head_size: int,
296
+ rotary_dim: int,
297
+ max_position_embeddings: int,
298
+ base: int,
299
+ is_neox_style: bool,
300
+ scaling_factor: float,
301
+ *,
302
+ extrapolation_factor: float = 1,
303
+ attn_factor: float = 1,
304
+ beta_fast: int = 32,
305
+ beta_slow: int = 1,
306
+ ) -> None:
307
+ self.scaling_factor = scaling_factor
308
+ self.extrapolation_factor = extrapolation_factor
309
+ self.attn_factor = attn_factor
310
+ self.beta_fast = beta_fast
311
+ self.beta_slow = beta_slow
312
+ # Get n-d magnitude scaling corrected for interpolation
313
+ self.mscale = float(
314
+ _yarn_get_mscale(self.scaling_factor) * attn_factor)
315
+ super().__init__(head_size, rotary_dim, max_position_embeddings, base,
316
+ is_neox_style)
317
+
318
+ def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
319
+ pos_freqs = self.base**(
320
+ torch.arange(0, self.rotary_dim, 2, dtype=torch.float) /
321
+ self.rotary_dim)
322
+ inv_freq_extrapolation = 1.0 / pos_freqs
323
+ inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
324
+
325
+ low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
326
+ self.rotary_dim, self.base,
327
+ self.max_position_embeddings)
328
+ # Get n-d rotational scaling corrected for extrapolation
329
+ inv_freq_mask = (1 - _yarn_linear_ramp_mask(
330
+ low, high, self.rotary_dim // 2,
331
+ dtype=torch.float)) * self.extrapolation_factor
332
+ inv_freq = inv_freq_interpolation * (
333
+ 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
334
+ return inv_freq
335
+
336
+ def _compute_cos_sin_cache(self) -> torch.Tensor:
337
+ inv_freq = self._compute_inv_freq(self.scaling_factor)
338
+ t = torch.arange(self.max_position_embeddings * self.scaling_factor,
339
+ dtype=torch.float32)
340
+ freqs = torch.einsum("i,j -> ij", t, inv_freq)
341
+ cos = (freqs.cos() * self.mscale)
342
+ sin = (freqs.sin() * self.mscale)
343
+ cache = torch.cat((cos, sin), dim=-1)
344
+ return cache
345
+
346
+
347
+ class Phi3SuScaledRotaryEmbedding(nn.Module):
348
+ """Phi3 family of models scaled rotary embedding.
349
+
350
+ Based on the original RotaryEmbedding implementation.
351
+ """
352
+
353
+ def __init__(
354
+ self,
355
+ head_size: int,
356
+ rotary_dim: int,
357
+ max_position_embeddings: int,
358
+ original_max_position_embeddings: int,
359
+ base: int,
360
+ is_neox_style: bool,
361
+ short_factor: List[float],
362
+ long_factor: List[float],
363
+ short_mscale: float = 1.1,
364
+ long_mscale: float = 1.225,
365
+ ):
366
+ super().__init__()
367
+
368
+ if rotary_dim != head_size:
369
+ raise ValueError(
370
+ f"`Phi3SuScaledRotaryEmbedding` does not support rotary_dim != \
371
+ head_size ({rotary_dim}!={head_size}).")
372
+ if is_neox_style is False:
373
+ raise ValueError(
374
+ "`Phi3SuScaledRotaryEmbedding` only supports neox_style.")
375
+
376
+ self.head_size = head_size
377
+ self.max_position_embeddings = max_position_embeddings
378
+ self.original_max_position_embeddings = original_max_position_embeddings
379
+ self.base = base
380
+ self.short_factor = short_factor
381
+ self.long_factor = long_factor
382
+ self.short_mscale = short_mscale
383
+ self.long_mscale = long_mscale
384
+
385
+ short_cache = self._compute_cos_sin_cache(
386
+ original_max_position_embeddings, short_factor, short_mscale)
387
+ short_cache = short_cache.to(torch.get_default_dtype())
388
+ self.register_buffer("short_cos_sin_cache",
389
+ short_cache,
390
+ persistent=False)
391
+
392
+ long_cache = self._compute_cos_sin_cache(max_position_embeddings,
393
+ long_factor, long_mscale)
394
+ long_cache = long_cache.to(torch.get_default_dtype())
395
+ self.register_buffer("long_cos_sin_cache",
396
+ long_cache,
397
+ persistent=False)
398
+
399
+ long_short_cache = torch.cat(
400
+ [self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0)
401
+ self.register_buffer("long_short_cos_sin_cache",
402
+ long_short_cache,
403
+ persistent=False)
404
+
405
+ def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor:
406
+ rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
407
+ inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange(
408
+ 0, self.head_size, 2, dtype=torch.float) / self.head_size)))
409
+ return inv_freq
410
+
411
+ def _compute_cos_sin_cache(
412
+ self,
413
+ max_position_embeddings: int,
414
+ rescale_factors: List[float],
415
+ mscale: float,
416
+ ) -> torch.Tensor:
417
+ inv_freq = self._compute_inv_freq(rescale_factors)
418
+ t = torch.arange(max_position_embeddings, dtype=torch.float)
419
+ freqs = torch.einsum("i,j -> ij", t, inv_freq)
420
+ cos = freqs.cos() * mscale
421
+ sin = freqs.sin() * mscale
422
+ cache = torch.cat((cos, sin), dim=-1)
423
+ return cache
424
+
425
+ def forward(
426
+ self,
427
+ positions: torch.Tensor,
428
+ query: torch.Tensor,
429
+ key: torch.Tensor,
430
+ offsets: Optional[torch.Tensor] = None,
431
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
432
+ query = query.view(*query.shape[:-1], -1, self.head_size)
433
+ key = key.view(*key.shape[:-1], -1, self.head_size)
434
+
435
+ k = self.original_max_position_embeddings
436
+ long_prompt_offset = (torch.any(positions > k).float() *
437
+ torch.full_like(positions, k)).long()
438
+ idx = (torch.add(positions, long_prompt_offset)
439
+ if long_prompt_offset is not None else positions)
440
+ self.long_short_cos_sin_cache: torch.Tensor = (
441
+ self.long_short_cos_sin_cache.to(idx.device))
442
+ idx = torch.add(idx, offsets) if offsets is not None else idx
443
+ cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
444
+
445
+ cos, sin = cos_sin.chunk(2, dim=-1)
446
+ cos = cos.repeat(1, 2).unsqueeze(-2)
447
+ sin = sin.repeat(1, 2).unsqueeze(-2)
448
+
449
+ query = query * cos + _rotate_neox(query) * sin
450
+ key = key * cos + _rotate_neox(key) * sin
451
+
452
+ return query.flatten(-2), key.flatten(-2)
453
+
454
+
455
+ _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
456
+
457
+
458
+ def get_rope(
459
+ head_size: int,
460
+ rotary_dim: int,
461
+ max_position: int,
462
+ base: int,
463
+ is_neox_style: bool = True,
464
+ rope_scaling: Optional[Dict[str, Any]] = None,
465
+ ) -> RotaryEmbedding:
466
+ if rope_scaling is not None:
467
+ # Transforms every value that is a list into a tuple for caching calls
468
+ rope_scaling_tuple = {
469
+ k: tuple(v) if isinstance(v, list) else v
470
+ for k, v in rope_scaling.items()
471
+ }
472
+ rope_scaling_args = tuple(rope_scaling_tuple.items())
473
+ else:
474
+ rope_scaling_args = None
475
+ key = (head_size, rotary_dim, max_position, base, is_neox_style,
476
+ rope_scaling_args)
477
+ if key in _ROPE_DICT:
478
+ return _ROPE_DICT[key]
479
+ if rope_scaling is None:
480
+ rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
481
+ is_neox_style)
482
+ else:
483
+ scaling_type = rope_scaling["type"]
484
+ if scaling_type != "su":
485
+ scaling_factor = rope_scaling["factor"]
486
+ if scaling_type == "linear":
487
+ rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
488
+ max_position, base,
489
+ is_neox_style,
490
+ scaling_factor)
491
+ elif scaling_type == "dynamic":
492
+ rotary_emb = DynamicNTKScalingRotaryEmbedding(
493
+ head_size, rotary_dim, max_position, base, is_neox_style,
494
+ scaling_factor)
495
+ elif scaling_type == "yarn":
496
+ original_max_position = rope_scaling[
497
+ "original_max_position_embeddings"]
498
+ extra_kwargs = {
499
+ k: v
500
+ for k, v in rope_scaling.items()
501
+ if k in ("extrapolation_factor", "attn_factor", "beta_fast",
502
+ "beta_slow")
503
+ }
504
+ rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
505
+ original_max_position,
506
+ base, is_neox_style,
507
+ scaling_factor,
508
+ **extra_kwargs)
509
+ elif scaling_type == "su":
510
+ short_factor = rope_scaling["short_factor"]
511
+ long_factor = rope_scaling["long_factor"]
512
+ original_max_position = rope_scaling[
513
+ "original_max_position_embeddings"]
514
+ extra_kwargs = {
515
+ k: v
516
+ for k, v in rope_scaling.items()
517
+ if k in ("short_mscale", "long_mscale")
518
+ }
519
+ rotary_emb = Phi3SuScaledRotaryEmbedding(
520
+ head_size, rotary_dim, max_position, original_max_position,
521
+ base, is_neox_style, short_factor, long_factor, **extra_kwargs)
522
+ else:
523
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
524
+ _ROPE_DICT[key] = rotary_emb
525
+ return rotary_emb