tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (251) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +22 -1
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +167 -97
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +31 -9
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +280 -210
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +77 -36
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +91 -31
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -4
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -71
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +158 -63
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +53 -30
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +54 -2
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +105 -57
  232. tpu_inference/runner/utils.py +2 -2
  233. tpu_inference/spec_decode/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/__init__.py +13 -0
  235. tpu_inference/spec_decode/jax/eagle3.py +65 -19
  236. tpu_inference/tpu_info.py +14 -0
  237. tpu_inference/utils.py +72 -44
  238. tpu_inference/worker/__init__.py +13 -0
  239. tpu_inference/worker/tpu_worker.py +65 -52
  240. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
  241. tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
  242. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  244. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  245. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  246. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  247. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  248. tpu_inference-0.11.1.dev202511270815.dist-info/RECORD +0 -174
  249. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
  250. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
  251. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import re
2
16
  from dataclasses import dataclass
3
17
  from typing import List, Optional, Tuple
@@ -11,6 +25,9 @@ from jax.sharding import Mesh, NamedSharding
11
25
  from jax.sharding import PartitionSpec as P
12
26
  from vllm.config import VllmConfig
13
27
 
28
+ from tpu_inference.layers.common.quant_methods import MXFP4
29
+ from tpu_inference.layers.common.quantization import (
30
+ dequantize_tensor_from_mxfp4_packed, e8m0_to_fp32, u8_unpack_e2m1)
14
31
  from tpu_inference.layers.jax.attention.gpt_oss_attention import (
15
32
  AttentionMetadata, GptOssAttention)
16
33
  from tpu_inference.layers.jax.constants import KVCacheType
@@ -18,8 +35,6 @@ from tpu_inference.layers.jax.layers import Embedder, LMhead, RMSNorm
18
35
  from tpu_inference.layers.jax.moe.gpt_oss_moe import GptOssMoE, GptOssRouter
19
36
  from tpu_inference.layers.jax.transformer_block import TransformerBlock
20
37
  from tpu_inference.logger import init_logger
21
- from tpu_inference.models.jax.utils.quantization.mxfp4_utils import (
22
- MXFP4_QUANT_METHOD, dequant_mxfp4_to_bf16, unpack_mxfp4_to_fp32)
23
38
  from tpu_inference.models.jax.utils.weight_utils import (
24
39
  get_param, model_weights_generator, print_param_info)
25
40
 
@@ -102,9 +117,9 @@ class GptOss(nnx.Module):
102
117
  rope_ntk_beta=rope_ntk_beta,
103
118
  rngs=self.rng,
104
119
  random_init=self.random_init,
105
- query_tnh=P(None, 'model', None),
106
- keyvalue_skh=P(None, 'model', None),
107
- attn_o_tnh=P(None, 'model', None),
120
+ query_tnh=P("data", 'model', None),
121
+ keyvalue_skh=P("data", 'model', None),
122
+ attn_o_tnh=P("data", 'model', None),
108
123
  dnh_sharding=P(None, 'model', None),
109
124
  dkh_sharding=P(None, 'model', None),
110
125
  nhd_sharding=P('model', None, None),
@@ -205,7 +220,7 @@ class GptOss(nnx.Module):
205
220
 
206
221
  # MXFP4 checkpoints swap last two dims for MoE to place packed dim at most minor
207
222
  swap_mlp_transform = transforms[
208
- "swap_last2"] if quant_method == MXFP4_QUANT_METHOD else None
223
+ "swap_last2"] if quant_method == MXFP4 else None
209
224
 
