tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (251) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +22 -1
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +167 -97
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +31 -9
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +280 -210
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +77 -36
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +91 -31
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -4
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -71
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +158 -63
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +53 -30
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +54 -2
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +105 -57
  232. tpu_inference/runner/utils.py +2 -2
  233. tpu_inference/spec_decode/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/__init__.py +13 -0
  235. tpu_inference/spec_decode/jax/eagle3.py +65 -19
  236. tpu_inference/tpu_info.py +14 -0
  237. tpu_inference/utils.py +72 -44
  238. tpu_inference/worker/__init__.py +13 -0
  239. tpu_inference/worker/tpu_worker.py +65 -52
  240. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
  241. tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
  242. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  244. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  245. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  246. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  247. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  248. tpu_inference-0.11.1.dev202511270815.dist-info/RECORD +0 -174
  249. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
  250. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
  251. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,18 @@
1
- from typing import Callable, Optional, Union
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Union
2
16
 
3
17
  import jax
4
18
  import jax.numpy as jnp
@@ -10,7 +24,7 @@ from torchax.interop import jax_view, torch_view
10
24
  from torchax.ops.mappings import t2j
11
25
  from vllm.logger import init_logger
12
26
  from vllm.model_executor.layers.fused_moe.config import (
13
- FusedMoEConfig, FusedMoEQuantConfig, biased_moe_quant_config)
27
+ FusedMoEConfig, FusedMoEQuantConfig, mxfp4_w4a16_moe_quant_config)
14
28
  from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
15
29
  FusedMoEMethodBase)
16
30
  from vllm.model_executor.layers.linear import LinearBase
@@ -28,44 +42,22 @@ from tpu_inference import envs
28
42
  from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
29
43
  from tpu_inference.layers.common.quant_methods import (MXFP4,
30
44
  get_tpu_quant_method)
31
- from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
45
+ from tpu_inference.layers.common.quantization import (
46
+ dequantize_tensor_from_mxfp4_packed, quantize_tensor)
47
+ from tpu_inference.layers.common.sharding import ShardingAxisName
48
+ from tpu_inference.layers.vllm.fused_moe import fused_moe_func
32
49
  from tpu_inference.layers.vllm.linear_common import \
33
50
  reorder_concatenated_tensor_for_sharding
34
51
  from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
35
52
  from tpu_inference.layers.vllm.quantization.unquantized import \
36
53
  VllmUnquantizedLinearMethod
54
+ from tpu_inference.utils import get_mesh_shape_product
37
55
 
38
- MXFP4_BLOCK_SIZE = 32
56
+ REQUANTIZED_BLOCK_SIZE = 512
39
57
 
40
58
  P = PartitionSpec
41
- logger = init_logger(__name__)
42
-
43
-
44
- # TODO(kyuyeunk): Move these functions into a common utility file.
45
- def u8_unpack_e2m1(u8_packed_e2m1: jax.Array) -> jax.Array:
46
- assert u8_packed_e2m1.dtype == jnp.uint8
47
- e2m1 = jax.lax.bitcast_convert_type(u8_packed_e2m1, jnp.float4_e2m1fn)
48
- # bitcast creates one more dimension that splits 8 bits into two e2m1.
49
- # we flatten them with the last dim.
50
- return jnp.reshape(e2m1, e2m1.shape[:-2] + (-1, ))
51
-
52
-
53
- def e8m0_to_fp32(u8: jax.Array) -> jax.Array:
54
- e8_finfo = jnp.finfo(jnp.float8_e8m0fnu)
55
- exponents = u8.astype(jnp.int32) + e8_finfo.minexp
56
- ones = jnp.ones_like(u8, dtype=jnp.float32)
57
- return jnp.ldexp(ones, exponents)
58
59
 
59
-
60
- def dequantize_block_weight(weight: jax.Array,
61
- scale: jax.Array,
62
- block_size: int,
63
- out_dtype: jnp.dtype = jnp.bfloat16) -> jax.Array:
64
- orig_shape = weight.shape
65
- weight_block = weight.reshape(orig_shape[:-1] + (-1, block_size))
66
- weight_dequantized = weight_block.astype(jnp.float32) * jnp.expand_dims(
67
- scale, -1)
68
- return weight_dequantized.reshape(orig_shape).astype(out_dtype)
60
+ logger = init_logger(__name__)
69
61
 
