tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (248) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +14 -0
  31. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  32. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
  35. tests/layers/__init__.py +13 -0
  36. tests/layers/common/__init__.py +13 -0
  37. tests/layers/common/test_attention_interface.py +156 -0
  38. tests/layers/common/test_quantization.py +149 -0
  39. tests/layers/jax/__init__.py +13 -0
  40. tests/layers/jax/attention/__init__.py +13 -0
  41. tests/layers/jax/attention/test_common_attention.py +103 -0
  42. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  43. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  44. tests/layers/jax/moe/__init__.py +13 -0
  45. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  46. tests/layers/jax/sample/__init__.py +13 -0
  47. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  48. tests/layers/jax/sample/test_sampling.py +115 -0
  49. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  50. tests/layers/jax/test_layers.py +155 -0
  51. tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
  52. tests/layers/jax/test_rope.py +93 -0
  53. tests/layers/jax/test_sharding.py +159 -0
  54. tests/layers/jax/test_transformer_block.py +152 -0
  55. tests/layers/vllm/__init__.py +13 -0
  56. tests/layers/vllm/test_attention.py +363 -0
  57. tests/layers/vllm/test_awq.py +406 -0
  58. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  59. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  61. tests/layers/vllm/test_fp8.py +17 -0
  62. tests/layers/vllm/test_mxfp4.py +320 -0
  63. tests/layers/vllm/test_unquantized.py +662 -0
  64. tests/layers/vllm/utils.py +87 -0
  65. tests/lora/__init__.py +13 -0
  66. tests/lora/conftest.py +14 -0
  67. tests/lora/test_bgmv.py +14 -0
  68. tests/lora/test_layers.py +25 -8
  69. tests/lora/test_lora.py +15 -1
  70. tests/lora/test_lora_perf.py +14 -0
  71. tests/models/__init__.py +13 -0
  72. tests/models/common/__init__.py +13 -0
  73. tests/models/common/test_model_loader.py +455 -0
  74. tests/models/jax/__init__.py +13 -0
  75. tests/models/jax/test_deepseek_v3.py +401 -0
  76. tests/models/jax/test_llama3.py +184 -0
  77. tests/models/jax/test_llama4.py +298 -0
  78. tests/models/jax/test_llama_eagle3.py +197 -0
  79. tests/models/jax/test_llama_guard_4.py +242 -0
  80. tests/models/jax/test_qwen2.py +172 -0
  81. tests/models/jax/test_qwen2_5_vl.py +605 -0
  82. tests/models/jax/test_qwen3.py +169 -0
  83. tests/models/jax/test_weight_loading.py +180 -0
  84. tests/models/jax/utils/__init__.py +13 -0
  85. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  86. tests/platforms/__init__.py +13 -0
  87. tests/platforms/test_tpu_platform.py +54 -0
  88. tests/runner/__init__.py +13 -0
  89. tests/runner/test_block_table.py +395 -0
  90. tests/runner/test_input_batch.py +226 -0
  91. tests/runner/test_kv_cache.py +220 -0
  92. tests/runner/test_kv_cache_manager.py +498 -0
  93. tests/runner/test_multimodal_manager.py +429 -0
  94. tests/runner/test_persistent_batch_manager.py +84 -0
  95. tests/runner/test_speculative_decoding_manager.py +368 -0
  96. tests/runner/test_structured_decoding_manager.py +220 -0
  97. tests/runner/test_tpu_runner.py +261 -0
  98. tests/runner/test_tpu_runner_dp.py +1099 -0
  99. tests/runner/test_tpu_runner_mesh.py +200 -0
  100. tests/runner/test_utils.py +411 -0
  101. tests/spec_decode/__init__.py +13 -0
  102. tests/spec_decode/test_eagle3.py +311 -0
  103. tests/test_base.py +14 -0
  104. tests/test_tpu_info.py +14 -0
  105. tests/test_utils.py +1 -43
  106. tests/worker/__init__.py +13 -0
  107. tests/worker/tpu_worker_test.py +414 -0
  108. tpu_inference/__init__.py +14 -0
  109. tpu_inference/core/__init__.py +13 -0
  110. tpu_inference/core/sched/__init__.py +13 -0
  111. tpu_inference/core/sched/dp_scheduler.py +372 -56
  112. tpu_inference/distributed/__init__.py +13 -0
  113. tpu_inference/distributed/jax_parallel_state.py +14 -0
  114. tpu_inference/distributed/tpu_connector.py +14 -9
  115. tpu_inference/distributed/utils.py +56 -4
  116. tpu_inference/executors/__init__.py +13 -0
  117. tpu_inference/executors/ray_distributed_executor.py +20 -3
  118. tpu_inference/experimental/__init__.py +13 -0
  119. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  120. tpu_inference/kernels/__init__.py +13 -0
  121. tpu_inference/kernels/collectives/__init__.py +13 -0
  122. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  123. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  124. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  125. tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
  126. tpu_inference/kernels/megablox/__init__.py +13 -0
  127. tpu_inference/kernels/megablox/common.py +54 -0
  128. tpu_inference/kernels/megablox/gmm.py +646 -0
  129. tpu_inference/kernels/mla/__init__.py +13 -0
  130. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  131. tpu_inference/kernels/mla/v1/kernel.py +20 -26
  132. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  133. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  134. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  135. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  136. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
  137. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
  138. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  139. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
  140. tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
  141. tpu_inference/layers/__init__.py +13 -0
  142. tpu_inference/layers/common/__init__.py +13 -0
  143. tpu_inference/layers/common/attention_interface.py +26 -19
  144. tpu_inference/layers/common/attention_metadata.py +14 -0
  145. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  146. tpu_inference/layers/common/quant_methods.py +15 -0
  147. tpu_inference/layers/common/quantization.py +282 -0
  148. tpu_inference/layers/common/sharding.py +22 -3
  149. tpu_inference/layers/common/utils.py +94 -0
  150. tpu_inference/layers/jax/__init__.py +13 -0
  151. tpu_inference/layers/jax/attention/__init__.py +13 -0
  152. tpu_inference/layers/jax/attention/attention.py +19 -6
  153. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
  154. tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
  155. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  156. tpu_inference/layers/jax/base.py +14 -0
  157. tpu_inference/layers/jax/constants.py +13 -0
  158. tpu_inference/layers/jax/layers.py +14 -0
  159. tpu_inference/layers/jax/misc.py +14 -0
  160. tpu_inference/layers/jax/moe/__init__.py +13 -0
  161. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  162. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  163. tpu_inference/layers/jax/moe/moe.py +43 -3
  164. tpu_inference/layers/jax/pp_utils.py +53 -0
  165. tpu_inference/layers/jax/rope.py +14 -0
  166. tpu_inference/layers/jax/rope_interface.py +14 -0
  167. tpu_inference/layers/jax/sample/__init__.py +13 -0
  168. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  169. tpu_inference/layers/jax/sample/sampling.py +15 -1
  170. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  171. tpu_inference/layers/jax/transformer_block.py +14 -0
  172. tpu_inference/layers/vllm/__init__.py +13 -0
  173. tpu_inference/layers/vllm/attention.py +4 -4
  174. tpu_inference/layers/vllm/fused_moe.py +100 -455
  175. tpu_inference/layers/vllm/linear.py +64 -0
  176. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  177. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  178. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  179. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  180. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  181. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  182. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  183. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
  184. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  188. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
  189. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  190. tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
  191. tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
  192. tpu_inference/lora/__init__.py +13 -0
  193. tpu_inference/lora/torch_lora_ops.py +8 -13
  194. tpu_inference/models/__init__.py +13 -0
  195. tpu_inference/models/common/__init__.py +13 -0
  196. tpu_inference/models/common/model_loader.py +37 -16
  197. tpu_inference/models/jax/__init__.py +13 -0
  198. tpu_inference/models/jax/deepseek_v3.py +113 -124
  199. tpu_inference/models/jax/gpt_oss.py +23 -7
  200. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  201. tpu_inference/models/jax/llama3.py +99 -36
  202. tpu_inference/models/jax/llama4.py +14 -0
  203. tpu_inference/models/jax/llama_eagle3.py +14 -0
  204. tpu_inference/models/jax/llama_guard_4.py +15 -1
  205. tpu_inference/models/jax/qwen2.py +17 -2
  206. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  207. tpu_inference/models/jax/qwen3.py +17 -2
  208. tpu_inference/models/jax/utils/__init__.py +13 -0
  209. tpu_inference/models/jax/utils/file_utils.py +14 -0
  210. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  211. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
  213. tpu_inference/models/jax/utils/weight_utils.py +32 -1
  214. tpu_inference/models/vllm/__init__.py +13 -0
  215. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
  216. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  217. tpu_inference/platforms/__init__.py +14 -0
  218. tpu_inference/platforms/tpu_platform.py +27 -29
  219. tpu_inference/runner/__init__.py +13 -0
  220. tpu_inference/runner/compilation_manager.py +69 -35
  221. tpu_inference/runner/kv_cache.py +14 -0
  222. tpu_inference/runner/kv_cache_manager.py +15 -2
  223. tpu_inference/runner/lora_utils.py +16 -1
  224. tpu_inference/runner/multimodal_manager.py +16 -2
  225. tpu_inference/runner/persistent_batch_manager.py +14 -0
  226. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  227. tpu_inference/runner/structured_decoding_manager.py +14 -0
  228. tpu_inference/runner/tpu_runner.py +30 -10
  229. tpu_inference/spec_decode/__init__.py +13 -0
  230. tpu_inference/spec_decode/jax/__init__.py +13 -0
  231. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  232. tpu_inference/tpu_info.py +14 -0
  233. tpu_inference/utils.py +31 -30
  234. tpu_inference/worker/__init__.py +13 -0
  235. tpu_inference/worker/tpu_worker.py +23 -7
  236. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
  237. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  238. tpu_inference/layers/vllm/linear_common.py +0 -208
  239. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  240. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  241. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  242. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  245. tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
  246. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  247. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  248. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,64 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import jax
