tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,318 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Optional, Tuple
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ from flax import nnx
20
+ from jax.sharding import Mesh
21
+ from transformers import Qwen3Config
22
+ from vllm.config import VllmConfig
23
+
24
+ from tpu_inference import utils
25
+ from tpu_inference.layers.common.attention_interface import attention
26
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
27
+ from tpu_inference.layers.common.quantization import quantize_kv
28
+ from tpu_inference.layers.jax.rope_interface import apply_rope
29
+ from tpu_inference.logger import init_logger
30
+ from tpu_inference.models.jax.qwen2 import Qwen2DecoderLayer
31
+ from tpu_inference.models.jax.qwen2 import Qwen2MLP as Qwen3MLP
32
+ from tpu_inference.models.jax.qwen2 import Qwen2Model
33
+ from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
34
+ load_hf_weights)
35
+
36
+ logger = init_logger(__name__)
37
+
38
+ init_fn = nnx.initializers.uniform()
39
+
40
+
41
+ class Qwen3Attention(nnx.Module):
42
+
43
+ def __init__(self, config: Qwen3Config, dtype: jnp.dtype, rng: nnx.Rngs,
44
+ mesh: Mesh, kv_cache_dtype: str):
45
+ self.hidden_size = config.hidden_size
46
+ self.num_heads = config.num_attention_heads
47
+ self.num_kv_heads = config.num_key_value_heads
48
+ self.rope_theta = config.rope_theta
49
+ self.rope_scaling = getattr(config, "rope_scaling", None)
50
+ self.rms_norm_eps = config.rms_norm_eps
51
+
52
+ self.head_dim_original = getattr(config, "head_dim",
53
+ self.hidden_size // self.num_heads)
54
+ self.head_dim = utils.get_padded_head_dim(self.head_dim_original)
55
+
56
+ sharding_size = mesh.shape["model"]
57
+ self.num_heads = utils.get_padded_num_heads(self.num_heads,
58
+ sharding_size)
59
+ self.num_kv_heads = utils.get_padded_num_heads(self.num_kv_heads,
60
+ sharding_size)
61
+
62
+ self.mesh = mesh
63
+
64
+ self.q_proj = nnx.Einsum(
65
+ "TD,DNH->TNH",
66
+ (self.hidden_size, self.num_heads, self.head_dim),
67
+ param_dtype=dtype,
68
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model", None)),
69
+ rngs=rng,
70
+ )
71
+ self.q_norm = nnx.RMSNorm(
72
+ self.head_dim,
73
+ epsilon=self.rms_norm_eps,
74
+ param_dtype=dtype,
75
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
76
+ rngs=rng,
77
+ )
78
+ self.k_proj = nnx.Einsum(
79
+ "TD,DKH->TKH",
80
+ (self.hidden_size, self.num_kv_heads, self.head_dim),
81
+ param_dtype=dtype,
82
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model", None)),
83
+ rngs=rng,
84
+ )
85
+ self.k_norm = nnx.RMSNorm(
86
+ self.head_dim,
87
+ epsilon=self.rms_norm_eps,
88
+ param_dtype=dtype,
89
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
90
+ rngs=rng,
91
+ )
92
+ self.v_proj = nnx.Einsum(
93
+ "TD,DKH->TKH",
94
+ (self.hidden_size, self.num_kv_heads, self.head_dim),
95
+ param_dtype=dtype,
96
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model", None)),
97
+ rngs=rng,
98
+ )
99
+ self.o_proj = nnx.Einsum(
100
+ "TNH,NHD->TD",
101
+ (self.num_heads, self.head_dim, self.hidden_size),
102
+ param_dtype=dtype,
103
+ kernel_init=nnx.with_partitioning(init_fn, ("model", None, None)),
104
+ rngs=rng,
105
+ )
106
+
107
+ self._q_scale = 1.0
108
+ self._k_scale = 1.0
109
+ self._v_scale = 1.0
110
+ self.kv_cache_quantized_dtype = None
111
+ if kv_cache_dtype != "auto":
112
+ self.kv_cache_quantized_dtype = utils.get_jax_dtype_from_str_dtype(
113
+ kv_cache_dtype)
114
+
115
+ def __call__(
116
+ self,
117
+ kv_cache: Optional[jax.Array],
118
+ x: jax.Array,
119
+ attention_metadata: AttentionMetadata,
120
+ ) -> Tuple[jax.Array, jax.Array]:
121
+ md = attention_metadata
122
+ # q: (T, N, H)
123
+ q = self.q_proj(x)
124
+ q = self.q_norm(q)
125
+ q = apply_rope(q, md.input_positions, self.head_dim_original,
126
+ self.rope_theta, self.rope_scaling)
127
+
128
+ # k: (T, K, H)
129
+ k = self.k_proj(x)
130
+ k = self.k_norm(k)
131
+ k = apply_rope(k, md.input_positions, self.head_dim_original,
132
+ self.rope_theta, self.rope_scaling)
133
+
134
+ # v: (T, K, H)
135
+ v = self.v_proj(x)
136
+ # o: (T, N, H)
137
+ q_scale = k_scale = v_scale = None
138
+ if self.kv_cache_quantized_dtype:
139
+ # TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
140
+ # q_scale = self._q_scale
141
+ k_scale = self._k_scale
142
+ v_scale = self._v_scale
143
+ k, v = quantize_kv(self.kv_cache_quantized_dtype, k, v, k_scale,
144
+ v_scale)
145
+ new_kv_cache, outputs = attention(
146
+ kv_cache,
147
+ q,
148
+ k,
149
+ v,
150
+ attention_metadata,
151
+ self.mesh,
152
+ self.head_dim_original,
153
+ q_scale=q_scale,
154
+ k_scale=k_scale,
155
+ v_scale=v_scale,
156
+ )
157
+ # (T, D)
158
+ o = self.o_proj(outputs)
159
+ return new_kv_cache, o
160
+
161
+
162
+ class Qwen3DecoderLayer(Qwen2DecoderLayer):
163
+
164
+ def __init__(self, config: Qwen3Config, dtype: jnp.dtype, rng: nnx.Rngs,
165
+ mesh: Mesh, kv_cache_dtype: str):
166
+ rms_norm_eps = config.rms_norm_eps
167
+ hidden_size = config.hidden_size
168
+
169
+ self.input_layernorm = nnx.RMSNorm(
170
+ hidden_size,
171
+ epsilon=rms_norm_eps,
172
+ param_dtype=dtype,
173
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
174
+ rngs=rng,
175
+ )
176
+ self.self_attn = Qwen3Attention(config=config,
177
+ dtype=dtype,
178
+ rng=rng,
179
+ mesh=mesh,
180
+ kv_cache_dtype=kv_cache_dtype)
181
+ self.post_attention_layernorm = nnx.RMSNorm(
182
+ hidden_size,
183
+ epsilon=rms_norm_eps,
184
+ param_dtype=dtype,
185
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
186
+ rngs=rng,
187
+ )
188
+ self.mlp = Qwen3MLP(
189
+ config=config,
190
+ dtype=dtype,
191
+ rng=rng,
192
+ )
193
+
194
+
195
+ class Qwen3Model(Qwen2Model):
196
+
197
+ def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs,
198
+ mesh: Mesh) -> None:
199
+ model_config = vllm_config.model_config
200
+ hf_config = model_config.hf_config
201
+ vocab_size = model_config.get_vocab_size()
202
+ dtype = model_config.dtype
203
+ rms_norm_eps = hf_config.rms_norm_eps
204
+ hidden_size = hf_config.hidden_size
205
+
206
+ self.embed = nnx.Embed(
207
+ num_embeddings=vocab_size,
208
+ features=hidden_size,
209
+ param_dtype=dtype,
210
+ embedding_init=nnx.with_partitioning(init_fn, ("model", None)),
211
+ rngs=rng,
212
+ )
213
+ self.layers = [
214
+ Qwen3DecoderLayer(
215
+ config=hf_config,
216
+ dtype=dtype,
217
+ rng=rng,
218
+ mesh=mesh,
219
+ # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
220
+ kv_cache_dtype=vllm_config.cache_config.cache_dtype)
221
+ for _ in range(hf_config.num_hidden_layers)
222
+ ]
223
+ self.norm = nnx.RMSNorm(
224
+ hidden_size,
225
+ epsilon=rms_norm_eps,
226
+ param_dtype=dtype,
227
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
228
+ rngs=rng,
229
+ )
230
+ if model_config.hf_config.tie_word_embeddings:
231
+ self.lm_head = self.embed.embedding
232
+ else:
233
+ self.lm_head = nnx.Param(
234
+ init_fn(rng.params(), (hidden_size, vocab_size), dtype),
235
+ sharding=(None, "model"),
236
+ )
237
+
238
+
239
+ class Qwen3ForCausalLM(nnx.Module):
240
+
241
+ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array,
242
+ mesh: Mesh) -> None:
243
+ self.vllm_config = vllm_config
244
+ self.rng = nnx.Rngs(rng_key)
245
+ self.mesh = mesh
246
+
247
+ self.model = Qwen3Model(
248
+ vllm_config=vllm_config,
249
+ rng=self.rng,
250
+ mesh=mesh,
251
+ )
252
+
253
+ def __call__(
254
+ self,
255
+ kv_caches: List[jax.Array],
256
+ input_ids: jax.Array,
257
+ attention_metadata: AttentionMetadata,
258
+ *args,
259
+ ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]:
260
+ kv_caches, x = self.model(
261
+ kv_caches,
262
+ input_ids,
263
+ attention_metadata,
264
+ )
265
+ return kv_caches, x, []
266
+
267
+ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
268
+ if self.vllm_config.model_config.hf_config.tie_word_embeddings:
269
+ logits = jnp.dot(hidden_states, self.model.lm_head.value.T)
270
+ else:
271
+ logits = jnp.dot(hidden_states, self.model.lm_head.value)
272
+ return logits
273
+
274
+ def load_weights(self, rng_key: jax.Array):
275
+ # NOTE: Since we are using nnx.eval_shape to init the model,
276
+ # we have to pass dynamic arrays here for __call__'s usage.
277
+ self.rng = nnx.Rngs(rng_key)
278
+
279
+ # Key: path to a HF layer weight
280
+ # Value: path to a nnx layer weight
281
+ mappings = {
282
+ "model.embed_tokens": "model.embed.embedding",
283
+ "model.layers.*.input_layernorm":
284
+ "model.layers.*.input_layernorm.scale",
285
+ "model.layers.*.mlp.down_proj":
286
+ "model.layers.*.mlp.down_proj.kernel",
287
+ "model.layers.*.mlp.gate_proj":
288
+ "model.layers.*.mlp.gate_proj.kernel",
289
+ "model.layers.*.mlp.up_proj": "model.layers.*.mlp.up_proj.kernel",
290
+ "model.layers.*.post_attention_layernorm":
291
+ "model.layers.*.post_attention_layernorm.scale",
292
+ "model.layers.*.self_attn.k_norm":
293
+ "model.layers.*.self_attn.k_norm.scale",
294
+ "model.layers.*.self_attn.k_proj":
295
+ "model.layers.*.self_attn.k_proj.kernel",
296
+ "model.layers.*.self_attn.o_proj":
297
+ "model.layers.*.self_attn.o_proj.kernel",
298
+ "model.layers.*.self_attn.q_norm":
299
+ "model.layers.*.self_attn.q_norm.scale",
300
+ "model.layers.*.self_attn.q_proj":
301
+ "model.layers.*.self_attn.q_proj.kernel",
302
+ "model.layers.*.self_attn.v_proj":
303
+ "model.layers.*.self_attn.v_proj.kernel",
304
+ "model.norm": "model.norm.scale",
305
+ }
306
+
307
+ # Add lm_head mapping only if it's not tied to embeddings
308
+ if not self.vllm_config.model_config.hf_config.tie_word_embeddings:
309
+ mappings.update({
310
+ "lm_head": "model.lm_head",
311
+ })
312
+
313
+ metadata_map = get_default_maps(self.vllm_config.model_config,
314
+ self.mesh, mappings)
315
+ load_hf_weights(vllm_config=self.vllm_config,
316
+ model=self,
317
+ metadata_map=metadata_map,
318
+ mesh=self.mesh)
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,110 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import glob
16
+ import hashlib
17
+ import os
18
+ import shutil
19
+ import subprocess
20
+ from typing import List, Optional
21
+
22
+ import filelock
23
+ import huggingface_hub.constants
24
+ from huggingface_hub import HfFileSystem, snapshot_download
25
+ from tqdm.auto import tqdm
26
+
27
+ from tpu_inference.logger import init_logger
28
+
29
+ logger = init_logger(__name__)
30
+ # Do not set the HuggingFace token here, it should be set via the env `HF_TOKEN`.
31
+ hfs = HfFileSystem()
32
+
33
+ LOCK_DIR = "/tmp/lock"
34
+
35
+ ##### Local file utils #####
36
+
37
+
38
+ def run_cmd(cmd: str, *args, **kwargs) -> subprocess.CompletedProcess:
39
+ return subprocess.run(cmd.split(), *args, **kwargs)
40
+
41
+
42
+ def delete_file(path: str) -> None:
43
+ if os.path.isfile(path):
44
+ os.remove(path)
45
+ else:
46
+ logger.error(f"Trying to delete non-existing file: {path}")
47
+
48
+
49
+ def list_files(dir: str, pattern: str = "*") -> List[str]:
50
+ files = glob.glob(os.path.join(dir, pattern))
51
+ return files
52
+
53
+
54
+ def get_lock(model_name_or_path: str):
55
+ lock_dir = LOCK_DIR
56
+ model_name_or_path = str(model_name_or_path)
57
+ os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
58
+ model_name = model_name_or_path.replace("/", "-")
59
+ hash_name = hashlib.sha256(model_name.encode()).hexdigest()
60
+ # add hash to avoid conflict with old users' lock files
61
+ lock_file_name = hash_name + model_name + ".lock"
62
+ # mode 0o666 is required for the filelock to be shared across users
63
+ lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name),
64
+ mode=0o666)
65
+ return lock
66
+
67
+
68
+ def get_free_disk_size(path: str = "/") -> int:
69
+ free_bytes = shutil.disk_usage(path)[2]
70
+ return free_bytes
71
+
72
+
73
+ ##### HuggingFace file utils #####
74
+
75
+
76
+ def is_hf_repo(repo_id: str) -> bool:
77
+ return hfs.exists(repo_id)
78
+
79
+
80
+ def list_hf_repo(repo_id: str, pattern: str = "**") -> List[str]:
81
+ repo_files = hfs.glob(os.path.join(repo_id, pattern))
82
+ return repo_files
83
+
84
+
85
+ def get_hf_model_weights_size(repo_id: str, weights_format: str) -> int:
86
+ weights_paths = list_hf_repo(repo_id, weights_format)
87
+ weights_size = 0
88
+ for weights_path in weights_paths:
89
+ weights_size += int(hfs.info(weights_path)["size"])
90
+ return weights_size
91
+
92
+
93
+ class DisabledTqdm(tqdm):
94
+
95
+ def __init__(self, *args, **kwargs):
96
+ super().__init__(*args, **kwargs, disable=True)
97
+
98
+
99
+ def download_model_weights_from_hf(model_path: str, cache_dir: Optional[str],
100
+ weights_format: str) -> str:
101
+ with get_lock(model_path):
102
+ local_dir = snapshot_download(
103
+ model_path,
104
+ cache_dir=cache_dir, # can be specified by HF_HOME or HF_HUB_CACHE
105
+ allow_patterns=weights_format,
106
+ tqdm_class=DisabledTqdm,
107
+ local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
108
+ )
109
+ local_files = list_files(local_dir, weights_format)
110
+ return local_files
@@ -0,0 +1,177 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Union
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ from typing_extensions import TypeAlias
20
+ from vllm.logger import init_logger
21
+
22
+ logger = init_logger(__name__)
23
+
24
+ NestedTensors: TypeAlias = Union[list["NestedTensors"], list["jax.Array"],
25
+ "jax.Array", tuple["jax.Array", ...]]
26
+ """
27
+ Uses a list instead of a tensor if the dimensions of each element do not match.
28
+ """
29
+
30
+ MultiModalEmbeddings = Union[list[jax.Array], jax.Array, tuple[jax.Array, ...]]
31
+ """
32
+ The output embeddings must be one of the following formats:
33
+
34
+ - A list or tuple of 2D tensors, where each tensor corresponds to
35
+ each input multimodal data item (e.g, image).
36
+ - A single 3D tensor, with the batch dimension grouping the 2D tensors.
37
+ """
38
+
39
+
40
+ def sanity_check_mm_encoder_outputs(
41
+ mm_embeddings: MultiModalEmbeddings,
42
+ expected_num_items: int,
43
+ ) -> None:
44
+ """
45
+ Perform sanity checks for the result of
46
+ [`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`][].
47
+ """
48
+ assert isinstance(mm_embeddings, (list, tuple, jax.Array)), (
49
+ "Expected multimodal embeddings to be a list/tuple of 2D tensors, "
50
+ f"or a single 3D tensor, but got {type(mm_embeddings)} "
51
+ "instead. This is most likely due to incorrect implementation "
52
+ "of the model's `get_multimodal_embeddings` method.")
53
+
54
+ assert len(mm_embeddings) == expected_num_items, (
55
+ "Expected number of multimodal embeddings to match number of "
56
+ f"input items: {expected_num_items}, but got {len(mm_embeddings)=} "
57
+ "instead. This is most likely due to incorrect implementation "
58
+ "of the model's `get_multimodal_embeddings` method.")
59
+
60
+ assert all(e.ndim == 2 for e in mm_embeddings), (
61
+ "Expected multimodal embeddings to be a sequence of 2D tensors, "
62
+ f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
63
+ "instead. This is most likely due to incorrect implementation "
64
+ "of the model's `get_multimodal_embeddings` method.")
65
+
66
+
67
+ def flatten_embeddings(embeddings: NestedTensors) -> jax.Array:
68
+ """
69
+ Recursively flattens and concatenates NestedTensors on all but the last
70
+ dimension.
71
+ """
72
+
73
+ if isinstance(embeddings, jax.Array):
74
+ return embeddings.reshape(-1, embeddings.shape[-1])
75
+
76
+ return jnp.concatenate([flatten_embeddings(t) for t in embeddings], axis=0)
77
+
78
+
79
+ def _embedding_count_expression(embeddings: NestedTensors) -> str:
80
+ """
81
+ Constructs a debugging representation of the number of embeddings in the
82
+ NestedTensors.
83
+ """
84
+
85
+ if isinstance(embeddings, jax.Array):
86
+ return " x ".join([str(dim) for dim in embeddings.shape[:-1]])
87
+
88
+ return " + ".join(
89
+ _embedding_count_expression(inner) for inner in embeddings)
90
+
91
+
92
+ def _merge_multimodal_embeddings(
93
+ inputs_embeds: jax.Array,
94
+ is_multimodal: jax.Array,
95
+ multimodal_embeddings: jax.Array,
96
+ ) -> jax.Array:
97
+ """
98
+ Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
99
+ positions in ``inputs_embeds`` corresponding to placeholder tokens in
100
+ ``input_ids``.
101
+ This returns a new array with the updated values.
102
+ Note:
103
+ This returns a new array with the updated values.
104
+ """
105
+ # The check for matching number of tokens is removed as it is not
106
+ # JIT-compatible. If the shapes mismatch, JAX will raise an error
107
+ # during execution anyway. The user-friendly error message is
108
+ # sacrificed for JIT compatibility.
109
+
110
+ # JIT-compatible implementation using jnp.where to avoid
111
+ # NonConcreteBooleanIndexError.
112
+ # Create a dummy row to handle indices for non-multimodal tokens.
113
+ # The content of the dummy row does not matter as it will be masked out.
114
+ dummy_row = jnp.zeros_like(multimodal_embeddings[0:1])
115
+
116
+ # Prepend the dummy row to the flattened embeddings.
117
+ flattened_padded = jnp.concatenate([dummy_row, multimodal_embeddings],
118
+ axis=0)
119
+
120
+ # Create gather indices. For each token in the input sequence, this gives
121
+ # the index into `flattened_padded`.
122
+ # For non-multimodal tokens, the index will be 0 (pointing to the dummy
123
+ # row). For the k-th multimodal token, the index will be k.
124
+ gather_indices = jnp.cumsum(is_multimodal)
125
+
126
+ # Gather the embeddings to be placed.
127
+ update_values = flattened_padded[gather_indices]
128
+
129
+ # Use jnp.where to select between original and new embeddings.
130
+ condition = jnp.expand_dims(is_multimodal, axis=-1)
131
+ return jnp.where(condition, update_values, inputs_embeds)
132
+
133
+
134
+ def merge_multimodal_embeddings(
135
+ input_ids: jax.Array,
136
+ inputs_embeds: jax.Array,
137
+ multimodal_embeddings: jax.Array,
138
+ placeholder_token_id: Union[int, list[int]],
139
+ ) -> jax.Array:
140
+ """
141
+ Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
142
+ positions in ``inputs_embeds`` corresponding to placeholder tokens in
143
+ ``input_ids``.
144
+
145
+ ``placeholder_token_id`` can be a list of token ids (e.g, token ids
146
+ of img_start, img_break, and img_end tokens) when needed: This means
147
+ the order of these tokens in the ``input_ids`` MUST MATCH the order of
148
+ their embeddings in ``multimodal_embeddings`` since we need to
149
+ slice-merge instead of individually scattering.
150
+
151
+ For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
152
+ - T is text token
153
+ - S is image start token
154
+ - I is image embedding token
155
+ - B is image break token
156
+ - E is image end token.
157
+
158
+ Then the image embeddings (that correspond to I's) from vision encoder
159
+ must be padded with embeddings of S, B, and E in the same order of
160
+ input_ids for a correct embedding merge.
161
+
162
+ This returns a new array with the updated values.
163
+ """
164
+ if isinstance(placeholder_token_id, list):
165
+ placeholder_token_id = jnp.array(placeholder_token_id)
166
+
167
+ return _merge_multimodal_embeddings(
168
+ inputs_embeds,
169
+ jnp.isin(input_ids, placeholder_token_id),
170
+ multimodal_embeddings,
171
+ )
172
+
173
+ return _merge_multimodal_embeddings(
174
+ inputs_embeds,
175
+ (input_ids == placeholder_token_id),
176
+ multimodal_embeddings,
177
+ )
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.