tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (248) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +14 -0
  31. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  32. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
  35. tests/layers/__init__.py +13 -0
  36. tests/layers/common/__init__.py +13 -0
  37. tests/layers/common/test_attention_interface.py +156 -0
  38. tests/layers/common/test_quantization.py +149 -0
  39. tests/layers/jax/__init__.py +13 -0
  40. tests/layers/jax/attention/__init__.py +13 -0
  41. tests/layers/jax/attention/test_common_attention.py +103 -0
  42. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  43. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  44. tests/layers/jax/moe/__init__.py +13 -0
  45. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  46. tests/layers/jax/sample/__init__.py +13 -0
  47. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  48. tests/layers/jax/sample/test_sampling.py +115 -0
  49. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  50. tests/layers/jax/test_layers.py +155 -0
  51. tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
  52. tests/layers/jax/test_rope.py +93 -0
  53. tests/layers/jax/test_sharding.py +159 -0
  54. tests/layers/jax/test_transformer_block.py +152 -0
  55. tests/layers/vllm/__init__.py +13 -0
  56. tests/layers/vllm/test_attention.py +363 -0
  57. tests/layers/vllm/test_awq.py +406 -0
  58. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  59. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  61. tests/layers/vllm/test_fp8.py +17 -0
  62. tests/layers/vllm/test_mxfp4.py +320 -0
  63. tests/layers/vllm/test_unquantized.py +662 -0
  64. tests/layers/vllm/utils.py +87 -0
  65. tests/lora/__init__.py +13 -0
  66. tests/lora/conftest.py +14 -0
  67. tests/lora/test_bgmv.py +14 -0
  68. tests/lora/test_layers.py +25 -8
  69. tests/lora/test_lora.py +15 -1
  70. tests/lora/test_lora_perf.py +14 -0
  71. tests/models/__init__.py +13 -0
  72. tests/models/common/__init__.py +13 -0
  73. tests/models/common/test_model_loader.py +455 -0
  74. tests/models/jax/__init__.py +13 -0
  75. tests/models/jax/test_deepseek_v3.py +401 -0
  76. tests/models/jax/test_llama3.py +184 -0
  77. tests/models/jax/test_llama4.py +298 -0
  78. tests/models/jax/test_llama_eagle3.py +197 -0
  79. tests/models/jax/test_llama_guard_4.py +242 -0
  80. tests/models/jax/test_qwen2.py +172 -0
  81. tests/models/jax/test_qwen2_5_vl.py +605 -0
  82. tests/models/jax/test_qwen3.py +169 -0
  83. tests/models/jax/test_weight_loading.py +180 -0
  84. tests/models/jax/utils/__init__.py +13 -0
  85. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  86. tests/platforms/__init__.py +13 -0
  87. tests/platforms/test_tpu_platform.py +54 -0
  88. tests/runner/__init__.py +13 -0
  89. tests/runner/test_block_table.py +395 -0
  90. tests/runner/test_input_batch.py +226 -0
  91. tests/runner/test_kv_cache.py +220 -0
  92. tests/runner/test_kv_cache_manager.py +498 -0
  93. tests/runner/test_multimodal_manager.py +429 -0
  94. tests/runner/test_persistent_batch_manager.py +84 -0
  95. tests/runner/test_speculative_decoding_manager.py +368 -0
  96. tests/runner/test_structured_decoding_manager.py +220 -0
  97. tests/runner/test_tpu_runner.py +261 -0
  98. tests/runner/test_tpu_runner_dp.py +1099 -0
  99. tests/runner/test_tpu_runner_mesh.py +200 -0
  100. tests/runner/test_utils.py +411 -0
  101. tests/spec_decode/__init__.py +13 -0
  102. tests/spec_decode/test_eagle3.py +311 -0
  103. tests/test_base.py +14 -0
  104. tests/test_tpu_info.py +14 -0
  105. tests/test_utils.py +1 -43
  106. tests/worker/__init__.py +13 -0
  107. tests/worker/tpu_worker_test.py +414 -0
  108. tpu_inference/__init__.py +14 -0
  109. tpu_inference/core/__init__.py +13 -0
  110. tpu_inference/core/sched/__init__.py +13 -0
  111. tpu_inference/core/sched/dp_scheduler.py +372 -56
  112. tpu_inference/distributed/__init__.py +13 -0
  113. tpu_inference/distributed/jax_parallel_state.py +14 -0
  114. tpu_inference/distributed/tpu_connector.py +14 -9
  115. tpu_inference/distributed/utils.py +56 -4
  116. tpu_inference/executors/__init__.py +13 -0
  117. tpu_inference/executors/ray_distributed_executor.py +20 -3
  118. tpu_inference/experimental/__init__.py +13 -0
  119. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  120. tpu_inference/kernels/__init__.py +13 -0
  121. tpu_inference/kernels/collectives/__init__.py +13 -0
  122. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  123. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  124. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  125. tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
  126. tpu_inference/kernels/megablox/__init__.py +13 -0
  127. tpu_inference/kernels/megablox/common.py +54 -0
  128. tpu_inference/kernels/megablox/gmm.py +646 -0
  129. tpu_inference/kernels/mla/__init__.py +13 -0
  130. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  131. tpu_inference/kernels/mla/v1/kernel.py +20 -26
  132. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  133. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  134. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  135. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  136. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
  137. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
  138. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  139. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
  140. tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
  141. tpu_inference/layers/__init__.py +13 -0
  142. tpu_inference/layers/common/__init__.py +13 -0
  143. tpu_inference/layers/common/attention_interface.py +26 -19
  144. tpu_inference/layers/common/attention_metadata.py +14 -0
  145. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  146. tpu_inference/layers/common/quant_methods.py +15 -0
  147. tpu_inference/layers/common/quantization.py +282 -0
  148. tpu_inference/layers/common/sharding.py +22 -3
  149. tpu_inference/layers/common/utils.py +94 -0
  150. tpu_inference/layers/jax/__init__.py +13 -0
  151. tpu_inference/layers/jax/attention/__init__.py +13 -0
  152. tpu_inference/layers/jax/attention/attention.py +19 -6
  153. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
  154. tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
  155. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  156. tpu_inference/layers/jax/base.py +14 -0
  157. tpu_inference/layers/jax/constants.py +13 -0
  158. tpu_inference/layers/jax/layers.py +14 -0
  159. tpu_inference/layers/jax/misc.py +14 -0
  160. tpu_inference/layers/jax/moe/__init__.py +13 -0
  161. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  162. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  163. tpu_inference/layers/jax/moe/moe.py +43 -3
  164. tpu_inference/layers/jax/pp_utils.py +53 -0
  165. tpu_inference/layers/jax/rope.py +14 -0
  166. tpu_inference/layers/jax/rope_interface.py +14 -0
  167. tpu_inference/layers/jax/sample/__init__.py +13 -0
  168. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  169. tpu_inference/layers/jax/sample/sampling.py +15 -1
  170. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  171. tpu_inference/layers/jax/transformer_block.py +14 -0
  172. tpu_inference/layers/vllm/__init__.py +13 -0
  173. tpu_inference/layers/vllm/attention.py +4 -4
  174. tpu_inference/layers/vllm/fused_moe.py +100 -455
  175. tpu_inference/layers/vllm/linear.py +64 -0
  176. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  177. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  178. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  179. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  180. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  181. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  182. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  183. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
  184. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  188. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
  189. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  190. tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
  191. tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
  192. tpu_inference/lora/__init__.py +13 -0
  193. tpu_inference/lora/torch_lora_ops.py +8 -13
  194. tpu_inference/models/__init__.py +13 -0
  195. tpu_inference/models/common/__init__.py +13 -0
  196. tpu_inference/models/common/model_loader.py +37 -16
  197. tpu_inference/models/jax/__init__.py +13 -0
  198. tpu_inference/models/jax/deepseek_v3.py +113 -124
  199. tpu_inference/models/jax/gpt_oss.py +23 -7
  200. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  201. tpu_inference/models/jax/llama3.py +99 -36
  202. tpu_inference/models/jax/llama4.py +14 -0
  203. tpu_inference/models/jax/llama_eagle3.py +14 -0
  204. tpu_inference/models/jax/llama_guard_4.py +15 -1
  205. tpu_inference/models/jax/qwen2.py +17 -2
  206. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  207. tpu_inference/models/jax/qwen3.py +17 -2
  208. tpu_inference/models/jax/utils/__init__.py +13 -0
  209. tpu_inference/models/jax/utils/file_utils.py +14 -0
  210. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  211. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
  213. tpu_inference/models/jax/utils/weight_utils.py +32 -1
  214. tpu_inference/models/vllm/__init__.py +13 -0
  215. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
  216. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  217. tpu_inference/platforms/__init__.py +14 -0
  218. tpu_inference/platforms/tpu_platform.py +27 -29
  219. tpu_inference/runner/__init__.py +13 -0
  220. tpu_inference/runner/compilation_manager.py +69 -35
  221. tpu_inference/runner/kv_cache.py +14 -0
  222. tpu_inference/runner/kv_cache_manager.py +15 -2
  223. tpu_inference/runner/lora_utils.py +16 -1
  224. tpu_inference/runner/multimodal_manager.py +16 -2
  225. tpu_inference/runner/persistent_batch_manager.py +14 -0
  226. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  227. tpu_inference/runner/structured_decoding_manager.py +14 -0
  228. tpu_inference/runner/tpu_runner.py +30 -10
  229. tpu_inference/spec_decode/__init__.py +13 -0
  230. tpu_inference/spec_decode/jax/__init__.py +13 -0
  231. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  232. tpu_inference/tpu_info.py +14 -0
  233. tpu_inference/utils.py +31 -30
  234. tpu_inference/worker/__init__.py +13 -0
  235. tpu_inference/worker/tpu_worker.py +23 -7
  236. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
  237. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  238. tpu_inference/layers/vllm/linear_common.py +0 -208
  239. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  240. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  241. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  242. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  245. tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
  246. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  247. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  248. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,29 @@
