sglang 0.4.1.post6__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +41 -27
  4. sglang/bench_one_batch.py +60 -4
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +83 -71
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +46 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/load_config.py +1 -0
  13. sglang/srt/configs/model_config.py +1 -0
  14. sglang/srt/constrained/base_grammar_backend.py +21 -0
  15. sglang/srt/constrained/xgrammar_backend.py +8 -4
  16. sglang/srt/conversation.py +14 -1
  17. sglang/srt/distributed/__init__.py +3 -3
  18. sglang/srt/distributed/communication_op.py +2 -1
  19. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  20. sglang/srt/distributed/device_communicators/custom_all_reduce.py +112 -42
  21. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  22. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  23. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  24. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  25. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  26. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  27. sglang/srt/distributed/parallel_state.py +1 -1
  28. sglang/srt/distributed/utils.py +2 -1
  29. sglang/srt/entrypoints/engine.py +452 -0
  30. sglang/srt/entrypoints/http_server.py +603 -0
  31. sglang/srt/function_call_parser.py +494 -0
  32. sglang/srt/layers/activation.py +8 -8
  33. sglang/srt/layers/attention/flashinfer_backend.py +10 -9
  34. sglang/srt/layers/attention/triton_backend.py +4 -6
  35. sglang/srt/layers/attention/vision.py +204 -0
  36. sglang/srt/layers/dp_attention.py +71 -0
  37. sglang/srt/layers/layernorm.py +5 -5
  38. sglang/srt/layers/linear.py +65 -14
  39. sglang/srt/layers/logits_processor.py +49 -64
  40. sglang/srt/layers/moe/ep_moe/layer.py +24 -16
  41. sglang/srt/layers/moe/fused_moe_native.py +84 -1
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  43. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -7
  44. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -5
  45. sglang/srt/layers/parameter.py +18 -8
  46. sglang/srt/layers/quantization/__init__.py +20 -23
  47. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  49. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  50. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  51. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  52. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  53. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  54. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  55. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  56. sglang/srt/layers/quantization/fp8.py +10 -4
  57. sglang/srt/layers/quantization/modelopt_quant.py +1 -2
  58. sglang/srt/layers/quantization/w8a8_int8.py +1 -1
  59. sglang/srt/layers/radix_attention.py +2 -2
  60. sglang/srt/layers/rotary_embedding.py +1184 -31
  61. sglang/srt/layers/sampler.py +64 -6
  62. sglang/srt/layers/torchao_utils.py +12 -6
  63. sglang/srt/layers/vocab_parallel_embedding.py +2 -2
  64. sglang/srt/lora/lora.py +1 -9
  65. sglang/srt/managers/configure_logging.py +3 -0
  66. sglang/srt/managers/data_parallel_controller.py +79 -72
  67. sglang/srt/managers/detokenizer_manager.py +24 -6
  68. sglang/srt/managers/image_processor.py +158 -2
  69. sglang/srt/managers/io_struct.py +57 -3
  70. sglang/srt/managers/schedule_batch.py +78 -45
  71. sglang/srt/managers/schedule_policy.py +26 -12
  72. sglang/srt/managers/scheduler.py +326 -201
  73. sglang/srt/managers/session_controller.py +1 -0
  74. sglang/srt/managers/tokenizer_manager.py +210 -121
  75. sglang/srt/managers/tp_worker.py +6 -4
  76. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  77. sglang/srt/managers/utils.py +44 -0
  78. sglang/srt/mem_cache/memory_pool.py +10 -32
  79. sglang/srt/metrics/collector.py +15 -6
  80. sglang/srt/model_executor/cuda_graph_runner.py +26 -30
  81. sglang/srt/model_executor/forward_batch_info.py +5 -7
  82. sglang/srt/model_executor/model_runner.py +44 -19
  83. sglang/srt/model_loader/loader.py +83 -6
  84. sglang/srt/model_loader/weight_utils.py +145 -6
  85. sglang/srt/models/baichuan.py +6 -6
  86. sglang/srt/models/chatglm.py +2 -2
  87. sglang/srt/models/commandr.py +17 -5
  88. sglang/srt/models/dbrx.py +13 -5
  89. sglang/srt/models/deepseek.py +3 -3
  90. sglang/srt/models/deepseek_v2.py +11 -11
  91. sglang/srt/models/exaone.py +2 -2
  92. sglang/srt/models/gemma.py +2 -2
  93. sglang/srt/models/gemma2.py +15 -25
  94. sglang/srt/models/gpt2.py +3 -5
  95. sglang/srt/models/gpt_bigcode.py +1 -1
  96. sglang/srt/models/granite.py +2 -2
  97. sglang/srt/models/grok.py +4 -3
  98. sglang/srt/models/internlm2.py +2 -2
  99. sglang/srt/models/llama.py +7 -5
  100. sglang/srt/models/minicpm.py +2 -2
  101. sglang/srt/models/minicpm3.py +9 -9
  102. sglang/srt/models/minicpmv.py +1238 -0
  103. sglang/srt/models/mixtral.py +3 -3
  104. sglang/srt/models/mixtral_quant.py +3 -3
  105. sglang/srt/models/mllama.py +2 -2
  106. sglang/srt/models/olmo.py +3 -3
  107. sglang/srt/models/olmo2.py +4 -4
  108. sglang/srt/models/olmoe.py +7 -13
  109. sglang/srt/models/phi3_small.py +2 -2
  110. sglang/srt/models/qwen.py +2 -2
  111. sglang/srt/models/qwen2.py +41 -4
  112. sglang/srt/models/qwen2_moe.py +3 -3
  113. sglang/srt/models/qwen2_vl.py +22 -122
  114. sglang/srt/models/stablelm.py +2 -2
  115. sglang/srt/models/torch_native_llama.py +20 -7
  116. sglang/srt/models/xverse.py +6 -6
  117. sglang/srt/models/xverse_moe.py +6 -6
  118. sglang/srt/openai_api/adapter.py +139 -37
  119. sglang/srt/openai_api/protocol.py +7 -4
  120. sglang/srt/sampling/custom_logit_processor.py +38 -0
  121. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
  122. sglang/srt/sampling/sampling_batch_info.py +143 -18
  123. sglang/srt/sampling/sampling_params.py +3 -1
  124. sglang/srt/server.py +4 -1090
  125. sglang/srt/server_args.py +77 -15
  126. sglang/srt/speculative/eagle_utils.py +37 -15
  127. sglang/srt/speculative/eagle_worker.py +11 -13
  128. sglang/srt/utils.py +164 -129
  129. sglang/test/runners.py +8 -13
  130. sglang/test/test_programs.py +2 -1
  131. sglang/test/test_utils.py +83 -22
  132. sglang/utils.py +12 -2
  133. sglang/version.py +1 -1
  134. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/METADATA +21 -10
  135. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/RECORD +138 -123
  136. sglang/launch_server_llavavid.py +0 -25
  137. sglang/srt/constrained/__init__.py +0 -16
  138. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  139. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
  140. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
  141. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
