tpu-inference 0.0.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (174) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +374 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +648 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +88 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +203 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +235 -0
  27. tpu_inference/__init__.py +53 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +49 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +727 -0
  37. tpu_inference/distributed/utils.py +60 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +160 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +382 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1566 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1501 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1603 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +396 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +469 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +110 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +331 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +368 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +310 -0
  120. tpu_inference/models/__init__.py +0 -0
  121. tpu_inference/models/common/__init__.py +0 -0
  122. tpu_inference/models/common/model_loader.py +478 -0
  123. tpu_inference/models/jax/__init__.py +0 -0
  124. tpu_inference/models/jax/deepseek_v3.py +868 -0
  125. tpu_inference/models/jax/gpt_oss.py +492 -0
  126. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  127. tpu_inference/models/jax/llama3.py +376 -0
  128. tpu_inference/models/jax/llama4.py +629 -0
  129. tpu_inference/models/jax/llama_eagle3.py +336 -0
  130. tpu_inference/models/jax/llama_guard_4.py +361 -0
  131. tpu_inference/models/jax/qwen2.py +376 -0
  132. tpu_inference/models/jax/qwen2_5_vl.py +1218 -0
  133. tpu_inference/models/jax/qwen3.py +303 -0
  134. tpu_inference/models/jax/utils/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/file_utils.py +96 -0
  136. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  137. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  138. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  139. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  140. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  141. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  142. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  143. tpu_inference/models/jax/utils/quantization/quantization_utils.py +650 -0
  144. tpu_inference/models/jax/utils/weight_utils.py +584 -0
  145. tpu_inference/models/vllm/__init__.py +0 -0
  146. tpu_inference/models/vllm/vllm_model_wrapper.py +293 -0
  147. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  148. tpu_inference/platforms/__init__.py +2 -0
  149. tpu_inference/platforms/tpu_platform.py +275 -0
  150. tpu_inference/runner/__init__.py +0 -0
  151. tpu_inference/runner/block_table.py +122 -0
  152. tpu_inference/runner/compilation_manager.py +865 -0
  153. tpu_inference/runner/input_batch.py +435 -0
  154. tpu_inference/runner/kv_cache.py +132 -0
  155. tpu_inference/runner/kv_cache_manager.py +478 -0
  156. tpu_inference/runner/lora_utils.py +92 -0
  157. tpu_inference/runner/multimodal_manager.py +217 -0
  158. tpu_inference/runner/persistent_batch_manager.py +282 -0
  159. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  160. tpu_inference/runner/structured_decoding_manager.py +87 -0
  161. tpu_inference/runner/tpu_runner.py +1744 -0
  162. tpu_inference/runner/utils.py +426 -0
  163. tpu_inference/spec_decode/__init__.py +0 -0
  164. tpu_inference/spec_decode/jax/__init__.py +0 -0
  165. tpu_inference/spec_decode/jax/eagle3.py +417 -0
  166. tpu_inference/tpu_info.py +78 -0
  167. tpu_inference/utils.py +340 -0
  168. tpu_inference/worker/__init__.py +0 -0
  169. tpu_inference/worker/tpu_worker.py +458 -0
  170. tpu_inference-0.0.1rc1.dist-info/METADATA +108 -0
  171. tpu_inference-0.0.1rc1.dist-info/RECORD +174 -0
  172. tpu_inference-0.0.1rc1.dist-info/WHEEL +5 -0
  173. tpu_inference-0.0.1rc1.dist-info/licenses/LICENSE +201 -0
  174. tpu_inference-0.0.1rc1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,49 @@
