tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (250) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +78 -1
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +1 -43
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +14 -9
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +38 -7
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +17 -0
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +28 -5
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +74 -35
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +88 -25
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -64
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +72 -37
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +45 -15
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +14 -0
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +41 -16
  232. tpu_inference/spec_decode/__init__.py +13 -0
  233. tpu_inference/spec_decode/jax/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  235. tpu_inference/tpu_info.py +14 -0
  236. tpu_inference/utils.py +42 -36
  237. tpu_inference/worker/__init__.py +13 -0
  238. tpu_inference/worker/tpu_worker.py +63 -50
  239. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
  240. tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
  241. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  242. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  245. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  246. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  247. tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
  248. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
  249. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
  250. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import functools
2
16
  from typing import Any, Optional
3
17
 
@@ -5,7 +19,6 @@ import jax
5
19
  import torch
6
20
  from flax import nnx
7
21
  from jax.sharding import Mesh, NamedSharding, PartitionSpec
8
- from torchax.ops.mappings import j2t_dtype
9
22
  from transformers import PretrainedConfig
10
23
  from vllm.config import VllmConfig
11
24
  from vllm.model_executor.model_loader import get_model_loader
@@ -16,14 +29,20 @@ from vllm.utils.func_utils import supports_kw
16
29
  from tpu_inference import envs
17
30
  from tpu_inference.layers.common.sharding import ShardingAxisName
18
31
  from tpu_inference.logger import init_logger
