tpu-inference 0.0.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (174) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +374 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +648 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +88 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +203 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +235 -0
  27. tpu_inference/__init__.py +53 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +49 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +727 -0
  37. tpu_inference/distributed/utils.py +60 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +160 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +382 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1566 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1501 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1603 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +396 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +469 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +110 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +331 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +368 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +310 -0
  120. tpu_inference/models/__init__.py +0 -0
  121. tpu_inference/models/common/__init__.py +0 -0
  122. tpu_inference/models/common/model_loader.py +478 -0
  123. tpu_inference/models/jax/__init__.py +0 -0
  124. tpu_inference/models/jax/deepseek_v3.py +868 -0
  125. tpu_inference/models/jax/gpt_oss.py +492 -0
  126. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  127. tpu_inference/models/jax/llama3.py +376 -0
  128. tpu_inference/models/jax/llama4.py +629 -0
  129. tpu_inference/models/jax/llama_eagle3.py +336 -0
  130. tpu_inference/models/jax/llama_guard_4.py +361 -0
  131. tpu_inference/models/jax/qwen2.py +376 -0
  132. tpu_inference/models/jax/qwen2_5_vl.py +1218 -0
  133. tpu_inference/models/jax/qwen3.py +303 -0
  134. tpu_inference/models/jax/utils/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/file_utils.py +96 -0
  136. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  137. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  138. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  139. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  140. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  141. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  142. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  143. tpu_inference/models/jax/utils/quantization/quantization_utils.py +650 -0
  144. tpu_inference/models/jax/utils/weight_utils.py +584 -0
  145. tpu_inference/models/vllm/__init__.py +0 -0
  146. tpu_inference/models/vllm/vllm_model_wrapper.py +293 -0
  147. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  148. tpu_inference/platforms/__init__.py +2 -0
  149. tpu_inference/platforms/tpu_platform.py +275 -0
  150. tpu_inference/runner/__init__.py +0 -0
  151. tpu_inference/runner/block_table.py +122 -0
  152. tpu_inference/runner/compilation_manager.py +865 -0
  153. tpu_inference/runner/input_batch.py +435 -0
  154. tpu_inference/runner/kv_cache.py +132 -0
  155. tpu_inference/runner/kv_cache_manager.py +478 -0
  156. tpu_inference/runner/lora_utils.py +92 -0
  157. tpu_inference/runner/multimodal_manager.py +217 -0
  158. tpu_inference/runner/persistent_batch_manager.py +282 -0
  159. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  160. tpu_inference/runner/structured_decoding_manager.py +87 -0
  161. tpu_inference/runner/tpu_runner.py +1744 -0
  162. tpu_inference/runner/utils.py +426 -0
  163. tpu_inference/spec_decode/__init__.py +0 -0
  164. tpu_inference/spec_decode/jax/__init__.py +0 -0
  165. tpu_inference/spec_decode/jax/eagle3.py +417 -0
  166. tpu_inference/tpu_info.py +78 -0
  167. tpu_inference/utils.py +340 -0
  168. tpu_inference/worker/__init__.py +0 -0
  169. tpu_inference/worker/tpu_worker.py +458 -0
  170. tpu_inference-0.0.1rc1.dist-info/METADATA +108 -0
  171. tpu_inference-0.0.1rc1.dist-info/RECORD +174 -0
  172. tpu_inference-0.0.1rc1.dist-info/WHEEL +5 -0
  173. tpu_inference-0.0.1rc1.dist-info/licenses/LICENSE +201 -0
  174. tpu_inference-0.0.1rc1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,1218 @@
