tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (257) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +317 -34
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +406 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +320 -0
  64. tests/layers/vllm/test_unquantized.py +662 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +26 -6
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +25 -4
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +25 -12
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  154. tpu_inference/layers/common/quant_methods.py +15 -0
  155. tpu_inference/layers/common/quantization.py +282 -0
  156. tpu_inference/layers/common/sharding.py +32 -9
  157. tpu_inference/layers/common/utils.py +94 -0
  158. tpu_inference/layers/jax/__init__.py +13 -0
  159. tpu_inference/layers/jax/attention/__init__.py +13 -0
  160. tpu_inference/layers/jax/attention/attention.py +19 -6
  161. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  162. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  163. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  164. tpu_inference/layers/jax/base.py +14 -0
  165. tpu_inference/layers/jax/constants.py +13 -0
  166. tpu_inference/layers/jax/layers.py +14 -0
  167. tpu_inference/layers/jax/misc.py +14 -0
  168. tpu_inference/layers/jax/moe/__init__.py +13 -0
  169. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  170. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  171. tpu_inference/layers/jax/moe/moe.py +43 -3
  172. tpu_inference/layers/jax/pp_utils.py +53 -0
  173. tpu_inference/layers/jax/rope.py +14 -0
  174. tpu_inference/layers/jax/rope_interface.py +14 -0
  175. tpu_inference/layers/jax/sample/__init__.py +13 -0
  176. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  177. tpu_inference/layers/jax/sample/sampling.py +15 -1
  178. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  179. tpu_inference/layers/jax/transformer_block.py +14 -0
  180. tpu_inference/layers/vllm/__init__.py +13 -0
  181. tpu_inference/layers/vllm/attention.py +4 -4
  182. tpu_inference/layers/vllm/fused_moe.py +101 -494
  183. tpu_inference/layers/vllm/linear.py +64 -0
  184. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  185. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  186. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  187. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  188. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  189. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  191. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
  192. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
  193. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  194. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  195. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  196. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
  197. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  198. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
  199. tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
  200. tpu_inference/lora/__init__.py +13 -0
  201. tpu_inference/lora/torch_lora_ops.py +8 -13
  202. tpu_inference/models/__init__.py +13 -0
  203. tpu_inference/models/common/__init__.py +13 -0
  204. tpu_inference/models/common/model_loader.py +112 -35
  205. tpu_inference/models/jax/__init__.py +13 -0
  206. tpu_inference/models/jax/deepseek_v3.py +267 -157
  207. tpu_inference/models/jax/gpt_oss.py +26 -10
  208. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  209. tpu_inference/models/jax/llama3.py +99 -36
  210. tpu_inference/models/jax/llama4.py +14 -0
  211. tpu_inference/models/jax/llama_eagle3.py +18 -5
  212. tpu_inference/models/jax/llama_guard_4.py +15 -1
  213. tpu_inference/models/jax/qwen2.py +17 -2
  214. tpu_inference/models/jax/qwen2_5_vl.py +179 -51
  215. tpu_inference/models/jax/qwen3.py +17 -2
  216. tpu_inference/models/jax/utils/__init__.py +13 -0
  217. tpu_inference/models/jax/utils/file_utils.py +14 -0
  218. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  219. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  220. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
  221. tpu_inference/models/jax/utils/weight_utils.py +234 -155
  222. tpu_inference/models/vllm/__init__.py +13 -0
  223. tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
  224. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  225. tpu_inference/platforms/__init__.py +14 -0
  226. tpu_inference/platforms/tpu_platform.py +51 -72
  227. tpu_inference/runner/__init__.py +13 -0
  228. tpu_inference/runner/compilation_manager.py +180 -80
  229. tpu_inference/runner/kv_cache.py +54 -20
  230. tpu_inference/runner/kv_cache_manager.py +55 -33
  231. tpu_inference/runner/lora_utils.py +16 -1
  232. tpu_inference/runner/multimodal_manager.py +16 -2
  233. tpu_inference/runner/persistent_batch_manager.py +54 -2
  234. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  235. tpu_inference/runner/structured_decoding_manager.py +16 -3
  236. tpu_inference/runner/tpu_runner.py +124 -61
  237. tpu_inference/runner/utils.py +2 -2
  238. tpu_inference/spec_decode/__init__.py +13 -0
  239. tpu_inference/spec_decode/jax/__init__.py +13 -0
  240. tpu_inference/spec_decode/jax/eagle3.py +84 -22
  241. tpu_inference/tpu_info.py +14 -0
  242. tpu_inference/utils.py +72 -44
  243. tpu_inference/worker/__init__.py +13 -0
  244. tpu_inference/worker/tpu_worker.py +66 -52
  245. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
  246. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  247. tpu_inference/layers/vllm/linear_common.py +0 -186
  248. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  249. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  250. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  251. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  252. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  253. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  254. tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
  255. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  256. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  257. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import functools
