tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (257) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +317 -34
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +406 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +320 -0
  64. tests/layers/vllm/test_unquantized.py +662 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +26 -6
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +25 -4
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +25 -12
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  154. tpu_inference/layers/common/quant_methods.py +15 -0
  155. tpu_inference/layers/common/quantization.py +282 -0
  156. tpu_inference/layers/common/sharding.py +32 -9
  157. tpu_inference/layers/common/utils.py +94 -0
  158. tpu_inference/layers/jax/__init__.py +13 -0
  159. tpu_inference/layers/jax/attention/__init__.py +13 -0
  160. tpu_inference/layers/jax/attention/attention.py +19 -6
  161. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  162. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  163. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  164. tpu_inference/layers/jax/base.py +14 -0
  165. tpu_inference/layers/jax/constants.py +13 -0
  166. tpu_inference/layers/jax/layers.py +14 -0
  167. tpu_inference/layers/jax/misc.py +14 -0
  168. tpu_inference/layers/jax/moe/__init__.py +13 -0
  169. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  170. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  171. tpu_inference/layers/jax/moe/moe.py +43 -3
  172. tpu_inference/layers/jax/pp_utils.py +53 -0
  173. tpu_inference/layers/jax/rope.py +14 -0
  174. tpu_inference/layers/jax/rope_interface.py +14 -0
  175. tpu_inference/layers/jax/sample/__init__.py +13 -0
  176. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  177. tpu_inference/layers/jax/sample/sampling.py +15 -1
  178. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  179. tpu_inference/layers/jax/transformer_block.py +14 -0
  180. tpu_inference/layers/vllm/__init__.py +13 -0
  181. tpu_inference/layers/vllm/attention.py +4 -4
  182. tpu_inference/layers/vllm/fused_moe.py +101 -494
  183. tpu_inference/layers/vllm/linear.py +64 -0
  184. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  185. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  186. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  187. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  188. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  189. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  191. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
  192. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
  193. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  194. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  195. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  196. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
  197. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  198. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
  199. tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
  200. tpu_inference/lora/__init__.py +13 -0
  201. tpu_inference/lora/torch_lora_ops.py +8 -13
  202. tpu_inference/models/__init__.py +13 -0
  203. tpu_inference/models/common/__init__.py +13 -0
  204. tpu_inference/models/common/model_loader.py +112 -35
  205. tpu_inference/models/jax/__init__.py +13 -0
  206. tpu_inference/models/jax/deepseek_v3.py +267 -157
  207. tpu_inference/models/jax/gpt_oss.py +26 -10
  208. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  209. tpu_inference/models/jax/llama3.py +99 -36
  210. tpu_inference/models/jax/llama4.py +14 -0
  211. tpu_inference/models/jax/llama_eagle3.py +18 -5
  212. tpu_inference/models/jax/llama_guard_4.py +15 -1
  213. tpu_inference/models/jax/qwen2.py +17 -2
  214. tpu_inference/models/jax/qwen2_5_vl.py +179 -51
  215. tpu_inference/models/jax/qwen3.py +17 -2
  216. tpu_inference/models/jax/utils/__init__.py +13 -0
  217. tpu_inference/models/jax/utils/file_utils.py +14 -0
  218. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  219. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  220. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
  221. tpu_inference/models/jax/utils/weight_utils.py +234 -155
  222. tpu_inference/models/vllm/__init__.py +13 -0
  223. tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
  224. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  225. tpu_inference/platforms/__init__.py +14 -0
  226. tpu_inference/platforms/tpu_platform.py +51 -72
  227. tpu_inference/runner/__init__.py +13 -0
  228. tpu_inference/runner/compilation_manager.py +180 -80
  229. tpu_inference/runner/kv_cache.py +54 -20
  230. tpu_inference/runner/kv_cache_manager.py +55 -33
  231. tpu_inference/runner/lora_utils.py +16 -1
  232. tpu_inference/runner/multimodal_manager.py +16 -2
  233. tpu_inference/runner/persistent_batch_manager.py +54 -2
  234. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  235. tpu_inference/runner/structured_decoding_manager.py +16 -3
  236. tpu_inference/runner/tpu_runner.py +124 -61
  237. tpu_inference/runner/utils.py +2 -2
  238. tpu_inference/spec_decode/__init__.py +13 -0
  239. tpu_inference/spec_decode/jax/__init__.py +13 -0
  240. tpu_inference/spec_decode/jax/eagle3.py +84 -22
  241. tpu_inference/tpu_info.py +14 -0
  242. tpu_inference/utils.py +72 -44
  243. tpu_inference/worker/__init__.py +13 -0
  244. tpu_inference/worker/tpu_worker.py +66 -52
  245. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
  246. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  247. tpu_inference/layers/vllm/linear_common.py +0 -186
  248. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  249. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  250. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  251. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  252. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  253. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  254. tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
  255. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  256. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  257. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -1,19 +1,29 @@
