tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (250) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +78 -1
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +1 -43
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +14 -9
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +38 -7
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +17 -0
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +28 -5
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +74 -35
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +88 -25
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -64
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +72 -37
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +45 -15
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +14 -0
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +41 -16
  232. tpu_inference/spec_decode/__init__.py +13 -0
  233. tpu_inference/spec_decode/jax/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  235. tpu_inference/tpu_info.py +14 -0
  236. tpu_inference/utils.py +42 -36
  237. tpu_inference/worker/__init__.py +13 -0
  238. tpu_inference/worker/tpu_worker.py +63 -50
  239. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
  240. tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
  241. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  242. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  245. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  246. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  247. tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
  248. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
  249. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
  250. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,18 @@
1
- from typing import Callable, Optional, Union
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Union
2
16
 
3
17
  import jax
4
18
  import jax.numpy as jnp
@@ -10,7 +24,7 @@ from torchax.interop import jax_view, torch_view
10
24
  from torchax.ops.mappings import t2j
11
25
  from vllm.logger import init_logger
12
26
  from vllm.model_executor.layers.fused_moe.config import (
13
- FusedMoEConfig, FusedMoEQuantConfig, biased_moe_quant_config)
27
+ FusedMoEConfig, FusedMoEQuantConfig, mxfp4_w4a16_moe_quant_config)
14
28
  from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
15
29
  FusedMoEMethodBase)
16
30
  from vllm.model_executor.layers.linear import LinearBase
@@ -28,44 +42,22 @@ from tpu_inference import envs
28
42
  from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
29
43
  from tpu_inference.layers.common.quant_methods import (MXFP4,
30
44
  get_tpu_quant_method)
31
- from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
45
+ from tpu_inference.layers.common.quantization import (
46
+ dequantize_tensor_from_mxfp4_packed, quantize_tensor)
47
+ from tpu_inference.layers.common.sharding import ShardingAxisName
48
+ from tpu_inference.layers.vllm.fused_moe import fused_moe_func
32
49
  from tpu_inference.layers.vllm.linear_common import \
33
50
  reorder_concatenated_tensor_for_sharding
34
51
  from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
35
52
  from tpu_inference.layers.vllm.quantization.unquantized import \
36
53
  VllmUnquantizedLinearMethod
54
+ from tpu_inference.utils import get_mesh_shape_product
37
55
 
38
- MXFP4_BLOCK_SIZE = 32
56
+ REQUANTIZED_BLOCK_SIZE = 512
39
57
 
40
58
  P = PartitionSpec
41
- logger = init_logger(__name__)
42
-
43
-
44
- # TODO(kyuyeunk): Move these functions into a common utility file.
45
- def u8_unpack_e2m1(u8_packed_e2m1: jax.Array) -> jax.Array:
46
- assert u8_packed_e2m1.dtype == jnp.uint8
47
- e2m1 = jax.lax.bitcast_convert_type(u8_packed_e2m1, jnp.float4_e2m1fn)
48
- # bitcast creates one more dimension that splits 8 bits into two e2m1.
49
- # we flatten them with the last dim.
50
- return jnp.reshape(e2m1, e2m1.shape[:-2] + (-1, ))
51
-
52
-
53
- def e8m0_to_fp32(u8: jax.Array) -> jax.Array:
54
- e8_finfo = jnp.finfo(jnp.float8_e8m0fnu)
55
- exponents = u8.astype(jnp.int32) + e8_finfo.minexp
56
- ones = jnp.ones_like(u8, dtype=jnp.float32)
57
- return jnp.ldexp(ones, exponents)
58
59
 
59
-
60
- def dequantize_block_weight(weight: jax.Array,
61
- scale: jax.Array,
62
- block_size: int,
63
- out_dtype: jnp.dtype = jnp.bfloat16) -> jax.Array:
64
- orig_shape = weight.shape
65
- weight_block = weight.reshape(orig_shape[:-1] + (-1, block_size))
66
- weight_dequantized = weight_block.astype(jnp.float32) * jnp.expand_dims(
67
- scale, -1)
68
- return weight_dequantized.reshape(orig_shape).astype(out_dtype)
60
+ logger = init_logger(__name__)
69
61
 