19
- from tpu_inference.models.jax.utils.quantization.quantization_utils import (
32
+ from tpu_inference.models.jax.utils.qwix.qwix_utils import (
20
33
  apply_qwix_on_abstract_model, apply_qwix_quantization,
21
34
  load_random_weights_into_qwix_abstract_model)
35
+ from tpu_inference.utils import to_jax_dtype, to_torch_dtype
22
36
 
23
37
  logger = init_logger(__name__)
24
38
 
25
39
  _MODEL_REGISTRY = {}
26
40
 
41
+ # List of architectures that are preferred to use "vllm" implementation over
42
+ # "flax_nnx" implementation due to various factors such as performance.
43
+ _VLLM_PREFERRED_ARCHITECTURES: frozenset[str] = frozenset(
44
+ {"GptOssForCausalLM"})
45
+
27
46
 
28
47
  class UnsupportedArchitectureError(ValueError):
29
48
  """Raised when a model architecture is not supported in the registry."""
@@ -210,6 +229,9 @@ def get_flax_model(
210
229
  mesh: Mesh,
211
230
  is_draft_model: bool = False,
212
231
  ) -> nnx.Module:
232
+ model_dtype = to_jax_dtype(vllm_config.model_config.dtype)
233
+ vllm_config.model_config.dtype = model_dtype
234
+
213
235
  if is_draft_model:
214
236
  model_class = _get_model_architecture(
215
237
  vllm_config.speculative_config.draft_model_config.hf_config)
@@ -218,7 +240,9 @@ def get_flax_model(
218
240
  vllm_config.model_config.hf_config)
219
241
  jit_model = _get_nnx_model(model_class, vllm_config, rng, mesh)
220
242
  kv_cache_sharding = NamedSharding(
221
- mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None, "model"))
243
+ mesh,
244
+ PartitionSpec(ShardingAxisName.ATTN_DATA, None,
245
+ ShardingAxisName.ATTN_HEAD))
222
246
  hidden_states_sharding = NamedSharding(mesh,
223
247
  PartitionSpec(
224
248
  ShardingAxisName.ATTN_DATA,
@@ -245,7 +269,8 @@ def get_flax_model(
245
269
  return model(*args)
246
270
 
247
271
  logits_sharding = NamedSharding(
248
- mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, "model"))
272
+ mesh,
273
+ PartitionSpec(ShardingAxisName.MLP_DATA, ShardingAxisName.MLP_TENSOR))
249
274
 
250
275
  @functools.partial(
251
276
  jax.jit,
@@ -258,10 +283,9 @@ def get_flax_model(
258
283
 
259
284
  # Multi-modal support only
260
285
  # This function calculates the image token's embeddings by VIT
261
- def run_get_multimodal_embeddings(graphdef, state, image_grid_thw,
262
- **kwargs):
286
+ def run_embed_multimodal(graphdef, state, image_grid_thw, **kwargs):
263
287
  model = nnx.merge(graphdef, state)
264
- return model.get_multimodal_embeddings(image_grid_thw, **kwargs)
288
+ return model.embed_multimodal(image_grid_thw, **kwargs)
265
289
 
266
290
  embed_sharding = NamedSharding(mesh, PartitionSpec(None))
267
291
  # This function will calculates the embeddings of input texts and then merge with the image embeddings
@@ -269,9 +293,9 @@ def get_flax_model(
269
293
  jax.jit,
270
294
  out_shardings=(embed_sharding),
271
295
  )
272
- def run_get_input_embeddings(graphdef, state, *args, **kwargs):
296
+ def run_embed_input_ids(graphdef, state, *args, **kwargs):
273
297
  model = nnx.merge(graphdef, state)
274
- return model.get_input_embeddings(*args, **kwargs)
298
+ return model.embed_input_ids(*args, **kwargs)
275
299
 
276
300
  # For models that want to work with EAGLE-3 speculative decoding
277
301
  @functools.partial(
@@ -287,10 +311,8 @@ def get_flax_model(
287
311
  None)
288
312
  model_fn = functools.partial(run_model, graphdef)
289
313
  compute_logits_fn = functools.partial(run_compute_logits, graphdef)
290
- get_multimodal_embeddings_fn = functools.partial(
291
- run_get_multimodal_embeddings, graphdef)
292
- get_input_embeddings_fn = functools.partial(run_get_input_embeddings,
293
- graphdef)
314
+ embed_multimodal_fn = functools.partial(run_embed_multimodal, graphdef)
315
+ embed_input_ids_fn = functools.partial(run_embed_input_ids, graphdef)
294
316
  lora_manager, model = None, None
295
317
  combine_hidden_states_fn = functools.partial(combine_hidden_states,
296
318
  graphdef)
@@ -301,8 +323,8 @@ def get_flax_model(
301
323
 
302
324
  multimodal_fns = {
303
325
  "precompile_vision_encoder_fn": precompile_vision_encoder_fn,
304
- "get_multimodal_embeddings_fn": get_multimodal_embeddings_fn,
305
- "get_input_embeddings_fn": get_input_embeddings_fn,
326
+ "embed_multimodal_fn": embed_multimodal_fn,
327
+ "embed_input_ids_fn": embed_input_ids_fn,
306
328
  "get_mrope_input_positions_fn": get_mrope_input_positions_fn,
307
329
  }
308
330
 
@@ -314,6 +336,8 @@ def get_vllm_model(
314
336
  rng: jax.Array,
315
337
  mesh: Mesh,
316
338
  ):
339
+ model_dtype = to_torch_dtype(vllm_config.model_config.dtype)
340
+ vllm_config.model_config.dtype = model_dtype
317
341
  from tpu_inference.models.vllm.vllm_model_wrapper import VllmModelWrapper
318
342
 
319
343
  model = VllmModelWrapper(
@@ -339,24 +363,39 @@ def get_model(
339
363
  impl = envs.MODEL_IMPL_TYPE
340
364
  logger.info(f"Loading model with MODEL_IMPL_TYPE={impl}")
341
365
 
342
- if impl == "flax_nnx":
343
- try:
344
- # Try to load the flax model first
345
- return get_flax_model(vllm_config, rng, mesh, is_draft_model)
346
- except UnsupportedArchitectureError as e:
347
- # Convert the error message to a string to check its contents
348
- error_msg = str(e)
349
-
350
- logger.warning(error_msg)
351
-
352
- # Fall back to the vLLM model and updating the dtype accordingly
353
- vllm_config.model_config.dtype = j2t_dtype(
354
- vllm_config.model_config.dtype.dtype)
366
+ if impl == "auto":
367
+ # Resolve "auto" based on architecture
368
+ architectures = getattr(vllm_config.model_config.hf_config,
369
+ "architectures", [])
370
+ assert len(architectures) == 1, (
371
+ f"Expected exactly one architecture, got {len(architectures)}: "
372
+ f"{architectures}")
373
+ arch = architectures[0]
374
+ impl = "vllm" if arch in _VLLM_PREFERRED_ARCHITECTURES else "flax_nnx"
375
+ logger.info(f"Resolved MODEL_IMPL_TYPE 'auto' to '{impl}'")
376
+
377
+ match impl:
378
+ case "flax_nnx":
379
+ if vllm_config.parallel_config.pipeline_parallel_size > 1:
380
+ logger.warning(
381
+ "PP is not fully supported on Jax flax_nnx models yet, fallback to vllm models."
382
+ )
383
+ return get_vllm_model(vllm_config, rng, mesh)
384
+ try:
385
+ # Try to load the flax model first
386
+ return get_flax_model(vllm_config, rng, mesh, is_draft_model)
387
+ except UnsupportedArchitectureError as e:
388
+ # Convert the error message to a string to check its contents
389
+ error_msg = str(e)
390
+
391
+ logger.warning(error_msg)
392
+
393
+ # Fall back to the vLLM model and updating the dtype accordingly
394
+ return get_vllm_model(vllm_config, rng, mesh)
395
+ case "vllm":
355
396
  return get_vllm_model(vllm_config, rng, mesh)
356
- elif impl == "vllm":
357
- return get_vllm_model(vllm_config, rng, mesh)
358
- else:
359
- raise NotImplementedError("Unsupported MODEL_IMPL_TYPE")
397
+ case _:
398
+ raise NotImplementedError(f"Unsupported MODEL_IMPL_TYPE: {impl}")
360
399
 
361
400
 
362
401
  def _validate_model_interface(model: Any) -> None:
@@ -443,14 +482,14 @@ def register_model(arch: str, model: Any) -> None:
443
482
  )
444
483
 
445
484
  # Same as `forward`, this is a dummy method to satisfy vLLM's type checks.
446
- def unimplemented_get_input_embeddings(
485
+ def unimplemented_embed_input_ids(
447
486
  self,
448
487
  input_ids: "torch.Tensor",
449
488
  positions: "torch.Tensor",
450
489
  inputs_embeds: Optional["torch.Tensor"] = None,
451
490
  ) -> "torch.Tensor":
452
491
  raise NotImplementedError(
453
- "This is a JAX model and does not implement the PyTorch get_input_embeddings method."
492
+ "This is a JAX model and does not implement the PyTorch embed_input_ids method."
454
493
  )
455
494
 
456
495
  # We need a custom __init__ that only calls torch.nn.Module's init,
@@ -466,7 +505,7 @@ def register_model(arch: str, model: Any) -> None:
466
505
  {
467
506
  "__init__": wrapper_init,
468
507
  "forward": unimplemented_forward,
469
- "get_input_embeddings": unimplemented_get_input_embeddings,
508
+ "embed_input_ids": unimplemented_embed_input_ids,
470
509
  # Prevent vLLM from trying to load weights into this dummy class.
471
510
  "load_weights": lambda self, *args, **kwargs: None,
472
511
  })
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.