1
- from typing import Any, Callable, Optional, Union
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Optional
2
16
 
3
17
  import jax
4
18
  import jax.numpy as jnp
5
19
  import torch
6
- from jax.experimental.layout import Format, Layout
7
20
  from jax.sharding import Mesh, NamedSharding, PartitionSpec
8
21
  from torch.nn.parameter import Parameter
9
22
  from torchax.interop import jax_view, torch_view
10
23
  from torchax.ops.mappings import t2j
11
24
  from vllm.attention.layer import Attention
12
- from vllm.logger import init_logger
13
25
  from vllm.model_executor.layers.fused_moe.layer import (
14
26
  FusedMoE, FusedMoEConfig, UnquantizedFusedMoEMethod)
15
- from vllm.model_executor.layers.fused_moe.modular_kernel import (
16
- FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize)
17
27
  from vllm.model_executor.layers.linear import (LinearBase,
18
28
  UnquantizedLinearMethod)
19
29
  from vllm.model_executor.layers.quantization import \
@@ -21,23 +31,31 @@ from vllm.model_executor.layers.quantization import \
21
31
  from vllm.model_executor.layers.quantization.base_config import (
22
32
  QuantizationConfig, QuantizeMethodBase)
23
33
 
24
- from tpu_inference import envs
25
- from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
26
34
  from tpu_inference.layers.common.quant_methods import (UNQUANTIZED,
27
35
  get_tpu_quant_method)
28
- from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
29
- from tpu_inference.layers.vllm.linear_common import (
30
- reorder_concatenated_tensor_for_sharding,
31
- slice_sharded_tensor_for_concatenation, torch_to_jax_param)
32
- from tpu_inference.layers.vllm.quantization.common import (
33
- JaxCommonConfig, JaxCommonLinearConfig)
36
+ from tpu_inference.layers.common.sharding import ShardingAxisName
37
+ from tpu_inference.layers.common.utils import \
38
+ slice_sharded_tensor_for_concatenation
39
+ from tpu_inference.layers.vllm.fused_moe import (FusedMoEBackend,
40
+ fused_moe_apply,
41
+ select_moe_backend)
42
+ from tpu_inference.layers.vllm.process_weights.fused_moe_weights import (
43
+ FusedMoEWeights, process_moe_weights, shard_moe_weights)
44
+ from tpu_inference.layers.vllm.process_weights.linear_weights import (
45
+ LinearWeights, process_lienar_weights, shard_linear_weights,
46
+ to_parameter_list)
47
+ from tpu_inference.layers.vllm.quantization.configs import (
48
+ VllmQuantConfig, VllmQuantLinearConfig)
49
+ from tpu_inference.logger import init_logger
50
+ from tpu_inference.utils import get_mesh_shape_product
34
51
 
35
52
  P = PartitionSpec
53
+
36
54
  logger = init_logger(__name__)
37
55
 
38
56
 
39
57
  @register_quantization_config(get_tpu_quant_method(UNQUANTIZED))
40
- class VllmUnquantizedConfig(QuantizationConfig, JaxCommonConfig):
58
+ class VllmUnquantizedConfig(QuantizationConfig, VllmQuantConfig):
41
59
 
42
60
  @classmethod
43
61
  def get_name(cls) -> str:
@@ -74,51 +92,73 @@ class VllmUnquantizedConfig(QuantizationConfig, JaxCommonConfig):
74
92
 
75
93
  class VllmUnquantizedLinearMethod(UnquantizedLinearMethod):
76
94
 
77
- def __init__(self, jax_config: JaxCommonLinearConfig):
78
- self.jax_config = jax_config
95
+ def __init__(self, linear_config: VllmQuantLinearConfig):
96
+ self.linear_config = linear_config
79
97
 
80
98
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
81
- weight = torch_to_jax_param(
82
- layer.weight,
83
- NamedSharding(self.jax_config.mesh,
84
- self.jax_config.weight_sharding),
85
- self.jax_config.output_sizes,
86
- self.jax_config.n_shards,
87
- self.jax_config.fuse_matmuls,
88
- )
99
+ weight = t2j(layer.weight, use_dlpack=False)
89
100
  delattr(layer, "weight")