70
62
 
71
63
  @register_quantization_config(get_tpu_quant_method(MXFP4))
@@ -87,9 +79,6 @@ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
87
79
  fused_mapping=self.packed_modules_mapping,
88
80
  ):
89
81
  return VllmUnquantizedLinearMethod(linear_config)
90
- # TODO: Add support for MXFP4 Linear Method.
91
- # MXFP4 LinearMethod is available in AMD-Quark, refer to that
92
- # implementation if you are interested in enabling MXFP4 here.
93
82
  logger.warning_once(
94
83
  "MXFP4 linear layer is not implemented - falling back to "
95
84
  "UnquantizedLinearMethod.")
@@ -98,7 +87,6 @@ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
98
87
  moe_config = self.get_moe_config(layer)
99
88
  return VllmMxfp4MoEMethod(moe_config, self.mesh)
100
89
  elif isinstance(layer, Attention):
101
- # TODO: Add support for MXFP4 Attention.
102
90
  logger.warning_once("MXFP4 attention layer is not implemented. "
103
91
  "Skipping quantization for this layer.")
104
92
  return None
@@ -117,225 +105,306 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
117
105
  self.mxfp4_backend = Mxfp4Backend.TRITON
118
106
 
119
107
  self.mesh = mesh
120
- self.use_kernel = envs.USE_MOE_EP_KERNEL
108
+ self.use_kernel = envs.USE_MOE_EP_KERNEL and moe.use_ep
121
109
  self.ep_axis_name = ep_axis_name
122
110
  # TODO: Use autotune table once we have it.
123
111
  self.block_size = {
124
- "bt": 64,
112
+ "bt": 256,
125
113
  "bf": 1024,
126
- "bd1": 1536,
127
- "bd2": 1536,
128
- "btc": 64,
114
+ "bd1": 1024,
115
+ "bd2": 1024,
116
+ "btc": 256,
129
117
  "bfc": 1024,
130
- "bd1c": 1536,
131
- "bd2c": 1536,
118
+ "bd1c": 1024,
119
+ "bd2c": 1024,
132
120
  }
133
121
 