1
+ import importlib
2
+ import unittest
3
+ from unittest.mock import patch
4
+
5
+
6
+ class TestPathwaysInit(unittest.TestCase):
7
+
8
+ @patch.dict("os.environ", {"JAX_PLATFORMS": "proxy,cpu"})
9
+ def test_VLLM_TPU_USING_PATHWAYS_enabled(self):
10
+ """Test when JAX_PLATFORMS contains 'proxy'."""
11
+
12
+ # Import vllm.envs to test the VLLM_TPU_USING_PATHWAYS logic
13
+ import vllm.envs as envs
14
+
15
+ # Reload the module to ensure fresh import
16
+ importlib.reload(envs)
17
+
18
+ # Check that VLLM_TPU_USING_PATHWAYS is True when JAX_PLATFORMS contains "proxy"
19
+ self.assertTrue(envs.VLLM_TPU_USING_PATHWAYS)
20
+
21
+ @patch.dict("os.environ", {"JAX_PLATFORMS": "cpu"})
22
+ def test_VLLM_TPU_USING_PATHWAYS_not_enabled(self):
23
+ """Test when JAX_PLATFORMS does not contain 'proxy'."""
24
+
25
+ # Import vllm.envs to test the VLLM_TPU_USING_PATHWAYS logic
26
+ import vllm.envs as envs
27
+
28
+ # Reload the module to ensure fresh import
29
+ importlib.reload(envs)
30
+
31
+ # Check that VLLM_TPU_USING_PATHWAYS is False when JAX_PLATFORMS doesn't contain "proxy"
32
+ self.assertFalse(envs.VLLM_TPU_USING_PATHWAYS)
33
+
34
+ @patch.dict("os.environ", {"JAX_PLATFORMS": "PROXY,CPU"})
35
+ def test_VLLM_TPU_USING_PATHWAYS_case_insensitive(self):
36
+ """Test that JAX_PLATFORMS check is case insensitive."""
37
+
38
+ # Import vllm.envs to test the VLLM_TPU_USING_PATHWAYS logic
39
+ import vllm.envs as envs
40
+
41
+ # Reload the module to ensure fresh import
42
+ importlib.reload(envs)
43
+
44
+ # Check that VLLM_TPU_USING_PATHWAYS is True even with uppercase "PROXY"
45
+ self.assertTrue(envs.VLLM_TPU_USING_PATHWAYS)
46
+
47
+
48
+ if __name__ == "__main__":
49
+ unittest.main()
File without changes
@@ -0,0 +1,374 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import numpy as np
4
+ from absl.testing import absltest, parameterized
5
+ from jax._src import test_util as jtu
6
+ from jax.sharding import Mesh
7
+
8
+ from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe, ref_moe
9
+
10
+ jax.config.parse_flags_with_absl()
11
+
12
+
13
+ def cdiv(a, b):
14
+ assert b != 0
15
+ return (a + b - 1) // b
16
+
17
+
18
+ def align_to(x, a):
19
+ return cdiv(x, a) * a
20
+
21
+
22
+ def gen_moe_inputs(
23
+ dtype,
24
+ top_k,
25
+ num_experts,
26
+ hidden_size,
27
+ intermediate_size,
28
+ num_tokens,
29
+ *,
30
+ seed=1234,
31
+ has_bias=False,
32
+ ):
33
+ key = jax.random.key(seed)
34
+ k0, k1, k2, k3, k4, k5, k6 = jax.random.split(key, 7)
35
+
36
+ a = jax.random.normal(k0, (num_tokens, hidden_size),
37
+ dtype=jnp.float32).astype(dtype) / 10
38
+
39
+ w1 = (jax.random.normal(
40
+ k1,
41
+ (num_experts, 2, hidden_size, intermediate_size),
42
+ dtype=jnp.float32,
43
+ ) / 10).astype(dtype)
44
+ w2 = (jax.random.normal(k2, (num_experts, intermediate_size, hidden_size),
45
+ dtype=jnp.float32) / 10).astype(dtype)
46
+
47
+ if has_bias:
48
+ b1 = (jax.random.normal(k3, (num_experts, 2, intermediate_size),
49
+ dtype=jnp.float32) / 10).astype(dtype)
50
+ b2 = (jax.random.normal(k4, (num_experts, hidden_size),
51
+ dtype=jnp.float32) / 10).astype(dtype)
52
+ else:
53
+ b1 = b2 = None
54
+
55
+ gating_output = (
56
+ jax.random.normal(k5, (num_tokens, num_experts), dtype=jnp.float32) +
57
+ jnp.arange(num_tokens * num_experts, dtype=jnp.float32).reshape(
58
+ num_tokens, num_experts) / 100)
59
+
60
+ # To generate unique top-k!
61
+ top_k_indices = jax.random.randint(k6, (num_tokens, top_k),
62
+ minval=0,
63
+ maxval=num_experts - 1,
64
+ dtype=jnp.int32)
65
+
66
+ one_hot = (jnp.sum(
67
+ jax.nn.one_hot(top_k_indices, num_experts, dtype=jnp.float32),
68
+ axis=1,
69
+ ) * 30)
70
+
71
+ gating_output = (gating_output + one_hot).astype(dtype)
72
+
73
+ return a, w1, w2, b1, b2, gating_output
74
+
75
+
76
+ def sub_channel_quantize(x, quant_dtype, wsz=256):
77
+ """Quantizes x with sub-channel quantization on the 2nd minor."""
78
+ if jnp.issubdtype(quant_dtype, jnp.floating):
79
+ dtype_info = jnp.finfo(quant_dtype)
80
+ else:
81
+ dtype_info = jnp.iinfo(quant_dtype)
82
+ dtype_max = float(dtype_info.max)
83
+ w_lst, scale_lst = [], []
84
+ assert len(x.shape) >= 2
85
+ assert x.shape[-2] % wsz == 0
86
+ for i in range(0, x.shape[-2], wsz):
87
+ y = x[..., i:i + wsz, :]
88
+ abs_max = jnp.abs(y).max(axis=-2, keepdims=True)
89
+ scale = (abs_max / dtype_max).astype(jnp.float32)
90
+ w = (y / scale).astype(quant_dtype)
91
+ w_lst.append(w)
92
+ scale_lst.append(scale)
93
+ return jnp.concat(w_lst, axis=-2), jnp.concat(scale_lst, axis=-2)
94
+
95
+
96
+ @jtu.with_config(jax_numpy_dtype_promotion="standard")
97
+ class MoEKernelTest(jtu.JaxTestCase):
98
+
99
+ def setUp(self):
100
+ super().setUp()
101
+ self.mesh_devices = sorted(
102
+ jax.devices(),
103
+ key=lambda x: (
104
+ x.coords[0],
105
+ (-1 if x.coords[0] % 2 else 1) * x.coords[1],
106
+ ),
107
+ )
108
+ self.mesh = Mesh(np.array(self.mesh_devices).reshape(1, -1),
109
+ axis_names=("data", "model"))
110
+
111
+ def _test_moe(
112
+ self,
113
+ dtype,
114
+ top_k,
115
+ num_experts,
116
+ hidden_size,
117
+ intermediate_size,
118
+ num_tokens,
119
+ seed,
120
+ renormalize_topk_logits,
121
+ bt,
122
+ bf,
123
+ bd1,
124
+ bd2,
125
+ btc,
126
+ bfc,
127
+ bd1c,
128
+ bd2c,
129
+ act_fn="silu",
130
+ w_dtype=None,
131
+ subc_quant_wsz=None,
132
+ has_bias=False,
133
+ atol=2e-1,
134
+ rtol=2e-1,
135
+ ):
136
+ a, w1, w2, b1, b2, gating_output = gen_moe_inputs(
137
+ dtype,
138
+ top_k,
139
+ num_experts,
140
+ hidden_size,
141
+ intermediate_size,
142
+ num_tokens,
143
+ seed=seed,
144
+ has_bias=has_bias,
145
+ )
146
+ w1_scale = None
147
+ w2_scale = None
148
+ if w_dtype is not None:
149
+ if subc_quant_wsz is None:
150
+ subc_quant_wsz = 256
151
+ w1, w1_scale = sub_channel_quantize(w1, w_dtype, subc_quant_wsz)
152
+ w2, w2_scale = sub_channel_quantize(w2, w_dtype, subc_quant_wsz)
153
+
154
+ actual = fused_ep_moe(
155
+ mesh=self.mesh,
156
+ tokens=a,
157
+ w1=w1,
158
+ w2=w2,
159
+ gating_output=gating_output,
160
+ top_k=top_k,
161
+ renormalize_topk_logits=renormalize_topk_logits,
162
+ act_fn=act_fn,
163
+ subc_quant_wsz=subc_quant_wsz,
164
+ w1_scale=w1_scale,
165
+ w2_scale=w2_scale,
166
+ b1=b1,
167
+ b2=b2,
168
+ bt=bt,
169
+ bf=bf,
170
+ bd1=bd1,
171
+ bd2=bd2,
172
+ btc=btc,
173
+ bfc=bfc,
174
+ bd1c=bd1c,
175
+ bd2c=bd2c,
176
+ )
177
+ expected = ref_moe(
178
+ a,
179
+ w1,
180
+ w2,
181
+ gating_output,
182
+ top_k,
183
+ b1=b1,
184
+ b2=b2,
185
+ renormalize_topk_logits=renormalize_topk_logits,
186
+ activation=act_fn,
187
+ subc_quant_wsz=subc_quant_wsz,
188
+ w1_scale=w1_scale,
189
+ w2_scale=w2_scale,
190
+ )
191
+ self.assertAllClose(actual, expected, atol=atol, rtol=rtol)
192
+
193
+ @parameterized.product(renormalize_topk_logits=[True, False], )
194
+ def test_basic(self, renormalize_topk_logits):
195
+ dtype = jnp.bfloat16
196
+ top_k = 8
197
+ num_experts = 128
198
+ hidden_size = 1024
199
+ intermediate_size = 1024
200
+ num_tokens = 8 * 32
201
+ self._test_moe(
202
+ dtype=dtype,
203
+ top_k=top_k,
204
+ num_experts=num_experts,
205
+ hidden_size=hidden_size,
206
+ intermediate_size=intermediate_size,
207
+ num_tokens=num_tokens,
208
+ seed=1234,
209
+ renormalize_topk_logits=renormalize_topk_logits,
210
+ bt=32,
211
+ bf=1024,
212
+ bd1=1024,
213
+ bd2=1024,
214
+ btc=32,
215
+ bfc=256,
216
+ bd1c=256,
217
+ bd2c=256,
218
+ )
219
+
220
+ @parameterized.product(act_fn=["silu", "gelu", "swigluoai"], )
221
+ def test_activation(self, act_fn):
222
+ dtype = jnp.bfloat16
223
+ top_k = 8
224
+ num_experts = 128
225
+ hidden_size = 1024
226
+ intermediate_size = 1024
227
+ num_tokens = 8 * 32
228
+ self._test_moe(
229
+ dtype=dtype,
230
+ top_k=top_k,
231
+ num_experts=num_experts,
232
+ hidden_size=hidden_size,
233
+ intermediate_size=intermediate_size,
234
+ num_tokens=num_tokens,
235
+ seed=1234,
236
+ renormalize_topk_logits=True,
237
+ act_fn=act_fn,
238
+ bt=32,
239
+ bf=512,
240
+ bd1=512,
241
+ bd2=512,
242
+ btc=32,
243
+ bfc=256,
244
+ bd1c=256,
245
+ bd2c=256,
246
+ )
247
+
248
+ def test_benchmark_qwen_235(self):
249
+ num_experts = 128
250
+ top_k = 8
251
+ hidden_size = 4096
252
+ intermediate_size = 1536
253
+ dtype = jnp.bfloat16
254
+ num_tokens = 8 * 64
255
+ seed = 54321
256
+ renormalize_topk_logits = True
257
+ self._test_moe(
258
+ dtype=dtype,
259
+ top_k=top_k,
260
+ num_experts=num_experts,
261
+ hidden_size=hidden_size,
262
+ intermediate_size=intermediate_size,
263
+ num_tokens=num_tokens,
264
+ seed=seed,
265
+ renormalize_topk_logits=renormalize_topk_logits,
266
+ bt=64,
267
+ bf=768,
268
+ bd1=2048,
269
+ bd2=2048,
270
+ btc=64,
271
+ bfc=768,
272
+ bd1c=2048,
273
+ bd2c=2048,
274
+ act_fn="silu",
275
+ atol=5e-2,
276
+ rtol=5e-2,
277
+ )
278
+
279
+ def test_benchmark_qwen_30b_a3b(self):
280
+ num_experts = 128
281
+ top_k = 8
282
+ hidden_size = 2048
283
+ intermediate_size = 768
284
+ dtype = jnp.bfloat16
285
+ num_tokens = 512
286
+ seed = 54321
287
+ renormalize_topk_logits = True
288
+ self._test_moe(
289
+ dtype=dtype,
290
+ top_k=top_k,
291
+ num_experts=num_experts,
292
+ hidden_size=hidden_size,
293
+ intermediate_size=intermediate_size,
294
+ num_tokens=num_tokens,
295
+ seed=seed,
296
+ renormalize_topk_logits=renormalize_topk_logits,
297
+ bt=16,
298
+ bf=384,
299
+ bd1=512,
300
+ bd2=512,
301
+ btc=16,
302
+ bfc=384,
303
+ bd1c=256,
304
+ bd2c=256,
305
+ act_fn="silu",
306
+ atol=5e-2,
307
+ rtol=5e-2,
308
+ )
309
+
310
+ @parameterized.product(
311
+ w_dtype=[jnp.int8, jnp.float8_e5m2, jnp.float4_e2m1fn], )
312
+ def test_sub_channel_quantization(self, w_dtype):
313
+ if w_dtype in (
314
+ jnp.float8_e5m2,
315
+ jnp.float4_e2m1fn,
316
+ ) and not jtu.is_device_tpu_at_least(version=7):
317
+ self.skipTest("Expect TPUv7+")
318
+ dtype = jnp.bfloat16
319
+ top_k = 8
320
+ num_experts = 128
321
+ hidden_size = 1024
322
+ intermediate_size = 1024
323
+ num_tokens = 8 * 32
324
+ self._test_moe(
325
+ dtype=dtype,
326
+ top_k=top_k,
327
+ num_experts=num_experts,
328
+ hidden_size=hidden_size,
329
+ intermediate_size=intermediate_size,
330
+ num_tokens=num_tokens,
331
+ seed=1234,
332
+ renormalize_topk_logits=False,
333
+ w_dtype=w_dtype,
334
+ subc_quant_wsz=256,
335
+ bt=32,
336
+ bf=1024,
337
+ bd1=1024,
338
+ bd2=1024,
339
+ btc=32,
340
+ bfc=256,
341
+ bd1c=256,
342
+ bd2c=256,
343
+ )
344
+
345
+ def test_bias(self):
346
+ dtype = jnp.bfloat16
347
+ top_k = 8
348
+ num_experts = 128
349
+ hidden_size = 1024
350
+ intermediate_size = 1024
351
+ num_tokens = 8 * 32
352
+ self._test_moe(
353
+ dtype=dtype,
354
+ top_k=top_k,
355
+ num_experts=num_experts,
356
+ hidden_size=hidden_size,
357
+ intermediate_size=intermediate_size,
358
+ num_tokens=num_tokens,
359
+ seed=1234,
360
+ renormalize_topk_logits=False,
361
+ has_bias=True,
362
+ bt=32,
363
+ bf=512,
364
+ bd1=512,
365
+ bd2=512,
366
+ btc=32,
367
+ bfc=256,
368
+ bd1c=256,
369
+ bd2c=256,
370
+ )
371
+
372
+
373
+ if __name__ == "__main__":
374
+ absltest.main(testLoader=jtu.JaxTestLoader())