tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,1482 @@
1
+ # Copyright 2025 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Auto-tuned block sizes for ragged paged attention."""
15
+
16
+ import jax
17
+ import jax.numpy as jnp
18
+
19
+ # The page size is too small. We only have 32 SREGs in TC. If the pages
20
+ # per seq is too large, SREGs will spill.
21
+ MAX_PAGES_PER_SEQ = 16
22
+
23
+ # key:
24
+ # - q_dtype_name
25
+ # - kv_dtype_name
26
+ # - num_q_heads_per_blk
27
+ # - num_kv_heads_per_blk
28
+ # - head_dim
29
+ # - page_size
30
+ # - max_num_batched_tokens
31
+ # - max_model_len = page_size * pages_per_seq
32
+ # value:
33
+ # - num_kv_pages_per_block
34
+ # - num_queries_per_block
35
+ TUNED_BLOCK_SIZES = {
36
+ 'TPU v6': {
37
+ # go/keep-sorted start
38
+ ('bfloat16', 'bfloat16', 12, 2, 128, 128, 1024, 1024): (8, 32),
39
+ ('bfloat16', 'bfloat16', 12, 2, 128, 128, 1024, 1280): (4, 32),
40
+ ('bfloat16', 'bfloat16', 12, 2, 128, 128, 1024, 2048): (16, 32),
41
+ ('bfloat16', 'bfloat16', 12, 2, 128, 128, 1024, 512): (4, 32),
42
+ ('bfloat16', 'bfloat16', 12, 2, 128, 128, 2048, 1024): (8, 32),
43
+ ('bfloat16', 'bfloat16', 12, 2, 128, 128, 2048, 1280): (8, 32),
44
+ ('bfloat16', 'bfloat16', 12, 2, 128, 128, 2048, 2048): (16, 32),
45
+ ('bfloat16', 'bfloat16', 12, 2, 128, 128, 2048, 512): (4, 32),
46
+ ('bfloat16', 'bfloat16', 12, 2, 128, 128, 4096, 1024): (8, 32),
47
+ ('bfloat16', 'bfloat16', 12, 2, 128, 128, 4096, 1280): (8, 32),
48
+ ('bfloat16', 'bfloat16', 12, 2, 128, 128, 4096, 2048): (16, 32),
49
+ ('bfloat16', 'bfloat16', 12, 2, 128, 128, 4096, 512): (4, 64),
50
+ ('bfloat16', 'bfloat16', 12, 2, 128, 128, 512, 1024): (8, 32),
51
+ ('bfloat16', 'bfloat16', 12, 2, 128, 128, 512, 1280): (8, 32),
52
+ ('bfloat16', 'bfloat16', 12, 2, 128, 128, 512, 2048): (16, 32),
53
+ ('bfloat16', 'bfloat16', 12, 2, 128, 128, 512, 512): (4, 32),
54
+ ('bfloat16', 'bfloat16', 12, 2, 128, 16, 1024, 128): (8, 32),
55
+ ('bfloat16', 'bfloat16', 12, 2, 128, 16, 1024, 256): (16, 32),
56
+ ('bfloat16', 'bfloat16', 12, 2, 128, 16, 1024, 64): (4, 64),
57
+ ('bfloat16', 'bfloat16', 12, 2, 128, 16, 2048, 128): (4, 32),
58
+ ('bfloat16', 'bfloat16', 12, 2, 128, 16, 2048, 256): (16, 32),
59
+ ('bfloat16', 'bfloat16', 12, 2, 128, 16, 2048, 64): (4, 32),
60
+ ('bfloat16', 'bfloat16', 12, 2, 128, 16, 4096, 128): (8, 64),
61
+ ('bfloat16', 'bfloat16', 12, 2, 128, 16, 4096, 256): (16, 32),
62
+ ('bfloat16', 'bfloat16', 12, 2, 128, 16, 4096, 64): (4, 64),
63
+ ('bfloat16', 'bfloat16', 12, 2, 128, 16, 512, 128): (8, 32),
64
+ ('bfloat16', 'bfloat16', 12, 2, 128, 16, 512, 256): (16, 32),
65
+ ('bfloat16', 'bfloat16', 12, 2, 128, 16, 512, 64): (4, 64),
66
+ ('bfloat16', 'bfloat16', 12, 2, 128, 256, 1024, 1024): (4, 32),
67
+ ('bfloat16', 'bfloat16', 12, 2, 128, 256, 1024, 1280): (4, 32),
68
+ ('bfloat16', 'bfloat16', 12, 2, 128, 256, 1024, 2048): (8, 32),
69
+ ('bfloat16', 'bfloat16', 12, 2, 128, 256, 1024, 4096): (16, 32),
70
+ ('bfloat16', 'bfloat16', 12, 2, 128, 256, 2048, 1024): (4, 32),
71
+ ('bfloat16', 'bfloat16', 12, 2, 128, 256, 2048, 1280): (4, 32),
72
+ ('bfloat16', 'bfloat16', 12, 2, 128, 256, 2048, 2048): (8, 32),
73
+ ('bfloat16', 'bfloat16', 12, 2, 128, 256, 2048, 4096): (16, 32),
74
+ ('bfloat16', 'bfloat16', 12, 2, 128, 256, 4096, 1024): (4, 32),
75
+ ('bfloat16', 'bfloat16', 12, 2, 128, 256, 4096, 1280): (4, 32),
76
+ ('bfloat16', 'bfloat16', 12, 2, 128, 256, 4096, 2048): (8, 32),
77
+ ('bfloat16', 'bfloat16', 12, 2, 128, 256, 4096, 4096): (16, 32),
78
+ ('bfloat16', 'bfloat16', 12, 2, 128, 256, 512, 1024): (4, 32),
79
+ ('bfloat16', 'bfloat16', 12, 2, 128, 256, 512, 1280): (4, 32),
80
+ ('bfloat16', 'bfloat16', 12, 2, 128, 256, 512, 2048): (8, 32),
81
+ ('bfloat16', 'bfloat16', 12, 2, 128, 256, 512, 4096): (16, 32),
82
+ ('bfloat16', 'bfloat16', 12, 2, 128, 32, 1024, 128): (4, 32),
83
+ ('bfloat16', 'bfloat16', 12, 2, 128, 32, 1024, 256): (8, 64),
84
+ ('bfloat16', 'bfloat16', 12, 2, 128, 32, 1024, 512): (16, 32),
85
+ ('bfloat16', 'bfloat16', 12, 2, 128, 32, 2048, 128): (4, 32),
86
+ ('bfloat16', 'bfloat16', 12, 2, 128, 32, 2048, 256): (8, 32),
87
+ ('bfloat16', 'bfloat16', 12, 2, 128, 32, 2048, 512): (16, 32),
88
+ ('bfloat16', 'bfloat16', 12, 2, 128, 32, 4096, 128): (4, 64),
89
+ ('bfloat16', 'bfloat16', 12, 2, 128, 32, 4096, 256): (8, 32),
90
+ ('bfloat16', 'bfloat16', 12, 2, 128, 32, 4096, 512): (16, 32),
91
+ ('bfloat16', 'bfloat16', 12, 2, 128, 32, 512, 128): (4, 32),
92
+ ('bfloat16', 'bfloat16', 12, 2, 128, 32, 512, 256): (8, 32),
93
+ ('bfloat16', 'bfloat16', 12, 2, 128, 32, 512, 512): (16, 32),
94
+ ('bfloat16', 'bfloat16', 12, 2, 128, 64, 1024, 1024): (16, 32),
95
+ ('bfloat16', 'bfloat16', 12, 2, 128, 64, 1024, 256): (4, 32),
96
+ ('bfloat16', 'bfloat16', 12, 2, 128, 64, 1024, 512): (8, 32),
97
+ ('bfloat16', 'bfloat16', 12, 2, 128, 64, 2048, 1024): (16, 32),
98
+ ('bfloat16', 'bfloat16', 12, 2, 128, 64, 2048, 256): (4, 32),
99
+ ('bfloat16', 'bfloat16', 12, 2, 128, 64, 2048, 512): (8, 32),
100
+ ('bfloat16', 'bfloat16', 12, 2, 128, 64, 4096, 1024): (16, 32),
101
+ ('bfloat16', 'bfloat16', 12, 2, 128, 64, 4096, 256): (4, 32),
102
+ ('bfloat16', 'bfloat16', 12, 2, 128, 64, 4096, 512): (8, 32),
103
+ ('bfloat16', 'bfloat16', 12, 2, 128, 64, 512, 1024): (16, 32),
104
+ ('bfloat16', 'bfloat16', 12, 2, 128, 64, 512, 256): (4, 32),
105
+ ('bfloat16', 'bfloat16', 12, 2, 128, 64, 512, 512): (8, 32),
106
+ ('bfloat16', 'bfloat16', 28, 4, 128, 128, 1024, 1024): (8, 32),
107
+ ('bfloat16', 'bfloat16', 28, 4, 128, 128, 1024, 1280): (8, 32),
108
+ ('bfloat16', 'bfloat16', 28, 4, 128, 128, 1024, 2048): (16, 32),
109
+ ('bfloat16', 'bfloat16', 28, 4, 128, 128, 1024, 512): (4, 32),
110
+ ('bfloat16', 'bfloat16', 28, 4, 128, 128, 2048, 1024): (8, 32),
111
+ ('bfloat16', 'bfloat16', 28, 4, 128, 128, 2048, 1280): (8, 32),
112
+ ('bfloat16', 'bfloat16', 28, 4, 128, 128, 2048, 2048): (16, 32),
113
+ ('bfloat16', 'bfloat16', 28, 4, 128, 128, 2048, 512): (4, 32),
114
+ ('bfloat16', 'bfloat16', 28, 4, 128, 128, 4096, 1024): (8, 32),
115
+ ('bfloat16', 'bfloat16', 28, 4, 128, 128, 4096, 1280): (8, 32),
116
+ ('bfloat16', 'bfloat16', 28, 4, 128, 128, 4096, 2048): (16, 32),
117
+ ('bfloat16', 'bfloat16', 28, 4, 128, 128, 4096, 512): (4, 32),
118
+ ('bfloat16', 'bfloat16', 28, 4, 128, 128, 512, 1024): (8, 32),
119
+ ('bfloat16', 'bfloat16', 28, 4, 128, 128, 512, 1280): (8, 32),
120
+ ('bfloat16', 'bfloat16', 28, 4, 128, 128, 512, 2048): (16, 32),
121
+ ('bfloat16', 'bfloat16', 28, 4, 128, 128, 512, 512): (4, 32),
122
+ ('bfloat16', 'bfloat16', 28, 4, 128, 16, 1024, 128): (8, 32),
123
+ ('bfloat16', 'bfloat16', 28, 4, 128, 16, 1024, 256): (16, 32),
124
+ ('bfloat16', 'bfloat16', 28, 4, 128, 16, 1024, 64): (4, 32),
125
+ ('bfloat16', 'bfloat16', 28, 4, 128, 16, 2048, 128): (8, 32),
126
+ ('bfloat16', 'bfloat16', 28, 4, 128, 16, 2048, 256): (16, 32),
127
+ ('bfloat16', 'bfloat16', 28, 4, 128, 16, 2048, 64): (4, 32),
128
+ ('bfloat16', 'bfloat16', 28, 4, 128, 16, 4096, 128): (8, 32),
129
+ ('bfloat16', 'bfloat16', 28, 4, 128, 16, 4096, 256): (16, 32),
130
+ ('bfloat16', 'bfloat16', 28, 4, 128, 16, 4096, 64): (4, 32),
131
+ ('bfloat16', 'bfloat16', 28, 4, 128, 16, 512, 128): (8, 32),
132
+ ('bfloat16', 'bfloat16', 28, 4, 128, 16, 512, 256): (16, 32),
133
+ ('bfloat16', 'bfloat16', 28, 4, 128, 16, 512, 64): (4, 32),
134
+ ('bfloat16', 'bfloat16', 28, 4, 128, 256, 1024, 1024): (4, 32),
135
+ ('bfloat16', 'bfloat16', 28, 4, 128, 256, 1024, 1280): (4, 32),
136
+ ('bfloat16', 'bfloat16', 28, 4, 128, 256, 1024, 2048): (8, 32),
137
+ ('bfloat16', 'bfloat16', 28, 4, 128, 256, 1024, 4096): (16, 32),
138
+ ('bfloat16', 'bfloat16', 28, 4, 128, 256, 2048, 1024): (4, 32),
139
+ ('bfloat16', 'bfloat16', 28, 4, 128, 256, 2048, 1280): (4, 32),
140
+ ('bfloat16', 'bfloat16', 28, 4, 128, 256, 2048, 2048): (8, 32),
141
+ ('bfloat16', 'bfloat16', 28, 4, 128, 256, 2048, 4096): (16, 32),
142
+ ('bfloat16', 'bfloat16', 28, 4, 128, 256, 4096, 1024): (4, 32),
143
+ ('bfloat16', 'bfloat16', 28, 4, 128, 256, 4096, 1280): (4, 32),
144
+ ('bfloat16', 'bfloat16', 28, 4, 128, 256, 4096, 2048): (8, 32),
145
+ ('bfloat16', 'bfloat16', 28, 4, 128, 256, 4096, 4096): (16, 32),
146
+ ('bfloat16', 'bfloat16', 28, 4, 128, 256, 512, 1024): (4, 32),
147
+ ('bfloat16', 'bfloat16', 28, 4, 128, 256, 512, 1280): (4, 32),
148
+ ('bfloat16', 'bfloat16', 28, 4, 128, 256, 512, 2048): (8, 32),
149
+ ('bfloat16', 'bfloat16', 28, 4, 128, 256, 512, 4096): (16, 32),
150
+ ('bfloat16', 'bfloat16', 28, 4, 128, 32, 1024, 128): (4, 32),
151
+ ('bfloat16', 'bfloat16', 28, 4, 128, 32, 1024, 256): (8, 32),
152
+ ('bfloat16', 'bfloat16', 28, 4, 128, 32, 1024, 512): (16, 32),
153
+ ('bfloat16', 'bfloat16', 28, 4, 128, 32, 2048, 128): (4, 32),
154
+ ('bfloat16', 'bfloat16', 28, 4, 128, 32, 2048, 256): (8, 32),
155
+ ('bfloat16', 'bfloat16', 28, 4, 128, 32, 2048, 512): (16, 32),
156
+ ('bfloat16', 'bfloat16', 28, 4, 128, 32, 4096, 128): (4, 32),
157
+ ('bfloat16', 'bfloat16', 28, 4, 128, 32, 4096, 256): (8, 32),
158
+ ('bfloat16', 'bfloat16', 28, 4, 128, 32, 4096, 512): (16, 32),
159
+ ('bfloat16', 'bfloat16', 28, 4, 128, 32, 512, 128): (4, 32),
160
+ ('bfloat16', 'bfloat16', 28, 4, 128, 32, 512, 256): (8, 32),
161
+ ('bfloat16', 'bfloat16', 28, 4, 128, 32, 512, 512): (16, 32),
162
+ ('bfloat16', 'bfloat16', 28, 4, 128, 64, 1024, 1024): (16, 32),
163
+ ('bfloat16', 'bfloat16', 28, 4, 128, 64, 1024, 256): (4, 32),
164
+ ('bfloat16', 'bfloat16', 28, 4, 128, 64, 1024, 512): (8, 32),
165
+ ('bfloat16', 'bfloat16', 28, 4, 128, 64, 2048, 1024): (16, 32),
166
+ ('bfloat16', 'bfloat16', 28, 4, 128, 64, 2048, 256): (4, 32),
167
+ ('bfloat16', 'bfloat16', 28, 4, 128, 64, 2048, 512): (8, 32),
168
+ ('bfloat16', 'bfloat16', 28, 4, 128, 64, 4096, 1024): (16, 32),
169
+ ('bfloat16', 'bfloat16', 28, 4, 128, 64, 4096, 256): (4, 32),
170
+ ('bfloat16', 'bfloat16', 28, 4, 128, 64, 4096, 512): (8, 32),
171
+ ('bfloat16', 'bfloat16', 28, 4, 128, 64, 512, 1024): (16, 32),
172
+ ('bfloat16', 'bfloat16', 28, 4, 128, 64, 512, 256): (4, 32),
173
+ ('bfloat16', 'bfloat16', 28, 4, 128, 64, 512, 512): (8, 32),
174
+ ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 1024): (8, 32),
175
+ ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 1280): (4, 32),
176
+ ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 512): (4, 32),
177
+ ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 1024): (8, 32),
178
+ ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 1280): (4, 32),
179
+ ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 512): (4, 32),
180
+ ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 1024): (8, 32),
181
+ ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 1280): (4, 32),
182
+ ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 512): (4, 32),
183
+ ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 1024): (8, 32),
184
+ ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 1280): (4, 32),
185
+ ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 512): (4, 32),
186
+ ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 128): (4, 32),
187
+ ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 256): (16, 32),
188
+ ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 64): (4, 32),
189
+ ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 128): (8, 32),
190
+ ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 256): (16, 32),
191
+ ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 64): (4, 32),
192
+ ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 128): (8, 32),
193
+ ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 256): (16, 32),
194
+ ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 64): (4, 32),
195
+ ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 128): (8, 32),
196
+ ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 256): (16, 32),
197
+ ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 64): (4, 32),
198
+ ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 1024): (4, 32),
199
+ ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 1280): (4, 32),
200
+ ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 1024): (4, 32),
201
+ ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 1280): (4, 32),
202
+ ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 1024): (4, 32),
203
+ ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 1280): (4, 32),
204
+ ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 1024): (4, 32),
205
+ ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 1280): (4, 32),
206
+ ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 128): (4, 32),
207
+ ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 256): (8, 32),
208
+ ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 512): (16, 32),
209
+ ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 128): (4, 32),
210
+ ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 256): (8, 32),
211
+ ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 512): (16, 32),
212
+ ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 128): (4, 32),
213
+ ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 256): (8, 32),
214
+ ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 512): (16, 32),
215
+ ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 128): (4, 32),
216
+ ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 256): (8, 32),
217
+ ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 512): (16, 32),
218
+ ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 1024): (16, 32),
219
+ ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 256): (4, 32),
220
+ ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 512): (8, 32),
221
+ ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 1024): (16, 32),
222
+ ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 256): (4, 32),
223
+ ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 512): (8, 32),
224
+ ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 1024): (16, 64),
225
+ ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 256): (4, 32),
226
+ ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 512): (8, 32),
227
+ ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 1024): (16, 32),
228
+ ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 256): (4, 32),
229
+ ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 512): (8, 32),
230
+ ('bfloat16', 'bfloat16', 4, 1, 128, 128, 1024, 1024): (8, 128),
231
+ ('bfloat16', 'bfloat16', 4, 1, 128, 128, 1024, 1280): (4, 128),
232
+ ('bfloat16', 'bfloat16', 4, 1, 128, 128, 1024, 2048): (8, 64),
233
+ ('bfloat16', 'bfloat16', 4, 1, 128, 128, 1024, 512): (4, 32),
234
+ ('bfloat16', 'bfloat16', 4, 1, 128, 128, 2048, 1024): (8, 64),
235
+ ('bfloat16', 'bfloat16', 4, 1, 128, 128, 2048, 1280): (4, 64),
236
+ ('bfloat16', 'bfloat16', 4, 1, 128, 128, 2048, 2048): (8, 64),
237
+ ('bfloat16', 'bfloat16', 4, 1, 128, 128, 2048, 512): (4, 64),
238
+ ('bfloat16', 'bfloat16', 4, 1, 128, 128, 4096, 1024): (8, 128),
239
+ ('bfloat16', 'bfloat16', 4, 1, 128, 128, 4096, 1280): (8, 128),
240
+ ('bfloat16', 'bfloat16', 4, 1, 128, 128, 4096, 2048): (16, 128),
241
+ ('bfloat16', 'bfloat16', 4, 1, 128, 128, 4096, 512): (4, 128),
242
+ ('bfloat16', 'bfloat16', 4, 1, 128, 128, 512, 1024): (8, 32),
243
+ ('bfloat16', 'bfloat16', 4, 1, 128, 128, 512, 1280): (4, 64),
244
+ ('bfloat16', 'bfloat16', 4, 1, 128, 128, 512, 2048): (16, 32),
245
+ ('bfloat16', 'bfloat16', 4, 1, 128, 128, 512, 512): (4, 32),
246
+ ('bfloat16', 'bfloat16', 4, 1, 128, 16, 1024, 128): (8, 128),
247
+ ('bfloat16', 'bfloat16', 4, 1, 128, 16, 1024, 256): (16, 32),
248
+ ('bfloat16', 'bfloat16', 4, 1, 128, 16, 1024, 64): (4, 64),
249
+ ('bfloat16', 'bfloat16', 4, 1, 128, 16, 2048, 128): (8, 128),
250
+ ('bfloat16', 'bfloat16', 4, 1, 128, 16, 2048, 256): (16, 64),
251
+ ('bfloat16', 'bfloat16', 4, 1, 128, 16, 2048, 64): (4, 128),
252
+ ('bfloat16', 'bfloat16', 4, 1, 128, 16, 4096, 128): (8, 128),
253
+ ('bfloat16', 'bfloat16', 4, 1, 128, 16, 4096, 256): (16, 256),
254
+ ('bfloat16', 'bfloat16', 4, 1, 128, 16, 4096, 64): (4, 256),
255
+ ('bfloat16', 'bfloat16', 4, 1, 128, 16, 512, 128): (8, 256),
256
+ ('bfloat16', 'bfloat16', 4, 1, 128, 16, 512, 256): (16, 128),
257
+ ('bfloat16', 'bfloat16', 4, 1, 128, 16, 512, 64): (4, 64),
258
+ ('bfloat16', 'bfloat16', 4, 1, 128, 256, 1024, 1024): (4, 64),
259
+ ('bfloat16', 'bfloat16', 4, 1, 128, 256, 1024, 1280): (4, 128),
260
+ ('bfloat16', 'bfloat16', 4, 1, 128, 256, 1024, 2048): (4, 64),
261
+ ('bfloat16', 'bfloat16', 4, 1, 128, 256, 1024, 4096): (8, 64),
262
+ ('bfloat16', 'bfloat16', 4, 1, 128, 256, 2048, 1024): (4, 64),
263
+ ('bfloat16', 'bfloat16', 4, 1, 128, 256, 2048, 1280): (4, 128),
264
+ ('bfloat16', 'bfloat16', 4, 1, 128, 256, 2048, 2048): (8, 64),
265
+ ('bfloat16', 'bfloat16', 4, 1, 128, 256, 2048, 4096): (16, 32),
266
+ ('bfloat16', 'bfloat16', 4, 1, 128, 256, 4096, 1024): (4, 64),
267
+ ('bfloat16', 'bfloat16', 4, 1, 128, 256, 4096, 1280): (4, 128),
268
+ ('bfloat16', 'bfloat16', 4, 1, 128, 256, 4096, 2048): (8, 64),
269
+ ('bfloat16', 'bfloat16', 4, 1, 128, 256, 4096, 4096): (16, 64),
270
+ ('bfloat16', 'bfloat16', 4, 1, 128, 256, 512, 1024): (4, 32),
271
+ ('bfloat16', 'bfloat16', 4, 1, 128, 256, 512, 1280): (4, 64),
272
+ ('bfloat16', 'bfloat16', 4, 1, 128, 256, 512, 2048): (8, 64),
273
+ ('bfloat16', 'bfloat16', 4, 1, 128, 256, 512, 4096): (16, 32),
274
+ ('bfloat16', 'bfloat16', 4, 1, 128, 32, 1024, 128): (4, 32),
275
+ ('bfloat16', 'bfloat16', 4, 1, 128, 32, 1024, 256): (8, 128),
276
+ ('bfloat16', 'bfloat16', 4, 1, 128, 32, 1024, 512): (16, 64),
277
+ ('bfloat16', 'bfloat16', 4, 1, 128, 32, 2048, 128): (4, 128),
278
+ ('bfloat16', 'bfloat16', 4, 1, 128, 32, 2048, 256): (8, 64),
279
+ ('bfloat16', 'bfloat16', 4, 1, 128, 32, 2048, 512): (16, 128),
280
+ ('bfloat16', 'bfloat16', 4, 1, 128, 32, 4096, 128): (4, 64),
281
+ ('bfloat16', 'bfloat16', 4, 1, 128, 32, 4096, 256): (8, 128),
282
+ ('bfloat16', 'bfloat16', 4, 1, 128, 32, 4096, 512): (16, 256),
283
+ ('bfloat16', 'bfloat16', 4, 1, 128, 32, 512, 128): (4, 64),
284
+ ('bfloat16', 'bfloat16', 4, 1, 128, 32, 512, 256): (4, 64),
285
+ ('bfloat16', 'bfloat16', 4, 1, 128, 32, 512, 512): (16, 64),
286
+ ('bfloat16', 'bfloat16', 4, 1, 128, 64, 1024, 1024): (16, 64),
287
+ ('bfloat16', 'bfloat16', 4, 1, 128, 64, 1024, 256): (4, 64),
288
+ ('bfloat16', 'bfloat16', 4, 1, 128, 64, 1024, 512): (8, 64),
289
+ ('bfloat16', 'bfloat16', 4, 1, 128, 64, 2048, 1024): (16, 128),
290
+ ('bfloat16', 'bfloat16', 4, 1, 128, 64, 2048, 256): (4, 64),
291
+ ('bfloat16', 'bfloat16', 4, 1, 128, 64, 2048, 512): (8, 64),
292
+ ('bfloat16', 'bfloat16', 4, 1, 128, 64, 4096, 1024): (16, 128),
293
+ ('bfloat16', 'bfloat16', 4, 1, 128, 64, 4096, 256): (4, 64),
294
+ ('bfloat16', 'bfloat16', 4, 1, 128, 64, 4096, 512): (8, 128),
295
+ ('bfloat16', 'bfloat16', 4, 1, 128, 64, 512, 1024): (16, 32),
296
+ ('bfloat16', 'bfloat16', 4, 1, 128, 64, 512, 256): (4, 32),
297
+ ('bfloat16', 'bfloat16', 4, 1, 128, 64, 512, 512): (8, 128),
298
+ ('bfloat16', 'bfloat16', 4, 2, 128, 128, 1024, 1024): (8, 64),
299
+ ('bfloat16', 'bfloat16', 4, 2, 128, 128, 1024, 1280): (8, 64),
300
+ ('bfloat16', 'bfloat16', 4, 2, 128, 128, 1024, 2048): (16, 64),
301
+ ('bfloat16', 'bfloat16', 4, 2, 128, 128, 1024, 512): (4, 64),
302
+ ('bfloat16', 'bfloat16', 4, 2, 128, 128, 2048, 1024): (8, 64),
303
+ ('bfloat16', 'bfloat16', 4, 2, 128, 128, 2048, 1280): (8, 128),
304
+ ('bfloat16', 'bfloat16', 4, 2, 128, 128, 2048, 2048): (16, 64),
305
+ ('bfloat16', 'bfloat16', 4, 2, 128, 128, 2048, 512): (4, 128),
306
+ ('bfloat16', 'bfloat16', 4, 2, 128, 128, 4096, 1024): (8, 64),
307
+ ('bfloat16', 'bfloat16', 4, 2, 128, 128, 4096, 1280): (4, 64),
308
+ ('bfloat16', 'bfloat16', 4, 2, 128, 128, 4096, 2048): (16, 64),
309
+ ('bfloat16', 'bfloat16', 4, 2, 128, 128, 4096, 512): (4, 128),
310
+ ('bfloat16', 'bfloat16', 4, 2, 128, 128, 512, 1024): (8, 32),
311
+ ('bfloat16', 'bfloat16', 4, 2, 128, 128, 512, 1280): (4, 64),
312
+ ('bfloat16', 'bfloat16', 4, 2, 128, 128, 512, 2048): (16, 64),
313
+ ('bfloat16', 'bfloat16', 4, 2, 128, 128, 512, 512): (4, 32),
314
+ ('bfloat16', 'bfloat16', 4, 2, 128, 16, 1024, 128): (8, 32),
315
+ ('bfloat16', 'bfloat16', 4, 2, 128, 16, 1024, 256): (16, 64),
316
+ ('bfloat16', 'bfloat16', 4, 2, 128, 16, 1024, 64): (4, 64),
317
+ ('bfloat16', 'bfloat16', 4, 2, 128, 16, 2048, 128): (8, 64),
318
+ ('bfloat16', 'bfloat16', 4, 2, 128, 16, 2048, 256): (16, 64),
319
+ ('bfloat16', 'bfloat16', 4, 2, 128, 16, 2048, 64): (4, 64),
320
+ ('bfloat16', 'bfloat16', 4, 2, 128, 16, 4096, 128): (8, 128),
321
+ ('bfloat16', 'bfloat16', 4, 2, 128, 16, 4096, 256): (16, 128),
322
+ ('bfloat16', 'bfloat16', 4, 2, 128, 16, 4096, 64): (4, 128),
323
+ ('bfloat16', 'bfloat16', 4, 2, 128, 16, 512, 128): (4, 32),
324
+ ('bfloat16', 'bfloat16', 4, 2, 128, 16, 512, 256): (16, 32),
325
+ ('bfloat16', 'bfloat16', 4, 2, 128, 16, 512, 64): (4, 64),
326
+ ('bfloat16', 'bfloat16', 4, 2, 128, 256, 1024, 1024): (4, 64),
327
+ ('bfloat16', 'bfloat16', 4, 2, 128, 256, 1024, 1280): (4, 64),
328
+ ('bfloat16', 'bfloat16', 4, 2, 128, 256, 1024, 2048): (8, 64),
329
+ ('bfloat16', 'bfloat16', 4, 2, 128, 256, 1024, 4096): (16, 64),
330
+ ('bfloat16', 'bfloat16', 4, 2, 128, 256, 2048, 1024): (4, 64),
331
+ ('bfloat16', 'bfloat16', 4, 2, 128, 256, 2048, 1280): (4, 64),
332
+ ('bfloat16', 'bfloat16', 4, 2, 128, 256, 2048, 2048): (8, 128),
333
+ ('bfloat16', 'bfloat16', 4, 2, 128, 256, 2048, 4096): (16, 64),
334
+ ('bfloat16', 'bfloat16', 4, 2, 128, 256, 4096, 1024): (4, 64),
335
+ ('bfloat16', 'bfloat16', 4, 2, 128, 256, 4096, 1280): (4, 128),
336
+ ('bfloat16', 'bfloat16', 4, 2, 128, 256, 4096, 2048): (8, 64),
337
+ ('bfloat16', 'bfloat16', 4, 2, 128, 256, 4096, 4096): (16, 64),
338
+ ('bfloat16', 'bfloat16', 4, 2, 128, 256, 512, 1024): (4, 32),
339
+ ('bfloat16', 'bfloat16', 4, 2, 128, 256, 512, 1280): (4, 64),
340
+ ('bfloat16', 'bfloat16', 4, 2, 128, 256, 512, 2048): (8, 32),
341
+ ('bfloat16', 'bfloat16', 4, 2, 128, 256, 512, 4096): (16, 64),
342
+ ('bfloat16', 'bfloat16', 4, 2, 128, 32, 1024, 128): (4, 64),
343
+ ('bfloat16', 'bfloat16', 4, 2, 128, 32, 1024, 256): (8, 32),
344
+ ('bfloat16', 'bfloat16', 4, 2, 128, 32, 1024, 512): (16, 64),
345
+ ('bfloat16', 'bfloat16', 4, 2, 128, 32, 2048, 128): (4, 64),
346
+ ('bfloat16', 'bfloat16', 4, 2, 128, 32, 2048, 256): (8, 64),
347
+ ('bfloat16', 'bfloat16', 4, 2, 128, 32, 2048, 512): (16, 128),
348
+ ('bfloat16', 'bfloat16', 4, 2, 128, 32, 4096, 128): (4, 64),
349
+ ('bfloat16', 'bfloat16', 4, 2, 128, 32, 4096, 256): (8, 128),
350
+ ('bfloat16', 'bfloat16', 4, 2, 128, 32, 4096, 512): (16, 64),
351
+ ('bfloat16', 'bfloat16', 4, 2, 128, 32, 512, 128): (4, 64),
352
+ ('bfloat16', 'bfloat16', 4, 2, 128, 32, 512, 256): (8, 64),
353
+ ('bfloat16', 'bfloat16', 4, 2, 128, 32, 512, 512): (16, 128),
354
+ ('bfloat16', 'bfloat16', 4, 2, 128, 64, 1024, 1024): (16, 128),
355
+ ('bfloat16', 'bfloat16', 4, 2, 128, 64, 1024, 256): (4, 128),
356
+ ('bfloat16', 'bfloat16', 4, 2, 128, 64, 1024, 512): (8, 128),
357
+ ('bfloat16', 'bfloat16', 4, 2, 128, 64, 2048, 1024): (16, 64),
358
+ ('bfloat16', 'bfloat16', 4, 2, 128, 64, 2048, 256): (4, 64),
359
+ ('bfloat16', 'bfloat16', 4, 2, 128, 64, 2048, 512): (8, 64),
360
+ ('bfloat16', 'bfloat16', 4, 2, 128, 64, 4096, 1024): (16, 128),
361
+ ('bfloat16', 'bfloat16', 4, 2, 128, 64, 4096, 256): (4, 128),
362
+ ('bfloat16', 'bfloat16', 4, 2, 128, 64, 4096, 512): (8, 128),
363
+ ('bfloat16', 'bfloat16', 4, 2, 128, 64, 512, 1024): (16, 64),
364
+ ('bfloat16', 'bfloat16', 4, 2, 128, 64, 512, 256): (4, 64),
365
+ ('bfloat16', 'bfloat16', 4, 2, 128, 64, 512, 512): (8, 64),
366
+ ('bfloat16', 'bfloat16', 5, 1, 128, 128, 1024, 1024): (8, 32),
367
+ ('bfloat16', 'bfloat16', 5, 1, 128, 128, 1024, 1280): (8, 32),
368
+ ('bfloat16', 'bfloat16', 5, 1, 128, 128, 1024, 2048): (16, 32),
369
+ ('bfloat16', 'bfloat16', 5, 1, 128, 128, 1024, 512): (4, 32),
370
+ ('bfloat16', 'bfloat16', 5, 1, 128, 128, 2048, 1024): (8, 32),
371
+ ('bfloat16', 'bfloat16', 5, 1, 128, 128, 2048, 1280): (8, 32),
372
+ ('bfloat16', 'bfloat16', 5, 1, 128, 128, 2048, 2048): (16, 32),
373
+ ('bfloat16', 'bfloat16', 5, 1, 128, 128, 2048, 512): (4, 32),
374
+ ('bfloat16', 'bfloat16', 5, 1, 128, 128, 4096, 1024): (8, 32),
375
+ ('bfloat16', 'bfloat16', 5, 1, 128, 128, 4096, 1280): (8, 32),
376
+ ('bfloat16', 'bfloat16', 5, 1, 128, 128, 4096, 2048): (16, 32),
377
+ ('bfloat16', 'bfloat16', 5, 1, 128, 128, 4096, 512): (4, 32),
378
+ ('bfloat16', 'bfloat16', 5, 1, 128, 128, 512, 1024): (8, 32),
379
+ ('bfloat16', 'bfloat16', 5, 1, 128, 128, 512, 1280): (8, 32),
380
+ ('bfloat16', 'bfloat16', 5, 1, 128, 128, 512, 2048): (16, 32),
381
+ ('bfloat16', 'bfloat16', 5, 1, 128, 128, 512, 512): (4, 32),
382
+ ('bfloat16', 'bfloat16', 5, 1, 128, 16, 1024, 128): (8, 32),
383
+ ('bfloat16', 'bfloat16', 5, 1, 128, 16, 1024, 256): (16, 32),
384
+ ('bfloat16', 'bfloat16', 5, 1, 128, 16, 1024, 64): (4, 32),
385
+ ('bfloat16', 'bfloat16', 5, 1, 128, 16, 2048, 128): (8, 32),
386
+ ('bfloat16', 'bfloat16', 5, 1, 128, 16, 2048, 256): (16, 32),
387
+ ('bfloat16', 'bfloat16', 5, 1, 128, 16, 2048, 64): (4, 32),
388
+ ('bfloat16', 'bfloat16', 5, 1, 128, 16, 4096, 128): (8, 32),
389
+ ('bfloat16', 'bfloat16', 5, 1, 128, 16, 4096, 256): (16, 64),
390
+ ('bfloat16', 'bfloat16', 5, 1, 128, 16, 4096, 64): (4, 32),
391
+ ('bfloat16', 'bfloat16', 5, 1, 128, 16, 512, 128): (8, 32),
392
+ ('bfloat16', 'bfloat16', 5, 1, 128, 16, 512, 256): (16, 32),
393
+ ('bfloat16', 'bfloat16', 5, 1, 128, 16, 512, 64): (4, 32),
394
+ ('bfloat16', 'bfloat16', 5, 1, 128, 256, 1024, 1024): (4, 32),
395
+ ('bfloat16', 'bfloat16', 5, 1, 128, 256, 1024, 1280): (4, 32),
396
+ ('bfloat16', 'bfloat16', 5, 1, 128, 256, 1024, 2048): (8, 32),
397
+ ('bfloat16', 'bfloat16', 5, 1, 128, 256, 1024, 4096): (16, 32),
398
+ ('bfloat16', 'bfloat16', 5, 1, 128, 256, 2048, 1024): (4, 32),
399
+ ('bfloat16', 'bfloat16', 5, 1, 128, 256, 2048, 1280): (4, 32),
400
+ ('bfloat16', 'bfloat16', 5, 1, 128, 256, 2048, 2048): (8, 32),
401
+ ('bfloat16', 'bfloat16', 5, 1, 128, 256, 2048, 4096): (16, 32),
402
+ ('bfloat16', 'bfloat16', 5, 1, 128, 256, 4096, 1024): (4, 32),
403
+ ('bfloat16', 'bfloat16', 5, 1, 128, 256, 4096, 1280): (4, 32),
404
+ ('bfloat16', 'bfloat16', 5, 1, 128, 256, 4096, 2048): (8, 32),
405
+ ('bfloat16', 'bfloat16', 5, 1, 128, 256, 4096, 4096): (16, 64),
406
+ ('bfloat16', 'bfloat16', 5, 1, 128, 256, 512, 1024): (4, 32),
407
+ ('bfloat16', 'bfloat16', 5, 1, 128, 256, 512, 1280): (4, 32),
408
+ ('bfloat16', 'bfloat16', 5, 1, 128, 256, 512, 2048): (8, 32),
409
+ ('bfloat16', 'bfloat16', 5, 1, 128, 256, 512, 4096): (16, 32),
410
+ ('bfloat16', 'bfloat16', 5, 1, 128, 32, 1024, 128): (4, 32),
411
+ ('bfloat16', 'bfloat16', 5, 1, 128, 32, 1024, 256): (8, 32),
412
+ ('bfloat16', 'bfloat16', 5, 1, 128, 32, 1024, 512): (16, 32),
413
+ ('bfloat16', 'bfloat16', 5, 1, 128, 32, 2048, 128): (4, 64),
414
+ ('bfloat16', 'bfloat16', 5, 1, 128, 32, 2048, 256): (8, 32),
415
+ ('bfloat16', 'bfloat16', 5, 1, 128, 32, 2048, 512): (16, 32),
416
+ ('bfloat16', 'bfloat16', 5, 1, 128, 32, 4096, 128): (4, 32),
417
+ ('bfloat16', 'bfloat16', 5, 1, 128, 32, 4096, 256): (8, 32),
418
+ ('bfloat16', 'bfloat16', 5, 1, 128, 32, 4096, 512): (16, 32),
419
+ ('bfloat16', 'bfloat16', 5, 1, 128, 32, 512, 128): (4, 32),
420
+ ('bfloat16', 'bfloat16', 5, 1, 128, 32, 512, 256): (8, 32),
421
+ ('bfloat16', 'bfloat16', 5, 1, 128, 32, 512, 512): (16, 32),
422
+ ('bfloat16', 'bfloat16', 5, 1, 128, 64, 1024, 1024): (16, 32),
423
+ ('bfloat16', 'bfloat16', 5, 1, 128, 64, 1024, 256): (4, 32),
424
+ ('bfloat16', 'bfloat16', 5, 1, 128, 64, 1024, 512): (8, 32),
425
+ ('bfloat16', 'bfloat16', 5, 1, 128, 64, 2048, 1024): (16, 32),
426
+ ('bfloat16', 'bfloat16', 5, 1, 128, 64, 2048, 256): (4, 32),
427
+ ('bfloat16', 'bfloat16', 5, 1, 128, 64, 2048, 512): (8, 32),
428
+ ('bfloat16', 'bfloat16', 5, 1, 128, 64, 4096, 1024): (16, 32),
429
+ ('bfloat16', 'bfloat16', 5, 1, 128, 64, 4096, 256): (4, 32),
430
+ ('bfloat16', 'bfloat16', 5, 1, 128, 64, 4096, 512): (8, 32),
431
+ ('bfloat16', 'bfloat16', 5, 1, 128, 64, 512, 1024): (16, 32),
432
+ ('bfloat16', 'bfloat16', 5, 1, 128, 64, 512, 256): (4, 32),
433
+ ('bfloat16', 'bfloat16', 5, 1, 128, 64, 512, 512): (8, 32),
434
+ ('bfloat16', 'bfloat16', 6, 1, 128, 128, 1024, 1024): (8, 32),
435
+ ('bfloat16', 'bfloat16', 6, 1, 128, 128, 1024, 1280): (8, 64),
436
+ ('bfloat16', 'bfloat16', 6, 1, 128, 128, 1024, 2048): (16, 32),
437
+ ('bfloat16', 'bfloat16', 6, 1, 128, 128, 1024, 512): (4, 64),
438
+ ('bfloat16', 'bfloat16', 6, 1, 128, 128, 2048, 1024): (8, 32),
439
+ ('bfloat16', 'bfloat16', 6, 1, 128, 128, 2048, 1280): (8, 64),
440
+ ('bfloat16', 'bfloat16', 6, 1, 128, 128, 2048, 2048): (16, 32),
441
+ ('bfloat16', 'bfloat16', 6, 1, 128, 128, 2048, 512): (4, 64),
442
+ ('bfloat16', 'bfloat16', 6, 1, 128, 128, 4096, 1024): (8, 64),
443
+ ('bfloat16', 'bfloat16', 6, 1, 128, 128, 4096, 1280): (8, 64),
444
+ ('bfloat16', 'bfloat16', 6, 1, 128, 128, 4096, 2048): (16, 32),
445
+ ('bfloat16', 'bfloat16', 6, 1, 128, 128, 4096, 512): (4, 32),
446
+ ('bfloat16', 'bfloat16', 6, 1, 128, 128, 512, 1024): (8, 64),
447
+ ('bfloat16', 'bfloat16', 6, 1, 128, 128, 512, 1280): (4, 32),
448
+ ('bfloat16', 'bfloat16', 6, 1, 128, 128, 512, 2048): (16, 32),
449
+ ('bfloat16', 'bfloat16', 6, 1, 128, 128, 512, 512): (4, 32),
450
+ ('bfloat16', 'bfloat16', 6, 1, 128, 16, 1024, 128): (8, 32),
451
+ ('bfloat16', 'bfloat16', 6, 1, 128, 16, 1024, 256): (16, 32),
452
+ ('bfloat16', 'bfloat16', 6, 1, 128, 16, 1024, 64): (4, 32),
453
+ ('bfloat16', 'bfloat16', 6, 1, 128, 16, 2048, 128): (8, 64),
454
+ ('bfloat16', 'bfloat16', 6, 1, 128, 16, 2048, 256): (16, 32),
455
+ ('bfloat16', 'bfloat16', 6, 1, 128, 16, 2048, 64): (4, 32),
456
+ ('bfloat16', 'bfloat16', 6, 1, 128, 16, 4096, 128): (8, 32),
457
+ ('bfloat16', 'bfloat16', 6, 1, 128, 16, 4096, 256): (16, 32),
458
+ ('bfloat16', 'bfloat16', 6, 1, 128, 16, 4096, 64): (4, 32),
459
+ ('bfloat16', 'bfloat16', 6, 1, 128, 16, 512, 128): (4, 32),
460
+ ('bfloat16', 'bfloat16', 6, 1, 128, 16, 512, 256): (16, 64),
461
+ ('bfloat16', 'bfloat16', 6, 1, 128, 16, 512, 64): (4, 64),
462
+ ('bfloat16', 'bfloat16', 6, 1, 128, 256, 1024, 1024): (4, 32),
463
+ ('bfloat16', 'bfloat16', 6, 1, 128, 256, 1024, 1280): (4, 32),
464
+ ('bfloat16', 'bfloat16', 6, 1, 128, 256, 1024, 2048): (8, 32),
465
+ ('bfloat16', 'bfloat16', 6, 1, 128, 256, 1024, 4096): (8, 32),
466
+ ('bfloat16', 'bfloat16', 6, 1, 128, 256, 2048, 1024): (4, 32),
467
+ ('bfloat16', 'bfloat16', 6, 1, 128, 256, 2048, 1280): (4, 32),
468
+ ('bfloat16', 'bfloat16', 6, 1, 128, 256, 2048, 2048): (8, 32),
469
+ ('bfloat16', 'bfloat16', 6, 1, 128, 256, 2048, 4096): (16, 32),
470
+ ('bfloat16', 'bfloat16', 6, 1, 128, 256, 4096, 1024): (4, 64),
471
+ ('bfloat16', 'bfloat16', 6, 1, 128, 256, 4096, 1280): (4, 32),
472
+ ('bfloat16', 'bfloat16', 6, 1, 128, 256, 4096, 2048): (8, 32),
473
+ ('bfloat16', 'bfloat16', 6, 1, 128, 256, 4096, 4096): (16, 32),
474
+ ('bfloat16', 'bfloat16', 6, 1, 128, 256, 512, 1024): (4, 32),
475
+ ('bfloat16', 'bfloat16', 6, 1, 128, 256, 512, 1280): (4, 32),
476
+ ('bfloat16', 'bfloat16', 6, 1, 128, 256, 512, 2048): (8, 32),
477
+ ('bfloat16', 'bfloat16', 6, 1, 128, 256, 512, 4096): (16, 32),
478
+ ('bfloat16', 'bfloat16', 6, 1, 128, 32, 1024, 128): (4, 64),
479
+ ('bfloat16', 'bfloat16', 6, 1, 128, 32, 1024, 256): (8, 64),
480
+ ('bfloat16', 'bfloat16', 6, 1, 128, 32, 1024, 512): (16, 64),
481
+ ('bfloat16', 'bfloat16', 6, 1, 128, 32, 2048, 128): (4, 32),
482
+ ('bfloat16', 'bfloat16', 6, 1, 128, 32, 2048, 256): (8, 32),
483
+ ('bfloat16', 'bfloat16', 6, 1, 128, 32, 2048, 512): (16, 32),
484
+ ('bfloat16', 'bfloat16', 6, 1, 128, 32, 4096, 128): (4, 64),
485
+ ('bfloat16', 'bfloat16', 6, 1, 128, 32, 4096, 256): (8, 32),
486
+ ('bfloat16', 'bfloat16', 6, 1, 128, 32, 4096, 512): (16, 128),
487
+ ('bfloat16', 'bfloat16', 6, 1, 128, 32, 512, 128): (4, 32),
488
+ ('bfloat16', 'bfloat16', 6, 1, 128, 32, 512, 256): (8, 32),
489
+ ('bfloat16', 'bfloat16', 6, 1, 128, 32, 512, 512): (16, 32),
490
+ ('bfloat16', 'bfloat16', 6, 1, 128, 64, 1024, 1024): (16, 32),
491
+ ('bfloat16', 'bfloat16', 6, 1, 128, 64, 1024, 256): (4, 64),
492
+ ('bfloat16', 'bfloat16', 6, 1, 128, 64, 1024, 512): (8, 32),
493
+ ('bfloat16', 'bfloat16', 6, 1, 128, 64, 2048, 1024): (16, 32),
494
+ ('bfloat16', 'bfloat16', 6, 1, 128, 64, 2048, 256): (4, 32),
495
+ ('bfloat16', 'bfloat16', 6, 1, 128, 64, 2048, 512): (8, 32),
496
+ ('bfloat16', 'bfloat16', 6, 1, 128, 64, 4096, 1024): (16, 64),
497
+ ('bfloat16', 'bfloat16', 6, 1, 128, 64, 4096, 256): (4, 128),
498
+ ('bfloat16', 'bfloat16', 6, 1, 128, 64, 4096, 512): (8, 64),
499
+ ('bfloat16', 'bfloat16', 6, 1, 128, 64, 512, 1024): (8, 64),
500
+ ('bfloat16', 'bfloat16', 6, 1, 128, 64, 512, 256): (4, 32),
501
+ ('bfloat16', 'bfloat16', 6, 1, 128, 64, 512, 512): (8, 32),
502
+ ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 1024): (8, 64),
503
+ ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 1280): (8, 64),
504
+ ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 2048): (16, 32),
505
+ ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 512): (4, 64),
506
+ ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 1024): (8, 64),
507
+ ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 1280): (4, 64),
508
+ ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 2048): (16, 32),
509
+ ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 512): (4, 64),
510
+ ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 1024): (8, 64),
511
+ ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 1280): (4, 64),
512
+ ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 2048): (16, 32),
513
+ ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 512): (4, 64),
514
+ ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 1024): (8, 32),
515
+ ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 1280): (8, 32),
516
+ ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 2048): (16, 32),
517
+ ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 512): (4, 32),
518
+ ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 128): (8, 128),
519
+ ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 256): (16, 32),
520
+ ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 64): (4, 64),
521
+ ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 128): (8, 64),
522
+ ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 256): (8, 64),
523
+ ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 64): (4, 64),
524
+ ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 128): (8, 64),
525
+ ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 256): (16, 64),
526
+ ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 64): (4, 64),
527
+ ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 128): (8, 32),
528
+ ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 256): (16, 64),
529
+ ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 64): (4, 128),
530
+ ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 1024): (4, 32),
531
+ ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 1280): (4, 64),
532
+ ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 2048): (8, 32),
533
+ ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 4096): (16, 32),
534
+ ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 1024): (4, 32),
535
+ ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 1280): (4, 64),
536
+ ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 2048): (8, 32),
537
+ ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 4096): (16, 32),
538
+ ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 1024): (4, 64),
539
+ ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 1280): (4, 32),
540
+ ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 2048): (8, 64),
541
+ ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 4096): (16, 32),
542
+ ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 1024): (4, 32),
543
+ ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 1280): (4, 32),
544
+ ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 2048): (8, 64),
545
+ ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 4096): (8, 32),
546
+ ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 128): (4, 64),
547
+ ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 256): (8, 128),
548
+ ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 512): (16, 32),
549
+ ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 128): (4, 128),
550
+ ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 256): (8, 128),
551
+ ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 512): (16, 64),
552
+ ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 128): (4, 128),
553
+ ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 256): (8, 32),
554
+ ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 512): (16, 64),
555
+ ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 128): (4, 64),
556
+ ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 256): (8, 32),
557
+ ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 512): (16, 64),
558
+ ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 1024): (16, 32),
559
+ ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 256): (4, 32),
560
+ ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 512): (8, 64),
561
+ ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 1024): (8, 32),
562
+ ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 256): (4, 128),
563
+ ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 512): (8, 64),
564
+ ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 1024): (16, 64),
565
+ ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 256): (4, 128),
566
+ ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 512): (8, 128),
567
+ ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 1024): (16, 32),
568
+ ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 256): (4, 64),
569
+ ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 512): (8, 64),
570
+ ('bfloat16', 'bfloat16', 8, 2, 128, 128, 1024, 1024): (8, 64),
571
+ ('bfloat16', 'bfloat16', 8, 2, 128, 128, 1024, 1280): (8, 64),
572
+ ('bfloat16', 'bfloat16', 8, 2, 128, 128, 1024, 2048): (16, 32),
573
+ ('bfloat16', 'bfloat16', 8, 2, 128, 128, 1024, 512): (4, 64),
574
+ ('bfloat16', 'bfloat16', 8, 2, 128, 128, 2048, 1024): (8, 64),
575
+ ('bfloat16', 'bfloat16', 8, 2, 128, 128, 2048, 1280): (8, 64),
576
+ ('bfloat16', 'bfloat16', 8, 2, 128, 128, 2048, 2048): (16, 32),
577
+ ('bfloat16', 'bfloat16', 8, 2, 128, 128, 2048, 512): (4, 64),
578
+ ('bfloat16', 'bfloat16', 8, 2, 128, 128, 4096, 1024): (8, 64),
579
+ ('bfloat16', 'bfloat16', 8, 2, 128, 128, 4096, 1280): (4, 64),
580
+ ('bfloat16', 'bfloat16', 8, 2, 128, 128, 4096, 2048): (16, 64),
581
+ ('bfloat16', 'bfloat16', 8, 2, 128, 128, 4096, 512): (4, 32),
582
+ ('bfloat16', 'bfloat16', 8, 2, 128, 128, 512, 1024): (8, 64),
583
+ ('bfloat16', 'bfloat16', 8, 2, 128, 128, 512, 1280): (8, 64),
584
+ ('bfloat16', 'bfloat16', 8, 2, 128, 128, 512, 2048): (16, 32),
585
+ ('bfloat16', 'bfloat16', 8, 2, 128, 128, 512, 512): (4, 32),
586
+ ('bfloat16', 'bfloat16', 8, 2, 128, 16, 1024, 128): (8, 32),
587
+ ('bfloat16', 'bfloat16', 8, 2, 128, 16, 1024, 256): (16, 32),
588
+ ('bfloat16', 'bfloat16', 8, 2, 128, 16, 1024, 64): (4, 64),
589
+ ('bfloat16', 'bfloat16', 8, 2, 128, 16, 2048, 128): (8, 64),
590
+ ('bfloat16', 'bfloat16', 8, 2, 128, 16, 2048, 256): (16, 64),
591
+ ('bfloat16', 'bfloat16', 8, 2, 128, 16, 2048, 64): (4, 32),
592
+ ('bfloat16', 'bfloat16', 8, 2, 128, 16, 4096, 128): (8, 64),
593
+ ('bfloat16', 'bfloat16', 8, 2, 128, 16, 4096, 256): (16, 128),
594
+ ('bfloat16', 'bfloat16', 8, 2, 128, 16, 4096, 64): (4, 128),
595
+ ('bfloat16', 'bfloat16', 8, 2, 128, 16, 512, 128): (8, 32),
596
+ ('bfloat16', 'bfloat16', 8, 2, 128, 16, 512, 256): (16, 32),
597
+ ('bfloat16', 'bfloat16', 8, 2, 128, 16, 512, 64): (4, 32),
598
+ ('bfloat16', 'bfloat16', 8, 2, 128, 256, 1024, 1024): (4, 32),
599
+ ('bfloat16', 'bfloat16', 8, 2, 128, 256, 1024, 1280): (4, 64),
600
+ ('bfloat16', 'bfloat16', 8, 2, 128, 256, 1024, 2048): (8, 32),
601
+ ('bfloat16', 'bfloat16', 8, 2, 128, 256, 1024, 4096): (16, 32),
602
+ ('bfloat16', 'bfloat16', 8, 2, 128, 256, 2048, 1024): (4, 64),
603
+ ('bfloat16', 'bfloat16', 8, 2, 128, 256, 2048, 1280): (4, 64),
604
+ ('bfloat16', 'bfloat16', 8, 2, 128, 256, 2048, 2048): (8, 64),
605
+ ('bfloat16', 'bfloat16', 8, 2, 128, 256, 2048, 4096): (16, 32),
606
+ ('bfloat16', 'bfloat16', 8, 2, 128, 256, 4096, 1024): (4, 64),
607
+ ('bfloat16', 'bfloat16', 8, 2, 128, 256, 4096, 1280): (4, 64),
608
+ ('bfloat16', 'bfloat16', 8, 2, 128, 256, 4096, 2048): (8, 64),
609
+ ('bfloat16', 'bfloat16', 8, 2, 128, 256, 4096, 4096): (16, 64),
610
+ ('bfloat16', 'bfloat16', 8, 2, 128, 256, 512, 1024): (4, 32),
611
+ ('bfloat16', 'bfloat16', 8, 2, 128, 256, 512, 1280): (4, 32),
612
+ ('bfloat16', 'bfloat16', 8, 2, 128, 256, 512, 2048): (8, 32),
613
+ ('bfloat16', 'bfloat16', 8, 2, 128, 256, 512, 4096): (16, 32),
614
+ ('bfloat16', 'bfloat16', 8, 2, 128, 32, 1024, 128): (4, 64),
615
+ ('bfloat16', 'bfloat16', 8, 2, 128, 32, 1024, 256): (8, 32),
616
+ ('bfloat16', 'bfloat16', 8, 2, 128, 32, 1024, 512): (16, 64),
617
+ ('bfloat16', 'bfloat16', 8, 2, 128, 32, 2048, 128): (4, 32),
618
+ ('bfloat16', 'bfloat16', 8, 2, 128, 32, 2048, 256): (8, 32),
619
+ ('bfloat16', 'bfloat16', 8, 2, 128, 32, 2048, 512): (16, 32),
620
+ ('bfloat16', 'bfloat16', 8, 2, 128, 32, 4096, 128): (4, 128),
621
+ ('bfloat16', 'bfloat16', 8, 2, 128, 32, 4096, 256): (8, 64),
622
+ ('bfloat16', 'bfloat16', 8, 2, 128, 32, 4096, 512): (16, 64),
623
+ ('bfloat16', 'bfloat16', 8, 2, 128, 32, 512, 128): (4, 64),
624
+ ('bfloat16', 'bfloat16', 8, 2, 128, 32, 512, 256): (8, 32),
625
+ ('bfloat16', 'bfloat16', 8, 2, 128, 32, 512, 512): (16, 32),
626
+ ('bfloat16', 'bfloat16', 8, 2, 128, 64, 1024, 1024): (16, 64),
627
+ ('bfloat16', 'bfloat16', 8, 2, 128, 64, 1024, 256): (4, 32),
628
+ ('bfloat16', 'bfloat16', 8, 2, 128, 64, 1024, 512): (8, 64),
629
+ ('bfloat16', 'bfloat16', 8, 2, 128, 64, 2048, 1024): (16, 64),
630
+ ('bfloat16', 'bfloat16', 8, 2, 128, 64, 2048, 256): (4, 64),
631
+ ('bfloat16', 'bfloat16', 8, 2, 128, 64, 2048, 512): (8, 32),
632
+ ('bfloat16', 'bfloat16', 8, 2, 128, 64, 4096, 1024): (16, 64),
633
+ ('bfloat16', 'bfloat16', 8, 2, 128, 64, 4096, 256): (4, 32),
634
+ ('bfloat16', 'bfloat16', 8, 2, 128, 64, 4096, 512): (8, 64),
635
+ ('bfloat16', 'bfloat16', 8, 2, 128, 64, 512, 1024): (16, 32),
636
+ ('bfloat16', 'bfloat16', 8, 2, 128, 64, 512, 256): (4, 32),
637
+ ('bfloat16', 'bfloat16', 8, 2, 128, 64, 512, 512): (8, 32),
638
+ ('bfloat16', 'bfloat16', 8, 4, 128, 128, 1024, 1024): (8, 64),
639
+ ('bfloat16', 'bfloat16', 8, 4, 128, 128, 1024, 1280): (4, 64),
640
+ ('bfloat16', 'bfloat16', 8, 4, 128, 128, 1024, 2048): (16, 64),
641
+ ('bfloat16', 'bfloat16', 8, 4, 128, 128, 1024, 512): (4, 64),
642
+ ('bfloat16', 'bfloat16', 8, 4, 128, 128, 2048, 1024): (8, 32),
643
+ ('bfloat16', 'bfloat16', 8, 4, 128, 128, 2048, 1280): (8, 64),
644
+ ('bfloat16', 'bfloat16', 8, 4, 128, 128, 2048, 2048): (16, 64),
645
+ ('bfloat16', 'bfloat16', 8, 4, 128, 128, 2048, 512): (4, 32),
646
+ ('bfloat16', 'bfloat16', 8, 4, 128, 128, 4096, 1024): (8, 64),
647
+ ('bfloat16', 'bfloat16', 8, 4, 128, 128, 4096, 1280): (8, 64),
648
+ ('bfloat16', 'bfloat16', 8, 4, 128, 128, 4096, 2048): (16, 64),
649
+ ('bfloat16', 'bfloat16', 8, 4, 128, 128, 4096, 512): (4, 64),
650
+ ('bfloat16', 'bfloat16', 8, 4, 128, 128, 512, 1024): (8, 32),
651
+ ('bfloat16', 'bfloat16', 8, 4, 128, 128, 512, 1280): (8, 64),
652
+ ('bfloat16', 'bfloat16', 8, 4, 128, 128, 512, 2048): (16, 32),
653
+ ('bfloat16', 'bfloat16', 8, 4, 128, 128, 512, 512): (4, 64),
654
+ ('bfloat16', 'bfloat16', 8, 4, 128, 16, 1024, 128): (8, 64),
655
+ ('bfloat16', 'bfloat16', 8, 4, 128, 16, 1024, 256): (16, 64),
656
+ ('bfloat16', 'bfloat16', 8, 4, 128, 16, 1024, 64): (4, 64),
657
+ ('bfloat16', 'bfloat16', 8, 4, 128, 16, 2048, 128): (8, 32),
658
+ ('bfloat16', 'bfloat16', 8, 4, 128, 16, 2048, 256): (16, 64),
659
+ ('bfloat16', 'bfloat16', 8, 4, 128, 16, 2048, 64): (4, 64),
660
+ ('bfloat16', 'bfloat16', 8, 4, 128, 16, 4096, 128): (8, 64),
661
+ ('bfloat16', 'bfloat16', 8, 4, 128, 16, 4096, 256): (16, 32),
662
+ ('bfloat16', 'bfloat16', 8, 4, 128, 16, 4096, 64): (4, 64),
663
+ ('bfloat16', 'bfloat16', 8, 4, 128, 16, 512, 128): (8, 128),
664
+ ('bfloat16', 'bfloat16', 8, 4, 128, 16, 512, 256): (16, 32),
665
+ ('bfloat16', 'bfloat16', 8, 4, 128, 16, 512, 64): (4, 32),
666
+ ('bfloat16', 'bfloat16', 8, 4, 128, 256, 1024, 1024): (4, 64),
667
+ ('bfloat16', 'bfloat16', 8, 4, 128, 256, 1024, 1280): (4, 64),
668
+ ('bfloat16', 'bfloat16', 8, 4, 128, 256, 1024, 2048): (8, 64),
669
+ ('bfloat16', 'bfloat16', 8, 4, 128, 256, 1024, 4096): (16, 64),
670
+ ('bfloat16', 'bfloat16', 8, 4, 128, 256, 2048, 1024): (4, 64),
671
+ ('bfloat16', 'bfloat16', 8, 4, 128, 256, 2048, 1280): (4, 64),
672
+ ('bfloat16', 'bfloat16', 8, 4, 128, 256, 2048, 2048): (8, 64),
673
+ ('bfloat16', 'bfloat16', 8, 4, 128, 256, 2048, 4096): (16, 64),
674
+ ('bfloat16', 'bfloat16', 8, 4, 128, 256, 4096, 1024): (4, 64),
675
+ ('bfloat16', 'bfloat16', 8, 4, 128, 256, 4096, 1280): (4, 64),
676
+ ('bfloat16', 'bfloat16', 8, 4, 128, 256, 4096, 2048): (8, 64),
677
+ ('bfloat16', 'bfloat16', 8, 4, 128, 256, 4096, 4096): (16, 64),
678
+ ('bfloat16', 'bfloat16', 8, 4, 128, 256, 512, 1024): (4, 32),
679
+ ('bfloat16', 'bfloat16', 8, 4, 128, 256, 512, 1280): (4, 64),
680
+ ('bfloat16', 'bfloat16', 8, 4, 128, 256, 512, 2048): (8, 64),
681
+ ('bfloat16', 'bfloat16', 8, 4, 128, 256, 512, 4096): (16, 64),
682
+ ('bfloat16', 'bfloat16', 8, 4, 128, 32, 1024, 128): (4, 64),
683
+ ('bfloat16', 'bfloat16', 8, 4, 128, 32, 1024, 256): (8, 64),
684
+ ('bfloat16', 'bfloat16', 8, 4, 128, 32, 1024, 512): (16, 64),
685
+ ('bfloat16', 'bfloat16', 8, 4, 128, 32, 2048, 128): (4, 64),
686
+ ('bfloat16', 'bfloat16', 8, 4, 128, 32, 2048, 256): (8, 64),
687
+ ('bfloat16', 'bfloat16', 8, 4, 128, 32, 2048, 512): (16, 32),
688
+ ('bfloat16', 'bfloat16', 8, 4, 128, 32, 4096, 128): (4, 64),
689
+ ('bfloat16', 'bfloat16', 8, 4, 128, 32, 4096, 256): (8, 128),
690
+ ('bfloat16', 'bfloat16', 8, 4, 128, 32, 4096, 512): (16, 32),
691
+ ('bfloat16', 'bfloat16', 8, 4, 128, 32, 512, 128): (4, 32),
692
+ ('bfloat16', 'bfloat16', 8, 4, 128, 32, 512, 256): (8, 32),
693
+ ('bfloat16', 'bfloat16', 8, 4, 128, 32, 512, 512): (16, 32),
694
+ ('bfloat16', 'bfloat16', 8, 4, 128, 64, 1024, 1024): (16, 32),
695
+ ('bfloat16', 'bfloat16', 8, 4, 128, 64, 1024, 256): (4, 64),
696
+ ('bfloat16', 'bfloat16', 8, 4, 128, 64, 1024, 512): (8, 32),
697
+ ('bfloat16', 'bfloat16', 8, 4, 128, 64, 2048, 1024): (16, 32),
698
+ ('bfloat16', 'bfloat16', 8, 4, 128, 64, 2048, 256): (4, 64),
699
+ ('bfloat16', 'bfloat16', 8, 4, 128, 64, 2048, 512): (8, 64),
700
+ ('bfloat16', 'bfloat16', 8, 4, 128, 64, 4096, 1024): (16, 64),
701
+ ('bfloat16', 'bfloat16', 8, 4, 128, 64, 4096, 256): (4, 64),
702
+ ('bfloat16', 'bfloat16', 8, 4, 128, 64, 4096, 512): (8, 64),
703
+ ('bfloat16', 'bfloat16', 8, 4, 128, 64, 512, 1024): (16, 64),
704
+ ('bfloat16', 'bfloat16', 8, 4, 128, 64, 512, 256): (4, 64),
705
+ ('bfloat16', 'bfloat16', 8, 4, 128, 64, 512, 512): (8, 32),
706
+ # go/keep-sorted end
707
+ },
708
+ 'TPU v5': {
709
+ # go/keep-sorted start
710
+ ('bfloat16', 'bfloat16', 12, 2, 128, 128, 1024, 1024): (8, 32),
711
+ ('bfloat16', 'bfloat16', 12, 2, 128, 128, 1024, 1280): (4, 32),
712
+ ('bfloat16', 'bfloat16', 12, 2, 128, 128, 1024, 2048): (16, 32),
713
+ ('bfloat16', 'bfloat16', 12, 2, 128, 128, 1024, 512): (4, 32),
714
+ ('bfloat16', 'bfloat16', 12, 2, 128, 128, 2048, 1024): (8, 32),
715
+ ('bfloat16', 'bfloat16', 12, 2, 128, 128, 2048, 1280): (8, 32),
716
+ ('bfloat16', 'bfloat16', 12, 2, 128, 128, 2048, 2048): (16, 32),
717
+ ('bfloat16', 'bfloat16', 12, 2, 128, 128, 2048, 512): (4, 32),
718
+ ('bfloat16', 'bfloat16', 12, 2, 128, 128, 4096, 1024): (8, 32),
719
+ ('bfloat16', 'bfloat16', 12, 2, 128, 128, 4096, 1280): (8, 32),
720
+ ('bfloat16', 'bfloat16', 12, 2, 128, 128, 4096, 2048): (16, 32),
721
+ ('bfloat16', 'bfloat16', 12, 2, 128, 128, 4096, 512): (4, 32),
722
+ ('bfloat16', 'bfloat16', 12, 2, 128, 128, 512, 1024): (8, 32),
723
+ ('bfloat16', 'bfloat16', 12, 2, 128, 128, 512, 1280): (8, 32),
724
+ ('bfloat16', 'bfloat16', 12, 2, 128, 128, 512, 2048): (16, 32),
725
+ ('bfloat16', 'bfloat16', 12, 2, 128, 128, 512, 512): (4, 32),
726
+ ('bfloat16', 'bfloat16', 12, 2, 128, 16, 1024, 128): (8, 32),
727
+ ('bfloat16', 'bfloat16', 12, 2, 128, 16, 1024, 256): (16, 32),
728
+ ('bfloat16', 'bfloat16', 12, 2, 128, 16, 1024, 64): (4, 32),
729
+ ('bfloat16', 'bfloat16', 12, 2, 128, 16, 2048, 128): (8, 32),
730
+ ('bfloat16', 'bfloat16', 12, 2, 128, 16, 2048, 256): (16, 32),
731
+ ('bfloat16', 'bfloat16', 12, 2, 128, 16, 2048, 64): (4, 32),
732
+ ('bfloat16', 'bfloat16', 12, 2, 128, 16, 4096, 128): (8, 32),
733
+ ('bfloat16', 'bfloat16', 12, 2, 128, 16, 4096, 256): (16, 32),
734
+ ('bfloat16', 'bfloat16', 12, 2, 128, 16, 4096, 64): (4, 32),
735
+ ('bfloat16', 'bfloat16', 12, 2, 128, 16, 512, 128): (8, 32),
736
+ ('bfloat16', 'bfloat16', 12, 2, 128, 16, 512, 256): (16, 32),
737
+ ('bfloat16', 'bfloat16', 12, 2, 128, 16, 512, 64): (4, 32),
738
+ ('bfloat16', 'bfloat16', 12, 2, 128, 256, 1024, 1024): (4, 32),
739
+ ('bfloat16', 'bfloat16', 12, 2, 128, 256, 1024, 1280): (4, 32),
740
+ ('bfloat16', 'bfloat16', 12, 2, 128, 256, 1024, 2048): (8, 32),
741
+ ('bfloat16', 'bfloat16', 12, 2, 128, 256, 1024, 4096): (16, 32),
742
+ ('bfloat16', 'bfloat16', 12, 2, 128, 256, 2048, 1024): (4, 32),
743
+ ('bfloat16', 'bfloat16', 12, 2, 128, 256, 2048, 1280): (4, 32),
744
+ ('bfloat16', 'bfloat16', 12, 2, 128, 256, 2048, 2048): (8, 32),
745
+ ('bfloat16', 'bfloat16', 12, 2, 128, 256, 2048, 4096): (16, 32),
746
+ ('bfloat16', 'bfloat16', 12, 2, 128, 256, 4096, 1024): (4, 32),
747
+ ('bfloat16', 'bfloat16', 12, 2, 128, 256, 4096, 1280): (4, 32),
748
+ ('bfloat16', 'bfloat16', 12, 2, 128, 256, 4096, 2048): (8, 32),
749
+ ('bfloat16', 'bfloat16', 12, 2, 128, 256, 4096, 4096): (16, 32),
750
+ ('bfloat16', 'bfloat16', 12, 2, 128, 256, 512, 1024): (4, 32),
751
+ ('bfloat16', 'bfloat16', 12, 2, 128, 256, 512, 1280): (4, 32),
752
+ ('bfloat16', 'bfloat16', 12, 2, 128, 256, 512, 2048): (8, 32),
753
+ ('bfloat16', 'bfloat16', 12, 2, 128, 256, 512, 4096): (16, 32),
754
+ ('bfloat16', 'bfloat16', 12, 2, 128, 32, 1024, 128): (4, 32),
755
+ ('bfloat16', 'bfloat16', 12, 2, 128, 32, 1024, 256): (8, 32),
756
+ ('bfloat16', 'bfloat16', 12, 2, 128, 32, 1024, 512): (16, 32),
757
+ ('bfloat16', 'bfloat16', 12, 2, 128, 32, 2048, 128): (4, 32),
758
+ ('bfloat16', 'bfloat16', 12, 2, 128, 32, 2048, 256): (8, 32),
759
+ ('bfloat16', 'bfloat16', 12, 2, 128, 32, 2048, 512): (16, 32),
760
+ ('bfloat16', 'bfloat16', 12, 2, 128, 32, 4096, 128): (4, 32),
761
+ ('bfloat16', 'bfloat16', 12, 2, 128, 32, 4096, 256): (8, 32),
762
+ ('bfloat16', 'bfloat16', 12, 2, 128, 32, 4096, 512): (16, 32),
763
+ ('bfloat16', 'bfloat16', 12, 2, 128, 32, 512, 128): (4, 32),
764
+ ('bfloat16', 'bfloat16', 12, 2, 128, 32, 512, 256): (8, 32),
765
+ ('bfloat16', 'bfloat16', 12, 2, 128, 32, 512, 512): (16, 32),
766
+ ('bfloat16', 'bfloat16', 12, 2, 128, 64, 1024, 1024): (16, 32),
767
+ ('bfloat16', 'bfloat16', 12, 2, 128, 64, 1024, 256): (4, 32),
768
+ ('bfloat16', 'bfloat16', 12, 2, 128, 64, 1024, 512): (8, 32),
769
+ ('bfloat16', 'bfloat16', 12, 2, 128, 64, 2048, 1024): (16, 32),
770
+ ('bfloat16', 'bfloat16', 12, 2, 128, 64, 2048, 256): (4, 32),
771
+ ('bfloat16', 'bfloat16', 12, 2, 128, 64, 2048, 512): (8, 32),
772
+ ('bfloat16', 'bfloat16', 12, 2, 128, 64, 4096, 1024): (16, 32),
773
+ ('bfloat16', 'bfloat16', 12, 2, 128, 64, 4096, 256): (4, 32),
774
+ ('bfloat16', 'bfloat16', 12, 2, 128, 64, 4096, 512): (8, 32),
775
+ ('bfloat16', 'bfloat16', 12, 2, 128, 64, 512, 1024): (16, 32),
776
+ ('bfloat16', 'bfloat16', 12, 2, 128, 64, 512, 256): (4, 32),
777
+ ('bfloat16', 'bfloat16', 12, 2, 128, 64, 512, 512): (8, 32),
778
+ ('bfloat16', 'bfloat16', 28, 4, 128, 128, 1024, 1024): (8, 32),
779
+ ('bfloat16', 'bfloat16', 28, 4, 128, 128, 1024, 1280): (8, 32),
780
+ ('bfloat16', 'bfloat16', 28, 4, 128, 128, 1024, 2048): (16, 32),
781
+ ('bfloat16', 'bfloat16', 28, 4, 128, 128, 1024, 512): (4, 32),
782
+ ('bfloat16', 'bfloat16', 28, 4, 128, 128, 2048, 1024): (8, 32),
783
+ ('bfloat16', 'bfloat16', 28, 4, 128, 128, 2048, 1280): (8, 32),
784
+ ('bfloat16', 'bfloat16', 28, 4, 128, 128, 2048, 2048): (16, 32),
785
+ ('bfloat16', 'bfloat16', 28, 4, 128, 128, 2048, 512): (4, 32),
786
+ ('bfloat16', 'bfloat16', 28, 4, 128, 128, 4096, 1024): (8, 32),
787
+ ('bfloat16', 'bfloat16', 28, 4, 128, 128, 4096, 1280): (8, 32),
788
+ ('bfloat16', 'bfloat16', 28, 4, 128, 128, 4096, 2048): (16, 32),
789
+ ('bfloat16', 'bfloat16', 28, 4, 128, 128, 4096, 512): (4, 32),
790
+ ('bfloat16', 'bfloat16', 28, 4, 128, 128, 512, 1024): (8, 32),
791
+ ('bfloat16', 'bfloat16', 28, 4, 128, 128, 512, 1280): (8, 32),
792
+ ('bfloat16', 'bfloat16', 28, 4, 128, 128, 512, 2048): (16, 32),
793
+ ('bfloat16', 'bfloat16', 28, 4, 128, 128, 512, 512): (4, 32),
794
+ ('bfloat16', 'bfloat16', 28, 4, 128, 16, 1024, 128): (8, 32),
795
+ ('bfloat16', 'bfloat16', 28, 4, 128, 16, 1024, 256): (16, 32),
796
+ ('bfloat16', 'bfloat16', 28, 4, 128, 16, 1024, 64): (4, 32),
797
+ ('bfloat16', 'bfloat16', 28, 4, 128, 16, 2048, 128): (8, 32),
798
+ ('bfloat16', 'bfloat16', 28, 4, 128, 16, 2048, 256): (16, 32),
799
+ ('bfloat16', 'bfloat16', 28, 4, 128, 16, 2048, 64): (4, 32),
800
+ ('bfloat16', 'bfloat16', 28, 4, 128, 16, 4096, 128): (8, 32),
801
+ ('bfloat16', 'bfloat16', 28, 4, 128, 16, 4096, 256): (16, 32),
802
+ ('bfloat16', 'bfloat16', 28, 4, 128, 16, 4096, 64): (4, 32),
803
+ ('bfloat16', 'bfloat16', 28, 4, 128, 16, 512, 128): (8, 32),
804
+ ('bfloat16', 'bfloat16', 28, 4, 128, 16, 512, 256): (16, 32),
805
+ ('bfloat16', 'bfloat16', 28, 4, 128, 16, 512, 64): (4, 32),
806
+ ('bfloat16', 'bfloat16', 28, 4, 128, 256, 1024, 1024): (4, 32),
807
+ ('bfloat16', 'bfloat16', 28, 4, 128, 256, 1024, 1280): (4, 32),
808
+ ('bfloat16', 'bfloat16', 28, 4, 128, 256, 1024, 2048): (8, 32),
809
+ ('bfloat16', 'bfloat16', 28, 4, 128, 256, 1024, 4096): (16, 32),
810
+ ('bfloat16', 'bfloat16', 28, 4, 128, 256, 2048, 1024): (4, 32),
811
+ ('bfloat16', 'bfloat16', 28, 4, 128, 256, 2048, 1280): (4, 32),
812
+ ('bfloat16', 'bfloat16', 28, 4, 128, 256, 2048, 2048): (8, 32),
813
+ ('bfloat16', 'bfloat16', 28, 4, 128, 256, 2048, 4096): (16, 32),
814
+ ('bfloat16', 'bfloat16', 28, 4, 128, 256, 4096, 1024): (4, 32),
815
+ ('bfloat16', 'bfloat16', 28, 4, 128, 256, 4096, 1280): (4, 32),
816
+ ('bfloat16', 'bfloat16', 28, 4, 128, 256, 4096, 2048): (8, 32),
817
+ ('bfloat16', 'bfloat16', 28, 4, 128, 256, 4096, 4096): (16, 32),
818
+ ('bfloat16', 'bfloat16', 28, 4, 128, 256, 512, 1024): (4, 32),
819
+ ('bfloat16', 'bfloat16', 28, 4, 128, 256, 512, 1280): (4, 32),
820
+ ('bfloat16', 'bfloat16', 28, 4, 128, 256, 512, 2048): (8, 32),
821
+ ('bfloat16', 'bfloat16', 28, 4, 128, 256, 512, 4096): (16, 32),
822
+ ('bfloat16', 'bfloat16', 28, 4, 128, 32, 1024, 128): (4, 32),
823
+ ('bfloat16', 'bfloat16', 28, 4, 128, 32, 1024, 256): (8, 32),
824
+ ('bfloat16', 'bfloat16', 28, 4, 128, 32, 1024, 512): (16, 32),
825
+ ('bfloat16', 'bfloat16', 28, 4, 128, 32, 2048, 128): (4, 32),
826
+ ('bfloat16', 'bfloat16', 28, 4, 128, 32, 2048, 256): (8, 32),
827
+ ('bfloat16', 'bfloat16', 28, 4, 128, 32, 2048, 512): (16, 32),
828
+ ('bfloat16', 'bfloat16', 28, 4, 128, 32, 4096, 128): (4, 32),
829
+ ('bfloat16', 'bfloat16', 28, 4, 128, 32, 4096, 256): (8, 32),
830
+ ('bfloat16', 'bfloat16', 28, 4, 128, 32, 4096, 512): (16, 32),
831
+ ('bfloat16', 'bfloat16', 28, 4, 128, 32, 512, 128): (4, 32),
832
+ ('bfloat16', 'bfloat16', 28, 4, 128, 32, 512, 256): (8, 32),
833
+ ('bfloat16', 'bfloat16', 28, 4, 128, 32, 512, 512): (16, 32),
834
+ ('bfloat16', 'bfloat16', 28, 4, 128, 64, 1024, 1024): (16, 32),
835
+ ('bfloat16', 'bfloat16', 28, 4, 128, 64, 1024, 256): (4, 32),
836
+ ('bfloat16', 'bfloat16', 28, 4, 128, 64, 1024, 512): (8, 32),
837
+ ('bfloat16', 'bfloat16', 28, 4, 128, 64, 2048, 1024): (16, 32),
838
+ ('bfloat16', 'bfloat16', 28, 4, 128, 64, 2048, 256): (4, 32),
839
+ ('bfloat16', 'bfloat16', 28, 4, 128, 64, 2048, 512): (8, 32),
840
+ ('bfloat16', 'bfloat16', 28, 4, 128, 64, 4096, 1024): (16, 32),
841
+ ('bfloat16', 'bfloat16', 28, 4, 128, 64, 4096, 256): (4, 32),
842
+ ('bfloat16', 'bfloat16', 28, 4, 128, 64, 4096, 512): (8, 32),
843
+ ('bfloat16', 'bfloat16', 28, 4, 128, 64, 512, 1024): (16, 32),
844
+ ('bfloat16', 'bfloat16', 28, 4, 128, 64, 512, 256): (4, 32),
845
+ ('bfloat16', 'bfloat16', 28, 4, 128, 64, 512, 512): (8, 32),
846
+ ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 1024): (8, 32),
847
+ ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 1280): (4, 32),
848
+ ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 512): (4, 32),
849
+ ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 1024): (8, 32),
850
+ ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 1280): (4, 32),
851
+ ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 512): (4, 32),
852
+ ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 1024): (8, 32),
853
+ ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 1280): (4, 32),
854
+ ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 512): (4, 32),
855
+ ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 1024): (8, 32),
856
+ ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 1280): (4, 32),
857
+ ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 512): (4, 32),
858
+ ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 128): (8, 32),
859
+ ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 256): (16, 32),
860
+ ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 64): (4, 32),
861
+ ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 128): (8, 32),
862
+ ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 256): (16, 32),
863
+ ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 64): (4, 32),
864
+ ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 128): (8, 32),
865
+ ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 256): (16, 32),
866
+ ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 64): (4, 32),
867
+ ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 128): (8, 32),
868
+ ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 256): (16, 32),
869
+ ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 64): (4, 32),
870
+ ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 1024): (4, 32),
871
+ ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 1280): (4, 32),
872
+ ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 1024): (4, 32),
873
+ ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 1280): (4, 32),
874
+ ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 1024): (4, 32),
875
+ ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 1280): (4, 32),
876
+ ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 1024): (4, 32),
877
+ ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 1280): (4, 32),
878
+ ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 128): (4, 32),
879
+ ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 256): (8, 32),
880
+ ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 512): (16, 32),
881
+ ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 128): (4, 32),
882
+ ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 256): (8, 32),
883
+ ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 512): (16, 32),
884
+ ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 128): (4, 32),
885
+ ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 256): (8, 32),
886
+ ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 512): (16, 32),
887
+ ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 128): (4, 32),
888
+ ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 256): (8, 32),
889
+ ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 512): (16, 32),
890
+ ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 1024): (16, 32),
891
+ ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 256): (4, 32),
892
+ ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 512): (8, 32),
893
+ ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 1024): (16, 32),
894
+ ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 256): (4, 32),
895
+ ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 512): (8, 32),
896
+ ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 1024): (16, 32),
897
+ ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 256): (4, 32),
898
+ ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 512): (8, 32),
899
+ ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 1024): (16, 32),
900
+ ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 256): (4, 32),
901
+ ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 512): (8, 32),
902
+ ('bfloat16', 'bfloat16', 4, 1, 128, 128, 1024, 1024): (8, 64),
903
+ ('bfloat16', 'bfloat16', 4, 1, 128, 128, 1024, 1280): (8, 64),
904
+ ('bfloat16', 'bfloat16', 4, 1, 128, 128, 1024, 2048): (16, 32),
905
+ ('bfloat16', 'bfloat16', 4, 1, 128, 128, 1024, 512): (4, 32),
906
+ ('bfloat16', 'bfloat16', 4, 1, 128, 128, 2048, 1024): (8, 64),
907
+ ('bfloat16', 'bfloat16', 4, 1, 128, 128, 2048, 1280): (4, 32),
908
+ ('bfloat16', 'bfloat16', 4, 1, 128, 128, 2048, 2048): (16, 64),
909
+ ('bfloat16', 'bfloat16', 4, 1, 128, 128, 2048, 512): (4, 32),
910
+ ('bfloat16', 'bfloat16', 4, 1, 128, 128, 4096, 1024): (8, 32),
911
+ ('bfloat16', 'bfloat16', 4, 1, 128, 128, 4096, 1280): (4, 64),
912
+ ('bfloat16', 'bfloat16', 4, 1, 128, 128, 4096, 2048): (16, 64),
913
+ ('bfloat16', 'bfloat16', 4, 1, 128, 128, 4096, 512): (4, 32),
914
+ ('bfloat16', 'bfloat16', 4, 1, 128, 128, 512, 1024): (8, 32),
915
+ ('bfloat16', 'bfloat16', 4, 1, 128, 128, 512, 1280): (4, 32),
916
+ ('bfloat16', 'bfloat16', 4, 1, 128, 128, 512, 2048): (16, 32),
917
+ ('bfloat16', 'bfloat16', 4, 1, 128, 128, 512, 512): (4, 64),
918
+ ('bfloat16', 'bfloat16', 4, 1, 128, 16, 1024, 128): (8, 32),
919
+ ('bfloat16', 'bfloat16', 4, 1, 128, 16, 1024, 256): (16, 128),
920
+ ('bfloat16', 'bfloat16', 4, 1, 128, 16, 1024, 64): (4, 64),
921
+ ('bfloat16', 'bfloat16', 4, 1, 128, 16, 2048, 128): (8, 32),
922
+ ('bfloat16', 'bfloat16', 4, 1, 128, 16, 2048, 256): (16, 128),
923
+ ('bfloat16', 'bfloat16', 4, 1, 128, 16, 2048, 64): (4, 32),
924
+ ('bfloat16', 'bfloat16', 4, 1, 128, 16, 4096, 128): (4, 64),
925
+ ('bfloat16', 'bfloat16', 4, 1, 128, 16, 4096, 256): (16, 64),
926
+ ('bfloat16', 'bfloat16', 4, 1, 128, 16, 4096, 64): (4, 32),
927
+ ('bfloat16', 'bfloat16', 4, 1, 128, 16, 512, 128): (8, 64),
928
+ ('bfloat16', 'bfloat16', 4, 1, 128, 16, 512, 256): (16, 128),
929
+ ('bfloat16', 'bfloat16', 4, 1, 128, 16, 512, 64): (4, 32),
930
+ ('bfloat16', 'bfloat16', 4, 1, 128, 256, 1024, 1024): (4, 128),
931
+ ('bfloat16', 'bfloat16', 4, 1, 128, 256, 1024, 1280): (4, 32),
932
+ ('bfloat16', 'bfloat16', 4, 1, 128, 256, 1024, 2048): (8, 64),
933
+ ('bfloat16', 'bfloat16', 4, 1, 128, 256, 1024, 4096): (16, 64),
934
+ ('bfloat16', 'bfloat16', 4, 1, 128, 256, 2048, 1024): (4, 32),
935
+ ('bfloat16', 'bfloat16', 4, 1, 128, 256, 2048, 1280): (4, 64),
936
+ ('bfloat16', 'bfloat16', 4, 1, 128, 256, 2048, 2048): (4, 64),
937
+ ('bfloat16', 'bfloat16', 4, 1, 128, 256, 2048, 4096): (16, 64),
938
+ ('bfloat16', 'bfloat16', 4, 1, 128, 256, 4096, 1024): (4, 32),
939
+ ('bfloat16', 'bfloat16', 4, 1, 128, 256, 4096, 1280): (4, 64),
940
+ ('bfloat16', 'bfloat16', 4, 1, 128, 256, 4096, 2048): (8, 64),
941
+ ('bfloat16', 'bfloat16', 4, 1, 128, 256, 4096, 4096): (16, 64),
942
+ ('bfloat16', 'bfloat16', 4, 1, 128, 256, 512, 1024): (4, 32),
943
+ ('bfloat16', 'bfloat16', 4, 1, 128, 256, 512, 1280): (4, 32),
944
+ ('bfloat16', 'bfloat16', 4, 1, 128, 256, 512, 2048): (8, 64),
945
+ ('bfloat16', 'bfloat16', 4, 1, 128, 256, 512, 4096): (16, 32),
946
+ ('bfloat16', 'bfloat16', 4, 1, 128, 32, 1024, 128): (4, 32),
947
+ ('bfloat16', 'bfloat16', 4, 1, 128, 32, 1024, 256): (8, 32),
948
+ ('bfloat16', 'bfloat16', 4, 1, 128, 32, 1024, 512): (16, 64),
949
+ ('bfloat16', 'bfloat16', 4, 1, 128, 32, 2048, 128): (4, 32),
950
+ ('bfloat16', 'bfloat16', 4, 1, 128, 32, 2048, 256): (8, 32),
951
+ ('bfloat16', 'bfloat16', 4, 1, 128, 32, 2048, 512): (16, 32),
952
+ ('bfloat16', 'bfloat16', 4, 1, 128, 32, 4096, 128): (4, 64),
953
+ ('bfloat16', 'bfloat16', 4, 1, 128, 32, 4096, 256): (8, 64),
954
+ ('bfloat16', 'bfloat16', 4, 1, 128, 32, 4096, 512): (8, 64),
955
+ ('bfloat16', 'bfloat16', 4, 1, 128, 32, 512, 128): (4, 64),
956
+ ('bfloat16', 'bfloat16', 4, 1, 128, 32, 512, 256): (8, 32),
957
+ ('bfloat16', 'bfloat16', 4, 1, 128, 32, 512, 512): (16, 64),
958
+ ('bfloat16', 'bfloat16', 4, 1, 128, 64, 1024, 1024): (8, 64),
959
+ ('bfloat16', 'bfloat16', 4, 1, 128, 64, 1024, 256): (4, 64),
960
+ ('bfloat16', 'bfloat16', 4, 1, 128, 64, 1024, 512): (8, 64),
961
+ ('bfloat16', 'bfloat16', 4, 1, 128, 64, 2048, 1024): (8, 32),
962
+ ('bfloat16', 'bfloat16', 4, 1, 128, 64, 2048, 256): (4, 64),
963
+ ('bfloat16', 'bfloat16', 4, 1, 128, 64, 2048, 512): (8, 32),
964
+ ('bfloat16', 'bfloat16', 4, 1, 128, 64, 4096, 1024): (16, 32),
965
+ ('bfloat16', 'bfloat16', 4, 1, 128, 64, 4096, 256): (4, 64),
966
+ ('bfloat16', 'bfloat16', 4, 1, 128, 64, 4096, 512): (4, 64),
967
+ ('bfloat16', 'bfloat16', 4, 1, 128, 64, 512, 1024): (16, 32),
968
+ ('bfloat16', 'bfloat16', 4, 1, 128, 64, 512, 256): (4, 64),
969
+ ('bfloat16', 'bfloat16', 4, 1, 128, 64, 512, 512): (8, 32),
970
+ ('bfloat16', 'bfloat16', 4, 2, 128, 128, 1024, 1024): (8, 64),
971
+ ('bfloat16', 'bfloat16', 4, 2, 128, 128, 1024, 1280): (8, 64),
972
+ ('bfloat16', 'bfloat16', 4, 2, 128, 128, 1024, 2048): (16, 64),
973
+ ('bfloat16', 'bfloat16', 4, 2, 128, 128, 1024, 512): (4, 32),
974
+ ('bfloat16', 'bfloat16', 4, 2, 128, 128, 2048, 1024): (8, 64),
975
+ ('bfloat16', 'bfloat16', 4, 2, 128, 128, 2048, 1280): (4, 64),
976
+ ('bfloat16', 'bfloat16', 4, 2, 128, 128, 2048, 2048): (16, 64),
977
+ ('bfloat16', 'bfloat16', 4, 2, 128, 128, 2048, 512): (4, 64),
978
+ ('bfloat16', 'bfloat16', 4, 2, 128, 128, 4096, 1024): (8, 64),
979
+ ('bfloat16', 'bfloat16', 4, 2, 128, 128, 4096, 1280): (4, 32),
980
+ ('bfloat16', 'bfloat16', 4, 2, 128, 128, 4096, 2048): (16, 64),
981
+ ('bfloat16', 'bfloat16', 4, 2, 128, 128, 4096, 512): (4, 64),
982
+ ('bfloat16', 'bfloat16', 4, 2, 128, 128, 512, 1024): (8, 32),
983
+ ('bfloat16', 'bfloat16', 4, 2, 128, 128, 512, 1280): (4, 32),
984
+ ('bfloat16', 'bfloat16', 4, 2, 128, 128, 512, 2048): (16, 64),
985
+ ('bfloat16', 'bfloat16', 4, 2, 128, 128, 512, 512): (4, 32),
986
+ ('bfloat16', 'bfloat16', 4, 2, 128, 16, 1024, 128): (4, 32),
987
+ ('bfloat16', 'bfloat16', 4, 2, 128, 16, 1024, 256): (8, 32),
988
+ ('bfloat16', 'bfloat16', 4, 2, 128, 16, 1024, 64): (4, 32),
989
+ ('bfloat16', 'bfloat16', 4, 2, 128, 16, 2048, 128): (8, 32),
990
+ ('bfloat16', 'bfloat16', 4, 2, 128, 16, 2048, 256): (16, 32),
991
+ ('bfloat16', 'bfloat16', 4, 2, 128, 16, 2048, 64): (4, 32),
992
+ ('bfloat16', 'bfloat16', 4, 2, 128, 16, 4096, 128): (8, 64),
993
+ ('bfloat16', 'bfloat16', 4, 2, 128, 16, 4096, 256): (16, 64),
994
+ ('bfloat16', 'bfloat16', 4, 2, 128, 16, 4096, 64): (4, 64),
995
+ ('bfloat16', 'bfloat16', 4, 2, 128, 16, 512, 128): (8, 32),
996
+ ('bfloat16', 'bfloat16', 4, 2, 128, 16, 512, 256): (16, 32),
997
+ ('bfloat16', 'bfloat16', 4, 2, 128, 16, 512, 64): (4, 64),
998
+ ('bfloat16', 'bfloat16', 4, 2, 128, 256, 1024, 1024): (4, 32),
999
+ ('bfloat16', 'bfloat16', 4, 2, 128, 256, 1024, 1280): (4, 64),
1000
+ ('bfloat16', 'bfloat16', 4, 2, 128, 256, 1024, 2048): (8, 64),
1001
+ ('bfloat16', 'bfloat16', 4, 2, 128, 256, 1024, 4096): (16, 64),
1002
+ ('bfloat16', 'bfloat16', 4, 2, 128, 256, 2048, 1024): (4, 64),
1003
+ ('bfloat16', 'bfloat16', 4, 2, 128, 256, 2048, 1280): (4, 64),
1004
+ ('bfloat16', 'bfloat16', 4, 2, 128, 256, 2048, 2048): (8, 64),
1005
+ ('bfloat16', 'bfloat16', 4, 2, 128, 256, 2048, 4096): (8, 64),
1006
+ ('bfloat16', 'bfloat16', 4, 2, 128, 256, 4096, 1024): (4, 64),
1007
+ ('bfloat16', 'bfloat16', 4, 2, 128, 256, 4096, 1280): (4, 64),
1008
+ ('bfloat16', 'bfloat16', 4, 2, 128, 256, 4096, 2048): (8, 64),
1009
+ ('bfloat16', 'bfloat16', 4, 2, 128, 256, 4096, 4096): (16, 64),
1010
+ ('bfloat16', 'bfloat16', 4, 2, 128, 256, 512, 1024): (4, 64),
1011
+ ('bfloat16', 'bfloat16', 4, 2, 128, 256, 512, 1280): (4, 64),
1012
+ ('bfloat16', 'bfloat16', 4, 2, 128, 256, 512, 2048): (8, 64),
1013
+ ('bfloat16', 'bfloat16', 4, 2, 128, 256, 512, 4096): (16, 64),
1014
+ ('bfloat16', 'bfloat16', 4, 2, 128, 32, 1024, 128): (4, 64),
1015
+ ('bfloat16', 'bfloat16', 4, 2, 128, 32, 1024, 256): (8, 64),
1016
+ ('bfloat16', 'bfloat16', 4, 2, 128, 32, 1024, 512): (16, 32),
1017
+ ('bfloat16', 'bfloat16', 4, 2, 128, 32, 2048, 128): (4, 32),
1018
+ ('bfloat16', 'bfloat16', 4, 2, 128, 32, 2048, 256): (8, 128),
1019
+ ('bfloat16', 'bfloat16', 4, 2, 128, 32, 2048, 512): (16, 64),
1020
+ ('bfloat16', 'bfloat16', 4, 2, 128, 32, 4096, 128): (4, 32),
1021
+ ('bfloat16', 'bfloat16', 4, 2, 128, 32, 4096, 256): (8, 32),
1022
+ ('bfloat16', 'bfloat16', 4, 2, 128, 32, 4096, 512): (16, 64),
1023
+ ('bfloat16', 'bfloat16', 4, 2, 128, 32, 512, 128): (4, 64),
1024
+ ('bfloat16', 'bfloat16', 4, 2, 128, 32, 512, 256): (8, 64),
1025
+ ('bfloat16', 'bfloat16', 4, 2, 128, 32, 512, 512): (16, 64),
1026
+ ('bfloat16', 'bfloat16', 4, 2, 128, 64, 1024, 1024): (16, 32),
1027
+ ('bfloat16', 'bfloat16', 4, 2, 128, 64, 1024, 256): (4, 32),
1028
+ ('bfloat16', 'bfloat16', 4, 2, 128, 64, 1024, 512): (8, 64),
1029
+ ('bfloat16', 'bfloat16', 4, 2, 128, 64, 2048, 1024): (16, 64),
1030
+ ('bfloat16', 'bfloat16', 4, 2, 128, 64, 2048, 256): (4, 32),
1031
+ ('bfloat16', 'bfloat16', 4, 2, 128, 64, 2048, 512): (8, 32),
1032
+ ('bfloat16', 'bfloat16', 4, 2, 128, 64, 4096, 1024): (16, 32),
1033
+ ('bfloat16', 'bfloat16', 4, 2, 128, 64, 4096, 256): (4, 32),
1034
+ ('bfloat16', 'bfloat16', 4, 2, 128, 64, 4096, 512): (8, 64),
1035
+ ('bfloat16', 'bfloat16', 4, 2, 128, 64, 512, 1024): (16, 64),
1036
+ ('bfloat16', 'bfloat16', 4, 2, 128, 64, 512, 256): (4, 64),
1037
+ ('bfloat16', 'bfloat16', 4, 2, 128, 64, 512, 512): (8, 32),
1038
+ ('bfloat16', 'bfloat16', 5, 1, 128, 128, 1024, 1024): (8, 32),
1039
+ ('bfloat16', 'bfloat16', 5, 1, 128, 128, 1024, 1280): (8, 32),
1040
+ ('bfloat16', 'bfloat16', 5, 1, 128, 128, 1024, 2048): (16, 32),
1041
+ ('bfloat16', 'bfloat16', 5, 1, 128, 128, 1024, 512): (4, 32),
1042
+ ('bfloat16', 'bfloat16', 5, 1, 128, 128, 2048, 1024): (8, 32),
1043
+ ('bfloat16', 'bfloat16', 5, 1, 128, 128, 2048, 1280): (8, 32),
1044
+ ('bfloat16', 'bfloat16', 5, 1, 128, 128, 2048, 2048): (16, 32),
1045
+ ('bfloat16', 'bfloat16', 5, 1, 128, 128, 2048, 512): (4, 32),
1046
+ ('bfloat16', 'bfloat16', 5, 1, 128, 128, 4096, 1024): (8, 32),
1047
+ ('bfloat16', 'bfloat16', 5, 1, 128, 128, 4096, 1280): (8, 32),
1048
+ ('bfloat16', 'bfloat16', 5, 1, 128, 128, 4096, 2048): (16, 32),
1049
+ ('bfloat16', 'bfloat16', 5, 1, 128, 128, 4096, 512): (4, 32),
1050
+ ('bfloat16', 'bfloat16', 5, 1, 128, 128, 512, 1024): (8, 32),
1051
+ ('bfloat16', 'bfloat16', 5, 1, 128, 128, 512, 1280): (4, 32),
1052
+ ('bfloat16', 'bfloat16', 5, 1, 128, 128, 512, 2048): (16, 32),
1053
+ ('bfloat16', 'bfloat16', 5, 1, 128, 128, 512, 512): (4, 32),
1054
+ ('bfloat16', 'bfloat16', 5, 1, 128, 16, 1024, 128): (8, 32),
1055
+ ('bfloat16', 'bfloat16', 5, 1, 128, 16, 1024, 256): (16, 32),
1056
+ ('bfloat16', 'bfloat16', 5, 1, 128, 16, 1024, 64): (4, 32),
1057
+ ('bfloat16', 'bfloat16', 5, 1, 128, 16, 2048, 128): (8, 32),
1058
+ ('bfloat16', 'bfloat16', 5, 1, 128, 16, 2048, 256): (16, 32),
1059
+ ('bfloat16', 'bfloat16', 5, 1, 128, 16, 2048, 64): (4, 32),
1060
+ ('bfloat16', 'bfloat16', 5, 1, 128, 16, 4096, 128): (8, 32),
1061
+ ('bfloat16', 'bfloat16', 5, 1, 128, 16, 4096, 256): (16, 32),
1062
+ ('bfloat16', 'bfloat16', 5, 1, 128, 16, 4096, 64): (4, 32),
1063
+ ('bfloat16', 'bfloat16', 5, 1, 128, 16, 512, 128): (8, 32),
1064
+ ('bfloat16', 'bfloat16', 5, 1, 128, 16, 512, 256): (16, 32),
1065
+ ('bfloat16', 'bfloat16', 5, 1, 128, 16, 512, 64): (4, 32),
1066
+ ('bfloat16', 'bfloat16', 5, 1, 128, 256, 1024, 1024): (4, 32),
1067
+ ('bfloat16', 'bfloat16', 5, 1, 128, 256, 1024, 1280): (4, 32),
1068
+ ('bfloat16', 'bfloat16', 5, 1, 128, 256, 1024, 2048): (8, 32),
1069
+ ('bfloat16', 'bfloat16', 5, 1, 128, 256, 1024, 4096): (16, 32),
1070
+ ('bfloat16', 'bfloat16', 5, 1, 128, 256, 2048, 1024): (4, 32),
1071
+ ('bfloat16', 'bfloat16', 5, 1, 128, 256, 2048, 1280): (4, 32),
1072
+ ('bfloat16', 'bfloat16', 5, 1, 128, 256, 2048, 2048): (8, 32),
1073
+ ('bfloat16', 'bfloat16', 5, 1, 128, 256, 2048, 4096): (16, 32),
1074
+ ('bfloat16', 'bfloat16', 5, 1, 128, 256, 4096, 1024): (4, 32),
1075
+ ('bfloat16', 'bfloat16', 5, 1, 128, 256, 4096, 1280): (4, 32),
1076
+ ('bfloat16', 'bfloat16', 5, 1, 128, 256, 4096, 2048): (8, 32),
1077
+ ('bfloat16', 'bfloat16', 5, 1, 128, 256, 4096, 4096): (16, 32),
1078
+ ('bfloat16', 'bfloat16', 5, 1, 128, 256, 512, 1024): (4, 32),
1079
+ ('bfloat16', 'bfloat16', 5, 1, 128, 256, 512, 1280): (4, 32),
1080
+ ('bfloat16', 'bfloat16', 5, 1, 128, 256, 512, 2048): (8, 32),
1081
+ ('bfloat16', 'bfloat16', 5, 1, 128, 256, 512, 4096): (16, 32),
1082
+ ('bfloat16', 'bfloat16', 5, 1, 128, 32, 1024, 128): (4, 32),
1083
+ ('bfloat16', 'bfloat16', 5, 1, 128, 32, 1024, 256): (8, 32),
1084
+ ('bfloat16', 'bfloat16', 5, 1, 128, 32, 1024, 512): (16, 32),
1085
+ ('bfloat16', 'bfloat16', 5, 1, 128, 32, 2048, 128): (4, 32),
1086
+ ('bfloat16', 'bfloat16', 5, 1, 128, 32, 2048, 256): (8, 32),
1087
+ ('bfloat16', 'bfloat16', 5, 1, 128, 32, 2048, 512): (16, 32),
1088
+ ('bfloat16', 'bfloat16', 5, 1, 128, 32, 4096, 128): (4, 32),
1089
+ ('bfloat16', 'bfloat16', 5, 1, 128, 32, 4096, 256): (8, 32),
1090
+ ('bfloat16', 'bfloat16', 5, 1, 128, 32, 4096, 512): (16, 32),
1091
+ ('bfloat16', 'bfloat16', 5, 1, 128, 32, 512, 128): (4, 32),
1092
+ ('bfloat16', 'bfloat16', 5, 1, 128, 32, 512, 256): (8, 32),
1093
+ ('bfloat16', 'bfloat16', 5, 1, 128, 32, 512, 512): (16, 32),
1094
+ ('bfloat16', 'bfloat16', 5, 1, 128, 64, 1024, 1024): (16, 32),
1095
+ ('bfloat16', 'bfloat16', 5, 1, 128, 64, 1024, 256): (4, 32),
1096
+ ('bfloat16', 'bfloat16', 5, 1, 128, 64, 1024, 512): (8, 32),
1097
+ ('bfloat16', 'bfloat16', 5, 1, 128, 64, 2048, 1024): (16, 32),
1098
+ ('bfloat16', 'bfloat16', 5, 1, 128, 64, 2048, 256): (4, 64),
1099
+ ('bfloat16', 'bfloat16', 5, 1, 128, 64, 2048, 512): (8, 32),
1100
+ ('bfloat16', 'bfloat16', 5, 1, 128, 64, 4096, 1024): (16, 32),
1101
+ ('bfloat16', 'bfloat16', 5, 1, 128, 64, 4096, 256): (4, 32),
1102
+ ('bfloat16', 'bfloat16', 5, 1, 128, 64, 4096, 512): (8, 32),
1103
+ ('bfloat16', 'bfloat16', 5, 1, 128, 64, 512, 1024): (16, 32),
1104
+ ('bfloat16', 'bfloat16', 5, 1, 128, 64, 512, 256): (4, 32),
1105
+ ('bfloat16', 'bfloat16', 5, 1, 128, 64, 512, 512): (8, 32),
1106
+ ('bfloat16', 'bfloat16', 6, 1, 128, 128, 1024, 1024): (8, 64),
1107
+ ('bfloat16', 'bfloat16', 6, 1, 128, 128, 1024, 1280): (8, 32),
1108
+ ('bfloat16', 'bfloat16', 6, 1, 128, 128, 1024, 2048): (16, 32),
1109
+ ('bfloat16', 'bfloat16', 6, 1, 128, 128, 1024, 512): (4, 32),
1110
+ ('bfloat16', 'bfloat16', 6, 1, 128, 128, 2048, 1024): (8, 32),
1111
+ ('bfloat16', 'bfloat16', 6, 1, 128, 128, 2048, 1280): (8, 32),
1112
+ ('bfloat16', 'bfloat16', 6, 1, 128, 128, 2048, 2048): (16, 32),
1113
+ ('bfloat16', 'bfloat16', 6, 1, 128, 128, 2048, 512): (4, 32),
1114
+ ('bfloat16', 'bfloat16', 6, 1, 128, 128, 4096, 1024): (8, 32),
1115
+ ('bfloat16', 'bfloat16', 6, 1, 128, 128, 4096, 1280): (4, 32),
1116
+ ('bfloat16', 'bfloat16', 6, 1, 128, 128, 4096, 2048): (16, 32),
1117
+ ('bfloat16', 'bfloat16', 6, 1, 128, 128, 4096, 512): (4, 32),
1118
+ ('bfloat16', 'bfloat16', 6, 1, 128, 128, 512, 1024): (8, 32),
1119
+ ('bfloat16', 'bfloat16', 6, 1, 128, 128, 512, 1280): (4, 32),
1120
+ ('bfloat16', 'bfloat16', 6, 1, 128, 128, 512, 2048): (16, 32),
1121
+ ('bfloat16', 'bfloat16', 6, 1, 128, 128, 512, 512): (4, 32),
1122
+ ('bfloat16', 'bfloat16', 6, 1, 128, 16, 1024, 128): (8, 32),
1123
+ ('bfloat16', 'bfloat16', 6, 1, 128, 16, 1024, 256): (16, 64),
1124
+ ('bfloat16', 'bfloat16', 6, 1, 128, 16, 1024, 64): (4, 32),
1125
+ ('bfloat16', 'bfloat16', 6, 1, 128, 16, 2048, 128): (8, 32),
1126
+ ('bfloat16', 'bfloat16', 6, 1, 128, 16, 2048, 256): (16, 32),
1127
+ ('bfloat16', 'bfloat16', 6, 1, 128, 16, 2048, 64): (4, 32),
1128
+ ('bfloat16', 'bfloat16', 6, 1, 128, 16, 4096, 128): (8, 32),
1129
+ ('bfloat16', 'bfloat16', 6, 1, 128, 16, 4096, 256): (16, 64),
1130
+ ('bfloat16', 'bfloat16', 6, 1, 128, 16, 4096, 64): (4, 32),
1131
+ ('bfloat16', 'bfloat16', 6, 1, 128, 16, 512, 128): (8, 32),
1132
+ ('bfloat16', 'bfloat16', 6, 1, 128, 16, 512, 256): (16, 32),
1133
+ ('bfloat16', 'bfloat16', 6, 1, 128, 16, 512, 64): (4, 32),
1134
+ ('bfloat16', 'bfloat16', 6, 1, 128, 256, 1024, 1024): (4, 32),
1135
+ ('bfloat16', 'bfloat16', 6, 1, 128, 256, 1024, 1280): (4, 32),
1136
+ ('bfloat16', 'bfloat16', 6, 1, 128, 256, 1024, 2048): (8, 32),
1137
+ ('bfloat16', 'bfloat16', 6, 1, 128, 256, 1024, 4096): (16, 32),
1138
+ ('bfloat16', 'bfloat16', 6, 1, 128, 256, 2048, 1024): (4, 32),
1139
+ ('bfloat16', 'bfloat16', 6, 1, 128, 256, 2048, 1280): (4, 32),
1140
+ ('bfloat16', 'bfloat16', 6, 1, 128, 256, 2048, 2048): (8, 32),
1141
+ ('bfloat16', 'bfloat16', 6, 1, 128, 256, 2048, 4096): (16, 32),
1142
+ ('bfloat16', 'bfloat16', 6, 1, 128, 256, 4096, 1024): (4, 64),
1143
+ ('bfloat16', 'bfloat16', 6, 1, 128, 256, 4096, 1280): (4, 32),
1144
+ ('bfloat16', 'bfloat16', 6, 1, 128, 256, 4096, 2048): (8, 32),
1145
+ ('bfloat16', 'bfloat16', 6, 1, 128, 256, 4096, 4096): (16, 32),
1146
+ ('bfloat16', 'bfloat16', 6, 1, 128, 256, 512, 1024): (4, 32),
1147
+ ('bfloat16', 'bfloat16', 6, 1, 128, 256, 512, 1280): (4, 32),
1148
+ ('bfloat16', 'bfloat16', 6, 1, 128, 256, 512, 2048): (8, 32),
1149
+ ('bfloat16', 'bfloat16', 6, 1, 128, 256, 512, 4096): (16, 32),
1150
+ ('bfloat16', 'bfloat16', 6, 1, 128, 32, 1024, 128): (4, 32),
1151
+ ('bfloat16', 'bfloat16', 6, 1, 128, 32, 1024, 256): (8, 32),
1152
+ ('bfloat16', 'bfloat16', 6, 1, 128, 32, 1024, 512): (16, 32),
1153
+ ('bfloat16', 'bfloat16', 6, 1, 128, 32, 2048, 128): (4, 32),
1154
+ ('bfloat16', 'bfloat16', 6, 1, 128, 32, 2048, 256): (8, 32),
1155
+ ('bfloat16', 'bfloat16', 6, 1, 128, 32, 2048, 512): (16, 32),
1156
+ ('bfloat16', 'bfloat16', 6, 1, 128, 32, 4096, 128): (4, 32),
1157
+ ('bfloat16', 'bfloat16', 6, 1, 128, 32, 4096, 256): (8, 64),
1158
+ ('bfloat16', 'bfloat16', 6, 1, 128, 32, 4096, 512): (16, 32),
1159
+ ('bfloat16', 'bfloat16', 6, 1, 128, 32, 512, 128): (4, 32),
1160
+ ('bfloat16', 'bfloat16', 6, 1, 128, 32, 512, 256): (8, 64),
1161
+ ('bfloat16', 'bfloat16', 6, 1, 128, 32, 512, 512): (16, 32),
1162
+ ('bfloat16', 'bfloat16', 6, 1, 128, 64, 1024, 1024): (16, 32),
1163
+ ('bfloat16', 'bfloat16', 6, 1, 128, 64, 1024, 256): (4, 32),
1164
+ ('bfloat16', 'bfloat16', 6, 1, 128, 64, 1024, 512): (8, 32),
1165
+ ('bfloat16', 'bfloat16', 6, 1, 128, 64, 2048, 1024): (8, 32),
1166
+ ('bfloat16', 'bfloat16', 6, 1, 128, 64, 2048, 256): (4, 32),
1167
+ ('bfloat16', 'bfloat16', 6, 1, 128, 64, 2048, 512): (8, 32),
1168
+ ('bfloat16', 'bfloat16', 6, 1, 128, 64, 4096, 1024): (16, 32),
1169
+ ('bfloat16', 'bfloat16', 6, 1, 128, 64, 4096, 256): (4, 32),
1170
+ ('bfloat16', 'bfloat16', 6, 1, 128, 64, 4096, 512): (8, 32),
1171
+ ('bfloat16', 'bfloat16', 6, 1, 128, 64, 512, 1024): (16, 32),
1172
+ ('bfloat16', 'bfloat16', 6, 1, 128, 64, 512, 256): (4, 32),
1173
+ ('bfloat16', 'bfloat16', 6, 1, 128, 64, 512, 512): (8, 32),
1174
+ ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 1024): (8, 32),
1175
+ ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 1280): (8, 32),
1176
+ ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 2048): (16, 32),
1177
+ ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 512): (4, 32),
1178
+ ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 1024): (8, 32),
1179
+ ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 1280): (4, 32),
1180
+ ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 2048): (8, 32),
1181
+ ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 512): (4, 32),
1182
+ ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 1024): (8, 64),
1183
+ ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 1280): (4, 32),
1184
+ ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 2048): (16, 32),
1185
+ ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 512): (4, 32),
1186
+ ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 1024): (8, 32),
1187
+ ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 1280): (4, 32),
1188
+ ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 2048): (16, 32),
1189
+ ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 512): (4, 32),
1190
+ ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 128): (4, 32),
1191
+ ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 256): (8, 32),
1192
+ ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 64): (4, 64),
1193
+ ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 128): (8, 32),
1194
+ ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 256): (16, 64),
1195
+ ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 64): (4, 128),
1196
+ ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 128): (8, 64),
1197
+ ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 256): (16, 64),
1198
+ ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 64): (4, 64),
1199
+ ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 128): (4, 32),
1200
+ ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 256): (16, 32),
1201
+ ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 64): (4, 32),
1202
+ ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 1024): (4, 32),
1203
+ ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 1280): (4, 32),
1204
+ ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 2048): (8, 32),
1205
+ ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 4096): (16, 32),
1206
+ ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 1024): (4, 32),
1207
+ ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 1280): (4, 32),
1208
+ ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 2048): (8, 32),
1209
+ ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 4096): (16, 32),
1210
+ ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 1024): (4, 64),
1211
+ ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 1280): (4, 32),
1212
+ ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 2048): (4, 32),
1213
+ ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 4096): (16, 32),
1214
+ ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 1024): (4, 32),
1215
+ ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 1280): (4, 32),
1216
+ ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 2048): (8, 32),
1217
+ ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 4096): (4, 32),
1218
+ ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 128): (4, 128),
1219
+ ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 256): (8, 64),
1220
+ ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 512): (16, 32),
1221
+ ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 128): (4, 32),
1222
+ ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 256): (8, 32),
1223
+ ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 512): (16, 64),
1224
+ ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 128): (4, 64),
1225
+ ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 256): (8, 64),
1226
+ ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 512): (16, 32),
1227
+ ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 128): (4, 32),
1228
+ ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 256): (8, 32),
1229
+ ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 512): (16, 64),
1230
+ ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 1024): (16, 32),
1231
+ ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 256): (4, 64),
1232
+ ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 512): (8, 32),
1233
+ ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 1024): (8, 32),
1234
+ ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 256): (4, 32),
1235
+ ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 512): (8, 64),
1236
+ ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 1024): (16, 32),
1237
+ ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 256): (4, 128),
1238
+ ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 512): (8, 32),
1239
+ ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 1024): (16, 32),
1240
+ ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 256): (4, 32),
1241
+ ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 512): (8, 32),
1242
+ ('bfloat16', 'bfloat16', 8, 2, 128, 128, 1024, 1024): (8, 32),
1243
+ ('bfloat16', 'bfloat16', 8, 2, 128, 128, 1024, 1280): (4, 32),
1244
+ ('bfloat16', 'bfloat16', 8, 2, 128, 128, 1024, 2048): (16, 32),
1245
+ ('bfloat16', 'bfloat16', 8, 2, 128, 128, 1024, 512): (4, 32),
1246
+ ('bfloat16', 'bfloat16', 8, 2, 128, 128, 2048, 1024): (8, 32),
1247
+ ('bfloat16', 'bfloat16', 8, 2, 128, 128, 2048, 1280): (8, 32),
1248
+ ('bfloat16', 'bfloat16', 8, 2, 128, 128, 2048, 2048): (16, 32),
1249
+ ('bfloat16', 'bfloat16', 8, 2, 128, 128, 2048, 512): (4, 32),
1250
+ ('bfloat16', 'bfloat16', 8, 2, 128, 128, 4096, 1024): (8, 32),
1251
+ ('bfloat16', 'bfloat16', 8, 2, 128, 128, 4096, 1280): (4, 32),
1252
+ ('bfloat16', 'bfloat16', 8, 2, 128, 128, 4096, 2048): (16, 32),
1253
+ ('bfloat16', 'bfloat16', 8, 2, 128, 128, 4096, 512): (4, 32),
1254
+ ('bfloat16', 'bfloat16', 8, 2, 128, 128, 512, 1024): (8, 32),
1255
+ ('bfloat16', 'bfloat16', 8, 2, 128, 128, 512, 1280): (4, 32),
1256
+ ('bfloat16', 'bfloat16', 8, 2, 128, 128, 512, 2048): (8, 32),
1257
+ ('bfloat16', 'bfloat16', 8, 2, 128, 128, 512, 512): (4, 32),
1258
+ ('bfloat16', 'bfloat16', 8, 2, 128, 16, 1024, 128): (8, 64),
1259
+ ('bfloat16', 'bfloat16', 8, 2, 128, 16, 1024, 256): (16, 32),
1260
+ ('bfloat16', 'bfloat16', 8, 2, 128, 16, 1024, 64): (4, 32),
1261
+ ('bfloat16', 'bfloat16', 8, 2, 128, 16, 2048, 128): (8, 32),
1262
+ ('bfloat16', 'bfloat16', 8, 2, 128, 16, 2048, 256): (16, 32),
1263
+ ('bfloat16', 'bfloat16', 8, 2, 128, 16, 2048, 64): (4, 32),
1264
+ ('bfloat16', 'bfloat16', 8, 2, 128, 16, 4096, 128): (8, 64),
1265
+ ('bfloat16', 'bfloat16', 8, 2, 128, 16, 4096, 256): (16, 64),
1266
+ ('bfloat16', 'bfloat16', 8, 2, 128, 16, 4096, 64): (4, 32),
1267
+ ('bfloat16', 'bfloat16', 8, 2, 128, 16, 512, 128): (8, 32),
1268
+ ('bfloat16', 'bfloat16', 8, 2, 128, 16, 512, 256): (16, 32),
1269
+ ('bfloat16', 'bfloat16', 8, 2, 128, 16, 512, 64): (4, 32),
1270
+ ('bfloat16', 'bfloat16', 8, 2, 128, 256, 1024, 1024): (4, 32),
1271
+ ('bfloat16', 'bfloat16', 8, 2, 128, 256, 1024, 1280): (4, 32),
1272
+ ('bfloat16', 'bfloat16', 8, 2, 128, 256, 1024, 2048): (4, 32),
1273
+ ('bfloat16', 'bfloat16', 8, 2, 128, 256, 1024, 4096): (8, 32),
1274
+ ('bfloat16', 'bfloat16', 8, 2, 128, 256, 2048, 1024): (4, 32),
1275
+ ('bfloat16', 'bfloat16', 8, 2, 128, 256, 2048, 1280): (4, 32),
1276
+ ('bfloat16', 'bfloat16', 8, 2, 128, 256, 2048, 2048): (4, 32),
1277
+ ('bfloat16', 'bfloat16', 8, 2, 128, 256, 2048, 4096): (16, 32),
1278
+ ('bfloat16', 'bfloat16', 8, 2, 128, 256, 4096, 1024): (4, 32),
1279
+ ('bfloat16', 'bfloat16', 8, 2, 128, 256, 4096, 1280): (4, 32),
1280
+ ('bfloat16', 'bfloat16', 8, 2, 128, 256, 4096, 2048): (8, 32),
1281
+ ('bfloat16', 'bfloat16', 8, 2, 128, 256, 4096, 4096): (16, 64),
1282
+ ('bfloat16', 'bfloat16', 8, 2, 128, 256, 512, 1024): (4, 32),
1283
+ ('bfloat16', 'bfloat16', 8, 2, 128, 256, 512, 1280): (4, 32),
1284
+ ('bfloat16', 'bfloat16', 8, 2, 128, 256, 512, 2048): (8, 32),
1285
+ ('bfloat16', 'bfloat16', 8, 2, 128, 256, 512, 4096): (16, 32),
1286
+ ('bfloat16', 'bfloat16', 8, 2, 128, 32, 1024, 128): (4, 32),
1287
+ ('bfloat16', 'bfloat16', 8, 2, 128, 32, 1024, 256): (8, 32),
1288
+ ('bfloat16', 'bfloat16', 8, 2, 128, 32, 1024, 512): (16, 64),
1289
+ ('bfloat16', 'bfloat16', 8, 2, 128, 32, 2048, 128): (4, 32),
1290
+ ('bfloat16', 'bfloat16', 8, 2, 128, 32, 2048, 256): (8, 32),
1291
+ ('bfloat16', 'bfloat16', 8, 2, 128, 32, 2048, 512): (16, 64),
1292
+ ('bfloat16', 'bfloat16', 8, 2, 128, 32, 4096, 128): (4, 32),
1293
+ ('bfloat16', 'bfloat16', 8, 2, 128, 32, 4096, 256): (8, 64),
1294
+ ('bfloat16', 'bfloat16', 8, 2, 128, 32, 4096, 512): (16, 32),
1295
+ ('bfloat16', 'bfloat16', 8, 2, 128, 32, 512, 128): (4, 32),
1296
+ ('bfloat16', 'bfloat16', 8, 2, 128, 32, 512, 256): (8, 32),
1297
+ ('bfloat16', 'bfloat16', 8, 2, 128, 32, 512, 512): (16, 64),
1298
+ ('bfloat16', 'bfloat16', 8, 2, 128, 64, 1024, 1024): (16, 32),
1299
+ ('bfloat16', 'bfloat16', 8, 2, 128, 64, 1024, 256): (4, 32),
1300
+ ('bfloat16', 'bfloat16', 8, 2, 128, 64, 1024, 512): (8, 32),
1301
+ ('bfloat16', 'bfloat16', 8, 2, 128, 64, 2048, 1024): (16, 64),
1302
+ ('bfloat16', 'bfloat16', 8, 2, 128, 64, 2048, 256): (4, 32),
1303
+ ('bfloat16', 'bfloat16', 8, 2, 128, 64, 2048, 512): (8, 32),
1304
+ ('bfloat16', 'bfloat16', 8, 2, 128, 64, 4096, 1024): (8, 32),
1305
+ ('bfloat16', 'bfloat16', 8, 2, 128, 64, 4096, 256): (4, 64),
1306
+ ('bfloat16', 'bfloat16', 8, 2, 128, 64, 4096, 512): (8, 32),
1307
+ ('bfloat16', 'bfloat16', 8, 2, 128, 64, 512, 1024): (8, 32),
1308
+ ('bfloat16', 'bfloat16', 8, 2, 128, 64, 512, 256): (4, 32),
1309
+ ('bfloat16', 'bfloat16', 8, 2, 128, 64, 512, 512): (8, 32),
1310
+ ('bfloat16', 'bfloat16', 8, 4, 128, 128, 1024, 1024): (8, 64),
1311
+ ('bfloat16', 'bfloat16', 8, 4, 128, 128, 1024, 1280): (8, 64),
1312
+ ('bfloat16', 'bfloat16', 8, 4, 128, 128, 1024, 2048): (16, 64),
1313
+ ('bfloat16', 'bfloat16', 8, 4, 128, 128, 1024, 512): (4, 32),
1314
+ ('bfloat16', 'bfloat16', 8, 4, 128, 128, 2048, 1024): (8, 64),
1315
+ ('bfloat16', 'bfloat16', 8, 4, 128, 128, 2048, 1280): (8, 64),
1316
+ ('bfloat16', 'bfloat16', 8, 4, 128, 128, 2048, 2048): (16, 64),
1317
+ ('bfloat16', 'bfloat16', 8, 4, 128, 128, 2048, 512): (4, 32),
1318
+ ('bfloat16', 'bfloat16', 8, 4, 128, 128, 4096, 1024): (8, 64),
1319
+ ('bfloat16', 'bfloat16', 8, 4, 128, 128, 4096, 1280): (8, 64),
1320
+ ('bfloat16', 'bfloat16', 8, 4, 128, 128, 4096, 2048): (16, 64),
1321
+ ('bfloat16', 'bfloat16', 8, 4, 128, 128, 4096, 512): (4, 32),
1322
+ ('bfloat16', 'bfloat16', 8, 4, 128, 128, 512, 1024): (8, 32),
1323
+ ('bfloat16', 'bfloat16', 8, 4, 128, 128, 512, 1280): (4, 32),
1324
+ ('bfloat16', 'bfloat16', 8, 4, 128, 128, 512, 2048): (16, 64),
1325
+ ('bfloat16', 'bfloat16', 8, 4, 128, 128, 512, 512): (4, 64),
1326
+ ('bfloat16', 'bfloat16', 8, 4, 128, 16, 1024, 128): (4, 32),
1327
+ ('bfloat16', 'bfloat16', 8, 4, 128, 16, 1024, 256): (16, 32),
1328
+ ('bfloat16', 'bfloat16', 8, 4, 128, 16, 1024, 64): (4, 32),
1329
+ ('bfloat16', 'bfloat16', 8, 4, 128, 16, 2048, 128): (8, 32),
1330
+ ('bfloat16', 'bfloat16', 8, 4, 128, 16, 2048, 256): (16, 32),
1331
+ ('bfloat16', 'bfloat16', 8, 4, 128, 16, 2048, 64): (4, 64),
1332
+ ('bfloat16', 'bfloat16', 8, 4, 128, 16, 4096, 128): (4, 32),
1333
+ ('bfloat16', 'bfloat16', 8, 4, 128, 16, 4096, 256): (16, 32),
1334
+ ('bfloat16', 'bfloat16', 8, 4, 128, 16, 4096, 64): (4, 32),
1335
+ ('bfloat16', 'bfloat16', 8, 4, 128, 16, 512, 128): (8, 32),
1336
+ ('bfloat16', 'bfloat16', 8, 4, 128, 16, 512, 256): (16, 32),
1337
+ ('bfloat16', 'bfloat16', 8, 4, 128, 16, 512, 64): (4, 32),
1338
+ ('bfloat16', 'bfloat16', 8, 4, 128, 256, 1024, 1024): (4, 64),
1339
+ ('bfloat16', 'bfloat16', 8, 4, 128, 256, 1024, 1280): (4, 64),
1340
+ ('bfloat16', 'bfloat16', 8, 4, 128, 256, 1024, 2048): (8, 64),
1341
+ ('bfloat16', 'bfloat16', 8, 4, 128, 256, 1024, 4096): (8, 64),
1342
+ ('bfloat16', 'bfloat16', 8, 4, 128, 256, 2048, 1024): (4, 64),
1343
+ ('bfloat16', 'bfloat16', 8, 4, 128, 256, 2048, 1280): (4, 64),
1344
+ ('bfloat16', 'bfloat16', 8, 4, 128, 256, 2048, 2048): (8, 64),
1345
+ ('bfloat16', 'bfloat16', 8, 4, 128, 256, 2048, 4096): (16, 64),
1346
+ ('bfloat16', 'bfloat16', 8, 4, 128, 256, 4096, 1024): (4, 64),
1347
+ ('bfloat16', 'bfloat16', 8, 4, 128, 256, 4096, 1280): (4, 64),
1348
+ ('bfloat16', 'bfloat16', 8, 4, 128, 256, 4096, 2048): (8, 64),
1349
+ ('bfloat16', 'bfloat16', 8, 4, 128, 256, 4096, 4096): (8, 64),
1350
+ ('bfloat16', 'bfloat16', 8, 4, 128, 256, 512, 1024): (4, 64),
1351
+ ('bfloat16', 'bfloat16', 8, 4, 128, 256, 512, 1280): (4, 64),
1352
+ ('bfloat16', 'bfloat16', 8, 4, 128, 256, 512, 2048): (8, 64),
1353
+ ('bfloat16', 'bfloat16', 8, 4, 128, 256, 512, 4096): (16, 64),
1354
+ ('bfloat16', 'bfloat16', 8, 4, 128, 32, 1024, 128): (4, 32),
1355
+ ('bfloat16', 'bfloat16', 8, 4, 128, 32, 1024, 256): (8, 32),
1356
+ ('bfloat16', 'bfloat16', 8, 4, 128, 32, 1024, 512): (16, 32),
1357
+ ('bfloat16', 'bfloat16', 8, 4, 128, 32, 2048, 128): (4, 32),
1358
+ ('bfloat16', 'bfloat16', 8, 4, 128, 32, 2048, 256): (8, 32),
1359
+ ('bfloat16', 'bfloat16', 8, 4, 128, 32, 2048, 512): (16, 64),
1360
+ ('bfloat16', 'bfloat16', 8, 4, 128, 32, 4096, 128): (4, 64),
1361
+ ('bfloat16', 'bfloat16', 8, 4, 128, 32, 4096, 256): (8, 32),
1362
+ ('bfloat16', 'bfloat16', 8, 4, 128, 32, 4096, 512): (16, 32),
1363
+ ('bfloat16', 'bfloat16', 8, 4, 128, 32, 512, 128): (4, 32),
1364
+ ('bfloat16', 'bfloat16', 8, 4, 128, 32, 512, 256): (8, 64),
1365
+ ('bfloat16', 'bfloat16', 8, 4, 128, 32, 512, 512): (16, 32),
1366
+ ('bfloat16', 'bfloat16', 8, 4, 128, 64, 1024, 1024): (16, 32),
1367
+ ('bfloat16', 'bfloat16', 8, 4, 128, 64, 1024, 256): (4, 32),
1368
+ ('bfloat16', 'bfloat16', 8, 4, 128, 64, 1024, 512): (8, 32),
1369
+ ('bfloat16', 'bfloat16', 8, 4, 128, 64, 2048, 1024): (16, 64),
1370
+ ('bfloat16', 'bfloat16', 8, 4, 128, 64, 2048, 256): (4, 32),
1371
+ ('bfloat16', 'bfloat16', 8, 4, 128, 64, 2048, 512): (8, 64),
1372
+ ('bfloat16', 'bfloat16', 8, 4, 128, 64, 4096, 1024): (16, 64),
1373
+ ('bfloat16', 'bfloat16', 8, 4, 128, 64, 4096, 256): (4, 32),
1374
+ ('bfloat16', 'bfloat16', 8, 4, 128, 64, 4096, 512): (8, 32),
1375
+ ('bfloat16', 'bfloat16', 8, 4, 128, 64, 512, 1024): (16, 64),
1376
+ ('bfloat16', 'bfloat16', 8, 4, 128, 64, 512, 256): (4, 32),
1377
+ ('bfloat16', 'bfloat16', 8, 4, 128, 64, 512, 512): (8, 32),
1378
+ # go/keep-sorted end
1379
+ },
1380
+ }
1381
+
1382
+
1383
+ def next_power_of_2(x: int):
1384
+ """Finds the smallest power of 2 >= x using bit manipulation.
1385
+
1386
+ Args:
1387
+ x: The input number (should be an integer).
1388
+
1389
+ Returns:
1390
+ The smallest integer power of 2 that is >= x.
1391
+ """
1392
+ assert x > 0
1393
+ if x == 1:
1394
+ return 1
1395
+ return 1 << (x - 1).bit_length()
1396
+
1397
+
1398
+ def simplify_key(key):
1399
+ """Simplify the key to reduce the number of combinations."""
1400
+ (
1401
+ q_dtype,
1402
+ kv_dtype,
1403
+ num_q_heads_per_blk,
1404
+ num_kv_heads_per_blk,
1405
+ head_dim,
1406
+ page_size,
1407
+ max_num_batched_tokens,
1408
+ pages_per_seq,
1409
+ ) = key
1410
+ return (
1411
+ jnp.dtype(q_dtype).name,
1412
+ jnp.dtype(kv_dtype).name,
1413
+ next_power_of_2(num_q_heads_per_blk),
1414
+ next_power_of_2(num_kv_heads_per_blk),
1415
+ (head_dim + 127) // 128 * 128,
1416
+ next_power_of_2(page_size),
1417
+ next_power_of_2(max_num_batched_tokens),
1418
+ next_power_of_2(page_size * pages_per_seq),
1419
+ )
1420
+
1421
+
1422
+ def get_tpu_version() -> int:
1423
+ """Returns the numeric version of the TPU, or -1 if not on TPU."""
1424
+ kind = jax.devices()[0].device_kind
1425
+ if 'TPU' not in kind:
1426
+ return -1
1427
+ if kind.endswith(' lite'):
1428
+ kind = kind[:-len(' lite')]
1429
+ assert kind[:-1] == 'TPU v', kind
1430
+ return int(kind[-1])
1431
+
1432
+
1433
+ def get_device_name(num_devices: int | None = None):
1434
+ name = ' '.join(jax.devices()[0].device_kind.split()[:2])
1435
+ if num_devices is not None:
1436
+ name += f'-{num_devices}'
1437
+ return name
1438
+
1439
+
1440
+ def get_tuned_block_sizes(
1441
+ q_dtype,
1442
+ kv_dtype,
1443
+ num_q_heads_per_blk,
1444
+ num_kv_heads_per_blk,
1445
+ head_dim,
1446
+ page_size,
1447
+ max_num_batched_tokens,
1448
+ pages_per_seq,
1449
+ ) -> tuple[int, int]:
1450
+ """Look up for the best (num_kv_pages_per_blk, num_queries_per_blk) from auto-tuned table."""
1451
+ tpu_version = get_tpu_version()
1452
+ if tpu_version < 4:
1453
+ raise NotImplementedError('TPU version must be 4 or higher.')
1454
+ key = (
1455
+ q_dtype,
1456
+ kv_dtype,
1457
+ num_q_heads_per_blk,
1458
+ num_kv_heads_per_blk,
1459
+ head_dim,
1460
+ page_size,
1461
+ max_num_batched_tokens,
1462
+ pages_per_seq,
1463
+ )
1464
+ key = simplify_key(key)
1465
+ device_name = get_device_name()
1466
+
1467
+ # Default block sizes.
1468
+ bkv, bq = (128, 32)
1469
+ if tpu_version == 4:
1470
+ # This default block size is not tuned, only make sure there's no
1471
+ # OOM in vmem
1472
+ bkv, bq = (32, 32)
1473
+ elif device_name in TUNED_BLOCK_SIZES:
1474
+ if key in TUNED_BLOCK_SIZES[device_name]:
1475
+ bkv, bq = TUNED_BLOCK_SIZES[device_name][key]
1476
+ return (min(pages_per_seq, bkv), min(max_num_batched_tokens, bq))
1477
+
1478
+
1479
+ def get_min_page_size(max_model_len, min_page_size=16):
1480
+ """Recommended min page size for high-performance kernel."""
1481
+ return max(
1482
+ next_power_of_2(max_model_len) // MAX_PAGES_PER_SEQ, min_page_size)