tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,520 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+ from typing import Any, Optional
17
+
18
+ import jax
19
+ import torch
20
+ from flax import nnx
21
+ from jax.sharding import Mesh, NamedSharding, PartitionSpec
22
+ from transformers import PretrainedConfig
23
+ from vllm.config import VllmConfig
24
+ from vllm.model_executor.model_loader import get_model_loader
25
+ from vllm.model_executor.model_loader.runai_streamer_loader import \
26
+ RunaiModelStreamerLoader
27
+ from vllm.utils.func_utils import supports_kw
28
+
29
+ from tpu_inference import envs
30
+ from tpu_inference.layers.common.sharding import ShardingAxisName
31
+ from tpu_inference.logger import init_logger
32
+ from tpu_inference.models.jax.utils.qwix.qwix_utils import (
33
+ apply_qwix_on_abstract_model, apply_qwix_quantization,
34
+ load_random_weights_into_qwix_abstract_model)
35
+ from tpu_inference.utils import to_jax_dtype, to_torch_dtype
36
+
37
+ logger = init_logger(__name__)
38
+
39
+ _MODEL_REGISTRY = {}
40
+
41
+ # List of architectures that are preferred to use "vllm" implementation over
42
+ # "flax_nnx" implementation due to various factors such as performance.
43
+ _VLLM_PREFERRED_ARCHITECTURES: frozenset[str] = frozenset(
44
+ {"GptOssForCausalLM"})
45
+
46
+
47
+ class UnsupportedArchitectureError(ValueError):
48
+ """Raised when a model architecture is not supported in the registry."""
49
+ pass
50
+
51
+
52
+ def _get_model_architecture(config: PretrainedConfig) -> nnx.Module:
53
+ # NOTE: Use inline imports here, otherwise the normal imports
54
+ # would cause JAX init failure when using multi hosts with Ray.
55
+
56
+ from tpu_inference.models.jax.deepseek_v3 import DeepSeekV3
57
+ from tpu_inference.models.jax.gpt_oss import GptOss
58
+ from tpu_inference.models.jax.llama3 import LlamaForCausalLM
59
+ from tpu_inference.models.jax.llama4 import Llama4ForCausalLM
60
+ from tpu_inference.models.jax.llama_eagle3 import EagleLlama3ForCausalLM
61
+ from tpu_inference.models.jax.llama_guard_4 import LlamaGuard4ForCausalLM
62
+ from tpu_inference.models.jax.qwen2_5_vl import \
63
+ Qwen2_5_VLForConditionalGeneration
64
+ from tpu_inference.models.jax.qwen3 import Qwen3ForCausalLM
65
+ _MODEL_REGISTRY["Llama4ForCausalLM"] = Llama4ForCausalLM
66
+ _MODEL_REGISTRY["DeepseekV3ForCausalLM"] = DeepSeekV3
67
+ _MODEL_REGISTRY["LlamaForCausalLM"] = LlamaForCausalLM
68
+ _MODEL_REGISTRY["Llama4ForConditionalGeneration"] = LlamaGuard4ForCausalLM
69
+ _MODEL_REGISTRY["Qwen3ForCausalLM"] = Qwen3ForCausalLM
70
+ _MODEL_REGISTRY[
71
+ "Qwen2_5_VLForConditionalGeneration"] = Qwen2_5_VLForConditionalGeneration
72
+ _MODEL_REGISTRY["Eagle3LlamaForCausalLM"] = EagleLlama3ForCausalLM
73
+ _MODEL_REGISTRY["GptOssForCausalLM"] = GptOss
74
+
75
+ architectures = getattr(config, "architectures", [])
76
+ for arch in architectures:
77
+ if arch in _MODEL_REGISTRY:
78
+ return _MODEL_REGISTRY[arch]
79
+ raise UnsupportedArchitectureError(
80
+ f"Model architectures {architectures} not "
81
+ "registered in tpu-inference. Falling back to vLLM-native "
82
+ f"Pytorch definition. JAX-native architectures: {list(_MODEL_REGISTRY.keys())}"
83
+ )
84
+
85
+
86
+ def _get_nnx_model(
87
+ model_class: Any,
88
+ vllm_config: VllmConfig,
89
+ rng: jax.Array,
90
+ mesh: Mesh,
91
+ ) -> nnx.Module:
92
+
93
+ def create_abstract_model() -> nnx.Module:
94
+ """
95
+ Helper class to create an abstract model for `nnx.eval_shape`.
96
+
97
+ Returns:
98
+ An abstract model function.
99
+ """
100
+ return model_class(vllm_config, rng, mesh)
101
+
102
+ @nnx.jit(donate_argnums=(0, ),
103
+ static_argnames=('use_qwix_on_abstract_model', ))
104
+ def create_jit_model(
105
+ model: nnx.Module,
106
+ use_qwix_on_abstract_model: bool = False) -> nnx.Module:
107
+ """
108
+ Create a jit model.
109
+
110
+ Args:
111
+ model: The model to jit.
112
+ use_qwix_on_abstract_model: Whether to apply Qwix on the abstract model.
113
+
114
+ Returns:
115
+ The jitted model.
116
+ """
117
+ state = nnx.state(model)
118
+ nnx.update(model, state)
119
+ if not use_qwix_on_abstract_model:
120
+ # NOTE: if Qwix is not configured, this will be a no-op
121
+ model = apply_qwix_quantization(vllm_config,
122
+ model,
123
+ rng,
124
+ mesh,
125
+ apply_to_abstract_model=False)
126
+ return model
127
+
128
+ if vllm_config.load_config.load_format == "dummy":
129
+ # Create a sharded model with random inited weights.
130
+ # TODO: currently Qwen2ForCausalLM is using legacy model implementation
131
+ # will merge the random init logic when all model are migrated to new model implementation
132
+
133
+ # Handle the case where we want to load in random weights to a Qwix-quantized model. Here, we
134
+ # need to run an abstract pass for Qwix first and then load in the random weights.
135
+ if apply_qwix_on_abstract_model(vllm_config):
136
+ abstract_model_fn = apply_qwix_quantization(
137
+ vllm_config,
138
+ create_abstract_model,
139
+ rng,
140
+ mesh,
141
+ apply_to_abstract_model=True)
142
+
143
+ model = nnx.eval_shape(abstract_model_fn)
144
+ quantization_config = vllm_config.model_config.hf_config.quantization_config if hasattr(
145
+ vllm_config.model_config.hf_config,
146
+ "quantization_config") else {}
147
+ load_random_weights_into_qwix_abstract_model(
148
+ rng, model, mesh, quantization_config)
149
+ with mesh:
150
+ jit_model = create_jit_model(model,
151
+ use_qwix_on_abstract_model=True)
152
+ return jit_model
153
+
154
+ @nnx.jit
155
+ def create_sharded_model():
156
+ model = model_class(vllm_config, rng, mesh)
157
+ state = nnx.state(model)
158
+ pspecs = nnx.get_partition_spec(state)
159
+ sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
160
+ nnx.update(model, sharded_state)
161
+ # NOTE: we don't support quantization for the old Qwen2ForCausalLM implementation
162
+ return model
163
+
164
+ with mesh:
165
+ jit_model = create_sharded_model()
166
+ # In this case, we are applying Qwix quantization to the true, concrete model
167
+ jit_model = apply_qwix_quantization(vllm_config,
168
+ jit_model,
169
+ rng,
170
+ mesh,
171
+ apply_to_abstract_model=False)
172
+ if hasattr(jit_model, 'initialize_cache'):
173
+ jit_model.initialize_cache()
174
+ else:
175
+ # We first create an abstract model without allocating any weights,
176
+ # then fill in its weigths during load_weights from HF.
177
+ # This shows 2 advantages than the normal way:
178
+ # 1. The model weights will only be allocated once. Otherwise the normal way
179
+ # will random-init the model weights first, then load the real weights.
180
+ # The two pass weights allocation causes model loading slow.
181
+ # 2. The model loading won't be OOM. Otherwise the normal way will hold
182
+ # a full model weights after random-init, then duplicate a layer during
183
+ # the load_weights. This would be easy to OOM if the layer is super large.
184
+ abstract_model_fn = create_abstract_model
185
+ # NOTE: only one of the abstract (this) or or concrete Qwix quantization paths should
186
+ # be taken
187
+ if should_apply_qwix_on_abstract_model := apply_qwix_on_abstract_model(
188
+ vllm_config):
189
+ # NOTE: if Qwix is not configured, this will return `create_abstract_model` and
190
+ # thus be a no-op
191
+ abstract_model_fn = apply_qwix_quantization(
192
+ vllm_config,
193
+ create_abstract_model,
194
+ rng,
195
+ mesh,
196
+ apply_to_abstract_model=True)
197
+ model = nnx.eval_shape(abstract_model_fn)
198
+ # Although the created model can already work, we still need to jit
199
+ # the model creation again, otherwise the model forward will have
200
+ # non-trivial overhead in PjitFunction.
201
+ with mesh:
202
+ loader = get_model_loader(vllm_config.load_config)
203
+ if isinstance(loader, RunaiModelStreamerLoader):
204
+ model_weights = vllm_config.model_config.model
205
+ if hasattr(vllm_config.model_config, "model_weights"):
206
+ model_weights = vllm_config.model_config.model_weights
207
+ weights_iterator = loader._get_weights_iterator(
208
+ model_weights, vllm_config.model_config.revision)
209
+ # We set the weights iterator at runtime, to prevent having to change
210
+ # every model's load_weights signature. This also prevents us from hitting
211
+ # a TypeError at runtime if you use the RunaiModelStreamerLoader with any
212
+ # flax_nnx model whose load_weights function does not accept the
213
+ # weights_iterator keyword argument.
214
+ vllm_config.model_config.model_weights_iterator = weights_iterator
215
+ model.load_weights(rng)
216
+ del vllm_config.model_config.model_weights_iterator
217
+ else:
218
+ model.load_weights(rng)
219
+ jit_model = create_jit_model(
220
+ model,
221
+ use_qwix_on_abstract_model=should_apply_qwix_on_abstract_model)
222
+ return jit_model
223
+
224
+
225
+ # TODO(pooyam): We need to refactor this. This is returning a bunch of functions that do not work with all models and this is not very easy to see from the code.
226
+ def get_flax_model(
227
+ vllm_config: VllmConfig,
228
+ rng: jax.Array,
229
+ mesh: Mesh,
230
+ is_draft_model: bool = False,
231
+ ) -> nnx.Module:
232
+ model_dtype = to_jax_dtype(vllm_config.model_config.dtype)
233
+ vllm_config.model_config.dtype = model_dtype
234
+
235
+ if is_draft_model:
236
+ model_class = _get_model_architecture(
237
+ vllm_config.speculative_config.draft_model_config.hf_config)
238
+ else:
239
+ model_class = _get_model_architecture(
240
+ vllm_config.model_config.hf_config)
241
+ jit_model = _get_nnx_model(model_class, vllm_config, rng, mesh)
242
+ kv_cache_sharding = NamedSharding(
243
+ mesh,
244
+ PartitionSpec(ShardingAxisName.ATTN_DATA, None,
245
+ ShardingAxisName.ATTN_HEAD))
246
+ hidden_states_sharding = NamedSharding(mesh,
247
+ PartitionSpec(
248
+ ShardingAxisName.ATTN_DATA,
249
+ None)) # (T, D)
250
+
251
+ # For performance consideration, refer to:
252
+ # https://flax.readthedocs.io/en/latest/guides/performance.html
253
+ graphdef, state = nnx.split(jit_model)
254
+
255
+ @functools.partial(
256
+ jax.jit,
257
+ out_shardings=(
258
+ kv_cache_sharding,
259
+ hidden_states_sharding,
260
+ hidden_states_sharding, # aux hidden states
261
+ ),
262
+ donate_argnums=2, # 0 is graphdef, 1 is state, 2 is kv_cache
263
+ static_argnums=(
264
+ 7, 10, 11
265
+ ), #7 is layer_name_to_kvcache_index, 10 is is_first_rank, 11 is is_last_rank
266
+ )
267
+ def run_model(graphdef, state, *args):
268
+ model = nnx.merge(graphdef, state)
269
+ return model(*args)
270
+
271
+ logits_sharding = NamedSharding(
272
+ mesh,
273
+ PartitionSpec(ShardingAxisName.MLP_DATA, ShardingAxisName.MLP_TENSOR))
274
+
275
+ @functools.partial(
276
+ jax.jit,
277
+ out_shardings=(logits_sharding),
278
+ )
279
+ def run_compute_logits(graphdef, state, *args):
280
+ model = nnx.merge(graphdef, state)
281
+ hidden_state, *_ = args
282
+ return model.compute_logits(hidden_state)
283
+
284
+ # Multi-modal support only
285
+ # This function calculates the image token's embeddings by VIT
286
+ def run_get_multimodal_embeddings(graphdef, state, image_grid_thw,
287
+ **kwargs):
288
+ model = nnx.merge(graphdef, state)
289
+ return model.get_multimodal_embeddings(image_grid_thw, **kwargs)
290
+
291
+ embed_sharding = NamedSharding(mesh, PartitionSpec(None))
292
+ # This function will calculates the embeddings of input texts and then merge with the image embeddings
293
+ @functools.partial(
294
+ jax.jit,
295
+ out_shardings=(embed_sharding),
296
+ )
297
+ def run_get_input_embeddings(graphdef, state, *args, **kwargs):
298
+ model = nnx.merge(graphdef, state)
299
+ return model.get_input_embeddings(*args, **kwargs)
300
+
301
+ # For models that want to work with EAGLE-3 speculative decoding
302
+ @functools.partial(
303
+ jax.jit,
304
+ out_shardings=(logits_sharding),
305
+ )
306
+ def combine_hidden_states(graphdef, state, hidden_states):
307
+ model = nnx.merge(graphdef, state)
308
+ return model.combine_hidden_states(hidden_states)
309
+
310
+ model = nnx.merge(graphdef, state)
311
+ precompile_vision_encoder_fn = getattr(model, "precompile_vision_encoder",
312
+ None)
313
+ model_fn = functools.partial(run_model, graphdef)
314
+ compute_logits_fn = functools.partial(run_compute_logits, graphdef)
315
+ get_multimodal_embeddings_fn = functools.partial(
316
+ run_get_multimodal_embeddings, graphdef)
317
+ get_input_embeddings_fn = functools.partial(run_get_input_embeddings,
318
+ graphdef)
319
+ lora_manager, model = None, None
320
+ combine_hidden_states_fn = functools.partial(combine_hidden_states,
321
+ graphdef)
322
+
323
+ get_mrope_input_positions_fn = None if not hasattr(
324
+ jit_model,
325
+ "get_mrope_input_positions") else jit_model.get_mrope_input_positions
326
+
327
+ multimodal_fns = {
328
+ "precompile_vision_encoder_fn": precompile_vision_encoder_fn,
329
+ "get_multimodal_embeddings_fn": get_multimodal_embeddings_fn,
330
+ "get_input_embeddings_fn": get_input_embeddings_fn,
331
+ "get_mrope_input_positions_fn": get_mrope_input_positions_fn,
332
+ }
333
+
334
+ return model_fn, compute_logits_fn, combine_hidden_states_fn, multimodal_fns, state, lora_manager, model
335
+
336
+
337
+ def get_vllm_model(
338
+ vllm_config: VllmConfig,
339
+ rng: jax.Array,
340
+ mesh: Mesh,
341
+ ):
342
+ model_dtype = to_torch_dtype(vllm_config.model_config.dtype)
343
+ vllm_config.model_config.dtype = model_dtype
344
+ from tpu_inference.models.vllm.vllm_model_wrapper import VllmModelWrapper
345
+
346
+ model = VllmModelWrapper(
347
+ vllm_config=vllm_config,
348
+ rng=rng,
349
+ mesh=mesh,
350
+ )
351
+ params, lora_manager = model.load_weights()
352
+
353
+ jit_model = model.jit_step_func()
354
+ compute_logits_fn = model.jit_compute_logits_func()
355
+ # the model needs to be returned because lora weights are neither torch.nn.parameter nor torch.nn.buffer. After we load the lora weights and set it to the torch.nn.Module, we can shard it and move it to TPU.
356
+ combine_hidden_states_fn = None
357
+ return jit_model, compute_logits_fn, combine_hidden_states_fn, None, params, lora_manager, model
358
+
359
+
360
+ def get_model(
361
+ vllm_config: VllmConfig,
362
+ rng: jax.Array,
363
+ mesh: Mesh,
364
+ is_draft_model: bool = False,
365
+ ) -> Any:
366
+ impl = envs.MODEL_IMPL_TYPE
367
+ logger.info(f"Loading model with MODEL_IMPL_TYPE={impl}")
368
+
369
+ if impl == "auto":
370
+ # Resolve "auto" based on architecture
371
+ architectures = getattr(vllm_config.model_config.hf_config,
372
+ "architectures", [])
373
+ assert len(architectures) == 1, (
374
+ f"Expected exactly one architecture, got {len(architectures)}: "
375
+ f"{architectures}")
376
+ arch = architectures[0]
377
+ impl = "vllm" if arch in _VLLM_PREFERRED_ARCHITECTURES else "flax_nnx"
378
+ logger.info(f"Resolved MODEL_IMPL_TYPE 'auto' to '{impl}'")
379
+
380
+ match impl:
381
+ case "flax_nnx":
382
+ if vllm_config.parallel_config.pipeline_parallel_size > 1:
383
+ logger.warning(
384
+ "PP is not fully supported on Jax flax_nnx models yet, fallback to vllm models."
385
+ )
386
+ return get_vllm_model(vllm_config, rng, mesh)
387
+ try:
388
+ # Try to load the flax model first
389
+ return get_flax_model(vllm_config, rng, mesh, is_draft_model)
390
+ except UnsupportedArchitectureError as e:
391
+ # Convert the error message to a string to check its contents
392
+ error_msg = str(e)
393
+
394
+ logger.warning(error_msg)
395
+
396
+ # Fall back to the vLLM model and updating the dtype accordingly
397
+ return get_vllm_model(vllm_config, rng, mesh)
398
+ case "vllm":
399
+ return get_vllm_model(vllm_config, rng, mesh)
400
+ case _:
401
+ raise NotImplementedError(f"Unsupported MODEL_IMPL_TYPE: {impl}")
402
+
403
+
404
+ def _validate_model_interface(model: Any) -> None:
405
+ """Validates that the model class has the required methods and signatures.
406
+
407
+ A valid model must have:
408
+ - An __init__ method that accepts a 'vllm_config' keyword argument.
409
+ - A __call__ method that accepts 'kv_caches', 'input_ids', and
410
+ 'attention_metadata' keyword arguments.
411
+
412
+ Args:
413
+ model: The model class to validate.
414
+
415
+ Raises:
416
+ TypeError: If the model does not meet the interface requirements.
417
+ """
418
+ # Check for __init__ with vllm_config
419
+ model_init = getattr(model, "__init__", None)
420
+ if not callable(model_init):
421
+ raise TypeError(
422
+ f"Model {model.__name__} must have an __init__ method.")
423
+
424
+ if not supports_kw(model_init, "vllm_config"):
425
+ raise TypeError(
426
+ f"Model {model.__name__} __init__ method must accept a "
427
+ "'vllm_config' keyword argument.")
428
+
429
+ # Check for __call__ with required arguments
430
+ model_call = getattr(model, "__call__", None)
431
+ # A class object is always callable (it produces an instance).
432
+ # We need to check if the class _explicitly_ defines a __call__ method for its
433
+ # instance, which is different from `type.__call__`.
434
+ has_defined_call = False
435
+ if isinstance(model, type):
436
+ if any("__call__" in C.__dict__ for C in model.__mro__):
437
+ has_defined_call = True
438
+ elif callable(model_call):
439
+ # For an instance, a simple callable check is sufficient.
440
+ has_defined_call = True
441
+
442
+ if not has_defined_call:
443
+ raise TypeError(f"Model {model.__name__} must have a __call__ method.")
444
+
445
+ required_call_args = ("kv_caches", "input_ids", "attention_metadata")
446
+ missing_args = tuple(arg for arg in required_call_args
447
+ if not supports_kw(model_call, arg))
448
+
449
+ if missing_args:
450
+ raise TypeError(
451
+ f"Model {model.__name__} __call__ method is missing required "
452
+ f"keyword arguments: {missing_args}")
453
+
454
+
455
+ def register_model(arch: str, model: Any) -> None:
456
+ """
457
+ Registers a model class for a given architecture name.
458
+
459
+ This function registers the model with both the tpu_inference registry
460
+ and the vLLM registry. For vLLM, it creates a compatible wrapper
461
+ around the JAX model.
462
+
463
+ Args:
464
+ arch: The name of the architecture (e.g., "LlamaForCausalLM").
465
+ model: The JAX model class to register (e.g., a flax.nnx.Module).
466
+ """
467
+ _validate_model_interface(model)
468
+
469
+ # Register with tpu_inference registry for the JAX backend
470
+ _MODEL_REGISTRY[arch] = model
471
+
472
+ # Create a vLLM-compatible wrapper for the JAX model class.
473
+ # This wrapper inherits from the JAX model and torch.nn.Module
474
+ # to pass vLLM's type checks. It is not meant to be instantiated
475
+ # or executed by vLLM's PyTorch backend.
476
+ def unimplemented_forward(
477
+ self,
478
+ input_ids: "torch.Tensor",
479
+ positions: "torch.Tensor",
480
+ intermediate_tensors: Optional[Any] = None,
481
+ inputs_embeds: Optional["torch.Tensor"] = None,
482
+ ) -> None:
483
+ raise NotImplementedError(
484
+ "This is a JAX model and does not implement the PyTorch forward method."
485
+ )
486
+
487
+ # Same as `forward`, this is a dummy method to satisfy vLLM's type checks.
488
+ def unimplemented_get_input_embeddings(
489
+ self,
490
+ input_ids: "torch.Tensor",
491
+ positions: "torch.Tensor",
492
+ inputs_embeds: Optional["torch.Tensor"] = None,
493
+ ) -> "torch.Tensor":
494
+ raise NotImplementedError(
495
+ "This is a JAX model and does not implement the PyTorch get_input_embeddings method."
496
+ )
497
+
498
+ # We need a custom __init__ that only calls torch.nn.Module's init,
499
+ # to avoid triggering JAX logic when vLLM inspects the class.
500
+ def wrapper_init(self, *args, **kwargs):
501
+ torch.nn.Module.__init__(self)
502
+
503
+ # Dynamically create the wrapper class that is a subclass of both the
504
+ # JAX model and torch.nn.Module.
505
+ VllmCompatibleModel = type(
506
+ f"VllmCompatible{model.__name__}",
507
+ (model, torch.nn.Module),
508
+ {
509
+ "__init__": wrapper_init,
510
+ "forward": unimplemented_forward,
511
+ "get_input_embeddings": unimplemented_get_input_embeddings,
512
+ # Prevent vLLM from trying to load weights into this dummy class.
513
+ "load_weights": lambda self, *args, **kwargs: None,
514
+ })
515
+
516
+ # Register the wrapped model with vLLM's registry.
517
+ from vllm.model_executor.models.registry import ModelRegistry
518
+ ModelRegistry.register_model(arch, VllmCompatibleModel)
519
+ logger.info(
520
+ f"Registered JAX model {arch} with tpu_inference and vLLM registries.")
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.