16
+ from jax.sharding import Mesh, NamedSharding
17
+ from jax.sharding import PartitionSpec as P
18
+
19
+ from tpu_inference import envs
20
+ from tpu_inference.kernels.quantized_matmul.kernel import (
21
+ quantized_matmul_kernel, xla_quantized_matmul)
22
+
23
+
24
+ def sharded_quantized_matmul(x: jax.Array, w_q: jax.Array, w_s: jax.Array,
25
+ mesh: Mesh, weight_sharding: P) -> jax.Array:
26
+ """
27
+ Wrapper around the quantized matmul kernel.
28
+
29
+ Args:
30
+ x: Activation.
31
+ w_q: Weight quantized array. [n_output_features, n_input_features]
32
+ w_s: Weight quantization scale. [n_output_features]
33
+ mesh: Mesh to shard on.
34
+ weight_sharding: PartitionSpec for the weight tensor.
35
+
36
+ Returns:
37
+ Output of the quantized matmul.
38
+ """
39
+
40
+ # NOTE (jacobplatin/kyuyeunk) there have been numeric issues (concerning) NaNs
41
+ # with the kernel and thus we disable it for now.
42
+ if envs.ENABLE_QUANTIZED_MATMUL_KERNEL:
43
+ out_axis, in_axis = weight_sharding
44
+ x_sharding = P(None, in_axis)
45
+ scale_sharding = P(out_axis, )
46
+ out_sharding = P(None, out_axis)
47
+
48
+ x = jax.lax.with_sharding_constraint(x,
49
+ NamedSharding(mesh, x_sharding))
50
+
51
+ def wrapper(x, w_q, w_s):
52
+ output = quantized_matmul_kernel(x, w_q, w_s, x_q_dtype=w_q.dtype)
53
+ if in_axis:
54
+ output = jax.lax.psum(output, axis_name=in_axis)
55
+ return output
56
+
57
+ return jax.shard_map(wrapper,
58
+ mesh=mesh,
59
+ in_specs=(x_sharding, weight_sharding,
60
+ scale_sharding),
61
+ out_specs=(out_sharding),
62
+ check_vma=False)(x, w_q, w_s)
63
+ else:
64
+ return xla_quantized_matmul(x, w_q, w_s)
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,7 +1,18 @@
1
- import os
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
2
14
 