90
- layer.weight = weight
91
-
92
101
  if layer.bias is not None and not layer.skip_bias_add:
93
102
  if layer.return_bias:
94
103
  logger.warning_once("Bias might return incorrect value.")
95
-
96
- bias = torch_to_jax_param(
97
- layer.bias,
98
- NamedSharding(self.jax_config.mesh,
99
- self.jax_config.bias_sharding),
100
- self.jax_config.output_sizes,
101
- self.jax_config.n_shards,
102
- self.jax_config.fuse_matmuls,
103
- )
104
+ bias = t2j(layer.bias, use_dlpack=False)
104
105
  delattr(layer, "bias")
105
- layer.bias = bias
106
+ else:
107
+ bias = None
108
+
109
+ @jax.jit
110
+ def process_unquantized_linear_weights(
111
+ weight: jax.Array,
112
+ bias: jax.Array | None,
113
+ ) -> LinearWeights:
114
+ return process_lienar_weights(
115
+ LinearWeights(
116
+ weight=weight,
117
+ weight_scale=None,
118
+ zero_point=None,
119
+ bias=bias,
120
+ ),
121
+ fused=self.linear_config.fuse_matmuls,
122
+ output_sizes=self.linear_config.output_sizes,
123
+ reorder_size=self.linear_config.n_shards,
124
+ )
125
+
126
+ weights = process_unquantized_linear_weights(weight, bias)
127
+ weights = torch_view(
128
+ shard_linear_weights(
129
+ weights,
130
+ mesh=self.linear_config.mesh,
131
+ weight_p_spec=self.linear_config.weight_sharding,
132
+ bias_p_spec=self.linear_config.bias_sharding,
133
+ ))
134
+
135
+ if self.linear_config.fuse_matmuls:
136
+ layer.weight = Parameter(weights.weight, requires_grad=False)
137
+ if bias is not None:
138
+ layer.bias = Parameter(weights.bias, requires_grad=False)
139
+ else:
140
+ layer.weight = to_parameter_list(weights.weight)
141
+ if bias is not None:
142
+ layer.bias = to_parameter_list(weights.bias)
106
143
 
107
144
  def apply(self,
108
145
  layer: torch.nn.Module,
109
146
  x: torch.Tensor,
110
147
  bias: Optional[torch.Tensor] = None) -> torch.Tensor:
148
+ assert isinstance(layer, LinearBase)
149
+
111
150
  with jax.named_scope(layer._get_name()):
112
- if in_sharding := self.jax_config.get_input_sharding(x):
113
- x.shard_(NamedSharding(self.jax_config.mesh, in_sharding))
151
+ if in_sharding := self.linear_config.get_input_sharding(x):
152
+ x.shard_(NamedSharding(self.linear_config.mesh, in_sharding))
114
153
 
115
- if self.jax_config.fuse_matmuls:
154
+ if self.linear_config.fuse_matmuls:
116
155
  out = self._apply_fused(layer, x, bias)
117
156
  else:
118
157
  out = self._apply_split(layer, x, bias)
119
158
 
120
- if out_sharding := self.jax_config.get_output_sharding(out):
121
- out.shard_(NamedSharding(self.jax_config.mesh, out_sharding))
159
+ if out_sharding := self.linear_config.get_output_sharding(out):
160
+ out.shard_(NamedSharding(self.linear_config.mesh,
161
+ out_sharding))
122
162
 
123
163
  return out
124
164
 
@@ -134,7 +174,7 @@ class VllmUnquantizedLinearMethod(UnquantizedLinearMethod):
134
174
  outs += bias.jax()
135
175
 
136
176
  outs = slice_sharded_tensor_for_concatenation(
137
- outs, self.jax_config.output_sizes, self.jax_config.n_shards)
177
+ outs, self.linear_config.output_sizes, self.linear_config.n_shards)
138
178
  out = jnp.concatenate(outs, axis=-1)
139
179
  return torch_view(out)
140
180
 
@@ -160,215 +200,99 @@ class VllmUnquantizedLinearMethod(UnquantizedLinearMethod):
160
200
 
161
201
  class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
162
202
 
