tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (257) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +317 -34
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +406 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +320 -0
  64. tests/layers/vllm/test_unquantized.py +662 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +26 -6
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +25 -4
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +25 -12
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  154. tpu_inference/layers/common/quant_methods.py +15 -0
  155. tpu_inference/layers/common/quantization.py +282 -0
  156. tpu_inference/layers/common/sharding.py +32 -9
  157. tpu_inference/layers/common/utils.py +94 -0
  158. tpu_inference/layers/jax/__init__.py +13 -0
  159. tpu_inference/layers/jax/attention/__init__.py +13 -0
  160. tpu_inference/layers/jax/attention/attention.py +19 -6
  161. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  162. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  163. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  164. tpu_inference/layers/jax/base.py +14 -0
  165. tpu_inference/layers/jax/constants.py +13 -0
  166. tpu_inference/layers/jax/layers.py +14 -0
  167. tpu_inference/layers/jax/misc.py +14 -0
  168. tpu_inference/layers/jax/moe/__init__.py +13 -0
  169. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  170. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  171. tpu_inference/layers/jax/moe/moe.py +43 -3
  172. tpu_inference/layers/jax/pp_utils.py +53 -0
  173. tpu_inference/layers/jax/rope.py +14 -0
  174. tpu_inference/layers/jax/rope_interface.py +14 -0
  175. tpu_inference/layers/jax/sample/__init__.py +13 -0
  176. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  177. tpu_inference/layers/jax/sample/sampling.py +15 -1
  178. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  179. tpu_inference/layers/jax/transformer_block.py +14 -0
  180. tpu_inference/layers/vllm/__init__.py +13 -0
  181. tpu_inference/layers/vllm/attention.py +4 -4
  182. tpu_inference/layers/vllm/fused_moe.py +101 -494
  183. tpu_inference/layers/vllm/linear.py +64 -0
  184. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  185. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  186. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  187. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  188. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  189. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  191. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
  192. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
  193. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  194. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  195. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  196. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
  197. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  198. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
  199. tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
  200. tpu_inference/lora/__init__.py +13 -0
  201. tpu_inference/lora/torch_lora_ops.py +8 -13
  202. tpu_inference/models/__init__.py +13 -0
  203. tpu_inference/models/common/__init__.py +13 -0
  204. tpu_inference/models/common/model_loader.py +112 -35
  205. tpu_inference/models/jax/__init__.py +13 -0
  206. tpu_inference/models/jax/deepseek_v3.py +267 -157
  207. tpu_inference/models/jax/gpt_oss.py +26 -10
  208. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  209. tpu_inference/models/jax/llama3.py +99 -36
  210. tpu_inference/models/jax/llama4.py +14 -0
  211. tpu_inference/models/jax/llama_eagle3.py +18 -5
  212. tpu_inference/models/jax/llama_guard_4.py +15 -1
  213. tpu_inference/models/jax/qwen2.py +17 -2
  214. tpu_inference/models/jax/qwen2_5_vl.py +179 -51
  215. tpu_inference/models/jax/qwen3.py +17 -2
  216. tpu_inference/models/jax/utils/__init__.py +13 -0
  217. tpu_inference/models/jax/utils/file_utils.py +14 -0
  218. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  219. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  220. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
  221. tpu_inference/models/jax/utils/weight_utils.py +234 -155
  222. tpu_inference/models/vllm/__init__.py +13 -0
  223. tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
  224. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  225. tpu_inference/platforms/__init__.py +14 -0
  226. tpu_inference/platforms/tpu_platform.py +51 -72
  227. tpu_inference/runner/__init__.py +13 -0
  228. tpu_inference/runner/compilation_manager.py +180 -80
  229. tpu_inference/runner/kv_cache.py +54 -20
  230. tpu_inference/runner/kv_cache_manager.py +55 -33
  231. tpu_inference/runner/lora_utils.py +16 -1
  232. tpu_inference/runner/multimodal_manager.py +16 -2
  233. tpu_inference/runner/persistent_batch_manager.py +54 -2
  234. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  235. tpu_inference/runner/structured_decoding_manager.py +16 -3
  236. tpu_inference/runner/tpu_runner.py +124 -61
  237. tpu_inference/runner/utils.py +2 -2
  238. tpu_inference/spec_decode/__init__.py +13 -0
  239. tpu_inference/spec_decode/jax/__init__.py +13 -0
  240. tpu_inference/spec_decode/jax/eagle3.py +84 -22
  241. tpu_inference/tpu_info.py +14 -0
  242. tpu_inference/utils.py +72 -44
  243. tpu_inference/worker/__init__.py +13 -0
  244. tpu_inference/worker/tpu_worker.py +66 -52
  245. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
  246. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  247. tpu_inference/layers/vllm/linear_common.py +0 -186
  248. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  249. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  250. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  251. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  252. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  253. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  254. tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
  255. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  256. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  257. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import re