3
15
  import jax
4
- import jax.numpy as jnp
5
16
  import torch
6
17
  import torchax
7
18
  from jax.sharding import Mesh, NamedSharding, PartitionSpec
@@ -9,6 +20,7 @@ from torch.nn import Parameter
9
20
  from torch.utils import _pytree as pytree
10
21
  from torchax.interop import jax_view, torch_view
11
22
  from torchax.ops.mappings import t2j
23
+ from vllm import envs as vllm_envs
12
24
  from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
13
25
  MergedColumnParallelLinearWithLoRA,
14
26
  MergedQKVParallelLinearWithLoRA,
@@ -20,18 +32,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
20
32
  ParallelLMHead, VocabParallelEmbedding)
21
33
 
22
34
  from tpu_inference import envs
35
+ from tpu_inference.layers.common.sharding import ShardingAxisName
23
36
  from tpu_inference.logger import init_logger
37
+ from tpu_inference.utils import to_jax_dtype
24
38
 
25
39
  P = PartitionSpec
26
40
 
27
41
  logger = init_logger(__name__)
28
42
 
29
- TORCH_TO_JAX_DTYPE_MAP = {
30
- torch.float32: jnp.float32,
31
- torch.float16: jnp.float16,
32
- torch.bfloat16: jnp.bfloat16,
33
- }
34
-
35
43
 