2
16
  from typing import Any, Optional
3
17
 
@@ -5,22 +19,31 @@ import jax
5
19
  import torch
6
20
  from flax import nnx
7
21
  from jax.sharding import Mesh, NamedSharding, PartitionSpec
8
- from torchax.ops.mappings import j2t_dtype
9
22
  from transformers import PretrainedConfig
10
23
  from vllm.config import VllmConfig
24
+ from vllm.model_executor.model_loader import get_model_loader
25
+ from vllm.model_executor.model_loader.runai_streamer_loader import \
26
+ RunaiModelStreamerLoader
11
27
  from vllm.utils.func_utils import supports_kw
12
28
 
13
29
  from tpu_inference import envs
14
30
  from tpu_inference.layers.common.sharding import ShardingAxisName
15
31
  from tpu_inference.logger import init_logger
16
- from tpu_inference.models.jax.utils.quantization.quantization_utils import (
32
+ from tpu_inference.models.jax.utils.qwix.qwix_utils import (
17
33
  apply_qwix_on_abstract_model, apply_qwix_quantization,
18
- load_random_weights_into_qwix_abstract_model)
34
+ load_random_weights_into_qwix_abstract_model,
35
+ update_vllm_config_for_qwix_quantization)
36
+ from tpu_inference.utils import to_jax_dtype, to_torch_dtype
19
37
 
20
38
  logger = init_logger(__name__)
21
39
 
22
40
  _MODEL_REGISTRY = {}
23
41
 
42
+ # List of architectures that are preferred to use "vllm" implementation over
43
+ # "flax_nnx" implementation due to various factors such as performance.
44
+ _VLLM_PREFERRED_ARCHITECTURES: frozenset[str] = frozenset(
45
+ {"GptOssForCausalLM"})
46
+
24
47
 
25
48
  class UnsupportedArchitectureError(ValueError):
26
49
  """Raised when a model architecture is not supported in the registry."""
@@ -177,7 +200,23 @@ def _get_nnx_model(
177
200
  # the model creation again, otherwise the model forward will have
178
201
  # non-trivial overhead in PjitFunction.
179
202
  with mesh:
180
- model.load_weights(rng)
203
+ loader = get_model_loader(vllm_config.load_config)
204
+ if isinstance(loader, RunaiModelStreamerLoader):
205
+ model_weights = vllm_config.model_config.model
206
+ if hasattr(vllm_config.model_config, "model_weights"):
207
+ model_weights = vllm_config.model_config.model_weights
208
+ weights_iterator = loader._get_weights_iterator(
209
+ model_weights, vllm_config.model_config.revision)
210
+ # We set the weights iterator at runtime, to prevent having to change
211
+ # every model's load_weights signature. This also prevents us from hitting
212
+ # a TypeError at runtime if you use the RunaiModelStreamerLoader with any
213
+ # flax_nnx model whose load_weights function does not accept the
214
+ # weights_iterator keyword argument.
215
+ vllm_config.model_config.model_weights_iterator = weights_iterator
216
+ model.load_weights(rng)
217
+ del vllm_config.model_config.model_weights_iterator
218
+ else:
219
+ model.load_weights(rng)
181
220
  jit_model = create_jit_model(
182
221
  model,
183
222
  use_qwix_on_abstract_model=should_apply_qwix_on_abstract_model)
@@ -191,6 +230,13 @@ def get_flax_model(
191
230
  mesh: Mesh,
192
231
  is_draft_model: bool = False,
193
232
  ) -> nnx.Module:
233
+ model_dtype = to_jax_dtype(vllm_config.model_config.dtype)
234
+ vllm_config.model_config.dtype = model_dtype
235
+
236
+ # Only perform qwix quantization if it is jax model.
237
+ if vllm_config.model_config:
238
+ update_vllm_config_for_qwix_quantization(vllm_config)
239
+
194
240
  if is_draft_model:
195
241
  model_class = _get_model_architecture(
196
242
  vllm_config.speculative_config.draft_model_config.hf_config)
@@ -199,7 +245,9 @@ def get_flax_model(
199
245
  vllm_config.model_config.hf_config)
200
246
  jit_model = _get_nnx_model(model_class, vllm_config, rng, mesh)
201
247
  kv_cache_sharding = NamedSharding(
202
- mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None, "model"))
248
+ mesh,
249
+ PartitionSpec(ShardingAxisName.ATTN_DATA, None,
250
+ ShardingAxisName.ATTN_HEAD))
203
251
  hidden_states_sharding = NamedSharding(mesh,
204
252
  PartitionSpec(
205
253
  ShardingAxisName.ATTN_DATA,
@@ -217,14 +265,17 @@ def get_flax_model(
217
265
  hidden_states_sharding, # aux hidden states
218
266
  ),
219
267
  donate_argnums=2, # 0 is graphdef, 1 is state, 2 is kv_cache
220
- static_argnums=7, #7 is layer_name_to_kvcache_index
268
+ static_argnums=(
269
+ 7, 10, 11
270
+ ), #7 is layer_name_to_kvcache_index, 10 is is_first_rank, 11 is is_last_rank
221
271
  )
222
272
  def run_model(graphdef, state, *args):
223
273
  model = nnx.merge(graphdef, state)
224
274
  return model(*args)
225
275
 
226
276
  logits_sharding = NamedSharding(
227
- mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, "model"))
277
+ mesh,
278
+ PartitionSpec(ShardingAxisName.MLP_DATA, ShardingAxisName.MLP_TENSOR))
228
279
 