1
- from typing import Optional, Union
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
2
16
 
3
17
  import jax
4
18
  import jax.numpy as jnp
5
19
  import torch
6
- from jax.experimental.layout import Format, Layout
7
- from jax.sharding import Mesh, NamedSharding, PartitionSpec
20
+ from jax.sharding import Mesh, PartitionSpec
8
21
  from torch.nn.parameter import Parameter
9
- from torchax.interop import jax_view, torch_view
22
+ from torchax.interop import torch_view
10
23
  from torchax.ops.mappings import t2j
11
- from vllm.logger import init_logger
24
+ from vllm.attention.layer import Attention
12
25
  from vllm.model_executor.layers.fused_moe.config import (
13
- FusedMoEConfig, FusedMoEQuantConfig, biased_moe_quant_config)
26
+ FusedMoEConfig, FusedMoEQuantConfig, mxfp4_w4a16_moe_quant_config)
14
27
  from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
15
28
  FusedMoEMethodBase)
16
29
  from vllm.model_executor.layers.linear import LinearBase
@@ -24,52 +37,32 @@ from vllm.model_executor.layers.quantization.mxfp4 import (Mxfp4Backend,
24
37
  from vllm.model_executor.layers.quantization.utils.quant_utils import \
25
38
  is_layer_skipped
26
39
 
27
- from tpu_inference import envs
28
- from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
29
40
  from tpu_inference.layers.common.quant_methods import (MXFP4,
30
41
  get_tpu_quant_method)
31
- from tpu_inference.layers.vllm.fused_moe import fused_moe_func
32
- from tpu_inference.layers.vllm.linear_common import \
33
- reorder_concatenated_tensor_for_sharding
34
- from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
42
+ from tpu_inference.layers.common.quantization import \
43
+ dequantize_tensor_from_mxfp4_packed
44
+ from tpu_inference.layers.common.sharding import ShardingAxisName
45
+ from tpu_inference.layers.vllm.fused_moe import (FusedMoEBackend,
46
+ fused_moe_apply,
47
+ select_moe_backend)
48
+ from tpu_inference.layers.vllm.process_weights.fused_moe_weights import (
49
+ FusedMoEWeights, process_moe_weights, quantize_moe_weights,
50
+ shard_moe_weights)
51
+ from tpu_inference.layers.vllm.quantization.configs import VllmQuantConfig
35
52
  from tpu_inference.layers.vllm.quantization.unquantized import \
36
53
  VllmUnquantizedLinearMethod
54
+ from tpu_inference.logger import init_logger
55
+ from tpu_inference.utils import get_mesh_shape_product
37
56
 
38
- MXFP4_BLOCK_SIZE = 32
57
+ REQUANTIZED_BLOCK_SIZE = 512
39
58
 
40
59
  P = PartitionSpec
41
- logger = init_logger(__name__)
42
-
43
-
44
- # TODO(kyuyeunk): Move these functions into a common utility file.
45
- def u8_unpack_e2m1(u8_packed_e2m1: jax.Array) -> jax.Array:
46
- assert u8_packed_e2m1.dtype == jnp.uint8
47
- e2m1 = jax.lax.bitcast_convert_type(u8_packed_e2m1, jnp.float4_e2m1fn)
48
- # bitcast creates one more dimension that splits 8 bits into two e2m1.
49
- # we flatten them with the last dim.
50
- return jnp.reshape(e2m1, e2m1.shape[:-2] + (-1, ))
51
-
52
60
 
53
- def e8m0_to_fp32(u8: jax.Array) -> jax.Array:
54
- e8_finfo = jnp.finfo(jnp.float8_e8m0fnu)
55
- exponents = u8.astype(jnp.int32) + e8_finfo.minexp
56
- ones = jnp.ones_like(u8, dtype=jnp.float32)
57
- return jnp.ldexp(ones, exponents)
58
-
59
-
60
- def dequantize_block_weight(weight: jax.Array,
61
- scale: jax.Array,
62
- block_size: int,
63
- out_dtype: jnp.dtype = jnp.bfloat16) -> jax.Array:
64
- orig_shape = weight.shape
65
- weight_block = weight.reshape(orig_shape[:-1] + (-1, block_size))
66
- weight_dequantized = weight_block.astype(jnp.float32) * jnp.expand_dims(
67
- scale, -1)
68
- return weight_dequantized.reshape(orig_shape).astype(out_dtype)
61
+ logger = init_logger(__name__)
69
62
 
70
63
 
71
64
  @register_quantization_config(get_tpu_quant_method(MXFP4))
72
- class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
65
+ class VllmMxfp4Config(Mxfp4Config, VllmQuantConfig):
73
66
 
74
67
  @classmethod
75
68
  def get_name(cls):
@@ -77,7 +70,6 @@ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
77
70
 
78
71
  def get_quant_method(self, layer: torch.nn.Module,
79
72
  prefix: str) -> Optional["QuantizeMethodBase"]:
80
- from vllm.attention.layer import Attention # Avoid circular import
81
73
 
82
74
  if isinstance(layer, LinearBase):
83
75
  linear_config = self.get_linear_config(layer)
@@ -102,10 +94,12 @@ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
102
94
 
103
95
  class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
104
96
 
105
- def __init__(self,
106
- moe: FusedMoEConfig,
107
- mesh: Mesh,
108
- ep_axis_name: str = 'model'):
97
+ def __init__(
98
+ self,
99
+ moe: FusedMoEConfig,
100
+ mesh: Mesh,
101
+ ep_axis_name: str = "model",
102
+ ):
109
103
  FusedMoEMethodBase.__init__(self, moe)
110
104
 
111
105
  # We piggyback on triton implementation as it applies minimal hardware
@@ -113,200 +107,119 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
113
107
  self.mxfp4_backend = Mxfp4Backend.TRITON
114
108
 
115
109
  self.mesh = mesh
116
- self.use_kernel = envs.USE_MOE_EP_KERNEL and moe.use_ep
117
- self.ep_axis_name = ep_axis_name
118
- # TODO: Use autotune table once we have it.
119
- self.block_size = {
120
- "bt": 64,
121
- "bf": 1024,
122
- "bd1": 1536,
123
- "bd2": 1536,
124
- "btc": 64,
125
- "bfc": 1024,
126
- "bd1c": 1536,
127
- "bd2c": 1536,
128
- }
110
+ self.moe_backend = select_moe_backend(self.moe)
111
+
112
+ self.extra_backend_kwargs = {}
113
+ if self.moe_backend == FusedMoEBackend.FUSED_MOE:
114
+ # When fused moe kernle is used, we pass extra arguments like
115
+ # tuned block sizes to the kernel.
116
+ self.extra_backend_kwargs = dict(
117
+ subc_quant_wsz=REQUANTIZED_BLOCK_SIZE,
118
+ ep_axis_name=ep_axis_name,
119
+ # TODO: Use autotune table once we have it.
120
+ bt=256,
121
+ bf=1024,
122
+ bd1=1024,
123
+ bd2=1024,
124
+ btc=256,
125
+ bfc=1024,
126
+ bd1c=1024,
127
+ bd2c=1024,
128
+ )
129
129
 
130
130
  def get_fused_moe_quant_config(
131
131
  self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None:
132
- # Because we have dequantized weights, we only need biased moe config.
133
- # TODO(kyuyeunk): Add native support for MXFP4.
134
- return biased_moe_quant_config(
135
- layer.w13_bias,
136
- layer.w2_bias,
132
+ return mxfp4_w4a16_moe_quant_config(
133
+ w1_scale=layer.w13_weight_scale,
134
+ w2_scale=layer.w2_weight_scale,
135
+ w1_bias=layer.w13_bias,
136
+ w2_bias=layer.w2_bias,
137
137
  )
138
138
 
139
139
  def process_weights_after_loading(self, layer: torch.nn.Module):
140
140
  assert isinstance(layer, FusedMoE)
141
141
  assert layer.moe_config.has_bias, "mxfp4 quantization alwyas use bias."
142
142
 
143
- w13_weight = u8_unpack_e2m1(t2j(layer.w13_weight, use_dlpack=False))
144
- w13_weight_scale = e8m0_to_fp32(
145
- t2j(layer.w13_weight_scale, use_dlpack=False))
143
+ w13_weight = t2j(layer.w13_weight, use_dlpack=False)
144
+ w13_weight_scale = t2j(layer.w13_weight_scale, use_dlpack=False)
146
145
  w13_bias = t2j(layer.w13_bias, use_dlpack=False)
147
146
 
148
- w2_weight = u8_unpack_e2m1(t2j(layer.w2_weight, use_dlpack=False))
149
- w2_weight_scale = e8m0_to_fp32(
150
- t2j(layer.w2_weight_scale, use_dlpack=False))
147
+ w2_weight = t2j(layer.w2_weight, use_dlpack=False)
148
+ w2_weight_scale = t2j(layer.w2_weight_scale, use_dlpack=False)
151
149
  w2_bias = t2j(layer.w2_bias, use_dlpack=False)
152
150
 
153
- # We dequantize fp4 weights into bf16.
154
- # TODO(kyuyeunk): Add native support for MXFP4.
155
- w13_weight = dequantize_block_weight(w13_weight, w13_weight_scale,
156
- MXFP4_BLOCK_SIZE, jnp.bfloat16)
157
- w2_weight = dequantize_block_weight(w2_weight, w2_weight_scale,
158
- MXFP4_BLOCK_SIZE, jnp.bfloat16)
159
-
160
- num_experts, hidden_size, intermediate_size = w2_weight.shape
161
-
162
- # Because we have dequantized weights, scales are not used anymore.
163
- delattr(layer, "w13_weight_scale")
164
- delattr(layer, "w2_weight_scale")
165
-
166
- if layer.activation == "swigluoai":
167
- # When using swigluoai, vLLM splits gmm output in a interleaved way.
168
- # However, interleaved split is not performant on TPU. Therefore,
169
- # we preprocess the weight so that splitting gmm output by middle
170
- # can still get the same result.
171
- w1_weight = w13_weight[:, ::2, :]
172
- w3_weight = w13_weight[:, 1::2, :]
173
- w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
174
-
175
- w1_bias = w13_bias[:, ::2]
176
- w3_bias = w13_bias[:, 1::2]
177
- w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
178
-
179
- if self.use_kernel:
180
- # Kernel expects:
181
- # w13: (num_experts, 2, hidden_size, intermediate_size)
182
- # w2: (num_experts, intermediate_size, hidden_size)
183
- # Current format:
184
- # w13_weight: (num_experts, 2*intermediate_size, hidden_size)
185
- # w2_weight: (num_experts, hidden_size, intermediate_size)
186
-
187
- w13_reshaped = w13_weight.reshape(num_experts, 2,
188
- intermediate_size, hidden_size)
189
-
190
- # Transpose non-constracting dim to right most dim
191
- w13_weight_transposed = jnp.swapaxes(w13_reshaped, 2, 3)
192
- w2_weight_transposed = jnp.swapaxes(w2_weight, 1, 2)
193
-
194
- # Apply EP sharding
195
- ep_sharding = NamedSharding(self.mesh, P("model"))
196
-
197
- w13_weight = jax.device_put(
198
- w13_weight_transposed, Format(Layout((0, 1, 2, 3)),
199
- ep_sharding))
200
- w2_weight = jax.device_put(w2_weight_transposed,
201
- Format(Layout((0, 1, 2)), ep_sharding))
202
-
203
- w13_bias = w13_bias.reshape(num_experts, 2, intermediate_size)
204
- w13_bias = jax.device_put(w13_bias,
205
- Format(Layout((0, 1, 2)), ep_sharding))
206
- w2_bias = jax.device_put(w2_bias,
207
- Format(Layout((0, 1)), ep_sharding))
208
-
209
- else:
210
- if layer.use_ep:
211
- ep_sharding = NamedSharding(self.mesh, P("model"))
212
- w13_weight = jax.device_put(
213
- w13_weight, Format(Layout((0, 1, 2)), ep_sharding))
214
- w2_weight = jax.device_put(
215
- w2_weight, Format(Layout((0, 1, 2)), ep_sharding))
216
-
217
- w13_bias = jax.device_put(w13_bias,
218
- Format(Layout((0, 1)), ep_sharding))
219
- w2_bias = jax.device_put(w2_bias,
220
- Format(Layout((0, 1)), ep_sharding))
221
-
222
- else:
223
- output_sizes = [intermediate_size, intermediate_size]
224
- n_shards = self.mesh.shape["model"]
225
- assert intermediate_size % n_shards == 0
151
+ @jax.jit
152
+ def process_mxfp4_moe_weights(
153
+ w13_weight: jax.Array,
154
+ w13_weight_scale: jax.Array,
155
+ w13_bias: jax.Array,
156
+ w2_weight: jax.Array,
157
+ w2_weight_scale: jax.Array,
158
+ w2_bias: jax.Array,
159
+ ) -> FusedMoEWeights:
160
+ # Dequantize fp4 weights into fp32.
161
+ w13_weight = dequantize_tensor_from_mxfp4_packed(
162
+ w13_weight, w13_weight_scale, 2)
163
+ w2_weight = dequantize_tensor_from_mxfp4_packed(
164
+ w2_weight, w2_weight_scale, 2)
165
+
166
+ w13_interleave = layer.activation == "swigluoai"
167
+ w13_reorder_size = get_mesh_shape_product(
168
+ self.mesh, ShardingAxisName.MLP_TENSOR)
169
+
170
+ weights = quantize_moe_weights(
171
+ FusedMoEWeights(
172
+ w13_weight=w13_weight,
173
+ w13_weight_scale=None,
174
+ w13_bias=w13_bias,
175
+ w2_weight=w2_weight,
176
+ w2_weight_scale=None,
177
+ w2_bias=w2_bias,
178
+ ),
179
+ jnp.float4_e2m1fn,
180
+ REQUANTIZED_BLOCK_SIZE,
181
+ )
182
+ return process_moe_weights(
183
+ weights,
184
+ moe_backend=self.moe_backend,
185
+ w13_reorder_size=w13_reorder_size,
186
+ w13_interleave=w13_interleave,
187
+ )
226
188
 
227
- w13_weight = reorder_concatenated_tensor_for_sharding(
228
- w13_weight,
229
- output_sizes,
230
- n_shards,
231
- dim=1,
232
- )
233
- w13_weight = jax.device_put(
234
- w13_weight,
235
- Format(Layout((0, 1, 2)),
236
- NamedSharding(self.mesh, P(None, "model", None))))
237
- w2_weight = jax.device_put(
238
- w2_weight,
239
- Format(Layout((0, 1, 2)),
240
- NamedSharding(self.mesh, P(None, None, "model"))))
189
+ weights = process_mxfp4_moe_weights(
190
+ w13_weight,
191
+ w13_weight_scale,
192
+ w13_bias,
193
+ w2_weight,
194
+ w2_weight_scale,
195
+ w2_bias,
196
+ )
197
+ weights = torch_view(
198
+ shard_moe_weights(weights, self.moe_backend, self.mesh))
241
199
 
242
- w13_bias = reorder_concatenated_tensor_for_sharding(
243
- w13_bias,
244
- output_sizes,
245
- n_shards,
246
- dim=1,
247
- )
248
- w13_bias = jax.device_put(
249
- w13_bias,
250
- Format(Layout((0, 1)),
251
- NamedSharding(self.mesh, P(None, "model"))))
252
- w2_bias = jax.device_put(
253
- w2_bias,
254
- Format(Layout((0, 1)),
255
- NamedSharding(self.mesh, P(None, None))))
200
+ layer.w13_weight = Parameter(weights.w13_weight, requires_grad=False)
201
+ layer.w2_weight = Parameter(weights.w2_weight, requires_grad=False)
256
202
 
257
- layer.w13_weight = Parameter(torch_view(w13_weight),
258
- requires_grad=False)
259
- layer.w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
203
+ layer.w13_weight_scale = Parameter(weights.w13_weight_scale,
204
+ requires_grad=False)
205
+ layer.w2_weight_scale = Parameter(weights.w2_weight_scale,
206
+ requires_grad=False)
260
207
 
261
- layer.w13_bias = Parameter(torch_view(w13_bias), requires_grad=False)
262
- layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
208
+ layer.w13_bias = Parameter(weights.w13_bias, requires_grad=False)
209
+ layer.w2_bias = Parameter(weights.w2_bias, requires_grad=False)
263
210
 
264
211
  def apply(
265
212
  self,
266
213
  layer: torch.nn.Module,
267
214
  x: torch.Tensor,
268
215
  router_logits: torch.Tensor,
269
- ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
270
- assert isinstance(layer, FusedMoE)
271
- if layer.scoring_func != "softmax":
272
- raise NotImplementedError(
273
- "Only softmax is supported for scoring_func")
274
-
275
- x = jax_view(x)
276
- w13_weight = jax_view(layer.w13_weight)
277
- w2_weight = jax_view(layer.w2_weight)
278
- w13_bias = jax_view(layer.w13_bias)
279
- w2_bias = jax_view(layer.w2_bias)
280
- gating_output = jax_view(router_logits)
281
-
282
- if self.use_kernel:
283
- output = fused_ep_moe(
284
- mesh=self.mesh,
285
- tokens=x,
286
- w1=w13_weight,
287
- w2=w2_weight,
288
- b1=w13_bias,
289
- b2=w2_bias,
290
- gating_output=gating_output,
291
- top_k=layer.top_k,
292
- ep_axis_name=self.ep_axis_name,
293
- renormalize_topk_logits=layer.renormalize,
294
- act_fn=layer.activation,
295
- **self.block_size,
296
- )
297
- else:
298
- output = fused_moe_func(
299
- hidden_states=x,
300
- w1=w13_weight,
301
- w2=w2_weight,
302
- w1_bias=w13_bias,
303
- w2_bias=w2_bias,
304
- gating_output=gating_output,
305
- topk=layer.top_k,
306
- renormalize=layer.renormalize,
307
- mesh=self.mesh,
308
- use_ep=layer.use_ep,
309
- activation=layer.activation,
310
- )
311
-
312
- return torch_view(output)
216
+ ) -> torch.Tensor:
217
+
218
+ return fused_moe_apply(
219
+ layer,
220
+ x,
221
+ router_logits,
222
+ self.moe_backend,
223
+ self.mesh,
224
+ self.extra_backend_kwargs,
225
+ )