tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (248) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +14 -0
  31. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  32. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
  35. tests/layers/__init__.py +13 -0
  36. tests/layers/common/__init__.py +13 -0
  37. tests/layers/common/test_attention_interface.py +156 -0
  38. tests/layers/common/test_quantization.py +149 -0
  39. tests/layers/jax/__init__.py +13 -0
  40. tests/layers/jax/attention/__init__.py +13 -0
  41. tests/layers/jax/attention/test_common_attention.py +103 -0
  42. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  43. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  44. tests/layers/jax/moe/__init__.py +13 -0
  45. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  46. tests/layers/jax/sample/__init__.py +13 -0
  47. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  48. tests/layers/jax/sample/test_sampling.py +115 -0
  49. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  50. tests/layers/jax/test_layers.py +155 -0
  51. tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
  52. tests/layers/jax/test_rope.py +93 -0
  53. tests/layers/jax/test_sharding.py +159 -0
  54. tests/layers/jax/test_transformer_block.py +152 -0
  55. tests/layers/vllm/__init__.py +13 -0
  56. tests/layers/vllm/test_attention.py +363 -0
  57. tests/layers/vllm/test_awq.py +406 -0
  58. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  59. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  61. tests/layers/vllm/test_fp8.py +17 -0
  62. tests/layers/vllm/test_mxfp4.py +320 -0
  63. tests/layers/vllm/test_unquantized.py +662 -0
  64. tests/layers/vllm/utils.py +87 -0
  65. tests/lora/__init__.py +13 -0
  66. tests/lora/conftest.py +14 -0
  67. tests/lora/test_bgmv.py +14 -0
  68. tests/lora/test_layers.py +25 -8
  69. tests/lora/test_lora.py +15 -1
  70. tests/lora/test_lora_perf.py +14 -0
  71. tests/models/__init__.py +13 -0
  72. tests/models/common/__init__.py +13 -0
  73. tests/models/common/test_model_loader.py +455 -0
  74. tests/models/jax/__init__.py +13 -0
  75. tests/models/jax/test_deepseek_v3.py +401 -0
  76. tests/models/jax/test_llama3.py +184 -0
  77. tests/models/jax/test_llama4.py +298 -0
  78. tests/models/jax/test_llama_eagle3.py +197 -0
  79. tests/models/jax/test_llama_guard_4.py +242 -0
  80. tests/models/jax/test_qwen2.py +172 -0
  81. tests/models/jax/test_qwen2_5_vl.py +605 -0
  82. tests/models/jax/test_qwen3.py +169 -0
  83. tests/models/jax/test_weight_loading.py +180 -0
  84. tests/models/jax/utils/__init__.py +13 -0
  85. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  86. tests/platforms/__init__.py +13 -0
  87. tests/platforms/test_tpu_platform.py +54 -0
  88. tests/runner/__init__.py +13 -0
  89. tests/runner/test_block_table.py +395 -0
  90. tests/runner/test_input_batch.py +226 -0
  91. tests/runner/test_kv_cache.py +220 -0
  92. tests/runner/test_kv_cache_manager.py +498 -0
  93. tests/runner/test_multimodal_manager.py +429 -0
  94. tests/runner/test_persistent_batch_manager.py +84 -0
  95. tests/runner/test_speculative_decoding_manager.py +368 -0
  96. tests/runner/test_structured_decoding_manager.py +220 -0
  97. tests/runner/test_tpu_runner.py +261 -0
  98. tests/runner/test_tpu_runner_dp.py +1099 -0
  99. tests/runner/test_tpu_runner_mesh.py +200 -0
  100. tests/runner/test_utils.py +411 -0
  101. tests/spec_decode/__init__.py +13 -0
  102. tests/spec_decode/test_eagle3.py +311 -0
  103. tests/test_base.py +14 -0
  104. tests/test_tpu_info.py +14 -0
  105. tests/test_utils.py +1 -43
  106. tests/worker/__init__.py +13 -0
  107. tests/worker/tpu_worker_test.py +414 -0
  108. tpu_inference/__init__.py +14 -0
  109. tpu_inference/core/__init__.py +13 -0
  110. tpu_inference/core/sched/__init__.py +13 -0
  111. tpu_inference/core/sched/dp_scheduler.py +372 -56
  112. tpu_inference/distributed/__init__.py +13 -0
  113. tpu_inference/distributed/jax_parallel_state.py +14 -0
  114. tpu_inference/distributed/tpu_connector.py +14 -9
  115. tpu_inference/distributed/utils.py +56 -4
  116. tpu_inference/executors/__init__.py +13 -0
  117. tpu_inference/executors/ray_distributed_executor.py +20 -3
  118. tpu_inference/experimental/__init__.py +13 -0
  119. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  120. tpu_inference/kernels/__init__.py +13 -0
  121. tpu_inference/kernels/collectives/__init__.py +13 -0
  122. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  123. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  124. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  125. tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
  126. tpu_inference/kernels/megablox/__init__.py +13 -0
  127. tpu_inference/kernels/megablox/common.py +54 -0
  128. tpu_inference/kernels/megablox/gmm.py +646 -0
  129. tpu_inference/kernels/mla/__init__.py +13 -0
  130. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  131. tpu_inference/kernels/mla/v1/kernel.py +20 -26
  132. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  133. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  134. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  135. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  136. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
  137. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
  138. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  139. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
  140. tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
  141. tpu_inference/layers/__init__.py +13 -0
  142. tpu_inference/layers/common/__init__.py +13 -0
  143. tpu_inference/layers/common/attention_interface.py +26 -19
  144. tpu_inference/layers/common/attention_metadata.py +14 -0
  145. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  146. tpu_inference/layers/common/quant_methods.py +15 -0
  147. tpu_inference/layers/common/quantization.py +282 -0
  148. tpu_inference/layers/common/sharding.py +22 -3
  149. tpu_inference/layers/common/utils.py +94 -0
  150. tpu_inference/layers/jax/__init__.py +13 -0
  151. tpu_inference/layers/jax/attention/__init__.py +13 -0
  152. tpu_inference/layers/jax/attention/attention.py +19 -6
  153. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
  154. tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
  155. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  156. tpu_inference/layers/jax/base.py +14 -0
  157. tpu_inference/layers/jax/constants.py +13 -0
  158. tpu_inference/layers/jax/layers.py +14 -0
  159. tpu_inference/layers/jax/misc.py +14 -0
  160. tpu_inference/layers/jax/moe/__init__.py +13 -0
  161. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  162. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  163. tpu_inference/layers/jax/moe/moe.py +43 -3
  164. tpu_inference/layers/jax/pp_utils.py +53 -0
  165. tpu_inference/layers/jax/rope.py +14 -0
  166. tpu_inference/layers/jax/rope_interface.py +14 -0
  167. tpu_inference/layers/jax/sample/__init__.py +13 -0
  168. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  169. tpu_inference/layers/jax/sample/sampling.py +15 -1
  170. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  171. tpu_inference/layers/jax/transformer_block.py +14 -0
  172. tpu_inference/layers/vllm/__init__.py +13 -0
  173. tpu_inference/layers/vllm/attention.py +4 -4
  174. tpu_inference/layers/vllm/fused_moe.py +100 -455
  175. tpu_inference/layers/vllm/linear.py +64 -0
  176. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  177. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  178. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  179. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  180. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  181. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  182. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  183. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
  184. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  188. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
  189. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  190. tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
  191. tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
  192. tpu_inference/lora/__init__.py +13 -0
  193. tpu_inference/lora/torch_lora_ops.py +8 -13
  194. tpu_inference/models/__init__.py +13 -0
  195. tpu_inference/models/common/__init__.py +13 -0
  196. tpu_inference/models/common/model_loader.py +37 -16
  197. tpu_inference/models/jax/__init__.py +13 -0
  198. tpu_inference/models/jax/deepseek_v3.py +113 -124
  199. tpu_inference/models/jax/gpt_oss.py +23 -7
  200. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  201. tpu_inference/models/jax/llama3.py +99 -36
  202. tpu_inference/models/jax/llama4.py +14 -0
  203. tpu_inference/models/jax/llama_eagle3.py +14 -0
  204. tpu_inference/models/jax/llama_guard_4.py +15 -1
  205. tpu_inference/models/jax/qwen2.py +17 -2
  206. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  207. tpu_inference/models/jax/qwen3.py +17 -2
  208. tpu_inference/models/jax/utils/__init__.py +13 -0
  209. tpu_inference/models/jax/utils/file_utils.py +14 -0
  210. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  211. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
  213. tpu_inference/models/jax/utils/weight_utils.py +32 -1
  214. tpu_inference/models/vllm/__init__.py +13 -0
  215. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
  216. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  217. tpu_inference/platforms/__init__.py +14 -0
  218. tpu_inference/platforms/tpu_platform.py +27 -29
  219. tpu_inference/runner/__init__.py +13 -0
  220. tpu_inference/runner/compilation_manager.py +69 -35
  221. tpu_inference/runner/kv_cache.py +14 -0
  222. tpu_inference/runner/kv_cache_manager.py +15 -2
  223. tpu_inference/runner/lora_utils.py +16 -1
  224. tpu_inference/runner/multimodal_manager.py +16 -2
  225. tpu_inference/runner/persistent_batch_manager.py +14 -0
  226. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  227. tpu_inference/runner/structured_decoding_manager.py +14 -0
  228. tpu_inference/runner/tpu_runner.py +30 -10
  229. tpu_inference/spec_decode/__init__.py +13 -0
  230. tpu_inference/spec_decode/jax/__init__.py +13 -0
  231. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  232. tpu_inference/tpu_info.py +14 -0
  233. tpu_inference/utils.py +31 -30
  234. tpu_inference/worker/__init__.py +13 -0
  235. tpu_inference/worker/tpu_worker.py +23 -7
  236. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
  237. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  238. tpu_inference/layers/vllm/linear_common.py +0 -208
  239. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  240. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  241. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  242. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  245. tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
  246. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  247. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  248. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import math
