tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (257) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +317 -34
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +406 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +320 -0
  64. tests/layers/vllm/test_unquantized.py +662 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +26 -6
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +25 -4
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +25 -12
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  154. tpu_inference/layers/common/quant_methods.py +15 -0
  155. tpu_inference/layers/common/quantization.py +282 -0
  156. tpu_inference/layers/common/sharding.py +32 -9
  157. tpu_inference/layers/common/utils.py +94 -0
  158. tpu_inference/layers/jax/__init__.py +13 -0
  159. tpu_inference/layers/jax/attention/__init__.py +13 -0
  160. tpu_inference/layers/jax/attention/attention.py +19 -6
  161. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  162. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  163. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  164. tpu_inference/layers/jax/base.py +14 -0
  165. tpu_inference/layers/jax/constants.py +13 -0
  166. tpu_inference/layers/jax/layers.py +14 -0
  167. tpu_inference/layers/jax/misc.py +14 -0
  168. tpu_inference/layers/jax/moe/__init__.py +13 -0
  169. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  170. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  171. tpu_inference/layers/jax/moe/moe.py +43 -3
  172. tpu_inference/layers/jax/pp_utils.py +53 -0
  173. tpu_inference/layers/jax/rope.py +14 -0
  174. tpu_inference/layers/jax/rope_interface.py +14 -0
  175. tpu_inference/layers/jax/sample/__init__.py +13 -0
  176. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  177. tpu_inference/layers/jax/sample/sampling.py +15 -1
  178. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  179. tpu_inference/layers/jax/transformer_block.py +14 -0
  180. tpu_inference/layers/vllm/__init__.py +13 -0
  181. tpu_inference/layers/vllm/attention.py +4 -4
  182. tpu_inference/layers/vllm/fused_moe.py +101 -494
  183. tpu_inference/layers/vllm/linear.py +64 -0
  184. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  185. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  186. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  187. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  188. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  189. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  191. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
  192. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
  193. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  194. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  195. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  196. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
  197. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  198. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
  199. tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
  200. tpu_inference/lora/__init__.py +13 -0
  201. tpu_inference/lora/torch_lora_ops.py +8 -13
  202. tpu_inference/models/__init__.py +13 -0
  203. tpu_inference/models/common/__init__.py +13 -0
  204. tpu_inference/models/common/model_loader.py +112 -35
  205. tpu_inference/models/jax/__init__.py +13 -0
  206. tpu_inference/models/jax/deepseek_v3.py +267 -157
  207. tpu_inference/models/jax/gpt_oss.py +26 -10
  208. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  209. tpu_inference/models/jax/llama3.py +99 -36
  210. tpu_inference/models/jax/llama4.py +14 -0
  211. tpu_inference/models/jax/llama_eagle3.py +18 -5
  212. tpu_inference/models/jax/llama_guard_4.py +15 -1
  213. tpu_inference/models/jax/qwen2.py +17 -2
  214. tpu_inference/models/jax/qwen2_5_vl.py +179 -51
  215. tpu_inference/models/jax/qwen3.py +17 -2
  216. tpu_inference/models/jax/utils/__init__.py +13 -0
  217. tpu_inference/models/jax/utils/file_utils.py +14 -0
  218. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  219. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  220. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
  221. tpu_inference/models/jax/utils/weight_utils.py +234 -155
  222. tpu_inference/models/vllm/__init__.py +13 -0
  223. tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
  224. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  225. tpu_inference/platforms/__init__.py +14 -0
  226. tpu_inference/platforms/tpu_platform.py +51 -72
  227. tpu_inference/runner/__init__.py +13 -0
  228. tpu_inference/runner/compilation_manager.py +180 -80
  229. tpu_inference/runner/kv_cache.py +54 -20
  230. tpu_inference/runner/kv_cache_manager.py +55 -33
  231. tpu_inference/runner/lora_utils.py +16 -1
  232. tpu_inference/runner/multimodal_manager.py +16 -2
  233. tpu_inference/runner/persistent_batch_manager.py +54 -2
  234. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  235. tpu_inference/runner/structured_decoding_manager.py +16 -3
  236. tpu_inference/runner/tpu_runner.py +124 -61
  237. tpu_inference/runner/utils.py +2 -2
  238. tpu_inference/spec_decode/__init__.py +13 -0
  239. tpu_inference/spec_decode/jax/__init__.py +13 -0
  240. tpu_inference/spec_decode/jax/eagle3.py +84 -22
  241. tpu_inference/tpu_info.py +14 -0
  242. tpu_inference/utils.py +72 -44
  243. tpu_inference/worker/__init__.py +13 -0
  244. tpu_inference/worker/tpu_worker.py +66 -52
  245. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
  246. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  247. tpu_inference/layers/vllm/linear_common.py +0 -186
  248. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  249. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  250. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  251. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  252. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  253. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  254. tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
  255. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  256. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  257. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,29 @@
