tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (257) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +317 -34
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +406 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +320 -0
  64. tests/layers/vllm/test_unquantized.py +662 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +26 -6
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +25 -4
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +25 -12
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  154. tpu_inference/layers/common/quant_methods.py +15 -0
  155. tpu_inference/layers/common/quantization.py +282 -0
  156. tpu_inference/layers/common/sharding.py +32 -9
  157. tpu_inference/layers/common/utils.py +94 -0
  158. tpu_inference/layers/jax/__init__.py +13 -0
  159. tpu_inference/layers/jax/attention/__init__.py +13 -0
  160. tpu_inference/layers/jax/attention/attention.py +19 -6
  161. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  162. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  163. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  164. tpu_inference/layers/jax/base.py +14 -0
  165. tpu_inference/layers/jax/constants.py +13 -0
  166. tpu_inference/layers/jax/layers.py +14 -0
  167. tpu_inference/layers/jax/misc.py +14 -0
  168. tpu_inference/layers/jax/moe/__init__.py +13 -0
  169. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  170. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  171. tpu_inference/layers/jax/moe/moe.py +43 -3
  172. tpu_inference/layers/jax/pp_utils.py +53 -0
  173. tpu_inference/layers/jax/rope.py +14 -0
  174. tpu_inference/layers/jax/rope_interface.py +14 -0
  175. tpu_inference/layers/jax/sample/__init__.py +13 -0
  176. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  177. tpu_inference/layers/jax/sample/sampling.py +15 -1
  178. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  179. tpu_inference/layers/jax/transformer_block.py +14 -0
  180. tpu_inference/layers/vllm/__init__.py +13 -0
  181. tpu_inference/layers/vllm/attention.py +4 -4
  182. tpu_inference/layers/vllm/fused_moe.py +101 -494
  183. tpu_inference/layers/vllm/linear.py +64 -0
  184. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  185. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  186. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  187. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  188. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  189. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  191. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
  192. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
  193. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  194. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  195. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  196. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
  197. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  198. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
  199. tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
  200. tpu_inference/lora/__init__.py +13 -0
  201. tpu_inference/lora/torch_lora_ops.py +8 -13
  202. tpu_inference/models/__init__.py +13 -0
  203. tpu_inference/models/common/__init__.py +13 -0
  204. tpu_inference/models/common/model_loader.py +112 -35
  205. tpu_inference/models/jax/__init__.py +13 -0
  206. tpu_inference/models/jax/deepseek_v3.py +267 -157
  207. tpu_inference/models/jax/gpt_oss.py +26 -10
  208. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  209. tpu_inference/models/jax/llama3.py +99 -36
  210. tpu_inference/models/jax/llama4.py +14 -0
  211. tpu_inference/models/jax/llama_eagle3.py +18 -5
  212. tpu_inference/models/jax/llama_guard_4.py +15 -1
  213. tpu_inference/models/jax/qwen2.py +17 -2
  214. tpu_inference/models/jax/qwen2_5_vl.py +179 -51
  215. tpu_inference/models/jax/qwen3.py +17 -2
  216. tpu_inference/models/jax/utils/__init__.py +13 -0
  217. tpu_inference/models/jax/utils/file_utils.py +14 -0
  218. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  219. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  220. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
  221. tpu_inference/models/jax/utils/weight_utils.py +234 -155
  222. tpu_inference/models/vllm/__init__.py +13 -0
  223. tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
  224. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  225. tpu_inference/platforms/__init__.py +14 -0
  226. tpu_inference/platforms/tpu_platform.py +51 -72
  227. tpu_inference/runner/__init__.py +13 -0
  228. tpu_inference/runner/compilation_manager.py +180 -80
  229. tpu_inference/runner/kv_cache.py +54 -20
  230. tpu_inference/runner/kv_cache_manager.py +55 -33
  231. tpu_inference/runner/lora_utils.py +16 -1
  232. tpu_inference/runner/multimodal_manager.py +16 -2
  233. tpu_inference/runner/persistent_batch_manager.py +54 -2
  234. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  235. tpu_inference/runner/structured_decoding_manager.py +16 -3
  236. tpu_inference/runner/tpu_runner.py +124 -61
  237. tpu_inference/runner/utils.py +2 -2
  238. tpu_inference/spec_decode/__init__.py +13 -0
  239. tpu_inference/spec_decode/jax/__init__.py +13 -0
  240. tpu_inference/spec_decode/jax/eagle3.py +84 -22
  241. tpu_inference/tpu_info.py +14 -0
  242. tpu_inference/utils.py +72 -44
  243. tpu_inference/worker/__init__.py +13 -0
  244. tpu_inference/worker/tpu_worker.py +66 -52
  245. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
  246. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  247. tpu_inference/layers/vllm/linear_common.py +0 -186
  248. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  249. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  250. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  251. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  252. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  253. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  254. tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
  255. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  256. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  257. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,21 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import jax