210
225
  mappings = {
211
226
  # Embeddings, Norms, and LM Head
@@ -285,7 +300,7 @@ class GptOss(nnx.Module):
285
300
  # Build a pool of weights with MXFP4 experts combined if neededs
286
301
  pool: dict[str, torch.Tensor | tuple] = (self._build_mxfp4_pool(
287
302
  names_and_weights_generator,
288
- mappings) if quant_method == MXFP4_QUANT_METHOD else {
303
+ mappings) if quant_method == MXFP4 else {
289
304
  loaded_name: loaded_weight
290
305
  for loaded_name, loaded_weight in names_and_weights_generator
291
306
  })
@@ -316,8 +331,9 @@ class GptOss(nnx.Module):
316
331
  blocks_u8, scales_u8 = loaded_weight
317
332
  # Quantized param (QArray): set qvalue/scale directly and skip regular path
318
333
  if hasattr(model_weight, "array"): # QArray check
319
- codes_fp32_t, scales_fp32_t = unpack_mxfp4_to_fp32(
320
- blocks_u8, scales_u8)
334
+ codes_fp32_t = u8_unpack_e2m1(blocks_u8).astype(
335
+ jnp.float32)
336
+ scales_fp32_t = e8m0_to_fp32(scales_u8)
321
337
  self._load_mxfp4(
322
338
  model_weight=model_weight,
323
339
  codes_fp32_t=codes_fp32_t,
@@ -328,7 +344,7 @@ class GptOss(nnx.Module):
328
344
  print_param_info(model_weight, loaded_name)
329
345
  continue
330
346
  # Not a QArray: dequantize MXFP4 to BF16 full weights
331
- prepared_weight = dequant_mxfp4_to_bf16(
347
+ prepared_weight = dequantize_tensor_from_mxfp4_packed(
332
348
  blocks_u8, scales_u8)
333
349
 
334
350
  # Single regular-tensor load call (BF16 or dequantized MXFP4)
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from dataclasses import dataclass
2
16
  from typing import TYPE_CHECKING, Any, Dict, Union
3
17
 
@@ -1,3 +1,18 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from itertools import islice
1
16
  from typing import List, Optional, Tuple
2
17
 
3
18
  import jax
@@ -8,13 +23,19 @@ from transformers import LlamaConfig, modeling_flax_utils
8
23
  from vllm.config import VllmConfig
9
24
 
10
25
  from tpu_inference import utils
26
+ from tpu_inference.distributed.jax_parallel_state import get_pp_group
11
27
  from tpu_inference.layers.common.attention_interface import attention
12
28
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
29
+ from tpu_inference.layers.common.quantization import quantize_kv
13
30
  from tpu_inference.layers.common.sharding import ShardingAxisName
31
+ from tpu_inference.layers.jax.pp_utils import PPMissingLayer, make_layers
14
32
  from tpu_inference.layers.jax.rope_interface import apply_rope
15
33
  from tpu_inference.logger import init_logger
34
+ from tpu_inference.models.jax.jax_intermediate_tensor import \
35
+ JaxIntermediateTensors
16
36
  from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
17
37
  load_hf_weights)
38
+ from tpu_inference.utils import get_mesh_shape_product
18
39
 
19
40
  logger = init_logger(__name__)
20
41
 
@@ -79,7 +100,8 @@ class LlamaAttention(nnx.Module):
79
100
  self.hidden_size // self.num_heads)
80
101
  self.head_dim = utils.get_padded_head_dim(self.head_dim_original)
81
102
 
82
- sharding_size = mesh.shape["model"] * mesh.shape.get("attn_dp", 1)
103
+ sharding_size = get_mesh_shape_product(mesh,
104
+ ShardingAxisName.MLP_TENSOR)
83
105
  self.num_heads = utils.get_padded_num_heads(self.num_heads,
84
106
  sharding_size)
85
107
  self.num_kv_heads = utils.get_padded_num_heads(self.num_kv_heads,
@@ -152,8 +174,8 @@ class LlamaAttention(nnx.Module):
152
174
  # q_scale = self._q_scale
153
175
  k_scale = self._k_scale
154
176
  v_scale = self._v_scale
155
- k, v = utils.quantize_kv(k, v, self.kv_cache_quantized_dtype,
156
- k_scale, v_scale)
177
+ k, v = quantize_kv(self.kv_cache_quantized_dtype, k, v, k_scale,
178
+ v_scale)
157
179
  new_kv_cache, outputs = attention(
158
180
  kv_cache,
159
181
  q,
@@ -235,38 +257,52 @@ class LlamaModel(nnx.Module):
235
257
  rms_norm_eps = hf_config.rms_norm_eps
236
258
  hidden_size = hf_config.hidden_size
237
259
 
238
- self.embed = nnx.Embed(
239
- num_embeddings=vocab_size,
240
- features=hidden_size,
241
- param_dtype=dtype,
242
- embedding_init=nnx.with_partitioning(
243
- init_fn, (ShardingAxisName.VOCAB, None)),
244
- rngs=rng,
245
- )
246
- self.layers = [
247
- LlamaDecoderLayer(
260
+ self.is_first_rank = get_pp_group().is_first_rank
261
+ self.is_last_rank = get_pp_group().is_last_rank
262
+
263
+ if self.is_first_rank or (hf_config.tie_word_embeddings
264
+ and self.is_last_rank):
265
+ self.embed = nnx.Embed(
266
+ num_embeddings=vocab_size,
267
+ features=hidden_size,
268
+ param_dtype=dtype,
269
+ embedding_init=nnx.with_partitioning(
270
+ init_fn, (ShardingAxisName.VOCAB, None)),
271
+ rngs=rng,
272
+ )
273
+ else:
274
+ self.embed = PPMissingLayer()
275
+
276
+ self.start_layer, self.end_layer, self.layers = make_layers(
277
+ hf_config.num_hidden_layers,
278
+ lambda: LlamaDecoderLayer(
248
279
  config=hf_config,
249
280
  dtype=dtype,
250
281
  rng=rng,
251
282
  mesh=mesh,
252
283
  # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
253
- kv_cache_dtype=vllm_config.cache_config.cache_dtype)
254
- for _ in range(hf_config.num_hidden_layers)
255
- ]
256
- self.norm = nnx.RMSNorm(
257
- hidden_size,
258
- epsilon=rms_norm_eps,
259
- param_dtype=dtype,
260
- scale_init=nnx.with_partitioning(init_fn, (None, )),
261
- rngs=rng,
262
- )
263
- if model_config.hf_config.tie_word_embeddings:
264
- self.lm_head = self.embed.embedding
265
- else:
266
- self.lm_head = nnx.Param(
267
- init_fn(rng.params(), (hidden_size, vocab_size), dtype),
268
- sharding=(None, ShardingAxisName.VOCAB),
284
+ kv_cache_dtype=vllm_config.cache_config.cache_dtype))
285
+ if self.is_last_rank:
286
+ self.norm = nnx.RMSNorm(
287
+ hidden_size,
288
+ epsilon=rms_norm_eps,
289
+ param_dtype=dtype,
290
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
291
+ rngs=rng,
269
292
  )
293
+ else:
294
+ self.norm = PPMissingLayer()
295
+
296
+ if self.is_last_rank:
297
+ if model_config.hf_config.tie_word_embeddings:
298
+ self.lm_head = self.embed.embedding
299
+ else:
300
+ self.lm_head = nnx.Param(
301
+ init_fn(rng.params(), (hidden_size, vocab_size), dtype),
302
+ sharding=(None, ShardingAxisName.VOCAB),
303
+ )
304
+ else:
305
+ self.lm_head = PPMissingLayer()
270
306
 
271
307
  self.aux_hidden_state_layers = []
272
308
  if vllm_config.speculative_config and vllm_config.speculative_config.method == "eagle3":
@@ -282,10 +318,18 @@ class LlamaModel(nnx.Module):
282
318
  kv_caches: List[jax.Array],
283
319
  input_ids: jax.Array,
284
320
  attention_metadata: AttentionMetadata,
285
- ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]:
286
- x = self.embed(input_ids)
321
+ intermediate_tensors: JaxIntermediateTensors | None,
322
+ ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]] | Tuple[
323
+ List[jax.Array], JaxIntermediateTensors]:
324
+ if self.is_first_rank:
325
+ x = self.embed(input_ids)
326
+ else:
327
+ assert intermediate_tensors is not None
328
+ x = intermediate_tensors["hidden_states"]
329
+
287
330
  aux_hidden_states = []
288
- for i, layer in enumerate(self.layers):
331
+ for i, layer in enumerate(
332
+ islice(self.layers, self.start_layer, self.end_layer)):
289
333
  if i in self.aux_hidden_state_layers:
290
334
  aux_hidden_states.append(x)
291
335
  kv_cache = kv_caches[i]
@@ -295,6 +339,10 @@ class LlamaModel(nnx.Module):
295
339
  attention_metadata,
296
340
  )
297
341
  kv_caches[i] = kv_cache
342
+ if not self.is_last_rank:
343
+ # Note: add aux_hidden_states to make the output spec consistent.
344
+ return kv_caches, JaxIntermediateTensors({"hidden_states":
345
+ x}), aux_hidden_states
298
346
  x = self.norm(x)
299
347
  return kv_caches, x, aux_hidden_states
300
348
 
@@ -313,19 +361,33 @@ class LlamaForCausalLM(nnx.Module):
313
361
  mesh=mesh,
314
362
  )
