tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (257) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +317 -34
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +406 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +320 -0
  64. tests/layers/vllm/test_unquantized.py +662 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +26 -6
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +25 -4
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +25 -12
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  154. tpu_inference/layers/common/quant_methods.py +15 -0
  155. tpu_inference/layers/common/quantization.py +282 -0
  156. tpu_inference/layers/common/sharding.py +32 -9
  157. tpu_inference/layers/common/utils.py +94 -0
  158. tpu_inference/layers/jax/__init__.py +13 -0
  159. tpu_inference/layers/jax/attention/__init__.py +13 -0
  160. tpu_inference/layers/jax/attention/attention.py +19 -6
  161. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  162. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  163. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  164. tpu_inference/layers/jax/base.py +14 -0
  165. tpu_inference/layers/jax/constants.py +13 -0
  166. tpu_inference/layers/jax/layers.py +14 -0
  167. tpu_inference/layers/jax/misc.py +14 -0
  168. tpu_inference/layers/jax/moe/__init__.py +13 -0
  169. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  170. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  171. tpu_inference/layers/jax/moe/moe.py +43 -3
  172. tpu_inference/layers/jax/pp_utils.py +53 -0
  173. tpu_inference/layers/jax/rope.py +14 -0
  174. tpu_inference/layers/jax/rope_interface.py +14 -0
  175. tpu_inference/layers/jax/sample/__init__.py +13 -0
  176. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  177. tpu_inference/layers/jax/sample/sampling.py +15 -1
  178. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  179. tpu_inference/layers/jax/transformer_block.py +14 -0
  180. tpu_inference/layers/vllm/__init__.py +13 -0
  181. tpu_inference/layers/vllm/attention.py +4 -4
  182. tpu_inference/layers/vllm/fused_moe.py +101 -494
  183. tpu_inference/layers/vllm/linear.py +64 -0
  184. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  185. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  186. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  187. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  188. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  189. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  191. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
  192. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
  193. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  194. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  195. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  196. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
  197. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  198. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
  199. tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
  200. tpu_inference/lora/__init__.py +13 -0
  201. tpu_inference/lora/torch_lora_ops.py +8 -13
  202. tpu_inference/models/__init__.py +13 -0
  203. tpu_inference/models/common/__init__.py +13 -0
  204. tpu_inference/models/common/model_loader.py +112 -35
  205. tpu_inference/models/jax/__init__.py +13 -0
  206. tpu_inference/models/jax/deepseek_v3.py +267 -157
  207. tpu_inference/models/jax/gpt_oss.py +26 -10
  208. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  209. tpu_inference/models/jax/llama3.py +99 -36
  210. tpu_inference/models/jax/llama4.py +14 -0
  211. tpu_inference/models/jax/llama_eagle3.py +18 -5
  212. tpu_inference/models/jax/llama_guard_4.py +15 -1
  213. tpu_inference/models/jax/qwen2.py +17 -2
  214. tpu_inference/models/jax/qwen2_5_vl.py +179 -51
  215. tpu_inference/models/jax/qwen3.py +17 -2
  216. tpu_inference/models/jax/utils/__init__.py +13 -0
  217. tpu_inference/models/jax/utils/file_utils.py +14 -0
  218. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  219. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  220. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
  221. tpu_inference/models/jax/utils/weight_utils.py +234 -155
  222. tpu_inference/models/vllm/__init__.py +13 -0
  223. tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
  224. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  225. tpu_inference/platforms/__init__.py +14 -0
  226. tpu_inference/platforms/tpu_platform.py +51 -72
  227. tpu_inference/runner/__init__.py +13 -0
  228. tpu_inference/runner/compilation_manager.py +180 -80
  229. tpu_inference/runner/kv_cache.py +54 -20
  230. tpu_inference/runner/kv_cache_manager.py +55 -33
  231. tpu_inference/runner/lora_utils.py +16 -1
  232. tpu_inference/runner/multimodal_manager.py +16 -2
  233. tpu_inference/runner/persistent_batch_manager.py +54 -2
  234. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  235. tpu_inference/runner/structured_decoding_manager.py +16 -3
  236. tpu_inference/runner/tpu_runner.py +124 -61
  237. tpu_inference/runner/utils.py +2 -2
  238. tpu_inference/spec_decode/__init__.py +13 -0
  239. tpu_inference/spec_decode/jax/__init__.py +13 -0
  240. tpu_inference/spec_decode/jax/eagle3.py +84 -22
  241. tpu_inference/tpu_info.py +14 -0
  242. tpu_inference/utils.py +72 -44
  243. tpu_inference/worker/__init__.py +13 -0
  244. tpu_inference/worker/tpu_worker.py +66 -52
  245. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
  246. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  247. tpu_inference/layers/vllm/linear_common.py +0 -186
  248. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  249. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  250. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  251. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  252. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  253. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  254. tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
  255. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  256. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  257. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,320 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import tempfile
