tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (248) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +14 -0
  31. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  32. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
  35. tests/layers/__init__.py +13 -0
  36. tests/layers/common/__init__.py +13 -0
  37. tests/layers/common/test_attention_interface.py +156 -0
  38. tests/layers/common/test_quantization.py +149 -0
  39. tests/layers/jax/__init__.py +13 -0
  40. tests/layers/jax/attention/__init__.py +13 -0
  41. tests/layers/jax/attention/test_common_attention.py +103 -0
  42. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  43. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  44. tests/layers/jax/moe/__init__.py +13 -0
  45. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  46. tests/layers/jax/sample/__init__.py +13 -0
  47. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  48. tests/layers/jax/sample/test_sampling.py +115 -0
  49. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  50. tests/layers/jax/test_layers.py +155 -0
  51. tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
  52. tests/layers/jax/test_rope.py +93 -0
  53. tests/layers/jax/test_sharding.py +159 -0
  54. tests/layers/jax/test_transformer_block.py +152 -0
  55. tests/layers/vllm/__init__.py +13 -0
  56. tests/layers/vllm/test_attention.py +363 -0
  57. tests/layers/vllm/test_awq.py +406 -0
  58. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  59. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  61. tests/layers/vllm/test_fp8.py +17 -0
  62. tests/layers/vllm/test_mxfp4.py +320 -0
  63. tests/layers/vllm/test_unquantized.py +662 -0
  64. tests/layers/vllm/utils.py +87 -0
  65. tests/lora/__init__.py +13 -0
  66. tests/lora/conftest.py +14 -0
  67. tests/lora/test_bgmv.py +14 -0
  68. tests/lora/test_layers.py +25 -8
  69. tests/lora/test_lora.py +15 -1
  70. tests/lora/test_lora_perf.py +14 -0
  71. tests/models/__init__.py +13 -0
  72. tests/models/common/__init__.py +13 -0
  73. tests/models/common/test_model_loader.py +455 -0
  74. tests/models/jax/__init__.py +13 -0
  75. tests/models/jax/test_deepseek_v3.py +401 -0
  76. tests/models/jax/test_llama3.py +184 -0
  77. tests/models/jax/test_llama4.py +298 -0
  78. tests/models/jax/test_llama_eagle3.py +197 -0
  79. tests/models/jax/test_llama_guard_4.py +242 -0
  80. tests/models/jax/test_qwen2.py +172 -0
  81. tests/models/jax/test_qwen2_5_vl.py +605 -0
  82. tests/models/jax/test_qwen3.py +169 -0
  83. tests/models/jax/test_weight_loading.py +180 -0
  84. tests/models/jax/utils/__init__.py +13 -0
  85. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  86. tests/platforms/__init__.py +13 -0
  87. tests/platforms/test_tpu_platform.py +54 -0
  88. tests/runner/__init__.py +13 -0
  89. tests/runner/test_block_table.py +395 -0
  90. tests/runner/test_input_batch.py +226 -0
  91. tests/runner/test_kv_cache.py +220 -0
  92. tests/runner/test_kv_cache_manager.py +498 -0
  93. tests/runner/test_multimodal_manager.py +429 -0
  94. tests/runner/test_persistent_batch_manager.py +84 -0
  95. tests/runner/test_speculative_decoding_manager.py +368 -0
  96. tests/runner/test_structured_decoding_manager.py +220 -0
  97. tests/runner/test_tpu_runner.py +261 -0
  98. tests/runner/test_tpu_runner_dp.py +1099 -0
  99. tests/runner/test_tpu_runner_mesh.py +200 -0
  100. tests/runner/test_utils.py +411 -0
  101. tests/spec_decode/__init__.py +13 -0
  102. tests/spec_decode/test_eagle3.py +311 -0
  103. tests/test_base.py +14 -0
  104. tests/test_tpu_info.py +14 -0
  105. tests/test_utils.py +1 -43
  106. tests/worker/__init__.py +13 -0
  107. tests/worker/tpu_worker_test.py +414 -0
  108. tpu_inference/__init__.py +14 -0
  109. tpu_inference/core/__init__.py +13 -0
  110. tpu_inference/core/sched/__init__.py +13 -0
  111. tpu_inference/core/sched/dp_scheduler.py +372 -56
  112. tpu_inference/distributed/__init__.py +13 -0
  113. tpu_inference/distributed/jax_parallel_state.py +14 -0
  114. tpu_inference/distributed/tpu_connector.py +14 -9
  115. tpu_inference/distributed/utils.py +56 -4
  116. tpu_inference/executors/__init__.py +13 -0
  117. tpu_inference/executors/ray_distributed_executor.py +20 -3
  118. tpu_inference/experimental/__init__.py +13 -0
  119. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  120. tpu_inference/kernels/__init__.py +13 -0
  121. tpu_inference/kernels/collectives/__init__.py +13 -0
  122. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  123. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  124. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  125. tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
  126. tpu_inference/kernels/megablox/__init__.py +13 -0
  127. tpu_inference/kernels/megablox/common.py +54 -0
  128. tpu_inference/kernels/megablox/gmm.py +646 -0
  129. tpu_inference/kernels/mla/__init__.py +13 -0
  130. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  131. tpu_inference/kernels/mla/v1/kernel.py +20 -26
  132. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  133. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  134. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  135. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  136. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
  137. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
  138. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  139. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
  140. tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
  141. tpu_inference/layers/__init__.py +13 -0
  142. tpu_inference/layers/common/__init__.py +13 -0
  143. tpu_inference/layers/common/attention_interface.py +26 -19
  144. tpu_inference/layers/common/attention_metadata.py +14 -0
  145. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  146. tpu_inference/layers/common/quant_methods.py +15 -0
  147. tpu_inference/layers/common/quantization.py +282 -0
  148. tpu_inference/layers/common/sharding.py +22 -3
  149. tpu_inference/layers/common/utils.py +94 -0
  150. tpu_inference/layers/jax/__init__.py +13 -0
  151. tpu_inference/layers/jax/attention/__init__.py +13 -0
  152. tpu_inference/layers/jax/attention/attention.py +19 -6
  153. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
  154. tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
  155. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  156. tpu_inference/layers/jax/base.py +14 -0
  157. tpu_inference/layers/jax/constants.py +13 -0
  158. tpu_inference/layers/jax/layers.py +14 -0
  159. tpu_inference/layers/jax/misc.py +14 -0
  160. tpu_inference/layers/jax/moe/__init__.py +13 -0
  161. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  162. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  163. tpu_inference/layers/jax/moe/moe.py +43 -3
  164. tpu_inference/layers/jax/pp_utils.py +53 -0
  165. tpu_inference/layers/jax/rope.py +14 -0
  166. tpu_inference/layers/jax/rope_interface.py +14 -0
  167. tpu_inference/layers/jax/sample/__init__.py +13 -0
  168. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  169. tpu_inference/layers/jax/sample/sampling.py +15 -1
  170. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  171. tpu_inference/layers/jax/transformer_block.py +14 -0
  172. tpu_inference/layers/vllm/__init__.py +13 -0
  173. tpu_inference/layers/vllm/attention.py +4 -4
  174. tpu_inference/layers/vllm/fused_moe.py +100 -455
  175. tpu_inference/layers/vllm/linear.py +64 -0
  176. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  177. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  178. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  179. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  180. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  181. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  182. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  183. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
  184. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  188. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
  189. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  190. tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
  191. tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
  192. tpu_inference/lora/__init__.py +13 -0
  193. tpu_inference/lora/torch_lora_ops.py +8 -13
  194. tpu_inference/models/__init__.py +13 -0
  195. tpu_inference/models/common/__init__.py +13 -0
  196. tpu_inference/models/common/model_loader.py +37 -16
  197. tpu_inference/models/jax/__init__.py +13 -0
  198. tpu_inference/models/jax/deepseek_v3.py +113 -124
  199. tpu_inference/models/jax/gpt_oss.py +23 -7
  200. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  201. tpu_inference/models/jax/llama3.py +99 -36
  202. tpu_inference/models/jax/llama4.py +14 -0
  203. tpu_inference/models/jax/llama_eagle3.py +14 -0
  204. tpu_inference/models/jax/llama_guard_4.py +15 -1
  205. tpu_inference/models/jax/qwen2.py +17 -2
  206. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  207. tpu_inference/models/jax/qwen3.py +17 -2
  208. tpu_inference/models/jax/utils/__init__.py +13 -0
  209. tpu_inference/models/jax/utils/file_utils.py +14 -0
  210. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  211. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
  213. tpu_inference/models/jax/utils/weight_utils.py +32 -1
  214. tpu_inference/models/vllm/__init__.py +13 -0
  215. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
  216. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  217. tpu_inference/platforms/__init__.py +14 -0
  218. tpu_inference/platforms/tpu_platform.py +27 -29
  219. tpu_inference/runner/__init__.py +13 -0
  220. tpu_inference/runner/compilation_manager.py +69 -35
  221. tpu_inference/runner/kv_cache.py +14 -0
  222. tpu_inference/runner/kv_cache_manager.py +15 -2
  223. tpu_inference/runner/lora_utils.py +16 -1
  224. tpu_inference/runner/multimodal_manager.py +16 -2
  225. tpu_inference/runner/persistent_batch_manager.py +14 -0
  226. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  227. tpu_inference/runner/structured_decoding_manager.py +14 -0
  228. tpu_inference/runner/tpu_runner.py +30 -10
  229. tpu_inference/spec_decode/__init__.py +13 -0
  230. tpu_inference/spec_decode/jax/__init__.py +13 -0
  231. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  232. tpu_inference/tpu_info.py +14 -0
  233. tpu_inference/utils.py +31 -30
  234. tpu_inference/worker/__init__.py +13 -0
  235. tpu_inference/worker/tpu_worker.py +23 -7
  236. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
  237. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  238. tpu_inference/layers/vllm/linear_common.py +0 -208
  239. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  240. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  241. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  242. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  245. tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
  246. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  247. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  248. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -1,19 +1,29 @@
