tpu-inference 0.11.1.dev202511150811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (179) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +105 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +654 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +96 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +182 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +236 -0
  27. tpu_inference/__init__.py +34 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +51 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +728 -0
  37. tpu_inference/distributed/utils.py +59 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +107 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +362 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +390 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +507 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +105 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +311 -0
  120. tpu_inference/mock/__init__.py +0 -0
  121. tpu_inference/mock/vllm_config_utils.py +28 -0
  122. tpu_inference/mock/vllm_envs.py +1219 -0
  123. tpu_inference/mock/vllm_logger.py +212 -0
  124. tpu_inference/mock/vllm_logging_utils.py +15 -0
  125. tpu_inference/models/__init__.py +0 -0
  126. tpu_inference/models/common/__init__.py +0 -0
  127. tpu_inference/models/common/model_loader.py +444 -0
  128. tpu_inference/models/jax/__init__.py +0 -0
  129. tpu_inference/models/jax/deepseek_v3.py +868 -0
  130. tpu_inference/models/jax/gpt_oss.py +492 -0
  131. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  132. tpu_inference/models/jax/llama3.py +375 -0
  133. tpu_inference/models/jax/llama4.py +629 -0
  134. tpu_inference/models/jax/llama_eagle3.py +333 -0
  135. tpu_inference/models/jax/phi3.py +376 -0
  136. tpu_inference/models/jax/qwen2.py +375 -0
  137. tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
  138. tpu_inference/models/jax/qwen3.py +302 -0
  139. tpu_inference/models/jax/utils/__init__.py +0 -0
  140. tpu_inference/models/jax/utils/file_utils.py +96 -0
  141. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  142. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  143. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  144. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  145. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  146. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  147. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  148. tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
  149. tpu_inference/models/jax/utils/weight_utils.py +529 -0
  150. tpu_inference/models/vllm/__init__.py +0 -0
  151. tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
  152. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  153. tpu_inference/platforms/__init__.py +2 -0
  154. tpu_inference/platforms/tpu_platform.py +269 -0
  155. tpu_inference/runner/__init__.py +0 -0
  156. tpu_inference/runner/block_table.py +122 -0
  157. tpu_inference/runner/compilation_manager.py +780 -0
  158. tpu_inference/runner/input_batch.py +435 -0
  159. tpu_inference/runner/kv_cache.py +132 -0
  160. tpu_inference/runner/kv_cache_manager.py +479 -0
  161. tpu_inference/runner/lora_utils.py +92 -0
  162. tpu_inference/runner/multimodal_manager.py +217 -0
  163. tpu_inference/runner/persistent_batch_manager.py +244 -0
  164. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  165. tpu_inference/runner/structured_decoding_manager.py +88 -0
  166. tpu_inference/runner/tpu_runner.py +1620 -0
  167. tpu_inference/runner/utils.py +426 -0
  168. tpu_inference/spec_decode/__init__.py +0 -0
  169. tpu_inference/spec_decode/jax/__init__.py +0 -0
  170. tpu_inference/spec_decode/jax/eagle3.py +367 -0
  171. tpu_inference/tpu_info.py +77 -0
  172. tpu_inference/utils.py +317 -0
  173. tpu_inference/worker/__init__.py +0 -0
  174. tpu_inference/worker/tpu_worker.py +321 -0
  175. tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
  176. tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
  177. tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
  178. tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
  179. tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
