tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (251) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +22 -1
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +167 -97
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +31 -9
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +280 -210
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +77 -36
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +91 -31
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -4
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -71
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +158 -63
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +53 -30
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +54 -2
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +105 -57
  232. tpu_inference/runner/utils.py +2 -2
  233. tpu_inference/spec_decode/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/__init__.py +13 -0
  235. tpu_inference/spec_decode/jax/eagle3.py +65 -19
  236. tpu_inference/tpu_info.py +14 -0
  237. tpu_inference/utils.py +72 -44
  238. tpu_inference/worker/__init__.py +13 -0
  239. tpu_inference/worker/tpu_worker.py +65 -52
  240. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
  241. tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
  242. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  244. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  245. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  246. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  247. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  248. tpu_inference-0.11.1.dev202511270815.dist-info/RECORD +0 -174
  249. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
  250. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
  251. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import functools
2
16
  from typing import Any, Optional
3
17
 
@@ -5,7 +19,6 @@ import jax
5
19
  import torch
6
20
  from flax import nnx
7
21
  from jax.sharding import Mesh, NamedSharding, PartitionSpec
8
- from torchax.ops.mappings import j2t_dtype
9
22
  from transformers import PretrainedConfig
10
23
  from vllm.config import VllmConfig
11
24
  from vllm.model_executor.model_loader import get_model_loader
@@ -16,14 +29,20 @@ from vllm.utils.func_utils import supports_kw
16
29
  from tpu_inference import envs
17
30
  from tpu_inference.layers.common.sharding import ShardingAxisName
18
31
  from tpu_inference.logger import init_logger
19
- from tpu_inference.models.jax.utils.quantization.quantization_utils import (
32
+ from tpu_inference.models.jax.utils.qwix.qwix_utils import (
20
33
  apply_qwix_on_abstract_model, apply_qwix_quantization,
21
34
  load_random_weights_into_qwix_abstract_model)
35
+ from tpu_inference.utils import to_jax_dtype, to_torch_dtype
22
36
 
23
37
  logger = init_logger(__name__)
24
38
 
25
39
  _MODEL_REGISTRY = {}
26
40
 
41
+ # List of architectures that are preferred to use "vllm" implementation over
42
+ # "flax_nnx" implementation due to various factors such as performance.
43
+ _VLLM_PREFERRED_ARCHITECTURES: frozenset[str] = frozenset(
44
+ {"GptOssForCausalLM"})
45
+
27
46
 
28
47
  class UnsupportedArchitectureError(ValueError):
29
48
  """Raised when a model architecture is not supported in the registry."""
@@ -210,6 +229,9 @@ def get_flax_model(
210
229
  mesh: Mesh,
211
230
  is_draft_model: bool = False,
212
231
  ) -> nnx.Module:
232
+ model_dtype = to_jax_dtype(vllm_config.model_config.dtype)
233
+ vllm_config.model_config.dtype = model_dtype
234
+
213
235
  if is_draft_model:
214
236
  model_class = _get_model_architecture(
215
237
  vllm_config.speculative_config.draft_model_config.hf_config)
@@ -218,7 +240,9 @@ def get_flax_model(
218
240
  vllm_config.model_config.hf_config)
219
241
  jit_model = _get_nnx_model(model_class, vllm_config, rng, mesh)
220
242
  kv_cache_sharding = NamedSharding(
221
- mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None, "model"))
243
+ mesh,
244
+ PartitionSpec(ShardingAxisName.ATTN_DATA, None,
245
+ ShardingAxisName.ATTN_HEAD))
222
246
  hidden_states_sharding = NamedSharding(mesh,
223
247
  PartitionSpec(
224
248
  ShardingAxisName.ATTN_DATA,
@@ -236,14 +260,17 @@ def get_flax_model(
236
260
  hidden_states_sharding, # aux hidden states
237
261
  ),
238
262
  donate_argnums=2, # 0 is graphdef, 1 is state, 2 is kv_cache
239
- static_argnums=7, #7 is layer_name_to_kvcache_index
263
+ static_argnums=(
264
+ 7, 10, 11
265
+ ), #7 is layer_name_to_kvcache_index, 10 is is_first_rank, 11 is is_last_rank
240
266
  )
241
267
  def run_model(graphdef, state, *args):
242
268
  model = nnx.merge(graphdef, state)
243
269
  return model(*args)
244
270
 
245
271
  logits_sharding = NamedSharding(
246
- mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, "model"))
272
+ mesh,
273
+ PartitionSpec(ShardingAxisName.MLP_DATA, ShardingAxisName.MLP_TENSOR))
247
274
 