1
- from typing import Any, Optional, Union
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Optional
2
16
 
3
17
  import jax
4
18
  import jax.numpy as jnp
5
19
  import torch
6
- from jax.experimental.layout import Format, Layout
7
20
  from jax.sharding import Mesh, NamedSharding, PartitionSpec
8
21
  from torch.nn.parameter import Parameter
9
22
  from torchax.interop import jax_view, torch_view
10
23
  from torchax.ops.mappings import t2j
11
24
  from vllm.attention.layer import Attention
12
- from vllm.logger import init_logger
13
25
  from vllm.model_executor.layers.fused_moe.layer import (
14
26
  FusedMoE, FusedMoEConfig, UnquantizedFusedMoEMethod)
15
- from vllm.model_executor.layers.fused_moe.modular_kernel import (
16
- FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize)
17
27
  from vllm.model_executor.layers.linear import (LinearBase,
18
28
  UnquantizedLinearMethod)
19
29
  from vllm.model_executor.layers.quantization import \
@@ -21,27 +31,31 @@ from vllm.model_executor.layers.quantization import \
21
31
  from vllm.model_executor.layers.quantization.base_config import (
22
32
  QuantizationConfig, QuantizeMethodBase)
23
33
 
24
- from tpu_inference import envs
25
- from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
26
34
  from tpu_inference.layers.common.quant_methods import (UNQUANTIZED,
27
35
  get_tpu_quant_method)