70
62
 
71
63
  @register_quantization_config(get_tpu_quant_method(MXFP4))
@@ -87,17 +79,14 @@ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
87
79
  fused_mapping=self.packed_modules_mapping,
88
80
  ):
89
81
  return VllmUnquantizedLinearMethod(linear_config)
90
- # TODO: Add support for MXFP4 Linear Method.
91
- # MXFP4 LinearMethod is available in AMD-Quark, refer to that
92
- # implementation if you are interested in enabling MXFP4 here.
93
82
  logger.warning_once(
94
83
  "MXFP4 linear layer is not implemented - falling back to "
95
84
  "UnquantizedLinearMethod.")
96
85
  return VllmUnquantizedLinearMethod(linear_config)
97
86
  elif isinstance(layer, FusedMoE):
98
- return VllmMxfp4MoEMethod(layer.moe_config, self.mesh)
87
+ moe_config = self.get_moe_config(layer)
88
+ return VllmMxfp4MoEMethod(moe_config, self.mesh)
99
89
  elif isinstance(layer, Attention):
100
- # TODO: Add support for MXFP4 Attention.
101
90
  logger.warning_once("MXFP4 attention layer is not implemented. "
102
91
  "Skipping quantization for this layer.")
103
92
  return None
@@ -116,225 +105,306 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
116
105
  self.mxfp4_backend = Mxfp4Backend.TRITON
117
106
 
118
107
  self.mesh = mesh
119
- self.use_kernel = envs.USE_MOE_EP_KERNEL
108
+ self.use_kernel = envs.USE_MOE_EP_KERNEL and moe.use_ep
120
109
  self.ep_axis_name = ep_axis_name
121
110
  # TODO: Use autotune table once we have it.
122
111
  self.block_size = {
123
- "bt": 64,
112
+ "bt": 256,
124
113
  "bf": 1024,
125
- "bd1": 1536,
126
- "bd2": 1536,
127
- "btc": 64,
114
+ "bd1": 1024,
115
+ "bd2": 1024,
116
+ "btc": 256,
128
117
  "bfc": 1024,
129
- "bd1c": 1536,
130
- "bd2c": 1536,
118
+ "bd1c": 1024,
119
+ "bd2c": 1024,
131
120
  }
132
121
 