@@ -1,54 +1,917 @@
1
- # Copyright 2023-2024 SGLang Team
2
- # Licensed under the Apache License, Version 2.0 (the "License");
3
- # you may not use this file except in compliance with the License.
4
- # You may obtain a copy of the License at
5
- #
6
- # http://www.apache.org/licenses/LICENSE-2.0
7
- #
8
- # Unless required by applicable law or agreed to in writing, software
9
- # distributed under the License is distributed on an "AS IS" BASIS,
10
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
- # See the License for the specific language governing permissions and
12
- # limitations under the License.
13
- # ==============================================================================
14
- """MRotaryEmbedding"""
1
+ # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/refs/tags/v0.6.6.post1/vllm/model_executor/layers/rotary_embedding.py
2
+
3
+ """Rotary Positional Embeddings."""
4
+ import math
15
5
  from typing import Any, Dict, List, Optional, Tuple, Union
16
6
 
17
7
  import torch
8
+ import torch.nn as nn
9
+ from vllm.model_executor.custom_op import CustomOp
10
+
11
+ from sglang.srt.layers.custom_op_util import register_custom_op
12
+
13
+
14
+ def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
15
+ x1 = x[..., : x.shape[-1] // 2]
16
+ x2 = x[..., x.shape[-1] // 2 :]
17
+ return torch.cat((-x2, x1), dim=-1)
18
+
19
+
20
+ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
21
+ x1 = x[..., ::2]
22
+ x2 = x[..., 1::2]
23
+ x = torch.stack((-x2, x1), dim=-1)
24
+ return x.flatten(-2)
25
+
26
+
27
+ def _apply_rotary_emb(
28
+ x: torch.Tensor,
29
+ cos: torch.Tensor,
30
+ sin: torch.Tensor,
31
+ is_neox_style: bool,
32
+ ) -> torch.Tensor:
33
+ """
34
+ Args:
35
+ x: [num_tokens, num_heads, head_size]
36
+ cos: [num_tokens, head_size // 2]
37
+ sin: [num_tokens, head_size // 2]
38
+ is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
39
+ positional embeddings.
40
+ """
41
+ cos = cos.unsqueeze(-2).to(x.dtype)
42
+ sin = sin.unsqueeze(-2).to(x.dtype)
43
+ if is_neox_style:
44
+ x1, x2 = torch.chunk(x, 2, dim=-1)
45
+ else:
46
+ x1 = x[..., ::2]
47
+ x2 = x[..., 1::2]
48
+ o1 = x1 * cos - x2 * sin
49
+ o2 = x2 * cos + x1 * sin
50
+ if is_neox_style:
51
+ return torch.cat((o1, o2), dim=-1)
52
+ else:
53
+ return torch.stack((o1, o2), dim=-1).flatten(-2)
54
+
55
+
56
+ @register_custom_op("sglang_rotary_embedding")
57
+ class RotaryEmbedding(CustomOp):
58
+ """Original rotary positional embedding."""
59
+
60
+ def __init__(
61
+ self,
62
+ head_size: int,
63
+ rotary_dim: int,
64
+ max_position_embeddings: int,
65
+ base: int,
66
+ is_neox_style: bool,
67
+ dtype: torch.dtype,
68
+ ) -> None:
69
+ super().__init__()
70
+ self.head_size = head_size
71
+ self.rotary_dim = rotary_dim
72
+ self.max_position_embeddings = max_position_embeddings
73
+ self.base = base
74
+ self.is_neox_style = is_neox_style
75
+ self.dtype = dtype
76
+
77
+ cache = self._compute_cos_sin_cache()
78
+ cache = cache.to(dtype)
79
+ self.cos_sin_cache: torch.Tensor
80
+ self.register_buffer("cos_sin_cache", cache, persistent=False)
81
+
82
+ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
83
+ """Compute the inverse frequency."""
84
+ # NOTE(woosuk): To exactly match the HF implementation, we need to
85
+ # use CPU to compute the cache and then move it to GPU. However, we
86
+ # create the cache on GPU for faster initialization. This may cause
87
+ # a slight numerical difference between the HF implementation and ours.
88
+ inv_freq = 1.0 / (
89
+ base
90
+ ** (
91
+ torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
92
+ )
93
+ )
94
+ return inv_freq
95
+
96
+ def _compute_cos_sin_cache(self) -> torch.Tensor:
97
+ """Compute the cos and sin cache."""
98
+ inv_freq = self._compute_inv_freq(self.base)
99
+ t = torch.arange(self.max_position_embeddings, dtype=torch.float)
100
+
101
+ freqs = torch.einsum("i,j -> ij", t, inv_freq)
102
+ cos = freqs.cos()
103
+ sin = freqs.sin()
104
+ cache = torch.cat((cos, sin), dim=-1)
105
+ return cache
106
+
107
+ def forward_native(
108
+ self,
109
+ positions: torch.Tensor,
110
+ query: torch.Tensor,
111
+ key: torch.Tensor,
112
+ offsets: Optional[torch.Tensor] = None,
113
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
114
+ """A PyTorch-native implementation of forward()."""
115
+ if offsets is not None:
116
+ positions = positions + offsets
117
+ positions = positions.flatten()
118
+ num_tokens = positions.shape[0]
119
+ cos_sin = self.cos_sin_cache.index_select(0, positions)
120
+ cos, sin = cos_sin.chunk(2, dim=-1)
121
+
122
+ query_shape = query.shape
123
+ query = query.view(num_tokens, -1, self.head_size)
124
+ query_rot = query[..., : self.rotary_dim]
125
+ query_pass = query[..., self.rotary_dim :]
126
+ query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
127
+ query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
128
+
129
+ key_shape = key.shape
130
+ key = key.view(num_tokens, -1, self.head_size)
131
+ key_rot = key[..., : self.rotary_dim]
132
+ key_pass = key[..., self.rotary_dim :]
133
+ key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
134
+ key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
135
+ return query, key
136
+
137
+ def forward_cuda(
138
+ self,
139
+ positions: torch.Tensor,
140
+ query: torch.Tensor,
141
+ key: torch.Tensor,
142
+ offsets: Optional[torch.Tensor] = None,
143
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
144
+ from vllm import _custom_ops as ops
145
+
146
+ self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
147
+ ops.rotary_embedding(
148
+ positions,
149
+ query,
150
+ key,
151
+ self.head_size,
152
+ self.cos_sin_cache,
153
+ self.is_neox_style,
154
+ )
155
+ return query, key
156
+
157
+ def forward_xpu(
158
+ self,
159
+ positions: torch.Tensor,
160
+ query: torch.Tensor,
161
+ key: torch.Tensor,
162
+ offsets: Optional[torch.Tensor] = None,
163
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
164
+ from vllm._ipex_ops import ipex_ops as ops
165
+
166
+ self.cos_sin_cache = self.cos_sin_cache.to(positions.device, dtype=query.dtype)
167
+ ops.rotary_embedding(
168
+ positions,
169
+ query,
170
+ key,
171
+ self.head_size,
172
+ self.cos_sin_cache,
173
+ self.is_neox_style,
174
+ )
175
+ return query, key
176
+
177
+ def forward_hpu(
178
+ self,
179
+ positions: torch.Tensor,
180
+ query: torch.Tensor,
181
+ key: torch.Tensor,
182
+ offsets: Optional[torch.Tensor] = None,
183
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
184
+ from habana_frameworks.torch.hpex.kernels import (
185
+ RotaryPosEmbeddingMode,
186
+ apply_rotary_pos_emb,
187
+ )
188
+
189
+ positions = positions.flatten()
190
+ if offsets is not None:
191
+ positions = positions + offsets
192
+ num_tokens = positions.shape[0]
193
+ cos_sin = self.cos_sin_cache.index_select(0, positions).view(num_tokens, 1, -1)
194
+ cos, sin = cos_sin.chunk(2, dim=-1)
195
+ # HPU RoPE kernel requires hidden dimension for cos and sin to be equal
196
+ # to query hidden dimension, so the original tensors need to be
197
+ # expanded
198
+ # GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
199
+ # and expansion of cos/sin tensors via concatenation
200
+ # GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE
201
+ # and expansion of cos/sin tensors via repeat_interleave
202
+ rope_mode: RotaryPosEmbeddingMode
203
+ if self.is_neox_style:
204
+ rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
205
+ cos = torch.cat((cos, cos), dim=-1)
206
+ sin = torch.cat((sin, sin), dim=-1)
207
+ else:
208
+ rope_mode = RotaryPosEmbeddingMode.PAIRWISE
209
+ sin = torch.repeat_interleave(sin, 2, dim=-1, output_size=cos_sin.shape[-1])
210
+ cos = torch.repeat_interleave(cos, 2, dim=-1, output_size=cos_sin.shape[-1])
211
+
212
+ query_shape = query.shape
213
+ query = query.view(num_tokens, -1, self.head_size)
214
+ query_rot = query[..., : self.rotary_dim]
215
+ query_pass = query[..., self.rotary_dim :]
216
+ query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
217
+ query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
218
+
219
+ key_shape = key.shape
220
+ key = key.view(num_tokens, -1, self.head_size)
221
+ key_rot = key[..., : self.rotary_dim]
222
+ key_pass = key[..., self.rotary_dim :]
223
+ key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
224
+ key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
225
+ return query, key
226
+
227
+ def extra_repr(self) -> str:
228
+ s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
229
+ s += f", max_position_embeddings={self.max_position_embeddings}"
230
+ s += f", base={self.base}, is_neox_style={self.is_neox_style}"
231
+ return s
232
+
233
+
234
+ class LinearScalingRotaryEmbedding(RotaryEmbedding):
235
+ """RotaryEmbedding extended with linear scaling.
236
+
237
+ It supports multiple scaling factors. Since multiple LoRA adapters may have
238
+ different scaling factors, we need multiple cos/sin caches. In this way,
239
+ instead of running rotary embedding kernel per lora, we can run multiple
240
+ lora in a batched way.
241
+
242
+ In addition to that, we also keep the cos/sin cache for the scaling factor
243
+ of 1 (default) at all times.
244
+
245
+ Exemplary for two scaling factors x=1, y and z with embeddings
246
+ [[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and
247
+ [[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and
248
+ [[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]],
249
+
250
+ we construct the cos/sin cache as follows:
251
+ [[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p],
252
+ ...
253
+ [xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]]
254
+
255
+ We then use offsets to index into the cos/sin cache for
256
+ the respective scaling factors.
257
+
258
+ The offset to cache can be accessed via `scaling_factor_to_offset` API.
259
+
260
+ Credits to the Reddit user /u/kaiokendev
261
+ """
262
+
263
+ def __init__(
264
+ self,
265
+ head_size: int,
266
+ rotary_dim: int,
267
+ max_position_embeddings: int,
268
+ base: int,
269
+ is_neox_style: bool,
270
+ scaling_factors: Union[List[float], float],
271
+ dtype: torch.dtype,
272
+ ) -> None:
273
+ if isinstance(scaling_factors, float):
274
+ scaling_factors = [scaling_factors]
275
+ self.scaling_factors: List[float] = scaling_factors # noqa
276
+ super().__init__(
277
+ head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
278
+ )
279
+ # Lazy initialized.
280
+ self._scaling_factor_to_offset: Dict[float, int]
281
+
282
+ def _compute_cos_sin_cache(self) -> torch.Tensor:
283
+ inv_freq = self._compute_inv_freq(self.base)
284
+ cache_list: List[torch.Tensor] = []
285
+ # offsets to the next cache in a tensor.
286
+ # Each offset corresponds to the same index in scaling_factors.
287
+ offsets: List[int] = []
288
+ for scaling_factor in self.scaling_factors:
289
+ # NOTE(woosuk): self.max_position_embeddings is the original
290
+ # maximum length before applying the rope scaling.
291
+ # Thus, the maximum length after applying the rope scaling is
292
+ # self.max_position_embeddings * self.scaling_factor.
293
+ max_len = self.max_position_embeddings * scaling_factor
294
+ t = torch.arange(max_len, dtype=torch.float)
295
+ t = t / scaling_factor
296
+
297
+ freqs = torch.einsum("i,j -> ij", t, inv_freq)
298
+ cos = freqs.cos()
299
+ sin = freqs.sin()
300
+ cache = torch.cat((cos, sin), dim=-1)
301
+ if not cache_list:
302
+ offset = 0
303
+ else:
304
+ last_offset = offsets[-1]
305
+ next_max_len = cache_list[-1].shape[0]
306
+ offset = last_offset + next_max_len
307
+ offsets.append(offset)
308
+ cache_list.append(cache)
309
+ self._scaling_factor_to_offset = {
310
+ float(scaling_factor): offsets[i]
311
+ for i, scaling_factor in enumerate(self.scaling_factors)
312
+ }
313
+ assert len(self.scaling_factors) == len(offsets)
314
+ return torch.cat(cache_list, dim=0)
315
+
316
+ @property
317
+ def scaling_factor_to_offset(self) -> Dict[float, int]:
318
+ return self._scaling_factor_to_offset
319
+
320
+
321
+ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
322
+ """RotaryEmbedding extended with Dynamic NTK scaling.
323
+
324
+ Credits to the Reddit users /u/bloc97 and /u/emozilla
325
+ """
326
+
327
+ def __init__(
328
+ self,
329
+ head_size: int,
330
+ rotary_dim: int,
331
+ max_position_embeddings: int,
332
+ base: int,
333
+ is_neox_style: bool,
334
+ scaling_factor: float,
335
+ dtype: torch.dtype,
336
+ ) -> None:
337
+ self.scaling_factor = scaling_factor
338
+ super().__init__(
339
+ head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
340
+ )
341
+
342
+ def _compute_cos_sin_cache(self) -> torch.Tensor:
343
+ # NOTE(woosuk): self.max_position_embeddings is the original
344
+ # maximum length before applying the rope scaling.
345
+ # Thus, the maximum length after applying the rope scaling is
346
+ # self.max_position_embeddings * self.scaling_factor.
347
+ max_len = self.max_position_embeddings * self.scaling_factor
348
+ base = self.base * (
349
+ (self.scaling_factor * max_len / self.max_position_embeddings)
350
+ - (self.scaling_factor - 1)
351
+ ) ** (self.rotary_dim / (self.rotary_dim - 2))
352
+ inv_freq = self._compute_inv_freq(base)
353
+ t = torch.arange(max_len, dtype=torch.float)
354
+
355
+ freqs = torch.einsum("i,j -> ij", t, inv_freq)
356
+ cos = freqs.cos()
357
+ sin = freqs.sin()
358
+ cache = torch.cat((cos, sin), dim=-1)
359
+ return cache
360
+
361
+
362
+ # Inverse dim formula to find dim based on number of rotations
363
+ def _yarn_find_correction_dim(
364
+ num_rotations: int,
365
+ dim: int,
366
+ base: float = 10000,
367
+ max_position_embeddings: int = 2048,
368
+ ) -> float:
369
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
370
+ 2 * math.log(base)
371
+ )
372
+
373
+
374
+ # Find dim range bounds based on rotations
375
+ def _yarn_find_correction_range(
376
+ low_rot: int,
377
+ high_rot: int,
378
+ dim: int,
379
+ base: float = 10000,
380
+ max_position_embeddings: int = 2048,
381
+ ) -> Tuple[int, int]:
382
+ low = math.floor(
383
+ _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
384
+ )
385
+ high = math.ceil(
386
+ _yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
387
+ )
388
+ return max(low, 0), min(high, dim - 1) # Clamp values just in case
389
+
390
+
391
+ def _yarn_linear_ramp_mask(
392
+ low: float, high: float, dim: int, dtype: torch.dtype
393
+ ) -> torch.Tensor:
394
+ if low == high:
395
+ high += 0.001 # Prevent singularity
396
+
397
+ linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
398
+ ramp_func = torch.clamp(linear_func, 0, 1)
399
+ return ramp_func
400
+
401
+
402
+ def _yarn_get_mscale(scale: float = 1) -> float:
403
+ if scale <= 1:
404
+ return 1.0
405
+ return 0.1 * math.log(scale) + 1.0
406
+
407
+
408
+ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
409
+ """RotaryEmbedding extended with YaRN method.
410
+
411
+ Credits to Peng et al. github.com/jquesnelle/yarn
412
+ """
413
+
414
+ def __init__(
415
+ self,
416
+ head_size: int,
417
+ rotary_dim: int,
418
+ max_position_embeddings: int,
419
+ base: int,
420
+ is_neox_style: bool,
421
+ scaling_factor: float,
422
+ dtype: torch.dtype,
423
+ *,
424
+ extrapolation_factor: float = 1,
425
+ attn_factor: float = 1,
426
+ beta_fast: int = 32,
427
+ beta_slow: int = 1,
428
+ ) -> None:
429
+ self.scaling_factor = scaling_factor
430
+ self.extrapolation_factor = extrapolation_factor
431
+ self.attn_factor = attn_factor
432
+ self.beta_fast = beta_fast
433
+ self.beta_slow = beta_slow
434
+ # Get n-d magnitude scaling corrected for interpolation
435
+ self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor)
436
+ super().__init__(
437
+ head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
438
+ )
439
+
440
+ def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
441
+ pos_freqs = self.base ** (
442
+ torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
443
+ )
444
+ inv_freq_extrapolation = 1.0 / pos_freqs
445
+ inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
446
+
447
+ low, high = _yarn_find_correction_range(
448
+ self.beta_fast,
449
+ self.beta_slow,
450
+ self.rotary_dim,
451
+ self.base,
452
+ self.max_position_embeddings,
453
+ )
454
+ # Get n-d rotational scaling corrected for extrapolation
455
+ inv_freq_mask = (
456
+ 1
457
+ - _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float)
458
+ ) * self.extrapolation_factor
459
+ inv_freq = (
460
+ inv_freq_interpolation * (1 - inv_freq_mask)
461
+ + inv_freq_extrapolation * inv_freq_mask
462
+ )
463
+ return inv_freq
464
+
465
+ def _compute_cos_sin_cache(self) -> torch.Tensor:
466
+ inv_freq = self._compute_inv_freq(self.scaling_factor)
467
+ t = torch.arange(
468
+ self.max_position_embeddings * self.scaling_factor, dtype=torch.float32
469
+ )
470
+ freqs = torch.einsum("i,j -> ij", t, inv_freq)
471
+ cos = freqs.cos() * self.mscale
472
+ sin = freqs.sin() * self.mscale
473
+ cache = torch.cat((cos, sin), dim=-1)
474
+ return cache
475
+
476
+
477
+ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
478
+ """Phi3 family of models scaled rotary embedding.
479
+
480
+ Based on the original RotaryEmbedding implementation.
481
+ """
482
+
483
+ def __init__(
484
+ self,
485
+ head_size: int,
486
+ rotary_dim: int,
487
+ max_position_embeddings: int,
488
+ original_max_position_embeddings: int,
489
+ base: int,
490
+ is_neox_style: bool,
491
+ dtype: torch.dtype,
492
+ short_factor: List[float],
493
+ long_factor: List[float],
494
+ short_mscale: Optional[float] = None,
495
+ long_mscale: Optional[float] = None,
496
+ ):
497
+ super().__init__()
498
+
499
+ if rotary_dim != head_size:
500
+ raise ValueError(
501
+ f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \
502
+ rotary_dim != head_size ({rotary_dim}!={head_size})."
503
+ )
504
+ if is_neox_style is False:
505
+ raise ValueError(
506
+ "`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
507
+ )
508
+
509
+ self.head_size = head_size
510
+ self.max_position_embeddings = max_position_embeddings
511
+ self.original_max_position_embeddings = original_max_position_embeddings
512
+ self.base = base
513
+ self.short_factor = short_factor
514
+ self.long_factor = long_factor
515
+
516
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
517
+ if scale <= 1.0:
518
+ scaling_factor = 1.0
519
+ else:
520
+ scaling_factor = math.sqrt(
521
+ 1 + math.log(scale) / math.log(self.original_max_position_embeddings)
522
+ )
523
+ if short_mscale is None:
524
+ short_mscale = scaling_factor
525
+ if long_mscale is None:
526
+ long_mscale = scaling_factor
527
+
528
+ self.short_mscale = short_mscale
529
+ self.long_mscale = long_mscale
530
+
531
+ short_cache = self._compute_cos_sin_cache(
532
+ original_max_position_embeddings, short_factor, short_mscale
533
+ )
534
+ short_cache = short_cache.to(dtype)
535
+ self.register_buffer("short_cos_sin_cache", short_cache, persistent=False)
536
+
537
+ long_cache = self._compute_cos_sin_cache(
538
+ max_position_embeddings, long_factor, long_mscale
539
+ )
540
+ long_cache = long_cache.to(dtype)
541
+ self.register_buffer("long_cos_sin_cache", long_cache, persistent=False)
542
+
543
+ long_short_cache = torch.cat(
544
+ [self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0
545
+ )
546
+ self.register_buffer(
547
+ "long_short_cos_sin_cache", long_short_cache, persistent=False
548
+ )
549
+
550
+ def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor:
551
+ rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
552
+ inv_freq = 1.0 / (
553
+ rescale_factors
554
+ * (
555
+ self.base
556
+ ** (
557
+ torch.arange(0, self.head_size, 2, dtype=torch.float)
558
+ / self.head_size
559
+ )
560
+ )
561
+ )
562
+ return inv_freq
563
+
564
+ def _compute_cos_sin_cache(
565
+ self,
566
+ max_position_embeddings: int,
567
+ rescale_factors: List[float],
568
+ mscale: float,
569
+ ) -> torch.Tensor:
570
+ inv_freq = self._compute_inv_freq(rescale_factors)
571
+ t = torch.arange(max_position_embeddings, dtype=torch.float)
572
+ freqs = torch.einsum("i,j -> ij", t, inv_freq)
573
+ cos = freqs.cos() * mscale
574
+ sin = freqs.sin() * mscale
575
+ cache = torch.cat((cos, sin), dim=-1)
576
+ return cache
577
+
578
+ def forward(
579
+ self,
580
+ positions: torch.Tensor,
581
+ query: torch.Tensor,
582
+ key: torch.Tensor,
583
+ offsets: Optional[torch.Tensor] = None,
584
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
585
+ query = query.view(*query.shape[:-1], -1, self.head_size)
586
+ key = key.view(*key.shape[:-1], -1, self.head_size)
587
+
588
+ k = self.original_max_position_embeddings
589
+ long_prompt_offset = (
590
+ torch.any(positions > k).float() * torch.full_like(positions, k)
591
+ ).long()
592
+ idx = (
593
+ torch.add(positions, long_prompt_offset)
594
+ if long_prompt_offset is not None
595
+ else positions
596
+ )
597
+ self.long_short_cos_sin_cache: torch.Tensor = self.long_short_cos_sin_cache.to(
598
+ idx.device
599
+ )
600
+ idx = torch.add(idx, offsets) if offsets is not None else idx
601
+ cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
602
+
603
+ cos, sin = cos_sin.chunk(2, dim=-1)
604
+ cos = cos.repeat(1, 2).unsqueeze(-2)
605
+ sin = sin.repeat(1, 2).unsqueeze(-2)
606
+
607
+ query = query * cos + _rotate_neox(query) * sin
608
+ key = key * cos + _rotate_neox(key) * sin
609
+
610
+ return query.flatten(-2), key.flatten(-2)
611
+
612
+
613
+ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
614
+ if scale <= 1:
615
+ return 1.0
616
+ return 0.1 * mscale * math.log(scale) + 1.0
617
+
618
+
619
+ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
620
+ """RotaryEmbedding extended with YaRN method.
18
621
 
622
+ Credits to Peng et al. github.com/jquesnelle/yarn
623
+ """
19
624
 
20
- class MRotaryEmbedding:
625
+ def __init__(
626
+ self,
627
+ head_size: int,
628
+ rotary_dim: int,
629
+ max_position_embeddings: int,
630
+ base: int,
631
+ is_neox_style: bool,
632
+ scaling_factor: float,
633
+ dtype: torch.dtype,
634
+ *,
635
+ extrapolation_factor: float = 1,
636
+ attn_factor: float = 1,
637
+ beta_fast: int = 32,
638
+ beta_slow: int = 1,
639
+ mscale: float = 1,
640
+ mscale_all_dim: float = 0,
641
+ device: Optional[str] = "cuda",
642
+ ) -> None:
643
+ self.scaling_factor = scaling_factor
644
+ self.extrapolation_factor = extrapolation_factor
645
+ self.attn_factor = attn_factor
646
+ self.beta_fast = beta_fast
647
+ self.beta_slow = beta_slow
648
+ # Get n-d magnitude scaling corrected for interpolation.
649
+ self.mscale = float(
650
+ yarn_get_mscale(self.scaling_factor, float(mscale))
651
+ / yarn_get_mscale(self.scaling_factor, float(mscale_all_dim))
652
+ * attn_factor
653
+ )
654
+ self.device = device
655
+ super().__init__(
656
+ head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
657
+ )
658
+
659
+ def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
660
+ pos_freqs = self.base ** (
661
+ torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device=self.device)
662
+ / self.rotary_dim
663
+ )
664
+ inv_freq_extrapolation = 1.0 / pos_freqs
665
+ inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
666
+
667
+ low, high = _yarn_find_correction_range(
668
+ self.beta_fast,
669
+ self.beta_slow,
670
+ self.rotary_dim,
671
+ self.base,
672
+ self.max_position_embeddings,
673
+ )
674
+ # Get n-d rotational scaling corrected for extrapolation
675
+ inv_freq_mask = (
676
+ 1
677
+ - _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float)
678
+ ) * self.extrapolation_factor
679
+ inv_freq = (
680
+ inv_freq_interpolation * (1 - inv_freq_mask)
681
+ + inv_freq_extrapolation * inv_freq_mask
682
+ )
683
+ return inv_freq
684
+
685
+ def _compute_cos_sin_cache(self) -> torch.Tensor:
686
+ inv_freq = self._compute_inv_freq(self.scaling_factor)
687
+ t = torch.arange(
688
+ self.max_position_embeddings * self.scaling_factor,
689
+ device=self.device,
690
+ dtype=torch.float32,
691
+ )
692
+ freqs = torch.einsum("i,j -> ij", t, inv_freq)
693
+ cos = freqs.cos() * self.mscale
694
+ sin = freqs.sin() * self.mscale
695
+ cache = torch.cat((cos, sin), dim=-1)
696
+ print("Cache shape", cache.shape)
697
+ return cache
698
+
699
+ def forward(
700
+ self,
701
+ positions: torch.Tensor,
702
+ query: torch.Tensor,
703
+ key: torch.Tensor,
704
+ offsets: Optional[torch.Tensor] = None,
705
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
706
+ """PyTorch-native implementation equivalent to forward()."""
707
+ query_rot = query[..., : self.rotary_dim]
708
+ key_rot = key[..., : self.rotary_dim]
709
+ if self.rotary_dim < self.head_size:
710
+ query_pass = query[..., self.rotary_dim :]
711
+ key_pass = key[..., self.rotary_dim :]
712
+
713
+ self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device)
714
+ cos_sin = self.cos_sin_cache[
715
+ torch.add(positions, offsets) if offsets is not None else positions
716
+ ]
717
+ cos, sin = cos_sin.chunk(2, dim=-1)
718
+ if self.is_neox_style:
719
+ # NOTE(woosuk): Here we assume that the positions tensor has the
720
+ # shape [batch_size, seq_len].
721
+ cos = cos.repeat(1, 1, 2).unsqueeze(-2)
722
+ sin = sin.repeat(1, 1, 2).unsqueeze(-2)
723
+ else:
724
+ cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
725
+ sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
726
+
727
+ rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
728
+ query_rot = query_rot * cos + rotate_fn(query_rot) * sin
729
+ key_rot = key_rot * cos + rotate_fn(key_rot) * sin
730
+
731
+ if self.rotary_dim < self.head_size:
732
+ query = torch.cat((query_rot, query_pass), dim=-1)
733
+ key = torch.cat((key_rot, key_pass), dim=-1)
734
+ else:
735
+ query = query_rot
736
+ key = key_rot
737
+ return query, key
738
+
739
+
740
+ class Llama3RotaryEmbedding(RotaryEmbedding):
741
+
742
+ def __init__(
743
+ self,
744
+ head_size: int,
745
+ rotary_dim: int,
746
+ max_position_embeddings: int,
747
+ base: int,
748
+ is_neox_style: bool,
749
+ dtype: torch.dtype,
750
+ scaling_factor: float,
751
+ low_freq_factor: float,
752
+ high_freq_factor: float,
753
+ orig_max_position: int,
754
+ ) -> None:
755
+ self.scaling_factor = scaling_factor
756
+ self.low_freq_factor = low_freq_factor
757
+ self.high_freq_factor = high_freq_factor
758
+ self.orig_max_position = orig_max_position
759
+ super().__init__(
760
+ head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
761
+ )
762
+
763
+ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
764
+ inv_freqs = super()._compute_inv_freq(base)
765
+ low_freq_wavelen = self.orig_max_position / self.low_freq_factor
766
+ high_freq_wavelen = self.orig_max_position / self.high_freq_factor
767
+
768
+ wave_len = 2 * math.pi / inv_freqs
769
+ if self.low_freq_factor != self.high_freq_factor:
770
+ smooth = (self.orig_max_position / wave_len - self.low_freq_factor) / (
771
+ self.high_freq_factor - self.low_freq_factor
772
+ )
773
+ else:
774
+ smooth = 0
775
+ new_freqs = torch.where(
776
+ wave_len < high_freq_wavelen,
777
+ inv_freqs,
778
+ torch.where(
779
+ wave_len > low_freq_wavelen,
780
+ inv_freqs / self.scaling_factor,
781
+ (1 - smooth) * inv_freqs / self.scaling_factor + smooth * inv_freqs,
782
+ ),
783
+ )
784
+ return new_freqs
785
+
786
+
787
+ class MRotaryEmbedding(RotaryEmbedding):
21
788
  """Rotary Embedding with Multimodal Sections."""
22
789
 
790
+ def __init__(
791
+ self,
792
+ head_size: int,
793
+ rotary_dim: int,
794
+ max_position_embeddings: int,
795
+ base: int,
796
+ is_neox_style: bool,
797
+ dtype: torch.dtype,
798
+ mrope_section: Optional[List[int]] = None,
799
+ ) -> None:
800
+ super().__init__(
801
+ head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
802
+ )
803
+
804
+ self.mrope_section = mrope_section
805
+ if self.mrope_section:
806
+ assert sum(self.mrope_section) == rotary_dim // 2
807
+
808
+ def forward(
809
+ self,
810
+ positions: torch.Tensor,
811
+ query: torch.Tensor,
812
+ key: torch.Tensor,
813
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
814
+ """PyTorch-native implementation equivalent to forward().
815
+
816
+ Args:
817
+ positions:
818
+ [num_tokens,] (text only) or
819
+ [3, num_tokens] (T/H/W positions with multimodal inputs)
820
+ query: [num_tokens, num_heads * head_size]
821
+ key: [num_tokens, num_kv_heads * head_size]
822
+ """
823
+ assert positions.ndim == 1 or positions.ndim == 2
824
+
825
+ num_tokens = positions.shape[-1]
826
+ cos_sin = self.cos_sin_cache[positions]
827
+ cos, sin = cos_sin.chunk(2, dim=-1)
828
+ if positions.ndim == 2:
829
+ assert self.mrope_section
830
+
831
+ cos = torch.cat(
832
+ [m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))],
833
+ dim=-1,
834
+ )
835
+ sin = torch.cat(
836
+ [m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))],
837
+ dim=-1,
838
+ )
839
+
840
+ query_shape = query.shape
841
+ query = query.view(num_tokens, -1, self.head_size)
842
+ query_rot = query[..., : self.rotary_dim]
843
+ query_pass = query[..., self.rotary_dim :]
844
+ query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
845
+ query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
846
+
847
+ key_shape = key.shape
848
+ key = key.view(num_tokens, -1, self.head_size)
849
+ key_rot = key[..., : self.rotary_dim]
850
+ key_pass = key[..., self.rotary_dim :]
851
+ key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
852
+ key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
853
+ return query, key
854
+
23
855
  @staticmethod