2
16
  from dataclasses import InitVar, dataclass
3
17
  from typing import Any, Tuple
@@ -6,7 +20,6 @@ import jax
6
20
  import jax.numpy as jnp
7
21
  from flax import nnx
8
22
  from flax.typing import Sharding
9
- from jax.experimental import shard_map
10
23
  from jax.sharding import Mesh
11
24
  from jax.sharding import PartitionSpec as P
12
25
 
@@ -17,6 +30,7 @@ from tpu_inference.kernels.ragged_paged_attention.v3.kernel import \
17
30
  from tpu_inference.kernels.ragged_paged_attention.v3.tuned_block_sizes import \
18
31
  get_tuned_block_sizes
19
32
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
33
+ from tpu_inference.layers.common.quantization import quantize_kv
20
34
  from tpu_inference.layers.common.sharding import ShardingAxisName
21
35
  from tpu_inference.layers.jax.base import create_param
22
36
  from tpu_inference.layers.jax.layers import RMSNorm
@@ -52,8 +66,9 @@ class MLA(nnx.Module):
52
66
  rms_norm_eps: float
53
67
 
54
68
  # Sharding attributes
55
- nhd_sharding: Sharding = ()
69
+ rd_sharding: Sharding = ()
56
70
  q_da_sharding: Sharding = ()