36
44
  def shard_model_to_tpu(model: torch.nn.Module,
37
45
  mesh: Mesh) -> dict[str, torchax.torch.Tensor]:
@@ -88,10 +96,9 @@ def _tensor_is_in_cpu(tensor: torch.tensor) -> bool:
88
96
 
89
97
  def _convert_to_torchax_and_shard(tensor: torch.Tensor,
90
98
  sharding: NamedSharding) -> torch.Tensor:
91
- if os.getenv("VLLM_TPU_USING_PATHWAYS", False) and isinstance(
92
- tensor, torch.Tensor):
99
+ if vllm_envs.VLLM_TPU_USING_PATHWAYS and isinstance(tensor, torch.Tensor):
93
100
  np_tensor = tensor.detach().cpu().to(torch.float32).numpy()
94
- dtype = TORCH_TO_JAX_DTYPE_MAP.get(tensor.dtype, jnp.float32)
101
+ dtype = to_jax_dtype(tensor.dtype)
95
102
  return torch_view(jax.device_put(np_tensor, sharding).astype(dtype))
96
103
  else:
97
104
  if isinstance(tensor, torchax.tensor.Tensor):
@@ -109,7 +116,8 @@ def _shard_tensor_to_tpu_replicated(tensor: torch.Tensor,
109
116
  def _shard_vocab_parallel_embedding(layer: VocabParallelEmbedding,
110
117
  mesh: Mesh) -> None:
111
118
  weight = _convert_to_torchax_and_shard(
112
- layer.weight, NamedSharding(mesh, P('model', None)))
119
+ layer.weight, NamedSharding(mesh, P(ShardingAxisName.MLP_TENSOR,
120
+ None)))
113
121
  layer.weight = Parameter(weight, requires_grad=False)
114
122
 
115
123
 
@@ -118,11 +126,12 @@ def _shard_lm_head(layer: ParallelLMHead, mesh: Mesh):
118
126
  # if that config is set, then we should not create new weights but reuse the
119
127
  # weight from VocabParallelEmbedding
120
128
  weight = _convert_to_torchax_and_shard(
121
- layer.weight, NamedSharding(mesh, P('model', None)))
129
+ layer.weight, NamedSharding(mesh, P(ShardingAxisName.MLP_TENSOR,
130
+ None)))
122
131
  layer.weight = Parameter(weight, requires_grad=False)
123
132
  if layer.bias is not None:
124
- bias = _convert_to_torchax_and_shard(layer.bias,
125
- NamedSharding(mesh, P('model')))
133
+ bias = _convert_to_torchax_and_shard(
134
+ layer.bias, NamedSharding(mesh, P(ShardingAxisName.MLP_TENSOR)))
126
135
  layer.bias = Parameter(bias, requires_grad=False)
127
136
 
128
137
 