24
856
  def get_input_positions(
25
- input_tokens: torch.Tensor,
857
+ input_tokens: List[int],
26
858
  image_grid_thw: Union[List[List[int]], torch.Tensor],
859
+ video_grid_thw: Union[List[List[int]], torch.Tensor],
860
+ image_token_id: int,
861
+ video_token_id: int,
27
862
  vision_start_token_id: int,
863
+ vision_end_token_id: int,
28
864
  spatial_merge_size: int,
29
865
  context_len: int = 0,
866
+ seq_len: Optional[int] = None,
30
867
  ) -> Tuple[List[List[int]], int]:
31
868
  """Get mrope input positions and delta value."""
32
869
 
33
870
  if isinstance(image_grid_thw, torch.Tensor):
34
871
  image_grid_thw = image_grid_thw.tolist()
872
+ if isinstance(video_grid_thw, torch.Tensor):
873
+ video_grid_thw = video_grid_thw.tolist()
35
874
 
875
+ input_tokens_tensor = torch.tensor(input_tokens)
36
876
  vision_start_indices = torch.argwhere(
37
- input_tokens == vision_start_token_id
877
+ input_tokens_tensor == vision_start_token_id
38
878
  ).squeeze(1)
39
- image_indices = vision_start_indices + 1
40
- image_nums = image_indices.shape[0]
879
+ vision_tokens = input_tokens_tensor[vision_start_indices + 1]
880
+ image_nums = (vision_tokens == image_token_id).sum()
881
+ video_nums = (vision_tokens == video_token_id).sum()
41
882
  llm_pos_ids_list: list = []