1
- from typing import Callable, Optional, Union
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
2
16
 
3
17
  import jax
4
18
  import jax.numpy as jnp
5
19
  import torch
6
- from jax.experimental.layout import Format, Layout
7
- from jax.sharding import Mesh, NamedSharding, PartitionSpec
20
+ from jax.sharding import Mesh, PartitionSpec
8
21
  from torch.nn.parameter import Parameter
9
- from torchax.interop import jax_view, torch_view
22
+ from torchax.interop import torch_view
10
23
  from torchax.ops.mappings import t2j
11
- from vllm.logger import init_logger
24
+ from vllm.attention.layer import Attention
12
25
  from vllm.model_executor.layers.fused_moe.config import (
13
- FusedMoEConfig, FusedMoEQuantConfig, biased_moe_quant_config)
26
+ FusedMoEConfig, FusedMoEQuantConfig, mxfp4_w4a16_moe_quant_config)
14
27
  from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
15
28
  FusedMoEMethodBase)
16
29
  from vllm.model_executor.layers.linear import LinearBase
@@ -26,48 +39,30 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import \
26
39
 
27
40
  from tpu_inference.layers.common.quant_methods import (MXFP4,
28
41
  get_tpu_quant_method)
29
- from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
30
- from tpu_inference.layers.vllm.linear_common import \
31
- reorder_concatenated_tensor_for_sharding
32
- from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
42
+ from tpu_inference.layers.common.quantization import \
43
+ dequantize_tensor_from_mxfp4_packed
44
+ from tpu_inference.layers.common.sharding import ShardingAxisName
45
+ from tpu_inference.layers.vllm.fused_moe import (FusedMoEBackend,
46
+ fused_moe_apply,
47
+ select_moe_backend)
48
+ from tpu_inference.layers.vllm.process_weights.fused_moe_weights import (
49
+ FusedMoEWeights, process_moe_weights, quantize_moe_weights,
50
+ shard_moe_weights)
51
+ from tpu_inference.layers.vllm.quantization.configs import VllmQuantConfig
33
52
  from tpu_inference.layers.vllm.quantization.unquantized import \
34
53
  VllmUnquantizedLinearMethod
54
+ from tpu_inference.logger import init_logger
55
+ from tpu_inference.utils import get_mesh_shape_product
35
56
 
36
- MXFP4_BLOCK_SIZE = 32
57
+ REQUANTIZED_BLOCK_SIZE = 512
37
58
 
38
59
  P = PartitionSpec
39
- logger = init_logger(__name__)
40
-
41
-
42
- # TODO(kyuyeunk): Move these functions into a common utility file.
43
- def u8_unpack_e2m1(u8_packed_e2m1: jax.Array) -> jax.Array:
44
- assert u8_packed_e2m1.dtype == jnp.uint8
45
- e2m1 = jax.lax.bitcast_convert_type(u8_packed_e2m1, jnp.float4_e2m1fn)
46
- # bitcast creates one more dimension that splits 8 bits into two e2m1.
47
- # we flatten them with the last dim.
48
- return jnp.reshape(e2m1, e2m1.shape[:-2] + (-1, ))
49
-
50
-
51
- def e8m0_to_fp32(u8: jax.Array) -> jax.Array:
52
- e8_finfo = jnp.finfo(jnp.float8_e8m0fnu)
53
- exponents = u8.astype(jnp.int32) + e8_finfo.minexp
54
- ones = jnp.ones_like(u8, dtype=jnp.float32)
55
- return jnp.ldexp(ones, exponents)
56
60
 
57
-
58
- def dequantize_block_weight(weight: jax.Array,
59
- scale: jax.Array,
60
- block_size: int,
61
- out_dtype: jnp.dtype = jnp.bfloat16) -> jax.Array:
62
- orig_shape = weight.shape
63
- weight_block = weight.reshape(orig_shape[:-1] + (-1, block_size))
64
- weight_dequantized = weight_block.astype(jnp.float32) * jnp.expand_dims(
65
- scale, -1)
66
- return weight_dequantized.reshape(orig_shape).astype(out_dtype)
61
+ logger = init_logger(__name__)
67
62
 
68
63
 
69
64
  @register_quantization_config(get_tpu_quant_method(MXFP4))