@@ -0,0 +1,369 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass, fields
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ from jax.experimental.layout import Format, Layout, with_layout_constraint
20
+ from jax.sharding import Mesh, NamedSharding, PartitionSpec
21
+ from torchax.tensor import Tensor
22
+
23
+ from tpu_inference.layers.common.quantization import quantize_tensor
24
+ from tpu_inference.layers.common.sharding import ShardingAxisName
25
+ from tpu_inference.layers.common.utils import \
26
+ reorder_concatenated_tensor_for_sharding
27
+ from tpu_inference.layers.vllm.fused_moe import FusedMoEBackend
28
+ from tpu_inference.utils import align_to
29
+
30
+ P = PartitionSpec
31
+
32
+
33
+ @jax.tree_util.register_dataclass
34
+ @dataclass
35
+ class FusedMoEWeights:
36
+ """Fused moe weights. weights can be either jax or torchax array."""
37
+ w13_weight: jax.Array | Tensor
38
+ w13_weight_scale: jax.Array | Tensor | None
39
+ w13_bias: jax.Array | Tensor | None
40
+ w2_weight: jax.Array | Tensor
41
+ w2_weight_scale: jax.Array | Tensor | None
42
+ w2_bias: jax.Array | Tensor | None
43
+
44
+
45
+ def quantize_moe_weights(
46
+ weights: FusedMoEWeights,
47
+ dtype: jnp.dtype,
48
+ block_size: int | None,
49
+ ) -> FusedMoEWeights:
50
+ """Quantize fused moe weights into a given dtype and block size.
51
+
52
+ Args:
53
+ weights: fused moe weights.
54
+ dtype: dtype to perform quantization.
55
+ block_size: Specify block quantization size. If non, use per-channel
56
+ quantization. If contracting dim is not divisible by block size,
57
+ the dim will be automatically padded and corresponding dim on bias
58
+ and the other weight (w13_weight <-> w2_weight) is also padded.
59
+
60
+ Returns:
61
+ Quantized fused moe weights that may have also been padded.
62
+ """
63
+
64
+ # If scale is present, it means the weights are already quantized.
65
+ # Ensure that weights are not quantized by checking if scales are None.
66
+ assert weights.w13_weight_scale is None
67
+ assert weights.w2_weight_scale is None
68
+
69
+ w13_weight = weights.w13_weight
70
+ w2_weight = weights.w2_weight
71
+
72
+ if block_size is None:
73
+ # Use per-channel quantizaiton.
74
+ w13_block_size = w13_weight.shape[-1]
75
+ w2_block_size = w2_weight.shape[-1]
76
+ else:
77
+ w13_block_size = w2_block_size = block_size
78
+
79
+ _, orig_hidden_size, orig_intermediate_size = w2_weight.shape
80
+
81
+ w13_weight, w13_weight_scale = quantize_tensor(dtype, w13_weight, 2,
82
+ w13_block_size, True)
83
+ w2_weight, w2_weight_scale = quantize_tensor(dtype, w2_weight, 2,
84
+ w2_block_size, True)
85
+
86
+ intermediate_size = w2_weight.shape[-1]
87
+ hidden_size = w13_weight.shape[-1]
88
+
89
+ # Dims may have been padded to align with subchannel size during
90
+ # quantization. We pad the corresponding dim on other weight.
91
+ # NOTE: We perform padding after quantization as padding value can
92
+ # affect quantization numerics.
93
+ w13_pad_widths = [[0, 0] for _ in range(3)]
94
+ w13_pad_widths[1][1] = 2 * (intermediate_size - orig_intermediate_size)
95
+ w2_pad_widths = [[0, 0] for _ in range(3)]
96
+ w2_pad_widths[1][1] = hidden_size - orig_hidden_size
97
+
98
+ weights.w13_weight = jnp.pad(w13_weight, w13_pad_widths)
99
+ weights.w13_weight_scale = jnp.pad(w13_weight_scale, w13_pad_widths)
100
+ weights.w2_weight = jnp.pad(w2_weight, w2_pad_widths)
101
+ weights.w2_weight_scale = jnp.pad(w2_weight_scale, w2_pad_widths)
102
+
103
+ if (w13_bias := weights.w13_bias) is not None:
104
+ weights.w13_bias = jnp.pad(w13_bias, w13_pad_widths[:2])
105
+ if (w2_bias := weights.w2_bias) is not None:
106
+ weights.w2_bias = jnp.pad(w2_bias, w2_pad_widths[:2])
107
+
108
+ return weights
109
+
110
+
111
+ def process_moe_weights(
112
+ weights: FusedMoEWeights,
113
+ moe_backend: FusedMoEBackend,
114
+ w13_reorder_size: int | None = None,
115
+ w13_interleave: bool = False,
116
+ ) -> FusedMoEWeights:
117
+ """Process fused moe weights to a layout that moe backend expects.
118
+
119
+ Args:
120
+ weights: fused moe weights.
121
+ moe_backend: backend type the weights should be processed for.
122
+ w13_reorder_size: only used when backend type is GMM_TP. in order to
123
+ eliminate collective operations when using tensor parallelism,
124
+ group w13_weight into w13_reorder_size number of chuncks where each
125
+ chunk stores both w1 and w3 weights.
126
+ w13_interleave: used when loaded w13_weight is stored in interleaved
127
+ pattern where even index element is w1 and odd index element is w3.
128
+ we uninterleave so that first half is w1 and second half is w3.
129
+
130
+ Returns:
131
+ MoE weights that are processed for specified backend.
132
+ """
133
+
134
+ w13_weight = weights.w13_weight
135
+ w13_weight_scale = weights.w13_weight_scale
136
+ w13_bias = weights.w13_bias
137
+ w2_weight = weights.w2_weight
138
+ w2_weight_scale = weights.w2_weight_scale
139
+ w2_bias = weights.w2_bias
140
+
141
+ num_experts, hidden_size, intermediate_size = w2_weight.shape
142
+
143
+ if w13_interleave:
144
+ w1_weight = w13_weight[:, ::2, :]
145
+ w3_weight = w13_weight[:, 1::2, :]
146
+ w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
147
+
148
+ if w13_weight_scale is not None:
149
+ w1_weight_scale = w13_weight_scale[:, ::2, :]
150
+ w3_weight_scale = w13_weight_scale[:, 1::2, :]
151
+ w13_weight_scale = jnp.concat([w1_weight_scale, w3_weight_scale],
152
+ axis=1)
153
+
154
+ if w13_bias is not None:
155
+ w1_bias = w13_bias[:, ::2]
156
+ w3_bias = w13_bias[:, 1::2]
157
+ w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
158
+
159
+ if w13_weight_scale is not None:
160
+ w13_weight_scale = w13_weight_scale.astype(jnp.float32)
161
+ if w2_weight_scale is not None:
162
+ w2_weight_scale = w2_weight_scale.astype(jnp.float32)
163
+ if w13_bias is not None:
164
+ w13_bias = w13_bias.astype(jnp.float32)
165
+ if w2_bias is not None:
166
+ w2_bias = w2_bias.astype(jnp.float32)
167
+
168
+ match moe_backend:
169
+ case FusedMoEBackend.FUSED_MOE:
170
+ # Kernel expects:
171
+ # w13: (num_experts, 2, hidden_size, intermediate_size)
172
+ # w2: (num_experts, intermediate_size, hidden_size)
173
+ # Current format:
174
+ # w13_weight: (num_experts, 2*intermediate_size, hidden_size)
175
+ # w2_weight: (num_experts, hidden_size, intermediate_size)
176
+
177
+ # Fused moe kernel expects dims to be multiple of 256.
178
+ pad_width_intermediate_size = align_to(intermediate_size,
179
+ 256) - intermediate_size
180
+ pad_width_hidden_size = align_to(hidden_size, 256) - hidden_size
181
+
182
+ w13_weight = w13_weight.reshape(
183
+ num_experts,
184
+ 2,
185
+ intermediate_size,
186
+ hidden_size,
187
+ )
188
+
189
+ # Transpose non-constracting dim to right most dim
190
+ w13_weight = jnp.swapaxes(w13_weight, 2, 3)
191
+ w2_weight = jnp.swapaxes(w2_weight, 1, 2)
192
+
193
+ # Workaround for JAX error "must have valid byte strides"
194
+ w13_weight = with_layout_constraint(w13_weight, Layout(
195
+ (0, 1, 2, 3)))
196
+ w2_weight = with_layout_constraint(w2_weight, Layout((0, 1, 2)))
197
+
198
+ w13_weight = jnp.pad(
199
+ w13_weight,
200
+ ((0, 0), (0, 0), (0, pad_width_hidden_size),
201
+ (0, pad_width_intermediate_size)),
202
+ )
203
+
204
+ w2_weight = jnp.pad(
205
+ w2_weight,
206
+ ((0, 0), (0, pad_width_intermediate_size),
207
+ (0, pad_width_hidden_size)),
208
+ )
209
+
210
+ if w13_weight_scale is not None:
211
+ w13_weight_scale = w13_weight_scale.reshape(
212
+ num_experts, 2, intermediate_size, 1, -1)
213
+ w13_weight_scale = jnp.swapaxes(w13_weight_scale, 2, 4)
214
+ w13_weight_scale = jnp.pad(
215
+ w13_weight_scale,
216
+ ((0, 0), (0, 0), (0, pad_width_hidden_size), (0, 0),
217
+ (0, pad_width_intermediate_size)),
218
+ )
219
+ if w2_weight_scale is not None:
220
+ w2_weight_scale = w2_weight_scale.reshape(
221
+ num_experts, hidden_size, 1, -1)
222
+ w2_weight_scale = jnp.swapaxes(w2_weight_scale, 1, 3)
223
+ w2_weight_scale = jnp.pad(
224
+ w2_weight_scale,
225
+ ((0, 0), (0, pad_width_intermediate_size), (0, 0),
226
+ (0, pad_width_hidden_size)),
227
+ )
228
+
229
+ if w13_bias is not None:
230
+ w13_bias = w13_bias.reshape(num_experts, 2, 1,
231
+ intermediate_size)
232
+ w13_bias = jnp.pad(
233
+ w13_bias,
234
+ ((0, 0), (0, 0), (0, 0), (0, pad_width_intermediate_size)),
235
+ )
236
+ if w2_bias is not None:
237
+ w2_bias = w2_bias.reshape(num_experts, 1, hidden_size)
238
+ w2_bias = jnp.pad(
239
+ w2_bias,
240
+ ((0, 0), (0, 0), (0, pad_width_hidden_size)),
241
+ )
242
+
243
+ case FusedMoEBackend.GMM_EP | FusedMoEBackend.GMM_TP:
244
+ if w13_weight_scale is not None:
245
+ w13_weight_scale = jnp.swapaxes(w13_weight_scale, 1, 2)
246
+ w13_weight_scale = jnp.expand_dims(w13_weight_scale, 2)
247
+ if w2_weight_scale is not None:
248
+ w2_weight_scale = jnp.swapaxes(w2_weight_scale, 1, 2)
249
+ w2_weight_scale = jnp.expand_dims(w2_weight_scale, 2)
250
+ if w13_bias is not None:
251
+ w13_bias = jnp.expand_dims(w13_bias, 1)
252
+ if w2_bias is not None:
253
+ w2_bias = jnp.expand_dims(w2_bias, 1)
254
+
255
+ if moe_backend == FusedMoEBackend.GMM_TP:
256
+ assert w13_reorder_size is not None
257
+ assert intermediate_size % w13_reorder_size == 0
258
+ output_sizes = [intermediate_size, intermediate_size]
259
+ w13_weight = reorder_concatenated_tensor_for_sharding(
260
+ w13_weight,
261
+ output_sizes,
262
+ w13_reorder_size,
263
+ dim=1,
264
+ )
265
+ if w13_weight_scale is not None:
266
+ w13_weight_scale = reorder_concatenated_tensor_for_sharding(
267
+ w13_weight_scale,
268
+ output_sizes,
269
+ w13_reorder_size,
270
+ dim=3,
271
+ )
272
+ if w13_bias is not None:
273
+ w13_bias = reorder_concatenated_tensor_for_sharding(
274
+ w13_bias,
275
+ output_sizes,
276
+ w13_reorder_size,
277
+ dim=2,
278
+ )
279
+
280
+ return FusedMoEWeights(
281
+ w13_weight=w13_weight,
282
+ w13_weight_scale=w13_weight_scale,
283
+ w13_bias=w13_bias,
284
+ w2_weight=w2_weight,
285
+ w2_weight_scale=w2_weight_scale,
286
+ w2_bias=w2_bias,
287
+ )
288
+
289
+
290
+ def shard_moe_weights(
291
+ weights: FusedMoEWeights,
292
+ moe_backend: FusedMoEBackend,
293
+ mesh: Mesh,
294
+ ) -> FusedMoEWeights:
295
+
296
+ match moe_backend:
297
+ case FusedMoEBackend.FUSED_MOE | FusedMoEBackend.GMM_EP:
298
+ ep_sharding = NamedSharding(mesh, P(ShardingAxisName.EXPERT))
299
+ weight_shardings = FusedMoEWeights(
300
+ w13_weight=ep_sharding,
301
+ w13_weight_scale=ep_sharding,
302
+ w13_bias=ep_sharding,
303
+ w2_weight=ep_sharding,
304
+ w2_weight_scale=ep_sharding,
305
+ w2_bias=ep_sharding,
306
+ )
307
+ case FusedMoEBackend.GMM_TP:
308
+ # When using per-channel, in_dim // block_size == 1. This means we
309
+ # are unable to shard w2_weight_scale along 1st dim. Therefore, we
310
+ # fully replicate it instead.
311
+ if (weights.w2_weight_scale is not None
312
+ and weights.w2_weight_scale.shape[1] == 1):
313
+ w2_weight_scale_p_spec = P()
314
+ else:
315
+ w2_weight_scale_p_spec = P(None, ShardingAxisName.MLP_TENSOR)
316
+ weight_shardings = FusedMoEWeights(
317
+ w13_weight=NamedSharding(
318
+ mesh,
319
+ P(None, ShardingAxisName.MLP_TENSOR, None),
320
+ ), # (num_experts, out_dim, in_dim)
321
+ w13_weight_scale=NamedSharding(
322
+ mesh,
323
+ P(None, None, None, ShardingAxisName.MLP_TENSOR),
324
+ ), # (num_experts, in_dim // block_size, 1, out_dim)
325
+ w13_bias=NamedSharding(
326
+ mesh,
327
+ P(None, None, ShardingAxisName.MLP_TENSOR),
328
+ ), # (num_experts, 1, out_dim)
329
+ w2_weight=NamedSharding(
330
+ mesh,
331
+ P(None, None, ShardingAxisName.MLP_TENSOR),
332
+ ), # (num_experts, out_dim, in_dim)
333
+ w2_weight_scale=NamedSharding(
334
+ mesh, w2_weight_scale_p_spec
335
+ ), # (num_experts, in_dim // block_size, 1, out_dim)
336
+ w2_bias=NamedSharding(
337
+ mesh,
338
+ P(None, None, None),
339
+ ), # (num_experts, 1, out_dim)
340
+ )
341
+
342
+ match moe_backend:
343
+ case FusedMoEBackend.FUSED_MOE:
344
+ weight_layouts = FusedMoEWeights(
345
+ w13_weight=Layout((0, 1, 2, 3)),
346
+ w13_weight_scale=Layout((0, 1, 2, 3, 4)),
347
+ w13_bias=Layout((0, 1, 2, 3)),
348
+ w2_weight=Layout((0, 1, 2)),
349
+ w2_weight_scale=Layout((0, 1, 2, 3)),
350
+ w2_bias=Layout((0, 1, 2)),
351
+ )
352
+ case FusedMoEBackend.GMM_TP | FusedMoEBackend.GMM_EP:
353
+ weight_layouts = FusedMoEWeights(
354
+ w13_weight=Layout((0, 1, 2)),
355
+ w13_weight_scale=Layout((0, 1, 2, 3)),
356
+ w13_bias=Layout((0, 1, 2)),
357
+ w2_weight=Layout((0, 1, 2)),
358
+ w2_weight_scale=Layout((0, 1, 2, 3)),
359
+ w2_bias=Layout((0, 1, 2)),
360
+ )
361
+
362
+ for field in fields(FusedMoEWeights):
363
+ key = field.name
364
+ if (weight := getattr(weights, key, None)) is not None:
365
+ layout = getattr(weight_layouts, key)
366
+ sharding = getattr(weight_shardings, key)
367
+ weight = jax.device_put(weight, Format(layout, sharding))
368
+ setattr(weights, key, weight)
369
+ return weights