248
275
  @functools.partial(
249
276
  jax.jit,
@@ -256,10 +283,9 @@ def get_flax_model(
256
283
 
257
284
  # Multi-modal support only
258
285
  # This function calculates the image token's embeddings by VIT
259
- def run_get_multimodal_embeddings(graphdef, state, image_grid_thw,
260
- **kwargs):
286
+ def run_embed_multimodal(graphdef, state, image_grid_thw, **kwargs):
261
287
  model = nnx.merge(graphdef, state)
262
- return model.get_multimodal_embeddings(image_grid_thw, **kwargs)
288
+ return model.embed_multimodal(image_grid_thw, **kwargs)
263
289
 
264
290
  embed_sharding = NamedSharding(mesh, PartitionSpec(None))
265
291
  # This function will calculates the embeddings of input texts and then merge with the image embeddings
@@ -267,9 +293,9 @@ def get_flax_model(
267
293
  jax.jit,
268
294
  out_shardings=(embed_sharding),
269
295
  )
270
- def run_get_input_embeddings(graphdef, state, *args, **kwargs):
296
+ def run_embed_input_ids(graphdef, state, *args, **kwargs):
271
297
  model = nnx.merge(graphdef, state)
272
- return model.get_input_embeddings(*args, **kwargs)
298
+ return model.embed_input_ids(*args, **kwargs)
273
299
 
274
300
  # For models that want to work with EAGLE-3 speculative decoding
275
301
  @functools.partial(
@@ -285,10 +311,8 @@ def get_flax_model(
285
311
  None)
286
312
  model_fn = functools.partial(run_model, graphdef)
287
313
  compute_logits_fn = functools.partial(run_compute_logits, graphdef)
288
- get_multimodal_embeddings_fn = functools.partial(
289
- run_get_multimodal_embeddings, graphdef)
290
- get_input_embeddings_fn = functools.partial(run_get_input_embeddings,
291
- graphdef)
314
+ embed_multimodal_fn = functools.partial(run_embed_multimodal, graphdef)
315
+ embed_input_ids_fn = functools.partial(run_embed_input_ids, graphdef)
292
316
  lora_manager, model = None, None
293
317
  combine_hidden_states_fn = functools.partial(combine_hidden_states,
294
318
  graphdef)
@@ -299,8 +323,8 @@ def get_flax_model(
299
323
 
300
324
  multimodal_fns = {
301
325
  "precompile_vision_encoder_fn": precompile_vision_encoder_fn,
302
- "get_multimodal_embeddings_fn": get_multimodal_embeddings_fn,
303
- "get_input_embeddings_fn": get_input_embeddings_fn,
326
+ "embed_multimodal_fn": embed_multimodal_fn,
327
+ "embed_input_ids_fn": embed_input_ids_fn,
304
328
  "get_mrope_input_positions_fn": get_mrope_input_positions_fn,
305
329
  }
306
330
 
@@ -312,6 +336,8 @@ def get_vllm_model(
312
336
  rng: jax.Array,
313
337
  mesh: Mesh,
314
338
  ):
339
+ model_dtype = to_torch_dtype(vllm_config.model_config.dtype)
340
+ vllm_config.model_config.dtype = model_dtype
315
341
  from tpu_inference.models.vllm.vllm_model_wrapper import VllmModelWrapper
316
342
 
317
343
  model = VllmModelWrapper(
@@ -337,24 +363,39 @@ def get_model(
337
363
  impl = envs.MODEL_IMPL_TYPE
338
364
  logger.info(f"Loading model with MODEL_IMPL_TYPE={impl}")
339
365
 
340
- if impl == "flax_nnx":
341
- try:
342
- # Try to load the flax model first
343
- return get_flax_model(vllm_config, rng, mesh, is_draft_model)
344
- except UnsupportedArchitectureError as e:
345
- # Convert the error message to a string to check its contents
346
- error_msg = str(e)
347
-
348
- logger.warning(error_msg)
349
-
350
- # Fall back to the vLLM model and updating the dtype accordingly
351
- vllm_config.model_config.dtype = j2t_dtype(
352
- vllm_config.model_config.dtype.dtype)
366
+ if impl == "auto":
367
+ # Resolve "auto" based on architecture
368
+ architectures = getattr(vllm_config.model_config.hf_config,
369
+ "architectures", [])
370
+ assert len(architectures) == 1, (
371
+ f"Expected exactly one architecture, got {len(architectures)}: "
372
+ f"{architectures}")
373
+ arch = architectures[0]
374
+ impl = "vllm" if arch in _VLLM_PREFERRED_ARCHITECTURES else "flax_nnx"
375
+ logger.info(f"Resolved MODEL_IMPL_TYPE 'auto' to '{impl}'")
376
+
377
+ match impl:
378
+ case "flax_nnx":
379
+ if vllm_config.parallel_config.pipeline_parallel_size > 1:
380
+ logger.warning(
381
+ "PP is not fully supported on Jax flax_nnx models yet, fallback to vllm models."
382
+ )
383
+ return get_vllm_model(vllm_config, rng, mesh)
384
+ try:
385
+ # Try to load the flax model first
386
+ return get_flax_model(vllm_config, rng, mesh, is_draft_model)
387
+ except UnsupportedArchitectureError as e:
388
+ # Convert the error message to a string to check its contents
389
+ error_msg = str(e)
390
+
391
+ logger.warning(error_msg)
392
+
393
+ # Fall back to the vLLM model and updating the dtype accordingly
394
+ return get_vllm_model(vllm_config, rng, mesh)
395
+ case "vllm":
353
396
  return get_vllm_model(vllm_config, rng, mesh)
354
- elif impl == "vllm":
355
- return get_vllm_model(vllm_config, rng, mesh)
356
- else:
357
- raise NotImplementedError("Unsupported MODEL_IMPL_TYPE")
397
+ case _:
398
+ raise NotImplementedError(f"Unsupported MODEL_IMPL_TYPE: {impl}")
358
399
 
359
400
 
360
401
  def _validate_model_interface(model: Any) -> None:
@@ -441,14 +482,14 @@ def register_model(arch: str, model: Any) -> None:
441
482
  )
442
483
 
443
484
  # Same as `forward`, this is a dummy method to satisfy vLLM's type checks.
444
- def unimplemented_get_input_embeddings(
485
+ def unimplemented_embed_input_ids(
445
486
  self,
446
487
  input_ids: "torch.Tensor",
447
488
  positions: "torch.Tensor",
448
489
  inputs_embeds: Optional["torch.Tensor"] = None,
449
490
  ) -> "torch.Tensor":
450
491
  raise NotImplementedError(
451
- "This is a JAX model and does not implement the PyTorch get_input_embeddings method."
492
+ "This is a JAX model and does not implement the PyTorch embed_input_ids method."
452
493
  )
453
494
 
454
495
  # We need a custom __init__ that only calls torch.nn.Module's init,
@@ -464,7 +505,7 @@ def register_model(arch: str, model: Any) -> None:
464
505
  {
465
506
  "__init__": wrapper_init,
466
507
  "forward": unimplemented_forward,
467
- "get_input_embeddings": unimplemented_get_input_embeddings,
508
+ "embed_input_ids": unimplemented_embed_input_ids,
468
509
  # Prevent vLLM from trying to load weights into this dummy class.
469
510
  "load_weights": lambda self, *args, **kwargs: None,
470
511
  })
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.