tpu-inference 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (168) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_adapters.py +83 -0
  4. tests/core/test_core_tpu.py +523 -0
  5. tests/core/test_disagg_executor.py +60 -0
  6. tests/core/test_disagg_utils.py +53 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  10. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  11. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  12. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  13. tests/lora/__init__.py +0 -0
  14. tests/lora/test_lora.py +123 -0
  15. tests/test_base.py +201 -0
  16. tests/test_quantization.py +836 -0
  17. tests/test_tpu_info.py +120 -0
  18. tests/test_utils.py +218 -0
  19. tests/tpu_backend_test.py +59 -0
  20. tpu_inference/__init__.py +30 -0
  21. tpu_inference/adapters/__init__.py +0 -0
  22. tpu_inference/adapters/vllm_adapters.py +42 -0
  23. tpu_inference/adapters/vllm_config_adapters.py +134 -0
  24. tpu_inference/backend.py +69 -0
  25. tpu_inference/core/__init__.py +0 -0
  26. tpu_inference/core/adapters.py +153 -0
  27. tpu_inference/core/core_tpu.py +776 -0
  28. tpu_inference/core/disagg_executor.py +117 -0
  29. tpu_inference/core/disagg_utils.py +51 -0
  30. tpu_inference/di/__init__.py +0 -0
  31. tpu_inference/di/abstracts.py +28 -0
  32. tpu_inference/di/host.py +76 -0
  33. tpu_inference/di/interfaces.py +51 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/tpu_connector.py +699 -0
  36. tpu_inference/distributed/utils.py +59 -0
  37. tpu_inference/executors/__init__.py +0 -0
  38. tpu_inference/executors/ray_distributed_executor.py +346 -0
  39. tpu_inference/experimental/__init__.py +0 -0
  40. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  41. tpu_inference/interfaces/__init__.py +0 -0
  42. tpu_inference/interfaces/cache.py +31 -0
  43. tpu_inference/interfaces/config.py +47 -0
  44. tpu_inference/interfaces/config_parts.py +117 -0
  45. tpu_inference/interfaces/engine.py +51 -0
  46. tpu_inference/interfaces/outputs.py +22 -0
  47. tpu_inference/interfaces/params.py +21 -0
  48. tpu_inference/interfaces/platform.py +74 -0
  49. tpu_inference/interfaces/request.py +39 -0
  50. tpu_inference/interfaces/scheduler.py +31 -0
  51. tpu_inference/kernels/__init__.py +0 -0
  52. tpu_inference/kernels/collectives/__init__.py +0 -0
  53. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  54. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  55. tpu_inference/kernels/collectives/util.py +47 -0
  56. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  57. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  58. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  59. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  60. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  61. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  62. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  66. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
  71. tpu_inference/layers/__init__.py +0 -0
  72. tpu_inference/layers/common/__init__.py +0 -0
  73. tpu_inference/layers/common/attention_metadata.py +34 -0
  74. tpu_inference/layers/jax/__init__.py +0 -0
  75. tpu_inference/layers/jax/attention/__init__.py +0 -0
  76. tpu_inference/layers/jax/attention/attention.py +254 -0
  77. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  78. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  79. tpu_inference/layers/jax/attention_interface.py +356 -0
  80. tpu_inference/layers/jax/base.py +151 -0
  81. tpu_inference/layers/jax/binary_search.py +295 -0
  82. tpu_inference/layers/jax/constants.py +88 -0
  83. tpu_inference/layers/jax/layers.py +301 -0
  84. tpu_inference/layers/jax/misc.py +16 -0
  85. tpu_inference/layers/jax/moe/__init__.py +0 -0
  86. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  87. tpu_inference/layers/jax/moe/moe.py +209 -0
  88. tpu_inference/layers/jax/rope.py +172 -0
  89. tpu_inference/layers/jax/rope_interface.py +214 -0
  90. tpu_inference/layers/jax/sample/__init__.py +0 -0
  91. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  92. tpu_inference/layers/jax/sample/sampling.py +95 -0
  93. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  94. tpu_inference/layers/jax/sharding.py +406 -0
  95. tpu_inference/layers/jax/transformer_block.py +76 -0
  96. tpu_inference/layers/vllm/__init__.py +0 -0
  97. tpu_inference/layers/vllm/attention.py +184 -0
  98. tpu_inference/layers/vllm/fused_moe.py +399 -0
  99. tpu_inference/layers/vllm/linear_common.py +186 -0
  100. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  101. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  102. tpu_inference/layers/vllm/quantization/common.py +105 -0
  103. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  104. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  105. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  106. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  108. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  109. tpu_inference/layers/vllm/sharding.py +151 -0
  110. tpu_inference/logger.py +10 -0
  111. tpu_inference/lora/__init__.py +0 -0
  112. tpu_inference/lora/torch_lora_ops.py +103 -0
  113. tpu_inference/lora/torch_punica_tpu.py +308 -0
  114. tpu_inference/mock/__init__.py +0 -0
  115. tpu_inference/mock/vllm_config_utils.py +28 -0
  116. tpu_inference/mock/vllm_envs.py +1233 -0
  117. tpu_inference/mock/vllm_logger.py +212 -0
  118. tpu_inference/mock/vllm_logging_utils.py +15 -0
  119. tpu_inference/models/__init__.py +0 -0
  120. tpu_inference/models/common/__init__.py +0 -0
  121. tpu_inference/models/common/model_loader.py +433 -0
  122. tpu_inference/models/jax/__init__.py +0 -0
  123. tpu_inference/models/jax/deepseek_v3.py +868 -0
  124. tpu_inference/models/jax/llama3.py +366 -0
  125. tpu_inference/models/jax/llama4.py +473 -0
  126. tpu_inference/models/jax/llama_eagle3.py +333 -0
  127. tpu_inference/models/jax/phi3.py +376 -0
  128. tpu_inference/models/jax/qwen2.py +375 -0
  129. tpu_inference/models/jax/qwen2_5_vl.py +976 -0
  130. tpu_inference/models/jax/qwen3.py +302 -0
  131. tpu_inference/models/jax/utils/__init__.py +0 -0
  132. tpu_inference/models/jax/utils/file_utils.py +96 -0
  133. tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
  134. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
  136. tpu_inference/models/jax/utils/weight_utils.py +510 -0
  137. tpu_inference/models/vllm/__init__.py +0 -0
  138. tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
  139. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  140. tpu_inference/platforms/__init__.py +2 -0
  141. tpu_inference/platforms/tpu_jax.py +257 -0
  142. tpu_inference/runner/__init__.py +0 -0
  143. tpu_inference/runner/block_table_jax.py +122 -0
  144. tpu_inference/runner/compilation_manager.py +672 -0
  145. tpu_inference/runner/input_batch_jax.py +435 -0
  146. tpu_inference/runner/kv_cache.py +119 -0
  147. tpu_inference/runner/kv_cache_manager.py +460 -0
  148. tpu_inference/runner/lora_utils.py +92 -0
  149. tpu_inference/runner/multimodal_manager.py +208 -0
  150. tpu_inference/runner/persistent_batch_manager.py +244 -0
  151. tpu_inference/runner/speculative_decoding_manager.py +250 -0
  152. tpu_inference/runner/structured_decoding_manager.py +89 -0
  153. tpu_inference/runner/tpu_jax_runner.py +771 -0
  154. tpu_inference/runner/utils.py +426 -0
  155. tpu_inference/spec_decode/__init__.py +0 -0
  156. tpu_inference/spec_decode/jax/__init__.py +0 -0
  157. tpu_inference/spec_decode/jax/eagle3.py +334 -0
  158. tpu_inference/tpu_info.py +77 -0
  159. tpu_inference/utils.py +294 -0
  160. tpu_inference/worker/__init__.py +0 -0
  161. tpu_inference/worker/_temporary_vllm_compat.py +129 -0
  162. tpu_inference/worker/base.py +100 -0
  163. tpu_inference/worker/tpu_worker_jax.py +321 -0
  164. tpu_inference-0.11.1.dist-info/METADATA +101 -0
  165. tpu_inference-0.11.1.dist-info/RECORD +168 -0
  166. tpu_inference-0.11.1.dist-info/WHEEL +5 -0
  167. tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
  168. tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,976 @@
