tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (250) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +78 -1
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +1 -43
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +14 -9
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +38 -7
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +17 -0
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +28 -5
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +74 -35
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +88 -25
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -64
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +72 -37
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +45 -15
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +14 -0
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +41 -16
  232. tpu_inference/spec_decode/__init__.py +13 -0
  233. tpu_inference/spec_decode/jax/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  235. tpu_inference/tpu_info.py +14 -0
  236. tpu_inference/utils.py +42 -36
  237. tpu_inference/worker/__init__.py +13 -0
  238. tpu_inference/worker/tpu_worker.py +63 -50
  239. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
  240. tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
  241. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  242. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  245. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  246. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  247. tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
  248. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
  249. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
  250. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,18 @@
1
- from typing import Any, Callable, Optional, Union
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Optional, Union
2
16
 
3
17
  import jax
4
18
  import jax.numpy as jnp
@@ -25,17 +39,23 @@ from tpu_inference import envs
25
39
  from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
26
40
  from tpu_inference.layers.common.quant_methods import (UNQUANTIZED,
27
41
  get_tpu_quant_method)
28
- from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
42
+ from tpu_inference.layers.common.sharding import ShardingAxisName
43
+ from tpu_inference.layers.vllm.fused_moe import fused_moe_func
29
44
  from tpu_inference.layers.vllm.linear_common import (
30
45
  reorder_concatenated_tensor_for_sharding,
31
46
  slice_sharded_tensor_for_concatenation, torch_to_jax_param)
32
47
  from tpu_inference.layers.vllm.quantization.common import (
33
48
  JaxCommonConfig, JaxCommonLinearConfig)
49
+ from tpu_inference.utils import get_mesh_shape_product
34
50
 
35
51
  P = PartitionSpec
36
52
  logger = init_logger(__name__)
37
53
 
38
54
 
55
+ def align_to(a, b):
56
+ return (a + b - 1) // b * b
57
+
58
+
39
59
  @register_quantization_config(get_tpu_quant_method(UNQUANTIZED))
40
60
  class VllmUnquantizedConfig(QuantizationConfig, JaxCommonConfig):
41
61
 
@@ -168,7 +188,7 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
168
188
  ep_axis_name: str = 'model'):
169
189
  super().__init__(moe)
170
190
  self.mesh = mesh
171
- self.use_kernel = envs.USE_MOE_EP_KERNEL
191
+ self.use_kernel = envs.USE_MOE_EP_KERNEL and moe.use_ep
172
192
  self.ep_axis_name = ep_axis_name
173
193
  # TODO: Use autotune table once we have it.