42
883
 
43
884
  st = 0
44
- input_tokens_len = input_tokens.shape[0]
45
- for image_index in range(image_nums):
46
- ed = image_indices[image_index].item()
47
- t, h, w = (
48
- image_grid_thw[image_index][0],
49
- image_grid_thw[image_index][1],
50
- image_grid_thw[image_index][2],
51
- )
885
+ remain_images, remain_videos = image_nums, video_nums
886
+
887
+ image_index, video_index = 0, 0
888
+ for _ in range(image_nums + video_nums):
889
+ if image_token_id in input_tokens and remain_images > 0:
890
+ ed_image = input_tokens.index(image_token_id, st)
891
+ else:
892
+ ed_image = len(input_tokens) + 1
893
+ if video_token_id in input_tokens and remain_videos > 0:
894
+ ed_video = input_tokens.index(video_token_id, st)
895
+ else:
896
+ ed_video = len(input_tokens) + 1
897
+ if ed_image < ed_video:
898
+ t, h, w = (
899
+ image_grid_thw[image_index][0],
900
+ image_grid_thw[image_index][1],
901
+ image_grid_thw[image_index][2],
902
+ )
903
+ image_index += 1
904
+ remain_images -= 1
905
+ ed = ed_image
906
+ else:
907
+ t, h, w = (
908
+ video_grid_thw[video_index][0],
909
+ video_grid_thw[video_index][1],
910
+ video_grid_thw[video_index][2],
911
+ )
912
+ video_index += 1
913
+ remain_videos -= 1
914
+ ed = ed_video
52
915
  llm_grid_t, llm_grid_h, llm_grid_w = (
53
916
  t,
54
917
  h // spatial_merge_size,
@@ -84,16 +947,17 @@ class MRotaryEmbedding:
84
947
  )