163
- def __init__(self,
164
- moe: FusedMoEConfig,
165
- mesh: Mesh,
166
- ep_axis_name: str = 'model'):
167
- super().__init__(moe)
168
- self.mesh = mesh
169
- self.use_kernel = envs.USE_MOE_EP_KERNEL
170
- self.ep_axis_name = ep_axis_name
171
- # TODO: Use autotune table once we have it.
172
- self.block_size = {
173
- "bt": 16,
174
- "bf": 384,
175
- "bd1": 512,
176
- "bd2": 512,
177
- "btc": 16,
178
- "bfc": 384,
179
- "bd1c": 256,
180
- "bd2c": 256,
181
- }
182
-
183
- def select_gemm_impl(
203
+ def __init__(
184
204
  self,
185
- prepare_finalize: FusedMoEPrepareAndFinalize,
186
205
  moe: FusedMoEConfig,
187
- layer: torch.nn.Module,
188
- ) -> FusedMoEPermuteExpertsUnpermute:
189
- raise NotImplementedError(
190
- "Selecting gemm implementation is currently not supported.")
206
+ mesh: Mesh,
207
+ ep_axis_name: str = "model",
208
+ ):
209
+ super().__init__(moe)
210
+ self.mesh = mesh
211
+ self.moe_backend = select_moe_backend(self.moe)
212
+
213
+ self.extra_backend_kwargs = {}
214
+ if self.moe_backend == FusedMoEBackend.FUSED_MOE:
215
+ # When fused moe kernle is used, we pass extra arguments like
216
+ # tuned block sizes to the kernel.
217
+ self.extra_backend_kwargs = dict(
218
+ ep_axis_name=ep_axis_name,
219
+ # TODO: Use autotune table once we have it.
220
+ bt=64,
221
+ bf=1024,
222
+ bd1=1536,
223
+ bd2=1536,
224
+ btc=64,
225
+ bfc=1024,
226
+ bd1c=1536,
227
+ bd2c=1536,
228
+ )
191
229
 
192
230
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
193
231
  assert isinstance(layer, FusedMoE)
232
+
194
233
  w13_weight = t2j(layer.w13_weight, use_dlpack=False)
195
234
  w2_weight = t2j(layer.w2_weight, use_dlpack=False)
196
235
 
197
236
  if self.moe.has_bias:
198
237
  w13_bias = t2j(layer.w13_bias, use_dlpack=False)
199
238
  w2_bias = t2j(layer.w2_bias, use_dlpack=False)
200
-
201
- if layer.activation == "swigluoai":
202
- # When using swigluoai, vLLM splits gmm output in a interleaved way.
203
- # However, interleaved split is not performant on TPU. Therefore,
204
- # we preprocess the weight so that splitting gmm output by middle
205
- # can still get the same result.
206
- w1_weight = w13_weight[:, ::2, :]
207
- w3_weight = w13_weight[:, 1::2, :]
208
- w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
209
-
210
- if self.moe.has_bias:
211
- w1_bias = w13_bias[:, ::2]
212
- w3_bias = w13_bias[:, 1::2]
213
- w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
214
-
215
- if self.use_kernel and layer.use_ep:
216
- # Kernel expects:
217
- # w13: (num_experts, 2, hidden_size, intermediate_size)
218
- # w2: (num_experts, intermediate_size, hidden_size)
219
- # Current format:
220
- # w13_weight: (num_experts, 2*intermediate_size, hidden_size)
221
- # w2_weight: (num_experts, hidden_size, intermediate_size)
222
- num_experts = w13_weight.shape[0]
223
- intermediate_size = w13_weight.shape[1] // 2
224
- hidden_size = w13_weight.shape[2]
225
-
226
- # Reshape and transpose w13_weight to (num_experts, 2, hidden_size, intermediate_size)
227
- w13_reshaped = w13_weight.reshape(num_experts, 2,
228
- intermediate_size, hidden_size)
229
- w13_weight_transposed = jnp.transpose(w13_reshaped, (0, 1, 3, 2))
230
-
231
- # Transpose w2_weight to (num_experts, intermediate_size, hidden_size)
232
- w2_weight_transposed = jnp.transpose(w2_weight, (0, 2, 1))
233
-
234
- # Apply EP sharding
235
- w13_weight = jax.device_put(
236
- w13_weight_transposed,
237
- Format(Layout((0, 1, 2, 3)),
238
- NamedSharding(self.mesh, P("model", None, None, None))))
239
- w2_weight = jax.device_put(
240
- w2_weight_transposed,
241
- Format(Layout((0, 1, 2)),
242
- NamedSharding(self.mesh, P("model", None, None))))
243
-
244
- if self.moe.has_bias:
245
- w13_bias = w13_bias.reshape(num_experts, 2, intermediate_size)
246
-
247
- # Apply EP sharding
248
- w13_bias = jax.device_put(
249
- w13_bias,
250
- Format(Layout((0, 1, 2)),
251
- NamedSharding(self.mesh, P("model", None, None))))
252
- w2_bias = jax.device_put(
253
- w2_bias,
254
- Format(Layout((0, 1)),
255
- NamedSharding(self.mesh, P("model", None))))
256
-
257
239
  else:
258
- # Original logic for non-kernel path
259
- if layer.use_ep:
260
- w13_weight = jax.device_put(
261
- w13_weight,
262
- Format(Layout((0, 1, 2)),
263
- NamedSharding(self.mesh, P("model", None, None))))
264
- w2_weight = jax.device_put(
265
- w2_weight,
266
- Format(Layout((0, 1, 2)),
267
- NamedSharding(self.mesh, P("model", None, None))))
268
-
269
- if self.moe.has_bias:
270
- w13_bias = jax.device_put(
271
- w13_bias,
272
- Format(Layout((0, 1)),
273
- NamedSharding(self.mesh, P("model", None))))
274
- w2_bias = jax.device_put(
275
- w2_bias,
276
- Format(Layout((0, 1)),
277
- NamedSharding(self.mesh, P("model", None))))
240
+ w13_bias = w2_bias = None
241
+
242
+ @jax.jit
243
+ def process_unquantized_moe_weights(
244
+ w13_weight: jax.Array,
245
+ w13_bias: jax.Array | None,
246
+ w2_weight: jax.Array,
247
+ w2_bias: jax.Array | None,
248
+ ) -> FusedMoEWeights:
249
+
250
+ w13_interleave = layer.activation == "swigluoai"
251
+ w13_reorder_size = get_mesh_shape_product(
252
+ self.mesh, ShardingAxisName.MLP_TENSOR)
253
+
254
+ return process_moe_weights(
255
+ FusedMoEWeights(
256
+ w13_weight=w13_weight,
257
+ w13_weight_scale=None,
258
+ w13_bias=w13_bias,
259
+ w2_weight=w2_weight,
260
+ w2_weight_scale=None,
261
+ w2_bias=w2_bias,
262
+ ),
263
+ moe_backend=self.moe_backend,
264
+ w13_reorder_size=w13_reorder_size,
265
+ w13_interleave=w13_interleave,
266
+ )
278
267
 
279
- else:
280
- intermediate_size = w13_weight.shape[1] // 2
281
- assert intermediate_size == w2_weight.shape[-1]
282
- output_sizes = [intermediate_size, intermediate_size]
283
- n_shards = self.mesh.shape["model"]
284
- assert intermediate_size % n_shards == 0
285
- w13_weight = reorder_concatenated_tensor_for_sharding(
286
- w13_weight, output_sizes, n_shards, dim=1)
287
- w13_weight = jax.device_put(
288
- w13_weight,
289
- Format(Layout((0, 1, 2)),
290
- NamedSharding(self.mesh, P(None, "model", None))))
291
- w2_weight = jax.device_put(
292
- w2_weight,
293
- Format(Layout((0, 1, 2)),
294
- NamedSharding(self.mesh, P(None, None, "model"))))
295
-
296
- if self.moe.has_bias:
297
- w13_bias = reorder_concatenated_tensor_for_sharding(
298
- w13_bias, output_sizes, n_shards, dim=1)
299
- w13_bias = jax.device_put(
300
- w13_bias,
301
- Format(Layout((0, 1)),
302
- NamedSharding(self.mesh, P(None, "model"))))
303
- w2_bias = jax.device_put(
304
- w2_bias,
305
- Format(Layout((0, 1)),
306
- NamedSharding(self.mesh, P(None, None))))
307
-
308
- layer.w13_weight = Parameter(torch_view(w13_weight),
309
- requires_grad=False)
310
- layer.w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
268
+ weights = process_unquantized_moe_weights(
269
+ w13_weight,
270
+ w13_bias,
271
+ w2_weight,
272
+ w2_bias,
273
+ )
274
+ weights = torch_view(
275
+ shard_moe_weights(weights, self.moe_backend, self.mesh))
276
+
277
+ layer.w13_weight = Parameter(weights.w13_weight, requires_grad=False)
278
+ layer.w2_weight = Parameter(weights.w2_weight, requires_grad=False)
311
279
 