1
+ import math
2
+ from functools import partial
3
+ from typing import (Callable, List, Literal, NamedTuple, Optional, TypedDict,
4
+ Union)
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+ from flax import nnx
10
+ from jax.sharding import Mesh
11
+ from transformers import modeling_flax_utils
12
+ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
13
+ Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
14
+ from vllm.config import VllmConfig
15
+
16
+ from tpu_inference import utils as utils
17
+ from tpu_inference.layers.common.attention_interface import \
18
+ sharded_flash_attention
19
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
20
+ from tpu_inference.logger import init_logger
21
+ from tpu_inference.models.jax.qwen2 import Qwen2ForCausalLM
22
+ # from vllm.model_executor.models.interfaces import MultiModalEmbeddings
23
+ from tpu_inference.models.jax.utils.multi_modal_utils import (
24
+ MultiModalEmbeddings, merge_multimodal_embeddings)
25
+ from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
26
+ load_hf_weights)
27
+
28
+ logger = init_logger(__name__)
29
+
30
+ init_fn = nnx.initializers.uniform()
31
+
32
+ DEFAULT_BLOCK_K_MAJOR = 128
33
+
34
+
35
+ class SegmentIds(NamedTuple):
36
+ """SegmentIds for Q and KV sequences.
37
+
38
+ SegmentIds are used to generate segment mask, which prevents attention between
39
+ different segments in the input sequence. Each array is a list of ids
40
+ (integers).
41
+ Only the token with the same id can attend to each other.
42
+
43
+ Attributes:
44
+ q: segment ids along the Q sequence.
45
+ kv: segment ids along the KV sequence.
46
+ """
47
+
48
+ q: jax.Array # [batch_size, q_seq_len]
49
+ kv: jax.Array # [batch_size, kv_seq_len]
50
+
51
+
52
+ class Qwen2_5_VLImagePixelInputs(TypedDict):
53
+ type: Literal["pixel_values"]
54
+ pixel_values: jax.Array
55
+ """Shape:
56
+ `(num_patches, num_channels * patch_size * patch_size)`
57
+ """
58
+
59
+ image_grid_thw: tuple[tuple[int, int, int], ...]
60
+ """Shape: `(num_images, 3)`
61
+ This should be in `(grid_t, grid_h, grid_w)` format.
62
+ """
63
+
64
+
65
+ # NOTE: We are not supporting embedding inputs for now
66
+ # The code here makes the struture consistent and
67
+ # makes iteasier for future implementation
68
+ class Qwen2_5_VLImageEmbeddingInputs(TypedDict):
69
+ type: Literal["image_embeds"]
70
+ image_embeds: jax.Array
71
+ """Supported types:
72
+ - list[`jax.Array`]: A list of tensors holding all images' features.
73
+ Each tensor holds an image's features.
74
+ - `jax.Array`: A tensor holding all images' features (concatenation of
75
+ all images' feature tensors).
76
+
77
+ Tensor shape: `(num_image_features, hidden_size)`
78
+ - `num_image_features` varies based on
79
+ the number and resolution of the images.
80
+ - `hidden_size` must match the hidden size of language model backbone.
81
+ """
82
+
83
+ image_grid_thw: jax.Array
84
+ """Shape: `(num_images, 3)`
85
+ This should be in `(grid_t, grid_h, grid_w)` format.
86
+ """
87
+
88
+
89
+ Qwen2_5_VLImageInputs = Union[Qwen2_5_VLImagePixelInputs,
90
+ Qwen2_5_VLImageEmbeddingInputs]
91
+
92
+
93
+ class Qwen2_5_VisionMLP(nnx.Module):
94
+
95
+ def __init__(self, config: Qwen2_5_VLVisionConfig, dtype: jnp.dtype,
96
+ rngs: nnx.Rngs):
97
+ in_features = config.hidden_size
98
+ hidden_features = config.intermediate_size
99
+ act_fn = modeling_flax_utils.ACT2FN[config.hidden_act]
100
+ self.gate_proj = nnx.Linear(
101
+ in_features,
102
+ hidden_features,
103
+ use_bias=True,
104
+ param_dtype=dtype,
105
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
106
+ bias_init=nnx.with_partitioning(init_fn, ("model", )),
107
+ rngs=rngs,
108
+ )
109
+ self.up_proj = nnx.Linear(
110
+ in_features,
111
+ hidden_features,
112
+ use_bias=True,
113
+ param_dtype=dtype,
114
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
115
+ bias_init=nnx.with_partitioning(init_fn, ("model", )),
116
+ rngs=rngs,
117
+ )
118
+ self.down_proj = nnx.Linear(
119
+ hidden_features,
120
+ in_features,
121
+ use_bias=True,
122
+ param_dtype=dtype,
123
+ kernel_init=nnx.with_partitioning(init_fn, ("model", None)),
124
+ bias_init=nnx.with_partitioning(init_fn, (None, )),
125
+ rngs=rngs,
126
+ )
127
+ self.act_fn = act_fn
128
+
129
+ def __call__(self, x: jax.Array) -> jax.Array:
130
+ gate = self.act_fn(self.gate_proj(x))
131
+ up = self.up_proj(x)
132
+ fuse = gate * up
133
+ result = self.down_proj(fuse)
134
+ return result
135
+
136
+
137
+ def apply_rotary_pos_emb_vision(x: jax.Array,
138
+ rotary_pos_emb: jax.Array) -> jax.Array:
139
+ # x: [B, T, N, H]
140
+ # rotary_pos_emb: [T, H//2]
141
+ _, _, _, H = x.shape
142
+ half_dim = H // 2
143
+
144
+ # [B, T, N, H//2]
145
+ x_real = x[..., :half_dim]
146
+ x_imag = x[..., half_dim:]
147
+
148
+ # [T, H//2]
149
+ cos_emb = jnp.cos(rotary_pos_emb)
150
+ sin_emb = jnp.sin(rotary_pos_emb)
151
+
152
+ # [1, T, 1, H//2]
153
+ cos_emb = cos_emb[None, :, None, :]
154
+ sin_emb = sin_emb[None, :, None, :]
155
+
156
+ # [B, T, N, H//2]
157
+ x_rotated_real = x_real * cos_emb - x_imag * sin_emb
158
+ x_rotated_imag = x_real * sin_emb + x_imag * cos_emb
159
+
160
+ # [B, T, N, H]
161
+ x_rotated = jnp.concatenate([x_rotated_real, x_rotated_imag], axis=-1)
162
+
163
+ return x_rotated
164
+
165
+
166
+ def generate_window_segment_ids(cu_seqlens: jax.Array, seq_len: int,
167
+ padded_seq_len: int) -> SegmentIds:
168
+ """Generates segment IDs for windowed attention
169
+
170
+ Args:
171
+ cu_seqlens: A 1D array of cumulative sequence lengths for each window.
172
+ e.g., [0, len_win0, len_win0+len_win1, ...]
173
+
174
+ Returns:
175
+ A SegmentIds object for flash_attention.
176
+ """
177
+ indices = jnp.arange(seq_len, dtype=jnp.int32)
178
+ segment_ids = jnp.searchsorted(cu_seqlens[1:], indices, side='right') + 1
179
+ padding_segment_ids = jnp.zeros(padded_seq_len - seq_len, dtype=jnp.int32)
180
+ segment_ids = jnp.concatenate([segment_ids, padding_segment_ids])
181
+ segment_ids = segment_ids.reshape(1, -1)
182
+
183
+ return SegmentIds(q=segment_ids, kv=segment_ids)
184
+
185
+
186
+ class Qwen2_5_VisionAttention(nnx.Module):
187
+
188
+ def __init__(self, config: Qwen2_5_VLConfig, dtype: jnp.dtype,
189
+ rngs: nnx.Rngs, mesh: Mesh):
190
+ vision_config = config.vision_config
191
+ self.hidden_size = vision_config.hidden_size
192
+ self.num_heads = vision_config.num_heads
193
+ self.num_kv_heads = self.num_heads
194
+ self.rope_theta = config.rope_theta
195
+ self.rope_scaling = getattr(config, "rope_scaling", None)
196
+ self.head_dim_original = self.hidden_size // self.num_heads
197
+
198
+ sharding_size = mesh.shape["model"]
199
+ self.num_heads = utils.get_padded_num_heads(self.num_heads,
200
+ sharding_size)
201
+ self.num_kv_heads = utils.get_padded_num_heads(self.num_kv_heads,
202
+ sharding_size)
203
+ self.head_dim = utils.get_padded_head_dim(self.head_dim_original)
204
+
205
+ # TODO: Wenlong: Do not consider padding for now
206
+ self.head_dim = self.head_dim_original
207
+
208
+ self.mesh = mesh
209
+
210
+ self.qkv_proj = nnx.Linear(
211
+ self.hidden_size,
212
+ 3 * self.hidden_size,
213
+ use_bias=True,
214
+ param_dtype=dtype,
215
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
216
+ bias_init=nnx.with_partitioning(init_fn, ("model", )),
217
+ rngs=rngs,
218
+ )
219
+
220
+ self.proj = nnx.Linear(
221
+ self.hidden_size,
222
+ self.hidden_size,
223
+ use_bias=True,
224
+ param_dtype=dtype,
225
+ kernel_init=nnx.with_partitioning(init_fn, ("model", None)),
226
+ bias_init=nnx.with_partitioning(init_fn, (None, )),
227
+ rngs=rngs,
228
+ )
229
+ self.flash_attention = sharded_flash_attention(
230
+ mesh=mesh,
231
+ causal=False,
232
+ sm_scale=1.0 / math.sqrt(self.head_dim),
233
+ vmem_limit_bytes=128 * 1024 * 1024,
234
+ )
235
+
236
+ def __call__(
237
+ self,
238
+ x: jax.Array,
239
+ rotary_pos_emb: jax.Array,
240
+ cu_window_seqlens: Optional[jax.Array] = None,
241
+ use_fullattn: bool = True,
242
+ ) -> jax.Array:
243
+ T, B, D = x.shape
244
+ assert B == 1, "Vision attention currently only supports batch size 1"
245
+ # [T, B, D] -> [T, B, 3 * D]
246
+ qkv = self.qkv_proj(x)
247
+
248
+ # Split into Q, K, V.
249
+ # NOTE: simplified from vLLM's split_qkv,
250
+ # may need to revisit for tp>1
251
+ # [T, B, 3 * D] -> 3 *[T, B, D]
252
+ q, k, v = jnp.split(qkv, 3, axis=-1)
253
+
254
+ # [T, B, N, H]
255
+ q = q.reshape(T, B, self.num_heads, self.head_dim)
256
+ k = k.reshape(T, B, self.num_heads, self.head_dim)
257
+ v = v.reshape(T, B, self.num_heads, self.head_dim)
258
+
259
+ # [T, B, N, H] -> [B, T, N, H]
260
+ q = jnp.transpose(q, (1, 0, 2, 3))
261
+ k = jnp.transpose(k, (1, 0, 2, 3))
262
+ v = jnp.transpose(v, (1, 0, 2, 3))
263
+
264
+ # rotary_pos_emb shape: (T, H)
265
+ q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
266
+ k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
267
+
268
+ # NOTE: an extra transpose because we need to
269
+ # align the correctness with vLLM's design.
270
+ # Might be able to remove one once implemented.
271
+ # [B, T, N, H] -> [B, N, T, H]
272
+ q = jnp.transpose(q, (0, 2, 1, 3))
273
+ k = jnp.transpose(k, (0, 2, 1, 3))
274
+ v = jnp.transpose(v, (0, 2, 1, 3))
275
+
276
+ # Pad the sequence length to be a multiple of 128 for flash_attention
277
+ block_k_major = DEFAULT_BLOCK_K_MAJOR
278
+ T_attn = q.shape[2]
279
+ padded_T = (T_attn + block_k_major -
280
+ 1) // block_k_major * block_k_major
281
+ pad_width = ((0, 0), (0, 0), (0, padded_T - T_attn), (0, 0))
282
+
283
+ q = jnp.pad(q, pad_width, 'constant')
284
+ k = jnp.pad(k, pad_width, 'constant')
285
+ v = jnp.pad(v, pad_width, 'constant')
286
+
287
+ segment_ids = generate_window_segment_ids(cu_window_seqlens, T_attn,
288
+ padded_T)
289
+
290
+ # TODO (jacobplatin): add support for quantized KV cache?
291
+ output = self.flash_attention(q, k, v, segment_ids)
292
+
293
+ # Unpad the output
294
+ output = output[:, :, :T_attn, :]
295
+
296
+ # [B, N, T, H] -> [T, B, N, H]
297
+ output = jnp.transpose(output, (2, 0, 1, 3))
298
+
299
+ output = output.reshape(T, B, D)
300
+
301
+ output = self.proj(output)
302
+
303
+ return output
304
+
305
+
306
+ class Qwen2_5_VisionBlock(nnx.Module):
307
+
308
+ def __init__(self, config: Qwen2_5_VLConfig, dtype: jnp.dtype,
309
+ rngs: nnx.Rngs, mesh: Mesh):
310
+ vision_config = config.vision_config
311
+ dim = vision_config.hidden_size
312
+ norm_layer = partial(nnx.RMSNorm,
313
+ epsilon=config.rms_norm_eps,
314
+ scale_init=nnx.with_partitioning(
315
+ init_fn, (None, )))
316
+
317
+ self.norm1 = norm_layer(dim, dtype=dtype, rngs=rngs)
318
+ self.norm2 = norm_layer(dim, dtype=dtype, rngs=rngs)
319
+ self.attn = Qwen2_5_VisionAttention(config=config,
320
+ dtype=dtype,
321
+ rngs=rngs,
322
+ mesh=mesh)
323
+ self.mlp = Qwen2_5_VisionMLP(config=vision_config,
324
+ dtype=dtype,
325
+ rngs=rngs)
326
+
327
+ def __call__(self,
328
+ x: jax.Array,
329
+ rotary_pos_emb: jax.Array,
330
+ cu_window_seqlens: Optional[jax.Array] = None,
331
+ use_fullattn: bool = True) -> jax.Array:
332
+
333
+ x = x + self.attn(self.norm1(x), rotary_pos_emb, cu_window_seqlens,
334
+ use_fullattn)
335
+ x = x + self.mlp(self.norm2(x))
336
+
337
+ return x
338
+
339
+
340
+ class Qwen2_5_VisionPatchEmbed(nnx.Module):
341
+
342
+ def __init__(
343
+ self,
344
+ rngs: nnx.Rngs,
345
+ patch_size: int = 14,
346
+ temporal_patch_size: int = 2,
347
+ in_channels: int = 3,
348
+ hidden_size: int = 1152,
349
+ dtype: jnp.dtype = jnp.bfloat16,
350
+ ) -> None:
351
+ self.patch_size = patch_size
352
+ self.temporal_patch_size = temporal_patch_size
353
+ self.hidden_size = hidden_size
354
+ kernel_size = (temporal_patch_size, patch_size, patch_size)
355
+ self.proj = nnx.Conv(in_features=in_channels,
356
+ out_features=hidden_size,
357
+ kernel_size=kernel_size,
358
+ strides=kernel_size,
359
+ use_bias=False,
360
+ param_dtype=dtype,
361
+ kernel_init=nnx.with_partitioning(
362
+ init_fn, (None, None, None, None, "model")),
363
+ rngs=rngs)
364
+
365
+ def __call__(self, x: jax.Array) -> jax.Array:
366
+ # x is (L, C * T * H * W)
367
+ L, dim = x.shape
368
+ C = dim // (self.temporal_patch_size * self.patch_size *
369
+ self.patch_size)
370
+ # Reshape to (L, T, H, W, C) for Conv3D with channels_last
371
+ x = x.reshape(L, C, self.temporal_patch_size, self.patch_size,
372
+ self.patch_size)
373
+ # L,T,H,W,C
374
+ x = jnp.transpose(x, (0, 2, 3, 4, 1))
375
+ x = self.proj(x)
376
+ # After conv, shape is (L, T_out, H_out, W_out, C_out)
377
+ # With stride=kernel_size, T_out=H_out=W_out=1.
378
+ # So shape is (L, 1, 1, 1, hidden_size)
379
+ x = x.reshape(L, self.hidden_size)
380
+ return x
381
+
382
+
383
+ class Qwen2_5_VisionPatchMerger(nnx.Module):
384
+
385
+ def __init__(self, d_model: int, context_dim: int, norm_layer: Callable,
386
+ spatial_merge_size: int, dtype: jnp.dtype, rngs: nnx.Rngs):
387
+ self.hidden_size = context_dim * (spatial_merge_size**2)
388
+ self.ln_q = norm_layer(context_dim,
389
+ dtype=dtype,
390
+ rngs=rngs,
391
+ scale_init=nnx.with_partitioning(
392
+ init_fn, (None, )))
393
+ self.mlp_fc1 = nnx.Linear(
394
+ self.hidden_size,
395
+ self.hidden_size,
396
+ use_bias=True,
397
+ param_dtype=dtype,
398
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
399
+ bias_init=nnx.with_partitioning(init_fn, ("model", )),
400
+ rngs=rngs)
401
+ self.mlp_act = modeling_flax_utils.ACT2FN["gelu"]
402
+ self.mlp_fc2 = nnx.Linear(
403
+ self.hidden_size,
404
+ d_model,
405
+ use_bias=True,
406
+ param_dtype=dtype,
407
+ kernel_init=nnx.with_partitioning(init_fn, ("model", None)),
408
+ bias_init=nnx.with_partitioning(init_fn, (None, )),
409
+ rngs=rngs)
410
+
411
+ def __call__(self, x: jax.Array) -> jax.Array:
412
+ x = self.ln_q(x)
413
+ x = x.reshape(-1, self.hidden_size)
414
+ x = self.mlp_fc1(x)
415
+ x = self.mlp_act(x)
416
+ x = self.mlp_fc2(x)
417
+ return x
418
+
419
+
420
+ class Qwen2_5_VisionRotaryEmbedding(nnx.Module):
421
+
422
+ def __init__(self, dim: int, theta: float = 10000.0):
423
+ self.dim = dim
424
+ self.theta = theta
425
+
426
+ def __call__(self, seqlen: int) -> jax.Array:
427
+ inv_freq = 1.0 / (self.theta**(
428
+ jnp.arange(0, self.dim, 2, dtype=jnp.float32) / self.dim))
429
+ seq = jnp.arange(seqlen, dtype=jnp.float32)
430
+ freqs = jnp.outer(seq, inv_freq)
431
+ return freqs.astype(jnp.bfloat16)
432
+
433
+
434
+ class Qwen2_5_VisionTransformer(nnx.Module):
435
+
436
+ def __init__(self,
437
+ vllm_config: VllmConfig,
438
+ rngs: nnx.Rngs,
439
+ mesh: Mesh,
440
+ norm_eps: float = 1e-6):
441
+ model_config = vllm_config.model_config
442
+ hf_config = model_config.hf_config
443
+ vision_config = hf_config.vision_config
444
+ dtype = model_config.dtype
445
+
446
+ self.config = vision_config
447
+ self.dtype = dtype
448
+
449
+ patch_size = vision_config.patch_size
450
+ temporal_patch_size = vision_config.temporal_patch_size
451
+ in_channels = vision_config.in_channels
452
+ self.hidden_size = vision_config.hidden_size
453
+ self.num_heads = vision_config.num_heads
454
+
455
+ # args for get_window_index_thw
456
+ self.window_size = vision_config.window_size
457
+ self.patch_size = vision_config.patch_size
458
+ self.spatial_merge_size = vision_config.spatial_merge_size
459
+ self.fullatt_block_indexes = vision_config.fullatt_block_indexes
460
+ self.spatial_merge_unit = self.spatial_merge_size**2
461
+
462
+ self.patch_embed = Qwen2_5_VisionPatchEmbed(
463
+ patch_size=patch_size,
464
+ temporal_patch_size=temporal_patch_size,
465
+ in_channels=in_channels,
466
+ hidden_size=self.hidden_size,
467
+ dtype=dtype,
468
+ rngs=rngs)
469
+
470
+ head_dim = vision_config.hidden_size // vision_config.num_heads
471
+ self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
472
+
473
+ self.blocks = [
474
+ Qwen2_5_VisionBlock(
475
+ config=hf_config,
476
+ dtype=dtype,
477
+ rngs=rngs,
478
+ mesh=mesh,
479
+ ) for _ in range(vision_config.depth)
480
+ ]
481
+ self.merger = Qwen2_5_VisionPatchMerger(
482
+ d_model=vision_config.out_hidden_size,
483
+ context_dim=vision_config.hidden_size,
484
+ norm_layer=partial(nnx.RMSNorm, epsilon=norm_eps),
485
+ spatial_merge_size=vision_config.spatial_merge_size,
486
+ dtype=dtype,
487
+ rngs=rngs)
488
+
489
+ additional_config = getattr(vllm_config, "additional_config",
490
+ None) or {}
491
+ self.enable_dynamic_image_sizes = additional_config.get(
492
+ "enable_dynamic_image_sizes", False)
493
+
494
+ def rotary_pos_emb_thw(self, t, h, w):
495
+ hpos_ids, wpos_ids = jnp.indices((h, w))
496
+ hpos_ids = hpos_ids.reshape(
497
+ h // self.spatial_merge_size,
498
+ self.spatial_merge_size,
499
+ w // self.spatial_merge_size,
500
+ self.spatial_merge_size,
501
+ ).transpose(0, 2, 1, 3).flatten()
502
+ wpos_ids = wpos_ids.reshape(
503
+ h // self.spatial_merge_size,
504
+ self.spatial_merge_size,
505
+ w // self.spatial_merge_size,
506
+ self.spatial_merge_size,
507
+ ).transpose(0, 2, 1, 3).flatten()
508
+ pos_ids = jnp.stack([hpos_ids, wpos_ids], axis=-1)
509
+ pos_ids = jnp.tile(pos_ids, (t, 1))
510
+
511
+ max_size = max(h, w)
512
+ rotary_pos_emb_full = self.rotary_pos_emb(max_size)
513
+
514
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].reshape(
515
+ pos_ids.shape[0], -1)
516
+ rotary_pos_emb = rotary_pos_emb.reshape(
517
+ rotary_pos_emb.shape[0] // self.spatial_merge_unit,
518
+ self.spatial_merge_unit, -1)
519
+
520
+ return rotary_pos_emb
521
+
522
+ def get_window_index_thw(self, grid_t, grid_h, grid_w):
523
+ vit_merger_window_size = (self.window_size //
524
+ self.spatial_merge_size // self.patch_size)
525
+
526
+ llm_grid_h = grid_h // self.spatial_merge_size
527
+ llm_grid_w = grid_w // self.spatial_merge_size
528
+
529
+ index = jnp.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
530
+ grid_t, llm_grid_h, llm_grid_w)
531
+
532
+ pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
533
+ pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
534
+ num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
535
+ num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
536
+
537
+ index_padded = jnp.pad(index, ((0, 0), (0, pad_h), (0, pad_w)),
538
+ constant_values=-100)
539
+ index_padded = index_padded.reshape(grid_t, num_windows_h,
540
+ vit_merger_window_size,
541
+ num_windows_w,
542
+ vit_merger_window_size)
543
+ index_padded = jnp.transpose(index_padded, (0, 1, 3, 2, 4)).reshape(
544
+ grid_t, num_windows_h * num_windows_w, vit_merger_window_size,
545
+ vit_merger_window_size)
546
+ seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
547
+ index_padded = index_padded.reshape(-1)
548
+ # The number of valid indices is static because grid_t, grid_h, grid_w
549
+ # are static.
550
+ num_valid_indices = grid_t * llm_grid_h * llm_grid_w
551
+ valid_indices = jnp.nonzero(index_padded != -100,
552
+ size=num_valid_indices)[0]
553
+ index_new = index_padded[valid_indices]
554
+ cu_seqlens_tmp = jnp.cumsum(seqlens) * self.spatial_merge_unit
555
+ cu_seqlens_tmp = cu_seqlens_tmp.astype(jnp.int32)
556
+
557
+ # NOTE (wenlong): Pytorch code uses this to reduce replication,
558
+ # but I don't think there is a need here, plus it would cause problem in JIT
559
+ # Please refer here if there is a problem down-stream
560
+ # cu_seqlens_tmp = jnp.unique(cu_seqlens_tmp)
561
+
562
+ return index_new, cu_seqlens_tmp
563
+
564
+ def get_rope_by_thw(self, t, h, w):
565
+ window_index_thw, cu_seqlens_window_thw = self.get_window_index_thw(
566
+ t, h, w)
567
+
568
+ rotary_pos_emb_thw = self.rotary_pos_emb_thw(t, h, w)
569
+
570
+ rotary_pos_emb_thw = rotary_pos_emb_thw[window_index_thw, :, :]
571
+ rotary_pos_emb_thw = rotary_pos_emb_thw.reshape(
572
+ -1, rotary_pos_emb_thw.shape[-1])
573
+ cu_seqlens_thw = jnp.full(t, h * w, dtype=jnp.int32)
574
+
575
+ return (rotary_pos_emb_thw, window_index_thw, cu_seqlens_window_thw,
576
+ cu_seqlens_thw)
577
+
578
+ def compute_attn_mask_seqlen(
579
+ self,
580
+ cu_seqlens: jax.Array,
581
+ ) -> tuple[Optional[int], Optional[list[int]]]:
582
+ max_seqlen, seqlens = None
583
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
584
+ seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
585
+ return max_seqlen, seqlens
586
+
587
+ def compute_aux_arrays(self, grid_thw: tuple[tuple[int, int, int]]):
588
+ # num of images/videoes
589
+ num_grids = len(grid_thw)
590
+
591
+ rotary_pos_emb = []
592
+ window_index: list = []
593
+ cu_window_seqlens: list = [jnp.array([0], dtype=jnp.int32)]
594
+ cu_seqlens: list = []
595
+
596
+ window_index_id = 0
597
+ cu_window_seqlens_last = 0
598
+ for i in range(num_grids):
599
+ t, h, w = grid_thw[i]
600
+
601
+ llm_h = h // self.spatial_merge_size
602
+ llm_w = w // self.spatial_merge_size
603
+
604
+ (
605
+ rotary_pos_emb_thw,
606
+ window_index_thw,
607
+ cu_seqlens_window_thw,
608
+ cu_seqlens_thw,
609
+ ) = self.get_rope_by_thw(t, h, w)
610
+
611
+ window_index.append(window_index_thw + window_index_id)
612
+ window_index_id += (t * llm_h * llm_w)
613
+
614
+ cu_seqlens_window_thw = (cu_seqlens_window_thw +
615
+ cu_window_seqlens_last)
616
+ cu_window_seqlens_last = cu_seqlens_window_thw[-1]
617
+ cu_window_seqlens.append(cu_seqlens_window_thw)
618
+
619
+ rotary_pos_emb.append(rotary_pos_emb_thw)
620
+
621
+ cu_seqlens.append(cu_seqlens_thw)
622
+
623
+ rotary_pos_emb = jnp.concatenate(rotary_pos_emb, axis=0)
624
+ window_index = jnp.concatenate(window_index, axis=0)
625
+ cu_window_seqlens = jnp.concatenate(cu_window_seqlens, axis=0)
626
+
627
+ cu_seqlens = jnp.concatenate(cu_seqlens, axis=0)
628
+ cu_seqlens = jnp.cumsum(cu_seqlens, axis=0, dtype=jnp.int32)
629
+ cu_seqlens = jnp.pad(cu_seqlens, ((1, 0), ),
630
+ mode='constant',
631
+ constant_values=0)
632
+ return window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens
633
+
634
+ def pad_inputs(self, x, window_index, rotary_pos_emb, cu_seqlens,
635
+ cu_window_seqlens):
636
+ # padding
637
+ num_patches = int(rotary_pos_emb.shape[0])
638
+ bucket_num_patches = 1 << (num_patches - 1).bit_length()
639
+ num_tokens = window_index.shape[0]
640
+ bucket_num_tokens = bucket_num_patches // self.spatial_merge_unit
641
+ vit_merger_window_size = (self.window_size //
642
+ self.spatial_merge_size // self.patch_size)
643
+ max_windows = (bucket_num_tokens // vit_merger_window_size) + 2
644
+
645
+ rotary_pos_emb = jnp.pad(rotary_pos_emb,
646
+ ((0, bucket_num_patches - num_patches),
647
+ (0, 0)))
648
+ window_index = jnp.concatenate([
649
+ window_index,
650
+ jnp.arange(num_tokens, bucket_num_tokens, dtype=jnp.int32)
651
+ ])
652
+ cu_window_seqlens = jnp.append(cu_window_seqlens, bucket_num_patches)
653
+ pad_w = max(0, max_windows + 1 - cu_window_seqlens.shape[0])
654
+ cu_window_seqlens = jnp.pad(cu_window_seqlens, (0, pad_w), mode='edge')
655
+ cu_seqlens = jnp.append(cu_seqlens, bucket_num_patches)
656
+
657
+ x_padded = jnp.pad(x, ((0, bucket_num_patches - x.shape[0]), (0, 0)))
658
+
659
+ return x_padded, window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens, num_tokens
660
+
661
+ def compute_hidden_states(self, x: jax.Array, window_index: jax.Array,
662
+ rotary_pos_emb: jax.Array, cu_seqlens: jax.Array,
663
+ cu_window_seqlens: jax.Array) -> jax.Array:
664
+ hidden_states = self.patch_embed(x)
665
+
666
+ # num of patches
667
+ seq_len = x.shape[0]
668
+
669
+ hidden_states = hidden_states.reshape(
670
+ seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
671
+ hidden_states = hidden_states[window_index, :, :]
672
+ hidden_states = hidden_states.reshape(seq_len, -1)
673
+
674
+ hidden_states = jnp.expand_dims(hidden_states, axis=1)
675
+
676
+ for layer_num, blk in enumerate(self.blocks):
677
+ if layer_num in self.fullatt_block_indexes:
678
+ hidden_states = blk(hidden_states,
679
+ rotary_pos_emb=rotary_pos_emb,
680
+ cu_window_seqlens=cu_seqlens,
681
+ use_fullattn=True)
682
+ else:
683
+ hidden_states = blk(hidden_states,
684
+ rotary_pos_emb=rotary_pos_emb,
685
+ cu_window_seqlens=cu_window_seqlens,
686
+ use_fullattn=False)
687
+
688
+ # adapter
689
+ hidden_states = self.merger(hidden_states)
690
+ reverse_indices = jnp.argsort(window_index)
691
+ hidden_states = hidden_states[reverse_indices, :]
692
+ return hidden_states
693
+
694
+ @jax.jit
695
+ def encode_padded_jit(self, x_padded, window_index, rotary_pos_emb,
696
+ cu_seqlens, cu_window_seqlens):
697
+ return self.compute_hidden_states(x_padded, window_index,
698
+ rotary_pos_emb, cu_seqlens,
699
+ cu_window_seqlens)
700
+
701
+ @partial(
702
+ jax.jit,
703
+ static_argnames=("grid_thw", ),
704
+ )
705
+ def encode_jit(self, x, grid_thw):
706
+ window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens = self.compute_aux_arrays(
707
+ grid_thw)
708
+ return self.compute_hidden_states(x, window_index, rotary_pos_emb,
709
+ cu_seqlens, cu_window_seqlens)
710
+
711
+ def __call__(self, x: jax.Array, grid_thw: tuple[tuple[int, int,
712
+ int]]) -> jax.Array:
713
+ # x: pixel_values: jax.Array
714
+ # """Shape:
715
+ # `(num_patches, num_channels * patch_size * patch_size)`
716
+ # """
717
+
718
+ # grid_thw: image_grid_thw: jax.Array
719
+ # """Shape: `(num_images, 3)`
720
+ # This should be in `(grid_t, grid_h, grid_w)` format.
721
+ # """
722
+ if self.enable_dynamic_image_sizes:
723
+ window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens = self.compute_aux_arrays(
724
+ grid_thw)
725
+ x_padded, window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens, num_tokens = self.pad_inputs(
726
+ x, window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens)
727
+
728
+ hidden_states = self.encode_padded_jit(x_padded, window_index,
729
+ rotary_pos_emb, cu_seqlens,
730
+ cu_window_seqlens)
731
+ return hidden_states[:num_tokens]
732
+
733
+ else:
734
+ return self.encode_jit(x, grid_thw)
735
+
736
+
737
+ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
738
+
739
+ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array,
740
+ mesh: Mesh) -> None:
741
+ config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
742
+ multimodal_config = vllm_config.model_config.multimodal_config
743
+
744
+ self.vllm_config = vllm_config
745
+ self.rng = nnx.Rngs(rng_key)
746
+ self.mesh = mesh
747
+
748
+ self.config = config
749
+ self.multimodal_config = multimodal_config
750
+
751
+ self.visual = Qwen2_5_VisionTransformer(
752
+ vllm_config=vllm_config,
753
+ rngs=self.rng,
754
+ mesh=mesh,
755
+ norm_eps=getattr(config, "rms_norm_eps", 1e-6),
756
+ )
757
+ self.language_model = Qwen2ForCausalLM(vllm_config, rng_key, mesh)
758
+
759
+ def get_mrope_input_positions(
760
+ self,
761
+ input_tokens: list[int],
762
+ hf_config,
763
+ image_grid_thw,
764
+ video_grid_thw,
765
+ second_per_grid_ts: list[float],
766
+ context_len: int = 0,
767
+ seq_len: int | None = None,
768
+ audio_feature_lengths=None,
769
+ use_audio_in_video: bool = False,
770
+ ) -> tuple[jax.Array, int]:
771
+ """Get mrope input positions and delta value."""
772
+
773
+ image_token_id = hf_config.image_token_id
774
+ video_token_id = hf_config.video_token_id
775
+ vision_start_token_id = hf_config.vision_start_token_id
776
+ spatial_merge_size = hf_config.vision_config.spatial_merge_size
777
+ tokens_per_second = getattr(hf_config.vision_config,
778
+ "tokens_per_second", 1.0)
779
+
780
+ input_tokens_tensor = np.array(input_tokens)
781
+ vision_start_indices = np.argwhere(
782
+ input_tokens_tensor == vision_start_token_id).squeeze(1)
783
+ vision_tokens = input_tokens_tensor[vision_start_indices + 1]
784
+ image_nums = np.sum(vision_tokens == image_token_id)
785
+ video_nums = np.sum(vision_tokens == video_token_id)
786
+ llm_pos_ids_list: list = []
787
+
788
+ st = 0
789
+ remain_images, remain_videos = image_nums, video_nums
790
+
791
+ image_index, video_index = 0, 0
792
+ for _ in range(image_nums + video_nums):
793
+ video_second_per_grid_t = 0.0
794
+ if remain_images > 0:
795
+ try:
796
+ ed_image = input_tokens.index(image_token_id, st)
797
+ except ValueError:
798
+ ed_image = len(input_tokens) + 1
799
+ else:
800
+ ed_image = len(input_tokens) + 1
801
+ if remain_videos > 0:
802
+ try:
803
+ ed_video = input_tokens.index(video_token_id, st)
804
+ except ValueError:
805
+ ed_video = len(input_tokens) + 1
806
+ else:
807
+ ed_video = len(input_tokens) + 1
808
+ if ed_image < ed_video:
809
+ t, h, w = (
810
+ image_grid_thw[image_index][0],
811
+ image_grid_thw[image_index][1],
812
+ image_grid_thw[image_index][2],
813
+ )
814
+ image_index += 1
815
+ remain_images -= 1
816
+ ed = ed_image
817
+ else:
818
+ t, h, w = (
819
+ video_grid_thw[video_index][0],
820
+ video_grid_thw[video_index][1],
821
+ video_grid_thw[video_index][2],
822
+ )
823
+ video_second_per_grid_t = 1.0
824
+ if second_per_grid_ts:
825
+ video_second_per_grid_t = second_per_grid_ts[video_index]
826
+ video_index += 1
827
+ remain_videos -= 1
828
+ ed = ed_video
829
+
830
+ llm_grid_t, llm_grid_h, llm_grid_w = (
831
+ t,
832
+ h // spatial_merge_size,
833
+ w // spatial_merge_size,
834
+ )
835
+ text_len = ed - st
836
+
837
+ st_idx = llm_pos_ids_list[-1].max().item() + 1 if len(
838
+ llm_pos_ids_list) > 0 else 0
839
+ llm_pos_ids_list.append(
840
+ jnp.broadcast_to(
841
+ jnp.arange(text_len, dtype=jnp.int32).reshape(1, -1),
842
+ (3, text_len)) + st_idx)
843
+
844
+ t_index = ((jnp.broadcast_to(
845
+ jnp.arange(llm_grid_t, dtype=jnp.int32).reshape(-1, 1),
846
+ (llm_grid_t, llm_grid_h * llm_grid_w)) *
847
+ video_second_per_grid_t * tokens_per_second).astype(
848
+ jnp.int32).flatten())
849
+
850
+ h_index = (jnp.broadcast_to(
851
+ jnp.arange(llm_grid_h, dtype=jnp.int32).reshape(1, -1, 1),
852
+ (llm_grid_t, llm_grid_h, llm_grid_w)).flatten())
853
+ w_index = (jnp.broadcast_to(
854
+ jnp.arange(llm_grid_w, dtype=jnp.int32).reshape(1, 1, -1),
855
+ (llm_grid_t, llm_grid_h, llm_grid_w)).flatten())
856
+
857
+ llm_pos_ids_list.append(
858
+ jnp.stack([t_index, h_index, w_index]) + text_len + st_idx)
859
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
860
+
861
+ if st < len(input_tokens):
862
+ st_idx = llm_pos_ids_list[-1].max().item() + 1 if len(
863
+ llm_pos_ids_list) > 0 else 0
864
+ text_len = len(input_tokens) - st
865
+
866
+ llm_pos_ids_list.append(
867
+ jnp.broadcast_to(
868
+ jnp.arange(text_len, dtype=jnp.int32).reshape(1, -1),
869
+ (3, text_len)) + st_idx)
870
+
871
+ llm_positions = jnp.concatenate(llm_pos_ids_list,
872
+ axis=1).reshape(3, -1)
873
+ mrope_position_delta = (llm_positions.max() + 1 -
874
+ len(input_tokens)).item()
875
+ llm_positions = llm_positions[:, context_len:seq_len]
876
+
877
+ return llm_positions, mrope_position_delta
878
+
879
+ def _validate_and_reshape_mm_tensor(self, mm_input: object,
880
+ name: str) -> jax.Array:
881
+ if isinstance(mm_input, list):
882
+ # Assuming it's a list of arrays (e.g., np.ndarray, torch.Tensor)
883
+ # that can be concatenated.
884
+ arrays_to_concat = [jnp.asarray(item) for item in mm_input]
885
+ return jnp.concatenate(arrays_to_concat, axis=0)
886
+
887
+ # Handle single array-like objects (np.ndarray, torch.Tensor, jax.Array)
888
+ if hasattr(mm_input, 'ndim'):
889
+ array_input = jnp.asarray(mm_input)
890
+ if array_input.ndim == 2:
891
+ return array_input
892
+ if array_input.ndim == 3:
893
+ # This reshapes the batched 3D tensor to a 2D tensor.
894
+ return array_input.reshape(-1, array_input.shape[-1])
895
+
896
+ raise ValueError(f"Incorrect type of {name}. "
897
+ f"Got type: {type(mm_input)}")
898
+
899
+ def _parse_and_validate_image_input(
900
+ self, image_grid_thw: tuple[tuple[int, int, int], ...],
901
+ **kwargs: object) -> Optional[Qwen2_5_VLImageInputs]:
902
+ pixel_values = kwargs.pop("pixel_values", None)
903
+ image_embeds = kwargs.pop("image_embeds", None)
904
+ # image_grid_thw = kwargs.pop("image_grid_thw", None)
905
+
906
+ if pixel_values is None and image_embeds is None:
907
+ return None
908
+
909
+ if pixel_values is not None:
910
+ pixel_values = self._validate_and_reshape_mm_tensor(
911
+ pixel_values, "image pixel values")
912
+ # image_grid_thw = self._validate_and_reshape_mm_tensor(
913
+ # image_grid_thw, "image grid_thw")
914
+
915
+ if not isinstance(pixel_values, jax.Array):
916
+ raise ValueError("Incorrect type of image pixel values. "
917
+ f"Got type: {type(pixel_values)}")
918
+
919
+ return Qwen2_5_VLImagePixelInputs(type="pixel_values",
920
+ pixel_values=pixel_values,
921
+ image_grid_thw=image_grid_thw)
922
+
923
+ # Note: comment them out for now and save for future support
924
+ # if image_embeds is not None:
925
+ # image_embeds = self._validate_and_reshape_mm_tensor(
926
+ # image_embeds, "image embeds")
927
+ # image_grid_thw = self._validate_and_reshape_mm_tensor(
928
+ # image_grid_thw, "image grid_thw")
929
+
930
+ # if not isinstance(image_embeds, jax.Array):
931
+ # raise ValueError("Incorrect type of image embeddings. "
932
+ # f"Got type: {type(image_embeds)}")
933
+ # return Qwen2_5_VLImageEmbeddingInputs(
934
+ # type="image_embeds",
935
+ # image_embeds=image_embeds,
936
+ # image_grid_thw=image_grid_thw)
937
+
938
+ def _parse_and_validate_multimodal_inputs(self,
939
+ image_grid_thw: tuple[tuple[int,
940
+ int,
941
+ int],
942
+ ...],
943
+ **kwargs: object) -> dict:
944
+ mm_input_by_modality = {}
945
+
946
+ # Preserve the order of modalities if there are multiple of them
947
+ # from the order of kwargs.
948
+ for input_key in kwargs:
949
+ if input_key in ("pixel_values", "image_embeds"
950
+ ) and "image" not in mm_input_by_modality:
951
+ mm_input_by_modality[
952
+ "image"] = self._parse_and_validate_image_input(
953
+ image_grid_thw, **kwargs)
954
+ # if input_key in ("pixel_values_videos", "video_embeds"
955
+ # ) and "video" not in mm_input_by_modality:
956
+ # mm_input_by_modality[
957
+ # "video"] = self._parse_and_validate_video_input(**kwargs)
958
+ return mm_input_by_modality
959
+
960
+ def get_single_image_embedding(self, image_pixel_values, image_grid_thw):
961
+ return self.visual(image_pixel_values, (image_grid_thw, ))
962
+
963
+ def _process_image_input(
964
+ self, image_input: Qwen2_5_VLImageInputs) -> tuple[jax.Array, ...]:
965
+
966
+ grid_thw = image_input["image_grid_thw"]
967
+
968
+ if image_input["type"] == "image_embeds":
969
+ image_embeds = image_input["image_embeds"].astype(
970
+ self.visual.dtype)
971
+ else:
972
+ pixel_values = image_input["pixel_values"]
973
+ image_embeds = []
974
+ current_idx = 0
975
+ for image_thw in grid_thw:
976
+ t, h, w = image_thw
977
+ image_size = t * h * w
978
+ end_idx = current_idx + image_size
979
+ image_pixel_values = pixel_values[current_idx:end_idx, :]
980
+ image_embeds.append(
981
+ self.get_single_image_embedding(image_pixel_values,
982
+ image_thw))
983
+ current_idx = end_idx
984
+ image_embeds = jnp.concatenate(image_embeds, axis=0)
985
+
986
+ # Split concatenated embeddings for each image item.
987
+ merge_size = self.visual.config.spatial_merge_size
988
+ sizes = np.prod(np.array(grid_thw, dtype=np.int64),
989
+ axis=-1) // merge_size // merge_size
990
+
991
+ if sizes.size == 0:
992
+ return ()
993
+ if sizes.size == 1:
994
+ return (image_embeds, )
995
+
996
+ split_indices = np.cumsum(sizes)[:-1]
997
+ return tuple(jnp.split(image_embeds, split_indices))
998
+
999
+ def get_multimodal_embeddings(self, image_grid_thw: tuple[tuple[int, int,
1000
+ int], ...],
1001
+ **kwargs: object) -> MultiModalEmbeddings:
1002
+
1003
+ mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
1004
+ image_grid_thw, **kwargs)
1005
+ if not mm_input_by_modality:
1006
+ return []
1007
+
1008
+ # The result multimodal_embeddings is tuple of tensors, with each
1009
+ # tensor correspoending to a multimodal data item (image or video).
1010
+ multimodal_embeddings: tuple[jax.Array, ...] = ()
1011
+
1012
+ # NOTE: It is important to iterate over the keys in this dictionary
1013
+ # to preserve the order of the modalities.
1014
+ for modality in mm_input_by_modality:
1015
+ multimodal_input = mm_input_by_modality[modality]
1016
+ if modality == "image":
1017
+ vision_embeddings = self._process_image_input(multimodal_input)
1018
+ multimodal_embeddings += vision_embeddings
1019
+ # if modality == "video":
1020
+ # video_embeddings = self._process_video_input(multimodal_input)
1021
+ # multimodal_embeddings += video_embeddings
1022
+
1023
+ return multimodal_embeddings
1024
+
1025
+ def get_input_embeddings(
1026
+ self, input_ids: jax.Array,
1027
+ multimodal_embeddings: Optional[jax.Array]) -> jax.Array:
1028
+
1029
+ inputs_embeds = self.language_model.model.embed(input_ids)
1030
+
1031
+
1032
+ if multimodal_embeddings is not None \
1033
+ and multimodal_embeddings.shape[0] != 0:
1034
+ inputs_embeds = merge_multimodal_embeddings(
1035
+ input_ids, inputs_embeds, multimodal_embeddings,
1036
+ [self.config.image_token_id, self.config.video_token_id])
1037
+
1038
+ return inputs_embeds
1039
+
1040
+ def __call__(
1041
+ self,
1042
+ kv_caches: list[jax.Array],
1043
+ input_ids: Optional[jax.Array],
1044
+ attention_metadata: AttentionMetadata,
1045
+ inputs_embeds: Optional[jax.Array] = None,
1046
+ *args,
1047
+ ) -> tuple[list[jax.Array], jax.Array, List[jax.Array]]:
1048
+ # The logic of choosing between input_ids and inputs_embeds is
1049
+ # handled inside self.language_model.__call__
1050
+ kv_caches, x, [] = self.language_model(
1051
+ kv_caches=kv_caches,
1052
+ input_ids=input_ids,
1053
+ attention_metadata=attention_metadata,
1054
+ inputs_embeds=inputs_embeds,
1055
+ )
1056
+ return kv_caches, x, []
1057
+
1058
+ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
1059
+ return self.language_model.compute_logits(hidden_states)
1060
+
1061
+ def load_weights(self, rng_key: jax.Array) -> None:
1062
+ self.rng = nnx.Rngs(rng_key)
1063
+ self.language_model.rng = self.rng
1064
+
1065
+ # Key: path to a HF layer weight
1066
+ # Value: a tuple of (path to a nnx layer weight, nnx weight sharding)
1067
+
1068
+ mappings = {
1069
+ "model.embed_tokens": "language_model.model.embed.embedding",
1070
+ "model.layers.*.input_layernorm":
1071
+ "language_model.model.layers.*.input_layernorm.scale",
1072
+ "model.layers.*.mlp.down_proj":
1073
+ "language_model.model.layers.*.mlp.down_proj.kernel",
1074
+ "model.layers.*.mlp.gate_proj":
1075
+ "language_model.model.layers.*.mlp.gate_proj.kernel",
1076
+ "model.layers.*.mlp.up_proj":
1077
+ "language_model.model.layers.*.mlp.up_proj.kernel",
1078
+ "model.layers.*.post_attention_layernorm":
1079
+ "language_model.model.layers.*.post_attention_layernorm.scale",
1080
+ "model.layers.*.self_attn.k_proj":
1081
+ "language_model.model.layers.*.self_attn.k_proj.kernel",
1082
+ "model.layers.*.self_attn.o_proj":
1083
+ "language_model.model.layers.*.self_attn.o_proj.kernel",
1084
+ "model.layers.*.self_attn.q_proj":
1085
+ "language_model.model.layers.*.self_attn.q_proj.kernel",
1086
+ "model.layers.*.self_attn.v_proj":
1087
+ "language_model.model.layers.*.self_attn.v_proj.kernel",
1088
+ "model.layers.*.self_attn.q_proj.bias":
1089
+ "language_model.model.layers.*.self_attn.q_proj.bias",
1090
+ "model.layers.*.self_attn.k_proj.bias":
1091
+ "language_model.model.layers.*.self_attn.k_proj.bias",
1092
+ "model.layers.*.self_attn.v_proj.bias":
1093
+ "language_model.model.layers.*.self_attn.v_proj.bias",
1094
+ "model.norm": "language_model.model.norm.scale",
1095
+ "visual.blocks.*.attn.proj.bias": "visual.blocks.*.attn.proj.bias",
1096
+ "visual.blocks.*.attn.proj": "visual.blocks.*.attn.proj.kernel",
1097
+ "visual.blocks.*.attn.qkv.bias":
1098
+ "visual.blocks.*.attn.qkv_proj.bias",
1099
+ "visual.blocks.*.attn.qkv": "visual.blocks.*.attn.qkv_proj.kernel",
1100
+ "visual.blocks.*.mlp.down_proj.bias":
1101
+ "visual.blocks.*.mlp.down_proj.bias",
1102
+ "visual.blocks.*.mlp.down_proj":
1103
+ "visual.blocks.*.mlp.down_proj.kernel",
1104
+ "visual.blocks.*.mlp.gate_proj.bias":
1105
+ "visual.blocks.*.mlp.gate_proj.bias",
1106
+ "visual.blocks.*.mlp.gate_proj":
1107
+ "visual.blocks.*.mlp.gate_proj.kernel",
1108
+ "visual.blocks.*.mlp.up_proj.bias":
1109
+ "visual.blocks.*.mlp.up_proj.bias",
1110
+ "visual.blocks.*.mlp.up_proj":
1111
+ "visual.blocks.*.mlp.up_proj.kernel",
1112
+ "visual.blocks.*.norm1": "visual.blocks.*.norm1.scale",
1113
+ "visual.blocks.*.norm2": "visual.blocks.*.norm2.scale",
1114
+ "visual.merger.ln_q": "visual.merger.ln_q.scale",
1115
+ "visual.merger.mlp.0.bias": "visual.merger.mlp_fc1.bias",
1116
+ "visual.merger.mlp.0": "visual.merger.mlp_fc1.kernel",
1117
+ "visual.merger.mlp.2.bias": "visual.merger.mlp_fc2.bias",
1118
+ "visual.merger.mlp.2": "visual.merger.mlp_fc2.kernel",
1119
+ "visual.patch_embed.proj": "visual.patch_embed.proj.kernel",
1120
+ }
1121
+
1122
+ # Add lm_head mapping only if it's not tied to embeddings
1123
+ hf_config = self.vllm_config.model_config.hf_config
1124
+ if not hf_config.tie_word_embeddings:
1125
+ mappings.update({
1126
+ "lm_head": "language_model.model.lm_head",
1127
+ })
1128
+
1129
+ metadata_map = get_default_maps(self.vllm_config.model_config,
1130
+ self.mesh, mappings)
1131
+ load_hf_weights(vllm_config=self.vllm_config,
1132
+ model=self,
1133
+ metadata_map=metadata_map,
1134
+ mesh=self.mesh)
1135
+
1136
+ def precompile_vision_encoder(
1137
+ self,
1138
+ run_compilation_fn: Callable,
1139
+ ) -> None:
1140
+ vc = self.vllm_config.model_config.hf_config.vision_config
1141
+ patch_input_dim = vc.in_channels * vc.temporal_patch_size * vc.patch_size * vc.patch_size
1142
+ if self.visual.enable_dynamic_image_sizes:
1143
+ spatial_merge_unit = vc.spatial_merge_size**2
1144
+ max_num_batched_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens
1145
+ mm_kwargs = self.vllm_config.model_config.multimodal_config.mm_processor_kwargs or {}
1146
+ limit_pixels = float(mm_kwargs.get("max_pixels", float('inf')))
1147
+
1148
+ max_patches = int(
1149
+ min(max_num_batched_tokens * spatial_merge_unit,
1150
+ limit_pixels / (vc.patch_size**2)))
1151
+
1152
+ num_patches_paddings = [
1153
+ 1 << i for i in range(4, (max_patches - 1).bit_length() + 1)
1154
+ ]
1155
+ rotary_dim = vc.hidden_size // vc.num_heads // 2
1156
+ vit_merger_window_size = (vc.window_size //
1157
+ vc.spatial_merge_size // vc.patch_size)
1158
+
1159
+ for num_patches in num_patches_paddings:
1160
+ dummy_x_padded = jnp.ones(
1161
+ (num_patches, patch_input_dim),
1162
+ dtype=self.vllm_config.model_config.dtype)
1163
+
1164
+ num_tokens = num_patches // spatial_merge_unit
1165
+ dummy_window_index = jnp.arange(num_tokens, dtype=jnp.int32)
1166
+
1167
+ dummy_rotary_pos_emb = jnp.ones(
1168
+ (num_patches, rotary_dim),
1169
+ dtype=self.vllm_config.model_config.dtype)
1170
+
1171
+ dummy_cu_seqlens = jnp.array([0, num_patches, num_patches],
1172
+ dtype=jnp.int32)
1173
+
1174
+ max_windows = (num_tokens // vit_merger_window_size) + 2
1175
+ patches_per_window = (vit_merger_window_size**
1176
+ 2) * spatial_merge_unit
1177
+ dummy_cu_window_seqlens = jnp.arange(
1178
+ max_windows + 1, dtype=jnp.int32) * patches_per_window
1179
+ dummy_cu_window_seqlens = jnp.minimum(dummy_cu_window_seqlens,
1180
+ num_patches)
1181
+
1182
+ run_compilation_fn("vision_encoder_padded",
1183
+ self.visual.encode_padded_jit,
1184
+ dummy_x_padded,
1185
+ dummy_window_index,
1186
+ dummy_rotary_pos_emb,
1187
+ dummy_cu_seqlens,
1188
+ dummy_cu_window_seqlens,
1189
+ num_patches=num_patches)
1190
+ else:
1191
+ image_shapes = []
1192
+ if (warmup_config := self.vllm_config.additional_config.get(
1193
+ "vision_warmup_config")):
1194
+ image_shapes = warmup_config.get("image_shapes")
1195
+
1196
+ factor = vc.patch_size * vc.spatial_merge_size
1197
+ for input_hw in image_shapes:
1198
+ if not isinstance(input_hw, list) or len(input_hw) != 2:
1199
+ logger.warning(f"Skipping invalid shape {input_hw}.")
1200
+ continue
1201
+ h_input, w_input = input_hw
1202
+ h_processed = round(h_input / factor) * factor
1203
+ w_processed = round(w_input / factor) * factor
1204
+ t, h, w = 1, h_processed // vc.patch_size, w_processed // vc.patch_size
1205
+ grid_thw = (t, h, w)
1206
+ num_patches = t * h * w
1207
+
1208
+ dummy_pixel_values = jnp.ones(
1209
+ (num_patches, patch_input_dim),
1210
+ self.vllm_config.model_config.dtype,
1211
+ )
1212
+ dummy_grid_thw = (grid_thw, )
1213
+
1214
+ run_compilation_fn("vision_encoder",
1215
+ self.visual.encode_jit,
1216
+ dummy_pixel_values,
1217
+ dummy_grid_thw,
1218
+ image_shape=input_hw)