85
948
  st = ed + llm_grid_t * llm_grid_h * llm_grid_w
86
949
 
87
- if st < input_tokens_len:
950
+ if st < len(input_tokens):
88
951
  st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
89
- text_len = input_tokens_len - st
952
+ text_len = len(input_tokens) - st
90
953
  llm_pos_ids_list.append(
91
954
  torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
92
955
  )
93
956
 
94
957
  llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
95
- llm_positions = llm_positions[:, context_len:]
96
- mrope_position_delta = (llm_positions.max() + 1 - input_tokens_len).item()
958
+ mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
959
+ llm_positions = llm_positions[:, context_len:seq_len]
960
+
97
961
  return llm_positions.tolist(), mrope_position_delta
98
962
 
99
963
  @staticmethod
@@ -110,3 +974,292 @@ class MRotaryEmbedding:
110
974
  )
111
975
  for _ in range(3)
112
976
  ]
977
+
978
+
979
+ _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
980
+
981
+
982
+ def get_rope(
983
+ head_size: int,
984
+ rotary_dim: int,
985
+ max_position: int,
986
+ base: int,
987
+ is_neox_style: bool = True,
988
+ rope_scaling: Optional[Dict[str, Any]] = None,
989
+ dtype: Optional[torch.dtype] = None,
990
+ partial_rotary_factor: float = 1.0,
991
+ ) -> RotaryEmbedding:
992
+ if dtype is None:
993
+ dtype = torch.get_default_dtype()
994
+ if rope_scaling is not None:
995
+ # Transforms every value that is a list into a tuple for caching calls
996
+ rope_scaling_tuple = {
997
+ k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items()
998
+ }
999
+ rope_scaling_args = tuple(rope_scaling_tuple.items())
1000
+ else:
1001
+ rope_scaling_args = None
1002
+ if partial_rotary_factor < 1.0:
1003
+ rotary_dim = int(rotary_dim * partial_rotary_factor)
1004
+ key = (
1005
+ head_size,
1006
+ rotary_dim,
1007
+ max_position,
1008
+ base,
1009
+ is_neox_style,
1010
+ rope_scaling_args,
1011
+ dtype,
1012
+ )
1013
+ if key in _ROPE_DICT:
1014
+ return _ROPE_DICT[key]
1015
+
1016
+ if rope_scaling is None:
1017
+ rotary_emb = RotaryEmbedding(
1018
+ head_size, rotary_dim, max_position, base, is_neox_style, dtype
1019
+ )
1020
+ else:
1021
+ if "rope_type" in rope_scaling:
1022
+ scaling_type = rope_scaling["rope_type"]
1023
+ elif "type" in rope_scaling:
1024
+ scaling_type = rope_scaling["type"]
1025
+ else:
1026
+ raise ValueError("Unknown RoPE scaling type")
1027
+
1028
+ if scaling_type == "llama3":
1029
+ scaling_factor = rope_scaling["factor"]
1030
+ low_freq_factor = rope_scaling["low_freq_factor"]
1031
+ high_freq_factor = rope_scaling["high_freq_factor"]
1032
+ original_max_position = rope_scaling["original_max_position_embeddings"]
1033
+ rotary_emb = Llama3RotaryEmbedding(
1034
+ head_size,
1035
+ rotary_dim,
1036
+ max_position,
1037
+ base,
1038
+ is_neox_style,
1039
+ dtype,
1040
+ scaling_factor,
1041
+ low_freq_factor,
1042
+ high_freq_factor,
1043
+ original_max_position,
1044
+ )
1045
+ elif scaling_type == "default":
1046
+ if "mrope_section" in rope_scaling:
1047
+ rotary_emb = MRotaryEmbedding(
1048
+ head_size,
1049
+ rotary_dim,
1050
+ max_position,
1051
+ base,
1052
+ is_neox_style,
1053
+ dtype,
1054
+ mrope_section=rope_scaling["mrope_section"],
1055
+ )
1056
+ else:
1057
+ rotary_emb = RotaryEmbedding(
1058
+ head_size,
1059
+ rotary_dim,
1060
+ max_position,
1061
+ base,
1062
+ is_neox_style,
1063
+ dtype,
1064
+ )
1065
+ elif scaling_type == "linear":
1066
+ scaling_factor = rope_scaling["factor"]
1067
+ rotary_emb = LinearScalingRotaryEmbedding(
1068
+ head_size,
1069
+ rotary_dim,
1070
+ max_position,
1071
+ base,
1072
+ is_neox_style,
1073
+ scaling_factor,
1074
+ dtype,
1075
+ )
1076
+ elif scaling_type == "dynamic":
1077
+ scaling_factor = rope_scaling["factor"]
1078
+ rotary_emb = DynamicNTKScalingRotaryEmbedding(
1079
+ head_size,
1080
+ rotary_dim,
1081
+ max_position,
1082
+ base,
1083
+ is_neox_style,
1084
+ scaling_factor,
1085
+ dtype,
1086
+ )
1087
+ elif scaling_type == "yarn":
1088
+ scaling_factor = rope_scaling["factor"]
1089
+ original_max_position = rope_scaling["original_max_position_embeddings"]
1090
+ extra_kwargs = {
1091
+ k: v
1092
+ for k, v in rope_scaling.items()
1093
+ if k
1094
+ in ("extrapolation_factor", "attn_factor", "beta_fast", "beta_slow")
1095
+ }
1096
+ rotary_emb = YaRNScalingRotaryEmbedding(
1097
+ head_size,
1098
+ rotary_dim,
1099
+ original_max_position,
1100
+ base,
1101
+ is_neox_style,
1102
+ scaling_factor,
1103
+ dtype,
1104
+ **extra_kwargs,
1105
+ )
1106
+ elif scaling_type == "deepseek_yarn":
1107
+ scaling_factor = rope_scaling["factor"]
1108
+ original_max_position = rope_scaling["original_max_position_embeddings"]
1109
+ # assert max_position == original_max_position * scaling_factor
1110
+ extra_kwargs = {
1111
+ k: v
1112
+ for k, v in rope_scaling.items()
1113
+ if k
1114
+ in (
1115
+ "extrapolation_factor",
1116
+ "attn_factor",
1117
+ "beta_fast",
1118
+ "beta_slow",
1119
+ "mscale",
1120
+ "mscale_all_dim",
1121
+ )
1122
+ }
1123
+ rotary_emb = DeepseekScalingRotaryEmbedding(
1124
+ head_size,
1125
+ rotary_dim,
1126
+ original_max_position,
1127
+ base,
1128
+ is_neox_style,
1129
+ scaling_factor,
1130
+ dtype,
1131
+ **extra_kwargs,
1132
+ )
1133
+ elif scaling_type == "longrope":
1134
+ short_factor = rope_scaling["short_factor"]
1135
+ long_factor = rope_scaling["long_factor"]
1136
+ original_max_position = rope_scaling["original_max_position_embeddings"]
1137
+ extra_kwargs = {
1138
+ k: v
1139
+ for k, v in rope_scaling.items()
1140
+ if k in ("short_mscale", "long_mscale")
1141
+ }
1142
+ rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
1143
+ head_size,
1144
+ rotary_dim,
1145
+ max_position,
1146
+ original_max_position,
1147
+ base,
1148
+ is_neox_style,
1149
+ dtype,
1150
+ short_factor,
1151
+ long_factor,
1152
+ **extra_kwargs,
1153
+ )
1154
+ else:
1155
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
1156
+ _ROPE_DICT[key] = rotary_emb
1157
+ return rotary_emb
1158
+
1159
+
1160
+ def get_rope_cpu(
1161
+ head_size: int,
1162
+ rotary_dim: int,
1163
+ max_position: int,
1164
+ base: int,
1165
+ is_neox_style: bool = True,
1166
+ rope_scaling: Optional[Dict[str, Any]] = None,
1167
+ dtype: Optional[torch.dtype] = None,
1168
+ partial_rotary_factor: float = 1.0,
1169
+ device: Optional[str] = None,
1170
+ ) -> RotaryEmbedding:
1171
+ if dtype is None:
1172
+ dtype = torch.get_default_dtype()
1173
+ if rope_scaling is not None:
1174
+ # Transforms every value that is a list into a tuple for caching calls
1175
+ rope_scaling_tuple = {
1176
+ k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items()
1177
+ }
1178
+ rope_scaling_args = tuple(rope_scaling_tuple.items())
1179
+ else:
1180
+ rope_scaling_args = None
1181
+ if partial_rotary_factor < 1.0:
1182
+ rotary_dim = int(rotary_dim * partial_rotary_factor)
1183
+ key = (
1184
+ head_size,
1185
+ rotary_dim,
1186
+ max_position,
1187
+ base,
1188
+ is_neox_style,
1189
+ rope_scaling_args,
1190
+ dtype,
1191
+ )
1192
+ if key in _ROPE_DICT:
1193
+ return _ROPE_DICT[key]
1194
+
1195
+ assert rope_scaling is not None
1196
+ scaling_type = rope_scaling["rope_type"]
1197
+ assert (
1198
+ scaling_type == "deepseek_yarn"
1199
+ ), "Only deepseek_yarn is supported for CPU for now"
1200
+
1201
+ scaling_factor = rope_scaling["factor"]
1202
+ original_max_position = rope_scaling["original_max_position_embeddings"]
1203
+ extra_kwargs = {
1204
+ k: v
1205
+ for k, v in rope_scaling.items()
1206
+ if k
1207
+ in (
1208
+ "extrapolation_factor",
1209
+ "attn_factor",
1210
+ "beta_fast",
1211
+ "beta_slow",
1212
+ "mscale",
1213
+ "mscale_all_dim",
1214
+ )
1215
+ }
1216
+ extra_kwargs["device"] = device
1217
+ rotary_emb = DeepseekScalingRotaryEmbedding(
1218
+ head_size,
1219
+ rotary_dim,
1220
+ original_max_position,
1221
+ base,
1222
+ is_neox_style,
1223
+ scaling_factor,
1224
+ dtype,
1225
+ **extra_kwargs,
1226
+ )
1227
+
1228
+ _ROPE_DICT[key] = rotary_emb
1229
+ return rotary_emb
1230
+
1231
+
1232
+ def get_rope_wrapper(
1233
+ head_size: int,
1234
+ rotary_dim: int,
1235
+ max_position: int,
1236
+ base: int,
1237
+ is_neox_style: bool = True,
1238
+ rope_scaling: Optional[Dict[str, Any]] = None,
1239
+ dtype: Optional[torch.dtype] = None,
1240
+ partial_rotary_factor: float = 1.0,
1241
+ device: Optional[str] = None,
1242
+ ):
1243
+ if device != "cpu":
1244
+ return get_rope(
1245
+ head_size,
1246
+ rotary_dim,
1247
+ max_position,
1248
+ base,
1249
+ is_neox_style,
1250
+ rope_scaling,
1251
+ dtype,
1252
+ partial_rotary_factor,
1253
+ )
1254
+
1255
+ return get_rope_cpu(
1256
+ head_size,
1257
+ rotary_dim,
1258
+ max_position,
1259
+ base,
1260
+ is_neox_style,
1261
+ rope_scaling,
1262
+ dtype,
1263
+ partial_rotary_factor,
1264
+ device,
1265
+ )