tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (248) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +14 -0
  31. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  32. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
  35. tests/layers/__init__.py +13 -0
  36. tests/layers/common/__init__.py +13 -0
  37. tests/layers/common/test_attention_interface.py +156 -0
  38. tests/layers/common/test_quantization.py +149 -0
  39. tests/layers/jax/__init__.py +13 -0
  40. tests/layers/jax/attention/__init__.py +13 -0
  41. tests/layers/jax/attention/test_common_attention.py +103 -0
  42. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  43. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  44. tests/layers/jax/moe/__init__.py +13 -0
  45. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  46. tests/layers/jax/sample/__init__.py +13 -0
  47. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  48. tests/layers/jax/sample/test_sampling.py +115 -0
  49. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  50. tests/layers/jax/test_layers.py +155 -0
  51. tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
  52. tests/layers/jax/test_rope.py +93 -0
  53. tests/layers/jax/test_sharding.py +159 -0
  54. tests/layers/jax/test_transformer_block.py +152 -0
  55. tests/layers/vllm/__init__.py +13 -0
  56. tests/layers/vllm/test_attention.py +363 -0
  57. tests/layers/vllm/test_awq.py +406 -0
  58. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  59. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  61. tests/layers/vllm/test_fp8.py +17 -0
  62. tests/layers/vllm/test_mxfp4.py +320 -0
  63. tests/layers/vllm/test_unquantized.py +662 -0
  64. tests/layers/vllm/utils.py +87 -0
  65. tests/lora/__init__.py +13 -0
  66. tests/lora/conftest.py +14 -0
  67. tests/lora/test_bgmv.py +14 -0
  68. tests/lora/test_layers.py +25 -8
  69. tests/lora/test_lora.py +15 -1
  70. tests/lora/test_lora_perf.py +14 -0
  71. tests/models/__init__.py +13 -0
  72. tests/models/common/__init__.py +13 -0
  73. tests/models/common/test_model_loader.py +455 -0
  74. tests/models/jax/__init__.py +13 -0
  75. tests/models/jax/test_deepseek_v3.py +401 -0
  76. tests/models/jax/test_llama3.py +184 -0
  77. tests/models/jax/test_llama4.py +298 -0
  78. tests/models/jax/test_llama_eagle3.py +197 -0
  79. tests/models/jax/test_llama_guard_4.py +242 -0
  80. tests/models/jax/test_qwen2.py +172 -0
  81. tests/models/jax/test_qwen2_5_vl.py +605 -0
  82. tests/models/jax/test_qwen3.py +169 -0
  83. tests/models/jax/test_weight_loading.py +180 -0
  84. tests/models/jax/utils/__init__.py +13 -0
  85. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  86. tests/platforms/__init__.py +13 -0
  87. tests/platforms/test_tpu_platform.py +54 -0
  88. tests/runner/__init__.py +13 -0
  89. tests/runner/test_block_table.py +395 -0
  90. tests/runner/test_input_batch.py +226 -0
  91. tests/runner/test_kv_cache.py +220 -0
  92. tests/runner/test_kv_cache_manager.py +498 -0
  93. tests/runner/test_multimodal_manager.py +429 -0
  94. tests/runner/test_persistent_batch_manager.py +84 -0
  95. tests/runner/test_speculative_decoding_manager.py +368 -0
  96. tests/runner/test_structured_decoding_manager.py +220 -0
  97. tests/runner/test_tpu_runner.py +261 -0
  98. tests/runner/test_tpu_runner_dp.py +1099 -0
  99. tests/runner/test_tpu_runner_mesh.py +200 -0
  100. tests/runner/test_utils.py +411 -0
  101. tests/spec_decode/__init__.py +13 -0
  102. tests/spec_decode/test_eagle3.py +311 -0
  103. tests/test_base.py +14 -0
  104. tests/test_tpu_info.py +14 -0
  105. tests/test_utils.py +1 -43
  106. tests/worker/__init__.py +13 -0
  107. tests/worker/tpu_worker_test.py +414 -0
  108. tpu_inference/__init__.py +14 -0
  109. tpu_inference/core/__init__.py +13 -0
  110. tpu_inference/core/sched/__init__.py +13 -0
  111. tpu_inference/core/sched/dp_scheduler.py +372 -56
  112. tpu_inference/distributed/__init__.py +13 -0
  113. tpu_inference/distributed/jax_parallel_state.py +14 -0
  114. tpu_inference/distributed/tpu_connector.py +14 -9
  115. tpu_inference/distributed/utils.py +56 -4
  116. tpu_inference/executors/__init__.py +13 -0
  117. tpu_inference/executors/ray_distributed_executor.py +20 -3
  118. tpu_inference/experimental/__init__.py +13 -0
  119. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  120. tpu_inference/kernels/__init__.py +13 -0
  121. tpu_inference/kernels/collectives/__init__.py +13 -0
  122. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  123. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  124. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  125. tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
  126. tpu_inference/kernels/megablox/__init__.py +13 -0
  127. tpu_inference/kernels/megablox/common.py +54 -0
  128. tpu_inference/kernels/megablox/gmm.py +646 -0
  129. tpu_inference/kernels/mla/__init__.py +13 -0
  130. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  131. tpu_inference/kernels/mla/v1/kernel.py +20 -26
  132. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  133. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  134. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  135. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  136. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
  137. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
  138. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  139. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
  140. tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
  141. tpu_inference/layers/__init__.py +13 -0
  142. tpu_inference/layers/common/__init__.py +13 -0
  143. tpu_inference/layers/common/attention_interface.py +26 -19
  144. tpu_inference/layers/common/attention_metadata.py +14 -0
  145. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  146. tpu_inference/layers/common/quant_methods.py +15 -0
  147. tpu_inference/layers/common/quantization.py +282 -0
  148. tpu_inference/layers/common/sharding.py +22 -3
  149. tpu_inference/layers/common/utils.py +94 -0
  150. tpu_inference/layers/jax/__init__.py +13 -0
  151. tpu_inference/layers/jax/attention/__init__.py +13 -0
  152. tpu_inference/layers/jax/attention/attention.py +19 -6
  153. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
  154. tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
  155. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  156. tpu_inference/layers/jax/base.py +14 -0
  157. tpu_inference/layers/jax/constants.py +13 -0
  158. tpu_inference/layers/jax/layers.py +14 -0
  159. tpu_inference/layers/jax/misc.py +14 -0
  160. tpu_inference/layers/jax/moe/__init__.py +13 -0
  161. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  162. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  163. tpu_inference/layers/jax/moe/moe.py +43 -3
  164. tpu_inference/layers/jax/pp_utils.py +53 -0
  165. tpu_inference/layers/jax/rope.py +14 -0
  166. tpu_inference/layers/jax/rope_interface.py +14 -0
  167. tpu_inference/layers/jax/sample/__init__.py +13 -0
  168. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  169. tpu_inference/layers/jax/sample/sampling.py +15 -1
  170. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  171. tpu_inference/layers/jax/transformer_block.py +14 -0
  172. tpu_inference/layers/vllm/__init__.py +13 -0
  173. tpu_inference/layers/vllm/attention.py +4 -4
  174. tpu_inference/layers/vllm/fused_moe.py +100 -455
  175. tpu_inference/layers/vllm/linear.py +64 -0
  176. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  177. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  178. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  179. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  180. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  181. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  182. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  183. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
  184. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  188. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
  189. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  190. tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
  191. tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
  192. tpu_inference/lora/__init__.py +13 -0
  193. tpu_inference/lora/torch_lora_ops.py +8 -13
  194. tpu_inference/models/__init__.py +13 -0
  195. tpu_inference/models/common/__init__.py +13 -0
  196. tpu_inference/models/common/model_loader.py +37 -16
  197. tpu_inference/models/jax/__init__.py +13 -0
  198. tpu_inference/models/jax/deepseek_v3.py +113 -124
  199. tpu_inference/models/jax/gpt_oss.py +23 -7
  200. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  201. tpu_inference/models/jax/llama3.py +99 -36
  202. tpu_inference/models/jax/llama4.py +14 -0
  203. tpu_inference/models/jax/llama_eagle3.py +14 -0
  204. tpu_inference/models/jax/llama_guard_4.py +15 -1
  205. tpu_inference/models/jax/qwen2.py +17 -2
  206. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  207. tpu_inference/models/jax/qwen3.py +17 -2
  208. tpu_inference/models/jax/utils/__init__.py +13 -0
  209. tpu_inference/models/jax/utils/file_utils.py +14 -0
  210. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  211. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
  213. tpu_inference/models/jax/utils/weight_utils.py +32 -1
  214. tpu_inference/models/vllm/__init__.py +13 -0
  215. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
  216. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  217. tpu_inference/platforms/__init__.py +14 -0
  218. tpu_inference/platforms/tpu_platform.py +27 -29
  219. tpu_inference/runner/__init__.py +13 -0
  220. tpu_inference/runner/compilation_manager.py +69 -35
  221. tpu_inference/runner/kv_cache.py +14 -0
  222. tpu_inference/runner/kv_cache_manager.py +15 -2
  223. tpu_inference/runner/lora_utils.py +16 -1
  224. tpu_inference/runner/multimodal_manager.py +16 -2
  225. tpu_inference/runner/persistent_batch_manager.py +14 -0
  226. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  227. tpu_inference/runner/structured_decoding_manager.py +14 -0
  228. tpu_inference/runner/tpu_runner.py +30 -10
  229. tpu_inference/spec_decode/__init__.py +13 -0
  230. tpu_inference/spec_decode/jax/__init__.py +13 -0
  231. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  232. tpu_inference/tpu_info.py +14 -0
  233. tpu_inference/utils.py +31 -30
  234. tpu_inference/worker/__init__.py +13 -0
  235. tpu_inference/worker/tpu_worker.py +23 -7
  236. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
  237. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  238. tpu_inference/layers/vllm/linear_common.py +0 -208
  239. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  240. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  241. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  242. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  245. tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
  246. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  247. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  248. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import Any, List