134
122
  def get_fused_moe_quant_config(
135
123
  self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None:
136
- # Because we have dequantized weights, we only need biased moe config.
137
- # TODO(kyuyeunk): Add native support for MXFP4.
138
- return biased_moe_quant_config(
139
- layer.w13_bias,
140
- layer.w2_bias,
124
+ return mxfp4_w4a16_moe_quant_config(
125
+ w1_scale=layer.w13_weight_scale,
126
+ w2_scale=layer.w2_weight_scale,
127
+ w1_bias=layer.w13_bias,
128
+ w2_bias=layer.w2_bias,
141
129
  )
142
130
 
143
131
  def process_weights_after_loading(self, layer: torch.nn.Module):
144
132
  assert isinstance(layer, FusedMoE)
145
133
  assert layer.moe_config.has_bias, "mxfp4 quantization alwyas use bias."
146
134
 
147
- w13_weight = u8_unpack_e2m1(t2j(layer.w13_weight, use_dlpack=False))
148
- w13_weight_scale = e8m0_to_fp32(
149
- t2j(layer.w13_weight_scale, use_dlpack=False))
135
+ w13_weight = t2j(layer.w13_weight, use_dlpack=False)
136
+ w13_weight_scale = t2j(layer.w13_weight_scale, use_dlpack=False)
150
137
  w13_bias = t2j(layer.w13_bias, use_dlpack=False)
151
138
 
152
- w2_weight = u8_unpack_e2m1(t2j(layer.w2_weight, use_dlpack=False))
153
- w2_weight_scale = e8m0_to_fp32(
154
- t2j(layer.w2_weight_scale, use_dlpack=False))
139
+ w2_weight = t2j(layer.w2_weight, use_dlpack=False)
140
+ w2_weight_scale = t2j(layer.w2_weight_scale, use_dlpack=False)
155
141
  w2_bias = t2j(layer.w2_bias, use_dlpack=False)
156
142
 
157
- # We dequantize fp4 weights into bf16.
158
- # TODO(kyuyeunk): Add native support for MXFP4.
159
- w13_weight = dequantize_block_weight(w13_weight, w13_weight_scale,
160
- MXFP4_BLOCK_SIZE, jnp.bfloat16)
161
- w2_weight = dequantize_block_weight(w2_weight, w2_weight_scale,
162
- MXFP4_BLOCK_SIZE, jnp.bfloat16)
163
-
164
- # Because we have dequantized weights, scales are not used anymore.
165
- delattr(layer, "w13_weight_scale")
166
- delattr(layer, "w2_weight_scale")
167
-
168
- if layer.activation == "swigluoai":
169
- # When using swigluoai, vLLM splits gmm output in a interleaved way.
170
- # However, interleaved split is not performant on TPU. Therefore,
171
- # we preprocess the weight so that splitting gmm output by middle
172
- # can still get the same result.
173
- w1_weight = w13_weight[:, ::2, :]
174
- w3_weight = w13_weight[:, 1::2, :]
175
- w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
176
-
177
- w1_bias = w13_bias[:, ::2]
178
- w3_bias = w13_bias[:, 1::2]
179
- w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
180
-
181
- if self.use_kernel and layer.use_ep:
182
- # Kernel expects:
183
- # w13: (num_experts, 2, hidden_size, intermediate_size)
184
- # w2: (num_experts, intermediate_size, hidden_size)
185
- # Current format:
186
- # w13_weight: (num_experts, 2*intermediate_size, hidden_size)
187
- # w2_weight: (num_experts, hidden_size, intermediate_size)
188
- num_experts = w13_weight.shape[0]
189
- intermediate_size = w13_weight.shape[1] // 2
190
- hidden_size = w13_weight.shape[2]
191
-
192
- # Reshape and transpose w13_weight to (num_experts, 2, hidden_size, intermediate_size)
193
- w13_reshaped = w13_weight.reshape(num_experts, 2,
194
- intermediate_size, hidden_size)
195
- w13_weight_transposed = jnp.transpose(w13_reshaped, (0, 1, 3, 2))
196
-
197
- # Transpose w2_weight to (num_experts, intermediate_size, hidden_size)
198
- w2_weight_transposed = jnp.transpose(w2_weight, (0, 2, 1))
199
-
200
- # Apply EP sharding
201
- w13_weight = jax.device_put(
202
- w13_weight_transposed,
203
- Format(Layout((0, 1, 2, 3)),
204
- NamedSharding(self.mesh, P("model", None, None, None))))
205
- w2_weight = jax.device_put(
206
- w2_weight_transposed,
207
- Format(Layout((0, 1, 2)),
208
- NamedSharding(self.mesh, P("model", None, None))))
209
-
210
- if self.moe.has_bias:
211
- w13_bias = w13_bias.reshape(num_experts, 2, intermediate_size)
143
+ # Wrap functions in jit to speedup requantization.
144
+ @jax.jit
145
+ def wrapper(w13_weight, w13_weight_scale, w13_bias, w2_weight,
146
+ w2_weight_scale, w2_bias):
147
+ # Dequantize fp4 weights into fp32.
148
+ w13_weight = dequantize_tensor_from_mxfp4_packed(
149
+ w13_weight, w13_weight_scale, 2)
150
+ w2_weight = dequantize_tensor_from_mxfp4_packed(
151
+ w2_weight, w2_weight_scale, 2)
152
+
153
+ num_experts, orig_hidden_size, orig_intermediate_size = w2_weight.shape
154
+
155
+ # Requantize the weights into TPU friendly block size.
156
+ w13_weight, w13_weight_scale = quantize_tensor(
157
+ jnp.float4_e2m1fn, w13_weight, 2, REQUANTIZED_BLOCK_SIZE, True)
158
+ w2_weight, w2_weight_scale = quantize_tensor(
159
+ jnp.float4_e2m1fn, w2_weight, 2, REQUANTIZED_BLOCK_SIZE, True)
160
+
161
+ intermediate_size = w2_weight.shape[-1]
162
+ hidden_size = w13_weight.shape[-1]
163
+
164
+ # Dims may have been padded to align with subchannel size during
165
+ # quantization. We pad the corresponding dim on other weight.
166
+ # NOTE: We perform padding after quantization as padding value can
167
+ # affect quantization numerics.
168
+ intermediate_padding_size = 2 * (intermediate_size -
169
+ orig_intermediate_size)
170
+ w13_weight = jnp.pad(w13_weight,
171
+ ((0, 0), (0, intermediate_padding_size),
172
+ (0, 0)))
173
+ w13_weight_scale = jnp.pad(w13_weight_scale,
174
+ ((0, 0), (0, intermediate_padding_size),
175
+ (0, 0)))
176
+ w13_bias = jnp.pad(w13_bias,
177
+ ((0, 0), (0, intermediate_padding_size)))
178
+
179
+ hidden_padding_size = hidden_size - orig_hidden_size
180
+ w2_weight = jnp.pad(w2_weight,
181
+ ((0, 0), (0, hidden_padding_size), (0, 0)))
182
+ w2_weight_scale = jnp.pad(w2_weight_scale,
183
+ ((0, 0), (0, hidden_padding_size),
184
+ (0, 0)))
185
+ w2_bias = jnp.pad(w2_bias, ((0, 0), (0, hidden_padding_size)))
186
+
187
+ if layer.activation == "swigluoai":
188
+ # When using swigluoai, vLLM splits gmm output in a interleaved way.
189
+ # However, interleaved split is not performant on TPU. Therefore,
190
+ # we preprocess the weight so that splitting gmm output by middle
191
+ # can still get the same result.
192
+ w1_weight = w13_weight[:, ::2, :]
193
+ w3_weight = w13_weight[:, 1::2, :]
194
+ w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
195
+
196
+ w1_weight_scale = w13_weight_scale[:, ::2, :]
197
+ w3_weight_scale = w13_weight_scale[:, 1::2, :]
198
+ w13_weight_scale = jnp.concat(
199
+ [w1_weight_scale, w3_weight_scale], axis=1)
200
+
201
+ w1_bias = w13_bias[:, ::2]
202
+ w3_bias = w13_bias[:, 1::2]
203
+ w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
204
+
205
+ if self.use_kernel:
206
+ # Kernel expects:
207
+ # w13: (num_experts, 2, hidden_size, intermediate_size)
208
+ # w2: (num_experts, intermediate_size, hidden_size)
209
+ # Current format:
210
+ # w13_weight: (num_experts, 2*intermediate_size, hidden_size)
211
+ # w2_weight: (num_experts, hidden_size, intermediate_size)
212
+
213
+ w13_weight = w13_weight.reshape(num_experts, 2,
214
+ intermediate_size, hidden_size)
215
+
216
+ w13_weight_scale = w13_weight_scale.reshape(
217
+ num_experts, 2, intermediate_size, 1, -1)
218
+ w2_weight_scale = w2_weight_scale.reshape(
219
+ num_experts, hidden_size, 1, -1)
220
+
221
+ w13_bias = w13_bias.astype(jnp.float32).reshape(
222
+ num_experts, 2, 1, intermediate_size)
223
+ w2_bias = w2_bias.astype(jnp.float32).reshape(
224
+ num_experts, 1, hidden_size)
225
+
226
+ # Transpose non-constracting dim to right most dim
227
+ w13_weight = jnp.swapaxes(w13_weight, 2, 3)
228
+ w2_weight = jnp.swapaxes(w2_weight, 1, 2)
229
+
230
+ w13_weight_scale = jnp.swapaxes(w13_weight_scale, 2, 4)
231
+ w2_weight_scale = jnp.swapaxes(w2_weight_scale, 1, 3)
212
232
 
213
233
  # Apply EP sharding
214
- w13_bias = jax.device_put(
215
- w13_bias,
216
- Format(Layout((0, 1, 2)),
217
- NamedSharding(self.mesh, P("model", None, None))))
218
- w2_bias = jax.device_put(
219
- w2_bias,
220
- Format(Layout((0, 1)),
221
- NamedSharding(self.mesh, P("model", None))))
222
-
223
- else:
224
- if layer.use_ep:
225
- w13_weight = jax.device_put(
226
- w13_weight,
227
- Format(Layout((0, 1, 2)),
228
- NamedSharding(self.mesh, P("model", None, None))))
229
- w2_weight = jax.device_put(
230
- w2_weight,
231
- Format(Layout((0, 1, 2)),
232
- NamedSharding(self.mesh, P("model", None, None))))
233
-
234
- w13_bias = jax.device_put(
235
- w13_bias,
236
- Format(Layout((0, 1)),
237
- NamedSharding(self.mesh, P("model", None))))
238
- w2_bias = jax.device_put(
239
- w2_bias,
240
- Format(Layout((0, 1)),
241
- NamedSharding(self.mesh, P("model", None))))
242
-
234
+ ep_sharding = NamedSharding(self.mesh, P("model"))
235
+
236
+ w13_weight = jax.lax.with_sharding_constraint(
237
+ w13_weight, Format(Layout((0, 1, 2, 3)), ep_sharding))
238
+ w2_weight = jax.lax.with_sharding_constraint(
239
+ w2_weight, Format(Layout((0, 1, 2)), ep_sharding))
240
+
241
+ w13_weight_scale = jax.lax.with_sharding_constraint(
242
+ w13_weight_scale,
243
+ Format(Layout((0, 1, 2, 3, 4)), ep_sharding))
244
+ w2_weight_scale = jax.lax.with_sharding_constraint(
245
+ w2_weight_scale, Format(Layout((0, 1, 2, 3)), ep_sharding))
246
+
247
+ w13_bias = jax.lax.with_sharding_constraint(
248
+ w13_bias, Format(Layout((0, 1, 2, 3)), ep_sharding))
249
+ w2_bias = jax.lax.with_sharding_constraint(
250
+ w2_bias, Format(Layout((0, 1, 2)), ep_sharding))
243
251
  else:
244
- intermediate_size = w13_weight.shape[1] // 2
245
- assert intermediate_size == w2_weight.shape[-1]
246
- output_sizes = [intermediate_size, intermediate_size]
247
- n_shards = self.mesh.shape["model"]
248
- assert intermediate_size % n_shards == 0
249
- w13_weight = reorder_concatenated_tensor_for_sharding(
250
- w13_weight, output_sizes, n_shards, dim=1)
251
- w13_weight = jax.device_put(
252
- w13_weight,
253
- Format(Layout((0, 1, 2)),
254
- NamedSharding(self.mesh, P(None, "model", None))))
255
- w2_weight = jax.device_put(
256
- w2_weight,
257
- Format(Layout((0, 1, 2)),
258
- NamedSharding(self.mesh, P(None, None, "model"))))
259
-
260
- w13_bias = reorder_concatenated_tensor_for_sharding(
261
- w13_bias, output_sizes, n_shards, dim=1)
262
- w13_bias = jax.device_put(
263
- w13_bias,
264
- Format(Layout((0, 1)),
265
- NamedSharding(self.mesh, P(None, "model"))))
266
- w2_bias = jax.device_put(
267
- w2_bias,
268
- Format(Layout((0, 1)),
269
- NamedSharding(self.mesh, P(None, None))))
252
+ w13_weight_scale = jnp.swapaxes(w13_weight_scale, 1, 2)
253
+ w13_weight_scale = jnp.expand_dims(w13_weight_scale, 2)
254
+ w2_weight_scale = jnp.swapaxes(w2_weight_scale, 1, 2)
255
+ w2_weight_scale = jnp.expand_dims(w2_weight_scale, 2)
256
+
257
+ w13_bias = jnp.expand_dims(w13_bias, 1)
258
+ w2_bias = jnp.expand_dims(w2_bias, 1)
259
+
260
+ if layer.use_ep:
261
+ ep_sharding = NamedSharding(self.mesh,
262
+ P(ShardingAxisName.EXPERT))
263
+
264
+ w13_weight = jax.lax.with_sharding_constraint(
265
+ w13_weight, ep_sharding)
266
+ w2_weight = jax.lax.with_sharding_constraint(
267
+ w2_weight, ep_sharding)
268
+
269
+ w13_weight_scale = jax.lax.with_sharding_constraint(
270
+ w13_weight_scale, ep_sharding)
271
+ w2_weight_scale = jax.lax.with_sharding_constraint(
272
+ w2_weight_scale, ep_sharding)
273
+
274
+ w13_bias = jax.lax.with_sharding_constraint(
275
+ w13_bias, ep_sharding)
276
+ w2_bias = jax.lax.with_sharding_constraint(
277
+ w2_bias, ep_sharding)
278
+
279
+ else:
280
+ output_sizes = [intermediate_size, intermediate_size]
281
+ n_shards = get_mesh_shape_product(
282
+ self.mesh, ShardingAxisName.MLP_TENSOR)
283
+ assert intermediate_size % n_shards == 0
284
+
285
+ # Reorder w13 weights so that splitting w1 and w3 output
286
+ # can happen locally without any collective operations.
287
+ w13_weight = reorder_concatenated_tensor_for_sharding(
288
+ w13_weight,
289
+ output_sizes,
290
+ n_shards,
291
+ dim=1,
292
+ )
293
+ w13_weight_scale = reorder_concatenated_tensor_for_sharding(
294
+ w13_weight_scale,
295
+ output_sizes,
296
+ n_shards,
297
+ dim=3,
298
+ )
299
+ w13_bias = reorder_concatenated_tensor_for_sharding(
300
+ w13_bias,
301
+ output_sizes,
302
+ n_shards,
303
+ dim=2,
304
+ )
305
+
306
+ w13_weight = jax.lax.with_sharding_constraint(
307
+ w13_weight,
308
+ NamedSharding(
309
+ self.mesh,
310
+ P(None, ShardingAxisName.MLP_TENSOR, None)))
311
+ w2_weight = jax.lax.with_sharding_constraint(
312
+ w2_weight,
313
+ NamedSharding(
314
+ self.mesh,
315
+ P(None, None, ShardingAxisName.MLP_TENSOR)))
316
+ w13_weight_scale = jax.lax.with_sharding_constraint(
317
+ w13_weight_scale,
318
+ NamedSharding(
319
+ self.mesh,
320
+ P(None, None, None, ShardingAxisName.MLP_TENSOR)))
321
+ w2_weight_scale = jax.lax.with_sharding_constraint(
322
+ w2_weight_scale,
323
+ NamedSharding(
324
+ self.mesh,
325
+ P(None, ShardingAxisName.MLP_TENSOR, None, None)))
326
+ w13_bias = jax.lax.with_sharding_constraint(
327
+ w13_bias,
328
+ NamedSharding(
329
+ self.mesh,
330
+ P(None, None, ShardingAxisName.MLP_TENSOR)))
331
+ w2_bias = jax.lax.with_sharding_constraint(
332
+ w2_bias, NamedSharding(self.mesh, P(None, None, None)))
333
+
334
+ return w13_weight, w13_weight_scale, w13_bias, w2_weight, w2_weight_scale, w2_bias
335
+
336
+ w13_weight, w13_weight_scale, w13_bias, w2_weight, w2_weight_scale, w2_bias = wrapper(
337
+ w13_weight, w13_weight_scale, w13_bias, w2_weight, w2_weight_scale,
338
+ w2_bias)
270
339
 
271
340
  layer.w13_weight = Parameter(torch_view(w13_weight),
272
341
  requires_grad=False)
273
- layer.w13_bias = Parameter(torch_view(w13_bias), requires_grad=False)
274
-
275
342
  layer.w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
276
- layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
277
343
 
278
- pass
344
+ layer.w13_weight_scale = Parameter(torch_view(w13_weight_scale),
345
+ requires_grad=False)
346
+ layer.w2_weight_scale = Parameter(torch_view(w2_weight_scale),
347
+ requires_grad=False)
348
+
349
+ layer.w13_bias = Parameter(torch_view(w13_bias), requires_grad=False)
350
+ layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
279
351
 
280
352
  def apply(
281
353
  self,
282
354
  layer: torch.nn.Module,
283
355
  x: torch.Tensor,
284
356
  router_logits: torch.Tensor,
285
- top_k: int,
286
- renormalize: bool,
287
- use_grouped_topk: bool = False,
288
- topk_group: Optional[int] = None,
289
- num_expert_group: Optional[int] = None,
290
- global_num_experts: int = -1,
291
- expert_map: Optional[torch.Tensor] = None,
292
- custom_routing_function: Optional[Callable] = None,
293
- scoring_func: str = "softmax",
294
- routed_scaling_factor: float = 1.0,
295
- e_score_correction_bias: Optional[torch.Tensor] = None,
296
- apply_router_weight_on_input: bool = False,
297
- activation: str = "silu",
298
- enable_eplb: bool = False,
299
- expert_load_view: Optional[torch.Tensor] = None,
300
- logical_to_physical_map: Optional[torch.Tensor] = None,
301
- logical_replica_count: Optional[torch.Tensor] = None,
302
357
  ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
303
358
  assert isinstance(layer, FusedMoE)
304
- if scoring_func != "softmax":
359
+ if layer.scoring_func != "softmax":
305
360
  raise NotImplementedError(
306
361
  "Only softmax is supported for scoring_func")
307
362
 
308
- if self.use_kernel and layer.use_ep:
363
+ x = jax_view(x)
364
+ w13_weight = jax_view(layer.w13_weight)
365
+ w2_weight = jax_view(layer.w2_weight)
366
+ w13_weight_scale = jax_view(layer.w13_weight_scale)
367
+ w2_weight_scale = jax_view(layer.w2_weight_scale)
368
+ w13_bias = jax_view(layer.w13_bias)
369
+ w2_bias = jax_view(layer.w2_bias)
370
+ gating_output = jax_view(router_logits)
371
+
372
+ if self.use_kernel:
373
+ actual_hidden_size = x.shape[-1]
374
+ padding_size = w13_weight.shape[-2] - actual_hidden_size
375
+ x = jnp.pad(x, ((0, 0), (0, padding_size)))
309
376
  output = fused_ep_moe(
310
377
  mesh=self.mesh,
311
- tokens=jax_view(x),
312
- w1=jax_view(layer.w13_weight),
313
- w2=jax_view(layer.w2_weight),
314
- b1=jax_view(layer.w13_bias),
315
- b2=jax_view(layer.w2_bias),
316
- gating_output=jax_view(router_logits),
317
- top_k=top_k,
378
+ tokens=x,
379
+ w1=w13_weight,
380
+ w2=w2_weight,
381
+ w1_scale=w13_weight_scale,
382
+ w2_scale=w2_weight_scale,
383
+ b1=w13_bias,
384
+ b2=w2_bias,
385
+ gating_output=gating_output,
386
+ subc_quant_wsz=REQUANTIZED_BLOCK_SIZE,
387
+ top_k=layer.top_k,
318
388
  ep_axis_name=self.ep_axis_name,
319
- renormalize_topk_logits=renormalize,
320
- act_fn=activation,
389
+ renormalize_topk_logits=layer.renormalize,
390
+ act_fn=layer.activation,
321
391
  **self.block_size,
322
- )
392
+ )[:, :actual_hidden_size]
323
393
  else:
324
- # Use the original implementation
325
- output = fused_moe_func_padded(
326
- jax_view(x),
327
- jax_view(layer.w13_weight),
328
- jax_view(layer.w2_weight),
329
- jax_view(layer.w13_bias),
330
- jax_view(layer.w2_bias),
331
- jax_view(router_logits),
332
- topk=top_k,
333
- global_num_experts=global_num_experts,
334
- renormalize=renormalize,
335
- reduce_results=layer.reduce_results,
394
+ output = fused_moe_func(
395
+ hidden_states=x,
396
+ w1=w13_weight,
397
+ w2=w2_weight,
398
+ w1_scale=w13_weight_scale,
399
+ w2_scale=w2_weight_scale,
400
+ w1_bias=w13_bias,
401
+ w2_bias=w2_bias,
402
+ gating_output=gating_output,
403
+ topk=layer.top_k,
404
+ renormalize=layer.renormalize,
336
405
  mesh=self.mesh,
337
406
  use_ep=layer.use_ep,
338
- activation=activation,
407
+ activation=layer.activation,
339
408
  )
340
409
 
341
410
  return torch_view(output)