2
16
  import jax.numpy as jnp
3
17
  import numpy as np
4
- from absl.testing import absltest
18
+ from absl.testing import absltest, parameterized
5
19
  from jax._src import test_util as jtu
6
20
  from jax.sharding import Mesh
7
21
 
@@ -10,6 +24,15 @@ from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe, ref_moe
10
24
  jax.config.parse_flags_with_absl()
11
25
 
12
26
 
27
+ def cdiv(a, b):
28
+ assert b != 0
29
+ return (a + b - 1) // b
30
+
31
+
32
+ def align_to(x, a):
33
+ return cdiv(x, a) * a
34
+
35
+
13
36
  def gen_moe_inputs(
14
37
  dtype,
15
38
  top_k,
@@ -19,11 +42,14 @@ def gen_moe_inputs(
19
42
  num_tokens,
20
43
  *,
21
44
  seed=1234,
45
+ has_bias=False,
22
46
  ):
23
47
  key = jax.random.key(seed)
24
- k0, k1, k2, k4, k5 = jax.random.split(key, 5)
48
+ k0, k1, k2, k3, k4, k5, k6 = jax.random.split(key, 7)
49
+
25
50
  a = jax.random.normal(k0, (num_tokens, hidden_size),
26
51
  dtype=jnp.float32).astype(dtype) / 10
52
+
27
53
  w1 = (jax.random.normal(
28
54
  k1,
29
55
  (num_experts, 2, hidden_size, intermediate_size),
@@ -31,21 +57,54 @@ def gen_moe_inputs(
31
57
  ) / 10).astype(dtype)
32
58
  w2 = (jax.random.normal(k2, (num_experts, intermediate_size, hidden_size),
33
59
  dtype=jnp.float32) / 10).astype(dtype)
60
+
61
+ if has_bias:
62
+ b1 = (jax.random.normal(k3, (num_experts, 2, intermediate_size),
63
+ dtype=jnp.float32) / 10).astype(dtype)
64
+ b2 = (jax.random.normal(k4, (num_experts, hidden_size),
65
+ dtype=jnp.float32) / 10).astype(dtype)
66
+ else:
67
+ b1 = b2 = None
68
+
34
69
  gating_output = (
35
- jax.random.normal(k4, (num_tokens, num_experts), dtype=jnp.float32) +
70
+ jax.random.normal(k5, (num_tokens, num_experts), dtype=jnp.float32) +
36
71
  jnp.arange(num_tokens * num_experts, dtype=jnp.float32).reshape(
37
72
  num_tokens, num_experts) / 100)
73
+
38
74
  # To generate unique top-k!
39
- top_k_indices = jax.random.randint(k5, (num_tokens, top_k),
75
+ top_k_indices = jax.random.randint(k6, (num_tokens, top_k),
40
76
  minval=0,
41
77
  maxval=num_experts - 1,
42
78
  dtype=jnp.int32)
79
+
43
80
  one_hot = (jnp.sum(
44
81
  jax.nn.one_hot(top_k_indices, num_experts, dtype=jnp.float32),
45
82
  axis=1,
46
- ) * 10)
83
+ ) * 30)
84
+
47
85
  gating_output = (gating_output + one_hot).astype(dtype)
48
- return a, w1, w2, gating_output
86
+
87
+ return a, w1, w2, b1, b2, gating_output
88
+
89
+
90
+ def sub_channel_quantize(x, quant_dtype, wsz=256):
91
+ """Quantizes x with sub-channel quantization on the 2nd minor."""
92
+ if jnp.issubdtype(quant_dtype, jnp.floating):
93
+ dtype_info = jnp.finfo(quant_dtype)
94
+ else:
95
+ dtype_info = jnp.iinfo(quant_dtype)
96
+ dtype_max = float(dtype_info.max)
97
+ w_lst, scale_lst = [], []
98
+ assert len(x.shape) >= 2
99
+ assert x.shape[-2] % wsz == 0
100
+ for i in range(0, x.shape[-2], wsz):
101
+ y = x[..., i:i + wsz, :]
102
+ abs_max = jnp.abs(y).max(axis=-2, keepdims=True)
103
+ scale = (abs_max / dtype_max).astype(jnp.float32)
104
+ w = (y / scale).astype(quant_dtype)
105
+ w_lst.append(w)
106
+ scale_lst.append(scale)
107
+ return jnp.concat(w_lst, axis=-2), jnp.concat(scale_lst, axis=-2)
49
108
 
50
109
 
51
110
  @jtu.with_config(jax_numpy_dtype_promotion="standard")
@@ -63,42 +122,266 @@ class MoEKernelTest(jtu.JaxTestCase):
63
122
  self.mesh = Mesh(np.array(self.mesh_devices).reshape(1, -1),
64
123
  axis_names=("data", "model"))
65
124
 
66
- def test_basic(self):
67
- dtype = jnp.bfloat16
68
- top_k = 2
69
- num_experts = 16
70
- hidden_size = 256
71
- intermediate_size = 256
72
- num_tokens = 8 * 2
73
-
74
- a, w1, w2, gating_output = gen_moe_inputs(
125
+ def _test_moe(
126
+ self,
127
+ dtype,
128
+ top_k,
129
+ num_experts,
130
+ hidden_size,
131
+ intermediate_size,
132
+ num_tokens,
133
+ seed,
134
+ renormalize_topk_logits,
135
+ bt,
136
+ bf,
137
+ bd1,
138
+ bd2,
139
+ btc,
140
+ bfc,
141
+ bd1c,
142
+ bd2c,
143
+ act_fn="silu",
144
+ w_dtype=None,
145
+ subc_quant_wsz=None,
146
+ has_bias=False,
147
+ atol=2e-1,
148
+ rtol=2e-1,
149
+ ):
150
+ a, w1, w2, b1, b2, gating_output = gen_moe_inputs(
75
151
  dtype,
76
152
  top_k,
77
153
  num_experts,
78
154
  hidden_size,
79
155
  intermediate_size,
80
156
  num_tokens,
157
+ seed=seed,
158
+ has_bias=has_bias,
159
+ )
160
+ w1_scale = None
161
+ w2_scale = None
162
+ if w_dtype is not None:
163
+ if subc_quant_wsz is None:
164
+ subc_quant_wsz = 256
165
+ w1, w1_scale = sub_channel_quantize(w1, w_dtype, subc_quant_wsz)
166
+ w2, w2_scale = sub_channel_quantize(w2, w_dtype, subc_quant_wsz)
167
+
168
+ actual = fused_ep_moe(
169
+ mesh=self.mesh,
170
+ tokens=a,
171
+ w1=w1,
172
+ w2=w2,
173
+ gating_output=gating_output,
174
+ top_k=top_k,
175
+ renormalize_topk_logits=renormalize_topk_logits,
176
+ act_fn=act_fn,
177
+ subc_quant_wsz=subc_quant_wsz,
178
+ w1_scale=w1_scale,
179
+ w2_scale=w2_scale,
180
+ b1=b1,
181
+ b2=b2,
182
+ bt=bt,
183
+ bf=bf,
184
+ bd1=bd1,
185
+ bd2=bd2,
186
+ btc=btc,
187
+ bfc=bfc,
188
+ bd1c=bd1c,
189
+ bd2c=bd2c,
190
+ )
191
+ expected = ref_moe(
192
+ a,
193
+ w1,
194
+ w2,
195
+ gating_output,
196
+ top_k,
197
+ b1=b1,
198
+ b2=b2,
199
+ renormalize_topk_logits=renormalize_topk_logits,
200
+ activation=act_fn,
201
+ subc_quant_wsz=subc_quant_wsz,
202
+ w1_scale=w1_scale,
203
+ w2_scale=w2_scale,
204
+ )
205
+ self.assertAllClose(actual, expected, atol=atol, rtol=rtol)
206
+
207
+ @parameterized.product(renormalize_topk_logits=[True, False], )
208
+ def test_basic(self, renormalize_topk_logits):
209
+ dtype = jnp.bfloat16
210
+ top_k = 8
211
+ num_experts = 128
212
+ hidden_size = 1024
213
+ intermediate_size = 1024
214
+ num_tokens = 8 * 32
215
+ self._test_moe(
216
+ dtype=dtype,
217
+ top_k=top_k,
218
+ num_experts=num_experts,
219
+ hidden_size=hidden_size,
220
+ intermediate_size=intermediate_size,
221
+ num_tokens=num_tokens,
222
+ seed=1234,
223
+ renormalize_topk_logits=renormalize_topk_logits,
224
+ bt=32,
225
+ bf=1024,
226
+ bd1=1024,
227
+ bd2=1024,
228
+ btc=32,
229
+ bfc=256,
230
+ bd1c=256,
231
+ bd2c=256,
81
232
  )
82
233
 
83
- actual = jax.block_until_ready(
84
- fused_ep_moe(
85
- mesh=self.mesh,
86
- tokens=a,
87
- w1=w1,
88
- w2=w2,
89
- gating_output=gating_output,
90
- top_k=top_k,
91
- bt=32,
92
- bf=512,
93
- bd1=512,
94
- bd2=512,
95
- btc=32,
96
- bfc=256,
97
- bd1c=256,
98
- bd2c=256,
99
- ))
100
- expected = ref_moe(a, w1, w2, gating_output, top_k)
101
- self.assertAllClose(expected, actual, atol=2e-2, rtol=2e-2)
234
+ @parameterized.product(act_fn=["silu", "gelu", "swigluoai"], )
235
+ def test_activation(self, act_fn):
236
+ dtype = jnp.bfloat16
237
+ top_k = 8
238
+ num_experts = 128
239
+ hidden_size = 1024
240
+ intermediate_size = 1024
241
+ num_tokens = 8 * 32
242
+ self._test_moe(
243
+ dtype=dtype,
244
+ top_k=top_k,
245
+ num_experts=num_experts,
246
+ hidden_size=hidden_size,
247
+ intermediate_size=intermediate_size,
248
+ num_tokens=num_tokens,
249
+ seed=1234,
250
+ renormalize_topk_logits=True,
251
+ act_fn=act_fn,
252
+ bt=32,
253
+ bf=512,
254
+ bd1=512,
255
+ bd2=512,
256
+ btc=32,
257
+ bfc=256,
258
+ bd1c=256,
259
+ bd2c=256,
260
+ )
261
+
262
+ def test_benchmark_qwen_235(self):
263
+ num_experts = 128
264
+ top_k = 8
265
+ hidden_size = 4096
266
+ intermediate_size = 1536
267
+ dtype = jnp.bfloat16
268
+ num_tokens = 8 * 64
269
+ seed = 54321
270
+ renormalize_topk_logits = True
271
+ self._test_moe(
272
+ dtype=dtype,
273
+ top_k=top_k,
274
+ num_experts=num_experts,
275
+ hidden_size=hidden_size,
276
+ intermediate_size=intermediate_size,
277
+ num_tokens=num_tokens,
278
+ seed=seed,
279
+ renormalize_topk_logits=renormalize_topk_logits,
280
+ bt=64,
281
+ bf=768,
282
+ bd1=2048,
283
+ bd2=2048,
284
+ btc=64,
285
+ bfc=768,
286
+ bd1c=2048,
287
+ bd2c=2048,
288
+ act_fn="silu",
289
+ atol=5e-2,
290
+ rtol=5e-2,
291
+ )
292
+
293
+ def test_benchmark_qwen_30b_a3b(self):
294
+ num_experts = 128
295
+ top_k = 8
296
+ hidden_size = 2048
297
+ intermediate_size = 768
298
+ dtype = jnp.bfloat16
299
+ num_tokens = 512
300
+ seed = 54321
301
+ renormalize_topk_logits = True
302
+ self._test_moe(
303
+ dtype=dtype,
304
+ top_k=top_k,
305
+ num_experts=num_experts,
306
+ hidden_size=hidden_size,
307
+ intermediate_size=intermediate_size,
308
+ num_tokens=num_tokens,
309
+ seed=seed,
310
+ renormalize_topk_logits=renormalize_topk_logits,
311
+ bt=16,
312
+ bf=384,
313
+ bd1=512,
314
+ bd2=512,
315
+ btc=16,
316
+ bfc=384,
317
+ bd1c=256,
318
+ bd2c=256,
319
+ act_fn="silu",
320
+ atol=5e-2,
321
+ rtol=5e-2,
322
+ )
323
+
324
+ @parameterized.product(
325
+ w_dtype=[jnp.int8, jnp.float8_e5m2, jnp.float4_e2m1fn], )
326
+ def test_sub_channel_quantization(self, w_dtype):
327
+ if w_dtype in (
328
+ jnp.float8_e5m2,
329
+ jnp.float4_e2m1fn,
330
+ ) and not jtu.is_device_tpu_at_least(version=7):
331
+ self.skipTest("Expect TPUv7+")
332
+ dtype = jnp.bfloat16
333
+ top_k = 8
334
+ num_experts = 128
335
+ hidden_size = 1024
336
+ intermediate_size = 1024
337
+ num_tokens = 8 * 32
338
+ self._test_moe(
339
+ dtype=dtype,
340
+ top_k=top_k,
341
+ num_experts=num_experts,
342
+ hidden_size=hidden_size,
343
+ intermediate_size=intermediate_size,
344
+ num_tokens=num_tokens,
345
+ seed=1234,
346
+ renormalize_topk_logits=False,
347
+ w_dtype=w_dtype,
348
+ subc_quant_wsz=256,
349
+ bt=32,
350
+ bf=1024,
351
+ bd1=1024,
352
+ bd2=1024,
353
+ btc=32,
354
+ bfc=256,
355
+ bd1c=256,
356
+ bd2c=256,
357
+ )
358
+
359
+ def test_bias(self):
360
+ dtype = jnp.bfloat16
361
+ top_k = 8
362
+ num_experts = 128
363
+ hidden_size = 1024
364
+ intermediate_size = 1024
365
+ num_tokens = 8 * 32
366
+ self._test_moe(
367
+ dtype=dtype,
368
+ top_k=top_k,
369
+ num_experts=num_experts,
370
+ hidden_size=hidden_size,
371
+ intermediate_size=intermediate_size,
372
+ num_tokens=num_tokens,
373
+ seed=1234,
374
+ renormalize_topk_logits=False,
375
+ has_bias=True,
376
+ bt=32,
377
+ bf=512,
378
+ bd1=512,
379
+ bd2=512,
380
+ btc=32,
381
+ bfc=256,
382
+ bd1c=256,
383
+ bd2c=256,
384
+ )
102
385
 
103
386
 
104
387
  if __name__ == "__main__":
@@ -0,0 +1,205 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import jax
16
+ import jax.numpy as jnp
17
+ from absl.testing import absltest, parameterized
18
+ from jax._src import test_util as jtu
19
+
20
+ from tpu_inference.kernels.megablox.gmm import gmm
21
+
22
+ jax.config.parse_flags_with_absl()
23
+
24
+
25
+ def quantize_tensor(x: jax.Array,
26
+ dtype: jnp.dtype,
27
+ axis: int = -1,
28
+ block_size: int = 256):
29
+ if jnp.issubdtype(dtype, jnp.integer):
30
+ dtype_info = jnp.iinfo(dtype)
31
+ max_val = int(dtype_info.max)
32
+ min_val = int(dtype_info.min)
33
+ else:
34
+ dtype_info = jnp.finfo(dtype)
35
+ max_val = float(dtype_info.max)
36
+ min_val = float(dtype_info.min)
37
+
38
+ orig_shape = x.shape
39
+ blocked_shape = orig_shape[:axis] + (-1,
40
+ block_size) + orig_shape[axis + 1:]
41
+ x_blocked = x.reshape(blocked_shape)
42
+
43
+ x_blocked_abs_max = jnp.max(jnp.abs(x_blocked),
44
+ axis=axis + 1,
45
+ keepdims=True)
46
+ scale = x_blocked_abs_max / max_val
47
+ x_blocked_q = jnp.clip(x_blocked / scale, min_val, max_val).astype(dtype)
48
+
49
+ x_q = x_blocked_q.reshape(orig_shape)
50
+ scale = scale.squeeze(axis=axis + 1).astype(jnp.float32)
51
+ return x_q, scale
52
+
53
+
54
+ def reference_gmm(
55
+ lhs: jax.Array,
56
+ rhs: jax.Array,
57
+ group_sizes: jax.Array,
58
+ rhs_scale: jax.Array | None = None,
59
+ rhs_bias: jax.Array | None = None,
60
+ group_offset: jax.Array | None = None,
61
+ ):
62
+ num_groups, out_size, in_size = rhs.shape
63
+ assert lhs.shape[1] == in_size
64
+
65
+ if group_offset is None:
66
+ group_offset = jnp.array(0, dtype=jnp.int32)
67
+ start = group_sizes[:group_offset].sum()
68
+ group_sizes = group_sizes[group_offset:]
69
+ assert len(group_sizes) == num_groups
70
+
71
+ if rhs_scale is not None:
72
+ num_blocks = rhs_scale.shape[1]
73
+ else:
74
+ num_blocks = 1
75
+ block_size = in_size // num_blocks
76
+
77
+ gmm_out = [jnp.zeros((start, out_size), lhs.dtype)]
78
+ for group in range(num_groups):
79
+ end = start + group_sizes[group]
80
+
81
+ lhs_slice = lhs[start:end]
82
+ rhs_slice = rhs[group]
83
+
84
+ out = 0
85
+ for block in range(num_blocks):
86
+ block_start = block * block_size
87
+ block_end = block_start + block_size
88
+ lhs_block = lhs_slice[:, block_start:block_end].astype(jnp.float32)
89
+ rhs_block = rhs_slice[:, block_start:block_end].astype(jnp.float32)
90
+
91
+ acc = jnp.einsum("bd,hd->bh", lhs_block, rhs_block)
92
+ if rhs_scale is not None:
93
+ acc *= rhs_scale[group][block]
94
+ out += acc
95
+ if rhs_bias is not None:
96
+ out = out + rhs_bias[group]
97
+
98
+ gmm_out.append(out.astype(lhs.dtype))
99
+ start = end
100
+
101
+ return jnp.concat(gmm_out, axis=0)
102
+
103
+
104
+ @jtu.with_config(jax_numpy_dtype_promotion="standard")
105
+ class GmmTest(jtu.JaxTestCase):
106
+
107
+ @parameterized.product(
108
+ batch_size=[128],
109
+ in_size=[1024],
110
+ out_size=[1024],
111
+ num_groups=[16, 32],
112
+ has_bias=[True, False],
113
+ )
114
+ def test_gmm(self, batch_size, in_size, out_size, num_groups, has_bias):
115
+ key = jax.random.key(0)
116
+
117
+ lhs = jax.random.normal(key, (batch_size, in_size), dtype=jnp.bfloat16)
118
+ rhs = jax.random.normal(key, (num_groups, out_size, in_size),
119
+ dtype=jnp.bfloat16)
120
+ rhs_bias = None
121
+ if has_bias:
122
+ rhs_bias = jax.random.normal(key, (num_groups, 1, out_size),
123
+ dtype=jnp.bfloat16)
124
+
125
+ group_sizes = jax.random.randint(key, (num_groups, ),
126
+ 0,
127
+ batch_size,
128
+ dtype=jnp.int32)
129
+
130
+ expected = reference_gmm(lhs, rhs, group_sizes, rhs_bias=rhs_bias)
131
+ actual = gmm(
132
+ lhs,
133
+ rhs,
134
+ group_sizes,
135
+ rhs_bias=rhs_bias,
136
+ transpose_rhs=True,
137
+ preferred_element_type=jnp.bfloat16,
138
+ )
139
+
140
+ self.assertArraysAllClose(actual, expected)
141
+
142
+ @parameterized.product(
143
+ batch_size=[128],
144
+ in_size=[1024],
145
+ out_size=[1024],
146
+ num_groups=[16, 32],
147
+ has_bias=[True, False],
148
+ weight_dtype=[jnp.int8, jnp.float8_e5m2, jnp.float4_e2m1fn],
149
+ block_size=[256, 512],
150
+ )
151
+ def test_gmm_weight_quantized(
152
+ self,
153
+ batch_size,
154
+ in_size,
155
+ out_size,
156
+ num_groups,
157
+ has_bias,
158
+ weight_dtype,
159
+ block_size,
160
+ ):
161
+ if weight_dtype == jnp.float4_e2m1fn and not jtu.is_device_tpu_at_least(
162
+ version=7):
163
+ self.skipTest("Expect TPUv7+")
164
+ key = jax.random.key(0)
165
+
166
+ lhs = jax.random.normal(key, (batch_size, in_size), dtype=jnp.bfloat16)
167
+ rhs = jax.random.normal(key, (num_groups, out_size, in_size),
168
+ dtype=jnp.bfloat16)
169
+ rhs_q, rhs_scale = quantize_tensor(rhs,
170
+ weight_dtype,
171
+ axis=2,
172
+ block_size=block_size)
173
+ rhs_scale = jnp.swapaxes(rhs_scale, 1, 2)
174
+ rhs_scale = jnp.expand_dims(rhs_scale, axis=2)
175
+
176
+ rhs_bias = None
177
+ if has_bias:
178
+ rhs_bias = jax.random.normal(key, (num_groups, 1, out_size),
179
+ dtype=jnp.bfloat16)
180
+
181
+ group_sizes = jax.random.randint(key, (num_groups, ),
182
+ 0,
183
+ batch_size,
184
+ dtype=jnp.int32)
185
+
186
+ expected = reference_gmm(lhs,
187
+ rhs_q,
188
+ group_sizes,
189
+ rhs_scale=rhs_scale,
190
+ rhs_bias=rhs_bias)
191
+ actual = gmm(
192
+ lhs,
193
+ rhs_q,
194
+ group_sizes,
195
+ rhs_scale=rhs_scale,
196
+ rhs_bias=rhs_bias,
197
+ transpose_rhs=True,
198
+ preferred_element_type=jnp.bfloat16,
199
+ )
200
+
201
+ self.assertArraysAllClose(actual, expected, atol=3e-1, rtol=3e-1)
202
+
203
+
204
+ if __name__ == "__main__":
205
+ absltest.main(testLoader=jtu.JaxTestLoader())