28
- from tpu_inference.layers.vllm.fused_moe import fused_moe_func
29
- from tpu_inference.layers.vllm.linear_common import (
30
- reorder_concatenated_tensor_for_sharding,
31
- slice_sharded_tensor_for_concatenation, torch_to_jax_param)
32
- from tpu_inference.layers.vllm.quantization.common import (
33
- JaxCommonConfig, JaxCommonLinearConfig)
36
+ from tpu_inference.layers.common.sharding import ShardingAxisName
37
+ from tpu_inference.layers.common.utils import \
38
+ slice_sharded_tensor_for_concatenation
39
+ from tpu_inference.layers.vllm.fused_moe import (FusedMoEBackend,
40
+ fused_moe_apply,
41
+ select_moe_backend)
42
+ from tpu_inference.layers.vllm.process_weights.fused_moe_weights import (
43
+ FusedMoEWeights, process_moe_weights, shard_moe_weights)
44
+ from tpu_inference.layers.vllm.process_weights.linear_weights import (
45
+ LinearWeights, process_lienar_weights, shard_linear_weights,
46
+ to_parameter_list)
47
+ from tpu_inference.layers.vllm.quantization.configs import (
48
+ VllmQuantConfig, VllmQuantLinearConfig)
49
+ from tpu_inference.logger import init_logger
50
+ from tpu_inference.utils import get_mesh_shape_product
34
51
 