2
16
  from dataclasses import dataclass
3
17
  from typing import List, Optional, Tuple
@@ -11,6 +25,9 @@ from jax.sharding import Mesh, NamedSharding
11
25
  from jax.sharding import PartitionSpec as P
12
26
  from vllm.config import VllmConfig
13
27
 
28
+ from tpu_inference.layers.common.quant_methods import MXFP4
29
+ from tpu_inference.layers.common.quantization import (
30
+ dequantize_tensor_from_mxfp4_packed, e8m0_to_fp32, u8_unpack_e2m1)
14
31
  from tpu_inference.layers.jax.attention.gpt_oss_attention import (
15
32
  AttentionMetadata, GptOssAttention)
16
33
  from tpu_inference.layers.jax.constants import KVCacheType
@@ -18,8 +35,6 @@ from tpu_inference.layers.jax.layers import Embedder, LMhead, RMSNorm
18
35
  from tpu_inference.layers.jax.moe.gpt_oss_moe import GptOssMoE, GptOssRouter
19
36
  from tpu_inference.layers.jax.transformer_block import TransformerBlock
20
37
  from tpu_inference.logger import init_logger
21
- from tpu_inference.models.jax.utils.quantization.mxfp4_utils import (
22
- MXFP4_QUANT_METHOD, dequant_mxfp4_to_bf16, unpack_mxfp4_to_fp32)
23
38
  from tpu_inference.models.jax.utils.weight_utils import (
24
39
  get_param, model_weights_generator, print_param_info)
25
40
 
@@ -102,9 +117,9 @@ class GptOss(nnx.Module):
102
117
  rope_ntk_beta=rope_ntk_beta,
103
118
  rngs=self.rng,
104
119
  random_init=self.random_init,
105
- query_tnh=P(None, 'model', None),
106
- keyvalue_skh=P(None, 'model', None),
107
- attn_o_tnh=P(None, 'model', None),
120
+ query_tnh=P("data", 'model', None),
121
+ keyvalue_skh=P("data", 'model', None),
122
+ attn_o_tnh=P("data", 'model', None),
108
123
  dnh_sharding=P(None, 'model', None),
109
124
  dkh_sharding=P(None, 'model', None),
110
125
  nhd_sharding=P('model', None, None),
@@ -205,7 +220,7 @@ class GptOss(nnx.Module):
205
220
 
206
221
  # MXFP4 checkpoints swap last two dims for MoE to place packed dim at most minor
207
222
  swap_mlp_transform = transforms[
208
- "swap_last2"] if quant_method == MXFP4_QUANT_METHOD else None
223
+ "swap_last2"] if quant_method == MXFP4 else None
209
224
 