312
280
  if self.moe.has_bias:
313
- layer.w13_bias = Parameter(torch_view(w13_bias),
314
- requires_grad=False)
315
- layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
281
+ layer.w13_bias = Parameter(weights.w13_bias, requires_grad=False)
282
+ layer.w2_bias = Parameter(weights.w2_bias, requires_grad=False)
316
283
 
317
284
  def apply(
318
285
  self,
319
286
  layer: torch.nn.Module,
320
287
  x: torch.Tensor,
321
288
  router_logits: torch.Tensor,
322
- top_k: int,
323
- renormalize: bool,
324
- use_grouped_topk: bool = False,
325
- topk_group: Optional[int] = None,
326
- num_expert_group: Optional[int] = None,
327
- global_num_experts: int = -1,
328
- expert_map: Optional[torch.Tensor] = None,
329
- custom_routing_function: Optional[Callable] = None,
330
- scoring_func: str = "softmax",
331
- routed_scaling_factor: float = 1.0,
332
- e_score_correction_bias: Optional[torch.Tensor] = None,
333
- apply_router_weight_on_input: bool = False,
334
- activation: str = "silu",
335
- enable_eplb: bool = False,
336
- expert_load_view: Optional[torch.Tensor] = None,
337
- logical_to_physical_map: Optional[torch.Tensor] = None,
338
- logical_replica_count: Optional[torch.Tensor] = None,
339
- ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
340
- assert isinstance(layer, FusedMoE)
341
- if scoring_func != "softmax":
342
- raise NotImplementedError(
343
- "Only softmax is supported for scoring_func")
344
-
345
- if self.use_kernel and layer.use_ep:
346
- output = fused_ep_moe(
347
- mesh=self.mesh,
348
- tokens=jax_view(x),
349
- w1=jax_view(layer.w13_weight),
350
- w2=jax_view(layer.w2_weight),
351
- gating_output=jax_view(router_logits),
352
- top_k=top_k,
353
- ep_axis_name=self.ep_axis_name,
354
- **self.block_size,
355
- )
356
- else:
357
- # Use the original implementation
358
- output = fused_moe_func_padded(
359
- jax_view(x),
360
- jax_view(layer.w13_weight),
361
- jax_view(layer.w2_weight),
362
- jax_view(layer.w13_bias) if self.moe.has_bias else None,
363
- jax_view(layer.w2_bias) if self.moe.has_bias else None,
364
- jax_view(router_logits),
365
- topk=top_k,
366
- global_num_experts=global_num_experts,
367
- renormalize=renormalize,
368
- reduce_results=layer.reduce_results,
369
- mesh=self.mesh,
370
- use_ep=layer.use_ep,
371
- activation=activation,
372
- )
373
-
374
- return torch_view(output)
289
+ ) -> torch.Tensor:
290
+
291
+ return fused_moe_apply(
292
+ layer,
293
+ x,
294
+ router_logits,
295
+ self.moe_backend,
296
+ self.mesh,
297
+ self.extra_backend_kwargs,
298
+ )
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -4,7 +4,6 @@
4
4
  import jax
5
5
  import jax.numpy as jnp
6
6
  import torch
7
- import torch.nn.functional as F
8
7
  from torchax.interop import call_jax
9
8
 
10
9
 
@@ -85,19 +84,15 @@ def bgmv_expand_slice(
85
84
  add_inputs (bool): Whether or not to add the input tensor to the output
86
85
  tensor.
87
86
  """
88
- outputs = bgmv_torch(inputs, lora_b_weights, lora_indices_tensor)
87
+ outputs = bgmv_torch(inputs, lora_b_weights,
88
+ lora_indices_tensor) # [num_tokens, out_features]
89
89
 
90
- outputs = F.pad(
91
- outputs,
92
- (
93
- slice_offset,
94
- output_tensor.shape[1] - (slice_offset + slice_size),
95
- 0,
96
- 0,
97
- ),
98
- )
90
+ # Create a padded tensor manually to avoid issues with F.pad on sharded tensors.
91
+ # This is a more robust way to handle padding in a distributed environment.
92
+ outputs_padded = torch.zeros_like(output_tensor)
93
+ outputs_padded[:, slice_offset:slice_offset + slice_size] = outputs
99
94
 
100
95
  if add_inputs:
101
- return output_tensor + outputs
96
+ return output_tensor + outputs_padded
102
97
  else:
103
- return outputs
98
+ return outputs_padded
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.