tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (248) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +14 -0
  31. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  32. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
  35. tests/layers/__init__.py +13 -0
  36. tests/layers/common/__init__.py +13 -0
  37. tests/layers/common/test_attention_interface.py +156 -0
  38. tests/layers/common/test_quantization.py +149 -0
  39. tests/layers/jax/__init__.py +13 -0
  40. tests/layers/jax/attention/__init__.py +13 -0
  41. tests/layers/jax/attention/test_common_attention.py +103 -0
  42. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  43. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  44. tests/layers/jax/moe/__init__.py +13 -0
  45. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  46. tests/layers/jax/sample/__init__.py +13 -0
  47. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  48. tests/layers/jax/sample/test_sampling.py +115 -0
  49. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  50. tests/layers/jax/test_layers.py +155 -0
  51. tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
  52. tests/layers/jax/test_rope.py +93 -0
  53. tests/layers/jax/test_sharding.py +159 -0
  54. tests/layers/jax/test_transformer_block.py +152 -0
  55. tests/layers/vllm/__init__.py +13 -0
  56. tests/layers/vllm/test_attention.py +363 -0
  57. tests/layers/vllm/test_awq.py +406 -0
  58. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  59. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  61. tests/layers/vllm/test_fp8.py +17 -0
  62. tests/layers/vllm/test_mxfp4.py +320 -0
  63. tests/layers/vllm/test_unquantized.py +662 -0
  64. tests/layers/vllm/utils.py +87 -0
  65. tests/lora/__init__.py +13 -0
  66. tests/lora/conftest.py +14 -0
  67. tests/lora/test_bgmv.py +14 -0
  68. tests/lora/test_layers.py +25 -8
  69. tests/lora/test_lora.py +15 -1
  70. tests/lora/test_lora_perf.py +14 -0
  71. tests/models/__init__.py +13 -0
  72. tests/models/common/__init__.py +13 -0
  73. tests/models/common/test_model_loader.py +455 -0
  74. tests/models/jax/__init__.py +13 -0
  75. tests/models/jax/test_deepseek_v3.py +401 -0
  76. tests/models/jax/test_llama3.py +184 -0
  77. tests/models/jax/test_llama4.py +298 -0
  78. tests/models/jax/test_llama_eagle3.py +197 -0
  79. tests/models/jax/test_llama_guard_4.py +242 -0
  80. tests/models/jax/test_qwen2.py +172 -0
  81. tests/models/jax/test_qwen2_5_vl.py +605 -0
  82. tests/models/jax/test_qwen3.py +169 -0
  83. tests/models/jax/test_weight_loading.py +180 -0
  84. tests/models/jax/utils/__init__.py +13 -0
  85. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  86. tests/platforms/__init__.py +13 -0
  87. tests/platforms/test_tpu_platform.py +54 -0
  88. tests/runner/__init__.py +13 -0
  89. tests/runner/test_block_table.py +395 -0
  90. tests/runner/test_input_batch.py +226 -0
  91. tests/runner/test_kv_cache.py +220 -0
  92. tests/runner/test_kv_cache_manager.py +498 -0
  93. tests/runner/test_multimodal_manager.py +429 -0
  94. tests/runner/test_persistent_batch_manager.py +84 -0
  95. tests/runner/test_speculative_decoding_manager.py +368 -0
  96. tests/runner/test_structured_decoding_manager.py +220 -0
  97. tests/runner/test_tpu_runner.py +261 -0
  98. tests/runner/test_tpu_runner_dp.py +1099 -0
  99. tests/runner/test_tpu_runner_mesh.py +200 -0
  100. tests/runner/test_utils.py +411 -0
  101. tests/spec_decode/__init__.py +13 -0
  102. tests/spec_decode/test_eagle3.py +311 -0
  103. tests/test_base.py +14 -0
  104. tests/test_tpu_info.py +14 -0
  105. tests/test_utils.py +1 -43
  106. tests/worker/__init__.py +13 -0
  107. tests/worker/tpu_worker_test.py +414 -0
  108. tpu_inference/__init__.py +14 -0
  109. tpu_inference/core/__init__.py +13 -0
  110. tpu_inference/core/sched/__init__.py +13 -0
  111. tpu_inference/core/sched/dp_scheduler.py +372 -56
  112. tpu_inference/distributed/__init__.py +13 -0
  113. tpu_inference/distributed/jax_parallel_state.py +14 -0
  114. tpu_inference/distributed/tpu_connector.py +14 -9
  115. tpu_inference/distributed/utils.py +56 -4
  116. tpu_inference/executors/__init__.py +13 -0
  117. tpu_inference/executors/ray_distributed_executor.py +20 -3
  118. tpu_inference/experimental/__init__.py +13 -0
  119. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  120. tpu_inference/kernels/__init__.py +13 -0
  121. tpu_inference/kernels/collectives/__init__.py +13 -0
  122. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  123. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  124. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  125. tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
  126. tpu_inference/kernels/megablox/__init__.py +13 -0
  127. tpu_inference/kernels/megablox/common.py +54 -0
  128. tpu_inference/kernels/megablox/gmm.py +646 -0
  129. tpu_inference/kernels/mla/__init__.py +13 -0
  130. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  131. tpu_inference/kernels/mla/v1/kernel.py +20 -26
  132. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  133. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  134. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  135. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  136. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
  137. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
  138. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  139. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
  140. tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
  141. tpu_inference/layers/__init__.py +13 -0
  142. tpu_inference/layers/common/__init__.py +13 -0
  143. tpu_inference/layers/common/attention_interface.py +26 -19
  144. tpu_inference/layers/common/attention_metadata.py +14 -0
  145. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  146. tpu_inference/layers/common/quant_methods.py +15 -0
  147. tpu_inference/layers/common/quantization.py +282 -0
  148. tpu_inference/layers/common/sharding.py +22 -3
  149. tpu_inference/layers/common/utils.py +94 -0
  150. tpu_inference/layers/jax/__init__.py +13 -0
  151. tpu_inference/layers/jax/attention/__init__.py +13 -0
  152. tpu_inference/layers/jax/attention/attention.py +19 -6
  153. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
  154. tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
  155. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  156. tpu_inference/layers/jax/base.py +14 -0
  157. tpu_inference/layers/jax/constants.py +13 -0
  158. tpu_inference/layers/jax/layers.py +14 -0
  159. tpu_inference/layers/jax/misc.py +14 -0
  160. tpu_inference/layers/jax/moe/__init__.py +13 -0
  161. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  162. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  163. tpu_inference/layers/jax/moe/moe.py +43 -3
  164. tpu_inference/layers/jax/pp_utils.py +53 -0
  165. tpu_inference/layers/jax/rope.py +14 -0
  166. tpu_inference/layers/jax/rope_interface.py +14 -0
  167. tpu_inference/layers/jax/sample/__init__.py +13 -0
  168. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  169. tpu_inference/layers/jax/sample/sampling.py +15 -1
  170. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  171. tpu_inference/layers/jax/transformer_block.py +14 -0
  172. tpu_inference/layers/vllm/__init__.py +13 -0
  173. tpu_inference/layers/vllm/attention.py +4 -4
  174. tpu_inference/layers/vllm/fused_moe.py +100 -455
  175. tpu_inference/layers/vllm/linear.py +64 -0
  176. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  177. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  178. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  179. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  180. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  181. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  182. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  183. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
  184. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  188. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
  189. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  190. tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
  191. tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
  192. tpu_inference/lora/__init__.py +13 -0
  193. tpu_inference/lora/torch_lora_ops.py +8 -13
  194. tpu_inference/models/__init__.py +13 -0
  195. tpu_inference/models/common/__init__.py +13 -0
  196. tpu_inference/models/common/model_loader.py +37 -16
  197. tpu_inference/models/jax/__init__.py +13 -0
  198. tpu_inference/models/jax/deepseek_v3.py +113 -124
  199. tpu_inference/models/jax/gpt_oss.py +23 -7
  200. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  201. tpu_inference/models/jax/llama3.py +99 -36
  202. tpu_inference/models/jax/llama4.py +14 -0
  203. tpu_inference/models/jax/llama_eagle3.py +14 -0
  204. tpu_inference/models/jax/llama_guard_4.py +15 -1
  205. tpu_inference/models/jax/qwen2.py +17 -2
  206. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  207. tpu_inference/models/jax/qwen3.py +17 -2
  208. tpu_inference/models/jax/utils/__init__.py +13 -0
  209. tpu_inference/models/jax/utils/file_utils.py +14 -0
  210. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  211. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
  213. tpu_inference/models/jax/utils/weight_utils.py +32 -1
  214. tpu_inference/models/vllm/__init__.py +13 -0
  215. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
  216. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  217. tpu_inference/platforms/__init__.py +14 -0
  218. tpu_inference/platforms/tpu_platform.py +27 -29
  219. tpu_inference/runner/__init__.py +13 -0
  220. tpu_inference/runner/compilation_manager.py +69 -35
  221. tpu_inference/runner/kv_cache.py +14 -0
  222. tpu_inference/runner/kv_cache_manager.py +15 -2
  223. tpu_inference/runner/lora_utils.py +16 -1
  224. tpu_inference/runner/multimodal_manager.py +16 -2
  225. tpu_inference/runner/persistent_batch_manager.py +14 -0
  226. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  227. tpu_inference/runner/structured_decoding_manager.py +14 -0
  228. tpu_inference/runner/tpu_runner.py +30 -10
  229. tpu_inference/spec_decode/__init__.py +13 -0
  230. tpu_inference/spec_decode/jax/__init__.py +13 -0
  231. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  232. tpu_inference/tpu_info.py +14 -0
  233. tpu_inference/utils.py +31 -30
  234. tpu_inference/worker/__init__.py +13 -0
  235. tpu_inference/worker/tpu_worker.py +23 -7
  236. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
  237. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  238. tpu_inference/layers/vllm/linear_common.py +0 -208
  239. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  240. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  241. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  242. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  245. tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
  246. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  247. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  248. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,18 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from itertools import islice