210
225
  mappings = {
211
226
  # Embeddings, Norms, and LM Head
@@ -285,7 +300,7 @@ class GptOss(nnx.Module):
285
300
  # Build a pool of weights with MXFP4 experts combined if neededs
286
301
  pool: dict[str, torch.Tensor | tuple] = (self._build_mxfp4_pool(
287
302
  names_and_weights_generator,
288
- mappings) if quant_method == MXFP4_QUANT_METHOD else {
303
+ mappings) if quant_method == MXFP4 else {
289
304
  loaded_name: loaded_weight
290
305
  for loaded_name, loaded_weight in names_and_weights_generator
291
306
  })
@@ -316,8 +331,9 @@ class GptOss(nnx.Module):
316
331
  blocks_u8, scales_u8 = loaded_weight
317
332
  # Quantized param (QArray): set qvalue/scale directly and skip regular path
318
333
  if hasattr(model_weight, "array"): # QArray check
319
- codes_fp32_t, scales_fp32_t = unpack_mxfp4_to_fp32(
320
- blocks_u8, scales_u8)
334
+ codes_fp32_t = u8_unpack_e2m1(blocks_u8).astype(
335
+ jnp.float32)
336
+ scales_fp32_t = e8m0_to_fp32(scales_u8)
321
337
  self._load_mxfp4(
322
338
  model_weight=model_weight,
323
339
  codes_fp32_t=codes_fp32_t,
@@ -328,7 +344,7 @@ class GptOss(nnx.Module):
328
344
  print_param_info(model_weight, loaded_name)
329
345
  continue
330
346
  # Not a QArray: dequantize MXFP4 to BF16 full weights
331
- prepared_weight = dequant_mxfp4_to_bf16(
347
+ prepared_weight = dequantize_tensor_from_mxfp4_packed(
332
348
  blocks_u8, scales_u8)
333
349
 
334
350
  # Single regular-tensor load call (BF16 or dequantized MXFP4)
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from dataclasses import dataclass
2
16
  from typing import TYPE_CHECKING, Any, Dict, Union
3
17
 
@@ -1,3 +1,18 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from itertools import islice
1
16
  from typing import List, Optional, Tuple
2
17
 
3
18
  import jax
@@ -8,13 +23,19 @@ from transformers import LlamaConfig, modeling_flax_utils
8
23
  from vllm.config import VllmConfig
9
24
 
10
25
  from tpu_inference import utils
26
+ from tpu_inference.distributed.jax_parallel_state import get_pp_group
11
27
  from tpu_inference.layers.common.attention_interface import attention
12
28
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
29
+ from tpu_inference.layers.common.quantization import quantize_kv
13
30
  from tpu_inference.layers.common.sharding import ShardingAxisName
31
+ from tpu_inference.layers.jax.pp_utils import PPMissingLayer, make_layers
14
32
  from tpu_inference.layers.jax.rope_interface import apply_rope
15
33
  from tpu_inference.logger import init_logger
34
+ from tpu_inference.models.jax.jax_intermediate_tensor import \
35
+ JaxIntermediateTensors
16
36
  from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
17
37
  load_hf_weights)
38
+ from tpu_inference.utils import get_mesh_shape_product
18
39
 
19
40
  logger = init_logger(__name__)
20
41
 
@@ -79,7 +100,8 @@ class LlamaAttention(nnx.Module):
79
100
  self.hidden_size // self.num_heads)
80
101
  self.head_dim = utils.get_padded_head_dim(self.head_dim_original)
81
102
 
82
- sharding_size = mesh.shape["model"] * mesh.shape.get("attn_dp", 1)
103
+ sharding_size = get_mesh_shape_product(mesh,
104
+ ShardingAxisName.MLP_TENSOR)
83
105
  self.num_heads = utils.get_padded_num_heads(self.num_heads,
84
106
  sharding_size)
85
107
  self.num_kv_heads = utils.get_padded_num_heads(self.num_kv_heads,
@@ -152,8 +174,8 @@ class LlamaAttention(nnx.Module):
152
174
  # q_scale = self._q_scale
153
175
  k_scale = self._k_scale
154
176
  v_scale = self._v_scale
155
- k, v = utils.quantize_kv(k, v, self.kv_cache_quantized_dtype,
156
- k_scale, v_scale)
177
+ k, v = quantize_kv(self.kv_cache_quantized_dtype, k, v, k_scale,
178
+ v_scale)
157
179
  new_kv_cache, outputs = attention(
158
180
  kv_cache,
159
181
  q,
@@ -235,38 +257,52 @@ class LlamaModel(nnx.Module):
235
257
  rms_norm_eps = hf_config.rms_norm_eps
236
258
  hidden_size = hf_config.hidden_size
237
259
 
238
- self.embed = nnx.Embed(
239
- num_embeddings=vocab_size,
240
- features=hidden_size,
241
- param_dtype=dtype,
242
- embedding_init=nnx.with_partitioning(
243
- init_fn, (ShardingAxisName.VOCAB, None)),
244
- rngs=rng,
245
- )
246
- self.layers = [
247
- LlamaDecoderLayer(
260
+ self.is_first_rank = get_pp_group().is_first_rank
261
+ self.is_last_rank = get_pp_group().is_last_rank
262
+
263
+ if self.is_first_rank or (hf_config.tie_word_embeddings
264
+ and self.is_last_rank):
265
+ self.embed = nnx.Embed(
266
+ num_embeddings=vocab_size,
267
+ features=hidden_size,
268
+ param_dtype=dtype,
269
+ embedding_init=nnx.with_partitioning(
270
+ init_fn, (ShardingAxisName.VOCAB, None)),
271
+ rngs=rng,
272
+ )
273
+ else:
274
+ self.embed = PPMissingLayer()
275
+
276
+ self.start_layer, self.end_layer, self.layers = make_layers(
277
+ hf_config.num_hidden_layers,
278
+ lambda: LlamaDecoderLayer(
248
279
  config=hf_config,
249
280
  dtype=dtype,
250
281
  rng=rng,
251
282
  mesh=mesh,
252
283
  # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
253
- kv_cache_dtype=vllm_config.cache_config.cache_dtype)
254
- for _ in range(hf_config.num_hidden_layers)
255
- ]
256
- self.norm = nnx.RMSNorm(
257
- hidden_size,
258
- epsilon=rms_norm_eps,
259
- param_dtype=dtype,
260
- scale_init=nnx.with_partitioning(init_fn, (None, )),
261
- rngs=rng,
262
- )
263
- if model_config.hf_config.tie_word_embeddings:
264
- self.lm_head = self.embed.embedding
265
- else:
266
- self.lm_head = nnx.Param(
267
- init_fn(rng.params(), (hidden_size, vocab_size), dtype),
268
- sharding=(None, ShardingAxisName.VOCAB),
284
+ kv_cache_dtype=vllm_config.cache_config.cache_dtype))
285
+ if self.is_last_rank:
286
+ self.norm = nnx.RMSNorm(
287
+ hidden_size,
288
+ epsilon=rms_norm_eps,
289
+ param_dtype=dtype,
290
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
291
+ rngs=rng,
269
292
  )
293
+ else:
294
+ self.norm = PPMissingLayer()
295
+
296
+ if self.is_last_rank:
297
+ if model_config.hf_config.tie_word_embeddings:
298
+ self.lm_head = self.embed.embedding
299
+ else:
300
+ self.lm_head = nnx.Param(
301
+ init_fn(rng.params(), (hidden_size, vocab_size), dtype),
302
+ sharding=(None, ShardingAxisName.VOCAB),
303
+ )
304
+ else:
305
+ self.lm_head = PPMissingLayer()
270
306
 
271
307
  self.aux_hidden_state_layers = []
272
308
  if vllm_config.speculative_config and vllm_config.speculative_config.method == "eagle3":
@@ -282,10 +318,18 @@ class LlamaModel(nnx.Module):
282
318
  kv_caches: List[jax.Array],
283
319
  input_ids: jax.Array,
284
320
  attention_metadata: AttentionMetadata,
285
- ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]:
286
- x = self.embed(input_ids)
321
+ intermediate_tensors: JaxIntermediateTensors | None,
322
+ ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]] | Tuple[
323
+ List[jax.Array], JaxIntermediateTensors]:
324
+ if self.is_first_rank:
325
+ x = self.embed(input_ids)
326
+ else:
327
+ assert intermediate_tensors is not None
328
+ x = intermediate_tensors["hidden_states"]
329
+
287
330
  aux_hidden_states = []
288
- for i, layer in enumerate(self.layers):
331
+ for i, layer in enumerate(
332
+ islice(self.layers, self.start_layer, self.end_layer)):
289
333
  if i in self.aux_hidden_state_layers:
290
334
  aux_hidden_states.append(x)
291
335
  kv_cache = kv_caches[i]
@@ -295,6 +339,10 @@ class LlamaModel(nnx.Module):
295
339
  attention_metadata,
296
340
  )
297
341
  kv_caches[i] = kv_cache
342
+ if not self.is_last_rank:
343
+ # Note: add aux_hidden_states to make the output spec consistent.
344
+ return kv_caches, JaxIntermediateTensors({"hidden_states":
345
+ x}), aux_hidden_states
298
346
  x = self.norm(x)
299
347
  return kv_caches, x, aux_hidden_states
300
348
 
@@ -313,19 +361,33 @@ class LlamaForCausalLM(nnx.Module):
313
361
  mesh=mesh,
314
362
  )