@@ -0,0 +1,1103 @@
1
+ import math
2
+ from functools import partial
3
+ from typing import (Callable, List, Literal, NamedTuple, Optional, TypedDict,
4
+ Union)
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+ from flax import nnx
10
+ from jax.sharding import Mesh
11
+ from transformers import modeling_flax_utils
12
+ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
13
+ Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
14
+ from vllm.config import VllmConfig
15
+
16
+ from tpu_inference import utils as utils
17
+ from tpu_inference.layers.common.attention_interface import \
18
+ sharded_flash_attention
19
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
20
+ from tpu_inference.logger import init_logger
21
+ from tpu_inference.models.jax.qwen2 import Qwen2ForCausalLM
22
+ # from vllm.model_executor.models.interfaces import MultiModalEmbeddings
23
+ from tpu_inference.models.jax.utils.multi_modal_utils import (
24
+ MultiModalEmbeddings, merge_multimodal_embeddings)
25
+ from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
26
+ load_hf_weights)
27
+
28
+ logger = init_logger(__name__)
29
+
30
+ init_fn = nnx.initializers.uniform()
31
+
32
+ DEFAULT_BLOCK_K_MAJOR = 128
33
+
34
+
35
+ class SegmentIds(NamedTuple):
36
+ """SegmentIds for Q and KV sequences.
37
+
38
+ SegmentIds are used to generate segment mask, which prevents attention between
39
+ different segments in the input sequence. Each array is a list of ids
40
+ (integers).
41
+ Only the token with the same id can attend to each other.
42
+
43
+ Attributes:
44
+ q: segment ids along the Q sequence.
45
+ kv: segment ids along the KV sequence.
46
+ """
47
+
48
+ q: jax.Array # [batch_size, q_seq_len]
49
+ kv: jax.Array # [batch_size, kv_seq_len]
50
+
51
+
52
+ class Qwen2_5_VLImagePixelInputs(TypedDict):
53
+ type: Literal["pixel_values"]
54
+ pixel_values: jax.Array
55
+ """Shape:
56
+ `(num_patches, num_channels * patch_size * patch_size)`
57
+ """
58
+
59
+ image_grid_thw: tuple[tuple[int, int, int], ...]
60
+ """Shape: `(num_images, 3)`
61
+ This should be in `(grid_t, grid_h, grid_w)` format.
62
+ """
63
+
64
+
65
+ # NOTE: We are not supporting embedding inputs for now
66
+ # The code here makes the struture consistent and
67
+ # makes iteasier for future implementation
68
+ class Qwen2_5_VLImageEmbeddingInputs(TypedDict):
69
+ type: Literal["image_embeds"]
70
+ image_embeds: jax.Array
71
+ """Supported types:
72
+ - list[`jax.Array`]: A list of tensors holding all images' features.
73
+ Each tensor holds an image's features.
74
+ - `jax.Array`: A tensor holding all images' features (concatenation of
75
+ all images' feature tensors).
76
+
77
+ Tensor shape: `(num_image_features, hidden_size)`
78
+ - `num_image_features` varies based on
79
+ the number and resolution of the images.
80
+ - `hidden_size` must match the hidden size of language model backbone.
81
+ """
82
+
83
+ image_grid_thw: jax.Array
84
+ """Shape: `(num_images, 3)`
85
+ This should be in `(grid_t, grid_h, grid_w)` format.
86
+ """
87
+
88
+
89
+ Qwen2_5_VLImageInputs = Union[Qwen2_5_VLImagePixelInputs,
90
+ Qwen2_5_VLImageEmbeddingInputs]
91
+
92
+
93
+ class Qwen2_5_VisionMLP(nnx.Module):
94
+
95
+ def __init__(self, config: Qwen2_5_VLVisionConfig, dtype: jnp.dtype,
96
+ rngs: nnx.Rngs):
97
+ in_features = config.hidden_size
98
+ hidden_features = config.intermediate_size
99
+ act_fn = modeling_flax_utils.ACT2FN[config.hidden_act]
100
+ self.gate_proj = nnx.Linear(
101
+ in_features,
102
+ hidden_features,
103
+ use_bias=True,
104
+ param_dtype=dtype,
105
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
106
+ bias_init=nnx.with_partitioning(init_fn, ("model", )),
107
+ rngs=rngs,
108
+ )
109
+ self.up_proj = nnx.Linear(
110
+ in_features,
111
+ hidden_features,
112
+ use_bias=True,
113
+ param_dtype=dtype,
114
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
115
+ bias_init=nnx.with_partitioning(init_fn, ("model", )),
116
+ rngs=rngs,
117
+ )
118
+ self.down_proj = nnx.Linear(
119
+ hidden_features,
120
+ in_features,
121
+ use_bias=True,
122
+ param_dtype=dtype,
123
+ kernel_init=nnx.with_partitioning(init_fn, ("model", None)),
124
+ bias_init=nnx.with_partitioning(init_fn, (None, )),
125
+ rngs=rngs,
126
+ )
127
+ self.act_fn = act_fn
128
+
129
+ def __call__(self, x: jax.Array) -> jax.Array:
130
+ gate = self.act_fn(self.gate_proj(x))
131
+ up = self.up_proj(x)
132
+ fuse = gate * up
133
+ result = self.down_proj(fuse)
134
+ return result
135
+
136
+
137
+ def apply_rotary_pos_emb_vision(x: jax.Array,
138
+ rotary_pos_emb: jax.Array) -> jax.Array:
139
+ # x: [B, T, N, H]
140
+ # rotary_pos_emb: [T, H//2]
141
+ _, _, _, H = x.shape
142
+ half_dim = H // 2
143
+
144
+ # [B, T, N, H//2]
145
+ x_real = x[..., :half_dim]
146
+ x_imag = x[..., half_dim:]
147
+
148
+ # [T, H//2]
149
+ cos_emb = jnp.cos(rotary_pos_emb)
150
+ sin_emb = jnp.sin(rotary_pos_emb)
151
+
152
+ # [1, T, 1, H//2]
153
+ cos_emb = cos_emb[None, :, None, :]
154
+ sin_emb = sin_emb[None, :, None, :]
155
+
156
+ # [B, T, N, H//2]
157
+ x_rotated_real = x_real * cos_emb - x_imag * sin_emb
158
+ x_rotated_imag = x_real * sin_emb + x_imag * cos_emb
159
+
160
+ # [B, T, N, H]
161
+ x_rotated = jnp.concatenate([x_rotated_real, x_rotated_imag], axis=-1)
162
+
163
+ return x_rotated
164
+
165
+
166
+ def generate_window_segment_ids(cu_seqlens: jax.Array, seq_len: int,
167
+ padded_seq_len: int) -> SegmentIds:
168
+ """Generates segment IDs for windowed attention
169
+
170
+ Args:
171
+ cu_seqlens: A 1D array of cumulative sequence lengths for each window.
172
+ e.g., [0, len_win0, len_win0+len_win1, ...]
173
+
174
+ Returns:
175
+ A SegmentIds object for flash_attention.
176
+ """
177
+ indices = jnp.arange(seq_len, dtype=jnp.int32)
178
+ segment_ids = jnp.searchsorted(cu_seqlens[1:], indices, side='right') + 1
179
+ padding_segment_ids = jnp.zeros(padded_seq_len - seq_len, dtype=jnp.int32)
180
+ segment_ids = jnp.concatenate([segment_ids, padding_segment_ids])
181
+ segment_ids = segment_ids.reshape(1, -1)
182
+
183
+ return SegmentIds(q=segment_ids, kv=segment_ids)
184
+
185
+
186
+ class Qwen2_5_VisionAttention(nnx.Module):
187
+
188
+ def __init__(self, config: Qwen2_5_VLConfig, dtype: jnp.dtype,
189
+ rngs: nnx.Rngs, mesh: Mesh):
190
+ vision_config = config.vision_config
191
+ self.hidden_size = vision_config.hidden_size
192
+ self.num_heads = vision_config.num_heads
193
+ self.num_kv_heads = self.num_heads
194
+ self.rope_theta = config.rope_theta
195
+ self.rope_scaling = getattr(config, "rope_scaling", None)
196
+ self.head_dim_original = self.hidden_size // self.num_heads
197
+
198
+ sharding_size = mesh.shape["model"]
199
+ self.num_heads = utils.get_padded_num_heads(self.num_heads,
200
+ sharding_size)
201
+ self.num_kv_heads = utils.get_padded_num_heads(self.num_kv_heads,
202
+ sharding_size)
203
+ self.head_dim = utils.get_padded_head_dim(self.head_dim_original)
204
+
205
+ # TODO: Wenlong: Do not consider padding for now
206
+ self.head_dim = self.head_dim_original
207
+
208
+ self.mesh = mesh
209
+
210
+ self.qkv_proj = nnx.Linear(
211
+ self.hidden_size,
212
+ 3 * self.hidden_size,
213
+ use_bias=True,
214
+ param_dtype=dtype,
215
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
216
+ bias_init=nnx.with_partitioning(init_fn, ("model", )),
217
+ rngs=rngs,
218
+ )
219
+
220
+ self.proj = nnx.Linear(
221
+ self.hidden_size,
222
+ self.hidden_size,
223
+ use_bias=True,
224
+ param_dtype=dtype,
225
+ kernel_init=nnx.with_partitioning(init_fn, ("model", None)),
226
+ bias_init=nnx.with_partitioning(init_fn, (None, )),
227
+ rngs=rngs,
228
+ )
229
+ self.flash_attention = sharded_flash_attention(
230
+ mesh=mesh,
231
+ causal=False,
232
+ sm_scale=1.0 / math.sqrt(self.head_dim),
233
+ vmem_limit_bytes=128 * 1024 * 1024,
234
+ )
235
+
236
+ def __call__(
237
+ self,
238
+ x: jax.Array,
239
+ rotary_pos_emb: jax.Array,
240
+ cu_window_seqlens: Optional[jax.Array] = None,
241
+ use_fullattn: bool = True,
242
+ ) -> jax.Array:
243
+ T, B, D = x.shape
244
+ assert B == 1, "Vision attention currently only supports batch size 1"
245
+ # [T, B, D] -> [T, B, 3 * D]
246
+ qkv = self.qkv_proj(x)
247
+
248
+ # Split into Q, K, V.
249
+ # NOTE: simplified from vLLM's split_qkv,
250
+ # may need to revisit for tp>1
251
+ # [T, B, 3 * D] -> 3 *[T, B, D]
252
+ q, k, v = jnp.split(qkv, 3, axis=-1)
253
+
254
+ # [T, B, N, H]
255
+ q = q.reshape(T, B, self.num_heads, self.head_dim)
256
+ k = k.reshape(T, B, self.num_heads, self.head_dim)
257
+ v = v.reshape(T, B, self.num_heads, self.head_dim)
258
+
259
+ # [T, B, N, H] -> [B, T, N, H]
260
+ q = jnp.transpose(q, (1, 0, 2, 3))
261
+ k = jnp.transpose(k, (1, 0, 2, 3))
262
+ v = jnp.transpose(v, (1, 0, 2, 3))
263
+
264
+ # rotary_pos_emb shape: (T, H)
265
+ q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
266
+ k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
267
+
268
+ # NOTE: an extra transpose because we need to
269
+ # align the correctness with vLLM's design.
270
+ # Might be able to remove one once implemented.
271
+ # [B, T, N, H] -> [B, N, T, H]
272
+ q = jnp.transpose(q, (0, 2, 1, 3))
273
+ k = jnp.transpose(k, (0, 2, 1, 3))
274
+ v = jnp.transpose(v, (0, 2, 1, 3))
275
+
276
+ # Pad the sequence length to be a multiple of 128 for flash_attention
277
+ block_k_major = DEFAULT_BLOCK_K_MAJOR
278
+ T_attn = q.shape[2]
279
+ padded_T = (T_attn + block_k_major -
280
+ 1) // block_k_major * block_k_major
281
+ pad_width = ((0, 0), (0, 0), (0, padded_T - T_attn), (0, 0))
282
+
283
+ q = jnp.pad(q, pad_width, 'constant')
284
+ k = jnp.pad(k, pad_width, 'constant')
285
+ v = jnp.pad(v, pad_width, 'constant')
286
+
287
+ segment_ids = generate_window_segment_ids(cu_window_seqlens, T_attn,
288
+ padded_T)
289
+
290
+ # TODO (jacobplatin): add support for quantized KV cache?
291
+ output = self.flash_attention(q, k, v, segment_ids)
292
+
293
+ # Unpad the output
294
+ output = output[:, :, :T_attn, :]
295
+
296
+ # [B, N, T, H] -> [T, B, N, H]
297
+ output = jnp.transpose(output, (2, 0, 1, 3))
298
+
299
+ output = output.reshape(T, B, D)
300
+
301
+ output = self.proj(output)
302
+
303
+ return output
304
+
305
+
306
+ class Qwen2_5_VisionBlock(nnx.Module):
307
+
308
+ def __init__(self, config: Qwen2_5_VLConfig, dtype: jnp.dtype,
309
+ rngs: nnx.Rngs, mesh: Mesh):
310
+ vision_config = config.vision_config
311
+ dim = vision_config.hidden_size
312
+ norm_layer = partial(nnx.RMSNorm,
313
+ epsilon=config.rms_norm_eps,
314
+ scale_init=nnx.with_partitioning(
315
+ init_fn, (None, )))
316
+
317
+ self.norm1 = norm_layer(dim, dtype=dtype, rngs=rngs)
318
+ self.norm2 = norm_layer(dim, dtype=dtype, rngs=rngs)
319
+ self.attn = Qwen2_5_VisionAttention(config=config,
320
+ dtype=dtype,
321
+ rngs=rngs,
322
+ mesh=mesh)
323
+ self.mlp = Qwen2_5_VisionMLP(config=vision_config,
324
+ dtype=dtype,
325
+ rngs=rngs)
326
+
327
+ def __call__(self,
328
+ x: jax.Array,
329
+ rotary_pos_emb: jax.Array,
330
+ cu_window_seqlens: Optional[jax.Array] = None,
331
+ use_fullattn: bool = True) -> jax.Array:
332
+
333
+ x = x + self.attn(self.norm1(x), rotary_pos_emb, cu_window_seqlens,
334
+ use_fullattn)
335
+ x = x + self.mlp(self.norm2(x))
336
+
337
+ return x
338
+
339
+
340
+ class Qwen2_5_VisionPatchEmbed(nnx.Module):
341
+
342
+ def __init__(
343
+ self,
344
+ rngs: nnx.Rngs,
345
+ patch_size: int = 14,
346
+ temporal_patch_size: int = 2,
347
+ in_channels: int = 3,
348
+ hidden_size: int = 1152,
349
+ dtype: jnp.dtype = jnp.bfloat16,
350
+ ) -> None:
351
+ self.patch_size = patch_size
352
+ self.temporal_patch_size = temporal_patch_size
353
+ self.hidden_size = hidden_size
354
+ kernel_size = (temporal_patch_size, patch_size, patch_size)
355
+ self.proj = nnx.Conv(in_features=in_channels,
356
+ out_features=hidden_size,
357
+ kernel_size=kernel_size,
358
+ strides=kernel_size,
359
+ use_bias=False,
360
+ param_dtype=dtype,
361
+ kernel_init=nnx.with_partitioning(
362
+ init_fn, (None, None, None, None, "model")),
363
+ rngs=rngs)
364
+
365
+ def __call__(self, x: jax.Array) -> jax.Array:
366
+ # x is (L, C * T * H * W)
367
+ L, dim = x.shape
368
+ C = dim // (self.temporal_patch_size * self.patch_size *
369
+ self.patch_size)
370
+ # Reshape to (L, T, H, W, C) for Conv3D with channels_last
371
+ x = x.reshape(L, C, self.temporal_patch_size, self.patch_size,
372
+ self.patch_size)
373
+ # L,T,H,W,C
374
+ x = jnp.transpose(x, (0, 2, 3, 4, 1))
375
+ x = self.proj(x)
376
+ # After conv, shape is (L, T_out, H_out, W_out, C_out)
377
+ # With stride=kernel_size, T_out=H_out=W_out=1.
378
+ # So shape is (L, 1, 1, 1, hidden_size)
379
+ x = x.reshape(L, self.hidden_size)
380
+ return x
381
+
382
+
383
+ class Qwen2_5_VisionPatchMerger(nnx.Module):
384
+
385
+ def __init__(self, d_model: int, context_dim: int, norm_layer: Callable,
386
+ spatial_merge_size: int, dtype: jnp.dtype, rngs: nnx.Rngs):
387
+ self.hidden_size = context_dim * (spatial_merge_size**2)
388
+ self.ln_q = norm_layer(context_dim,
389
+ dtype=dtype,
390
+ rngs=rngs,
391
+ scale_init=nnx.with_partitioning(
392
+ init_fn, (None, )))
393
+ self.mlp_fc1 = nnx.Linear(
394
+ self.hidden_size,
395
+ self.hidden_size,
396
+ use_bias=True,
397
+ param_dtype=dtype,
398
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
399
+ bias_init=nnx.with_partitioning(init_fn, ("model", )),
400
+ rngs=rngs)
401
+ self.mlp_act = modeling_flax_utils.ACT2FN["gelu"]
402
+ self.mlp_fc2 = nnx.Linear(
403
+ self.hidden_size,
404
+ d_model,
405
+ use_bias=True,
406
+ param_dtype=dtype,
407
+ kernel_init=nnx.with_partitioning(init_fn, ("model", None)),
408
+ bias_init=nnx.with_partitioning(init_fn, (None, )),
409
+ rngs=rngs)
410
+
411
+ def __call__(self, x: jax.Array) -> jax.Array:
412
+ x = self.ln_q(x)
413
+ x = x.reshape(-1, self.hidden_size)
414
+ x = self.mlp_fc1(x)
415
+ x = self.mlp_act(x)
416
+ x = self.mlp_fc2(x)
417
+ return x
418
+
419
+
420
+ class Qwen2_5_VisionRotaryEmbedding(nnx.Module):
421
+
422
+ def __init__(self, dim: int, theta: float = 10000.0):
423
+ self.dim = dim
424
+ self.theta = theta
425
+
426
+ def __call__(self, seqlen: int) -> jax.Array:
427
+ inv_freq = 1.0 / (self.theta**(
428
+ jnp.arange(0, self.dim, 2, dtype=jnp.float32) / self.dim))
429
+ seq = jnp.arange(seqlen, dtype=jnp.float32)
430
+ freqs = jnp.outer(seq, inv_freq)
431
+ return freqs.astype(jnp.bfloat16)
432
+
433
+
434
+ class Qwen2_5_VisionTransformer(nnx.Module):
435
+
436
+ def __init__(self,
437
+ vllm_config: VllmConfig,
438
+ rngs: nnx.Rngs,
439
+ mesh: Mesh,
440
+ norm_eps: float = 1e-6):
441
+ model_config = vllm_config.model_config
442
+ hf_config = model_config.hf_config
443
+ vision_config = hf_config.vision_config
444
+ dtype = model_config.dtype
445
+
446
+ self.config = vision_config
447
+ self.dtype = dtype
448
+
449
+ patch_size = vision_config.patch_size
450
+ temporal_patch_size = vision_config.temporal_patch_size
451
+ in_channels = vision_config.in_channels
452
+ self.hidden_size = vision_config.hidden_size
453
+ self.num_heads = vision_config.num_heads
454
+
455
+ # args for get_window_index_thw
456
+ self.window_size = vision_config.window_size
457
+ self.patch_size = vision_config.patch_size
458
+ self.spatial_merge_size = vision_config.spatial_merge_size
459
+ self.fullatt_block_indexes = vision_config.fullatt_block_indexes
460
+ self.spatial_merge_unit = self.spatial_merge_size**2
461
+
462
+ self.patch_embed = Qwen2_5_VisionPatchEmbed(
463
+ patch_size=patch_size,
464
+ temporal_patch_size=temporal_patch_size,
465
+ in_channels=in_channels,
466
+ hidden_size=self.hidden_size,
467
+ dtype=dtype,
468
+ rngs=rngs)
469
+
470
+ head_dim = vision_config.hidden_size // vision_config.num_heads
471
+ self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
472
+
473
+ self.blocks = [
474
+ Qwen2_5_VisionBlock(
475
+ config=hf_config,
476
+ dtype=dtype,
477
+ rngs=rngs,
478
+ mesh=mesh,
479
+ ) for _ in range(vision_config.depth)
480
+ ]
481
+ self.merger = Qwen2_5_VisionPatchMerger(
482
+ d_model=vision_config.out_hidden_size,
483
+ context_dim=vision_config.hidden_size,
484
+ norm_layer=partial(nnx.RMSNorm, epsilon=norm_eps),
485
+ spatial_merge_size=vision_config.spatial_merge_size,
486
+ dtype=dtype,
487
+ rngs=rngs)
488
+
489
+ def rotary_pos_emb_thw(self, t, h, w):
490
+ hpos_ids, wpos_ids = jnp.indices((h, w))
491
+ hpos_ids = hpos_ids.reshape(
492
+ h // self.spatial_merge_size,
493
+ self.spatial_merge_size,
494
+ w // self.spatial_merge_size,
495
+ self.spatial_merge_size,
496
+ ).transpose(0, 2, 1, 3).flatten()
497
+ wpos_ids = wpos_ids.reshape(
498
+ h // self.spatial_merge_size,
499
+ self.spatial_merge_size,
500
+ w // self.spatial_merge_size,
501
+ self.spatial_merge_size,
502
+ ).transpose(0, 2, 1, 3).flatten()
503
+ pos_ids = jnp.stack([hpos_ids, wpos_ids], axis=-1)
504
+ pos_ids = jnp.tile(pos_ids, (t, 1))
505
+
506
+ max_size = max(h, w)
507
+ rotary_pos_emb_full = self.rotary_pos_emb(max_size)
508
+
509
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].reshape(
510
+ pos_ids.shape[0], -1)
511
+ rotary_pos_emb = rotary_pos_emb.reshape(
512
+ rotary_pos_emb.shape[0] // self.spatial_merge_unit,
513
+ self.spatial_merge_unit, -1)
514
+
515
+ return rotary_pos_emb
516
+
517
+ def get_window_index_thw(self, grid_t, grid_h, grid_w):
518
+ vit_merger_window_size = (self.window_size //
519
+ self.spatial_merge_size // self.patch_size)
520
+
521
+ llm_grid_h = grid_h // self.spatial_merge_size
522
+ llm_grid_w = grid_w // self.spatial_merge_size
523
+
524
+ index = jnp.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
525
+ grid_t, llm_grid_h, llm_grid_w)
526
+
527
+ pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
528
+ pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
529
+ num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
530
+ num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
531
+
532
+ index_padded = jnp.pad(index, ((0, 0), (0, pad_h), (0, pad_w)),
533
+ constant_values=-100)
534
+ index_padded = index_padded.reshape(grid_t, num_windows_h,
535
+ vit_merger_window_size,
536
+ num_windows_w,
537
+ vit_merger_window_size)
538
+ index_padded = jnp.transpose(index_padded, (0, 1, 3, 2, 4)).reshape(
539
+ grid_t, num_windows_h * num_windows_w, vit_merger_window_size,
540
+ vit_merger_window_size)
541
+ seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
542
+ index_padded = index_padded.reshape(-1)
543
+ # The number of valid indices is static because grid_t, grid_h, grid_w
544
+ # are static.
545
+ num_valid_indices = grid_t * llm_grid_h * llm_grid_w
546
+ valid_indices = jnp.nonzero(index_padded != -100,
547
+ size=num_valid_indices)[0]
548
+ index_new = index_padded[valid_indices]
549
+ cu_seqlens_tmp = jnp.cumsum(seqlens) * self.spatial_merge_unit
550
+ cu_seqlens_tmp = cu_seqlens_tmp.astype(jnp.int32)
551
+
552
+ # NOTE (wenlong): Pytorch code uses this to reduce replication,
553
+ # but I don't think there is a need here, plus it would cause problem in JIT
554
+ # Please refer here if there is a problem down-stream
555
+ # cu_seqlens_tmp = jnp.unique(cu_seqlens_tmp)
556
+
557
+ return index_new, cu_seqlens_tmp
558
+
559
+ def get_rope_by_thw(self, t, h, w):
560
+ window_index_thw, cu_seqlens_window_thw = self.get_window_index_thw(
561
+ t, h, w)
562
+
563
+ rotary_pos_emb_thw = self.rotary_pos_emb_thw(t, h, w)
564
+
565
+ rotary_pos_emb_thw = rotary_pos_emb_thw[window_index_thw, :, :]
566
+ rotary_pos_emb_thw = rotary_pos_emb_thw.reshape(
567
+ -1, rotary_pos_emb_thw.shape[-1])
568
+ cu_seqlens_thw = jnp.full(t, h * w, dtype=jnp.int32)
569
+
570
+ return (rotary_pos_emb_thw, window_index_thw, cu_seqlens_window_thw,
571
+ cu_seqlens_thw)
572
+
573
+ def compute_attn_mask_seqlen(
574
+ self,
575
+ cu_seqlens: jax.Array,
576
+ ) -> tuple[Optional[int], Optional[list[int]]]:
577
+ max_seqlen, seqlens = None
578
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
579
+ seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
580
+ return max_seqlen, seqlens
581
+
582
+ def __call__(self, x: jax.Array, grid_thw: tuple[tuple[int, int,
583
+ int]]) -> jax.Array:
584
+ # x: pixel_values: jax.Array
585
+ # """Shape:
586
+ # `(num_patches, num_channels * patch_size * patch_size)`
587
+ # """
588
+
589
+ # grid_thw: image_grid_thw: jax.Array
590
+ # """Shape: `(num_images, 3)`
591
+ # This should be in `(grid_t, grid_h, grid_w)` format.
592
+ # """
593
+ hidden_states = self.patch_embed(x)
594
+
595
+ # num of patches
596
+ seq_len = x.shape[0]
597
+ # num of images/videoes
598
+ num_grids = len(grid_thw)
599
+
600
+ rotary_pos_emb = []
601
+ window_index: list = []
602
+ cu_window_seqlens: list = [jnp.array([0], dtype=jnp.int32)]
603
+ cu_seqlens: list = []
604
+
605
+ window_index_id = 0
606
+ cu_window_seqlens_last = 0
607
+ for i in range(num_grids):
608
+ t, h, w = grid_thw[i]
609
+
610
+ llm_h = h // self.spatial_merge_size
611
+ llm_w = w // self.spatial_merge_size
612
+
613
+ (
614
+ rotary_pos_emb_thw,
615
+ window_index_thw,
616
+ cu_seqlens_window_thw,
617
+ cu_seqlens_thw,
618
+ ) = self.get_rope_by_thw(t, h, w)
619
+
620
+ window_index.append(window_index_thw + window_index_id)
621
+ window_index_id += (t * llm_h * llm_w)
622
+
623
+ cu_seqlens_window_thw = (cu_seqlens_window_thw +
624
+ cu_window_seqlens_last)
625
+ cu_window_seqlens_last = cu_seqlens_window_thw[-1]
626
+ cu_window_seqlens.append(cu_seqlens_window_thw)
627
+
628
+ rotary_pos_emb.append(rotary_pos_emb_thw)
629
+
630
+ cu_seqlens.append(cu_seqlens_thw)
631
+
632
+ rotary_pos_emb = jnp.concatenate(rotary_pos_emb, axis=0)
633
+ window_index = jnp.concatenate(window_index, axis=0)
634
+ cu_window_seqlens = jnp.concatenate(cu_window_seqlens, axis=0)
635
+
636
+ cu_seqlens = jnp.concatenate(cu_seqlens, axis=0)
637
+ cu_seqlens = jnp.cumsum(cu_seqlens, axis=0, dtype=jnp.int32)
638
+ cu_seqlens = jnp.pad(cu_seqlens, ((1, 0), ),
639
+ mode='constant',
640
+ constant_values=0)
641
+
642
+ hidden_states = hidden_states.reshape(
643
+ seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
644
+ hidden_states = hidden_states[window_index, :, :]
645
+ hidden_states = hidden_states.reshape(seq_len, -1)
646
+
647
+ hidden_states = jnp.expand_dims(hidden_states, axis=1)
648
+
649
+ for layer_num, blk in enumerate(self.blocks):
650
+ if layer_num in self.fullatt_block_indexes:
651
+ hidden_states = blk(hidden_states,
652
+ rotary_pos_emb=rotary_pos_emb,
653
+ cu_window_seqlens=cu_seqlens,
654
+ use_fullattn=True)
655
+ else:
656
+ hidden_states = blk(hidden_states,
657
+ rotary_pos_emb=rotary_pos_emb,
658
+ cu_window_seqlens=cu_window_seqlens,
659
+ use_fullattn=False)
660
+
661
+ # adapter
662
+ hidden_states = self.merger(hidden_states)
663
+ reverse_indices = jnp.argsort(window_index)
664
+ hidden_states = hidden_states[reverse_indices, :]
665
+ return hidden_states
666
+
667
+
668
+ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
669
+
670
+ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array,
671
+ mesh: Mesh) -> None:
672
+ config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
673
+ multimodal_config = vllm_config.model_config.multimodal_config
674
+
675
+ self.vllm_config = vllm_config
676
+ self.rng = nnx.Rngs(rng_key)
677
+ self.mesh = mesh
678
+
679
+ self.config = config
680
+ self.multimodal_config = multimodal_config
681
+
682
+ self.visual = Qwen2_5_VisionTransformer(
683
+ vllm_config=vllm_config,
684
+ rngs=self.rng,
685
+ mesh=mesh,
686
+ norm_eps=getattr(config, "rms_norm_eps", 1e-6),
687
+ )
688
+ self.language_model = Qwen2ForCausalLM(vllm_config, rng_key, mesh)
689
+
690
+ def get_mrope_input_positions(
691
+ self,
692
+ input_tokens: list[int],
693
+ hf_config,
694
+ image_grid_thw,
695
+ video_grid_thw,
696
+ second_per_grid_ts: list[float],
697
+ context_len: int = 0,
698
+ seq_len: int | None = None,
699
+ audio_feature_lengths=None,
700
+ use_audio_in_video: bool = False,
701
+ ) -> tuple[jax.Array, int]:
702
+ """Get mrope input positions and delta value."""
703
+
704
+ image_token_id = hf_config.image_token_id
705
+ video_token_id = hf_config.video_token_id
706
+ vision_start_token_id = hf_config.vision_start_token_id
707
+ spatial_merge_size = hf_config.vision_config.spatial_merge_size
708
+ tokens_per_second = getattr(hf_config.vision_config,
709
+ "tokens_per_second", 1.0)
710
+
711
+ input_tokens_tensor = np.array(input_tokens)
712
+ vision_start_indices = np.argwhere(
713
+ input_tokens_tensor == vision_start_token_id).squeeze(1)
714
+ vision_tokens = input_tokens_tensor[vision_start_indices + 1]
715
+ image_nums = np.sum(vision_tokens == image_token_id)
716
+ video_nums = np.sum(vision_tokens == video_token_id)
717
+ llm_pos_ids_list: list = []
718
+
719
+ st = 0
720
+ remain_images, remain_videos = image_nums, video_nums
721
+
722
+ image_index, video_index = 0, 0
723
+ for _ in range(image_nums + video_nums):
724
+ video_second_per_grid_t = 0.0
725
+ if remain_images > 0:
726
+ try:
727
+ ed_image = input_tokens.index(image_token_id, st)
728
+ except ValueError:
729
+ ed_image = len(input_tokens) + 1
730
+ else:
731
+ ed_image = len(input_tokens) + 1
732
+ if remain_videos > 0:
733
+ try:
734
+ ed_video = input_tokens.index(video_token_id, st)
735
+ except ValueError:
736
+ ed_video = len(input_tokens) + 1
737
+ else:
738
+ ed_video = len(input_tokens) + 1
739
+ if ed_image < ed_video:
740
+ t, h, w = (
741
+ image_grid_thw[image_index][0],
742
+ image_grid_thw[image_index][1],
743
+ image_grid_thw[image_index][2],
744
+ )
745
+ image_index += 1
746
+ remain_images -= 1
747
+ ed = ed_image
748
+ else:
749
+ t, h, w = (
750
+ video_grid_thw[video_index][0],
751
+ video_grid_thw[video_index][1],
752
+ video_grid_thw[video_index][2],
753
+ )
754
+ video_second_per_grid_t = 1.0
755
+ if second_per_grid_ts:
756
+ video_second_per_grid_t = second_per_grid_ts[video_index]
757
+ video_index += 1
758
+ remain_videos -= 1
759
+ ed = ed_video
760
+
761
+ llm_grid_t, llm_grid_h, llm_grid_w = (
762
+ t,
763
+ h // spatial_merge_size,
764
+ w // spatial_merge_size,
765
+ )
766
+ text_len = ed - st
767
+
768
+ st_idx = llm_pos_ids_list[-1].max().item() + 1 if len(
769
+ llm_pos_ids_list) > 0 else 0
770
+ llm_pos_ids_list.append(
771
+ jnp.broadcast_to(
772
+ jnp.arange(text_len, dtype=jnp.int32).reshape(1, -1),
773
+ (3, text_len)) + st_idx)
774
+
775
+ t_index = ((jnp.broadcast_to(
776
+ jnp.arange(llm_grid_t, dtype=jnp.int32).reshape(-1, 1),
777
+ (llm_grid_t, llm_grid_h * llm_grid_w)) *
778
+ video_second_per_grid_t * tokens_per_second).astype(
779
+ jnp.int32).flatten())
780
+
781
+ h_index = (jnp.broadcast_to(
782
+ jnp.arange(llm_grid_h, dtype=jnp.int32).reshape(1, -1, 1),
783
+ (llm_grid_t, llm_grid_h, llm_grid_w)).flatten())
784
+ w_index = (jnp.broadcast_to(
785
+ jnp.arange(llm_grid_w, dtype=jnp.int32).reshape(1, 1, -1),
786
+ (llm_grid_t, llm_grid_h, llm_grid_w)).flatten())
787
+
788
+ llm_pos_ids_list.append(
789
+ jnp.stack([t_index, h_index, w_index]) + text_len + st_idx)
790
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
791
+
792
+ if st < len(input_tokens):
793
+ st_idx = llm_pos_ids_list[-1].max().item() + 1 if len(
794
+ llm_pos_ids_list) > 0 else 0
795
+ text_len = len(input_tokens) - st
796
+
797
+ llm_pos_ids_list.append(
798
+ jnp.broadcast_to(
799
+ jnp.arange(text_len, dtype=jnp.int32).reshape(1, -1),
800
+ (3, text_len)) + st_idx)
801
+
802
+ llm_positions = jnp.concatenate(llm_pos_ids_list,
803
+ axis=1).reshape(3, -1)
804
+ mrope_position_delta = (llm_positions.max() + 1 -
805
+ len(input_tokens)).item()
806
+ llm_positions = llm_positions[:, context_len:seq_len]
807
+
808
+ return llm_positions, mrope_position_delta
809
+
810
+ def _validate_and_reshape_mm_tensor(self, mm_input: object,
811
+ name: str) -> jax.Array:
812
+ if isinstance(mm_input, list):
813
+ # Assuming it's a list of arrays (e.g., np.ndarray, torch.Tensor)
814
+ # that can be concatenated.
815
+ arrays_to_concat = [jnp.asarray(item) for item in mm_input]
816
+ return jnp.concatenate(arrays_to_concat, axis=0)
817
+
818
+ # Handle single array-like objects (np.ndarray, torch.Tensor, jax.Array)
819
+ if hasattr(mm_input, 'ndim'):
820
+ array_input = jnp.asarray(mm_input)
821
+ if array_input.ndim == 2:
822
+ return array_input
823
+ if array_input.ndim == 3:
824
+ # This reshapes the batched 3D tensor to a 2D tensor.
825
+ return array_input.reshape(-1, array_input.shape[-1])
826
+
827
+ raise ValueError(f"Incorrect type of {name}. "
828
+ f"Got type: {type(mm_input)}")
829
+
830
+ def _parse_and_validate_image_input(
831
+ self, image_grid_thw: tuple[tuple[int, int, int], ...],
832
+ **kwargs: object) -> Optional[Qwen2_5_VLImageInputs]:
833
+ pixel_values = kwargs.pop("pixel_values", None)
834
+ image_embeds = kwargs.pop("image_embeds", None)
835
+ # image_grid_thw = kwargs.pop("image_grid_thw", None)
836
+
837
+ if pixel_values is None and image_embeds is None:
838
+ return None
839
+
840
+ if pixel_values is not None:
841
+ pixel_values = self._validate_and_reshape_mm_tensor(
842
+ pixel_values, "image pixel values")
843
+ # image_grid_thw = self._validate_and_reshape_mm_tensor(
844
+ # image_grid_thw, "image grid_thw")
845
+
846
+ if not isinstance(pixel_values, jax.Array):
847
+ raise ValueError("Incorrect type of image pixel values. "
848
+ f"Got type: {type(pixel_values)}")
849
+
850
+ return Qwen2_5_VLImagePixelInputs(type="pixel_values",
851
+ pixel_values=pixel_values,
852
+ image_grid_thw=image_grid_thw)
853
+
854
+ # Note: comment them out for now and save for future support
855
+ # if image_embeds is not None:
856
+ # image_embeds = self._validate_and_reshape_mm_tensor(
857
+ # image_embeds, "image embeds")
858
+ # image_grid_thw = self._validate_and_reshape_mm_tensor(
859
+ # image_grid_thw, "image grid_thw")
860
+
861
+ # if not isinstance(image_embeds, jax.Array):
862
+ # raise ValueError("Incorrect type of image embeddings. "
863
+ # f"Got type: {type(image_embeds)}")
864
+ # return Qwen2_5_VLImageEmbeddingInputs(
865
+ # type="image_embeds",
866
+ # image_embeds=image_embeds,
867
+ # image_grid_thw=image_grid_thw)
868
+
869
+ def _parse_and_validate_multimodal_inputs(self,
870
+ image_grid_thw: tuple[tuple[int,
871
+ int,
872
+ int],
873
+ ...],
874
+ **kwargs: object) -> dict:
875
+ mm_input_by_modality = {}
876
+
877
+ # Preserve the order of modalities if there are multiple of them
878
+ # from the order of kwargs.
879
+ for input_key in kwargs:
880
+ if input_key in ("pixel_values", "image_embeds"
881
+ ) and "image" not in mm_input_by_modality:
882
+ mm_input_by_modality[
883
+ "image"] = self._parse_and_validate_image_input(
884
+ image_grid_thw, **kwargs)
885
+ # if input_key in ("pixel_values_videos", "video_embeds"
886
+ # ) and "video" not in mm_input_by_modality:
887
+ # mm_input_by_modality[
888
+ # "video"] = self._parse_and_validate_video_input(**kwargs)
889
+ return mm_input_by_modality
890
+
891
+ @partial(
892
+ jax.jit,
893
+ static_argnames=("image_grid_thw", ),
894
+ )
895
+ def get_single_image_embedding(self, image_pixel_values, image_grid_thw):
896
+ return self.visual(image_pixel_values, (image_grid_thw, ))
897
+
898
+ def _process_image_input(
899
+ self, image_input: Qwen2_5_VLImageInputs) -> tuple[jax.Array, ...]:
900
+
901
+ grid_thw = image_input["image_grid_thw"]
902
+
903
+ if image_input["type"] == "image_embeds":
904
+ image_embeds = image_input["image_embeds"].astype(
905
+ self.visual.dtype)
906
+ else:
907
+ pixel_values = image_input["pixel_values"]
908
+ image_embeds = []
909
+ current_idx = 0
910
+ for image_thw in grid_thw:
911
+ t, h, w = image_thw
912
+ image_size = t * h * w
913
+ end_idx = current_idx + image_size
914
+ image_pixel_values = pixel_values[current_idx:end_idx, :]
915
+ image_embeds.append(
916
+ self.get_single_image_embedding(image_pixel_values,
917
+ image_thw))
918
+ current_idx = end_idx
919
+ image_embeds = jnp.concatenate(image_embeds, axis=0)
920
+
921
+ # Split concatenated embeddings for each image item.
922
+ merge_size = self.visual.config.spatial_merge_size
923
+ sizes = np.prod(np.array(grid_thw, dtype=np.int64),
924
+ axis=-1) // merge_size // merge_size
925
+
926
+ if sizes.size == 0:
927
+ return ()
928
+ if sizes.size == 1:
929
+ return (image_embeds, )
930
+
931
+ split_indices = np.cumsum(sizes)[:-1]
932
+ return tuple(jnp.split(image_embeds, split_indices))
933
+
934
+ def get_multimodal_embeddings(self, image_grid_thw: tuple[tuple[int, int,
935
+ int], ...],
936
+ **kwargs: object) -> MultiModalEmbeddings:
937
+
938
+ mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
939
+ image_grid_thw, **kwargs)
940
+ if not mm_input_by_modality:
941
+ return []
942
+
943
+ # The result multimodal_embeddings is tuple of tensors, with each
944
+ # tensor correspoending to a multimodal data item (image or video).
945
+ multimodal_embeddings: tuple[jax.Array, ...] = ()
946
+
947
+ # NOTE: It is important to iterate over the keys in this dictionary
948
+ # to preserve the order of the modalities.
949
+ for modality in mm_input_by_modality:
950
+ multimodal_input = mm_input_by_modality[modality]
951
+ if modality == "image":
952
+ vision_embeddings = self._process_image_input(multimodal_input)
953
+ multimodal_embeddings += vision_embeddings
954
+ # if modality == "video":
955
+ # video_embeddings = self._process_video_input(multimodal_input)
956
+ # multimodal_embeddings += video_embeddings
957
+
958
+ return multimodal_embeddings
959
+
960
+ def get_input_embeddings(
961
+ self, input_ids: jax.Array,
962
+ multimodal_embeddings: Optional[jax.Array]) -> jax.Array:
963
+
964
+ inputs_embeds = self.language_model.model.embed(input_ids)
965
+
966
+
967
+ if multimodal_embeddings is not None \
968
+ and multimodal_embeddings.shape[0] != 0:
969
+ inputs_embeds = merge_multimodal_embeddings(
970
+ input_ids, inputs_embeds, multimodal_embeddings,
971
+ [self.config.image_token_id, self.config.video_token_id])
972
+
973
+ return inputs_embeds
974
+
975
+ def __call__(
976
+ self,
977
+ kv_caches: list[jax.Array],
978
+ input_ids: Optional[jax.Array],
979
+ attention_metadata: AttentionMetadata,
980
+ inputs_embeds: Optional[jax.Array] = None,
981
+ *args,
982
+ ) -> tuple[list[jax.Array], jax.Array, List[jax.Array]]:
983
+ # The logic of choosing between input_ids and inputs_embeds is
984
+ # handled inside self.language_model.__call__
985
+ kv_caches, x, [] = self.language_model(
986
+ kv_caches=kv_caches,
987
+ input_ids=input_ids,
988
+ attention_metadata=attention_metadata,
989
+ inputs_embeds=inputs_embeds,
990
+ )
991
+ return kv_caches, x, []
992
+
993
+ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
994
+ return self.language_model.compute_logits(hidden_states)
995
+
996
+ def load_weights(self, rng_key: jax.Array) -> None:
997
+ self.rng = nnx.Rngs(rng_key)
998
+ self.language_model.rng = self.rng
999
+
1000
+ # Key: path to a HF layer weight
1001
+ # Value: a tuple of (path to a nnx layer weight, nnx weight sharding)
1002
+
1003
+ mappings = {
1004
+ "model.embed_tokens": "language_model.model.embed.embedding",
1005
+ "model.layers.*.input_layernorm":
1006
+ "language_model.model.layers.*.input_layernorm.scale",
1007
+ "model.layers.*.mlp.down_proj":
1008
+ "language_model.model.layers.*.mlp.down_proj.kernel",
1009
+ "model.layers.*.mlp.gate_proj":
1010
+ "language_model.model.layers.*.mlp.gate_proj.kernel",
1011
+ "model.layers.*.mlp.up_proj":
1012
+ "language_model.model.layers.*.mlp.up_proj.kernel",
1013
+ "model.layers.*.post_attention_layernorm":
1014
+ "language_model.model.layers.*.post_attention_layernorm.scale",
1015
+ "model.layers.*.self_attn.k_proj":
1016
+ "language_model.model.layers.*.self_attn.k_proj.kernel",
1017
+ "model.layers.*.self_attn.o_proj":
1018
+ "language_model.model.layers.*.self_attn.o_proj.kernel",
1019
+ "model.layers.*.self_attn.q_proj":
1020
+ "language_model.model.layers.*.self_attn.q_proj.kernel",
1021
+ "model.layers.*.self_attn.v_proj":
1022
+ "language_model.model.layers.*.self_attn.v_proj.kernel",
1023
+ "model.layers.*.self_attn.q_proj.bias":
1024
+ "language_model.model.layers.*.self_attn.q_proj.bias",
1025
+ "model.layers.*.self_attn.k_proj.bias":
1026
+ "language_model.model.layers.*.self_attn.k_proj.bias",
1027
+ "model.layers.*.self_attn.v_proj.bias":
1028
+ "language_model.model.layers.*.self_attn.v_proj.bias",
1029
+ "model.norm": "language_model.model.norm.scale",
1030
+ "visual.blocks.*.attn.proj.bias": "visual.blocks.*.attn.proj.bias",
1031
+ "visual.blocks.*.attn.proj": "visual.blocks.*.attn.proj.kernel",
1032
+ "visual.blocks.*.attn.qkv.bias":
1033
+ "visual.blocks.*.attn.qkv_proj.bias",
1034
+ "visual.blocks.*.attn.qkv": "visual.blocks.*.attn.qkv_proj.kernel",
1035
+ "visual.blocks.*.mlp.down_proj.bias":
1036
+ "visual.blocks.*.mlp.down_proj.bias",
1037
+ "visual.blocks.*.mlp.down_proj":
1038
+ "visual.blocks.*.mlp.down_proj.kernel",
1039
+ "visual.blocks.*.mlp.gate_proj.bias":
1040
+ "visual.blocks.*.mlp.gate_proj.bias",
1041
+ "visual.blocks.*.mlp.gate_proj":
1042
+ "visual.blocks.*.mlp.gate_proj.kernel",
1043
+ "visual.blocks.*.mlp.up_proj.bias":
1044
+ "visual.blocks.*.mlp.up_proj.bias",
1045
+ "visual.blocks.*.mlp.up_proj":
1046
+ "visual.blocks.*.mlp.up_proj.kernel",
1047
+ "visual.blocks.*.norm1": "visual.blocks.*.norm1.scale",
1048
+ "visual.blocks.*.norm2": "visual.blocks.*.norm2.scale",
1049
+ "visual.merger.ln_q": "visual.merger.ln_q.scale",
1050
+ "visual.merger.mlp.0.bias": "visual.merger.mlp_fc1.bias",
1051
+ "visual.merger.mlp.0": "visual.merger.mlp_fc1.kernel",
1052
+ "visual.merger.mlp.2.bias": "visual.merger.mlp_fc2.bias",
1053
+ "visual.merger.mlp.2": "visual.merger.mlp_fc2.kernel",
1054
+ "visual.patch_embed.proj": "visual.patch_embed.proj.kernel",
1055
+ }
1056
+
1057
+ # Add lm_head mapping only if it's not tied to embeddings
1058
+ hf_config = self.vllm_config.model_config.hf_config
1059
+ if not hf_config.tie_word_embeddings:
1060
+ mappings.update({
1061
+ "lm_head": "language_model.model.lm_head",
1062
+ })
1063
+
1064
+ metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
1065
+ load_hf_weights(vllm_config=self.vllm_config,
1066
+ model=self,
1067
+ metadata_map=metadata_map,
1068
+ mesh=self.mesh)
1069
+
1070
+ def precompile_vision_encoder(
1071
+ self,
1072
+ run_compilation_fn: Callable,
1073
+ ) -> None:
1074
+ image_shapes = []
1075
+ if (warmup_config := self.vllm_config.additional_config.get(
1076
+ "vision_warmup_config")):
1077
+ image_shapes = warmup_config.get("image_shapes")
1078
+
1079
+ vc = self.vllm_config.model_config.hf_config.vision_config
1080
+ factor = vc.patch_size * vc.spatial_merge_size
1081
+ for input_hw in image_shapes:
1082
+ if not isinstance(input_hw, list) or len(input_hw) != 2:
1083
+ logger.warning(f"Skipping invalid shape {input_hw}.")
1084
+ continue
1085
+ h_input, w_input = input_hw
1086
+ h_processed = round(h_input / factor) * factor
1087
+ w_processed = round(w_input / factor) * factor
1088
+ t, h, w = 1, h_processed // vc.patch_size, w_processed // vc.patch_size
1089
+ grid_thw = (t, h, w)
1090
+ num_patches = t * h * w
1091
+ patch_input_dim = vc.in_channels * vc.temporal_patch_size * vc.patch_size * vc.patch_size
1092
+
1093
+ dummy_pixel_values = jnp.ones(
1094
+ (num_patches, patch_input_dim),
1095
+ self.vllm_config.model_config.dtype,
1096
+ )
1097
+ dummy_grid_thw = grid_thw
1098
+
1099
+ run_compilation_fn("single_image_encoder",
1100
+ self.get_single_image_embedding,
1101
+ dummy_pixel_values,
1102
+ dummy_grid_thw,
1103
+ image_shape=input_hw)