1
16
  from typing import List, Optional, Tuple
2
17
 
3
18
  import jax
@@ -8,13 +23,19 @@ from transformers import LlamaConfig, modeling_flax_utils
8
23
  from vllm.config import VllmConfig
9
24
 
10
25
  from tpu_inference import utils
26
+ from tpu_inference.distributed.jax_parallel_state import get_pp_group
11
27
  from tpu_inference.layers.common.attention_interface import attention
12
28
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
29
+ from tpu_inference.layers.common.quantization import quantize_kv
13
30
  from tpu_inference.layers.common.sharding import ShardingAxisName
31
+ from tpu_inference.layers.jax.pp_utils import PPMissingLayer, make_layers
14
32
  from tpu_inference.layers.jax.rope_interface import apply_rope
15
33
  from tpu_inference.logger import init_logger
34
+ from tpu_inference.models.jax.jax_intermediate_tensor import \
35
+ JaxIntermediateTensors
16
36
  from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
17
37
  load_hf_weights)
38
+ from tpu_inference.utils import get_mesh_shape_product
18
39
 
19
40
  logger = init_logger(__name__)
20
41
 
@@ -79,7 +100,8 @@ class LlamaAttention(nnx.Module):
79
100
  self.hidden_size // self.num_heads)
80
101
  self.head_dim = utils.get_padded_head_dim(self.head_dim_original)