315
363
 
364
+ self.pp_missing_layers = []
365
+ for path, module in nnx.iter_graph(self.model):
366
+ if isinstance(module, PPMissingLayer):
367
+ # the path should be sth like ('layers', '0')
368
+ self.pp_missing_layers.append('.'.join([str(s) for s in path]))
369
+
316
370
  def __call__(
317
371
  self,
318
372
  kv_caches: List[jax.Array],
319
373
  input_ids: jax.Array,
320
374
  attention_metadata: AttentionMetadata,
375
+ _input_embeds=None,
376
+ _input_positions=None,
377
+ _layer_name_to_kv_cache=None,
378
+ _lora_metadata=None,
379
+ intermediate_tensors: JaxIntermediateTensors | None = None,
380
+ _is_first_rank: bool | None = None,
381
+ _is_last_rank: bool | None = None,
321
382
  *args,
322
- ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]:
323
- kv_caches, x, aux_hidden_states = self.model(
383
+ ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]] | Tuple[
384
+ List[jax.Array], JaxIntermediateTensors]:
385
+ return self.model(
324
386
  kv_caches,
325
387
  input_ids,
326
388
  attention_metadata,
389
+ intermediate_tensors,
327
390
  )
328
- return kv_caches, x, aux_hidden_states
329
391
 