35
52
  P = PartitionSpec
36
- logger = init_logger(__name__)
37
-
38
53
 
39
- def align_to(a, b):
40
- return (a + b - 1) // b * b
54
+ logger = init_logger(__name__)
41
55
 
42
56
 
43
57
  @register_quantization_config(get_tpu_quant_method(UNQUANTIZED))
44
- class VllmUnquantizedConfig(QuantizationConfig, JaxCommonConfig):
58
+ class VllmUnquantizedConfig(QuantizationConfig, VllmQuantConfig):
45
59
 
46
60
  @classmethod
47
61
  def get_name(cls) -> str:
@@ -78,35 +92,54 @@ class VllmUnquantizedConfig(QuantizationConfig, JaxCommonConfig):
78
92
 
79
93
  class VllmUnquantizedLinearMethod(UnquantizedLinearMethod):
80
94
 
81
- def __init__(self, jax_config: JaxCommonLinearConfig):
82
- self.jax_config = jax_config
95
+ def __init__(self, linear_config: VllmQuantLinearConfig):
96
+ self.linear_config = linear_config
83
97
 
84
98
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
85
- weight = torch_to_jax_param(
86
- layer.weight,
87
- NamedSharding(self.jax_config.mesh,
88
- self.jax_config.weight_sharding),
89
- self.jax_config.output_sizes,
90
- self.jax_config.n_shards,
91
- self.jax_config.fuse_matmuls,
92
- )
99
+ weight = t2j(layer.weight, use_dlpack=False)
93
100
  delattr(layer, "weight")
94
- layer.weight = weight
95
-
96
101
  if layer.bias is not None and not layer.skip_bias_add:
97
102
  if layer.return_bias:
98
103
  logger.warning_once("Bias might return incorrect value.")
99
-
100
- bias = torch_to_jax_param(
101
- layer.bias,
102
- NamedSharding(self.jax_config.mesh,
103
- self.jax_config.bias_sharding),
104
- self.jax_config.output_sizes,
105
- self.jax_config.n_shards,
106
- self.jax_config.fuse_matmuls,
107
- )
104
+ bias = t2j(layer.bias, use_dlpack=False)
108
105
  delattr(layer, "bias")
109
- layer.bias = bias
106
+ else:
107
+ bias = None
108
+
109
+ @jax.jit
110
+ def process_unquantized_linear_weights(
111
+ weight: jax.Array,
112
+ bias: jax.Array | None,
113
+ ) -> LinearWeights:
114
+ return process_lienar_weights(
115
+ LinearWeights(
116
+ weight=weight,
117
+ weight_scale=None,
118
+ zero_point=None,
119
+ bias=bias,
120
+ ),
121
+ fused=self.linear_config.fuse_matmuls,
122
+ output_sizes=self.linear_config.output_sizes,
123
+ reorder_size=self.linear_config.n_shards,
124
+ )
125
+
126
+ weights = process_unquantized_linear_weights(weight, bias)
127
+ weights = torch_view(
128
+ shard_linear_weights(
129
+ weights,
130
+ mesh=self.linear_config.mesh,
131
+ weight_p_spec=self.linear_config.weight_sharding,
132
+ bias_p_spec=self.linear_config.bias_sharding,
133
+ ))
134
+
135
+ if self.linear_config.fuse_matmuls:
136
+ layer.weight = Parameter(weights.weight, requires_grad=False)
137
+ if bias is not None:
138
+ layer.bias = Parameter(weights.bias, requires_grad=False)
139
+ else:
140
+ layer.weight = to_parameter_list(weights.weight)
141
+ if bias is not None:
142
+ layer.bias = to_parameter_list(weights.bias)
110
143
 