81
102
 
82
- sharding_size = mesh.shape["model"] * mesh.shape.get("attn_dp", 1)
103
+ sharding_size = get_mesh_shape_product(mesh,
104
+ ShardingAxisName.MLP_TENSOR)
83
105
  self.num_heads = utils.get_padded_num_heads(self.num_heads,
84
106
  sharding_size)
85
107
  self.num_kv_heads = utils.get_padded_num_heads(self.num_kv_heads,
@@ -152,8 +174,8 @@ class LlamaAttention(nnx.Module):
152
174
  # q_scale = self._q_scale
153
175
  k_scale = self._k_scale
154
176
  v_scale = self._v_scale
155
- k, v = utils.quantize_kv(k, v, self.kv_cache_quantized_dtype,
156
- k_scale, v_scale)
177
+ k, v = quantize_kv(self.kv_cache_quantized_dtype, k, v, k_scale,
178
+ v_scale)
157
179
  new_kv_cache, outputs = attention(
158
180
  kv_cache,
159
181
  q,
@@ -235,38 +257,52 @@ class LlamaModel(nnx.Module):
235
257
  rms_norm_eps = hf_config.rms_norm_eps
236
258
  hidden_size = hf_config.hidden_size
237
259
 
238
- self.embed = nnx.Embed(
239
- num_embeddings=vocab_size,
240
- features=hidden_size,
241
- param_dtype=dtype,
242
- embedding_init=nnx.with_partitioning(
243
- init_fn, (ShardingAxisName.VOCAB, None)),
244
- rngs=rng,
245
- )
246
- self.layers = [
247
- LlamaDecoderLayer(
260
+ self.is_first_rank = get_pp_group().is_first_rank
261
+ self.is_last_rank = get_pp_group().is_last_rank
262
+
263
+ if self.is_first_rank or (hf_config.tie_word_embeddings
264
+ and self.is_last_rank):
265
+ self.embed = nnx.Embed(
266
+ num_embeddings=vocab_size,
267
+ features=hidden_size,
268
+ param_dtype=dtype,
269
+ embedding_init=nnx.with_partitioning(
270
+ init_fn, (ShardingAxisName.VOCAB, None)),
271
+ rngs=rng,
272
+ )
273
+ else:
274
+ self.embed = PPMissingLayer()
275
+
276
+ self.start_layer, self.end_layer, self.layers = make_layers(
277
+ hf_config.num_hidden_layers,
278
+ lambda: LlamaDecoderLayer(
248
279
  config=hf_config,
249
280
  dtype=dtype,
250
281
  rng=rng,
251
282
  mesh=mesh,
252
283
  # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
253
- kv_cache_dtype=vllm_config.cache_config.cache_dtype)
254
- for _ in range(hf_config.num_hidden_layers)
255
- ]
256
- self.norm = nnx.RMSNorm(
257
- hidden_size,
258
- epsilon=rms_norm_eps,
259
- param_dtype=dtype,
260
- scale_init=nnx.with_partitioning(init_fn, (None, )),
261
- rngs=rng,
262
- )
263
- if model_config.hf_config.tie_word_embeddings:
264
- self.lm_head = self.embed.embedding
265
- else:
266
- self.lm_head = nnx.Param(
267
- init_fn(rng.params(), (hidden_size, vocab_size), dtype),
268
- sharding=(None, ShardingAxisName.VOCAB),
284
+ kv_cache_dtype=vllm_config.cache_config.cache_dtype))
285
+ if self.is_last_rank:
286
+ self.norm = nnx.RMSNorm(
287
+ hidden_size,
288
+ epsilon=rms_norm_eps,
289
+ param_dtype=dtype,
290
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
291
+ rngs=rng,
269
292
  )
293
+ else:
294
+ self.norm = PPMissingLayer()
295
+
296
+ if self.is_last_rank:
297
+ if model_config.hf_config.tie_word_embeddings:
298
+ self.lm_head = self.embed.embedding
299
+ else:
300
+ self.lm_head = nnx.Param(
301
+ init_fn(rng.params(), (hidden_size, vocab_size), dtype),
302
+ sharding=(None, ShardingAxisName.VOCAB),
303
+ )
304
+ else:
305
+ self.lm_head = PPMissingLayer()
270
306
 
271
307
  self.aux_hidden_state_layers = []
272
308
  if vllm_config.speculative_config and vllm_config.speculative_config.method == "eagle3":
@@ -282,10 +318,18 @@ class LlamaModel(nnx.Module):
282
318
  kv_caches: List[jax.Array],
283
319
  input_ids: jax.Array,
284
320
  attention_metadata: AttentionMetadata,
285
- ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]:
286
- x = self.embed(input_ids)
321
+ intermediate_tensors: JaxIntermediateTensors | None,
322
+ ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]] | Tuple[
323
+ List[jax.Array], JaxIntermediateTensors]:
324
+ if self.is_first_rank:
325
+ x = self.embed(input_ids)
326
+ else:
327
+ assert intermediate_tensors is not None
328
+ x = intermediate_tensors["hidden_states"]
329
+
287
330
  aux_hidden_states = []