229
280
  @functools.partial(
230
281
  jax.jit,
@@ -237,10 +288,9 @@ def get_flax_model(
237
288
 
238
289
  # Multi-modal support only
239
290
  # This function calculates the image token's embeddings by VIT
240
- def run_get_multimodal_embeddings(graphdef, state, image_grid_thw,
241
- **kwargs):
291
+ def run_embed_multimodal(graphdef, state, image_grid_thw, **kwargs):
242
292
  model = nnx.merge(graphdef, state)
243
- return model.get_multimodal_embeddings(image_grid_thw, **kwargs)
293
+ return model.embed_multimodal(image_grid_thw, **kwargs)
244
294
 
245
295
  embed_sharding = NamedSharding(mesh, PartitionSpec(None))
246
296
  # This function will calculates the embeddings of input texts and then merge with the image embeddings
@@ -248,9 +298,9 @@ def get_flax_model(
248
298
  jax.jit,
249
299
  out_shardings=(embed_sharding),
250
300
  )
251
- def run_get_input_embeddings(graphdef, state, *args, **kwargs):
301
+ def run_embed_input_ids(graphdef, state, *args, **kwargs):
252
302
  model = nnx.merge(graphdef, state)
253
- return model.get_input_embeddings(*args, **kwargs)
303
+ return model.embed_input_ids(*args, **kwargs)
254
304
 
255
305
  # For models that want to work with EAGLE-3 speculative decoding
256
306
  @functools.partial(
@@ -266,10 +316,8 @@ def get_flax_model(
266
316
  None)
267
317
  model_fn = functools.partial(run_model, graphdef)
268
318
  compute_logits_fn = functools.partial(run_compute_logits, graphdef)
269
- get_multimodal_embeddings_fn = functools.partial(
270
- run_get_multimodal_embeddings, graphdef)
271
- get_input_embeddings_fn = functools.partial(run_get_input_embeddings,
272
- graphdef)
319
+ embed_multimodal_fn = functools.partial(run_embed_multimodal, graphdef)
320
+ embed_input_ids_fn = functools.partial(run_embed_input_ids, graphdef)
273
321
  lora_manager, model = None, None
274
322
  combine_hidden_states_fn = functools.partial(combine_hidden_states,
275
323
  graphdef)
@@ -280,8 +328,8 @@ def get_flax_model(
280
328
 
281
329
  multimodal_fns = {
282
330
  "precompile_vision_encoder_fn": precompile_vision_encoder_fn,
283
- "get_multimodal_embeddings_fn": get_multimodal_embeddings_fn,
284
- "get_input_embeddings_fn": get_input_embeddings_fn,
331
+ "embed_multimodal_fn": embed_multimodal_fn,
332
+ "embed_input_ids_fn": embed_input_ids_fn,
285
333
  "get_mrope_input_positions_fn": get_mrope_input_positions_fn,
286
334
  }
287
335
 
@@ -293,6 +341,8 @@ def get_vllm_model(
293
341
  rng: jax.Array,
294
342
  mesh: Mesh,
295
343
  ):
344
+ model_dtype = to_torch_dtype(vllm_config.model_config.dtype)
345
+ vllm_config.model_config.dtype = model_dtype
296
346
  from tpu_inference.models.vllm.vllm_model_wrapper import VllmModelWrapper
297
347
 
298
348
  model = VllmModelWrapper(
@@ -318,24 +368,39 @@ def get_model(
318
368
  impl = envs.MODEL_IMPL_TYPE
319
369
  logger.info(f"Loading model with MODEL_IMPL_TYPE={impl}")
320
370
 
321
- if impl == "flax_nnx":
322
- try:
323
- # Try to load the flax model first
324
- return get_flax_model(vllm_config, rng, mesh, is_draft_model)
325
- except UnsupportedArchitectureError as e:
326
- # Convert the error message to a string to check its contents
327
- error_msg = str(e)
328
-
329
- logger.warning(error_msg)
330
-
331
- # Fall back to the vLLM model and updating the dtype accordingly
332
- vllm_config.model_config.dtype = j2t_dtype(
333
- vllm_config.model_config.dtype.dtype)
371
+ if impl == "auto":
372
+ # Resolve "auto" based on architecture
373
+ architectures = getattr(vllm_config.model_config.hf_config,
374
+ "architectures", [])
375
+ assert len(architectures) == 1, (
376
+ f"Expected exactly one architecture, got {len(architectures)}: "
377
+ f"{architectures}")
378
+ arch = architectures[0]
379
+ impl = "vllm" if arch in _VLLM_PREFERRED_ARCHITECTURES else "flax_nnx"
380
+ logger.info(f"Resolved MODEL_IMPL_TYPE 'auto' to '{impl}'")
381
+
382
+ match impl:
383
+ case "flax_nnx":
384
+ if vllm_config.parallel_config.pipeline_parallel_size > 1:
385
+ logger.warning(
386
+ "PP is not fully supported on Jax flax_nnx models yet, fallback to vllm models."
387
+ )
388
+ return get_vllm_model(vllm_config, rng, mesh)
389
+ try:
390
+ # Try to load the flax model first
391
+ return get_flax_model(vllm_config, rng, mesh, is_draft_model)
392
+ except UnsupportedArchitectureError as e:
393
+ # Convert the error message to a string to check its contents
394
+ error_msg = str(e)
395
+
396
+ logger.warning(error_msg)
397
+
398
+ # Fall back to the vLLM model and updating the dtype accordingly
399
+ return get_vllm_model(vllm_config, rng, mesh)
400
+ case "vllm":
334
401
  return get_vllm_model(vllm_config, rng, mesh)
335
- elif impl == "vllm":
336
- return get_vllm_model(vllm_config, rng, mesh)
337
- else:
338
- raise NotImplementedError("Unsupported MODEL_IMPL_TYPE")
402
+ case _:
403
+ raise NotImplementedError(f"Unsupported MODEL_IMPL_TYPE: {impl}")
339
404
 
340
405
 
341
406
  def _validate_model_interface(model: Any) -> None:
@@ -421,6 +486,17 @@ def register_model(arch: str, model: Any) -> None:
421
486
  "This is a JAX model and does not implement the PyTorch forward method."
422
487
  )
423
488
 
489
+ # Same as `forward`, this is a dummy method to satisfy vLLM's type checks.
490
+ def unimplemented_embed_input_ids(
491
+ self,
492
+ input_ids: "torch.Tensor",
493
+ positions: "torch.Tensor",
494
+ inputs_embeds: Optional["torch.Tensor"] = None,
495
+ ) -> "torch.Tensor":
496
+ raise NotImplementedError(
497
+ "This is a JAX model and does not implement the PyTorch embed_input_ids method."
498
+ )
499
+
424
500
  # We need a custom __init__ that only calls torch.nn.Module's init,
425
501
  # to avoid triggering JAX logic when vLLM inspects the class.
426
502
  def wrapper_init(self, *args, **kwargs):
@@ -434,6 +510,7 @@ def register_model(arch: str, model: Any) -> None:
434
510
  {
435
511
  "__init__": wrapper_init,
436
512
  "forward": unimplemented_forward,
513
+ "embed_input_ids": unimplemented_embed_input_ids,
437
514
  # Prevent vLLM from trying to load weights into this dummy class.
438
515
  "load_weights": lambda self, *args, **kwargs: None,
439
516
  })
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.