111
144
  def apply(self,
112
145
  layer: torch.nn.Module,
@@ -115,16 +148,17 @@ class VllmUnquantizedLinearMethod(UnquantizedLinearMethod):
115
148
  assert isinstance(layer, LinearBase)
116
149
 
117
150
  with jax.named_scope(layer._get_name()):
118
- if in_sharding := self.jax_config.get_input_sharding(x):
119
- x.shard_(NamedSharding(self.jax_config.mesh, in_sharding))
151
+ if in_sharding := self.linear_config.get_input_sharding(x):
152
+ x.shard_(NamedSharding(self.linear_config.mesh, in_sharding))
120
153
 
121
- if self.jax_config.fuse_matmuls:
154
+ if self.linear_config.fuse_matmuls:
122
155
  out = self._apply_fused(layer, x, bias)
123
156
  else:
124
157
  out = self._apply_split(layer, x, bias)
125
158
 
126
- if out_sharding := self.jax_config.get_output_sharding(out):
127
- out.shard_(NamedSharding(self.jax_config.mesh, out_sharding))
159
+ if out_sharding := self.linear_config.get_output_sharding(out):
160
+ out.shard_(NamedSharding(self.linear_config.mesh,
161
+ out_sharding))
128
162
 
129
163
  return out
130
164
 
@@ -140,7 +174,7 @@ class VllmUnquantizedLinearMethod(UnquantizedLinearMethod):
140
174
  outs += bias.jax()
141
175
 
142
176
  outs = slice_sharded_tensor_for_concatenation(
143
- outs, self.jax_config.output_sizes, self.jax_config.n_shards)
177
+ outs, self.linear_config.output_sizes, self.linear_config.n_shards)
144
178
  out = jnp.concatenate(outs, axis=-1)
145
179
  return torch_view(out)
146
180
 
@@ -166,232 +200,99 @@ class VllmUnquantizedLinearMethod(UnquantizedLinearMethod):
166
200
 
167
201
  class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
168
202
 
169
- def __init__(self,
170
- moe: FusedMoEConfig,
171
- mesh: Mesh,
172
- ep_axis_name: str = 'model'):
173
- super().__init__(moe)
174
- self.mesh = mesh
175
- self.use_kernel = envs.USE_MOE_EP_KERNEL and moe.use_ep
176
- self.ep_axis_name = ep_axis_name
177
- # TODO: Use autotune table once we have it.
178
- self.block_size = {
179
- "bt": 64,
180
- "bf": 1024,
181
- "bd1": 1536,
182
- "bd2": 1536,
183
- "btc": 64,
184
- "bfc": 1024,
185
- "bd1c": 1536,
186
- "bd2c": 1536,
187
- }
188
-
189
- def select_gemm_impl(
203
+ def __init__(
190
204
  self,
191
- prepare_finalize: FusedMoEPrepareAndFinalize,
192
205
  moe: FusedMoEConfig,
193
- layer: torch.nn.Module,
194
- ) -> FusedMoEPermuteExpertsUnpermute:
195
- raise NotImplementedError(
196
- "Selecting gemm implementation is currently not supported.")
206
+ mesh: Mesh,
207
+ ep_axis_name: str = "model",
208
+ ):
209
+ super().__init__(moe)
210
+ self.mesh = mesh
211
+ self.moe_backend = select_moe_backend(self.moe)
212
+
213
+ self.extra_backend_kwargs = {}
214
+ if self.moe_backend == FusedMoEBackend.FUSED_MOE:
215
+ # When fused moe kernle is used, we pass extra arguments like
216
+ # tuned block sizes to the kernel.
217
+ self.extra_backend_kwargs = dict(
218
+ ep_axis_name=ep_axis_name,
219
+ # TODO: Use autotune table once we have it.
220
+ bt=64,
221
+ bf=1024,
222
+ bd1=1536,
223
+ bd2=1536,
224
+ btc=64,
225
+ bfc=1024,
226
+ bd1c=1536,
227
+ bd2c=1536,
228
+ )
197
229
 