315
363
 
364
+ self.pp_missing_layers = []
365
+ for path, module in nnx.iter_graph(self.model):
366
+ if isinstance(module, PPMissingLayer):
367
+ # the path should be sth like ('layers', '0')
368
+ self.pp_missing_layers.append('.'.join([str(s) for s in path]))
369
+
316
370
  def __call__(
317
371
  self,
318
372
  kv_caches: List[jax.Array],
319
373
  input_ids: jax.Array,
320
374
  attention_metadata: AttentionMetadata,
375
+ _input_embeds=None,
376
+ _input_positions=None,
377
+ _layer_name_to_kv_cache=None,
378
+ _lora_metadata=None,
379
+ intermediate_tensors: JaxIntermediateTensors | None = None,
380
+ _is_first_rank: bool | None = None,
381
+ _is_last_rank: bool | None = None,
321
382
  *args,
322
- ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]:
323
- kv_caches, x, aux_hidden_states = self.model(
383
+ ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]] | Tuple[
384
+ List[jax.Array], JaxIntermediateTensors]:
385
+ return self.model(
324
386
  kv_caches,
325
387
  input_ids,
326
388
  attention_metadata,
389
+ intermediate_tensors,
327
390
  )
328
- return kv_caches, x, aux_hidden_states
329
391
 