288
- for i, layer in enumerate(self.layers):
331
+ for i, layer in enumerate(
332
+ islice(self.layers, self.start_layer, self.end_layer)):
289
333
  if i in self.aux_hidden_state_layers:
290
334
  aux_hidden_states.append(x)
291
335
  kv_cache = kv_caches[i]
@@ -295,6 +339,10 @@ class LlamaModel(nnx.Module):
295
339
  attention_metadata,
296
340
  )
297
341
  kv_caches[i] = kv_cache
342
+ if not self.is_last_rank:
343
+ # Note: add aux_hidden_states to make the output spec consistent.
344
+ return kv_caches, JaxIntermediateTensors({"hidden_states":
345
+ x}), aux_hidden_states
298
346
  x = self.norm(x)
299
347
  return kv_caches, x, aux_hidden_states
300
348
 
@@ -313,19 +361,33 @@ class LlamaForCausalLM(nnx.Module):
313
361
  mesh=mesh,
314
362
  )
315
363
 
364
+ self.pp_missing_layers = []
365
+ for path, module in nnx.iter_graph(self.model):
366
+ if isinstance(module, PPMissingLayer):
367
+ # the path should be sth like ('layers', '0')
368
+ self.pp_missing_layers.append('.'.join([str(s) for s in path]))
369
+
316
370
  def __call__(
317
371
  self,
318
372
  kv_caches: List[jax.Array],
319
373
  input_ids: jax.Array,
320
374
  attention_metadata: AttentionMetadata,
375
+ _input_embeds=None,
376
+ _input_positions=None,
377
+ _layer_name_to_kv_cache=None,
378
+ _lora_metadata=None,
379
+ intermediate_tensors: JaxIntermediateTensors | None = None,
380
+ _is_first_rank: bool | None = None,
381
+ _is_last_rank: bool | None = None,
321
382
  *args,
322
- ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]:
323
- kv_caches, x, aux_hidden_states = self.model(
383
+ ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]] | Tuple[
384
+ List[jax.Array], JaxIntermediateTensors]:
385
+ return self.model(
324
386
  kv_caches,
325
387
  input_ids,
326
388
  attention_metadata,
389
+ intermediate_tensors,
327
390
  )