174
194
  self.block_size = {
@@ -196,6 +216,8 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
196
216
  w13_weight = t2j(layer.w13_weight, use_dlpack=False)
197
217
  w2_weight = t2j(layer.w2_weight, use_dlpack=False)
198
218
 
219
+ num_experts, hidden_size, intermediate_size = w2_weight.shape
220
+
199
221
  if self.moe.has_bias:
200
222
  w13_bias = t2j(layer.w13_bias, use_dlpack=False)
201
223
  w2_bias = t2j(layer.w2_bias, use_dlpack=False)
@@ -214,7 +236,7 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
214
236
  w3_bias = w13_bias[:, 1::2]
215
237
  w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
216
238
 
217
- if self.use_kernel and layer.use_ep:
239
+ if self.use_kernel:
218
240
  # Kernel expects:
219
241
  # w13: (num_experts, 2, hidden_size, intermediate_size)
220
242
  # w2: (num_experts, intermediate_size, hidden_size)
@@ -225,87 +247,119 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
225
247
  intermediate_size = w13_weight.shape[1] // 2
226
248
  hidden_size = w13_weight.shape[2]
227
249
 
228
- # Reshape and transpose w13_weight to (num_experts, 2, hidden_size, intermediate_size)
229
- w13_reshaped = w13_weight.reshape(num_experts, 2,
230
- intermediate_size, hidden_size)
231
- w13_weight_transposed = jnp.transpose(w13_reshaped, (0, 1, 3, 2))
250
+ padded_intermediate_size = align_to(intermediate_size, 256)
251
+ padded_hidden_size = align_to(hidden_size, 256)
232
252
 
233
253
  # Transpose w2_weight to (num_experts, intermediate_size, hidden_size)
234
- w2_weight_transposed = jnp.transpose(w2_weight, (0, 2, 1))
254
+ w13_weight = w13_weight.reshape(num_experts, 2, intermediate_size,
255
+ hidden_size)
256
+ w13_weight = jnp.swapaxes(w13_weight, 3, 2)
257
+
258
+ w2_weight = jnp.swapaxes(w2_weight, 2, 1)
259
+
260
+ w13_weight = jnp.pad(
261
+ w13_weight,
262
+ ((0, 0), (0, 0), (0, padded_hidden_size - hidden_size),
263
+ (0, padded_intermediate_size - intermediate_size)),
264
+ constant_values=0)
265
+
266
+ w2_weight = jnp.pad(
267
+ w2_weight,
268
+ ((0, 0), (0, padded_intermediate_size - intermediate_size),
269
+ (0, padded_hidden_size - hidden_size)),
270
+ constant_values=0)
235
271
 
236
272
  # Apply EP sharding
273
+ ep_sharding = NamedSharding(self.mesh, P("model"))
274
+
237
275
  w13_weight = jax.device_put(
238
- w13_weight_transposed,
276
+ w13_weight,
239
277
  Format(Layout((0, 1, 2, 3)),
240
278
  NamedSharding(self.mesh, P("model", None, None, None))))
241
279
  w2_weight = jax.device_put(
242
- w2_weight_transposed,
280
+ w2_weight,
243
281
  Format(Layout((0, 1, 2)),
244
282
  NamedSharding(self.mesh, P("model", None, None))))
245
283
 
246
284
  if self.moe.has_bias:
247
- w13_bias = w13_bias.reshape(num_experts, 2, intermediate_size)
285
+ w13_bias = w13_bias.astype(jnp.float32).reshape(
286
+ num_experts, 2, 1, intermediate_size)
287
+ w2_bias = w2_bias.astype(jnp.float32).reshape(
288
+ num_experts, 1, hidden_size)
289
+
290
+ w13_bias = jnp.pad(
291
+ w13_bias,
292
+ ((0, 0), (0, 0), (0, 0),
293
+ (0, padded_intermediate_size - intermediate_size)),
294
+ constant_values=0)
295
+
296
+ w2_bias = jnp.pad(w2_bias,
297
+ ((0, 0), (0, 0),
298
+ (0, padded_hidden_size - hidden_size)),
299
+ constant_values=0)
248
300
 
249
301
  # Apply EP sharding
250
302
  w13_bias = jax.device_put(
251
- w13_bias,
252
- Format(Layout((0, 1, 2)),
253
- NamedSharding(self.mesh, P("model", None, None))))
303
+ w13_bias, Format(Layout((0, 1, 2, 3)), ep_sharding))
254
304
  w2_bias = jax.device_put(
255
- w2_bias,
256
- Format(Layout((0, 1)),
257
- NamedSharding(self.mesh, P("model", None))))
258
-
305
+ w2_bias, Format(Layout((0, 1, 2)), ep_sharding))
259
306
  else:
260
- # Original logic for non-kernel path
307
+ if self.moe.has_bias:
308
+ w13_bias = jnp.expand_dims(w13_bias, 1)
309
+ w2_bias = jnp.expand_dims(w2_bias, 1)
310
+
261
311
  if layer.use_ep:
312
+ ep_sharding = NamedSharding(self.mesh,
313
+ P(ShardingAxisName.EXPERT))
262
314
  w13_weight = jax.device_put(
263
- w13_weight,
264
- Format(Layout((0, 1, 2)),
265
- NamedSharding(self.mesh, P("model", None, None))))
315
+ w13_weight, Format(Layout((0, 1, 2)), ep_sharding))
266
316
  w2_weight = jax.device_put(
267
- w2_weight,
268
- Format(Layout((0, 1, 2)),
269
- NamedSharding(self.mesh, P("model", None, None))))
317
+ w2_weight, Format(Layout((0, 1, 2)), ep_sharding))
270
318
 
271
319
  if self.moe.has_bias:
272
320
  w13_bias = jax.device_put(
273
- w13_bias,
274
- Format(Layout((0, 1)),
275
- NamedSharding(self.mesh, P("model", None))))
321
+ w13_bias, Format(Layout((0, 1, 2)), ep_sharding))
276
322
  w2_bias = jax.device_put(
277
- w2_bias,
278
- Format(Layout((0, 1)),
279
- NamedSharding(self.mesh, P("model", None))))
323
+ w2_bias, Format(Layout((0, 1, 2)), ep_sharding))
280
324
 
281
325
  else:
282
- intermediate_size = w13_weight.shape[1] // 2
283
- assert intermediate_size == w2_weight.shape[-1]
284
326
  output_sizes = [intermediate_size, intermediate_size]
285
- n_shards = self.mesh.shape["model"]
327
+ n_shards = get_mesh_shape_product(self.mesh,
328
+ ShardingAxisName.MLP_TENSOR)
286
329
  assert intermediate_size % n_shards == 0
330
+
287
331
  w13_weight = reorder_concatenated_tensor_for_sharding(
288
332
  w13_weight, output_sizes, n_shards, dim=1)