2
16
 
3
17
  import jax
@@ -1,5 +1,19 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import functools
2
- from typing import TYPE_CHECKING, Dict, List
16
+ from typing import TYPE_CHECKING, List
3
17
 
4
18
  import jax
5
19
  import jax.numpy as jnp
@@ -198,7 +212,6 @@ class KVCacheManager:
198
212
  # uniform page size.
199
213
  representative_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec
200
214
  page_size_bytes = representative_spec.page_size_bytes
201
- self.runner.layer_name_to_kvcache_index: Dict[str, int] = {}
202
215
  kv_caches = self.runner.kv_caches
203
216
  num_blocks_list = []
204
217
  for i, kv_cache_tensor in enumerate(kv_cache_config.kv_cache_tensors):
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from __future__ import annotations
2
16
 
3
17
  from typing import TYPE_CHECKING
@@ -7,7 +21,8 @@ from torchax.interop import jax_view
7
21
  from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA
8
22
  from vllm.lora.request import LoRARequest
9
23
 
10
- from tpu_inference.layers.vllm.sharding import update_lora
24
+ from tpu_inference.layers.vllm.process_weights.cleanup_sharding import \
25
+ update_lora
11
26
 
12
27
  if TYPE_CHECKING:
13
28
  from tpu_inference.runner.tpu_runner import TPUModelRunner
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import TYPE_CHECKING
2
16
 