328
- return kv_caches, x, aux_hidden_states
329
391
 
330
392
  def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
331
393
  if self.vllm_config.model_config.hf_config.tie_word_embeddings:
@@ -373,4 +435,5 @@ class LlamaForCausalLM(nnx.Module):
373
435
  load_hf_weights(vllm_config=self.vllm_config,
374
436
  model=self,
375
437
  metadata_map=metadata_map,
376
- mesh=self.mesh)
438
+ mesh=self.mesh,
439
+ pp_missing_layers=self.pp_missing_layers)
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import re
2
16
  from typing import List, Optional, Tuple
3
17
 
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import List, Tuple
2
16
 
3
17
  import jax
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import re
2
16
  from typing import Any, List, Optional, Tuple
3
17
 
@@ -242,7 +256,7 @@ class LlamaGuard4ForCausalLM(nnx.Module):
242
256
  self.lm_head.input_embedding_table_DV.value)
243
257
  return logits_TV
244
258
 
245
- def get_input_embeddings(
259
+ def embed_input_ids(
246
260
  self,
247
261
  input_ids: jax.Array,
248
262
  multimodal_embeddings: Optional[List[jax.Array]] = None
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import List, Optional, Tuple
2
16
 
3
17
  import jax
@@ -10,6 +24,7 @@ from vllm.config import VllmConfig
10
24
  from tpu_inference import utils
11
25
  from tpu_inference.layers.common.attention_interface import attention
12
26
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
27
+ from tpu_inference.layers.common.quantization import quantize_kv
13
28
  from tpu_inference.layers.jax.rope_interface import apply_rope
14
29
  from tpu_inference.logger import init_logger
15
30
  from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
@@ -152,8 +167,8 @@ class Qwen2Attention(nnx.Module):
152
167
  # q_scale = self._q_scale
153
168
  k_scale = self._k_scale
154
169
  v_scale = self._v_scale
155
- k, v = utils.quantize_kv(k, v, self.kv_cache_quantized_dtype,
156
- k_scale, v_scale)
170
+ k, v = quantize_kv(self.kv_cache_quantized_dtype, k, v, k_scale,
171
+ v_scale)
157
172
  new_kv_cache, outputs = attention(
158
173
  kv_cache,
159
174
  q,
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import math
2
16
  from functools import partial
3
17
  from typing import (Callable, List, Literal, NamedTuple, Optional, TypedDict,
@@ -996,9 +1010,9 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
996
1010
  split_indices = np.cumsum(sizes)[:-1]
997
1011
  return tuple(jnp.split(image_embeds, split_indices))
998
1012
 
999
- def get_multimodal_embeddings(self, image_grid_thw: tuple[tuple[int, int,
1000
- int], ...],
1001
- **kwargs: object) -> MultiModalEmbeddings:
1013
+ def embed_multimodal(self, image_grid_thw: tuple[tuple[int, int, int],
1014
+ ...],
1015
+ **kwargs: object) -> MultiModalEmbeddings:
1002
1016
 
1003
1017
  mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
1004
1018
  image_grid_thw, **kwargs)
@@ -1022,7 +1036,7 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
1022
1036
 
1023
1037
  return multimodal_embeddings
1024
1038
 
1025
- def get_input_embeddings(
1039
+ def embed_input_ids(
1026
1040
  self, input_ids: jax.Array,
1027
1041
  multimodal_embeddings: Optional[jax.Array]) -> jax.Array:
1028
1042
 
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import List, Optional, Tuple
2
16
 
3
17
  import jax
@@ -10,6 +24,7 @@ from vllm.config import VllmConfig
10
24
  from tpu_inference import utils
11
25
  from tpu_inference.layers.common.attention_interface import attention
12
26
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
27
+ from tpu_inference.layers.common.quantization import quantize_kv
13
28
  from tpu_inference.layers.jax.rope_interface import apply_rope
14
29
  from tpu_inference.logger import init_logger
15
30
  from tpu_inference.models.jax.qwen2 import Qwen2DecoderLayer
@@ -125,8 +140,8 @@ class Qwen3Attention(nnx.Module):
125
140
  # q_scale = self._q_scale
126
141
  k_scale = self._k_scale
127
142
  v_scale = self._v_scale
128
- k, v = utils.quantize_kv(k, v, self.kv_cache_quantized_dtype,
129
- k_scale, v_scale)
143
+ k, v = quantize_kv(self.kv_cache_quantized_dtype, k, v, k_scale,
144
+ v_scale)
130
145
  new_kv_cache, outputs = attention(
131
146
  kv_cache,
132
147
  q,
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import glob
2
16
  import hashlib
3
17
  import os
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import Union
2
16
 
3
17
  import jax
@@ -29,25 +43,25 @@ def sanity_check_mm_encoder_outputs(
29
43
  ) -> None:
30
44
  """
31
45
  Perform sanity checks for the result of
32
- [`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`][].
46
+ [`vllm.model_executor.models.SupportsMultiModal.embed_multimodal`][].
33
47
  """
34
48
  assert isinstance(mm_embeddings, (list, tuple, jax.Array)), (
35
49
  "Expected multimodal embeddings to be a list/tuple of 2D tensors, "
36
50
  f"or a single 3D tensor, but got {type(mm_embeddings)} "
37
51
  "instead. This is most likely due to incorrect implementation "
38
- "of the model's `get_multimodal_embeddings` method.")
52
+ "of the model's `embed_multimodal` method.")
39
53
 
40
54
  assert len(mm_embeddings) == expected_num_items, (
41
55
  "Expected number of multimodal embeddings to match number of "
42
56
  f"input items: {expected_num_items}, but got {len(mm_embeddings)=} "
43
57
  "instead. This is most likely due to incorrect implementation "
44
- "of the model's `get_multimodal_embeddings` method.")
58
+ "of the model's `embed_multimodal` method.")
45
59
 
46
60
  assert all(e.ndim == 2 for e in mm_embeddings), (
47
61
  "Expected multimodal embeddings to be a sequence of 2D tensors, "
48
62
  f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
49
63
  "instead. This is most likely due to incorrect implementation "
50
- "of the model's `get_multimodal_embeddings` method.")
64
+ "of the model's `embed_multimodal` method.")
51
65
 
52
66
 
53
67
  def flatten_embeddings(embeddings: NestedTensors) -> jax.Array:
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.