289
333
  w13_weight = jax.device_put(
290
334
  w13_weight,
291
- Format(Layout((0, 1, 2)),
292
- NamedSharding(self.mesh, P(None, "model", None))))
335
+ Format(
336
+ Layout((0, 1, 2)),
337
+ NamedSharding(
338
+ self.mesh,
339
+ P(None, ShardingAxisName.MLP_TENSOR, None))))
293
340
  w2_weight = jax.device_put(
294
341
  w2_weight,
295
- Format(Layout((0, 1, 2)),
296
- NamedSharding(self.mesh, P(None, None, "model"))))
342
+ Format(
343
+ Layout((0, 1, 2)),
344
+ NamedSharding(
345
+ self.mesh,
346
+ P(None, None, ShardingAxisName.MLP_TENSOR))))
297
347
 
298
348
  if self.moe.has_bias:
299
349
  w13_bias = reorder_concatenated_tensor_for_sharding(
300
- w13_bias, output_sizes, n_shards, dim=1)
350
+ w13_bias, output_sizes, n_shards, dim=2)
351
+
301
352
  w13_bias = jax.device_put(
302
353
  w13_bias,
303
- Format(Layout((0, 1)),
304
- NamedSharding(self.mesh, P(None, "model"))))
354
+ Format(
355
+ Layout((0, 1, 2)),
356
+ NamedSharding(
357
+ self.mesh,
358
+ P(None, None, ShardingAxisName.MLP_TENSOR))))
305
359
  w2_bias = jax.device_put(
306
360
  w2_bias,
307
- Format(Layout((0, 1)),
308
- NamedSharding(self.mesh, P(None, None))))
361
+ Format(Layout((0, 1, 2)),
362
+ NamedSharding(self.mesh, P(None, None, None))))
309
363
 
310
364
  layer.w13_weight = Parameter(torch_view(w13_weight),
311
365
  requires_grad=False)
@@ -321,60 +375,54 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
321
375
  layer: torch.nn.Module,
322
376
  x: torch.Tensor,
323
377
  router_logits: torch.Tensor,
324
- top_k: int,
325
- renormalize: bool,
326
- use_grouped_topk: bool = False,
327
- topk_group: Optional[int] = None,
328
- num_expert_group: Optional[int] = None,
329
- global_num_experts: int = -1,
330
- expert_map: Optional[torch.Tensor] = None,
331
- custom_routing_function: Optional[Callable] = None,
332
- scoring_func: str = "softmax",
333
- routed_scaling_factor: float = 1.0,
334
- e_score_correction_bias: Optional[torch.Tensor] = None,
335
- apply_router_weight_on_input: bool = False,
336
- activation: str = "silu",
337
- enable_eplb: bool = False,
338
- expert_load_view: Optional[torch.Tensor] = None,
339
- logical_to_physical_map: Optional[torch.Tensor] = None,
340
- logical_replica_count: Optional[torch.Tensor] = None,
341
378
  ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
342
379
  assert isinstance(layer, FusedMoE)
343
- if scoring_func != "softmax":
380
+ if layer.scoring_func != "softmax":
344
381
  raise NotImplementedError(
345
382
  "Only softmax is supported for scoring_func")
346
383
 
347
- if self.use_kernel and layer.use_ep:
384
+ x = jax_view(x)
385
+ w13_weight = jax_view(layer.w13_weight)
386
+ w2_weight = jax_view(layer.w2_weight)
387
+ w13_bias = w2_bias = None
388
+ if self.moe.has_bias:
389
+ w13_bias = jax_view(layer.w13_bias)
390
+ w2_bias = jax_view(layer.w2_bias)
391
+ gating_output = jax_view(router_logits)
392
+
393
+ if self.use_kernel:
394
+ actual_hidden_size = x.shape[-1]
395
+ padding_size = w13_weight.shape[-2] - actual_hidden_size
396
+ x = jnp.pad(x, ((0, 0), (0, padding_size)))
348
397
  output = fused_ep_moe(
349
398
  mesh=self.mesh,
350
- tokens=jax_view(x),
351
- w1=jax_view(layer.w13_weight),
352
- w2=jax_view(layer.w2_weight),
353
- b1=jax_view(layer.w13_bias) if self.moe.has_bias else None,
354
- b2=jax_view(layer.w2_bias) if self.moe.has_bias else None,
355
- gating_output=jax_view(router_logits),
356
- top_k=top_k,
399
+ tokens=x,
400
+ w1=w13_weight,
401
+ w2=w2_weight,
402
+ b1=w13_bias,
403
+ b2=w2_bias,
404
+ gating_output=gating_output,
405
+ top_k=layer.top_k,
357
406
  ep_axis_name=self.ep_axis_name,
358
- renormalize_topk_logits=renormalize,
359
- act_fn=activation,
407
+ renormalize_topk_logits=layer.renormalize,
408
+ act_fn=layer.activation,
360
409
  **self.block_size,
361
- )
410
+ )[:, :actual_hidden_size]
362
411
  else:
363
- # Use the original implementation
364
- output = fused_moe_func_padded(
365
- jax_view(x),
366
- jax_view(layer.w13_weight),
367
- jax_view(layer.w2_weight),
368
- jax_view(layer.w13_bias) if self.moe.has_bias else None,
369
- jax_view(layer.w2_bias) if self.moe.has_bias else None,
370
- jax_view(router_logits),
371
- topk=top_k,
372
- global_num_experts=global_num_experts,
373
- renormalize=renormalize,
374
- reduce_results=layer.reduce_results,
412
+ output = fused_moe_func(
413
+ hidden_states=x,
414
+ w1=w13_weight,
415
+ w2=w2_weight,
416
+ w1_scale=None,
417
+ w2_scale=None,
418
+ w1_bias=w13_bias,
419
+ w2_bias=w2_bias,
420
+ gating_output=gating_output,
421
+ topk=layer.top_k,
422
+ renormalize=layer.renormalize,
375
423
  mesh=self.mesh,
376
424
  use_ep=layer.use_ep,
377
- activation=activation,
425
+ activation=layer.activation,
378
426
  )
379
427
 
380
428
  return torch_view(output)
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import os
2
16
 
3
17
  import jax
@@ -20,6 +34,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
20
34
  ParallelLMHead, VocabParallelEmbedding)
21
35
 
22
36
  from tpu_inference import envs
37
+ from tpu_inference.layers.common.sharding import ShardingAxisName
23
38
  from tpu_inference.logger import init_logger
24
39
 
25
40
  P = PartitionSpec
@@ -109,7 +124,8 @@ def _shard_tensor_to_tpu_replicated(tensor: torch.Tensor,
109
124
  def _shard_vocab_parallel_embedding(layer: VocabParallelEmbedding,
110
125
  mesh: Mesh) -> None:
111
126
  weight = _convert_to_torchax_and_shard(
112
- layer.weight, NamedSharding(mesh, P('model', None)))
127
+ layer.weight, NamedSharding(mesh, P(ShardingAxisName.MLP_TENSOR,
128
+ None)))
113
129
  layer.weight = Parameter(weight, requires_grad=False)
114
130
 
115
131
 
@@ -118,11 +134,12 @@ def _shard_lm_head(layer: ParallelLMHead, mesh: Mesh):
118
134
  # if that config is set, then we should not create new weights but reuse the
119
135
  # weight from VocabParallelEmbedding
120
136
  weight = _convert_to_torchax_and_shard(
121
- layer.weight, NamedSharding(mesh, P('model', None)))
137
+ layer.weight, NamedSharding(mesh, P(ShardingAxisName.MLP_TENSOR,
138
+ None)))
122
139
  layer.weight = Parameter(weight, requires_grad=False)
123
140
  if layer.bias is not None:
124
- bias = _convert_to_torchax_and_shard(layer.bias,
125
- NamedSharding(mesh, P('model')))
141
+ bias = _convert_to_torchax_and_shard(
142
+ layer.bias, NamedSharding(mesh, P(ShardingAxisName.MLP_TENSOR)))
126
143
  layer.bias = Parameter(bias, requires_grad=False)
127
144
 
128
145
 
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -4,7 +4,6 @@
4
4
  import jax
5
5
  import jax.numpy as jnp
6
6
  import torch
7
- import torch.nn.functional as F
8
7
  from torchax.interop import call_jax
9
8
 
10
9
 
@@ -85,19 +84,15 @@ def bgmv_expand_slice(
85
84
  add_inputs (bool): Whether or not to add the input tensor to the output
86
85
  tensor.
87
86
  """
88
- outputs = bgmv_torch(inputs, lora_b_weights, lora_indices_tensor)
87
+ outputs = bgmv_torch(inputs, lora_b_weights,
88
+ lora_indices_tensor) # [num_tokens, out_features]
89
89
 
90
- outputs = F.pad(
91
- outputs,
92
- (
93
- slice_offset,
94
- output_tensor.shape[1] - (slice_offset + slice_size),
95
- 0,
96
- 0,
97
- ),
98
- )
90
+ # Create a padded tensor manually to avoid issues with F.pad on sharded tensors.
91
+ # This is a more robust way to handle padding in a distributed environment.
92
+ outputs_padded = torch.zeros_like(output_tensor)
93
+ outputs_padded[:, slice_offset:slice_offset + slice_size] = outputs
99
94
 
100
95
  if add_inputs:
101
- return output_tensor + outputs
96
+ return output_tensor + outputs_padded
102
97
  else:
103
- return outputs
98
+ return outputs_padded
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.