tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (250) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +78 -1
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +1 -43
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +14 -9
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +38 -7
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +17 -0
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +28 -5
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +74 -35
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +88 -25
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -64
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +72 -37
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +45 -15
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +14 -0
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +41 -16
  232. tpu_inference/spec_decode/__init__.py +13 -0
  233. tpu_inference/spec_decode/jax/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  235. tpu_inference/tpu_info.py +14 -0
  236. tpu_inference/utils.py +42 -36
  237. tpu_inference/worker/__init__.py +13 -0
  238. tpu_inference/worker/tpu_worker.py +63 -50
  239. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
  240. tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
  241. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  242. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  245. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  246. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  247. tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
  248. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
  249. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
  250. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,312 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import tempfile
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ import pytest
20
+ import torch
21
+ import torchax
22
+ from jax._src import test_util as jtu
23
+ from jax.sharding import PartitionSpec
24
+ from torchax.ops.mappings import j2t, t2j
25
+ from vllm.config import ParallelConfig, set_current_vllm_config
26
+ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
27
+ init_distributed_environment)
28
+ from vllm.engine.arg_utils import EngineArgs
29
+ from vllm.forward_context import set_forward_context
30
+ from vllm.model_executor.layers.fused_moe.layer import FusedMoE
31
+
32
+ from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
33
+ from tpu_inference.layers.vllm.quantization.mxfp4 import (VllmMxfp4Config,
34
+ VllmMxfp4MoEMethod)
35
+
36
+ from . import utils as test_utils
37
+
38
+ P = PartitionSpec
39
+ MODELS = ["openai/gpt-oss-20b"]
40
+ MXFP4_BLOCK_SIZE = 32
41
+
42
+ if not jtu.is_device_tpu_at_least(version=7):
43
+ pytest.skip(allow_module_level=True, reason="Expected TPUv7+")
44
+
45
+
46
+ def quantize_to_mxfp4(weight: torch.tensor):
47
+ # Utilize JAX because native support for e2m1 makes it easier to work with.
48
+ weight = t2j(weight)
49
+ e2m1_finfo = jnp.finfo(jnp.float4_e2m1fn)
50
+ dtype_min = float(e2m1_finfo.min)
51
+ dtype_max = float(e2m1_finfo.max)
52
+
53
+ # Do a subchannel quantization where block size is 32.
54
+ weight_shape = weight.shape
55
+ weight_block = weight.reshape(weight_shape[:-1] + (-1, MXFP4_BLOCK_SIZE))
56
+ abs_max = jnp.max(jnp.abs(weight_block), axis=-1, keepdims=True)
57
+ scale = abs_max / dtype_max
58
+
59
+ weight_q = jnp.clip(weight_block / scale, dtype_min, dtype_max)
60
+ weight_q = weight_q.astype(jnp.float4_e2m1fn).reshape(weight_shape[:-1] +
61
+ (-1, 2))
62
+ weight_packed = jax.lax.bitcast_convert_type(weight_q, jnp.uint8)
63
+
64
+ # We convert scale into e8m0 manually because there is no hardware support.
65
+ e8m0_finfo = jnp.finfo(jnp.float8_e8m0fnu)
66
+ _, scale_exp = jnp.frexp(scale.squeeze(axis=-1))
67
+ # Subtract by one sinced e8m0 has no decimal
68
+ scale_exp -= 1
69
+ scale_exp = (scale_exp - e8m0_finfo.minexp).astype(jnp.uint8)
70
+
71
+ return j2t(weight_packed), j2t(scale_exp)
72
+
73
+
74
+ @pytest.fixture(autouse=True)
75
+ def setup_environment():
76
+ # This is a fake config used for init dist env.
77
+ # RowParallelLinear needs dist env to be initialized.
78
+ engine_args = EngineArgs(
79
+ model=MODELS[0],
80
+ max_model_len=64,
81
+ max_num_batched_tokens=64,
82
+ max_num_seqs=4,
83
+ load_format='dummy',
84
+ )
85
+
86
+ vllm_config = engine_args.create_engine_config()
87
+
88
+ with set_current_vllm_config(vllm_config):
89
+ temp_file = tempfile.mkstemp()[1]
90
+ init_distributed_environment(
91
+ 1,
92
+ 0,
93
+ local_rank=0,
94
+ distributed_init_method=f"file://{temp_file}",
95
+ backend="gloo")
96
+ ensure_model_parallel_initialized(1, 1)
97
+
98
+
99
+ @pytest.mark.parametrize("model", MODELS)
100
+ @pytest.mark.parametrize("mesh", [
101
+ test_utils.get_spmd_mesh(1),
102
+ test_utils.get_spmd_mesh(jax.local_device_count())
103
+ ])
104
+ def test_quant_override(model, mesh):
105
+
106
+ engine_args = EngineArgs(
107
+ model=model,
108
+ max_model_len=64,
109
+ max_num_batched_tokens=64,
110
+ max_num_seqs=4,
111
+ load_format='dummy',
112
+ )
113
+ vllm_config = engine_args.create_engine_config()
114
+ vllm_config.model_config.dtype = torch.bfloat16
115
+
116
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
117
+ assert isinstance(quant_config, VllmMxfp4Config)
118
+ assert quant_config.vllm_config == vllm_config
119
+ assert quant_config.mesh == mesh
120
+
121
+
122
+ @pytest.mark.parametrize("num_devices", [1, 2])
123
+ @pytest.mark.parametrize("num_tokens", [8])
124
+ @pytest.mark.parametrize("intermediate_size", [1024])
125
+ @pytest.mark.parametrize("hidden_size", [128])
126
+ @pytest.mark.parametrize("num_experts", [8])
127
+ @pytest.mark.parametrize("topk", [2])
128
+ @pytest.mark.parametrize("use_ep", [True, False])
129
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
130
+ def test_mxfp4_fused_moe(num_devices, num_tokens, intermediate_size,
131
+ hidden_size, num_experts, topk, use_ep,
132
+ enable_attn_dp):
133
+ # Skip if enable_attn_dp is True but we don't have enough devices
134
+ if enable_attn_dp and num_devices < 2:
135
+ pytest.skip("enable_attn_dp requires at least 2 devices")
136
+
137
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
138
+ torch.manual_seed(42)
139
+ dtype = torch.bfloat16
140
+
141
+ a = torch.randn((num_tokens, hidden_size), dtype=dtype) / 10
142
+ w1 = torch.randn(
143
+ (num_experts, 2 * intermediate_size, hidden_size), dtype=dtype) / 10
144
+ w2 = torch.randn(
145
+ (num_experts, hidden_size, intermediate_size), dtype=dtype) / 10
146
+ w1_weight, w1_weight_scale = quantize_to_mxfp4(w1)
147
+ w2_weight, w2_weight_scale = quantize_to_mxfp4(w2)
148
+
149
+ w1_bias = torch.randn(
150
+ (num_experts, 2 * intermediate_size), dtype=dtype) / 10
151
+ w2_bias = torch.randn((num_experts, hidden_size), dtype=dtype) / 10
152
+ score = torch.randn((num_tokens, num_experts), dtype=dtype)
153
+
154
+ engine_args = EngineArgs(
155
+ model=MODELS[0],
156
+ max_model_len=64,
157
+ max_num_batched_tokens=64,
158
+ max_num_seqs=4,
159
+ load_format='dummy',
160
+ )
161
+ vllm_config = engine_args.create_engine_config()
162
+ vllm_config.model_config.dtype = dtype
163
+
164
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
165
+ with set_current_vllm_config(vllm_config):
166
+ vllm_fused_moe = FusedMoE(
167
+ num_experts=num_experts,
168
+ top_k=topk,
169
+ hidden_size=hidden_size,
170
+ intermediate_size=intermediate_size,
171
+ reduce_results=False,
172
+ renormalize=False,
173
+ tp_size=1,
174
+ dp_size=1,
175
+ quant_config=quant_config,
176
+ has_bias=True,
177
+ )
178
+ vllm_fused_moe.moe_parallel_config.use_ep = use_ep
179
+ vllm_fused_moe.w13_weight.data = w1_weight
180
+ vllm_fused_moe.w2_weight.data = w2_weight
181
+ vllm_fused_moe.w13_weight_scale.data = w1_weight_scale
182
+ vllm_fused_moe.w2_weight_scale.data = w2_weight_scale
183
+ vllm_fused_moe.w13_bias.data = w1_bias
184
+ vllm_fused_moe.w2_bias.data = w2_bias
185
+
186
+ expected = test_utils.ref_moe(a, score, w1, w2, w1_bias, w2_bias,
187
+ vllm_fused_moe.top_k,
188
+ vllm_fused_moe.renormalize,
189
+ vllm_fused_moe.activation)
190
+
191
+ with torchax.default_env(), set_forward_context(None, vllm_config):
192
+ assert isinstance(vllm_fused_moe.quant_method, VllmMxfp4MoEMethod)
193
+
194
+ jax_a = a.to('jax')
195
+ score = score.to('jax')
196
+
197
+ vllm_fused_moe.quant_method.process_weights_after_loading(
198
+ vllm_fused_moe)
199
+
200
+ actual = vllm_fused_moe(jax_a, score)
201
+
202
+ torch.testing.assert_close(expected,
203
+ actual,
204
+ check_device=False,
205
+ atol=1e-1,
206
+ rtol=1e-1)
207
+
208
+
209
+ @pytest.mark.parametrize("num_devices", [1, 2])
210
+ @pytest.mark.parametrize("num_tokens", [8])
211
+ @pytest.mark.parametrize("intermediate_size", [512])
212
+ @pytest.mark.parametrize("hidden_size", [1024])
213
+ @pytest.mark.parametrize("num_experts", [8])
214
+ @pytest.mark.parametrize("topk", [2])
215
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
216
+ def test_mxfp4_fused_moe_use_kernel(num_devices, num_tokens, intermediate_size,
217
+ hidden_size, num_experts, topk,
218
+ enable_attn_dp):
219
+ # Skip if enable_attn_dp is True but we don't have enough devices
220
+ if enable_attn_dp and num_devices < 2:
221
+ pytest.skip("enable_attn_dp requires at least 2 devices")
222
+
223
+ # Skip attn_dp tests for fused_moe_use_kernel since the kernel only supports 2D mesh
224
+ if enable_attn_dp:
225
+ pytest.skip(
226
+ "fused_moe kernel does not support attn_dp (requires 2D mesh)")
227
+
228
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
229
+
230
+ torch.manual_seed(42)
231
+ dtype = torch.bfloat16
232
+
233
+ a = torch.randn((num_tokens, hidden_size), dtype=dtype) / 10
234
+ w1 = torch.randn(
235
+ (num_experts, 2 * intermediate_size, hidden_size), dtype=dtype) / 10
236
+ w2 = torch.randn(
237
+ (num_experts, hidden_size, intermediate_size), dtype=dtype) / 10
238
+ w1_weight, w1_weight_scale = quantize_to_mxfp4(w1)
239
+ w2_weight, w2_weight_scale = quantize_to_mxfp4(w2)
240
+
241
+ w1_bias = torch.randn(
242
+ (num_experts, 2 * intermediate_size), dtype=dtype) / 10
243
+ w2_bias = torch.randn((num_experts, hidden_size), dtype=dtype) / 10
244
+ score = torch.randn((num_tokens, num_experts), dtype=dtype)
245
+
246
+ engine_args = EngineArgs(
247
+ model=MODELS[0],
248
+ max_model_len=64,
249
+ max_num_batched_tokens=64,
250
+ max_num_seqs=4,
251
+ load_format='dummy',
252
+ )
253
+ vllm_config = engine_args.create_engine_config()
254
+ vllm_config.model_config.dtype = dtype
255
+ vllm_config.parallel_config = ParallelConfig(
256
+ tensor_parallel_size=mesh.devices.size)
257
+
258
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
259
+ with set_current_vllm_config(vllm_config):
260
+ vllm_fused_moe = FusedMoE(
261
+ num_experts=num_experts,
262
+ top_k=topk,
263
+ hidden_size=hidden_size,
264
+ intermediate_size=intermediate_size,
265
+ reduce_results=False,
266
+ renormalize=False,
267
+ tp_size=1,
268
+ dp_size=1,
269
+ quant_config=quant_config,
270
+ has_bias=True,
271
+ )
272
+ vllm_fused_moe.moe_parallel_config.use_ep = True
273
+
274
+ vllm_fused_moe.w13_weight.data = w1_weight
275
+ vllm_fused_moe.w2_weight.data = w2_weight
276
+ vllm_fused_moe.w13_weight_scale.data = w1_weight_scale
277
+ vllm_fused_moe.w2_weight_scale.data = w2_weight_scale
278
+ vllm_fused_moe.w13_bias.data = w1_bias
279
+ vllm_fused_moe.w2_bias.data = w2_bias
280
+
281
+ expected = test_utils.ref_moe(a, score, w1, w2, w1_bias, w2_bias,
282
+ vllm_fused_moe.top_k,
283
+ vllm_fused_moe.renormalize,
284
+ vllm_fused_moe.activation)
285
+
286
+ with torchax.default_env(), set_forward_context(None, vllm_config):
287
+ assert isinstance(vllm_fused_moe.quant_method, VllmMxfp4MoEMethod)
288
+
289
+ jax_a = a.to('jax')
290
+ score = score.to('jax')
291
+
292
+ vllm_fused_moe.quant_method.use_kernel = True
293
+ vllm_fused_moe.quant_method.process_weights_after_loading(
294
+ vllm_fused_moe)
295
+ vllm_fused_moe.quant_method.block_size = {
296
+ "bt": 32,
297
+ "bf": 512,
298
+ "bd1": 1024,
299
+ "bd2": 1024,
300
+ "btc": 32,
301
+ "bfc": 512,
302
+ "bd1c": 1024,
303
+ "bd2c": 1024,
304
+ }
305
+
306
+ actual = vllm_fused_moe(jax_a, score)
307
+
308
+ torch.testing.assert_close(expected,
309
+ actual,
310
+ check_device=False,
311
+ atol=2e-1,
312
+ rtol=2e-1)