330
392
  def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
331
393
  if self.vllm_config.model_config.hf_config.tie_word_embeddings:
@@ -373,4 +435,5 @@ class LlamaForCausalLM(nnx.Module):
373
435
  load_hf_weights(vllm_config=self.vllm_config,
374
436
  model=self,
375
437
  metadata_map=metadata_map,
376
- mesh=self.mesh)
438
+ mesh=self.mesh,
439
+ pp_missing_layers=self.pp_missing_layers)
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import re
2
16
  from typing import List, Optional, Tuple
3
17
 
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import List, Tuple
2
16
 
3
17
  import jax
@@ -304,6 +318,8 @@ class EagleLlama3ForCausalLM(nnx.Module):
304
318
  "fc": "model.fc.kernel",
305
319
  "lm_head": "lm_head.kernel",
306
320
  "d2t": "draft_id_to_target_id",
321
+ "embed_tokens":
322
+ "model.embed_tokens.embedding", # Some checkpoints need this
307
323
  }
308
324
 
309
325
  # Define keys to keep in original dtype (e.g., float32 for stability)
@@ -311,8 +327,6 @@ class EagleLlama3ForCausalLM(nnx.Module):
311
327
  r".*d2t.*",
312
328
  ]
313
329
 
314
- # `embed_tokens` is shared between target and draft.
315
- exclude_regex = [r".*embed_tokens.*"]
316
330
  metadata_map = get_default_maps(
317
331
  self.vllm_config.speculative_config.draft_model_config, self.mesh,
318
332
  mappings)
@@ -325,10 +339,9 @@ class EagleLlama3ForCausalLM(nnx.Module):
325
339
  metadata_map=metadata_map,
326
340
  mesh=self.mesh,
327
341
  is_draft_model=True,
328
- keep_original_dtype_keys_regex=keep_original_dtype_keys_regex,
329
- exclude_regex=exclude_regex if exclude_regex else None)
342
+ keep_original_dtype_keys_regex=keep_original_dtype_keys_regex)
330
343
 
331
- # If the embedding is not initialized, initialize it with a dummpy array here to pass jit compilation. The real weights will be shared from the target model in eagle3 class.
344
+ # If the embedding is not initialized, initialize it with a dummy array here to pass jit compilation. The real weights will be shared from the target model in eagle3 class.
332
345
  if isinstance(self.model.embed_tokens.embedding.value,
333
346
  jax.ShapeDtypeStruct):
334
347
  self.model.embed_tokens.embedding.value = jnp.zeros(
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import re
2
16
  from typing import Any, List, Optional, Tuple
3
17
 
@@ -242,7 +256,7 @@ class LlamaGuard4ForCausalLM(nnx.Module):
242
256
  self.lm_head.input_embedding_table_DV.value)
243
257
  return logits_TV
244
258
 
245
- def get_input_embeddings(
259
+ def embed_input_ids(
246
260
  self,
247
261
  input_ids: jax.Array,
248
262
  multimodal_embeddings: Optional[List[jax.Array]] = None
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import List, Optional, Tuple
2
16
 
3
17
  import jax
@@ -10,6 +24,7 @@ from vllm.config import VllmConfig
10
24
  from tpu_inference import utils
11
25
  from tpu_inference.layers.common.attention_interface import attention
12
26
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
27
+ from tpu_inference.layers.common.quantization import quantize_kv
13
28
  from tpu_inference.layers.jax.rope_interface import apply_rope
14
29
  from tpu_inference.logger import init_logger
15
30
  from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
@@ -152,8 +167,8 @@ class Qwen2Attention(nnx.Module):
152
167
  # q_scale = self._q_scale
153
168
  k_scale = self._k_scale
154
169
  v_scale = self._v_scale
155
- k, v = utils.quantize_kv(k, v, self.kv_cache_quantized_dtype,
156
- k_scale, v_scale)
170
+ k, v = quantize_kv(self.kv_cache_quantized_dtype, k, v, k_scale,
171
+ v_scale)
157
172
  new_kv_cache, outputs = attention(
158
173
  kv_cache,
159
174
  q,