70
- class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
65
+ class VllmMxfp4Config(Mxfp4Config, VllmQuantConfig):
71
66
 
72
67
  @classmethod
73
68
  def get_name(cls):
@@ -75,7 +70,6 @@ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
75
70
 
76
71
  def get_quant_method(self, layer: torch.nn.Module,
77
72
  prefix: str) -> Optional["QuantizeMethodBase"]:
78
- from vllm.attention.layer import Attention # Avoid circular import
79
73
 
80
74
  if isinstance(layer, LinearBase):
81
75
  linear_config = self.get_linear_config(layer)
@@ -85,17 +79,14 @@ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
85
79
  fused_mapping=self.packed_modules_mapping,
86
80
  ):
87
81
  return VllmUnquantizedLinearMethod(linear_config)
88
- # TODO: Add support for MXFP4 Linear Method.
89
- # MXFP4 LinearMethod is available in AMD-Quark, refer to that
90
- # implementation if you are interested in enabling MXFP4 here.
91
82
  logger.warning_once(
92
83
  "MXFP4 linear layer is not implemented - falling back to "
93
84
  "UnquantizedLinearMethod.")
94
85
  return VllmUnquantizedLinearMethod(linear_config)
95
86
  elif isinstance(layer, FusedMoE):
96
- return VllmMxfp4MoEMethod(layer.moe_config, self.mesh)
87
+ moe_config = self.get_moe_config(layer)
88
+ return VllmMxfp4MoEMethod(moe_config, self.mesh)
97
89
  elif isinstance(layer, Attention):
98
- # TODO: Add support for MXFP4 Attention.
99
90
  logger.warning_once("MXFP4 attention layer is not implemented. "
100
91
  "Skipping quantization for this layer.")
101
92
  return None
@@ -103,164 +94,132 @@ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
103
94
 
104
95
  class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
105
96
 
106
- def __init__(self, moe: FusedMoEConfig, mesh: Mesh):
97
+ def __init__(
98
+ self,
99
+ moe: FusedMoEConfig,
100
+ mesh: Mesh,
101
+ ep_axis_name: str = "model",
102
+ ):
107
103
  FusedMoEMethodBase.__init__(self, moe)
108
104
 
109
105
  # We piggyback on triton implementation as it applies minimal hardware
110
106
  # specific post processing to the weights.
111
107
  self.mxfp4_backend = Mxfp4Backend.TRITON
108
+
112
109
  self.mesh = mesh
110
+ self.moe_backend = select_moe_backend(self.moe)
111
+
112
+ self.extra_backend_kwargs = {}
113
+ if self.moe_backend == FusedMoEBackend.FUSED_MOE:
114
+ # When fused moe kernle is used, we pass extra arguments like
115
+ # tuned block sizes to the kernel.
116
+ self.extra_backend_kwargs = dict(
117
+ subc_quant_wsz=REQUANTIZED_BLOCK_SIZE,
118
+ ep_axis_name=ep_axis_name,
119
+ # TODO: Use autotune table once we have it.
120
+ bt=256,
121
+ bf=1024,
122
+ bd1=1024,
123
+ bd2=1024,
124
+ btc=256,
125
+ bfc=1024,
126
+ bd1c=1024,
127
+ bd2c=1024,
128
+ )
113
129
 