71
+ ap_sharding: Sharding = ()
57
72
  anh_sharding: Sharding = ()
58
73
  kv_da_sharding: Sharding = ()
59
74
 
@@ -113,10 +128,10 @@ class MLA(nnx.Module):
113
128
  self.q_da_sharding,
114
129
  self.dtype,
115
130
  random_init=self.random_init)
116
- self.kernel_q_up_proj_ANH = create_param(
131
+ self.kernel_q_up_proj_AP = create_param(
117
132
  rngs,
118
- (self.q_lora_rank, self.N, self.qk_head_dim),
119
- self.anh_sharding,
133
+ (self.q_lora_rank, self.N * self.qk_head_dim),
134
+ self.ap_sharding,
120
135
  self.dtype,
121
136
  random_init=self.random_init,
122
137
  )
@@ -127,6 +142,10 @@ class MLA(nnx.Module):
127
142
  self.dtype,
128
143
  random_init=self.random_init,
129
144
  )
145
+ # NOTE (jacobplatin): we are keeping these variables as 3D because
146
+ # we would need to reshape them before the below projection,
147
+ # which caused issues as Qwix wasn't quantizing it correctly
148
+ # on the abstract pass
130
149
  if self.use_mla_kernel:
131
150
  self.kernel_k_up_proj_ANH = create_param(
132
151
  rngs,
@@ -143,17 +162,18 @@ class MLA(nnx.Module):
143
162
  random_init=self.random_init,
144
163
  )