16
+ from unittest import mock
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+ import pytest
21
+ import torch
22
+ import torchax
23
+ from jax._src import test_util as jtu
24
+ from jax.sharding import PartitionSpec
25
+ from torchax.ops.mappings import j2t, t2j
26
+ from vllm.config import ParallelConfig, set_current_vllm_config
27
+ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
28
+ init_distributed_environment)
29
+ from vllm.engine.arg_utils import EngineArgs
30
+ from vllm.forward_context import set_forward_context
31
+ from vllm.model_executor.layers.fused_moe.layer import FusedMoE
32
+
33
+ from tpu_inference.layers.vllm.fused_moe import FusedMoEBackend
34
+ from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
35
+ from tpu_inference.layers.vllm.quantization.mxfp4 import (VllmMxfp4Config,
36
+ VllmMxfp4MoEMethod)
37
+
38
+ from . import utils as test_utils
39
+
40
+ P = PartitionSpec
41
+ MODELS = ["openai/gpt-oss-20b"]
42
+ MXFP4_BLOCK_SIZE = 32
43
+
44
+ if not jtu.is_device_tpu_at_least(version=7):
45
+ pytest.skip(allow_module_level=True, reason="Expected TPUv7+")
46
+
47
+
48
+ def quantize_to_mxfp4(weight: torch.tensor):
49
+ # Utilize JAX because native support for e2m1 makes it easier to work with.
50
+ weight = t2j(weight)
51
+ e2m1_finfo = jnp.finfo(jnp.float4_e2m1fn)
52
+ dtype_min = float(e2m1_finfo.min)
53
+ dtype_max = float(e2m1_finfo.max)
54
+
55
+ # Do a subchannel quantization where block size is 32.
56
+ weight_shape = weight.shape
57
+ weight_block = weight.reshape(weight_shape[:-1] + (-1, MXFP4_BLOCK_SIZE))
58
+ abs_max = jnp.max(jnp.abs(weight_block), axis=-1, keepdims=True)
59
+ scale = abs_max / dtype_max
60
+
61
+ weight_q = jnp.clip(weight_block / scale, dtype_min, dtype_max)
62
+ weight_q = weight_q.astype(jnp.float4_e2m1fn).reshape(weight_shape[:-1] +
63
+ (-1, 2))
64
+ weight_packed = jax.lax.bitcast_convert_type(weight_q, jnp.uint8)
65
+
66
+ # We convert scale into e8m0 manually because there is no hardware support.
67
+ e8m0_finfo = jnp.finfo(jnp.float8_e8m0fnu)
68
+ _, scale_exp = jnp.frexp(scale.squeeze(axis=-1))
69
+ # Subtract by one sinced e8m0 has no decimal
70
+ scale_exp -= 1
71
+ scale_exp = (scale_exp - e8m0_finfo.minexp).astype(jnp.uint8)
72
+
73
+ return j2t(weight_packed), j2t(scale_exp)
74
+
75
+
76
+ @pytest.fixture(autouse=True)
77
+ def setup_environment():
78
+ # This is a fake config used for init dist env.
79
+ # RowParallelLinear needs dist env to be initialized.
80
+ engine_args = EngineArgs(
81
+ model=MODELS[0],
82
+ max_model_len=64,
83
+ max_num_batched_tokens=64,
84
+ max_num_seqs=4,
85
+ load_format='dummy',
86
+ )
87
+
88
+ vllm_config = engine_args.create_engine_config()
89
+
90
+ with set_current_vllm_config(vllm_config):
91
+ temp_file = tempfile.mkstemp()[1]
92
+ init_distributed_environment(
93
+ 1,
94
+ 0,
95
+ local_rank=0,
96
+ distributed_init_method=f"file://{temp_file}",
97
+ backend="gloo")
98
+ ensure_model_parallel_initialized(1, 1)
99
+
100
+
101
+ @pytest.mark.parametrize("model", MODELS)
102
+ @pytest.mark.parametrize("mesh", [
103
+ test_utils.get_spmd_mesh(1),
104
+ test_utils.get_spmd_mesh(jax.local_device_count())
105
+ ])
106
+ def test_quant_override(model, mesh):
107
+
108
+ engine_args = EngineArgs(
109
+ model=model,
110
+ max_model_len=64,
111
+ max_num_batched_tokens=64,
112
+ max_num_seqs=4,
113
+ load_format='dummy',
114
+ )
115
+ vllm_config = engine_args.create_engine_config()
116
+ vllm_config.model_config.dtype = torch.bfloat16
117
+
118
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
119
+ assert isinstance(quant_config, VllmMxfp4Config)
120
+ assert quant_config.vllm_config == vllm_config
121
+ assert quant_config.mesh == mesh
122
+
123
+
124
+ @pytest.mark.parametrize("num_devices", [1, 2])
125
+ @pytest.mark.parametrize("num_tokens", [8])
126
+ @pytest.mark.parametrize("intermediate_size", [1024])
127
+ @pytest.mark.parametrize("hidden_size", [128])
128
+ @pytest.mark.parametrize("num_experts", [8])
129
+ @pytest.mark.parametrize("topk", [2])
130
+ @pytest.mark.parametrize("use_ep", [True, False])
131
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
132
+ def test_mxfp4_fused_moe(num_devices, num_tokens, intermediate_size,
133
+ hidden_size, num_experts, topk, use_ep,
134
+ enable_attn_dp):
135
+ # Skip if enable_attn_dp is True but we don't have enough devices
136
+ if enable_attn_dp and num_devices < 2:
137
+ pytest.skip("enable_attn_dp requires at least 2 devices")
138
+
139
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
140
+ torch.manual_seed(42)
141
+ dtype = torch.bfloat16
142
+
143
+ a = torch.randn((num_tokens, hidden_size), dtype=dtype) / 10
144
+ w1 = torch.randn(
145
+ (num_experts, 2 * intermediate_size, hidden_size), dtype=dtype) / 10
146
+ w2 = torch.randn(
147
+ (num_experts, hidden_size, intermediate_size), dtype=dtype) / 10
148
+ w1_weight, w1_weight_scale = quantize_to_mxfp4(w1)
149
+ w2_weight, w2_weight_scale = quantize_to_mxfp4(w2)
150
+
151
+ w1_bias = torch.randn(
152
+ (num_experts, 2 * intermediate_size), dtype=dtype) / 10
153
+ w2_bias = torch.randn((num_experts, hidden_size), dtype=dtype) / 10
154
+ score = torch.randn((num_tokens, num_experts), dtype=dtype)
155
+
156
+ engine_args = EngineArgs(
157
+ model=MODELS[0],
158
+ max_model_len=64,
159
+ max_num_batched_tokens=64,
160
+ max_num_seqs=4,
161
+ load_format='dummy',
162
+ )
163
+ vllm_config = engine_args.create_engine_config()
164
+ vllm_config.model_config.dtype = dtype
165
+ vllm_config.parallel_config = ParallelConfig(
166
+ tensor_parallel_size=mesh.devices.size, enable_expert_parallel=use_ep)
167
+
168
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
169
+ with set_current_vllm_config(vllm_config):
170
+ vllm_fused_moe = FusedMoE(
171
+ num_experts=num_experts,
172
+ top_k=topk,
173
+ hidden_size=hidden_size,
174
+ intermediate_size=intermediate_size,
175
+ reduce_results=False,
176
+ renormalize=False,
177
+ tp_size=1,
178
+ dp_size=1,
179
+ quant_config=quant_config,
180
+ has_bias=True,
181
+ )
182
+ vllm_fused_moe.moe_parallel_config.use_ep = use_ep
183
+ vllm_fused_moe.w13_weight.data = w1_weight
184
+ vllm_fused_moe.w2_weight.data = w2_weight
185
+ vllm_fused_moe.w13_weight_scale.data = w1_weight_scale
186
+ vllm_fused_moe.w2_weight_scale.data = w2_weight_scale
187
+ vllm_fused_moe.w13_bias.data = w1_bias
188
+ vllm_fused_moe.w2_bias.data = w2_bias
189
+
190
+ expected = test_utils.ref_moe(a, score, w1, w2, w1_bias, w2_bias,
191
+ vllm_fused_moe.top_k,
192
+ vllm_fused_moe.renormalize,
193
+ vllm_fused_moe.activation)
194
+
195
+ with torchax.default_env(), set_forward_context(None, vllm_config):
196
+ assert isinstance(vllm_fused_moe.quant_method, VllmMxfp4MoEMethod)
197
+ if use_ep:
198
+ assert vllm_fused_moe.quant_method.moe_backend == FusedMoEBackend.GMM_EP
199
+ else:
200
+ assert vllm_fused_moe.quant_method.moe_backend == FusedMoEBackend.GMM_TP
201
+
202
+ jax_a = a.to('jax')
203
+ score = score.to('jax')
204
+
205
+ vllm_fused_moe.quant_method.process_weights_after_loading(
206
+ vllm_fused_moe)
207
+ actual = vllm_fused_moe(jax_a, score)
208
+
209
+ torch.testing.assert_close(expected,
210
+ actual,
211
+ check_device=False,
212
+ atol=1e-1,
213
+ rtol=1e-1)
214
+
215
+
216
+ @pytest.mark.parametrize("num_devices", [1, 2])
217
+ @pytest.mark.parametrize("num_tokens", [8])
218
+ @pytest.mark.parametrize("intermediate_size", [512])
219
+ @pytest.mark.parametrize("hidden_size", [1024])
220
+ @pytest.mark.parametrize("num_experts", [8])
221
+ @pytest.mark.parametrize("topk", [2])
222
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
223
+ @mock.patch("os.environ", {"USE_MOE_EP_KERNEL": "1"})
224
+ def test_mxfp4_fused_moe_use_kernel(num_devices, num_tokens, intermediate_size,
225
+ hidden_size, num_experts, topk,
226
+ enable_attn_dp):
227
+ # Skip if enable_attn_dp is True but we don't have enough devices
228
+ if enable_attn_dp and num_devices < 2:
229
+ pytest.skip("enable_attn_dp requires at least 2 devices")
230
+
231
+ # Skip attn_dp tests for fused_moe_use_kernel since the kernel only supports 2D mesh
232
+ if enable_attn_dp:
233
+ pytest.skip(
234
+ "fused_moe kernel does not support attn_dp (requires 2D mesh)")
235
+
236
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
237
+
238
+ torch.manual_seed(42)
239
+ dtype = torch.bfloat16
240
+
241
+ a = torch.randn((num_tokens, hidden_size), dtype=dtype) / 10
242
+ w1 = torch.randn(
243
+ (num_experts, 2 * intermediate_size, hidden_size), dtype=dtype) / 10
244
+ w2 = torch.randn(
245
+ (num_experts, hidden_size, intermediate_size), dtype=dtype) / 10
246
+ w1_weight, w1_weight_scale = quantize_to_mxfp4(w1)
247
+ w2_weight, w2_weight_scale = quantize_to_mxfp4(w2)
248
+
249
+ w1_bias = torch.randn(
250
+ (num_experts, 2 * intermediate_size), dtype=dtype) / 10
251
+ w2_bias = torch.randn((num_experts, hidden_size), dtype=dtype) / 10
252
+ score = torch.randn((num_tokens, num_experts), dtype=dtype)
253
+
254
+ engine_args = EngineArgs(
255
+ model=MODELS[0],
256
+ max_model_len=64,
257
+ max_num_batched_tokens=64,
258
+ max_num_seqs=4,
259
+ load_format='dummy',
260
+ )
261
+ vllm_config = engine_args.create_engine_config()
262
+ vllm_config.model_config.dtype = dtype
263
+ vllm_config.parallel_config = ParallelConfig(
264
+ tensor_parallel_size=mesh.devices.size, enable_expert_parallel=True)
265
+
266
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
267
+ with set_current_vllm_config(vllm_config):
268
+ vllm_fused_moe = FusedMoE(
269
+ num_experts=num_experts,
270
+ top_k=topk,
271
+ hidden_size=hidden_size,
272
+ intermediate_size=intermediate_size,
273
+ reduce_results=False,
274
+ renormalize=False,
275
+ tp_size=1,
276
+ dp_size=1,
277
+ quant_config=quant_config,
278
+ has_bias=True,
279
+ )
280
+ vllm_fused_moe.moe_parallel_config.use_ep = True
281
+
282
+ vllm_fused_moe.w13_weight.data = w1_weight
283
+ vllm_fused_moe.w2_weight.data = w2_weight
284
+ vllm_fused_moe.w13_weight_scale.data = w1_weight_scale
285
+ vllm_fused_moe.w2_weight_scale.data = w2_weight_scale
286
+ vllm_fused_moe.w13_bias.data = w1_bias
287
+ vllm_fused_moe.w2_bias.data = w2_bias
288
+
289
+ expected = test_utils.ref_moe(a, score, w1, w2, w1_bias, w2_bias,
290
+ vllm_fused_moe.top_k,
291
+ vllm_fused_moe.renormalize,
292
+ vllm_fused_moe.activation)
293
+
294
+ with torchax.default_env(), set_forward_context(None, vllm_config):
295
+ assert isinstance(vllm_fused_moe.quant_method, VllmMxfp4MoEMethod)
296
+ assert vllm_fused_moe.quant_method.moe_backend == FusedMoEBackend.FUSED_MOE
297
+
298
+ jax_a = a.to('jax')
299
+ score = score.to('jax')
300
+
301
+ vllm_fused_moe.quant_method.process_weights_after_loading(
302
+ vllm_fused_moe)
303
+ vllm_fused_moe.quant_method.extra_backend_kwargs.update({
304
+ "bt": 32,
305
+ "bf": 512,
306
+ "bd1": 1024,
307
+ "bd2": 1024,
308
+ "btc": 32,
309
+ "bfc": 512,
310
+ "bd1c": 1024,
311
+ "bd2c": 1024,
312
+ })
313
+
314
+ actual = vllm_fused_moe(jax_a, score)
315
+
316
+ torch.testing.assert_close(expected,
317
+ actual,
318
+ check_device=False,
319
+ atol=2e-1,
320
+ rtol=2e-1)