198
230
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
199
231
  assert isinstance(layer, FusedMoE)
232
+
200
233
  w13_weight = t2j(layer.w13_weight, use_dlpack=False)
201
234
  w2_weight = t2j(layer.w2_weight, use_dlpack=False)
202
235
 
203
- num_experts, hidden_size, intermediate_size = w2_weight.shape
204
-
205
236
  if self.moe.has_bias:
206
237
  w13_bias = t2j(layer.w13_bias, use_dlpack=False)
207
238
  w2_bias = t2j(layer.w2_bias, use_dlpack=False)
208
-
209
- if layer.activation == "swigluoai":
210
- # When using swigluoai, vLLM splits gmm output in a interleaved way.
211
- # However, interleaved split is not performant on TPU. Therefore,
212
- # we preprocess the weight so that splitting gmm output by middle
213
- # can still get the same result.
214
- w1_weight = w13_weight[:, ::2, :]
215
- w3_weight = w13_weight[:, 1::2, :]
216
- w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
217
-
218
- if self.moe.has_bias:
219
- w1_bias = w13_bias[:, ::2]
220
- w3_bias = w13_bias[:, 1::2]
221
- w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
222
-
223
- if self.use_kernel:
224
- # Kernel expects:
225
- # w13: (num_experts, 2, hidden_size, intermediate_size)
226
- # w2: (num_experts, intermediate_size, hidden_size)
227
- # Current format:
228
- # w13_weight: (num_experts, 2*intermediate_size, hidden_size)
229
- # w2_weight: (num_experts, hidden_size, intermediate_size)
230
- num_experts = w13_weight.shape[0]
231
- intermediate_size = w13_weight.shape[1] // 2
232
- hidden_size = w13_weight.shape[2]
233
-
234
- padded_intermediate_size = align_to(intermediate_size, 256)
235
- padded_hidden_size = align_to(hidden_size, 256)
236
-
237
- w13_weight = w13_weight.reshape(num_experts, 2, intermediate_size,
238
- hidden_size)
239
- w13_weight = jnp.transpose(w13_weight, (0, 1, 3, 2))
240
-
241
- # Transpose w2_weight to (num_experts, intermediate_size, hidden_size)
242
- w2_weight = jnp.transpose(w2_weight, (0, 2, 1))
243
-
244
- w13_weight = jnp.pad(
245
- w13_weight,
246
- ((0, 0), (0, 0), (0, padded_hidden_size - hidden_size),
247
- (0, padded_intermediate_size - intermediate_size)),
248
- constant_values=0)
249
-
250
- w2_weight = jnp.pad(
251
- w2_weight,
252
- ((0, 0), (0, padded_intermediate_size - intermediate_size),
253
- (0, padded_hidden_size - hidden_size)),
254
- constant_values=0)
255
-
256
- # Apply EP sharding
257
- ep_sharding = NamedSharding(self.mesh, P("model"))
258
-
259
- w13_weight = jax.device_put(
260
- w13_weight,
261
- Format(Layout((0, 1, 2, 3)),
262
- NamedSharding(self.mesh, P("model", None, None, None))))
263
- w2_weight = jax.device_put(
264
- w2_weight,
265
- Format(Layout((0, 1, 2)),
266
- NamedSharding(self.mesh, P("model", None, None))))
267
-
268
- if self.moe.has_bias:
269
- w13_bias = w13_bias.astype(jnp.float32).reshape(
270
- num_experts, 2, 1, intermediate_size)
271
- w2_bias = w2_bias.astype(jnp.float32).reshape(
272
- num_experts, 1, hidden_size)
273
-
274
- w13_bias = jnp.pad(
275
- w13_bias,
276
- ((0, 0), (0, 0), (0, 0),
277
- (0, padded_intermediate_size - intermediate_size)),
278
- constant_values=0)
279
-
280
- w2_bias = jnp.pad(w2_bias,
281
- ((0, 0), (0, 0),
282
- (0, padded_hidden_size - hidden_size)),
283
- constant_values=0)
284
-
285
- # Apply EP sharding
286
- w13_bias = jax.device_put(
287
- w13_bias, Format(Layout((0, 1, 2, 3)), ep_sharding))
288
- w2_bias = jax.device_put(
289
- w2_bias, Format(Layout((0, 1, 2)), ep_sharding))
290
239
  else:
240
+ w13_bias = w2_bias = None
241
+
242
+ @jax.jit
243
+ def process_unquantized_moe_weights(
244
+ w13_weight: jax.Array,
245
+ w13_bias: jax.Array | None,
246
+ w2_weight: jax.Array,
247
+ w2_bias: jax.Array | None,
248
+ ) -> FusedMoEWeights:
249
+
250
+ w13_interleave = layer.activation == "swigluoai"
251
+ w13_reorder_size = get_mesh_shape_product(
252
+ self.mesh, ShardingAxisName.MLP_TENSOR)
253
+
254
+ return process_moe_weights(
255
+ FusedMoEWeights(
256
+ w13_weight=w13_weight,
257
+ w13_weight_scale=None,
258
+ w13_bias=w13_bias,
259
+ w2_weight=w2_weight,
260
+ w2_weight_scale=None,
261
+ w2_bias=w2_bias,
262
+ ),
263
+ moe_backend=self.moe_backend,
264
+ w13_reorder_size=w13_reorder_size,
265
+ w13_interleave=w13_interleave,
266
+ )
291
267
 
292
- if layer.use_ep:
293
- ep_sharding = NamedSharding(self.mesh, P("model"))
294
- w13_weight = jax.device_put(
295
- w13_weight, Format(Layout((0, 1, 2)), ep_sharding))
296
- w2_weight = jax.device_put(
297
- w2_weight, Format(Layout((0, 1, 2)), ep_sharding))
268
+ weights = process_unquantized_moe_weights(
269
+ w13_weight,
270
+ w13_bias,
271
+ w2_weight,
272
+ w2_bias,
273
+ )
274
+ weights = torch_view(
275
+ shard_moe_weights(weights, self.moe_backend, self.mesh))
298
276
 
299
- if self.moe.has_bias:
300
- w13_bias = jax.device_put(
301
- w13_bias, Format(Layout((0, 1)), ep_sharding))
302
- w2_bias = jax.device_put(
303
- w2_bias, Format(Layout((0, 1)), ep_sharding))
304
-
305
- else:
306
- output_sizes = [intermediate_size, intermediate_size]
307
- n_shards = self.mesh.shape["model"]
308
- assert intermediate_size % n_shards == 0
309
-
310
- w13_weight = reorder_concatenated_tensor_for_sharding(
311
- w13_weight, output_sizes, n_shards, dim=1)
312
- w13_weight = jax.device_put(
313
- w13_weight,
314
- Format(Layout((0, 1, 2)),
315
- NamedSharding(self.mesh, P(None, "model", None))))
316
- w2_weight = jax.device_put(
317
- w2_weight,
318
- Format(Layout((0, 1, 2)),
319
- NamedSharding(self.mesh, P(None, None, "model"))))
320
-
321
- if self.moe.has_bias:
322
- w13_bias = reorder_concatenated_tensor_for_sharding(
323
- w13_bias, output_sizes, n_shards, dim=1)
324
- w13_bias = jax.device_put(
325
- w13_bias,
326
- Format(Layout((0, 1)),
327
- NamedSharding(self.mesh, P(None, "model"))))
328
- w2_bias = jax.device_put(
329
- w2_bias,
330
- Format(Layout((0, 1)),
331
- NamedSharding(self.mesh, P(None, None))))
332
-
333
- layer.w13_weight = Parameter(torch_view(w13_weight),
334
- requires_grad=False)
335
- layer.w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
277
+ layer.w13_weight = Parameter(weights.w13_weight, requires_grad=False)
278
+ layer.w2_weight = Parameter(weights.w2_weight, requires_grad=False)
336
279
 
337
280
  if self.moe.has_bias:
338
- layer.w13_bias = Parameter(torch_view(w13_bias),
339
- requires_grad=False)
340
- layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
281
+ layer.w13_bias = Parameter(weights.w13_bias, requires_grad=False)
282
+ layer.w2_bias = Parameter(weights.w2_bias, requires_grad=False)
341
283
 
342
284
  def apply(
343
285
  self,
344
286
  layer: torch.nn.Module,
345
287
  x: torch.Tensor,
346
288
  router_logits: torch.Tensor,
347
- ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
348
- assert isinstance(layer, FusedMoE)
349
- if layer.scoring_func != "softmax":
350
- raise NotImplementedError(
351
- "Only softmax is supported for scoring_func")
352
-
353
- x = jax_view(x)
354
- w13_weight = jax_view(layer.w13_weight)
355
- w2_weight = jax_view(layer.w2_weight)
356
- w13_bias = w2_bias = None
357
- if self.moe.has_bias:
358
- w13_bias = jax_view(layer.w13_bias)
359
- w2_bias = jax_view(layer.w2_bias)
360
- gating_output = jax_view(router_logits)
361
-
362
- if self.use_kernel:
363
- actual_hidden_size = x.shape[-1]
364
- padded_hidden_size = align_to(actual_hidden_size, 256)
365
- x = jnp.pad(x,
366
- ((0, 0), (0, padded_hidden_size - actual_hidden_size)),
367
- constant_values=0)
368
- output = fused_ep_moe(
369
- mesh=self.mesh,
370
- tokens=x,
371
- w1=w13_weight,
372
- w2=w2_weight,
373
- b1=w13_bias,
374
- b2=w2_bias,
375
- gating_output=gating_output,
376
- top_k=layer.top_k,
377
- ep_axis_name=self.ep_axis_name,
378
- renormalize_topk_logits=layer.renormalize,
379
- act_fn=layer.activation,
380
- **self.block_size,
381
- )[:, :actual_hidden_size]
382
- else:
383
- output = fused_moe_func(
384
- hidden_states=x,
385
- w1=w13_weight,
386
- w2=w2_weight,
387
- w1_bias=w13_bias,
388
- w2_bias=w2_bias,
389
- gating_output=gating_output,
390
- topk=layer.top_k,
391
- renormalize=layer.renormalize,
392
- mesh=self.mesh,
393
- use_ep=layer.use_ep,
394
- activation=layer.activation,
395
- )
396
-
397
- return torch_view(output)
289
+ ) -> torch.Tensor:
290
+
291
+ return fused_moe_apply(
292
+ layer,
293
+ x,
294
+ router_logits,
295
+ self.moe_backend,
296
+ self.mesh,
297
+ self.extra_backend_kwargs,
298
+ )
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -4,7 +4,6 @@
4
4
  import jax
5
5
  import jax.numpy as jnp
6
6
  import torch
7
- import torch.nn.functional as F
8
7
  from torchax.interop import call_jax
9
8
 
10
9
 
@@ -85,19 +84,15 @@ def bgmv_expand_slice(
85
84
  add_inputs (bool): Whether or not to add the input tensor to the output
86
85
  tensor.
87
86
  """
88
- outputs = bgmv_torch(inputs, lora_b_weights, lora_indices_tensor)
87
+ outputs = bgmv_torch(inputs, lora_b_weights,
88
+ lora_indices_tensor) # [num_tokens, out_features]
89
89
 
90
- outputs = F.pad(
91
- outputs,
92
- (
93
- slice_offset,
94
- output_tensor.shape[1] - (slice_offset + slice_size),
95
- 0,
96
- 0,
97
- ),
98
- )
90
+ # Create a padded tensor manually to avoid issues with F.pad on sharded tensors.
91
+ # This is a more robust way to handle padding in a distributed environment.
92
+ outputs_padded = torch.zeros_like(output_tensor)
93
+ outputs_padded[:, slice_offset:slice_offset + slice_size] = outputs
99
94
 
100
95
  if add_inputs:
101
- return output_tensor + outputs
96
+ return output_tensor + outputs_padded
102
97
  else:
103
- return outputs
98
+ return outputs_padded
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.