tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.2rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (250) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +78 -1
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +1 -43
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +14 -9
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +38 -7
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +17 -0
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +28 -5
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +74 -35
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +89 -26
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -64
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +72 -37
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +46 -17
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +14 -0
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +44 -17
  232. tpu_inference/spec_decode/__init__.py +13 -0
  233. tpu_inference/spec_decode/jax/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  235. tpu_inference/tpu_info.py +14 -0
  236. tpu_inference/utils.py +42 -36
  237. tpu_inference/worker/__init__.py +13 -0
  238. tpu_inference/worker/tpu_worker.py +63 -50
  239. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/METADATA +7 -9
  240. tpu_inference-0.13.2rc3.dist-info/RECORD +261 -0
  241. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  242. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  245. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  246. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  247. tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
  248. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/WHEEL +0 -0
  249. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/licenses/LICENSE +0 -0
  250. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,169 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from unittest.mock import MagicMock
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ import numpy as np
20
+ import pytest
21
+ from flax import nnx
22
+ from flax.typing import PRNGKey
23
+ from jax.sharding import Mesh
24
+ from vllm.config import ModelConfig
25
+
26
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
27
+ from tpu_inference.models.jax.qwen3 import Qwen3ForCausalLM
28
+ from tpu_inference.runner.kv_cache import create_kv_caches
29
+
30
+
31
+ class MockVllmConfig:
32
+
33
+ def __init__(self, model: str, kv_cache_dtype: str):
34
+ self.model_config = ModelConfig(model)
35
+ self.model_config.dtype = jnp.bfloat16
36
+ self.load_config = MagicMock()
37
+ self.load_config.download_dir = None
38
+ self.cache_config = MagicMock(cache_dtype=kv_cache_dtype)
39
+
40
+
41
+ @pytest.fixture(scope="module")
42
+ def mesh():
43
+ """
44
+ Creates a mesh with 1 device.
45
+ """
46
+ if not jax.devices():
47
+ pytest.skip("No JAX devices available for mesh creation.")
48
+
49
+ devices = np.array(jax.local_devices()[:1])
50
+ num_devices = len(devices)
51
+ assert num_devices == 1
52
+ device_mesh = devices.reshape((num_devices, 1, 1, 1))
53
+
54
+ with Mesh(device_mesh,
55
+ axis_names=('data', 'attn_dp', 'expert', 'model')) as m:
56
+ yield m
57
+
58
+
59
+ @pytest.fixture
60
+ def mock_model_inputs():
61
+ num_tokens = 8
62
+ num_reqs = 1
63
+ max_num_blocks_per_req = 4
64
+ input_ids = jnp.ones((num_tokens, ), dtype=jnp.int32)
65
+ positions = jnp.ones((num_tokens, ), dtype=jnp.int32)
66
+ block_tables = jnp.zeros((num_reqs, max_num_blocks_per_req),
67
+ dtype=jnp.int32).reshape(-1)
68
+ seq_lens = jnp.ones((num_reqs, ), dtype=jnp.int32)
69
+ query_start_loc = jnp.ones((num_reqs + 1, ), dtype=jnp.int32)
70
+ request_distribution = jnp.array([0, 0, 0], dtype=jnp.int32)
71
+
72
+ attention_metadata = AttentionMetadata(
73
+ input_positions=positions,
74
+ block_tables=block_tables,
75
+ seq_lens=seq_lens,
76
+ query_start_loc=query_start_loc,
77
+ request_distribution=request_distribution,
78
+ )
79
+ indices_do_sample = jnp.ones((num_reqs, ), dtype=jnp.int32)
80
+
81
+ return (input_ids, attention_metadata, indices_do_sample)
82
+
83
+
84
+ @pytest.fixture
85
+ def rng() -> PRNGKey:
86
+ """Provides a reusable JAX PRNGKey."""
87
+ return jax.random.PRNGKey(42)
88
+
89
+
90
+ class TestQwen3ForCausalLM:
91
+
92
+ @pytest.mark.parametrize("mock_vllm_config", [
93
+ MockVllmConfig("Qwen/Qwen3-0.6B", "auto"),
94
+ MockVllmConfig("Qwen/Qwen3-0.6B", "fp8")
95
+ ])
96
+ def test_qwen3_600M(self, mock_vllm_config, rng, mesh, mock_model_inputs):
97
+ """Tests model init and model forward for the 0.6B model variant."""
98
+
99
+ # Test model init
100
+ model = Qwen3ForCausalLM(mock_vllm_config, rng, mesh)
101
+
102
+ model_config = mock_vllm_config.model_config
103
+ hf_config = model_config.hf_config
104
+
105
+ assert model.mesh.shape == {
106
+ "data": 1,
107
+ "attn_dp": 1,
108
+ "expert": 1,
109
+ "model": 1
110
+ }
111
+
112
+ layers = model.model.layers
113
+ assert len(layers) == hf_config.num_hidden_layers
114
+ assert isinstance(model.rng, nnx.Rngs)
115
+ assert model.model.lm_head == model.model.embed.embedding
116
+
117
+ attn = layers[0].self_attn
118
+ hidden_size = hf_config.hidden_size
119
+ num_heads = hf_config.num_attention_heads
120
+ num_kv_heads = hf_config.num_key_value_heads
121
+ rope_theta = hf_config.rope_theta
122
+ original_head_dim = hf_config.head_dim
123
+ head_dim = 128
124
+ intermediate_size = hf_config.intermediate_size
125
+
126
+ assert attn.hidden_size == hidden_size
127
+ assert attn.num_heads == num_heads
128
+ assert attn.num_kv_heads == num_kv_heads
129
+ assert attn.rope_theta == rope_theta
130
+ assert attn.head_dim_original == original_head_dim
131
+ assert attn.head_dim == head_dim
132
+ assert attn.q_proj.kernel.shape == (hidden_size, num_heads, head_dim)
133
+ assert attn.k_proj.kernel.shape == (hidden_size, num_kv_heads,
134
+ head_dim)
135
+ assert attn.v_proj.kernel.shape == (hidden_size, num_kv_heads,
136
+ head_dim)
137
+ assert attn.o_proj.kernel.shape == (num_heads, head_dim, hidden_size)
138
+
139
+ mlp = layers[0].mlp
140
+ assert mlp.gate_proj.kernel.shape == (hidden_size, intermediate_size)
141
+ assert mlp.up_proj.kernel.shape == (hidden_size, intermediate_size)
142
+ assert mlp.down_proj.kernel.shape == (intermediate_size, hidden_size)
143
+
144
+ # Test model load
145
+ model.load_weights(rng)
146
+
147
+ # Test model forward
148
+ kv_caches = create_kv_caches(
149
+ num_blocks=4,
150
+ block_size=32,
151
+ num_kv_heads=num_kv_heads,
152
+ head_size=head_dim,
153
+ mesh=mesh,
154
+ layer_names=["layer"] * hf_config.num_hidden_layers,
155
+ cache_dtype=jnp.float8_e4m3fn
156
+ if mock_vllm_config.cache_config.cache_dtype == "fp8" else
157
+ jnp.bfloat16)
158
+ # 1 seq with 16 tokens
159
+ input_ids, attention_metadata, indices_do_sample = mock_model_inputs
160
+ kv_caches, hidden_states, aux_hidden_states = model(
161
+ kv_caches, input_ids, attention_metadata)
162
+ assert hidden_states.shape == (8, hidden_size)
163
+ assert len(aux_hidden_states) == 0
164
+
165
+ hidden_states = hidden_states[indices_do_sample]
166
+ assert hidden_states.shape == (1, hidden_size)
167
+
168
+ logits = model.compute_logits(hidden_states)
169
+ assert logits.shape == (1, hf_config.vocab_size)
@@ -0,0 +1,180 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Test for LoRA weight loading API
3
+
4
+ import os
5
+ import tempfile
6
+ from dataclasses import dataclass
7
+ from typing import Any
8
+
9
+ import jax
10
+ import jax.numpy as jnp
11
+ import numpy as np
12
+ from flax import nnx
13
+ from jax._src import test_util as jtu
14
+ from jax.sharding import Mesh
15
+ from safetensors.numpy import save_file
16
+
17
+ from tpu_inference.models.jax.utils.weight_utils import (
18
+ MetadataMap, load_hf_weights, transfer_state_with_mappings)
19
+
20
+ # ----- nnx.Module Wrappers -----
21
+
22
+
23
+ class SourceLayer(nnx.Module):
24
+
25
+ def __init__(self, rngs):
26
+ self.kernel = nnx.Param(jax.random.normal(rngs(), (4, 4)))
27
+ self.bias = nnx.Param(jax.random.normal(rngs(), (4, )))
28
+
29
+
30
+ class SourceModel(nnx.Module):
31
+
32
+ def __init__(self, rngs):
33
+ self.src_lm_head = nnx.Param(jax.random.normal(rngs(), (2, 4)))
34
+ self.layers = {0: SourceLayer(rngs)}
35
+
36
+
37
+ class TargetLinear(nnx.Module):
38
+
39
+ def __init__(self, rngs):
40
+ self.kernel = nnx.Param(jnp.zeros((4, 4)))
41
+ self.bias = nnx.Param(jnp.zeros((4, )))
42
+
43
+
44
+ class TargetBlock(nnx.Module):
45
+
46
+ def __init__(self, rngs):
47
+ self.mlp = {"up_proj": TargetLinear(rngs)}
48
+
49
+
50
+ class TargetModel(nnx.Module):
51
+
52
+ def __init__(self, rngs):
53
+ self.tgt_lm_head = nnx.Param(jnp.zeros((2, 4)))
54
+ self.model = {"layers": {0: TargetBlock(rngs)}}
55
+
56
+
57
+ # ----- Test -----
58
+ class WeightTransfer(jtu.JaxTestCase):
59
+
60
+ def test_transfer_state(self):
61
+ rng = nnx.Rngs(0)
62
+ src_model = SourceModel(rng)
63
+ tgt_model = TargetModel(rng)
64
+
65
+ # Get split states
66
+ _, src_state = nnx.split(src_model)
67
+ _, tgt_state = nnx.split(tgt_model)
68
+
69
+ # Overwrite known values
70
+ src_state["layers"][0]["kernel"].value = jnp.ones((4, 4)) * 42.0
71
+ src_state["layers"][0]["bias"].value = jnp.ones((4, )) * 7.0
72
+ src_state["src_lm_head"].value = jnp.ones((2, 4)) * 6.0
73
+ # Mapping for both kernel and bias
74
+ mappings = {
75
+ "layers.*.kernel": ("model.layers.*.mlp.up_proj.kernel", (None, )),
76
+ "layers.*.bias": ("model.layers.*.mlp.up_proj.bias", (None, )),
77
+ "src_lm_head": ("tgt_lm_head", (None, None)),
78
+ }
79
+
80
+ # Transfer
81
+ new_tgt_state = transfer_state_with_mappings(src_state, tgt_state,
82
+ mappings)
83
+
84
+ # Assert correctness
85
+ assert jnp.allclose(
86
+ new_tgt_state["model"]["layers"][0]["mlp"]["up_proj"]
87
+ ["kernel"].value, 42.0)
88
+ assert jnp.allclose(
89
+ new_tgt_state["model"]["layers"][0]["mlp"]["up_proj"]
90
+ ["bias"].value, 7.0)
91
+ assert jnp.allclose(new_tgt_state["tgt_lm_head"].value, 6.0)
92
+
93
+
94
+ # ----- Mocks for dtype test -----
95
+
96
+
97
+ class DtypeTestModel(nnx.Module):
98
+
99
+ def __init__(self, dtype: jnp.dtype, rngs: nnx.Rngs):
100
+ self.weight_to_cast = nnx.Param(jnp.zeros((2, 2), dtype=dtype))
101
+ self.weight_to_keep = nnx.Param(jnp.zeros((2, 2), dtype=dtype))
102
+
103
+
104
+ @dataclass
105
+ class MockModelConfig:
106
+ model: str
107
+ dtype: jnp.dtype
108
+ hf_config: Any = None
109
+
110
+ def get_vocab_size(self):
111
+ return 1
112
+
113
+ def get_hidden_size(self):
114
+ return 1
115
+
116
+ def get_head_size(self):
117
+ return 1
118
+
119
+ is_multimodal_model: bool = False
120
+
121
+
122
+ @dataclass
123
+ class MockLoadConfig:
124
+ download_dir: str
125
+
126
+
127
+ @dataclass
128
+ class MockVllmConfig:
129
+ model_config: MockModelConfig
130
+ load_config: MockLoadConfig
131
+ speculative_config: Any = None
132
+
133
+
134
+ class WeightLoadingDtypeTest(jtu.JaxTestCase):
135
+
136
+ def setUp(self):
137
+ super().setUp()
138
+ self.tempdir = tempfile.TemporaryDirectory()
139
+ self.addCleanup(self.tempdir.cleanup)
140
+
141
+ # Create dummy safetensors file
142
+ tensors = {
143
+ "weight_to_cast.weight": np.ones((2, 2), dtype=np.float32),
144
+ "weight_to_keep.weight": np.ones((2, 2), dtype=np.float32),
145
+ }
146
+ self.safetensors_path = os.path.join(self.tempdir.name,
147
+ "model.safetensors")
148
+ save_file(tensors, self.safetensors_path)
149
+
150
+ def test_keep_original_dtype(self):
151
+ rng = nnx.Rngs(0)
152
+ model_dtype = jnp.bfloat16
153
+ model = DtypeTestModel(dtype=model_dtype, rngs=rng)
154
+
155
+ mock_model_config = MockModelConfig(model=self.tempdir.name,
156
+ dtype=model_dtype)
157
+ mock_load_config = MockLoadConfig(download_dir=self.tempdir.name)
158
+ vllm_config = MockVllmConfig(model_config=mock_model_config,
159
+ load_config=mock_load_config)
160
+
161
+ mesh = Mesh(jax.devices(), ("model", ))
162
+
163
+ name_map = {
164
+ "weight_to_cast": "weight_to_cast",
165
+ "weight_to_keep": "weight_to_keep",
166
+ }
167
+ metadata_map = MetadataMap(name_map=name_map)
168
+
169
+ keep_original_dtype_keys_regex = [r"weight_to_keep.*"]
170
+
171
+ load_hf_weights(
172
+ vllm_config=vllm_config,
173
+ model=model,
174
+ metadata_map=metadata_map,
175
+ mesh=mesh,
176
+ keep_original_dtype_keys_regex=keep_original_dtype_keys_regex,
177
+ )
178
+
179
+ self.assertEqual(model.weight_to_cast.value.dtype, model_dtype)
180
+ self.assertEqual(model.weight_to_keep.value.dtype, jnp.float32)
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,212 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # test_multi_modal_utils.py
16
+ import jax.numpy as jnp
17
+ import numpy as np
18
+ import pytest
19
+
20
+ from tpu_inference.models.jax.utils.multi_modal_utils import (
21
+ MultiModalEmbeddings, NestedTensors, flatten_embeddings,
22
+ merge_multimodal_embeddings, sanity_check_mm_encoder_outputs)
23
+
24
+ # --- Tests for sanity_check_mm_encoder_outputs ---
25
+
26
+
27
+ def test_sanity_check_valid_list():
28
+ """Tests sanity_check with a valid list of 2D embeddings."""
29
+ embeddings: MultiModalEmbeddings = [
30
+ jnp.ones((10, 128)), jnp.ones((15, 128))
31
+ ]
32
+ sanity_check_mm_encoder_outputs(embeddings, 2)
33
+ # No assertion error expected
34
+
35
+
36
+ def test_sanity_check_valid_tuple():
37
+ """Tests sanity_check with a valid tuple of 2D embeddings."""
38
+ embeddings: MultiModalEmbeddings = (jnp.ones((10, 128)), jnp.ones(
39
+ (15, 128)))
40
+ sanity_check_mm_encoder_outputs(embeddings, 2)
41
+ # No assertion error expected
42
+
43
+
44
+ def test_sanity_check_valid_3d_jax_array():
45
+ """Tests sanity_check with a valid 3D jax.Array."""
46
+ embeddings: MultiModalEmbeddings = jnp.ones((2, 10, 128))
47
+ # This is valid because mm_embeddings is iterable, and each item (e)
48
+ # in the first dimension has e.ndim == 2.
49
+ sanity_check_mm_encoder_outputs(embeddings, 2)
50
+ # No assertion error expected
51
+
52
+
53
+ def test_sanity_check_invalid_type():
54
+ """Tests sanity_check with an invalid type for embeddings."""
55
+ with pytest.raises(
56
+ AssertionError,
57
+ match=
58
+ "Expected multimodal embeddings to be a list/tuple of 2D tensors"):
59
+ sanity_check_mm_encoder_outputs("not a tensor", 1)
60
+
61
+
62
+ def test_sanity_check_wrong_num_items():
63
+ """Tests sanity_check with a mismatch in the number of embeddings."""
64
+ embeddings: MultiModalEmbeddings = [jnp.ones((10, 128))]
65
+ with pytest.raises(
66
+ AssertionError,
67
+ match="Expected number of multimodal embeddings to match number of"
68
+ ):
69
+ sanity_check_mm_encoder_outputs(embeddings, 2)
70
+
71
+
72
+ def test_sanity_check_wrong_dimensions_in_list():
73
+ """Tests sanity_check with non-2D tensors within the list."""
74
+ embeddings: MultiModalEmbeddings = [jnp.ones((10, 128, 1))]
75
+ with pytest.raises(
76
+ AssertionError,
77
+ match=
78
+ "Expected multimodal embeddings to be a sequence of 2D tensors"):
79
+ sanity_check_mm_encoder_outputs(embeddings, 1)
80
+
81
+
82
+ # --- Tests for flatten_embeddings ---
83
+
84
+
85
+ def test_flatten_single_array():
86
+ """Tests flatten_embeddings with a single 2D array."""
87
+ emb: NestedTensors = jnp.arange(12).reshape((3, 4))
88
+ result = flatten_embeddings(emb)
89
+ np.testing.assert_array_equal(result, emb)
90
+
91
+
92
+ def test_flatten_single_3d_array():
93
+ """Tests flatten_embeddings with a single 3D array."""
94
+ emb: NestedTensors = jnp.arange(24).reshape((2, 3, 4))
95
+ result = flatten_embeddings(emb)
96
+ expected = jnp.arange(24).reshape((6, 4))
97
+ np.testing.assert_array_equal(result, expected)
98
+
99
+
100
+ def test_flatten_list_of_arrays():
101
+ """Tests flatten_embeddings with a list of 2D arrays."""
102
+ emb: NestedTensors = [
103
+ jnp.arange(12).reshape((3, 4)),
104
+ jnp.arange(12, 20).reshape((2, 4))
105
+ ]
106
+ result = flatten_embeddings(emb)
107
+ expected = jnp.arange(20).reshape((5, 4))
108
+ np.testing.assert_array_equal(result, expected)
109
+
110
+
111
+ def test_flatten_nested_list():
112
+ """Tests flatten_embeddings with a nested list of arrays."""
113
+ emb: NestedTensors = [
114
+ jnp.arange(6).reshape((2, 3)),
115
+ [
116
+ jnp.arange(6, 12).reshape((2, 3)),
117
+ jnp.arange(12, 15).reshape((1, 3))
118
+ ]
119
+ ]
120
+ result = flatten_embeddings(emb)
121
+ expected = jnp.arange(15).reshape((5, 3))
122
+ np.testing.assert_array_equal(result, expected)
123
+
124
+
125
+ # --- Tests for merge_multimodal_embeddings ---
126
+
127
+ EMBED_DIM = 4
128
+
129
+
130
+ @pytest.fixture
131
+ def base_embeds():
132
+ return jnp.zeros((8, EMBED_DIM))
133
+
134
+
135
+ def test_merge_single_placeholder(base_embeds):
136
+ """Tests merging with a single integer placeholder ID."""
137
+ input_ids = jnp.array([1, 2, -1, -1, 3, 4, -1, 5])
138
+ inputs_embeds = base_embeds[:len(input_ids)]
139
+ mm_embeds: NestedTensors = jnp.arange(3 * EMBED_DIM).reshape(
140
+ (3, EMBED_DIM))
141
+ result = merge_multimodal_embeddings(input_ids,
142
+ inputs_embeds,
143
+ mm_embeds,
144
+ placeholder_token_id=-1)
145
+ expected = np.array(inputs_embeds)
146
+ expected[input_ids == -1] = mm_embeds
147
+ np.testing.assert_array_equal(result, expected)
148
+
149
+
150
+ def test_merge_no_placeholders(base_embeds):
151
+ """Tests merging when no placeholder tokens are in input_ids."""
152
+ input_ids = jnp.array([1, 2, 3, 4])
153
+ inputs_embeds = jnp.arange(len(input_ids) * EMBED_DIM).reshape(
154
+ (len(input_ids), EMBED_DIM))
155
+ mm_embeds: NestedTensors = jnp.empty((0, EMBED_DIM))
156
+
157
+ # Based on the provided traceback, this raises a TypeError within JAX's gather.
158
+ with pytest.raises(
159
+ TypeError,
160
+ match="Slice size at index 0 in gather op is out of range"):
161
+ merge_multimodal_embeddings(input_ids,
162
+ inputs_embeds,
163
+ mm_embeds,
164
+ placeholder_token_id=-1)
165
+
166
+
167
+ @pytest.mark.parametrize("placeholder_id", [-1, [-1, -2]])
168
+ def test_merge_mm_embeds_count_too_few(placeholder_id, base_embeds):
169
+ """
170
+ Tests behavior when fewer embeddings are provided than placeholders.
171
+ Based on the test results provided, this scenario does NOT raise an error
172
+ in the testing environment.
173
+ """
174
+ input_ids = jnp.array([1, 2, -1, -1, 3]) # 2 placeholders
175
+ inputs_embeds = base_embeds[:len(input_ids)]
176
+ mm_embeds_too_few: NestedTensors = jnp.ones((1, EMBED_DIM))
177
+
178
+ try:
179
+ # We are only asserting that this call does not crash.
180
+ # The actual output in this unexpected case is not being tested.
181
+ merge_multimodal_embeddings(input_ids,
182
+ inputs_embeds,
183
+ mm_embeds_too_few,
184
+ placeholder_token_id=placeholder_id)
185
+ except Exception as e:
186
+ pytest.fail(
187
+ f"Did not expect an exception based on test logs, but got {type(e).__name__}: {e}"
188
+ )
189
+
190
+
191
+ @pytest.mark.parametrize("placeholder_id", [-1, [-1, -2]])
192
+ def test_merge_mm_embeds_count_too_many_no_raise(placeholder_id, base_embeds):
193
+ """Tests that no error is raised if mm_embeds are too many; extras are ignored."""
194
+ input_ids = jnp.array([1, 2, -1, -1, 3]) # 2 placeholders
195
+ inputs_embeds = base_embeds[:len(input_ids)]
196
+ mm_embeds_too_many: NestedTensors = jnp.arange(3 * EMBED_DIM).reshape(
197
+ (3, EMBED_DIM))
198
+
199
+ try:
200
+ result = merge_multimodal_embeddings(
201
+ input_ids,
202
+ inputs_embeds,
203
+ mm_embeds_too_many,
204
+ placeholder_token_id=placeholder_id)
205
+ # Check that the first 2 embeddings from mm_embeds_too_many were used.
206
+ expected = np.array(inputs_embeds)
207
+ is_mm = np.isin(input_ids, np.array(placeholder_id))
208
+ expected[is_mm] = flatten_embeddings(mm_embeds_too_many)[:2]
209
+ np.testing.assert_array_equal(result, expected)
210
+ except Exception as e:
211
+ pytest.fail(
212
+ f"Did not expect an exception, but got {type(e).__name__}: {e}")
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,54 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from unittest.mock import MagicMock, patch
16
+
17
+ import pytest
18
+ import torch
19
+ from vllm.config import CacheConfig, VllmConfig
20
+
21
+ from tpu_inference.platforms.tpu_platform import TpuPlatform
22
+
23
+
24
+ class TestTpuPlatform:
25
+
26
+ @pytest.fixture
27
+ def vllm_config(self):
28
+ cache_config = CacheConfig(block_size=16,
29
+ gpu_memory_utilization=0.9,
30
+ swap_space=4,
31
+ cache_dtype="fp8")
32
+
33
+ vllm_config = MagicMock(spec=VllmConfig)
34
+ vllm_config.cache_config = cache_config
35
+ vllm_config.model_config = MagicMock(dtype='bfloat16')
36
+ vllm_config.scheduler_config = MagicMock(is_multimodal_model=False)
37
+ vllm_config.parallel_config = MagicMock()
38
+ vllm_config.compilation_config = MagicMock(mode="dynamo_trace_once",
39
+ backend="openxla")
40
+ vllm_config.kv_transfer_config = None
41
+ return vllm_config
42
+
43
+ @pytest.mark.parametrize("chip_name,expected_dtype", [
44
+ ("v6e", torch.float8_e5m2),
45
+ ("v5e", torch.float8_e4m3fn),
46
+ ])
47
+ def test_fp8_dtype(self, chip_name, expected_dtype):
48
+ mock_chip_type = MagicMock()
49
+ mock_chip_type.name = chip_name
50
+
51
+ with patch('tpu_inference.platforms.tpu_platform.init_logger'), \
52
+ patch('tpu_inference.platforms.tpu_platform.device.get_local_chips', return_value=(mock_chip_type, None)), \
53
+ patch('vllm.envs.VLLM_TPU_USING_PATHWAYS', False):
54
+ assert TpuPlatform.fp8_dtype() == expected_dtype
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.