1
+ import math
2
+ from functools import partial
3
+ from typing import (Callable, List, Literal, NamedTuple, Optional, TypedDict,
4
+ Union)
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+ from flax import nnx
10
+ from jax.sharding import Mesh
11
+ from transformers import modeling_flax_utils
12
+ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
13
+ Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
14
+ from vllm.config import VllmConfig
15
+ from vllm.model_executor.models.qwen2_5_vl import \
16
+ Qwen2_5_VLForConditionalGeneration as vllm_model_cls
17
+
18
+ from tpu_inference import utils as utils
19
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
20
+ from tpu_inference.layers.jax.attention_interface import \
21
+ sharded_flash_attention
22
+ from tpu_inference.logger import init_logger
23
+ from tpu_inference.models.jax.qwen2 import Qwen2ForCausalLM
24
+ # from vllm.model_executor.models.interfaces import MultiModalEmbeddings
25
+ from tpu_inference.models.jax.utils.multi_modal_utils import (
26
+ MultiModalEmbeddings, merge_multimodal_embeddings)
27
+ from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
28
+ load_hf_weights)
29
+
30
+ logger = init_logger(__name__)
31
+
32
+ init_fn = nnx.initializers.uniform()
33
+
34
+ DEFAULT_BLOCK_K_MAJOR = 128
35
+
36
+
37
+ class SegmentIds(NamedTuple):
38
+ """SegmentIds for Q and KV sequences.
39
+
40
+ SegmentIds are used to generate segment mask, which prevents attention between
41
+ different segments in the input sequence. Each array is a list of ids
42
+ (integers).
43
+ Only the token with the same id can attend to each other.
44
+
45
+ Attributes:
46
+ q: segment ids along the Q sequence.
47
+ kv: segment ids along the KV sequence.
48
+ """
49
+
50
+ q: jax.Array # [batch_size, q_seq_len]
51
+ kv: jax.Array # [batch_size, kv_seq_len]
52
+
53
+
54
+ class Qwen2_5_VLImagePixelInputs(TypedDict):
55
+ type: Literal["pixel_values"]
56
+ pixel_values: jax.Array
57
+ """Shape:
58
+ `(num_patches, num_channels * patch_size * patch_size)`
59
+ """
60
+
61
+ image_grid_thw: tuple[tuple[int, int, int], ...]
62
+ """Shape: `(num_images, 3)`
63
+ This should be in `(grid_t, grid_h, grid_w)` format.
64
+ """
65
+
66
+
67
+ # NOTE: We are not supporting embedding inputs for now
68
+ # The code here makes the struture consistent and
69
+ # makes iteasier for future implementation
70
+ class Qwen2_5_VLImageEmbeddingInputs(TypedDict):
71
+ type: Literal["image_embeds"]
72
+ image_embeds: jax.Array
73
+ """Supported types:
74
+ - list[`jax.Array`]: A list of tensors holding all images' features.
75
+ Each tensor holds an image's features.
76
+ - `jax.Array`: A tensor holding all images' features (concatenation of
77
+ all images' feature tensors).
78
+
79
+ Tensor shape: `(num_image_features, hidden_size)`
80
+ - `num_image_features` varies based on
81
+ the number and resolution of the images.
82
+ - `hidden_size` must match the hidden size of language model backbone.
83
+ """
84
+
85
+ image_grid_thw: jax.Array
86
+ """Shape: `(num_images, 3)`
87
+ This should be in `(grid_t, grid_h, grid_w)` format.
88
+ """
89
+
90
+
91
+ Qwen2_5_VLImageInputs = Union[Qwen2_5_VLImagePixelInputs,
92
+ Qwen2_5_VLImageEmbeddingInputs]
93
+
94
+
95
+ class Qwen2_5_VisionMLP(nnx.Module):
96
+
97
+ def __init__(self, config: Qwen2_5_VLVisionConfig, dtype: jnp.dtype,
98
+ rngs: nnx.Rngs):
99
+ in_features = config.hidden_size
100
+ hidden_features = config.intermediate_size
101
+ act_fn = modeling_flax_utils.ACT2FN[config.hidden_act]
102
+ self.gate_proj = nnx.Linear(
103
+ in_features,
104
+ hidden_features,
105
+ use_bias=True,
106
+ param_dtype=dtype,
107
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
108
+ bias_init=nnx.with_partitioning(init_fn, ("model", )),
109
+ rngs=rngs,
110
+ )
111
+ self.up_proj = nnx.Linear(
112
+ in_features,
113
+ hidden_features,
114
+ use_bias=True,
115
+ param_dtype=dtype,
116
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
117
+ bias_init=nnx.with_partitioning(init_fn, ("model", )),
118
+ rngs=rngs,
119
+ )
120
+ self.down_proj = nnx.Linear(
121
+ hidden_features,
122
+ in_features,
123
+ use_bias=True,
124
+ param_dtype=dtype,
125
+ kernel_init=nnx.with_partitioning(init_fn, ("model", None)),
126
+ bias_init=nnx.with_partitioning(init_fn, (None, )),
127
+ rngs=rngs,
128
+ )
129
+ self.act_fn = act_fn
130
+
131
+ def __call__(self, x: jax.Array) -> jax.Array:
132
+ gate = self.act_fn(self.gate_proj(x))
133
+ up = self.up_proj(x)
134
+ fuse = gate * up
135
+ result = self.down_proj(fuse)
136
+ return result
137
+
138
+
139
+ def apply_rotary_pos_emb_vision(x: jax.Array,
140
+ rotary_pos_emb: jax.Array) -> jax.Array:
141
+ # x: [B, T, N, H]
142
+ # rotary_pos_emb: [T, H//2]
143
+ _, _, _, H = x.shape
144
+ half_dim = H // 2
145
+
146
+ # [B, T, N, H//2]
147
+ x_real = x[..., :half_dim]
148
+ x_imag = x[..., half_dim:]
149
+
150
+ # [T, H//2]
151
+ cos_emb = jnp.cos(rotary_pos_emb)
152
+ sin_emb = jnp.sin(rotary_pos_emb)
153
+
154
+ # [1, T, 1, H//2]
155
+ cos_emb = cos_emb[None, :, None, :]
156
+ sin_emb = sin_emb[None, :, None, :]
157
+
158
+ # [B, T, N, H//2]
159
+ x_rotated_real = x_real * cos_emb - x_imag * sin_emb
160
+ x_rotated_imag = x_real * sin_emb + x_imag * cos_emb
161
+
162
+ # [B, T, N, H]
163
+ x_rotated = jnp.concatenate([x_rotated_real, x_rotated_imag], axis=-1)
164
+
165
+ return x_rotated
166
+
167
+
168
+ def generate_window_segment_ids(cu_seqlens: jax.Array, seq_len: int,
169
+ padded_seq_len: int) -> SegmentIds:
170
+ """Generates segment IDs for windowed attention
171
+
172
+ Args:
173
+ cu_seqlens: A 1D array of cumulative sequence lengths for each window.
174
+ e.g., [0, len_win0, len_win0+len_win1, ...]
175
+
176
+ Returns:
177
+ A SegmentIds object for flash_attention.
178
+ """
179
+ indices = jnp.arange(seq_len, dtype=jnp.int32)
180
+ segment_ids = jnp.searchsorted(cu_seqlens[1:], indices, side='right') + 1
181
+ padding_segment_ids = jnp.zeros(padded_seq_len - seq_len, dtype=jnp.int32)
182
+ segment_ids = jnp.concatenate([segment_ids, padding_segment_ids])
183
+ segment_ids = segment_ids.reshape(1, -1)
184
+
185
+ return SegmentIds(q=segment_ids, kv=segment_ids)
186
+
187
+
188
+ class Qwen2_5_VisionAttention(nnx.Module):
189
+
190
+ def __init__(self, config: Qwen2_5_VLConfig, dtype: jnp.dtype,
191
+ rngs: nnx.Rngs, mesh: Mesh):
192
+ vision_config = config.vision_config
193
+ self.hidden_size = vision_config.hidden_size
194
+ self.num_heads = vision_config.num_heads
195
+ self.num_kv_heads = self.num_heads
196
+ self.rope_theta = config.rope_theta
197
+ self.rope_scaling = getattr(config, "rope_scaling", None)
198
+ self.head_dim_original = self.hidden_size // self.num_heads
199
+
200
+ sharding_size = mesh.shape["model"]
201
+ self.num_heads = utils.get_padded_num_heads(self.num_heads,
202
+ sharding_size)
203
+ self.num_kv_heads = utils.get_padded_num_heads(self.num_kv_heads,
204
+ sharding_size)
205
+ self.head_dim = utils.get_padded_head_dim(self.head_dim_original)
206
+
207
+ # TODO: Wenlong: Do not consider padding for now
208
+ self.head_dim = self.head_dim_original
209
+
210
+ self.mesh = mesh
211
+
212
+ self.qkv_proj = nnx.Linear(
213
+ self.hidden_size,
214
+ 3 * self.hidden_size,
215
+ use_bias=True,
216
+ param_dtype=dtype,
217
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
218
+ bias_init=nnx.with_partitioning(init_fn, ("model", )),
219
+ rngs=rngs,
220
+ )
221
+
222
+ self.proj = nnx.Linear(
223
+ self.hidden_size,
224
+ self.hidden_size,
225
+ use_bias=True,
226
+ param_dtype=dtype,
227
+ kernel_init=nnx.with_partitioning(init_fn, ("model", None)),
228
+ bias_init=nnx.with_partitioning(init_fn, (None, )),
229
+ rngs=rngs,
230
+ )
231
+ self.flash_attention = sharded_flash_attention(
232
+ mesh=mesh,
233
+ causal=False,
234
+ sm_scale=1.0 / math.sqrt(self.head_dim),
235
+ vmem_limit_bytes=128 * 1024 * 1024,
236
+ )
237
+
238
+ def __call__(
239
+ self,
240
+ x: jax.Array,
241
+ rotary_pos_emb: jax.Array,
242
+ cu_window_seqlens: Optional[jax.Array] = None,
243
+ use_fullattn: bool = True,
244
+ ) -> jax.Array:
245
+ T, B, D = x.shape
246
+ assert B == 1, "Vision attention currently only supports batch size 1"
247
+ # [T, B, D] -> [T, B, 3 * D]
248
+ qkv = self.qkv_proj(x)
249
+
250
+ # Split into Q, K, V.
251
+ # NOTE: simplified from vLLM's split_qkv,
252
+ # may need to revisit for tp>1
253
+ # [T, B, 3 * D] -> 3 *[T, B, D]
254
+ q, k, v = jnp.split(qkv, 3, axis=-1)
255
+
256
+ # [T, B, N, H]
257
+ q = q.reshape(T, B, self.num_heads, self.head_dim)
258
+ k = k.reshape(T, B, self.num_heads, self.head_dim)
259
+ v = v.reshape(T, B, self.num_heads, self.head_dim)
260
+
261
+ # [T, B, N, H] -> [B, T, N, H]
262
+ q = jnp.transpose(q, (1, 0, 2, 3))
263
+ k = jnp.transpose(k, (1, 0, 2, 3))
264
+ v = jnp.transpose(v, (1, 0, 2, 3))
265
+
266
+ # rotary_pos_emb shape: (T, H)
267
+ q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
268
+ k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
269
+
270
+ # NOTE: an extra transpose because we need to
271
+ # align the correctness with vLLM's design.
272
+ # Might be able to remove one once implemented.
273
+ # [B, T, N, H] -> [B, N, T, H]
274
+ q = jnp.transpose(q, (0, 2, 1, 3))
275
+ k = jnp.transpose(k, (0, 2, 1, 3))
276
+ v = jnp.transpose(v, (0, 2, 1, 3))
277
+
278
+ # Pad the sequence length to be a multiple of 128 for flash_attention
279
+ block_k_major = DEFAULT_BLOCK_K_MAJOR
280
+ T_attn = q.shape[2]
281
+ padded_T = (T_attn + block_k_major -
282
+ 1) // block_k_major * block_k_major
283
+ pad_width = ((0, 0), (0, 0), (0, padded_T - T_attn), (0, 0))
284
+
285
+ q = jnp.pad(q, pad_width, 'constant')
286
+ k = jnp.pad(k, pad_width, 'constant')
287
+ v = jnp.pad(v, pad_width, 'constant')
288
+
289
+ segment_ids = generate_window_segment_ids(cu_window_seqlens, T_attn,
290
+ padded_T)
291
+
292
+ # TODO (jacobplatin): add support for quantized KV cache?
293
+ output = self.flash_attention(q, k, v, segment_ids)
294
+
295
+ # Unpad the output
296
+ output = output[:, :, :T_attn, :]
297
+
298
+ # [B, N, T, H] -> [T, B, N, H]
299
+ output = jnp.transpose(output, (2, 0, 1, 3))
300
+
301
+ output = output.reshape(T, B, D)
302
+
303
+ output = self.proj(output)
304
+
305
+ return output
306
+
307
+
308
+ class Qwen2_5_VisionBlock(nnx.Module):
309
+
310
+ def __init__(self, config: Qwen2_5_VLConfig, dtype: jnp.dtype,
311
+ rngs: nnx.Rngs, mesh: Mesh):
312
+ vision_config = config.vision_config
313
+ dim = vision_config.hidden_size
314
+ norm_layer = partial(nnx.RMSNorm,
315
+ epsilon=config.rms_norm_eps,
316
+ scale_init=nnx.with_partitioning(
317
+ init_fn, (None, )))
318
+
319
+ self.norm1 = norm_layer(dim, dtype=dtype, rngs=rngs)
320
+ self.norm2 = norm_layer(dim, dtype=dtype, rngs=rngs)
321
+ self.attn = Qwen2_5_VisionAttention(config=config,
322
+ dtype=dtype,
323
+ rngs=rngs,
324
+ mesh=mesh)
325
+ self.mlp = Qwen2_5_VisionMLP(config=vision_config,
326
+ dtype=dtype,
327
+ rngs=rngs)
328
+
329
+ def __call__(self,
330
+ x: jax.Array,
331
+ rotary_pos_emb: jax.Array,
332
+ cu_window_seqlens: Optional[jax.Array] = None,
333
+ use_fullattn: bool = True) -> jax.Array:
334
+
335
+ x = x + self.attn(self.norm1(x), rotary_pos_emb, cu_window_seqlens,
336
+ use_fullattn)
337
+ x = x + self.mlp(self.norm2(x))
338
+
339
+ return x
340
+
341
+
342
+ class Qwen2_5_VisionPatchEmbed(nnx.Module):
343
+
344
+ def __init__(
345
+ self,
346
+ rngs: nnx.Rngs,
347
+ patch_size: int = 14,
348
+ temporal_patch_size: int = 2,
349
+ in_channels: int = 3,
350
+ hidden_size: int = 1152,
351
+ dtype: jnp.dtype = jnp.bfloat16,
352
+ ) -> None:
353
+ self.patch_size = patch_size
354
+ self.temporal_patch_size = temporal_patch_size
355
+ self.hidden_size = hidden_size
356
+ kernel_size = (temporal_patch_size, patch_size, patch_size)
357
+ self.proj = nnx.Conv(in_features=in_channels,
358
+ out_features=hidden_size,
359
+ kernel_size=kernel_size,
360
+ strides=kernel_size,
361
+ use_bias=False,
362
+ param_dtype=dtype,
363
+ kernel_init=nnx.with_partitioning(
364
+ init_fn, (None, None, None, None, "model")),
365
+ rngs=rngs)
366
+
367
+ def __call__(self, x: jax.Array) -> jax.Array:
368
+ # x is (L, C * T * H * W)
369
+ L, dim = x.shape
370
+ C = dim // (self.temporal_patch_size * self.patch_size *
371
+ self.patch_size)
372
+ # Reshape to (L, T, H, W, C) for Conv3D with channels_last
373
+ x = x.reshape(L, C, self.temporal_patch_size, self.patch_size,
374
+ self.patch_size)
375
+ # L,T,H,W,C
376
+ x = jnp.transpose(x, (0, 2, 3, 4, 1))
377
+ x = self.proj(x)
378
+ # After conv, shape is (L, T_out, H_out, W_out, C_out)
379
+ # With stride=kernel_size, T_out=H_out=W_out=1.
380
+ # So shape is (L, 1, 1, 1, hidden_size)
381
+ x = x.reshape(L, self.hidden_size)
382
+ return x
383
+
384
+
385
+ class Qwen2_5_VisionPatchMerger(nnx.Module):
386
+
387
+ def __init__(self, d_model: int, context_dim: int, norm_layer: Callable,
388
+ spatial_merge_size: int, dtype: jnp.dtype, rngs: nnx.Rngs):
389
+ self.hidden_size = context_dim * (spatial_merge_size**2)
390
+ self.ln_q = norm_layer(context_dim,
391
+ dtype=dtype,
392
+ rngs=rngs,
393
+ scale_init=nnx.with_partitioning(
394
+ init_fn, (None, )))
395
+ self.mlp_fc1 = nnx.Linear(
396
+ self.hidden_size,
397
+ self.hidden_size,
398
+ use_bias=True,
399
+ param_dtype=dtype,
400
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
401
+ bias_init=nnx.with_partitioning(init_fn, ("model", )),
402
+ rngs=rngs)
403
+ self.mlp_act = modeling_flax_utils.ACT2FN["gelu"]
404
+ self.mlp_fc2 = nnx.Linear(
405
+ self.hidden_size,
406
+ d_model,
407
+ use_bias=True,
408
+ param_dtype=dtype,
409
+ kernel_init=nnx.with_partitioning(init_fn, ("model", None)),
410
+ bias_init=nnx.with_partitioning(init_fn, (None, )),
411
+ rngs=rngs)
412
+
413
+ def __call__(self, x: jax.Array) -> jax.Array:
414
+ x = self.ln_q(x)
415
+ x = x.reshape(-1, self.hidden_size)
416
+ x = self.mlp_fc1(x)
417
+ x = self.mlp_act(x)
418
+ x = self.mlp_fc2(x)
419
+ return x
420
+
421
+
422
+ class Qwen2_5_VisionRotaryEmbedding(nnx.Module):
423
+
424
+ def __init__(self, dim: int, theta: float = 10000.0):
425
+ self.dim = dim
426
+ self.theta = theta
427
+
428
+ def __call__(self, seqlen: int) -> jax.Array:
429
+ inv_freq = 1.0 / (self.theta**(
430
+ jnp.arange(0, self.dim, 2, dtype=jnp.float32) / self.dim))
431
+ seq = jnp.arange(seqlen, dtype=jnp.float32)
432
+ freqs = jnp.outer(seq, inv_freq)
433
+ return freqs.astype(jnp.bfloat16)
434
+
435
+
436
+ class Qwen2_5_VisionTransformer(nnx.Module):
437
+
438
+ def __init__(self,
439
+ vllm_config: VllmConfig,
440
+ rngs: nnx.Rngs,
441
+ mesh: Mesh,
442
+ norm_eps: float = 1e-6):
443
+ model_config = vllm_config.model_config
444
+ hf_config = model_config.hf_config
445
+ vision_config = hf_config.vision_config
446
+ dtype = model_config.dtype
447
+
448
+ self.config = vision_config
449
+ self.dtype = dtype
450
+
451
+ patch_size = vision_config.patch_size
452
+ temporal_patch_size = vision_config.temporal_patch_size
453
+ in_channels = vision_config.in_channels
454
+ self.hidden_size = vision_config.hidden_size
455
+ self.num_heads = vision_config.num_heads
456
+
457
+ # args for get_window_index_thw
458
+ self.window_size = vision_config.window_size
459
+ self.patch_size = vision_config.patch_size
460
+ self.spatial_merge_size = vision_config.spatial_merge_size
461
+ self.fullatt_block_indexes = vision_config.fullatt_block_indexes
462
+ self.spatial_merge_unit = self.spatial_merge_size**2
463
+
464
+ self.patch_embed = Qwen2_5_VisionPatchEmbed(
465
+ patch_size=patch_size,
466
+ temporal_patch_size=temporal_patch_size,
467
+ in_channels=in_channels,
468
+ hidden_size=self.hidden_size,
469
+ dtype=dtype,
470
+ rngs=rngs)
471
+
472
+ head_dim = vision_config.hidden_size // vision_config.num_heads
473
+ self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
474
+
475
+ self.blocks = [
476
+ Qwen2_5_VisionBlock(
477
+ config=hf_config,
478
+ dtype=dtype,
479
+ rngs=rngs,
480
+ mesh=mesh,
481
+ ) for _ in range(vision_config.depth)
482
+ ]
483
+ self.merger = Qwen2_5_VisionPatchMerger(
484
+ d_model=vision_config.out_hidden_size,
485
+ context_dim=vision_config.hidden_size,
486
+ norm_layer=partial(nnx.RMSNorm, epsilon=norm_eps),
487
+ spatial_merge_size=vision_config.spatial_merge_size,
488
+ dtype=dtype,
489
+ rngs=rngs)
490
+
491
+ def rotary_pos_emb_thw(self, t, h, w):
492
+ hpos_ids, wpos_ids = jnp.indices((h, w))
493
+ hpos_ids = hpos_ids.reshape(
494
+ h // self.spatial_merge_size,
495
+ self.spatial_merge_size,
496
+ w // self.spatial_merge_size,
497
+ self.spatial_merge_size,
498
+ ).transpose(0, 2, 1, 3).flatten()
499
+ wpos_ids = wpos_ids.reshape(
500
+ h // self.spatial_merge_size,
501
+ self.spatial_merge_size,
502
+ w // self.spatial_merge_size,
503
+ self.spatial_merge_size,
504
+ ).transpose(0, 2, 1, 3).flatten()
505
+ pos_ids = jnp.stack([hpos_ids, wpos_ids], axis=-1)
506
+ pos_ids = jnp.tile(pos_ids, (t, 1))
507
+
508
+ max_size = max(h, w)
509
+ rotary_pos_emb_full = self.rotary_pos_emb(max_size)
510
+
511
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].reshape(
512
+ pos_ids.shape[0], -1)
513
+ rotary_pos_emb = rotary_pos_emb.reshape(
514
+ rotary_pos_emb.shape[0] // self.spatial_merge_unit,
515
+ self.spatial_merge_unit, -1)
516
+
517
+ return rotary_pos_emb
518
+
519
+ def get_window_index_thw(self, grid_t, grid_h, grid_w):
520
+ vit_merger_window_size = (self.window_size //
521
+ self.spatial_merge_size // self.patch_size)
522
+
523
+ llm_grid_h = grid_h // self.spatial_merge_size
524
+ llm_grid_w = grid_w // self.spatial_merge_size
525
+
526
+ index = jnp.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
527
+ grid_t, llm_grid_h, llm_grid_w)
528
+
529
+ pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
530
+ pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
531
+ num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
532
+ num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
533
+
534
+ index_padded = jnp.pad(index, ((0, 0), (0, pad_h), (0, pad_w)),
535
+ constant_values=-100)
536
+ index_padded = index_padded.reshape(grid_t, num_windows_h,
537
+ vit_merger_window_size,
538
+ num_windows_w,
539
+ vit_merger_window_size)
540
+ index_padded = jnp.transpose(index_padded, (0, 1, 3, 2, 4)).reshape(
541
+ grid_t, num_windows_h * num_windows_w, vit_merger_window_size,
542
+ vit_merger_window_size)
543
+ seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
544
+ index_padded = index_padded.reshape(-1)
545
+ # The number of valid indices is static because grid_t, grid_h, grid_w
546
+ # are static.
547
+ num_valid_indices = grid_t * llm_grid_h * llm_grid_w
548
+ valid_indices = jnp.nonzero(index_padded != -100,
549
+ size=num_valid_indices)[0]
550
+ index_new = index_padded[valid_indices]
551
+ cu_seqlens_tmp = jnp.cumsum(seqlens) * self.spatial_merge_unit
552
+ cu_seqlens_tmp = cu_seqlens_tmp.astype(jnp.int32)
553
+
554
+ # NOTE (wenlong): Pytorch code uses this to reduce replication,
555
+ # but I don't think there is a need here, plus it would cause problem in JIT
556
+ # Please refer here if there is a problem down-stream
557
+ # cu_seqlens_tmp = jnp.unique(cu_seqlens_tmp)
558
+
559
+ return index_new, cu_seqlens_tmp
560
+
561
+ def get_rope_by_thw(self, t, h, w):
562
+ window_index_thw, cu_seqlens_window_thw = self.get_window_index_thw(
563
+ t, h, w)
564
+
565
+ rotary_pos_emb_thw = self.rotary_pos_emb_thw(t, h, w)
566
+
567
+ rotary_pos_emb_thw = rotary_pos_emb_thw[window_index_thw, :, :]
568
+ rotary_pos_emb_thw = rotary_pos_emb_thw.reshape(
569
+ -1, rotary_pos_emb_thw.shape[-1])
570
+ cu_seqlens_thw = jnp.full(t, h * w, dtype=jnp.int32)
571
+
572
+ return (rotary_pos_emb_thw, window_index_thw, cu_seqlens_window_thw,
573
+ cu_seqlens_thw)
574
+
575
+ def compute_attn_mask_seqlen(
576
+ self,
577
+ cu_seqlens: jax.Array,
578
+ ) -> tuple[Optional[int], Optional[list[int]]]:
579
+ max_seqlen, seqlens = None
580
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
581
+ seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
582
+ return max_seqlen, seqlens
583
+
584
+ def __call__(self, x: jax.Array, grid_thw: tuple[tuple[int, int,
585
+ int]]) -> jax.Array:
586
+ # x: pixel_values: jax.Array
587
+ # """Shape:
588
+ # `(num_patches, num_channels * patch_size * patch_size)`
589
+ # """
590
+
591
+ # grid_thw: image_grid_thw: jax.Array
592
+ # """Shape: `(num_images, 3)`
593
+ # This should be in `(grid_t, grid_h, grid_w)` format.
594
+ # """
595
+ hidden_states = self.patch_embed(x)
596
+
597
+ # num of patches
598
+ seq_len = x.shape[0]
599
+ # num of images/videoes
600
+ num_grids = len(grid_thw)
601
+
602
+ rotary_pos_emb = []
603
+ window_index: list = []
604
+ cu_window_seqlens: list = [jnp.array([0], dtype=jnp.int32)]
605
+ cu_seqlens: list = []
606
+
607
+ window_index_id = 0
608
+ cu_window_seqlens_last = 0
609
+ for i in range(num_grids):
610
+ t, h, w = grid_thw[i]
611
+
612
+ llm_h = h // self.spatial_merge_size
613
+ llm_w = w // self.spatial_merge_size
614
+
615
+ (
616
+ rotary_pos_emb_thw,
617
+ window_index_thw,
618
+ cu_seqlens_window_thw,
619
+ cu_seqlens_thw,
620
+ ) = self.get_rope_by_thw(t, h, w)
621
+
622
+ window_index.append(window_index_thw + window_index_id)
623
+ window_index_id += (t * llm_h * llm_w)
624
+
625
+ cu_seqlens_window_thw = (cu_seqlens_window_thw +
626
+ cu_window_seqlens_last)
627
+ cu_window_seqlens_last = cu_seqlens_window_thw[-1]
628
+ cu_window_seqlens.append(cu_seqlens_window_thw)
629
+
630
+ rotary_pos_emb.append(rotary_pos_emb_thw)
631
+
632
+ cu_seqlens.append(cu_seqlens_thw)
633
+
634
+ rotary_pos_emb = jnp.concatenate(rotary_pos_emb, axis=0)
635
+ window_index = jnp.concatenate(window_index, axis=0)
636
+ cu_window_seqlens = jnp.concatenate(cu_window_seqlens, axis=0)
637
+
638
+ cu_seqlens = jnp.concatenate(cu_seqlens, axis=0)
639
+ cu_seqlens = jnp.cumsum(cu_seqlens, axis=0, dtype=jnp.int32)
640
+ cu_seqlens = jnp.pad(cu_seqlens, ((1, 0), ),
641
+ mode='constant',
642
+ constant_values=0)
643
+
644
+ hidden_states = hidden_states.reshape(
645
+ seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
646
+ hidden_states = hidden_states[window_index, :, :]
647
+ hidden_states = hidden_states.reshape(seq_len, -1)
648
+
649
+ hidden_states = jnp.expand_dims(hidden_states, axis=1)
650
+
651
+ for layer_num, blk in enumerate(self.blocks):
652
+ if layer_num in self.fullatt_block_indexes:
653
+ hidden_states = blk(hidden_states,
654
+ rotary_pos_emb=rotary_pos_emb,
655
+ cu_window_seqlens=cu_seqlens,
656
+ use_fullattn=True)
657
+ else:
658
+ hidden_states = blk(hidden_states,
659
+ rotary_pos_emb=rotary_pos_emb,
660
+ cu_window_seqlens=cu_window_seqlens,
661
+ use_fullattn=False)
662
+
663
+ # adapter
664
+ hidden_states = self.merger(hidden_states)
665
+ reverse_indices = jnp.argsort(window_index)
666
+ hidden_states = hidden_states[reverse_indices, :]
667
+ return hidden_states
668
+
669
+
670
+ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
671
+
672
+ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array,
673
+ mesh: Mesh) -> None:
674
+ config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
675
+ multimodal_config = vllm_config.model_config.multimodal_config
676
+
677
+ self.vllm_config = vllm_config
678
+ self.rng = nnx.Rngs(rng_key)
679
+ self.mesh = mesh
680
+
681
+ self.config = config
682
+ self.multimodal_config = multimodal_config
683
+
684
+ self.visual = Qwen2_5_VisionTransformer(
685
+ vllm_config=vllm_config,
686
+ rngs=self.rng,
687
+ mesh=mesh,
688
+ norm_eps=getattr(config, "rms_norm_eps", 1e-6),
689
+ )
690
+ self.language_model = Qwen2ForCausalLM(vllm_config, rng_key, mesh)
691
+
692
+ @classmethod
693
+ def get_mrope_input_positions(
694
+ cls,
695
+ input_tokens: list[int],
696
+ hf_config,
697
+ image_grid_thw,
698
+ video_grid_thw,
699
+ second_per_grid_ts: list[float],
700
+ context_len: int = 0,
701
+ seq_len: int | None = None,
702
+ audio_feature_lengths=None,
703
+ use_audio_in_video: bool = False,
704
+ ):
705
+ return vllm_model_cls.get_mrope_input_positions(
706
+ input_tokens=input_tokens,
707
+ hf_config=hf_config,
708
+ image_grid_thw=image_grid_thw,
709
+ video_grid_thw=video_grid_thw,
710
+ second_per_grid_ts=second_per_grid_ts,
711
+ context_len=context_len,
712
+ seq_len=seq_len,
713
+ audio_feature_lengths=audio_feature_lengths,
714
+ use_audio_in_video=use_audio_in_video,
715
+ )
716
+
717
+ def _validate_and_reshape_mm_tensor(self, mm_input: object,
718
+ name: str) -> jax.Array:
719
+ if isinstance(mm_input, list):
720
+ # Assuming it's a list of arrays (e.g., np.ndarray, torch.Tensor)
721
+ # that can be concatenated.
722
+ arrays_to_concat = [jnp.asarray(item) for item in mm_input]
723
+ return jnp.concatenate(arrays_to_concat, axis=0)
724
+
725
+ # Handle single array-like objects (np.ndarray, torch.Tensor, jax.Array)
726
+ if hasattr(mm_input, 'ndim'):
727
+ array_input = jnp.asarray(mm_input)
728
+ if array_input.ndim == 2:
729
+ return array_input
730
+ if array_input.ndim == 3:
731
+ # This reshapes the batched 3D tensor to a 2D tensor.
732
+ return array_input.reshape(-1, array_input.shape[-1])
733
+
734
+ raise ValueError(f"Incorrect type of {name}. "
735
+ f"Got type: {type(mm_input)}")
736
+
737
+ def _parse_and_validate_image_input(
738
+ self, image_grid_thw: tuple[tuple[int, int, int], ...],
739
+ **kwargs: object) -> Optional[Qwen2_5_VLImageInputs]:
740
+ pixel_values = kwargs.pop("pixel_values", None)
741
+ image_embeds = kwargs.pop("image_embeds", None)
742
+ # image_grid_thw = kwargs.pop("image_grid_thw", None)
743
+
744
+ if pixel_values is None and image_embeds is None:
745
+ return None
746
+
747
+ if pixel_values is not None:
748
+ pixel_values = self._validate_and_reshape_mm_tensor(
749
+ pixel_values, "image pixel values")
750
+ # image_grid_thw = self._validate_and_reshape_mm_tensor(
751
+ # image_grid_thw, "image grid_thw")
752
+
753
+ if not isinstance(pixel_values, jax.Array):
754
+ raise ValueError("Incorrect type of image pixel values. "
755
+ f"Got type: {type(pixel_values)}")
756
+
757
+ return Qwen2_5_VLImagePixelInputs(type="pixel_values",
758
+ pixel_values=pixel_values,
759
+ image_grid_thw=image_grid_thw)
760
+
761
+ # Note: comment them out for now and save for future support
762
+ # if image_embeds is not None:
763
+ # image_embeds = self._validate_and_reshape_mm_tensor(
764
+ # image_embeds, "image embeds")
765
+ # image_grid_thw = self._validate_and_reshape_mm_tensor(
766
+ # image_grid_thw, "image grid_thw")
767
+
768
+ # if not isinstance(image_embeds, jax.Array):
769
+ # raise ValueError("Incorrect type of image embeddings. "
770
+ # f"Got type: {type(image_embeds)}")
771
+ # return Qwen2_5_VLImageEmbeddingInputs(
772
+ # type="image_embeds",
773
+ # image_embeds=image_embeds,
774
+ # image_grid_thw=image_grid_thw)
775
+
776
+ def _parse_and_validate_multimodal_inputs(self,
777
+ image_grid_thw: tuple[tuple[int,
778
+ int,
779
+ int],
780
+ ...],
781
+ **kwargs: object) -> dict:
782
+ mm_input_by_modality = {}
783
+
784
+ # Preserve the order of modalities if there are multiple of them
785
+ # from the order of kwargs.
786
+ for input_key in kwargs:
787
+ if input_key in ("pixel_values", "image_embeds"
788
+ ) and "image" not in mm_input_by_modality:
789
+ mm_input_by_modality[
790
+ "image"] = self._parse_and_validate_image_input(
791
+ image_grid_thw, **kwargs)
792
+ # if input_key in ("pixel_values_videos", "video_embeds"
793
+ # ) and "video" not in mm_input_by_modality:
794
+ # mm_input_by_modality[
795
+ # "video"] = self._parse_and_validate_video_input(**kwargs)
796
+ return mm_input_by_modality
797
+
798
+ @partial(
799
+ jax.jit,
800
+ static_argnames=("image_grid_thw", ),
801
+ )
802
+ def get_single_image_embedding(self, image_pixel_values, image_grid_thw):
803
+ return self.visual(image_pixel_values, (image_grid_thw, ))
804
+
805
+ def _process_image_input(
806
+ self, image_input: Qwen2_5_VLImageInputs) -> tuple[jax.Array, ...]:
807
+
808
+ grid_thw = image_input["image_grid_thw"]
809
+
810
+ if image_input["type"] == "image_embeds":
811
+ image_embeds = image_input["image_embeds"].astype(
812
+ self.visual.dtype)
813
+ else:
814
+ pixel_values = image_input["pixel_values"]
815
+ image_embeds = []
816
+ current_idx = 0
817
+ for image_thw in grid_thw:
818
+ t, h, w = image_thw
819
+ image_size = t * h * w
820
+ end_idx = current_idx + image_size
821
+ image_pixel_values = pixel_values[current_idx:end_idx, :]
822
+ image_embeds.append(
823
+ self.get_single_image_embedding(image_pixel_values,
824
+ image_thw))
825
+ current_idx = end_idx
826
+ image_embeds = jnp.concatenate(image_embeds, axis=0)
827
+
828
+ # Split concatenated embeddings for each image item.
829
+ merge_size = self.visual.config.spatial_merge_size
830
+ sizes = np.prod(np.array(grid_thw, dtype=np.int64),
831
+ axis=-1) // merge_size // merge_size
832
+
833
+ if sizes.size == 0:
834
+ return ()
835
+ if sizes.size == 1:
836
+ return (image_embeds, )
837
+
838
+ split_indices = np.cumsum(sizes)[:-1]
839
+ return tuple(jnp.split(image_embeds, split_indices))
840
+
841
+ def get_multimodal_embeddings(self, image_grid_thw: tuple[tuple[int, int,
842
+ int], ...],
843
+ **kwargs: object) -> MultiModalEmbeddings:
844
+
845
+ mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
846
+ image_grid_thw, **kwargs)
847
+ if not mm_input_by_modality:
848
+ return []
849
+
850
+ # The result multimodal_embeddings is tuple of tensors, with each
851
+ # tensor correspoending to a multimodal data item (image or video).
852
+ multimodal_embeddings: tuple[jax.Array, ...] = ()
853
+
854
+ # NOTE: It is important to iterate over the keys in this dictionary
855
+ # to preserve the order of the modalities.
856
+ for modality in mm_input_by_modality:
857
+ multimodal_input = mm_input_by_modality[modality]
858
+ if modality == "image":
859
+ vision_embeddings = self._process_image_input(multimodal_input)
860
+ multimodal_embeddings += vision_embeddings
861
+ # if modality == "video":
862
+ # video_embeddings = self._process_video_input(multimodal_input)
863
+ # multimodal_embeddings += video_embeddings
864
+
865
+ return multimodal_embeddings
866
+
867
+ def get_input_embeddings(
868
+ self, input_ids: jax.Array,
869
+ multimodal_embeddings: Optional[MultiModalEmbeddings]
870
+ ) -> jax.Array:
871
+
872
+ inputs_embeds = self.language_model.model.embed(input_ids)
873
+
874
+
875
+ if multimodal_embeddings is not None \
876
+ and len(multimodal_embeddings) != 0:
877
+ inputs_embeds = merge_multimodal_embeddings(
878
+ input_ids, inputs_embeds, multimodal_embeddings,
879
+ [self.config.image_token_id, self.config.video_token_id])
880
+
881
+ return inputs_embeds
882
+
883
+ def __call__(
884
+ self,
885
+ kv_caches: list[jax.Array],
886
+ input_ids: Optional[jax.Array],
887
+ attention_metadata: AttentionMetadata,
888
+ inputs_embeds: Optional[jax.Array] = None,
889
+ *args,
890
+ ) -> tuple[list[jax.Array], jax.Array, List[jax.Array]]:
891
+ # The logic of choosing between input_ids and inputs_embeds is
892
+ # handled inside self.language_model.__call__
893
+ kv_caches, x, [] = self.language_model(
894
+ kv_caches=kv_caches,
895
+ input_ids=input_ids,
896
+ attention_metadata=attention_metadata,
897
+ inputs_embeds=inputs_embeds,
898
+ )
899
+ return kv_caches, x, []
900
+
901
+ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
902
+ return self.language_model.compute_logits(hidden_states)
903
+
904
+ def load_weights(self, rng_key: jax.Array) -> None:
905
+ self.rng = nnx.Rngs(rng_key)
906
+ self.language_model.rng = self.rng
907
+
908
+ # Key: path to a HF layer weight
909
+ # Value: a tuple of (path to a nnx layer weight, nnx weight sharding)
910
+
911
+ mappings = {
912
+ "model.embed_tokens": "language_model.model.embed.embedding",
913
+ "model.layers.*.input_layernorm":
914
+ "language_model.model.layers.*.input_layernorm.scale",
915
+ "model.layers.*.mlp.down_proj":
916
+ "language_model.model.layers.*.mlp.down_proj.kernel",
917
+ "model.layers.*.mlp.gate_proj":
918
+ "language_model.model.layers.*.mlp.gate_proj.kernel",
919
+ "model.layers.*.mlp.up_proj":
920
+ "language_model.model.layers.*.mlp.up_proj.kernel",
921
+ "model.layers.*.post_attention_layernorm":
922
+ "language_model.model.layers.*.post_attention_layernorm.scale",
923
+ "model.layers.*.self_attn.k_proj":
924
+ "language_model.model.layers.*.self_attn.k_proj.kernel",
925
+ "model.layers.*.self_attn.o_proj":
926
+ "language_model.model.layers.*.self_attn.o_proj.kernel",
927
+ "model.layers.*.self_attn.q_proj":
928
+ "language_model.model.layers.*.self_attn.q_proj.kernel",
929
+ "model.layers.*.self_attn.v_proj":
930
+ "language_model.model.layers.*.self_attn.v_proj.kernel",
931
+ "model.layers.*.self_attn.q_proj.bias":
932
+ "language_model.model.layers.*.self_attn.q_proj.bias",
933
+ "model.layers.*.self_attn.k_proj.bias":
934
+ "language_model.model.layers.*.self_attn.k_proj.bias",
935
+ "model.layers.*.self_attn.v_proj.bias":
936
+ "language_model.model.layers.*.self_attn.v_proj.bias",
937
+ "model.norm": "language_model.model.norm.scale",
938
+ "visual.blocks.*.attn.proj.bias": "visual.blocks.*.attn.proj.bias",
939
+ "visual.blocks.*.attn.proj": "visual.blocks.*.attn.proj.kernel",
940
+ "visual.blocks.*.attn.qkv.bias":
941
+ "visual.blocks.*.attn.qkv_proj.bias",
942
+ "visual.blocks.*.attn.qkv": "visual.blocks.*.attn.qkv_proj.kernel",
943
+ "visual.blocks.*.mlp.down_proj.bias":
944
+ "visual.blocks.*.mlp.down_proj.bias",
945
+ "visual.blocks.*.mlp.down_proj":
946
+ "visual.blocks.*.mlp.down_proj.kernel",
947
+ "visual.blocks.*.mlp.gate_proj.bias":
948
+ "visual.blocks.*.mlp.gate_proj.bias",
949
+ "visual.blocks.*.mlp.gate_proj":
950
+ "visual.blocks.*.mlp.gate_proj.kernel",
951
+ "visual.blocks.*.mlp.up_proj.bias":
952
+ "visual.blocks.*.mlp.up_proj.bias",
953
+ "visual.blocks.*.mlp.up_proj":
954
+ "visual.blocks.*.mlp.up_proj.kernel",
955
+ "visual.blocks.*.norm1": "visual.blocks.*.norm1.scale",
956
+ "visual.blocks.*.norm2": "visual.blocks.*.norm2.scale",
957
+ "visual.merger.ln_q": "visual.merger.ln_q.scale",
958
+ "visual.merger.mlp.0.bias": "visual.merger.mlp_fc1.bias",
959
+ "visual.merger.mlp.0": "visual.merger.mlp_fc1.kernel",
960
+ "visual.merger.mlp.2.bias": "visual.merger.mlp_fc2.bias",
961
+ "visual.merger.mlp.2": "visual.merger.mlp_fc2.kernel",
962
+ "visual.patch_embed.proj": "visual.patch_embed.proj.kernel",
963
+ }
964
+
965
+ # Add lm_head mapping only if it's not tied to embeddings
966
+ hf_config = self.vllm_config.model_config.hf_config
967
+ if not hf_config.tie_word_embeddings:
968
+ mappings.update({
969
+ "lm_head": "language_model.model.lm_head",
970
+ })
971
+
972
+ metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
973
+ load_hf_weights(vllm_config=self.vllm_config,
974
+ model=self,
975
+ metadata_map=metadata_map,
976
+ mesh=self.mesh)