133
122
  def get_fused_moe_quant_config(
134
123
  self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None:
135
- # Because we have dequantized weights, we only need biased moe config.
136
- # TODO(kyuyeunk): Add native support for MXFP4.
137
- return biased_moe_quant_config(
138
- layer.w13_bias,
139
- layer.w2_bias,
124
+ return mxfp4_w4a16_moe_quant_config(
125
+ w1_scale=layer.w13_weight_scale,
126
+ w2_scale=layer.w2_weight_scale,
127
+ w1_bias=layer.w13_bias,
128
+ w2_bias=layer.w2_bias,
140
129
  )
141
130
 
142
131
  def process_weights_after_loading(self, layer: torch.nn.Module):
143
132
  assert isinstance(layer, FusedMoE)
144
133
  assert layer.moe_config.has_bias, "mxfp4 quantization alwyas use bias."
145
134
 
146
- w13_weight = u8_unpack_e2m1(t2j(layer.w13_weight, use_dlpack=False))
147
- w13_weight_scale = e8m0_to_fp32(
148
- t2j(layer.w13_weight_scale, use_dlpack=False))
135
+ w13_weight = t2j(layer.w13_weight, use_dlpack=False)
136
+ w13_weight_scale = t2j(layer.w13_weight_scale, use_dlpack=False)
149
137
  w13_bias = t2j(layer.w13_bias, use_dlpack=False)
150
138
 
151
- w2_weight = u8_unpack_e2m1(t2j(layer.w2_weight, use_dlpack=False))
152
- w2_weight_scale = e8m0_to_fp32(
153
- t2j(layer.w2_weight_scale, use_dlpack=False))
139
+ w2_weight = t2j(layer.w2_weight, use_dlpack=False)
140
+ w2_weight_scale = t2j(layer.w2_weight_scale, use_dlpack=False)
154
141
  w2_bias = t2j(layer.w2_bias, use_dlpack=False)
155
142
 
156
- # We dequantize fp4 weights into bf16.
157
- # TODO(kyuyeunk): Add native support for MXFP4.
158
- w13_weight = dequantize_block_weight(w13_weight, w13_weight_scale,
159
- MXFP4_BLOCK_SIZE, jnp.bfloat16)
160
- w2_weight = dequantize_block_weight(w2_weight, w2_weight_scale,
161
- MXFP4_BLOCK_SIZE, jnp.bfloat16)
162
-
163
- # Because we have dequantized weights, scales are not used anymore.
164
- delattr(layer, "w13_weight_scale")
165
- delattr(layer, "w2_weight_scale")
166
-
167
- if layer.activation == "swigluoai":
168
- # When using swigluoai, vLLM splits gmm output in a interleaved way.
169
- # However, interleaved split is not performant on TPU. Therefore,
170
- # we preprocess the weight so that splitting gmm output by middle
171
- # can still get the same result.
172
- w1_weight = w13_weight[:, ::2, :]
173
- w3_weight = w13_weight[:, 1::2, :]
174
- w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
175
-
176
- w1_bias = w13_bias[:, ::2]
177
- w3_bias = w13_bias[:, 1::2]
178
- w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
179
-
180
- if self.use_kernel and layer.use_ep:
181
- # Kernel expects:
182
- # w13: (num_experts, 2, hidden_size, intermediate_size)
183
- # w2: (num_experts, intermediate_size, hidden_size)
184
- # Current format:
185
- # w13_weight: (num_experts, 2*intermediate_size, hidden_size)
186
- # w2_weight: (num_experts, hidden_size, intermediate_size)
187
- num_experts = w13_weight.shape[0]
188
- intermediate_size = w13_weight.shape[1] // 2
189
- hidden_size = w13_weight.shape[2]
190
-
191
- # Reshape and transpose w13_weight to (num_experts, 2, hidden_size, intermediate_size)
192
- w13_reshaped = w13_weight.reshape(num_experts, 2,
193
- intermediate_size, hidden_size)
194
- w13_weight_transposed = jnp.transpose(w13_reshaped, (0, 1, 3, 2))
195
-
196
- # Transpose w2_weight to (num_experts, intermediate_size, hidden_size)
197
- w2_weight_transposed = jnp.transpose(w2_weight, (0, 2, 1))
198
-
199
- # Apply EP sharding
200
- w13_weight = jax.device_put(
201
- w13_weight_transposed,
202
- Format(Layout((0, 1, 2, 3)),
203
- NamedSharding(self.mesh, P("model", None, None, None))))
204
- w2_weight = jax.device_put(
205
- w2_weight_transposed,
206
- Format(Layout((0, 1, 2)),
207
- NamedSharding(self.mesh, P("model", None, None))))
208
-
209
- if self.moe.has_bias:
210
- w13_bias = w13_bias.reshape(num_experts, 2, intermediate_size)
143
+ # Wrap functions in jit to speedup requantization.
144
+ @jax.jit
145
+ def wrapper(w13_weight, w13_weight_scale, w13_bias, w2_weight,
146
+ w2_weight_scale, w2_bias):
147
+ # Dequantize fp4 weights into fp32.
148
+ w13_weight = dequantize_tensor_from_mxfp4_packed(
149
+ w13_weight, w13_weight_scale, 2)
150
+ w2_weight = dequantize_tensor_from_mxfp4_packed(
151
+ w2_weight, w2_weight_scale, 2)
152
+
153
+ num_experts, orig_hidden_size, orig_intermediate_size = w2_weight.shape
154
+
155
+ # Requantize the weights into TPU friendly block size.
156
+ w13_weight, w13_weight_scale = quantize_tensor(
157
+ jnp.float4_e2m1fn, w13_weight, 2, REQUANTIZED_BLOCK_SIZE, True)
158
+ w2_weight, w2_weight_scale = quantize_tensor(
159
+ jnp.float4_e2m1fn, w2_weight, 2, REQUANTIZED_BLOCK_SIZE, True)
160
+
161
+ intermediate_size = w2_weight.shape[-1]
162
+ hidden_size = w13_weight.shape[-1]
163
+
164
+ # Dims may have been padded to align with subchannel size during
165
+ # quantization. We pad the corresponding dim on other weight.
166
+ # NOTE: We perform padding after quantization as padding value can
167
+ # affect quantization numerics.
168
+ intermediate_padding_size = 2 * (intermediate_size -
169
+ orig_intermediate_size)
170
+ w13_weight = jnp.pad(w13_weight,
171
+ ((0, 0), (0, intermediate_padding_size),
172
+ (0, 0)))
173
+ w13_weight_scale = jnp.pad(w13_weight_scale,
174
+ ((0, 0), (0, intermediate_padding_size),
175
+ (0, 0)))
176
+ w13_bias = jnp.pad(w13_bias,
177
+ ((0, 0), (0, intermediate_padding_size)))
178
+
179
+ hidden_padding_size = hidden_size - orig_hidden_size
180
+ w2_weight = jnp.pad(w2_weight,
181
+ ((0, 0), (0, hidden_padding_size), (0, 0)))
182
+ w2_weight_scale = jnp.pad(w2_weight_scale,
183
+ ((0, 0), (0, hidden_padding_size),
184
+ (0, 0)))
185
+ w2_bias = jnp.pad(w2_bias, ((0, 0), (0, hidden_padding_size)))
186
+
187
+ if layer.activation == "swigluoai":
188
+ # When using swigluoai, vLLM splits gmm output in a interleaved way.
189
+ # However, interleaved split is not performant on TPU. Therefore,
190
+ # we preprocess the weight so that splitting gmm output by middle
191
+ # can still get the same result.
192
+ w1_weight = w13_weight[:, ::2, :]
193
+ w3_weight = w13_weight[:, 1::2, :]
194
+ w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
195
+
196
+ w1_weight_scale = w13_weight_scale[:, ::2, :]
197
+ w3_weight_scale = w13_weight_scale[:, 1::2, :]
198
+ w13_weight_scale = jnp.concat(
199
+ [w1_weight_scale, w3_weight_scale], axis=1)
200
+
201
+ w1_bias = w13_bias[:, ::2]
202
+ w3_bias = w13_bias[:, 1::2]
203
+ w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
204
+
205
+ if self.use_kernel:
206
+ # Kernel expects:
207
+ # w13: (num_experts, 2, hidden_size, intermediate_size)
208
+ # w2: (num_experts, intermediate_size, hidden_size)
209
+ # Current format:
210
+ # w13_weight: (num_experts, 2*intermediate_size, hidden_size)
211
+ # w2_weight: (num_experts, hidden_size, intermediate_size)
212
+
213
+ w13_weight = w13_weight.reshape(num_experts, 2,
214
+ intermediate_size, hidden_size)
215
+
216
+ w13_weight_scale = w13_weight_scale.reshape(
217
+ num_experts, 2, intermediate_size, 1, -1)
218
+ w2_weight_scale = w2_weight_scale.reshape(
219
+ num_experts, hidden_size, 1, -1)
220
+
221
+ w13_bias = w13_bias.astype(jnp.float32).reshape(
222
+ num_experts, 2, 1, intermediate_size)
223
+ w2_bias = w2_bias.astype(jnp.float32).reshape(
224
+ num_experts, 1, hidden_size)
225
+
226
+ # Transpose non-constracting dim to right most dim
227
+ w13_weight = jnp.swapaxes(w13_weight, 2, 3)
228
+ w2_weight = jnp.swapaxes(w2_weight, 1, 2)
229
+
230
+ w13_weight_scale = jnp.swapaxes(w13_weight_scale, 2, 4)
231
+ w2_weight_scale = jnp.swapaxes(w2_weight_scale, 1, 3)
211
232
 
212
233
  # Apply EP sharding
213
- w13_bias = jax.device_put(
214
- w13_bias,
215
- Format(Layout((0, 1, 2)),
216
- NamedSharding(self.mesh, P("model", None, None))))
217
- w2_bias = jax.device_put(
218
- w2_bias,
219
- Format(Layout((0, 1)),
220
- NamedSharding(self.mesh, P("model", None))))
221
-
222
- else:
223
- if layer.use_ep:
224
- w13_weight = jax.device_put(
225
- w13_weight,
226
- Format(Layout((0, 1, 2)),
227
- NamedSharding(self.mesh, P("model", None, None))))
228
- w2_weight = jax.device_put(
229
- w2_weight,
230
- Format(Layout((0, 1, 2)),
231
- NamedSharding(self.mesh, P("model", None, None))))
232
-
233
- w13_bias = jax.device_put(
234
- w13_bias,
235
- Format(Layout((0, 1)),
236
- NamedSharding(self.mesh, P("model", None))))
237
- w2_bias = jax.device_put(
238
- w2_bias,
239
- Format(Layout((0, 1)),
240
- NamedSharding(self.mesh, P("model", None))))
241
-
234
+ ep_sharding = NamedSharding(self.mesh, P("model"))
235
+
236
+ w13_weight = jax.lax.with_sharding_constraint(
237
+ w13_weight, Format(Layout((0, 1, 2, 3)), ep_sharding))
238
+ w2_weight = jax.lax.with_sharding_constraint(
239
+ w2_weight, Format(Layout((0, 1, 2)), ep_sharding))
240
+
241
+ w13_weight_scale = jax.lax.with_sharding_constraint(
242
+ w13_weight_scale,
243
+ Format(Layout((0, 1, 2, 3, 4)), ep_sharding))
244
+ w2_weight_scale = jax.lax.with_sharding_constraint(
245
+ w2_weight_scale, Format(Layout((0, 1, 2, 3)), ep_sharding))
246
+
247
+ w13_bias = jax.lax.with_sharding_constraint(
248
+ w13_bias, Format(Layout((0, 1, 2, 3)), ep_sharding))
249
+ w2_bias = jax.lax.with_sharding_constraint(
250
+ w2_bias, Format(Layout((0, 1, 2)), ep_sharding))
242
251
  else:
243
- intermediate_size = w13_weight.shape[1] // 2
244
- assert intermediate_size == w2_weight.shape[-1]
245
- output_sizes = [intermediate_size, intermediate_size]
246
- n_shards = self.mesh.shape["model"]
247
- assert intermediate_size % n_shards == 0
248
- w13_weight = reorder_concatenated_tensor_for_sharding(
249
- w13_weight, output_sizes, n_shards, dim=1)
250
- w13_weight = jax.device_put(
251
- w13_weight,
252
- Format(Layout((0, 1, 2)),
253
- NamedSharding(self.mesh, P(None, "model", None))))
254
- w2_weight = jax.device_put(
255
- w2_weight,
256
- Format(Layout((0, 1, 2)),
257
- NamedSharding(self.mesh, P(None, None, "model"))))
258
-
259
- w13_bias = reorder_concatenated_tensor_for_sharding(
260
- w13_bias, output_sizes, n_shards, dim=1)
261
- w13_bias = jax.device_put(
262
- w13_bias,
263
- Format(Layout((0, 1)),
264
- NamedSharding(self.mesh, P(None, "model"))))
265
- w2_bias = jax.device_put(
266
- w2_bias,
267
- Format(Layout((0, 1)),
268
- NamedSharding(self.mesh, P(None, None))))
252
+ w13_weight_scale = jnp.swapaxes(w13_weight_scale, 1, 2)
253
+ w13_weight_scale = jnp.expand_dims(w13_weight_scale, 2)
254
+ w2_weight_scale = jnp.swapaxes(w2_weight_scale, 1, 2)
255
+ w2_weight_scale = jnp.expand_dims(w2_weight_scale, 2)
256
+
257
+ w13_bias = jnp.expand_dims(w13_bias, 1)
258
+ w2_bias = jnp.expand_dims(w2_bias, 1)
259
+
260
+ if layer.use_ep:
261
+ ep_sharding = NamedSharding(self.mesh,
262
+ P(ShardingAxisName.EXPERT))
263
+
264
+ w13_weight = jax.lax.with_sharding_constraint(
265
+ w13_weight, ep_sharding)
266
+ w2_weight = jax.lax.with_sharding_constraint(
267
+ w2_weight, ep_sharding)
268
+
269
+ w13_weight_scale = jax.lax.with_sharding_constraint(
270
+ w13_weight_scale, ep_sharding)
271
+ w2_weight_scale = jax.lax.with_sharding_constraint(
272
+ w2_weight_scale, ep_sharding)
273
+
274
+ w13_bias = jax.lax.with_sharding_constraint(
275
+ w13_bias, ep_sharding)
276
+ w2_bias = jax.lax.with_sharding_constraint(
277
+ w2_bias, ep_sharding)
278
+
279
+ else:
280
+ output_sizes = [intermediate_size, intermediate_size]
281
+ n_shards = get_mesh_shape_product(
282
+ self.mesh, ShardingAxisName.MLP_TENSOR)
283
+ assert intermediate_size % n_shards == 0
284
+
285
+ # Reorder w13 weights so that splitting w1 and w3 output
286
+ # can happen locally without any collective operations.
287
+ w13_weight = reorder_concatenated_tensor_for_sharding(
288
+ w13_weight,
289
+ output_sizes,
290
+ n_shards,
291
+ dim=1,
292
+ )
293
+ w13_weight_scale = reorder_concatenated_tensor_for_sharding(
294
+ w13_weight_scale,
295
+ output_sizes,
296
+ n_shards,
297
+ dim=3,
298
+ )
299
+ w13_bias = reorder_concatenated_tensor_for_sharding(
300
+ w13_bias,
301
+ output_sizes,
302
+ n_shards,
303
+ dim=2,
304
+ )
305
+
306
+ w13_weight = jax.lax.with_sharding_constraint(
307
+ w13_weight,
308
+ NamedSharding(
309
+ self.mesh,
310
+ P(None, ShardingAxisName.MLP_TENSOR, None)))
311
+ w2_weight = jax.lax.with_sharding_constraint(
312
+ w2_weight,
313
+ NamedSharding(
314
+ self.mesh,
315
+ P(None, None, ShardingAxisName.MLP_TENSOR)))
316
+ w13_weight_scale = jax.lax.with_sharding_constraint(
317
+ w13_weight_scale,
318
+ NamedSharding(
319
+ self.mesh,
320
+ P(None, None, None, ShardingAxisName.MLP_TENSOR)))
321
+ w2_weight_scale = jax.lax.with_sharding_constraint(
322
+ w2_weight_scale,
323
+ NamedSharding(
324
+ self.mesh,
325
+ P(None, ShardingAxisName.MLP_TENSOR, None, None)))
326
+ w13_bias = jax.lax.with_sharding_constraint(
327
+ w13_bias,
328
+ NamedSharding(
329
+ self.mesh,
330
+ P(None, None, ShardingAxisName.MLP_TENSOR)))
331
+ w2_bias = jax.lax.with_sharding_constraint(
332
+ w2_bias, NamedSharding(self.mesh, P(None, None, None)))
333
+
334
+ return w13_weight, w13_weight_scale, w13_bias, w2_weight, w2_weight_scale, w2_bias
335
+
336
+ w13_weight, w13_weight_scale, w13_bias, w2_weight, w2_weight_scale, w2_bias = wrapper(
337
+ w13_weight, w13_weight_scale, w13_bias, w2_weight, w2_weight_scale,
338
+ w2_bias)
269
339
 
270
340
  layer.w13_weight = Parameter(torch_view(w13_weight),
271
341
  requires_grad=False)
272
- layer.w13_bias = Parameter(torch_view(w13_bias), requires_grad=False)
273
-
274
342
  layer.w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
275
- layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
276
343
 
277
- pass
344
+ layer.w13_weight_scale = Parameter(torch_view(w13_weight_scale),
345
+ requires_grad=False)
346
+ layer.w2_weight_scale = Parameter(torch_view(w2_weight_scale),
347
+ requires_grad=False)
348
+
349
+ layer.w13_bias = Parameter(torch_view(w13_bias), requires_grad=False)
350
+ layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
278
351
 
279
352
  def apply(
280
353
  self,
281
354
  layer: torch.nn.Module,
282
355
  x: torch.Tensor,
283
356
  router_logits: torch.Tensor,
284
- top_k: int,
285
- renormalize: bool,
286
- use_grouped_topk: bool = False,
287
- topk_group: Optional[int] = None,
288
- num_expert_group: Optional[int] = None,
289
- global_num_experts: int = -1,
290
- expert_map: Optional[torch.Tensor] = None,
291
- custom_routing_function: Optional[Callable] = None,
292
- scoring_func: str = "softmax",
293
- routed_scaling_factor: float = 1.0,
294
- e_score_correction_bias: Optional[torch.Tensor] = None,
295
- apply_router_weight_on_input: bool = False,
296
- activation: str = "silu",
297
- enable_eplb: bool = False,
298
- expert_load_view: Optional[torch.Tensor] = None,
299
- logical_to_physical_map: Optional[torch.Tensor] = None,
300
- logical_replica_count: Optional[torch.Tensor] = None,
301
357
  ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
302
358
  assert isinstance(layer, FusedMoE)
303
- if scoring_func != "softmax":
359
+ if layer.scoring_func != "softmax":
304
360
  raise NotImplementedError(
305
361
  "Only softmax is supported for scoring_func")
306
362
 
307
- if self.use_kernel and layer.use_ep:
363
+ x = jax_view(x)
364
+ w13_weight = jax_view(layer.w13_weight)
365
+ w2_weight = jax_view(layer.w2_weight)
366
+ w13_weight_scale = jax_view(layer.w13_weight_scale)
367
+ w2_weight_scale = jax_view(layer.w2_weight_scale)
368
+ w13_bias = jax_view(layer.w13_bias)
369
+ w2_bias = jax_view(layer.w2_bias)
370
+ gating_output = jax_view(router_logits)
371
+
372
+ if self.use_kernel:
373
+ actual_hidden_size = x.shape[-1]
374
+ padding_size = w13_weight.shape[-2] - actual_hidden_size
375
+ x = jnp.pad(x, ((0, 0), (0, padding_size)))
308
376
  output = fused_ep_moe(
309
377
  mesh=self.mesh,
310
- tokens=jax_view(x),
311
- w1=jax_view(layer.w13_weight),
312
- w2=jax_view(layer.w2_weight),
313
- b1=jax_view(layer.w13_bias),
314
- b2=jax_view(layer.w2_bias),
315
- gating_output=jax_view(router_logits),
316
- top_k=top_k,
378
+ tokens=x,
379
+ w1=w13_weight,
380
+ w2=w2_weight,
381
+ w1_scale=w13_weight_scale,
382
+ w2_scale=w2_weight_scale,
383
+ b1=w13_bias,
384
+ b2=w2_bias,
385
+ gating_output=gating_output,
386
+ subc_quant_wsz=REQUANTIZED_BLOCK_SIZE,
387
+ top_k=layer.top_k,
317
388
  ep_axis_name=self.ep_axis_name,
318
- renormalize_topk_logits=renormalize,
319
- act_fn=activation,
389
+ renormalize_topk_logits=layer.renormalize,
390
+ act_fn=layer.activation,
320
391
  **self.block_size,
321
- )
392
+ )[:, :actual_hidden_size]
322
393
  else:
323
- # Use the original implementation
324
- output = fused_moe_func_padded(
325
- jax_view(x),
326
- jax_view(layer.w13_weight),
327
- jax_view(layer.w2_weight),
328
- jax_view(layer.w13_bias),
329
- jax_view(layer.w2_bias),
330
- jax_view(router_logits),
331
- topk=top_k,
332
- global_num_experts=global_num_experts,
333
- renormalize=renormalize,
334
- reduce_results=layer.reduce_results,
394
+ output = fused_moe_func(
395
+ hidden_states=x,
396
+ w1=w13_weight,
397
+ w2=w2_weight,
398
+ w1_scale=w13_weight_scale,
399
+ w2_scale=w2_weight_scale,
400
+ w1_bias=w13_bias,
401
+ w2_bias=w2_bias,
402
+ gating_output=gating_output,
403
+ topk=layer.top_k,
404
+ renormalize=layer.renormalize,
335
405
  mesh=self.mesh,
336
406
  use_ep=layer.use_ep,
337
- activation=activation,
407
+ activation=layer.activation,
338
408
  )
339
409
 
340
410
  return torch_view(output)