tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,1232 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from functools import partial
17
+ from typing import (Callable, List, Literal, NamedTuple, Optional, TypedDict,
18
+ Union)
19
+
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import numpy as np
23
+ from flax import nnx
24
+ from jax.sharding import Mesh
25
+ from transformers import modeling_flax_utils
26
+ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
27
+ Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
28
+ from vllm.config import VllmConfig
29
+
30
+ from tpu_inference import utils as utils
31
+ from tpu_inference.layers.common.attention_interface import \
32
+ sharded_flash_attention
33
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
34
+ from tpu_inference.logger import init_logger
35
+ from tpu_inference.models.jax.qwen2 import Qwen2ForCausalLM
36
+ # from vllm.model_executor.models.interfaces import MultiModalEmbeddings
37
+ from tpu_inference.models.jax.utils.multi_modal_utils import (
38
+ MultiModalEmbeddings, merge_multimodal_embeddings)
39
+ from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
40
+ load_hf_weights)
41
+
42
+ logger = init_logger(__name__)
43
+
44
+ init_fn = nnx.initializers.uniform()
45
+
46
+ DEFAULT_BLOCK_K_MAJOR = 128
47
+
48
+
49
+ class SegmentIds(NamedTuple):
50
+ """SegmentIds for Q and KV sequences.
51
+
52
+ SegmentIds are used to generate segment mask, which prevents attention between
53
+ different segments in the input sequence. Each array is a list of ids
54
+ (integers).
55
+ Only the token with the same id can attend to each other.
56
+
57
+ Attributes:
58
+ q: segment ids along the Q sequence.
59
+ kv: segment ids along the KV sequence.
60
+ """
61
+
62
+ q: jax.Array # [batch_size, q_seq_len]
63
+ kv: jax.Array # [batch_size, kv_seq_len]
64
+
65
+
66
+ class Qwen2_5_VLImagePixelInputs(TypedDict):
67
+ type: Literal["pixel_values"]
68
+ pixel_values: jax.Array
69
+ """Shape:
70
+ `(num_patches, num_channels * patch_size * patch_size)`
71
+ """
72
+
73
+ image_grid_thw: tuple[tuple[int, int, int], ...]
74
+ """Shape: `(num_images, 3)`
75
+ This should be in `(grid_t, grid_h, grid_w)` format.
76
+ """
77
+
78
+
79
+ # NOTE: We are not supporting embedding inputs for now
80
+ # The code here makes the struture consistent and
81
+ # makes iteasier for future implementation
82
+ class Qwen2_5_VLImageEmbeddingInputs(TypedDict):
83
+ type: Literal["image_embeds"]
84
+ image_embeds: jax.Array
85
+ """Supported types:
86
+ - list[`jax.Array`]: A list of tensors holding all images' features.
87
+ Each tensor holds an image's features.
88
+ - `jax.Array`: A tensor holding all images' features (concatenation of
89
+ all images' feature tensors).
90
+
91
+ Tensor shape: `(num_image_features, hidden_size)`
92
+ - `num_image_features` varies based on
93
+ the number and resolution of the images.
94
+ - `hidden_size` must match the hidden size of language model backbone.
95
+ """
96
+
97
+ image_grid_thw: jax.Array
98
+ """Shape: `(num_images, 3)`
99
+ This should be in `(grid_t, grid_h, grid_w)` format.
100
+ """
101
+
102
+
103
+ Qwen2_5_VLImageInputs = Union[Qwen2_5_VLImagePixelInputs,
104
+ Qwen2_5_VLImageEmbeddingInputs]
105
+
106
+
107
+ class Qwen2_5_VisionMLP(nnx.Module):
108
+
109
+ def __init__(self, config: Qwen2_5_VLVisionConfig, dtype: jnp.dtype,
110
+ rngs: nnx.Rngs):
111
+ in_features = config.hidden_size
112
+ hidden_features = config.intermediate_size
113
+ act_fn = modeling_flax_utils.ACT2FN[config.hidden_act]
114
+ self.gate_proj = nnx.Linear(
115
+ in_features,
116
+ hidden_features,
117
+ use_bias=True,
118
+ param_dtype=dtype,
119
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
120
+ bias_init=nnx.with_partitioning(init_fn, ("model", )),
121
+ rngs=rngs,
122
+ )
123
+ self.up_proj = nnx.Linear(
124
+ in_features,
125
+ hidden_features,
126
+ use_bias=True,
127
+ param_dtype=dtype,
128
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
129
+ bias_init=nnx.with_partitioning(init_fn, ("model", )),
130
+ rngs=rngs,
131
+ )
132
+ self.down_proj = nnx.Linear(
133
+ hidden_features,
134
+ in_features,
135
+ use_bias=True,
136
+ param_dtype=dtype,
137
+ kernel_init=nnx.with_partitioning(init_fn, ("model", None)),
138
+ bias_init=nnx.with_partitioning(init_fn, (None, )),
139
+ rngs=rngs,
140
+ )
141
+ self.act_fn = act_fn
142
+
143
+ def __call__(self, x: jax.Array) -> jax.Array:
144
+ gate = self.act_fn(self.gate_proj(x))
145
+ up = self.up_proj(x)
146
+ fuse = gate * up
147
+ result = self.down_proj(fuse)
148
+ return result
149
+
150
+
151
+ def apply_rotary_pos_emb_vision(x: jax.Array,
152
+ rotary_pos_emb: jax.Array) -> jax.Array:
153
+ # x: [B, T, N, H]
154
+ # rotary_pos_emb: [T, H//2]
155
+ _, _, _, H = x.shape
156
+ half_dim = H // 2
157
+
158
+ # [B, T, N, H//2]
159
+ x_real = x[..., :half_dim]
160
+ x_imag = x[..., half_dim:]
161
+
162
+ # [T, H//2]
163
+ cos_emb = jnp.cos(rotary_pos_emb)
164
+ sin_emb = jnp.sin(rotary_pos_emb)
165
+
166
+ # [1, T, 1, H//2]
167
+ cos_emb = cos_emb[None, :, None, :]
168
+ sin_emb = sin_emb[None, :, None, :]
169
+
170
+ # [B, T, N, H//2]
171
+ x_rotated_real = x_real * cos_emb - x_imag * sin_emb
172
+ x_rotated_imag = x_real * sin_emb + x_imag * cos_emb
173
+
174
+ # [B, T, N, H]
175
+ x_rotated = jnp.concatenate([x_rotated_real, x_rotated_imag], axis=-1)
176
+
177
+ return x_rotated
178
+
179
+
180
+ def generate_window_segment_ids(cu_seqlens: jax.Array, seq_len: int,
181
+ padded_seq_len: int) -> SegmentIds:
182
+ """Generates segment IDs for windowed attention
183
+
184
+ Args:
185
+ cu_seqlens: A 1D array of cumulative sequence lengths for each window.
186
+ e.g., [0, len_win0, len_win0+len_win1, ...]
187
+
188
+ Returns:
189
+ A SegmentIds object for flash_attention.
190
+ """
191
+ indices = jnp.arange(seq_len, dtype=jnp.int32)
192
+ segment_ids = jnp.searchsorted(cu_seqlens[1:], indices, side='right') + 1
193
+ padding_segment_ids = jnp.zeros(padded_seq_len - seq_len, dtype=jnp.int32)
194
+ segment_ids = jnp.concatenate([segment_ids, padding_segment_ids])
195
+ segment_ids = segment_ids.reshape(1, -1)
196
+
197
+ return SegmentIds(q=segment_ids, kv=segment_ids)
198
+
199
+
200
+ class Qwen2_5_VisionAttention(nnx.Module):
201
+
202
+ def __init__(self, config: Qwen2_5_VLConfig, dtype: jnp.dtype,
203
+ rngs: nnx.Rngs, mesh: Mesh):
204
+ vision_config = config.vision_config
205
+ self.hidden_size = vision_config.hidden_size
206
+ self.num_heads = vision_config.num_heads
207
+ self.num_kv_heads = self.num_heads
208
+ self.rope_theta = config.rope_theta
209
+ self.rope_scaling = getattr(config, "rope_scaling", None)
210
+ self.head_dim_original = self.hidden_size // self.num_heads
211
+
212
+ sharding_size = mesh.shape["model"]
213
+ self.num_heads = utils.get_padded_num_heads(self.num_heads,
214
+ sharding_size)
215
+ self.num_kv_heads = utils.get_padded_num_heads(self.num_kv_heads,
216
+ sharding_size)
217
+ self.head_dim = utils.get_padded_head_dim(self.head_dim_original)
218
+
219
+ # TODO: Wenlong: Do not consider padding for now
220
+ self.head_dim = self.head_dim_original
221
+
222
+ self.mesh = mesh
223
+
224
+ self.qkv_proj = nnx.Linear(
225
+ self.hidden_size,
226
+ 3 * self.hidden_size,
227
+ use_bias=True,
228
+ param_dtype=dtype,
229
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
230
+ bias_init=nnx.with_partitioning(init_fn, ("model", )),
231
+ rngs=rngs,
232
+ )
233
+
234
+ self.proj = nnx.Linear(
235
+ self.hidden_size,
236
+ self.hidden_size,
237
+ use_bias=True,
238
+ param_dtype=dtype,
239
+ kernel_init=nnx.with_partitioning(init_fn, ("model", None)),
240
+ bias_init=nnx.with_partitioning(init_fn, (None, )),
241
+ rngs=rngs,
242
+ )
243
+ self.flash_attention = sharded_flash_attention(
244
+ mesh=mesh,
245
+ causal=False,
246
+ sm_scale=1.0 / math.sqrt(self.head_dim),
247
+ vmem_limit_bytes=128 * 1024 * 1024,
248
+ )
249
+
250
+ def __call__(
251
+ self,
252
+ x: jax.Array,
253
+ rotary_pos_emb: jax.Array,
254
+ cu_window_seqlens: Optional[jax.Array] = None,
255
+ use_fullattn: bool = True,
256
+ ) -> jax.Array:
257
+ T, B, D = x.shape
258
+ assert B == 1, "Vision attention currently only supports batch size 1"
259
+ # [T, B, D] -> [T, B, 3 * D]
260
+ qkv = self.qkv_proj(x)
261
+
262
+ # Split into Q, K, V.
263
+ # NOTE: simplified from vLLM's split_qkv,
264
+ # may need to revisit for tp>1
265
+ # [T, B, 3 * D] -> 3 *[T, B, D]
266
+ q, k, v = jnp.split(qkv, 3, axis=-1)
267
+
268
+ # [T, B, N, H]
269
+ q = q.reshape(T, B, self.num_heads, self.head_dim)
270
+ k = k.reshape(T, B, self.num_heads, self.head_dim)
271
+ v = v.reshape(T, B, self.num_heads, self.head_dim)
272
+
273
+ # [T, B, N, H] -> [B, T, N, H]
274
+ q = jnp.transpose(q, (1, 0, 2, 3))
275
+ k = jnp.transpose(k, (1, 0, 2, 3))
276
+ v = jnp.transpose(v, (1, 0, 2, 3))
277
+
278
+ # rotary_pos_emb shape: (T, H)
279
+ q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
280
+ k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
281
+
282
+ # NOTE: an extra transpose because we need to
283
+ # align the correctness with vLLM's design.
284
+ # Might be able to remove one once implemented.
285
+ # [B, T, N, H] -> [B, N, T, H]
286
+ q = jnp.transpose(q, (0, 2, 1, 3))
287
+ k = jnp.transpose(k, (0, 2, 1, 3))
288
+ v = jnp.transpose(v, (0, 2, 1, 3))
289
+
290
+ # Pad the sequence length to be a multiple of 128 for flash_attention
291
+ block_k_major = DEFAULT_BLOCK_K_MAJOR
292
+ T_attn = q.shape[2]
293
+ padded_T = (T_attn + block_k_major -
294
+ 1) // block_k_major * block_k_major
295
+ pad_width = ((0, 0), (0, 0), (0, padded_T - T_attn), (0, 0))
296
+
297
+ q = jnp.pad(q, pad_width, 'constant')
298
+ k = jnp.pad(k, pad_width, 'constant')
299
+ v = jnp.pad(v, pad_width, 'constant')
300
+
301
+ segment_ids = generate_window_segment_ids(cu_window_seqlens, T_attn,
302
+ padded_T)
303
+
304
+ # TODO (jacobplatin): add support for quantized KV cache?
305
+ output = self.flash_attention(q, k, v, segment_ids)
306
+
307
+ # Unpad the output
308
+ output = output[:, :, :T_attn, :]
309
+
310
+ # [B, N, T, H] -> [T, B, N, H]
311
+ output = jnp.transpose(output, (2, 0, 1, 3))
312
+
313
+ output = output.reshape(T, B, D)
314
+
315
+ output = self.proj(output)
316
+
317
+ return output
318
+
319
+
320
+ class Qwen2_5_VisionBlock(nnx.Module):
321
+
322
+ def __init__(self, config: Qwen2_5_VLConfig, dtype: jnp.dtype,
323
+ rngs: nnx.Rngs, mesh: Mesh):
324
+ vision_config = config.vision_config
325
+ dim = vision_config.hidden_size
326
+ norm_layer = partial(nnx.RMSNorm,
327
+ epsilon=config.rms_norm_eps,
328
+ scale_init=nnx.with_partitioning(
329
+ init_fn, (None, )))
330
+
331
+ self.norm1 = norm_layer(dim, dtype=dtype, rngs=rngs)
332
+ self.norm2 = norm_layer(dim, dtype=dtype, rngs=rngs)
333
+ self.attn = Qwen2_5_VisionAttention(config=config,
334
+ dtype=dtype,
335
+ rngs=rngs,
336
+ mesh=mesh)
337
+ self.mlp = Qwen2_5_VisionMLP(config=vision_config,
338
+ dtype=dtype,
339
+ rngs=rngs)
340
+
341
+ def __call__(self,
342
+ x: jax.Array,
343
+ rotary_pos_emb: jax.Array,
344
+ cu_window_seqlens: Optional[jax.Array] = None,
345
+ use_fullattn: bool = True) -> jax.Array:
346
+
347
+ x = x + self.attn(self.norm1(x), rotary_pos_emb, cu_window_seqlens,
348
+ use_fullattn)
349
+ x = x + self.mlp(self.norm2(x))
350
+
351
+ return x
352
+
353
+
354
+ class Qwen2_5_VisionPatchEmbed(nnx.Module):
355
+
356
+ def __init__(
357
+ self,
358
+ rngs: nnx.Rngs,
359
+ patch_size: int = 14,
360
+ temporal_patch_size: int = 2,
361
+ in_channels: int = 3,
362
+ hidden_size: int = 1152,
363
+ dtype: jnp.dtype = jnp.bfloat16,
364
+ ) -> None:
365
+ self.patch_size = patch_size
366
+ self.temporal_patch_size = temporal_patch_size
367
+ self.hidden_size = hidden_size
368
+ kernel_size = (temporal_patch_size, patch_size, patch_size)
369
+ self.proj = nnx.Conv(in_features=in_channels,
370
+ out_features=hidden_size,
371
+ kernel_size=kernel_size,
372
+ strides=kernel_size,
373
+ use_bias=False,
374
+ param_dtype=dtype,
375
+ kernel_init=nnx.with_partitioning(
376
+ init_fn, (None, None, None, None, "model")),
377
+ rngs=rngs)
378
+
379
+ def __call__(self, x: jax.Array) -> jax.Array:
380
+ # x is (L, C * T * H * W)
381
+ L, dim = x.shape
382
+ C = dim // (self.temporal_patch_size * self.patch_size *
383
+ self.patch_size)
384
+ # Reshape to (L, T, H, W, C) for Conv3D with channels_last
385
+ x = x.reshape(L, C, self.temporal_patch_size, self.patch_size,
386
+ self.patch_size)
387
+ # L,T,H,W,C
388
+ x = jnp.transpose(x, (0, 2, 3, 4, 1))
389
+ x = self.proj(x)
390
+ # After conv, shape is (L, T_out, H_out, W_out, C_out)
391
+ # With stride=kernel_size, T_out=H_out=W_out=1.
392
+ # So shape is (L, 1, 1, 1, hidden_size)
393
+ x = x.reshape(L, self.hidden_size)
394
+ return x
395
+
396
+
397
+ class Qwen2_5_VisionPatchMerger(nnx.Module):
398
+
399
+ def __init__(self, d_model: int, context_dim: int, norm_layer: Callable,
400
+ spatial_merge_size: int, dtype: jnp.dtype, rngs: nnx.Rngs):
401
+ self.hidden_size = context_dim * (spatial_merge_size**2)
402
+ self.ln_q = norm_layer(context_dim,
403
+ dtype=dtype,
404
+ rngs=rngs,
405
+ scale_init=nnx.with_partitioning(
406
+ init_fn, (None, )))
407
+ self.mlp_fc1 = nnx.Linear(
408
+ self.hidden_size,
409
+ self.hidden_size,
410
+ use_bias=True,
411
+ param_dtype=dtype,
412
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
413
+ bias_init=nnx.with_partitioning(init_fn, ("model", )),
414
+ rngs=rngs)
415
+ self.mlp_act = modeling_flax_utils.ACT2FN["gelu"]
416
+ self.mlp_fc2 = nnx.Linear(
417
+ self.hidden_size,
418
+ d_model,
419
+ use_bias=True,
420
+ param_dtype=dtype,
421
+ kernel_init=nnx.with_partitioning(init_fn, ("model", None)),
422
+ bias_init=nnx.with_partitioning(init_fn, (None, )),
423
+ rngs=rngs)
424
+
425
+ def __call__(self, x: jax.Array) -> jax.Array:
426
+ x = self.ln_q(x)
427
+ x = x.reshape(-1, self.hidden_size)
428
+ x = self.mlp_fc1(x)
429
+ x = self.mlp_act(x)
430
+ x = self.mlp_fc2(x)
431
+ return x
432
+
433
+
434
+ class Qwen2_5_VisionRotaryEmbedding(nnx.Module):
435
+
436
+ def __init__(self, dim: int, theta: float = 10000.0):
437
+ self.dim = dim
438
+ self.theta = theta
439
+
440
+ def __call__(self, seqlen: int) -> jax.Array:
441
+ inv_freq = 1.0 / (self.theta**(
442
+ jnp.arange(0, self.dim, 2, dtype=jnp.float32) / self.dim))
443
+ seq = jnp.arange(seqlen, dtype=jnp.float32)
444
+ freqs = jnp.outer(seq, inv_freq)
445
+ return freqs.astype(jnp.bfloat16)
446
+
447
+
448
+ class Qwen2_5_VisionTransformer(nnx.Module):
449
+
450
+ def __init__(self,
451
+ vllm_config: VllmConfig,
452
+ rngs: nnx.Rngs,
453
+ mesh: Mesh,
454
+ norm_eps: float = 1e-6):
455
+ model_config = vllm_config.model_config
456
+ hf_config = model_config.hf_config
457
+ vision_config = hf_config.vision_config
458
+ dtype = model_config.dtype
459
+
460
+ self.config = vision_config
461
+ self.dtype = dtype
462
+
463
+ patch_size = vision_config.patch_size
464
+ temporal_patch_size = vision_config.temporal_patch_size
465
+ in_channels = vision_config.in_channels
466
+ self.hidden_size = vision_config.hidden_size
467
+ self.num_heads = vision_config.num_heads
468
+
469
+ # args for get_window_index_thw
470
+ self.window_size = vision_config.window_size
471
+ self.patch_size = vision_config.patch_size
472
+ self.spatial_merge_size = vision_config.spatial_merge_size
473
+ self.fullatt_block_indexes = vision_config.fullatt_block_indexes
474
+ self.spatial_merge_unit = self.spatial_merge_size**2
475
+
476
+ self.patch_embed = Qwen2_5_VisionPatchEmbed(
477
+ patch_size=patch_size,
478
+ temporal_patch_size=temporal_patch_size,
479
+ in_channels=in_channels,
480
+ hidden_size=self.hidden_size,
481
+ dtype=dtype,
482
+ rngs=rngs)
483
+
484
+ head_dim = vision_config.hidden_size // vision_config.num_heads
485
+ self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
486
+
487
+ self.blocks = [
488
+ Qwen2_5_VisionBlock(
489
+ config=hf_config,
490
+ dtype=dtype,
491
+ rngs=rngs,
492
+ mesh=mesh,
493
+ ) for _ in range(vision_config.depth)
494
+ ]
495
+ self.merger = Qwen2_5_VisionPatchMerger(
496
+ d_model=vision_config.out_hidden_size,
497
+ context_dim=vision_config.hidden_size,
498
+ norm_layer=partial(nnx.RMSNorm, epsilon=norm_eps),
499
+ spatial_merge_size=vision_config.spatial_merge_size,
500
+ dtype=dtype,
501
+ rngs=rngs)
502
+
503
+ additional_config = getattr(vllm_config, "additional_config",
504
+ None) or {}
505
+ self.enable_dynamic_image_sizes = additional_config.get(
506
+ "enable_dynamic_image_sizes", False)
507
+
508
+ def rotary_pos_emb_thw(self, t, h, w):
509
+ hpos_ids, wpos_ids = jnp.indices((h, w))
510
+ hpos_ids = hpos_ids.reshape(
511
+ h // self.spatial_merge_size,
512
+ self.spatial_merge_size,
513
+ w // self.spatial_merge_size,
514
+ self.spatial_merge_size,
515
+ ).transpose(0, 2, 1, 3).flatten()
516
+ wpos_ids = wpos_ids.reshape(
517
+ h // self.spatial_merge_size,
518
+ self.spatial_merge_size,
519
+ w // self.spatial_merge_size,
520
+ self.spatial_merge_size,
521
+ ).transpose(0, 2, 1, 3).flatten()
522
+ pos_ids = jnp.stack([hpos_ids, wpos_ids], axis=-1)
523
+ pos_ids = jnp.tile(pos_ids, (t, 1))
524
+
525
+ max_size = max(h, w)
526
+ rotary_pos_emb_full = self.rotary_pos_emb(max_size)
527
+
528
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].reshape(
529
+ pos_ids.shape[0], -1)
530
+ rotary_pos_emb = rotary_pos_emb.reshape(
531
+ rotary_pos_emb.shape[0] // self.spatial_merge_unit,
532
+ self.spatial_merge_unit, -1)
533
+
534
+ return rotary_pos_emb
535
+
536
+ def get_window_index_thw(self, grid_t, grid_h, grid_w):
537
+ vit_merger_window_size = (self.window_size //
538
+ self.spatial_merge_size // self.patch_size)
539
+
540
+ llm_grid_h = grid_h // self.spatial_merge_size
541
+ llm_grid_w = grid_w // self.spatial_merge_size
542
+
543
+ index = jnp.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
544
+ grid_t, llm_grid_h, llm_grid_w)
545
+
546
+ pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
547
+ pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
548
+ num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
549
+ num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
550
+
551
+ index_padded = jnp.pad(index, ((0, 0), (0, pad_h), (0, pad_w)),
552
+ constant_values=-100)
553
+ index_padded = index_padded.reshape(grid_t, num_windows_h,
554
+ vit_merger_window_size,
555
+ num_windows_w,
556
+ vit_merger_window_size)
557
+ index_padded = jnp.transpose(index_padded, (0, 1, 3, 2, 4)).reshape(
558
+ grid_t, num_windows_h * num_windows_w, vit_merger_window_size,
559
+ vit_merger_window_size)
560
+ seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
561
+ index_padded = index_padded.reshape(-1)
562
+ # The number of valid indices is static because grid_t, grid_h, grid_w
563
+ # are static.
564
+ num_valid_indices = grid_t * llm_grid_h * llm_grid_w
565
+ valid_indices = jnp.nonzero(index_padded != -100,
566
+ size=num_valid_indices)[0]
567
+ index_new = index_padded[valid_indices]
568
+ cu_seqlens_tmp = jnp.cumsum(seqlens) * self.spatial_merge_unit
569
+ cu_seqlens_tmp = cu_seqlens_tmp.astype(jnp.int32)
570
+
571
+ # NOTE (wenlong): Pytorch code uses this to reduce replication,
572
+ # but I don't think there is a need here, plus it would cause problem in JIT
573
+ # Please refer here if there is a problem down-stream
574
+ # cu_seqlens_tmp = jnp.unique(cu_seqlens_tmp)
575
+
576
+ return index_new, cu_seqlens_tmp
577
+
578
+ def get_rope_by_thw(self, t, h, w):
579
+ window_index_thw, cu_seqlens_window_thw = self.get_window_index_thw(
580
+ t, h, w)
581
+
582
+ rotary_pos_emb_thw = self.rotary_pos_emb_thw(t, h, w)
583
+
584
+ rotary_pos_emb_thw = rotary_pos_emb_thw[window_index_thw, :, :]
585
+ rotary_pos_emb_thw = rotary_pos_emb_thw.reshape(
586
+ -1, rotary_pos_emb_thw.shape[-1])
587
+ cu_seqlens_thw = jnp.full(t, h * w, dtype=jnp.int32)
588
+
589
+ return (rotary_pos_emb_thw, window_index_thw, cu_seqlens_window_thw,
590
+ cu_seqlens_thw)
591
+
592
+ def compute_attn_mask_seqlen(
593
+ self,
594
+ cu_seqlens: jax.Array,
595
+ ) -> tuple[Optional[int], Optional[list[int]]]:
596
+ max_seqlen, seqlens = None
597
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
598
+ seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
599
+ return max_seqlen, seqlens
600
+
601
+ def compute_aux_arrays(self, grid_thw: tuple[tuple[int, int, int]]):
602
+ # num of images/videoes
603
+ num_grids = len(grid_thw)
604
+
605
+ rotary_pos_emb = []
606
+ window_index: list = []
607
+ cu_window_seqlens: list = [jnp.array([0], dtype=jnp.int32)]
608
+ cu_seqlens: list = []
609
+
610
+ window_index_id = 0
611
+ cu_window_seqlens_last = 0
612
+ for i in range(num_grids):
613
+ t, h, w = grid_thw[i]
614
+
615
+ llm_h = h // self.spatial_merge_size
616
+ llm_w = w // self.spatial_merge_size
617
+
618
+ (
619
+ rotary_pos_emb_thw,
620
+ window_index_thw,
621
+ cu_seqlens_window_thw,
622
+ cu_seqlens_thw,
623
+ ) = self.get_rope_by_thw(t, h, w)
624
+
625
+ window_index.append(window_index_thw + window_index_id)
626
+ window_index_id += (t * llm_h * llm_w)
627
+
628
+ cu_seqlens_window_thw = (cu_seqlens_window_thw +
629
+ cu_window_seqlens_last)
630
+ cu_window_seqlens_last = cu_seqlens_window_thw[-1]
631
+ cu_window_seqlens.append(cu_seqlens_window_thw)
632
+
633
+ rotary_pos_emb.append(rotary_pos_emb_thw)
634
+
635
+ cu_seqlens.append(cu_seqlens_thw)
636
+
637
+ rotary_pos_emb = jnp.concatenate(rotary_pos_emb, axis=0)
638
+ window_index = jnp.concatenate(window_index, axis=0)
639
+ cu_window_seqlens = jnp.concatenate(cu_window_seqlens, axis=0)
640
+
641
+ cu_seqlens = jnp.concatenate(cu_seqlens, axis=0)
642
+ cu_seqlens = jnp.cumsum(cu_seqlens, axis=0, dtype=jnp.int32)
643
+ cu_seqlens = jnp.pad(cu_seqlens, ((1, 0), ),
644
+ mode='constant',
645
+ constant_values=0)
646
+ return window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens
647
+
648
+ def pad_inputs(self, x, window_index, rotary_pos_emb, cu_seqlens,
649
+ cu_window_seqlens):
650
+ # padding
651
+ num_patches = int(rotary_pos_emb.shape[0])
652
+ bucket_num_patches = 1 << (num_patches - 1).bit_length()
653
+ num_tokens = window_index.shape[0]
654
+ bucket_num_tokens = bucket_num_patches // self.spatial_merge_unit
655
+ vit_merger_window_size = (self.window_size //
656
+ self.spatial_merge_size // self.patch_size)
657
+ max_windows = (bucket_num_tokens // vit_merger_window_size) + 2
658
+
659
+ rotary_pos_emb = jnp.pad(rotary_pos_emb,
660
+ ((0, bucket_num_patches - num_patches),
661
+ (0, 0)))
662
+ window_index = jnp.concatenate([
663
+ window_index,
664
+ jnp.arange(num_tokens, bucket_num_tokens, dtype=jnp.int32)
665
+ ])
666
+ cu_window_seqlens = jnp.append(cu_window_seqlens, bucket_num_patches)
667
+ pad_w = max(0, max_windows + 1 - cu_window_seqlens.shape[0])
668
+ cu_window_seqlens = jnp.pad(cu_window_seqlens, (0, pad_w), mode='edge')
669
+ cu_seqlens = jnp.append(cu_seqlens, bucket_num_patches)
670
+
671
+ x_padded = jnp.pad(x, ((0, bucket_num_patches - x.shape[0]), (0, 0)))
672
+
673
+ return x_padded, window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens, num_tokens
674
+
675
+ def compute_hidden_states(self, x: jax.Array, window_index: jax.Array,
676
+ rotary_pos_emb: jax.Array, cu_seqlens: jax.Array,
677
+ cu_window_seqlens: jax.Array) -> jax.Array:
678
+ hidden_states = self.patch_embed(x)
679
+
680
+ # num of patches
681
+ seq_len = x.shape[0]
682
+
683
+ hidden_states = hidden_states.reshape(
684
+ seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
685
+ hidden_states = hidden_states[window_index, :, :]
686
+ hidden_states = hidden_states.reshape(seq_len, -1)
687
+
688
+ hidden_states = jnp.expand_dims(hidden_states, axis=1)
689
+
690
+ for layer_num, blk in enumerate(self.blocks):
691
+ if layer_num in self.fullatt_block_indexes:
692
+ hidden_states = blk(hidden_states,
693
+ rotary_pos_emb=rotary_pos_emb,
694
+ cu_window_seqlens=cu_seqlens,
695
+ use_fullattn=True)
696
+ else:
697
+ hidden_states = blk(hidden_states,
698
+ rotary_pos_emb=rotary_pos_emb,
699
+ cu_window_seqlens=cu_window_seqlens,
700
+ use_fullattn=False)
701
+
702
+ # adapter
703
+ hidden_states = self.merger(hidden_states)
704
+ reverse_indices = jnp.argsort(window_index)
705
+ hidden_states = hidden_states[reverse_indices, :]
706
+ return hidden_states
707
+
708
+ @jax.jit
709
+ def encode_padded_jit(self, x_padded, window_index, rotary_pos_emb,
710
+ cu_seqlens, cu_window_seqlens):
711
+ return self.compute_hidden_states(x_padded, window_index,
712
+ rotary_pos_emb, cu_seqlens,
713
+ cu_window_seqlens)
714
+
715
+ @partial(
716
+ jax.jit,
717
+ static_argnames=("grid_thw", ),
718
+ )
719
+ def encode_jit(self, x, grid_thw):
720
+ window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens = self.compute_aux_arrays(
721
+ grid_thw)
722
+ return self.compute_hidden_states(x, window_index, rotary_pos_emb,
723
+ cu_seqlens, cu_window_seqlens)
724
+
725
+ def __call__(self, x: jax.Array, grid_thw: tuple[tuple[int, int,
726
+ int]]) -> jax.Array:
727
+ # x: pixel_values: jax.Array
728
+ # """Shape:
729
+ # `(num_patches, num_channels * patch_size * patch_size)`
730
+ # """
731
+
732
+ # grid_thw: image_grid_thw: jax.Array
733
+ # """Shape: `(num_images, 3)`
734
+ # This should be in `(grid_t, grid_h, grid_w)` format.
735
+ # """
736
+ if self.enable_dynamic_image_sizes:
737
+ window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens = self.compute_aux_arrays(
738
+ grid_thw)
739
+ x_padded, window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens, num_tokens = self.pad_inputs(
740
+ x, window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens)
741
+
742
+ hidden_states = self.encode_padded_jit(x_padded, window_index,
743
+ rotary_pos_emb, cu_seqlens,
744
+ cu_window_seqlens)
745
+ return hidden_states[:num_tokens]
746
+
747
+ else:
748
+ return self.encode_jit(x, grid_thw)
749
+
750
+
751
+ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
752
+
753
+ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array,
754
+ mesh: Mesh) -> None:
755
+ config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
756
+ multimodal_config = vllm_config.model_config.multimodal_config
757
+
758
+ self.vllm_config = vllm_config
759
+ self.rng = nnx.Rngs(rng_key)
760
+ self.mesh = mesh
761
+
762
+ self.config = config
763
+ self.multimodal_config = multimodal_config
764
+
765
+ self.visual = Qwen2_5_VisionTransformer(
766
+ vllm_config=vllm_config,
767
+ rngs=self.rng,
768
+ mesh=mesh,
769
+ norm_eps=getattr(config, "rms_norm_eps", 1e-6),
770
+ )
771
+ self.language_model = Qwen2ForCausalLM(vllm_config, rng_key, mesh)
772
+
773
+ def get_mrope_input_positions(
774
+ self,
775
+ input_tokens: list[int],
776
+ hf_config,
777
+ image_grid_thw,
778
+ video_grid_thw,
779
+ second_per_grid_ts: list[float],
780
+ context_len: int = 0,
781
+ seq_len: int | None = None,
782
+ audio_feature_lengths=None,
783
+ use_audio_in_video: bool = False,
784
+ ) -> tuple[jax.Array, int]:
785
+ """Get mrope input positions and delta value."""
786
+
787
+ image_token_id = hf_config.image_token_id
788
+ video_token_id = hf_config.video_token_id
789
+ vision_start_token_id = hf_config.vision_start_token_id
790
+ spatial_merge_size = hf_config.vision_config.spatial_merge_size
791
+ tokens_per_second = getattr(hf_config.vision_config,
792
+ "tokens_per_second", 1.0)
793
+
794
+ input_tokens_tensor = np.array(input_tokens)
795
+ vision_start_indices = np.argwhere(
796
+ input_tokens_tensor == vision_start_token_id).squeeze(1)
797
+ vision_tokens = input_tokens_tensor[vision_start_indices + 1]
798
+ image_nums = np.sum(vision_tokens == image_token_id)
799
+ video_nums = np.sum(vision_tokens == video_token_id)
800
+ llm_pos_ids_list: list = []
801
+
802
+ st = 0
803
+ remain_images, remain_videos = image_nums, video_nums
804
+
805
+ image_index, video_index = 0, 0
806
+ for _ in range(image_nums + video_nums):
807
+ video_second_per_grid_t = 0.0
808
+ if remain_images > 0:
809
+ try:
810
+ ed_image = input_tokens.index(image_token_id, st)
811
+ except ValueError:
812
+ ed_image = len(input_tokens) + 1
813
+ else:
814
+ ed_image = len(input_tokens) + 1
815
+ if remain_videos > 0:
816
+ try:
817
+ ed_video = input_tokens.index(video_token_id, st)
818
+ except ValueError:
819
+ ed_video = len(input_tokens) + 1
820
+ else:
821
+ ed_video = len(input_tokens) + 1
822
+ if ed_image < ed_video:
823
+ t, h, w = (
824
+ image_grid_thw[image_index][0],
825
+ image_grid_thw[image_index][1],
826
+ image_grid_thw[image_index][2],
827
+ )
828
+ image_index += 1
829
+ remain_images -= 1
830
+ ed = ed_image
831
+ else:
832
+ t, h, w = (
833
+ video_grid_thw[video_index][0],
834
+ video_grid_thw[video_index][1],
835
+ video_grid_thw[video_index][2],
836
+ )
837
+ video_second_per_grid_t = 1.0
838
+ if second_per_grid_ts:
839
+ video_second_per_grid_t = second_per_grid_ts[video_index]
840
+ video_index += 1
841
+ remain_videos -= 1
842
+ ed = ed_video
843
+
844
+ llm_grid_t, llm_grid_h, llm_grid_w = (
845
+ t,
846
+ h // spatial_merge_size,
847
+ w // spatial_merge_size,
848
+ )
849
+ text_len = ed - st
850
+
851
+ st_idx = llm_pos_ids_list[-1].max().item() + 1 if len(
852
+ llm_pos_ids_list) > 0 else 0
853
+ llm_pos_ids_list.append(
854
+ jnp.broadcast_to(
855
+ jnp.arange(text_len, dtype=jnp.int32).reshape(1, -1),
856
+ (3, text_len)) + st_idx)
857
+
858
+ t_index = ((jnp.broadcast_to(
859
+ jnp.arange(llm_grid_t, dtype=jnp.int32).reshape(-1, 1),
860
+ (llm_grid_t, llm_grid_h * llm_grid_w)) *
861
+ video_second_per_grid_t * tokens_per_second).astype(
862
+ jnp.int32).flatten())
863
+
864
+ h_index = (jnp.broadcast_to(
865
+ jnp.arange(llm_grid_h, dtype=jnp.int32).reshape(1, -1, 1),
866
+ (llm_grid_t, llm_grid_h, llm_grid_w)).flatten())
867
+ w_index = (jnp.broadcast_to(
868
+ jnp.arange(llm_grid_w, dtype=jnp.int32).reshape(1, 1, -1),
869
+ (llm_grid_t, llm_grid_h, llm_grid_w)).flatten())
870
+
871
+ llm_pos_ids_list.append(
872
+ jnp.stack([t_index, h_index, w_index]) + text_len + st_idx)
873
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
874
+
875
+ if st < len(input_tokens):
876
+ st_idx = llm_pos_ids_list[-1].max().item() + 1 if len(
877
+ llm_pos_ids_list) > 0 else 0
878
+ text_len = len(input_tokens) - st
879
+
880
+ llm_pos_ids_list.append(
881
+ jnp.broadcast_to(
882
+ jnp.arange(text_len, dtype=jnp.int32).reshape(1, -1),
883
+ (3, text_len)) + st_idx)
884
+
885
+ llm_positions = jnp.concatenate(llm_pos_ids_list,
886
+ axis=1).reshape(3, -1)
887
+ mrope_position_delta = (llm_positions.max() + 1 -
888
+ len(input_tokens)).item()
889
+ llm_positions = llm_positions[:, context_len:seq_len]
890
+
891
+ return llm_positions, mrope_position_delta
892
+
893
+ def _validate_and_reshape_mm_tensor(self, mm_input: object,
894
+ name: str) -> jax.Array:
895
+ if isinstance(mm_input, list):
896
+ # Assuming it's a list of arrays (e.g., np.ndarray, torch.Tensor)
897
+ # that can be concatenated.
898
+ arrays_to_concat = [jnp.asarray(item) for item in mm_input]
899
+ return jnp.concatenate(arrays_to_concat, axis=0)
900
+
901
+ # Handle single array-like objects (np.ndarray, torch.Tensor, jax.Array)
902
+ if hasattr(mm_input, 'ndim'):
903
+ array_input = jnp.asarray(mm_input)
904
+ if array_input.ndim == 2:
905
+ return array_input
906
+ if array_input.ndim == 3:
907
+ # This reshapes the batched 3D tensor to a 2D tensor.
908
+ return array_input.reshape(-1, array_input.shape[-1])
909
+
910
+ raise ValueError(f"Incorrect type of {name}. "
911
+ f"Got type: {type(mm_input)}")
912
+
913
+ def _parse_and_validate_image_input(
914
+ self, image_grid_thw: tuple[tuple[int, int, int], ...],
915
+ **kwargs: object) -> Optional[Qwen2_5_VLImageInputs]:
916
+ pixel_values = kwargs.pop("pixel_values", None)
917
+ image_embeds = kwargs.pop("image_embeds", None)
918
+ # image_grid_thw = kwargs.pop("image_grid_thw", None)
919
+
920
+ if pixel_values is None and image_embeds is None:
921
+ return None
922
+
923
+ if pixel_values is not None:
924
+ pixel_values = self._validate_and_reshape_mm_tensor(
925
+ pixel_values, "image pixel values")
926
+ # image_grid_thw = self._validate_and_reshape_mm_tensor(
927
+ # image_grid_thw, "image grid_thw")
928
+
929
+ if not isinstance(pixel_values, jax.Array):
930
+ raise ValueError("Incorrect type of image pixel values. "
931
+ f"Got type: {type(pixel_values)}")
932
+
933
+ return Qwen2_5_VLImagePixelInputs(type="pixel_values",
934
+ pixel_values=pixel_values,
935
+ image_grid_thw=image_grid_thw)
936
+
937
+ # Note: comment them out for now and save for future support
938
+ # if image_embeds is not None:
939
+ # image_embeds = self._validate_and_reshape_mm_tensor(
940
+ # image_embeds, "image embeds")
941
+ # image_grid_thw = self._validate_and_reshape_mm_tensor(
942
+ # image_grid_thw, "image grid_thw")
943
+
944
+ # if not isinstance(image_embeds, jax.Array):
945
+ # raise ValueError("Incorrect type of image embeddings. "
946
+ # f"Got type: {type(image_embeds)}")
947
+ # return Qwen2_5_VLImageEmbeddingInputs(
948
+ # type="image_embeds",
949
+ # image_embeds=image_embeds,
950
+ # image_grid_thw=image_grid_thw)
951
+
952
+ def _parse_and_validate_multimodal_inputs(self,
953
+ image_grid_thw: tuple[tuple[int,
954
+ int,
955
+ int],
956
+ ...],
957
+ **kwargs: object) -> dict:
958
+ mm_input_by_modality = {}
959
+
960
+ # Preserve the order of modalities if there are multiple of them
961
+ # from the order of kwargs.
962
+ for input_key in kwargs:
963
+ if input_key in ("pixel_values", "image_embeds"
964
+ ) and "image" not in mm_input_by_modality:
965
+ mm_input_by_modality[
966
+ "image"] = self._parse_and_validate_image_input(
967
+ image_grid_thw, **kwargs)
968
+ # if input_key in ("pixel_values_videos", "video_embeds"
969
+ # ) and "video" not in mm_input_by_modality:
970
+ # mm_input_by_modality[
971
+ # "video"] = self._parse_and_validate_video_input(**kwargs)
972
+ return mm_input_by_modality
973
+
974
+ def get_single_image_embedding(self, image_pixel_values, image_grid_thw):
975
+ return self.visual(image_pixel_values, (image_grid_thw, ))
976
+
977
+ def _process_image_input(
978
+ self, image_input: Qwen2_5_VLImageInputs) -> tuple[jax.Array, ...]:
979
+
980
+ grid_thw = image_input["image_grid_thw"]
981
+
982
+ if image_input["type"] == "image_embeds":
983
+ image_embeds = image_input["image_embeds"].astype(
984
+ self.visual.dtype)
985
+ else:
986
+ pixel_values = image_input["pixel_values"]
987
+ image_embeds = []
988
+ current_idx = 0
989
+ for image_thw in grid_thw:
990
+ t, h, w = image_thw
991
+ image_size = t * h * w
992
+ end_idx = current_idx + image_size
993
+ image_pixel_values = pixel_values[current_idx:end_idx, :]
994
+ image_embeds.append(
995
+ self.get_single_image_embedding(image_pixel_values,
996
+ image_thw))
997
+ current_idx = end_idx
998
+ image_embeds = jnp.concatenate(image_embeds, axis=0)
999
+
1000
+ # Split concatenated embeddings for each image item.
1001
+ merge_size = self.visual.config.spatial_merge_size
1002
+ sizes = np.prod(np.array(grid_thw, dtype=np.int64),
1003
+ axis=-1) // merge_size // merge_size
1004
+
1005
+ if sizes.size == 0:
1006
+ return ()
1007
+ if sizes.size == 1:
1008
+ return (image_embeds, )
1009
+
1010
+ split_indices = np.cumsum(sizes)[:-1]
1011
+ return tuple(jnp.split(image_embeds, split_indices))
1012
+
1013
+ def get_multimodal_embeddings(self, image_grid_thw: tuple[tuple[int, int,
1014
+ int], ...],
1015
+ **kwargs: object) -> MultiModalEmbeddings:
1016
+
1017
+ mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
1018
+ image_grid_thw, **kwargs)
1019
+ if not mm_input_by_modality:
1020
+ return []
1021
+
1022
+ # The result multimodal_embeddings is tuple of tensors, with each
1023
+ # tensor correspoending to a multimodal data item (image or video).
1024
+ multimodal_embeddings: tuple[jax.Array, ...] = ()
1025
+
1026
+ # NOTE: It is important to iterate over the keys in this dictionary
1027
+ # to preserve the order of the modalities.
1028
+ for modality in mm_input_by_modality:
1029
+ multimodal_input = mm_input_by_modality[modality]
1030
+ if modality == "image":
1031
+ vision_embeddings = self._process_image_input(multimodal_input)
1032
+ multimodal_embeddings += vision_embeddings
1033
+ # if modality == "video":
1034
+ # video_embeddings = self._process_video_input(multimodal_input)
1035
+ # multimodal_embeddings += video_embeddings
1036
+
1037
+ return multimodal_embeddings
1038
+
1039
+ def get_input_embeddings(
1040
+ self, input_ids: jax.Array,
1041
+ multimodal_embeddings: Optional[jax.Array]) -> jax.Array:
1042
+
1043
+ inputs_embeds = self.language_model.model.embed(input_ids)
1044
+
1045
+
1046
+ if multimodal_embeddings is not None \
1047
+ and multimodal_embeddings.shape[0] != 0:
1048
+ inputs_embeds = merge_multimodal_embeddings(
1049
+ input_ids, inputs_embeds, multimodal_embeddings,
1050
+ [self.config.image_token_id, self.config.video_token_id])
1051
+
1052
+ return inputs_embeds
1053
+
1054
+ def __call__(
1055
+ self,
1056
+ kv_caches: list[jax.Array],
1057
+ input_ids: Optional[jax.Array],
1058
+ attention_metadata: AttentionMetadata,
1059
+ inputs_embeds: Optional[jax.Array] = None,
1060
+ *args,
1061
+ ) -> tuple[list[jax.Array], jax.Array, List[jax.Array]]:
1062
+ # The logic of choosing between input_ids and inputs_embeds is
1063
+ # handled inside self.language_model.__call__
1064
+ kv_caches, x, [] = self.language_model(
1065
+ kv_caches=kv_caches,
1066
+ input_ids=input_ids,
1067
+ attention_metadata=attention_metadata,
1068
+ inputs_embeds=inputs_embeds,
1069
+ )
1070
+ return kv_caches, x, []
1071
+
1072
+ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
1073
+ return self.language_model.compute_logits(hidden_states)
1074
+
1075
+ def load_weights(self, rng_key: jax.Array) -> None:
1076
+ self.rng = nnx.Rngs(rng_key)
1077
+ self.language_model.rng = self.rng
1078
+
1079
+ # Key: path to a HF layer weight
1080
+ # Value: a tuple of (path to a nnx layer weight, nnx weight sharding)
1081
+
1082
+ mappings = {
1083
+ "model.embed_tokens": "language_model.model.embed.embedding",
1084
+ "model.layers.*.input_layernorm":
1085
+ "language_model.model.layers.*.input_layernorm.scale",
1086
+ "model.layers.*.mlp.down_proj":
1087
+ "language_model.model.layers.*.mlp.down_proj.kernel",
1088
+ "model.layers.*.mlp.gate_proj":
1089
+ "language_model.model.layers.*.mlp.gate_proj.kernel",
1090
+ "model.layers.*.mlp.up_proj":
1091
+ "language_model.model.layers.*.mlp.up_proj.kernel",
1092
+ "model.layers.*.post_attention_layernorm":
1093
+ "language_model.model.layers.*.post_attention_layernorm.scale",
1094
+ "model.layers.*.self_attn.k_proj":
1095
+ "language_model.model.layers.*.self_attn.k_proj.kernel",
1096
+ "model.layers.*.self_attn.o_proj":
1097
+ "language_model.model.layers.*.self_attn.o_proj.kernel",
1098
+ "model.layers.*.self_attn.q_proj":
1099
+ "language_model.model.layers.*.self_attn.q_proj.kernel",
1100
+ "model.layers.*.self_attn.v_proj":
1101
+ "language_model.model.layers.*.self_attn.v_proj.kernel",
1102
+ "model.layers.*.self_attn.q_proj.bias":
1103
+ "language_model.model.layers.*.self_attn.q_proj.bias",
1104
+ "model.layers.*.self_attn.k_proj.bias":
1105
+ "language_model.model.layers.*.self_attn.k_proj.bias",
1106
+ "model.layers.*.self_attn.v_proj.bias":
1107
+ "language_model.model.layers.*.self_attn.v_proj.bias",
1108
+ "model.norm": "language_model.model.norm.scale",
1109
+ "visual.blocks.*.attn.proj.bias": "visual.blocks.*.attn.proj.bias",
1110
+ "visual.blocks.*.attn.proj": "visual.blocks.*.attn.proj.kernel",
1111
+ "visual.blocks.*.attn.qkv.bias":
1112
+ "visual.blocks.*.attn.qkv_proj.bias",
1113
+ "visual.blocks.*.attn.qkv": "visual.blocks.*.attn.qkv_proj.kernel",
1114
+ "visual.blocks.*.mlp.down_proj.bias":
1115
+ "visual.blocks.*.mlp.down_proj.bias",
1116
+ "visual.blocks.*.mlp.down_proj":
1117
+ "visual.blocks.*.mlp.down_proj.kernel",
1118
+ "visual.blocks.*.mlp.gate_proj.bias":
1119
+ "visual.blocks.*.mlp.gate_proj.bias",
1120
+ "visual.blocks.*.mlp.gate_proj":
1121
+ "visual.blocks.*.mlp.gate_proj.kernel",
1122
+ "visual.blocks.*.mlp.up_proj.bias":
1123
+ "visual.blocks.*.mlp.up_proj.bias",
1124
+ "visual.blocks.*.mlp.up_proj":
1125
+ "visual.blocks.*.mlp.up_proj.kernel",
1126
+ "visual.blocks.*.norm1": "visual.blocks.*.norm1.scale",
1127
+ "visual.blocks.*.norm2": "visual.blocks.*.norm2.scale",
1128
+ "visual.merger.ln_q": "visual.merger.ln_q.scale",
1129
+ "visual.merger.mlp.0.bias": "visual.merger.mlp_fc1.bias",
1130
+ "visual.merger.mlp.0": "visual.merger.mlp_fc1.kernel",
1131
+ "visual.merger.mlp.2.bias": "visual.merger.mlp_fc2.bias",
1132
+ "visual.merger.mlp.2": "visual.merger.mlp_fc2.kernel",
1133
+ "visual.patch_embed.proj": "visual.patch_embed.proj.kernel",
1134
+ }
1135
+
1136
+ # Add lm_head mapping only if it's not tied to embeddings
1137
+ hf_config = self.vllm_config.model_config.hf_config
1138
+ if not hf_config.tie_word_embeddings:
1139
+ mappings.update({
1140
+ "lm_head": "language_model.model.lm_head",
1141
+ })
1142
+
1143
+ metadata_map = get_default_maps(self.vllm_config.model_config,
1144
+ self.mesh, mappings)
1145
+ load_hf_weights(vllm_config=self.vllm_config,
1146
+ model=self,
1147
+ metadata_map=metadata_map,
1148
+ mesh=self.mesh)
1149
+
1150
+ def precompile_vision_encoder(
1151
+ self,
1152
+ run_compilation_fn: Callable,
1153
+ ) -> None:
1154
+ vc = self.vllm_config.model_config.hf_config.vision_config
1155
+ patch_input_dim = vc.in_channels * vc.temporal_patch_size * vc.patch_size * vc.patch_size
1156
+ if self.visual.enable_dynamic_image_sizes:
1157
+ spatial_merge_unit = vc.spatial_merge_size**2
1158
+ max_num_batched_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens
1159
+ mm_kwargs = self.vllm_config.model_config.multimodal_config.mm_processor_kwargs or {}
1160
+ limit_pixels = float(mm_kwargs.get("max_pixels", float('inf')))
1161
+
1162
+ max_patches = int(
1163
+ min(max_num_batched_tokens * spatial_merge_unit,
1164
+ limit_pixels / (vc.patch_size**2)))
1165
+
1166
+ num_patches_paddings = [
1167
+ 1 << i for i in range(4, (max_patches - 1).bit_length() + 1)
1168
+ ]
1169
+ rotary_dim = vc.hidden_size // vc.num_heads // 2
1170
+ vit_merger_window_size = (vc.window_size //
1171
+ vc.spatial_merge_size // vc.patch_size)
1172
+
1173
+ for num_patches in num_patches_paddings:
1174
+ dummy_x_padded = jnp.ones(
1175
+ (num_patches, patch_input_dim),
1176
+ dtype=self.vllm_config.model_config.dtype)
1177
+
1178
+ num_tokens = num_patches // spatial_merge_unit
1179
+ dummy_window_index = jnp.arange(num_tokens, dtype=jnp.int32)
1180
+
1181
+ dummy_rotary_pos_emb = jnp.ones(
1182
+ (num_patches, rotary_dim),
1183
+ dtype=self.vllm_config.model_config.dtype)
1184
+
1185
+ dummy_cu_seqlens = jnp.array([0, num_patches, num_patches],
1186
+ dtype=jnp.int32)
1187
+
1188
+ max_windows = (num_tokens // vit_merger_window_size) + 2
1189
+ patches_per_window = (vit_merger_window_size**
1190
+ 2) * spatial_merge_unit
1191
+ dummy_cu_window_seqlens = jnp.arange(
1192
+ max_windows + 1, dtype=jnp.int32) * patches_per_window
1193
+ dummy_cu_window_seqlens = jnp.minimum(dummy_cu_window_seqlens,
1194
+ num_patches)
1195
+
1196
+ run_compilation_fn("vision_encoder_padded",
1197
+ self.visual.encode_padded_jit,
1198
+ dummy_x_padded,
1199
+ dummy_window_index,
1200
+ dummy_rotary_pos_emb,
1201
+ dummy_cu_seqlens,
1202
+ dummy_cu_window_seqlens,
1203
+ num_patches=num_patches)
1204
+ else:
1205
+ image_shapes = []
1206
+ if (warmup_config := self.vllm_config.additional_config.get(
1207
+ "vision_warmup_config")):
1208
+ image_shapes = warmup_config.get("image_shapes")
1209
+
1210
+ factor = vc.patch_size * vc.spatial_merge_size
1211
+ for input_hw in image_shapes:
1212
+ if not isinstance(input_hw, list) or len(input_hw) != 2:
1213
+ logger.warning(f"Skipping invalid shape {input_hw}.")
1214
+ continue
1215
+ h_input, w_input = input_hw
1216
+ h_processed = round(h_input / factor) * factor
1217
+ w_processed = round(w_input / factor) * factor
1218
+ t, h, w = 1, h_processed // vc.patch_size, w_processed // vc.patch_size
1219
+ grid_thw = (t, h, w)
1220
+ num_patches = t * h * w
1221
+
1222
+ dummy_pixel_values = jnp.ones(
1223
+ (num_patches, patch_input_dim),
1224
+ self.vllm_config.model_config.dtype,
1225
+ )
1226
+ dummy_grid_thw = (grid_thw, )
1227
+
1228
+ run_compilation_fn("vision_encoder",
1229
+ self.visual.encode_jit,
1230
+ dummy_pixel_values,
1231
+ dummy_grid_thw,
1232
+ image_shape=input_hw)