3
17
  import jax
@@ -98,7 +112,7 @@ class MultiModalManager:
98
112
  # encoder outputs.
99
113
  encoder_outputs = []
100
114
  for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
101
- mm_kwargs, merge_by_field_config=False):
115
+ mm_kwargs):
102
116
  batched_mm_inputs = mm_kwargs_group
103
117
  # Convert torch tensors to numpy arrays that JAX can handle.
104
118
  if "pixel_values" in batched_mm_inputs and isinstance(
@@ -134,7 +148,7 @@ class MultiModalManager:
134
148
  # 2. A list or tuple (length: num_items) of tensors, each of shape
135
149
  # (feature_size, hidden_size) in case the feature size is dynamic
136
150
  # depending on the input multimodal items.
137
- curr_group_outputs = self.runner.get_multimodal_embeddings_fn(
151
+ curr_group_outputs = self.runner.embed_multimodal_fn(
138
152
  self.runner.state, image_grid_thw, **batched_mm_inputs)
139
153
 
140
154
  sanity_check_mm_encoder_outputs(
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import Dict
2
16
 
3
17
  import jax
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from __future__ import annotations
2
16
 
3
17
  from dataclasses import dataclass
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import functools
2
16
  from typing import TYPE_CHECKING, Tuple
3
17
 
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import copy
2
16
  import functools
3
17
  import logging
@@ -268,6 +282,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
268
282
  self._substitute_placeholder_token_fn = _substitute_placeholder_token
269
283
  self.execute_model_state: ExecuteModelState | None = None
270
284
 
285
+ self.kv_caches: list[jax.Array] = []
286
+ self.layer_name_to_kvcache_index: dict[str, int] = {}
287
+
271
288
  def _init_random(self):
272
289
  if self.model_config.seed is None:
273
290
  self.model_config.seed = 0
@@ -494,10 +511,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
494
511
  multimodal_fns = multimodal_fns or {}
495
512
  self.precompile_vision_encoder_fn = multimodal_fns.get(
496
513
  "precompile_vision_encoder_fn", None)
497
- self.get_multimodal_embeddings_fn = multimodal_fns.get(
498
- "get_multimodal_embeddings_fn", None)
499
- self.get_input_embeddings_fn = multimodal_fns.get(
500
- "get_input_embeddings_fn", None)
514
+ self.embed_multimodal_fn = multimodal_fns.get("embed_multimodal_fn",
515
+ None)
516
+ self.embed_input_ids_fn = multimodal_fns.get("embed_input_ids_fn",
517
+ None)
501
518
  self.get_mrope_input_positions_fn = multimodal_fns.get(
502
519
  "get_mrope_input_positions_fn", None)
503
520
 
@@ -509,7 +526,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
509
526
  jax.random.key(self.model_config.seed)).params()
510
527
  self.is_multimodal_model = (
511
528
  self.model_config.is_multimodal_model
512
- and self.get_multimodal_embeddings_fn is not None and hasattr(
529
+ and self.embed_multimodal_fn is not None and hasattr(
513
530
  self.model_config.hf_config, "architectures"
514
531
  ) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
515
532
  and len(self.model_config.hf_config.architectures) >= 1
@@ -525,10 +542,12 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
525
542
  def get_kv_cache_spec(self):
526
543
  return self.kv_cache_manager.get_kv_cache_spec()
527
544
 
528
- def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
545
+ def initialize_kv_cache(self,
546
+ kv_cache_config: KVCacheConfig,
547
+ topology_order_id: int = 0) -> None:
548
+ self.topology_order_id = topology_order_id
529
549
  self.kv_cache_config = kv_cache_config
530
550
  self.use_hybrid_kvcache = len(kv_cache_config.kv_cache_groups) > 1
531
- self.kv_caches = []
532
551
  self.kv_cache_manager.initialize_kv_cache(kv_cache_config)
533
552
  if has_kv_transfer_group():
534
553
  get_kv_transfer_group().register_runner(self)
@@ -810,7 +829,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
810
829
  sharding = None
811
830
  if self.dp_size > 1:
812
831
  sharding = NamedSharding(self.mesh,
813
- PartitionSpec(ShardingAxisName.ATTN_DATA))
832
+ PartitionSpec(ShardingAxisName.MLP_DATA))
814
833
 
815
834
  tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
816
835
  self.mesh, self.input_batch, padded_num_reqs, sharding=sharding)
@@ -1373,7 +1392,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1373
1392
  self.mesh,
1374
1393
  self.input_batch,
1375
1394
  padded_num_reqs,
1376
- sharding=data_parallel_attn_sharding,
1395
+ sharding=NamedSharding(self.mesh,
1396
+ PartitionSpec(ShardingAxisName.MLP_DATA)),
1377
1397
  )
1378
1398
  if self.uses_mrope:
1379
1399
  positions = mrope_positions
@@ -1663,7 +1683,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1663
1683
  def _get_input_ids_embeds(self, input_ids: jax.Array,
1664
1684
  mm_embeds: list[jax.Array]):
1665
1685
  if self.is_multimodal_model:
1666
- inputs_embeds = self.get_input_embeddings_fn(
1686
+ inputs_embeds = self.embed_input_ids_fn(
1667
1687
  self.state,
1668
1688
  input_ids,
1669
1689
  mm_embeds,
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,3 +1,16 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
1
14
  """Implements the Eagle3 proposer for speculative decoding on JAX/TPU."""
2
15
  import functools
3
16
  from dataclasses import replace
tpu_inference/tpu_info.py CHANGED
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import glob
2
16
  import os
3
17
 
tpu_inference/utils.py CHANGED
@@ -3,7 +3,7 @@ import time
3
3
  from collections import defaultdict
4
4
  from collections.abc import Sequence
5
5
  from functools import wraps
6
- from typing import Any, Callable, List, Tuple
6
+ from typing import Any, Callable, List, Tuple, Union
7
7
 
8
8
  import jax
9
9
  import jax.numpy as jnp
@@ -283,35 +283,6 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
283
283
  return utils.hashing.get_hash_fn_by_name(hash_fn_name)
284
284
 
285
285
 
286
- def quantize_kv(key: jax.Array, value: jax.Array,
287
- kv_cache_quantized_dtype: jnp.dtype, k_scale: float,
288
- v_scale: float) -> Tuple[jax.Array, jax.Array]:
289
- """
290
- Quantize the key and value tensors.
291
-
292
- Args:
293
- key: The key tensor to quantize.
294
- value: The value tensor to quantize.
295
- kv_cache_quantized_dtype: The dtype to quantize the key and value tensors to.
296
- q_scale: The scale to quantize the key and value tensors by.
297
- k_scale: The scale to quantize the key tensor by.
298
- v_scale: The scale to quantize the value tensor by.
299
-
300
- Returns:
301
- Tuple[jax.Array, jax.Array]: The quantized key and value tensors.
302
- """
303
- dtype_info = jnp.finfo(kv_cache_quantized_dtype)
304
- minval, maxval = float(dtype_info.min), float(dtype_info.max)
305
- key = key.astype(jnp.float32) / k_scale
306
- key = jnp.clip(key, minval, maxval)
307
- key = key.astype(kv_cache_quantized_dtype)
308
- value = value.astype(jnp.float32) / v_scale
309
- value = jnp.clip(value, minval, maxval)
310
- value = value.astype(kv_cache_quantized_dtype)
311
-
312
- return key, value
313
-
314
-
315
286
  def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
316
287
  """
317
288
  Get the JAX dtype from a string dtype.
@@ -326,6 +297,36 @@ def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
326
297
  return to_jax_dtype(str_dtype)
327
298
 
328
299
 
300
+ def get_mesh_shape_product(
301
+ mesh: Mesh,
302
+ axes: Union[str, list[str], None],
303
+ ) -> int:
304
+ """
305
+ Get the product of mesh dimensions for one or more axes.
306
+
307
+ Examples:
308
+ # Single axis (defaults to 1 if not present)
309
+ get_mesh_shape_product(mesh, "model")
310
+
311
+ # Multiple axes - computes product of their sizes
312
+ get_mesh_shape_product(mesh, ["model", "attn_dp"])
313
+
314
+ # None means no sharding on this dimension
315
+ get_mesh_shape_product(mesh, None) # returns 1
316
+ """
317
+ if axes is None:
318
+ return 1
319
+
320
+ if isinstance(axes, str):
321
+ axes = [axes]
322
+
323
+ product = 1
324
+ for axis in axes:
325
+ product *= mesh.shape.get(axis, 1)
326
+
327
+ return product
328
+
329
+
329
330
  def time_function(func):
330
331
  """
331
332
  A decorator to measure the execution time of a function.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -26,8 +26,8 @@ from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
26
26
 
27
27
  from tpu_inference import envs, utils
28
28
  from tpu_inference.distributed import jax_parallel_state
29
- from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port,
30
- get_node_id)
29
+ from tpu_inference.distributed.utils import (get_device_topology_order_id,
30
+ get_host_ip, get_kv_transfer_port)
31
31
  from tpu_inference.layers.common.sharding import ShardingConfigManager
32
32
  from tpu_inference.logger import init_logger
33
33
  from tpu_inference.models.jax.jax_intermediate_tensor import \
@@ -232,9 +232,16 @@ class TPUWorker:
232
232
 
233
233
  is_first_rank = True
234
234
  is_last_rank = True
235
+ self.topology_order_id = self.rank
235
236
  if self.parallel_config.pipeline_parallel_size > 1:
236
237
  is_first_rank = self.rank == 0
237
238
  is_last_rank = self.rank == self.pp_config.pp_world_size - 1
239
+ else:
240
+ # topology_order_id is used to determine the KV cache
241
+ # mapping between P/D workers
242
+ if multihost_backend == "ray":
243
+ self.topology_order_id = get_device_topology_order_id(
244
+ jax.local_devices(), jax.devices())
238
245
 
239
246
  self.model_runner = TPUModelRunner(self.vllm_config, self.devices,
240
247
  self.rank, is_first_rank,
@@ -243,9 +250,12 @@ class TPUWorker:
243
250
  f"rank={self.rank} | "
244
251
  f"is_first_rank={is_first_rank} | "
245
252
  f"is_last_rank={is_last_rank} | "
246
- f"node_id={get_node_id()} | "
253
+ f"topology_order_id={self.topology_order_id} | "
247
254
  f"is_driver_worker={self.is_driver_worker} | "
248
- f"hbm={utils.hbm_usage_gb(self.devices)}GiB")
255
+ f"hbm={utils.hbm_usage_gb(self.devices)}GiB |"
256
+ f"self.devices={self.devices} | "
257
+ f"total devices={jax.devices()} | "
258
+ f"local_devices={jax.local_devices()}")
249
259
  vllm_utils.report_usage_stats(self.vllm_config)
250
260
 
251
261
  def initialize_pp_transfer_connect(self):
@@ -420,13 +430,19 @@ class TPUWorker:
420
430
  kv_cache_config: KVCacheConfig,
421
431
  ) -> None:
422
432
  """Allocate GPU KV cache with the specified kv_cache_config."""
423
- self.model_runner.initialize_kv_cache(kv_cache_config)
433
+ # Precompile functions with large vocab_size tensors before allocating KV cache to avoid OOM
434
+ if not (envs.SKIP_JAX_PRECOMPILE or
435
+ (hasattr(self.model_runner.model_config, "enforce_eager")
436
+ and self.model_runner.model_config.enforce_eager)):
437
+ self.model_runner.compilation_manager._precompile_sampling()
438
+ self.model_runner.compilation_manager._precompile_gather_logprobs()
439
+ self.model_runner.initialize_kv_cache(kv_cache_config,
440
+ self.topology_order_id)
424
441
 
425
442
  def get_node_kv_ip_port(self) -> tuple[int, str, int]:
426
- node_id = get_node_id()
427
443
  ip = get_host_ip()
428
444
  port = get_kv_transfer_port()
429
- return (int(node_id), ip, int(port))
445
+ return (int(self.topology_order_id), ip, int(port))
430
446
 
431
447
  def check_health(self) -> None:
432
448
  # worker will always be healthy as long as it's running.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tpu_inference
3
- Version: 0.12.0.dev20251213
3
+ Version: 0.13.2.dev20251230
4
4
  Author: tpu_inference Contributors
5
5
  Classifier: Development Status :: 3 - Alpha
6
6
  Classifier: Intended Audience :: Developers