330
392
  def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
331
393
  if self.vllm_config.model_config.hf_config.tie_word_embeddings:
@@ -373,4 +435,5 @@ class LlamaForCausalLM(nnx.Module):
373
435
  load_hf_weights(vllm_config=self.vllm_config,
374
436
  model=self,
375
437
  metadata_map=metadata_map,
376
- mesh=self.mesh)
438
+ mesh=self.mesh,
439
+ pp_missing_layers=self.pp_missing_layers)
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import re
2
16
  from typing import List, Optional, Tuple
3
17
 
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import List, Tuple
2
16
 
3
17
  import jax
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import re
2
16
  from typing import Any, List, Optional, Tuple
3
17
 
@@ -242,7 +256,7 @@ class LlamaGuard4ForCausalLM(nnx.Module):
242
256
  self.lm_head.input_embedding_table_DV.value)
243
257
  return logits_TV
244
258
 
245
- def get_input_embeddings(
259
+ def embed_input_ids(
246
260
  self,
247
261
  input_ids: jax.Array,
248
262
  multimodal_embeddings: Optional[List[jax.Array]] = None
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import List, Optional, Tuple
2
16
 
3
17
  import jax
@@ -10,6 +24,7 @@ from vllm.config import VllmConfig
10
24
  from tpu_inference import utils
11
25
  from tpu_inference.layers.common.attention_interface import attention
12
26
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
27
+ from tpu_inference.layers.common.quantization import quantize_kv
13
28
  from tpu_inference.layers.jax.rope_interface import apply_rope
14
29
  from tpu_inference.logger import init_logger
15
30
  from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
@@ -152,8 +167,8 @@ class Qwen2Attention(nnx.Module):
152
167
  # q_scale = self._q_scale
153
168
  k_scale = self._k_scale
154
169
  v_scale = self._v_scale
155
- k, v = utils.quantize_kv(k, v, self.kv_cache_quantized_dtype,
156
- k_scale, v_scale)
170
+ k, v = quantize_kv(self.kv_cache_quantized_dtype, k, v, k_scale,
171
+ v_scale)
157
172
  new_kv_cache, outputs = attention(
158
173
  kv_cache,
159
174
  q,
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import math
2
16
  from functools import partial
3
17
  from typing import (Callable, List, Literal, NamedTuple, Optional, TypedDict,
@@ -996,9 +1010,9 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
996
1010
  split_indices = np.cumsum(sizes)[:-1]
997
1011
  return tuple(jnp.split(image_embeds, split_indices))
998
1012
 
999
- def get_multimodal_embeddings(self, image_grid_thw: tuple[tuple[int, int,
1000
- int], ...],
1001
- **kwargs: object) -> MultiModalEmbeddings:
1013
+ def embed_multimodal(self, image_grid_thw: tuple[tuple[int, int, int],
1014
+ ...],
1015
+ **kwargs: object) -> MultiModalEmbeddings:
1002
1016
 
1003
1017
  mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
1004
1018
  image_grid_thw, **kwargs)
@@ -1022,7 +1036,7 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
1022
1036
 
1023
1037
  return multimodal_embeddings
1024
1038
 
1025
- def get_input_embeddings(
1039
+ def embed_input_ids(
1026
1040
  self, input_ids: jax.Array,
1027
1041
  multimodal_embeddings: Optional[jax.Array]) -> jax.Array:
1028
1042
 
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import List, Optional, Tuple
2
16
 
3
17
  import jax
@@ -10,6 +24,7 @@ from vllm.config import VllmConfig
10
24
  from tpu_inference import utils
11
25
  from tpu_inference.layers.common.attention_interface import attention
12
26
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
27
+ from tpu_inference.layers.common.quantization import quantize_kv
13
28
  from tpu_inference.layers.jax.rope_interface import apply_rope
14
29
  from tpu_inference.logger import init_logger
15
30
  from tpu_inference.models.jax.qwen2 import Qwen2DecoderLayer
@@ -125,8 +140,8 @@ class Qwen3Attention(nnx.Module):
125
140
  # q_scale = self._q_scale
126
141
  k_scale = self._k_scale
127
142
  v_scale = self._v_scale
128
- k, v = utils.quantize_kv(k, v, self.kv_cache_quantized_dtype,
129
- k_scale, v_scale)
143
+ k, v = quantize_kv(self.kv_cache_quantized_dtype, k, v, k_scale,
144
+ v_scale)
130
145
  new_kv_cache, outputs = attention(
131
146
  kv_cache,
132
147
  q,
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import glob
2
16
  import hashlib
3
17
  import os
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import Union
2
16
 
3
17
  import jax
@@ -29,25 +43,25 @@ def sanity_check_mm_encoder_outputs(
29
43
  ) -> None:
30
44
  """
31
45
  Perform sanity checks for the result of
32
- [`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`][].
46
+ [`vllm.model_executor.models.SupportsMultiModal.embed_multimodal`][].
33
47
  """
34
48
  assert isinstance(mm_embeddings, (list, tuple, jax.Array)), (
35
49
  "Expected multimodal embeddings to be a list/tuple of 2D tensors, "
36
50
  f"or a single 3D tensor, but got {type(mm_embeddings)} "
37
51
  "instead. This is most likely due to incorrect implementation "
38
- "of the model's `get_multimodal_embeddings` method.")
52
+ "of the model's `embed_multimodal` method.")
39
53
 
40
54
  assert len(mm_embeddings) == expected_num_items, (
41
55
  "Expected number of multimodal embeddings to match number of "
42
56
  f"input items: {expected_num_items}, but got {len(mm_embeddings)=} "
43
57
  "instead. This is most likely due to incorrect implementation "
44
- "of the model's `get_multimodal_embeddings` method.")
58
+ "of the model's `embed_multimodal` method.")
45
59
 
46
60
  assert all(e.ndim == 2 for e in mm_embeddings), (
47
61
  "Expected multimodal embeddings to be a sequence of 2D tensors, "
48
62
  f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
49
63
  "instead. This is most likely due to incorrect implementation "
50
- "of the model's `get_multimodal_embeddings` method.")
64
+ "of the model's `embed_multimodal` method.")
51
65
 
52
66
 
53
67
  def flatten_embeddings(embeddings: NestedTensors) -> jax.Array:
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.