145
164
  else:
146
- self.kernel_kv_up_proj_ANH = create_param(
165
+ self.kernel_kv_up_proj_AL = create_param(
147
166
  rngs,
148
- (self.kv_lora_rank, self.N,
149
- self.qk_nope_head_dim + self.v_head_dim),
150
- self.anh_sharding,
167
+ (self.kv_lora_rank, self.N *
168
+ (self.qk_nope_head_dim + self.v_head_dim)),
169
+ self.
170
+ ap_sharding, # NOTE: we use the same sharding for kv_up_proj_AL and kernel_q_up_proj_AP
151
171
  self.dtype,
152
172
  random_init=self.random_init,
153
173
  )
154
- self.kernel_o_proj_NHD = create_param(
155
- rngs, (self.N, self.v_head_dim, self.D),
156
- self.nhd_sharding,
174
+ self.kernel_o_proj_RD = create_param(
175
+ rngs, (self.N * self.v_head_dim, self.D),
176
+ self.rd_sharding,
157
177
  self.dtype,
158
178
  random_init=self.random_init)
159
179
  self.q_rms_norm = RMSNorm(
@@ -209,9 +229,10 @@ class MLA(nnx.Module):
209
229
  q_TA = jnp.einsum("TD,DA -> TA", x_q_TD,
210
230
  self.kernel_q_down_proj_DA.value)
211
231
  q_TA = self.q_rms_norm(q_TA)
212
- # Query up projection.
213
- q_TNH = jnp.einsum("TA,ANH -> TNH", q_TA,
214
- self.kernel_q_up_proj_ANH.value)
232
+ # Query up projection, then reshape to TNH.
233
+ q_TP = jnp.einsum("TA,AP -> TP", q_TA,
234
+ self.kernel_q_up_proj_AP.value)
235
+ q_TNH = q_TP.reshape(q_TA.shape[0], self.N, self.qk_head_dim)
215
236
  # Split the query into nope and rope.
216
237
  q_nope_TNH = q_TNH[..., :self.qk_nope_head_dim]
217
238
  q_rope_TNH = q_TNH[..., self.qk_nope_head_dim:]
@@ -247,9 +268,12 @@ class MLA(nnx.Module):
247
268
  k_rope_SNH = jnp.broadcast_to(
248
269
  k_rope_SNH,
249
270
  (k_rope_SNH.shape[0], self.N, self.qk_rope_head_dim))
250
- # KV up projection.
251
- kv_nope_SNH = jnp.einsum("SA,ANH -> SNH", kv_SA,
252
- self.kernel_kv_up_proj_ANH.value)
271
+ # KV up projection, then reshape to SN(Hk+Hv).
272
+ kv_SL = jnp.einsum("SA,AL -> SL", kv_SA,
273
+ self.kernel_kv_up_proj_AL.value)
274
+ kv_nope_SNH = kv_SL.reshape(
275
+ kv_SA.shape[0], self.N,
276
+ self.qk_nope_head_dim + self.v_head_dim)
253
277
  # Split the latent kv vector into k nope vector and v vector.
254
278
  k_nope_SNH = kv_nope_SNH[..., :self.qk_nope_head_dim]
255
279
  v_SNH = kv_nope_SNH[..., self.qk_nope_head_dim:]
@@ -287,9 +311,8 @@ class MLA(nnx.Module):
287
311
  # TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
288
312
  k_scale = self._k_scale
289
313
  v_scale = self._v_scale
290
- k_SNH, v_SNH = utils.quantize_kv(
291
- k_SNH, v_SNH, self.kv_cache_quantized_dtype, k_scale,
292
- v_scale)
314
+ k_SNH, v_SNH = quantize_kv(self.kv_cache_quantized_dtype,
315
+ k_SNH, v_SNH, k_scale, v_scale)
293
316
 
294
317
  new_kv_cache, outputs_TNH = self.attention(
295
318
  is_prefill,
@@ -323,8 +346,10 @@ class MLA(nnx.Module):
323
346
  with jax.named_scope("o_proj"):
324
347
  outputs_TNH = nnx.with_sharding_constraint(
325
348
  outputs_TNH, self.activation_attention_out_td)
326
- o_TD = jnp.einsum("TNH,NHD -> TD", outputs_TNH,
327
- self.kernel_o_proj_NHD.value)
349
+ outputs_TR = outputs_TNH.reshape(outputs_TNH.shape[0],
350
+ self.N * self.v_head_dim)
351
+ o_TD = jnp.einsum("TR,RD -> TD", outputs_TR,
352
+ self.kernel_o_proj_RD.value)
328
353
 
329
354
  return new_kv_cache, o_TD
330
355
 
@@ -391,12 +416,12 @@ class MLA(nnx.Module):
391
416
  return outputs
392
417
 
393
418
  output_TNH, kv_cache = jax.jit(
394
- shard_map.shard_map(
419
+ jax.shard_map(
395
420
  _ragged_paged_attention,
396
421
  mesh=mesh,
397
422
  in_specs=in_specs,
398
423
  out_specs=out_specs,
399
- check_rep=False,
424
+ check_vma=False,
400
425
  ))(
401
426
  q_TNH,
402
427
  k_SKH,
@@ -502,12 +527,12 @@ class MLA(nnx.Module):
502
527
  return kv_cache, output
503
528
 
504
529
  kv_cache, output_TNH = jax.jit(
505
- shard_map.shard_map(
530
+ jax.shard_map(
506
531
  _mla_ragged_paged_attention,
507
532
  mesh=mesh,
508
533
  in_specs=in_specs,
509
534
  out_specs=out_specs,
510
- check_rep=False,
535
+ check_vma=False,
511
536
  ), )(
512
537
  q_TNA,
513
538
  q_rope_TNH,
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from dataclasses import InitVar, dataclass
2
16
  from typing import Tuple
3
17
 
@@ -5,7 +19,6 @@ import jax
5
19
  import jax.numpy as jnp
6
20
  from flax import nnx
7
21
  from flax.typing import Sharding
8
- from jax.experimental import shard_map
9
22
  from jax.sharding import Mesh
10
23
  from jax.sharding import PartitionSpec as P
11
24
 
@@ -13,6 +26,7 @@ from tpu_inference import utils
13
26
  from tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 import \
14
27
  ragged_paged_attention_hd64
15
28
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
29
+ from tpu_inference.layers.common.quantization import quantize_kv
16
30
  from tpu_inference.layers.jax.base import create_param
17
31
  from tpu_inference.layers.jax.rope import GptOssRotaryEmbedding
18
32
 
@@ -185,12 +199,12 @@ class GptOssAttention(nnx.Module):
185
199
  )
186
200
 
187
201
  output_TNH, kv_cache = jax.jit(
188
- shard_map.shard_map(
202
+ jax.shard_map(
189
203
  _ragged_paged_attention_wrapper,
190
204
  mesh=mesh,
191
205
  in_specs=in_specs,
192
206
  out_specs=out_specs,
193
- check_rep=False,
207
+ check_vma=False,
194
208
  ))(
195
209
  q_TNH,
196
210
  k_SKH,
@@ -235,9 +249,8 @@ class GptOssAttention(nnx.Module):
235
249
  # q_scale = self._q_scale
236
250
  k_scale = self._k_scale
237
251
  v_scale = self._v_scale
238
- k_TKH, v_TKH = utils.quantize_kv(k_TKH, v_TKH,
239
- self.kv_cache_quantized_dtype,
240
- k_scale, v_scale)
252
+ k_TKH, v_TKH = quantize_kv(self.kv_cache_quantized_dtype, k_TKH,
253
+ v_TKH, k_scale, v_scale)
241
254
 
242
255
  with jax.named_scope("attn_op"):
243
256
  new_kv_cache, attn_out_TNH = self.attention(
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from dataclasses import dataclass
2
16
 
3
17
  import jax
@@ -5,8 +19,8 @@ import jax.numpy as jnp
5
19
  from flax import nnx
6
20
  from jax.sharding import Sharding
7
21
 
8
- from tpu_inference import utils
9
22
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
23
+ from tpu_inference.layers.common.quantization import quantize_kv
10
24
  from tpu_inference.layers.jax.attention.attention import Attention, KVCache
11
25
  from tpu_inference.layers.jax.rope_interface import apply_rope
12
26
  from tpu_inference.logger import init_logger
@@ -114,9 +128,8 @@ class Llama4Attention(Attention):
114
128
  # q_scale = self._q_scale
115
129
  k_scale = self._k_scale
116
130
  v_scale = self._v_scale
117
- k_SKH, v_SKH = utils.quantize_kv(k_SKH, v_SKH,
118
- self.kv_cache_quantized_dtype,
119
- k_scale, v_scale)
131
+ k_SKH, v_SKH = quantize_kv(self.kv_cache_quantized_dtype, k_SKH,
132
+ v_SKH, k_scale, v_scale)
120
133
 
121
134
  with jax.named_scope("attn_op"):
122
135
  new_kv_cache, outputs_TNH = self.attention(
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import dataclasses
2
16
  from dataclasses import dataclass, fields
3
17
  from typing import Any, Callable, Mapping
@@ -1,3 +1,16 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
1
14
  """
2
15
  Current Used Abbreviation for Tensor Dimensions:
3
16
  B: Batch size
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from dataclasses import InitVar, dataclass
2
16
  from typing import Any
3
17
 
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import math
2
16
  from typing import Tuple
3
17
 
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import enum
2
16
  from dataclasses import InitVar, dataclass
3
17
  from functools import partial
@@ -14,8 +28,8 @@ from qwix._src.providers import ptq
14
28
 
15
29
  from tpu_inference.layers.jax.base import create_param
16
30
  from tpu_inference.layers.jax.layers import FlaxUtils
17
- from tpu_inference.layers.jax.moe.moe import MoE
18
- from tpu_inference.models.jax.utils.quantization.quantization_utils import (
31
+ from tpu_inference.layers.jax.moe.moe import CombineExperts, MoE
32
+ from tpu_inference.models.jax.utils.qwix.qwix_utils import (
19
33
  manually_quantize_qwix_activation, manually_quantize_qwix_weight)
20
34
 
21
35
  modeling_flax_utils = FlaxUtils()
@@ -150,6 +164,7 @@ class SparseMoE(MoE):
150
164
 
151
165
  def __post_init__(self, rngs: nnx.Rngs):
152
166
  super().__post_init__(rngs)
167
+ self.combine_experts = CombineExperts(dtype=self.dtype)
153
168
 
154
169
  # Derive the expert sharding
155
170
  self.expert_axis_name = self.edf_sharding[0]
@@ -331,15 +346,7 @@ class SparseMoE(MoE):
331
346
  processed_tokens, jnp.argsort(sort_indices))
332
347
  reshaped_tokens_TXD = unsorted_tokens_tD.reshape(
333
348
  -1, self.num_experts_per_tok, self.hidden_size)
334
- with jax.named_scope("combine_weights"):
335
- output_TD = jnp.einsum(
336
- "TXD,TX -> TD",
337
- reshaped_tokens_TXD.astype(jnp.float32),
338
- router_weights_TX.astype(jnp.float32),
339
- precision='float32',
340
- )
341
-
342
- return output_TD.astype(self.dtype)
349
+ return self.combine_experts(reshaped_tokens_TXD, router_weights_TX)
343
350
 
344
351
  def _gmm(self, inputs, kernel, group_sizes):
345
352
  """Performs Grouped Matrix Multiply."""
@@ -575,11 +582,11 @@ class SparseMoE(MoE):
575
582
  )
576
583
  out_specs = PartitionSpec(*self.activation_ffw_td)
577
584
 
578
- mapped_moe_fwd = partial(jax.experimental.shard_map.shard_map,
585
+ mapped_moe_fwd = partial(jax.shard_map,
579
586
  mesh=self.mesh,
580
587
  in_specs=in_specs,
581
588
  out_specs=out_specs,
582
- check_rep=False)(
589
+ check_vma=False)(
583
590
  SparseMoE._distributed_sparse_moe_fwd)
584
591
 
585
592
  kernel_gating_EDF = self.kernel_gating_EDF.value
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from dataclasses import InitVar, dataclass
2
16
 
3
17
  import jax
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from dataclasses import InitVar, dataclass
2
16
 
3
17
  import jax
@@ -12,6 +26,29 @@ from tpu_inference.layers.jax.layers import FlaxUtils
12
26
  modeling_flax_utils = FlaxUtils()
13
27
 
14
28
 
29
+ @dataclass(kw_only=True)
30
+ class CombineExperts(nnx.Module):
31
+ """Combines expert outputs with router weights.
32
+
33
+ Supports `TED,TE -> TD` when passed expert outputs, using float32
34
+ accumulation for numerical stability, then casting back to the target
35
+ dtype.
36
+ """
37
+
38
+ dtype: jnp.dtype
39
+
40
+ def __call__(self, expert_outputs_TED: Float, weights_TE: Float) -> Float:
41
+ with jax.named_scope("combine_experts"):
42
+ output_TD = jnp.einsum(
43
+ "TED,TE -> TD",
44
+ expert_outputs_TED.astype(jnp.float32),
45
+ weights_TE.astype(jnp.float32),
46
+ precision="float32",
47
+ )
48
+
49
+ return output_TD.astype(self.dtype)
50
+
51
+
15
52
  @dataclass(kw_only=True)
16
53
  class Router(nnx.Module):
17
54
  """Router module for Mixture-of-Experts (MoE) layers.
@@ -139,6 +176,9 @@ class MoE(nnx.Module):
139
176
  sharding=self.efd_sharding,
140
177
  random_init=self.random_init)
141
178
 
179
+ # Shared combine module for combine path
180
+ self.combine_experts = CombineExperts(dtype=self.dtype)
181
+
142
182
  def _moe_fwd_preapply_router_weights(self, x_TD: jax.Array, weights_TE):
143
183
  """Performs the forward pass of the MoE experts with router weights pre-applied to the inputs.
144
184
 
@@ -204,6 +244,6 @@ class MoE(nnx.Module):
204
244
  with jax.named_scope("down_projection"):
205
245
  down_proj_TED = jnp.einsum('TEF,EFD -> TED', fuse_TEF,
206
246
  self.kernel_down_proj_EFD.value)
207
- with jax.named_scope("sum"):
208
- output_TD = jnp.einsum('TED,TE -> TD', down_proj_TED, weights)
209
- return output_TD.astype(self.dtype)
247
+ # Combine across experts
248
+ output_TD = self.combine_experts(down_proj_TED, weights)
249
+ return output_TD
@@ -0,0 +1,53 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Protocol
16
+
17
+ from flax import nnx
18
+ from vllm.distributed import get_pp_group
19
+ from vllm.distributed.utils import get_pp_indices
20
+
21
+
22
+ class PPMissingLayer(nnx.Module):
23
+ """
24
+ A placeholder layer for missing layers in a pipeline parallel model.
25
+ """
26
+
27
+ def __init__(self, *args, **kwargs):
28
+ pass
29
+
30
+ def __call__(self, *args, **kwargs):
31
+ """Return the first arg from args or the first value from kwargs."""
32
+ return args[0] if args else next(iter(kwargs.values()))
33
+
34
+
35
+ class LayerFn(Protocol):
36
+
37
+ def __call__(self) -> nnx.Module:
38
+ ...
39
+
40
+
41
+ def make_layers(
42
+ num_hidden_layers: int,
43
+ layer_fn: LayerFn,
44
+ ) -> tuple[int, int, List[nnx.Module]]:
45
+ start_layer, end_layer = get_pp_indices(num_hidden_layers,
46
+ get_pp_group().rank_in_group,
47
+ get_pp_group().world_size)
48
+
49
+ layers = [PPMissingLayer() for _ in range(start_layer)] \
50
+ + [layer_fn() for _ in range(start_layer, end_layer)] \
51
+ + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]
52
+
53
+ return start_layer, end_layer, layers
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import math
2
16
  from dataclasses import dataclass, field
3
17
  from typing import Optional, Tuple
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import math
2
16
  from typing import Any, Dict
3
17
 
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,3 +1,16 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
1
14
  """
2
15
  JAX-based rejection sampler for speculative decoding on TPU.
3
16
 
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import functools
2
16
 
3
17
  import jax
@@ -28,7 +42,7 @@ def sample(
28
42
  if tpu_sampling_metadata.do_sampling:
29
43
  # Unshard the logits explicity to avoid latency increase.
30
44
  logits = jax.lax.with_sharding_constraint(
31
- logits, NamedSharding(mesh, P(ShardingAxisName.ATTN_DATA, None)))
45
+ logits, NamedSharding(mesh, P(ShardingAxisName.MLP_DATA, None)))
32
46
  greedy_sampled = jnp.argmax(logits, axis=-1)
33
47
  if not tpu_sampling_metadata.do_sampling:
34
48
  return greedy_sampled