114
130
  def get_fused_moe_quant_config(
115
131
  self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None:
116
- # Because we have dequantized weights, we only need biased moe config.
117
- # TODO(kyuyeunk): Add native support for MXFP4.
118
- return biased_moe_quant_config(
119
- layer.w13_bias,
120
- layer.w2_bias,
132
+ return mxfp4_w4a16_moe_quant_config(
133
+ w1_scale=layer.w13_weight_scale,
134
+ w2_scale=layer.w2_weight_scale,
135
+ w1_bias=layer.w13_bias,
136
+ w2_bias=layer.w2_bias,
121
137
  )
122
138
 
123
139
  def process_weights_after_loading(self, layer: torch.nn.Module):
124
140
  assert isinstance(layer, FusedMoE)
141
+ assert layer.moe_config.has_bias, "mxfp4 quantization alwyas use bias."
125
142
 
126
- w13_weight = u8_unpack_e2m1(t2j(layer.w13_weight, use_dlpack=False))
127
- w13_weight_scale = e8m0_to_fp32(
128
- t2j(layer.w13_weight_scale, use_dlpack=False))
143
+ w13_weight = t2j(layer.w13_weight, use_dlpack=False)
144
+ w13_weight_scale = t2j(layer.w13_weight_scale, use_dlpack=False)
129
145
  w13_bias = t2j(layer.w13_bias, use_dlpack=False)
130
146
 
131
- w2_weight = u8_unpack_e2m1(t2j(layer.w2_weight, use_dlpack=False))
132
- w2_weight_scale = e8m0_to_fp32(
133
- t2j(layer.w2_weight_scale, use_dlpack=False))
147
+ w2_weight = t2j(layer.w2_weight, use_dlpack=False)
148
+ w2_weight_scale = t2j(layer.w2_weight_scale, use_dlpack=False)
134
149
  w2_bias = t2j(layer.w2_bias, use_dlpack=False)
135
150
 
136
- # We dequantize fp4 weights into bf16.
137
- # TODO(kyuyeunk): Add native support for MXFP4.
138
- w13_weight = dequantize_block_weight(w13_weight, w13_weight_scale,
139
- MXFP4_BLOCK_SIZE, jnp.bfloat16)
140
- w2_weight = dequantize_block_weight(w2_weight, w2_weight_scale,
141
- MXFP4_BLOCK_SIZE, jnp.bfloat16)
142
-
143
- # Because we have dequantized weights, scales are not used anymore.
144
- delattr(layer, "w13_weight_scale")
145
- delattr(layer, "w2_weight_scale")
146
-
147
- if layer.activation == "swigluoai":
148
- # When using swigluoai, vLLM splits gmm output in a interleaved way.
149
- # However, interleaved split is not performant on TPU. Therefore,
150
- # we preprocess the weight so that splitting gmm output by middle
151
- # can still get the same result.
152
- w1_weight = w13_weight[:, ::2, :]
153
- w3_weight = w13_weight[:, 1::2, :]
154
- w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
155
-
156
- w1_bias = w13_bias[:, ::2]
157
- w3_bias = w13_bias[:, 1::2]
158
- w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
159
-
160
- # TODO(kyuyeunk): Add weight processing logic for the new kernel.
161
- if layer.use_ep:
162
- w13_weight = jax.device_put(
163
- w13_weight,
164
- Format(Layout((0, 1, 2)),
165
- NamedSharding(self.mesh, P("model", None, None))))
166
- w2_weight = jax.device_put(
167
- w2_weight,
168
- Format(Layout((0, 1, 2)),
169
- NamedSharding(self.mesh, P("model", None, None))))
170
-
171
- w13_bias = jax.device_put(
172
- w13_bias,
173
- Format(Layout((0, 1)),
174
- NamedSharding(self.mesh, P("model", None))))
175
- w2_bias = jax.device_put(
176
- w2_bias,
177
- Format(Layout((0, 1)),
178
- NamedSharding(self.mesh, P("model", None))))
179
-
180
- else:
181
- intermediate_size = w13_weight.shape[1] // 2
182
- assert intermediate_size == w2_weight.shape[-1]
183
- output_sizes = [intermediate_size, intermediate_size]
184
- n_shards = self.mesh.shape["model"]
185
- assert intermediate_size % n_shards == 0
186
- w13_weight = reorder_concatenated_tensor_for_sharding(w13_weight,
187
- output_sizes,
188
- n_shards,
189
- dim=1)
190
- w13_weight = jax.device_put(
191
- w13_weight,
192
- Format(Layout((0, 1, 2)),
193
- NamedSharding(self.mesh, P(None, "model", None))))
194
- w2_weight = jax.device_put(
195
- w2_weight,
196
- Format(Layout((0, 1, 2)),
197
- NamedSharding(self.mesh, P(None, None, "model"))))
198
-
199
- w13_bias = reorder_concatenated_tensor_for_sharding(w13_bias,
200
- output_sizes,
201
- n_shards,
202
- dim=1)
203
- w13_bias = jax.device_put(
204
- w13_bias,
205
- Format(Layout((0, 1)),
206
- NamedSharding(self.mesh, P(None, "model"))))
207
- w2_bias = jax.device_put(
208
- w2_bias,
209
- Format(Layout((0, 1)), NamedSharding(self.mesh, P(None,
210
- None))))
151
+ @jax.jit
152
+ def process_mxfp4_moe_weights(
153
+ w13_weight: jax.Array,
154
+ w13_weight_scale: jax.Array,
155
+ w13_bias: jax.Array,
156
+ w2_weight: jax.Array,
157
+ w2_weight_scale: jax.Array,
158
+ w2_bias: jax.Array,
159
+ ) -> FusedMoEWeights:
160
+ # Dequantize fp4 weights into fp32.
161
+ w13_weight = dequantize_tensor_from_mxfp4_packed(
162
+ w13_weight, w13_weight_scale, 2)
163
+ w2_weight = dequantize_tensor_from_mxfp4_packed(
164
+ w2_weight, w2_weight_scale, 2)
165
+
166
+ w13_interleave = layer.activation == "swigluoai"
167
+ w13_reorder_size = get_mesh_shape_product(
168
+ self.mesh, ShardingAxisName.MLP_TENSOR)
169
+
170
+ weights = quantize_moe_weights(
171
+ FusedMoEWeights(
172
+ w13_weight=w13_weight,
173
+ w13_weight_scale=None,
174
+ w13_bias=w13_bias,
175
+ w2_weight=w2_weight,
176
+ w2_weight_scale=None,
177
+ w2_bias=w2_bias,
178
+ ),
179
+ jnp.float4_e2m1fn,
180
+ REQUANTIZED_BLOCK_SIZE,
181
+ )
182
+ return process_moe_weights(
183
+ weights,
184
+ moe_backend=self.moe_backend,
185
+ w13_reorder_size=w13_reorder_size,
186
+ w13_interleave=w13_interleave,
187
+ )
188
+
189
+ weights = process_mxfp4_moe_weights(
190
+ w13_weight,
191
+ w13_weight_scale,
192
+ w13_bias,
193
+ w2_weight,
194
+ w2_weight_scale,
195
+ w2_bias,
196
+ )
197
+ weights = torch_view(
198
+ shard_moe_weights(weights, self.moe_backend, self.mesh))
211
199
 
212
- layer.w13_weight = Parameter(torch_view(w13_weight),
213
- requires_grad=False)
214
- layer.w13_bias = Parameter(torch_view(w13_bias), requires_grad=False)
200
+ layer.w13_weight = Parameter(weights.w13_weight, requires_grad=False)
201
+ layer.w2_weight = Parameter(weights.w2_weight, requires_grad=False)
215
202
 
216
- layer.w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
217
- layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
203
+ layer.w13_weight_scale = Parameter(weights.w13_weight_scale,
204
+ requires_grad=False)
205
+ layer.w2_weight_scale = Parameter(weights.w2_weight_scale,
206
+ requires_grad=False)
218
207
 
219
- pass
208
+ layer.w13_bias = Parameter(weights.w13_bias, requires_grad=False)
209
+ layer.w2_bias = Parameter(weights.w2_bias, requires_grad=False)
220
210
 
221
211
  def apply(
222
212
  self,
223
213
  layer: torch.nn.Module,
224
214
  x: torch.Tensor,
225
215
  router_logits: torch.Tensor,
226
- top_k: int,
227
- renormalize: bool,
228
- use_grouped_topk: bool = False,
229
- topk_group: Optional[int] = None,
230
- num_expert_group: Optional[int] = None,
231
- global_num_experts: int = -1,
232
- expert_map: Optional[torch.Tensor] = None,
233
- custom_routing_function: Optional[Callable] = None,
234
- scoring_func: str = "softmax",
235
- routed_scaling_factor: float = 1.0,
236
- e_score_correction_bias: Optional[torch.Tensor] = None,
237
- apply_router_weight_on_input: bool = False,
238
- activation: str = "silu",
239
- enable_eplb: bool = False,
240
- expert_load_view: Optional[torch.Tensor] = None,
241
- logical_to_physical_map: Optional[torch.Tensor] = None,
242
- logical_replica_count: Optional[torch.Tensor] = None,
243
- ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
244
- assert isinstance(layer, FusedMoE)
245
- if scoring_func != "softmax":
246
- raise NotImplementedError(
247
- "Only softmax is supported for scoring_func")
248
-
249
- # Use the original implementation
250
- output = fused_moe_func_padded(
251
- jax_view(x),
252
- jax_view(layer.w13_weight),
253
- jax_view(layer.w2_weight),
254
- jax_view(layer.w13_bias) if self.moe.has_bias else None,
255
- jax_view(layer.w2_bias) if self.moe.has_bias else None,
256
- jax_view(router_logits),
257
- topk=top_k,
258
- global_num_experts=global_num_experts,
259
- renormalize=renormalize,
260
- reduce_results=layer.reduce_results,
261
- mesh=self.mesh,
262
- use_ep=layer.use_ep,
263
- activation=activation,
216
+ ) -> torch.Tensor:
217
+
218
+ return fused_moe_apply(
219
+ layer,
220
+ x,
221
+ router_logits,
222
+ self.moe_backend,
223
+ self.mesh,
224
+ self.extra_backend_kwargs,
264
225
  )
265
-
266
- return torch_view(output)