tpu-inference 0.11.1.dev202511150811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (179) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +105 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +654 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +96 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +182 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +236 -0
  27. tpu_inference/__init__.py +34 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +51 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +728 -0
  37. tpu_inference/distributed/utils.py +59 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +107 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +362 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +390 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +507 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +105 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +311 -0
  120. tpu_inference/mock/__init__.py +0 -0
  121. tpu_inference/mock/vllm_config_utils.py +28 -0
  122. tpu_inference/mock/vllm_envs.py +1219 -0
  123. tpu_inference/mock/vllm_logger.py +212 -0
  124. tpu_inference/mock/vllm_logging_utils.py +15 -0
  125. tpu_inference/models/__init__.py +0 -0
  126. tpu_inference/models/common/__init__.py +0 -0
  127. tpu_inference/models/common/model_loader.py +444 -0
  128. tpu_inference/models/jax/__init__.py +0 -0
  129. tpu_inference/models/jax/deepseek_v3.py +868 -0
  130. tpu_inference/models/jax/gpt_oss.py +492 -0
  131. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  132. tpu_inference/models/jax/llama3.py +375 -0
  133. tpu_inference/models/jax/llama4.py +629 -0
  134. tpu_inference/models/jax/llama_eagle3.py +333 -0
  135. tpu_inference/models/jax/phi3.py +376 -0
  136. tpu_inference/models/jax/qwen2.py +375 -0
  137. tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
  138. tpu_inference/models/jax/qwen3.py +302 -0
  139. tpu_inference/models/jax/utils/__init__.py +0 -0
  140. tpu_inference/models/jax/utils/file_utils.py +96 -0
  141. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  142. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  143. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  144. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  145. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  146. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  147. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  148. tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
  149. tpu_inference/models/jax/utils/weight_utils.py +529 -0
  150. tpu_inference/models/vllm/__init__.py +0 -0
  151. tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
  152. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  153. tpu_inference/platforms/__init__.py +2 -0
  154. tpu_inference/platforms/tpu_platform.py +269 -0
  155. tpu_inference/runner/__init__.py +0 -0
  156. tpu_inference/runner/block_table.py +122 -0
  157. tpu_inference/runner/compilation_manager.py +780 -0
  158. tpu_inference/runner/input_batch.py +435 -0
  159. tpu_inference/runner/kv_cache.py +132 -0
  160. tpu_inference/runner/kv_cache_manager.py +479 -0
  161. tpu_inference/runner/lora_utils.py +92 -0
  162. tpu_inference/runner/multimodal_manager.py +217 -0
  163. tpu_inference/runner/persistent_batch_manager.py +244 -0
  164. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  165. tpu_inference/runner/structured_decoding_manager.py +88 -0
  166. tpu_inference/runner/tpu_runner.py +1620 -0
  167. tpu_inference/runner/utils.py +426 -0
  168. tpu_inference/spec_decode/__init__.py +0 -0
  169. tpu_inference/spec_decode/jax/__init__.py +0 -0
  170. tpu_inference/spec_decode/jax/eagle3.py +367 -0
  171. tpu_inference/tpu_info.py +77 -0
  172. tpu_inference/utils.py +317 -0
  173. tpu_inference/worker/__init__.py +0 -0
  174. tpu_inference/worker/tpu_worker.py +321 -0
  175. tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
  176. tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
  177. tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
  178. tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
  179. tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
@@ -0,0 +1,4147 @@
1
+ """Auto-tuned block sizes for ragged paged attention."""
2
+
3
+ import jax.numpy as jnp
4
+
5
+ from tpu_inference.kernels.ragged_paged_attention.v3.util import (
6
+ align_to, get_dtype_packing, get_tpu_version, next_power_of_2)
7
+ from tpu_inference.logger import init_logger
8
+ from tpu_inference.utils import get_device_name
9
+
10
+ logger = init_logger(__name__)
11
+
12
+ # key
13
+ # - device_name
14
+ # - page_size
15
+ # - q_{q_dtype_name}_kv_{kv_dtype_name}
16
+ # - q_head-{num_q_heads}_kv_head-{num_kv_heads}-_head-{head_dim}
17
+ # - max_model_len
18
+ # value:
19
+ # - (num_kv_pages_per_block, num_queries_per_block)
20
+ TUNED_BLOCK_SIZES = {
21
+ 'TPU v6e': {
22
+ 128: {
23
+ 'q_bfloat16_kv_bfloat16': {
24
+ 'q_head-128_kv_head-1_head-128': {
25
+ 1024: (8, 16),
26
+ 2048: (8, 16),
27
+ 256: (2, 8),
28
+ 4096: (16, 8),
29
+ 512: (2, 32),
30
+ 8192: (16, 16),
31
+ },
32
+ 'q_head-128_kv_head-1_head-256': {
33
+ 1024: (8, 16),
34
+ 2048: (16, 16),
35
+ 256: (2, 8),
36
+ 4096: (16, 8),
37
+ 512: (4, 8),
38
+ 8192: (16, 16),
39
+ },
40
+ 'q_head-128_kv_head-16_head-128': {
41
+ 1024: (8, 16),
42
+ 2048: (8, 16),
43
+ 256: (2, 16),
44
+ 4096: (8, 16),
45
+ 512: (4, 16),
46
+ 8192: (8, 16),
47
+ },
48
+ 'q_head-128_kv_head-16_head-256': {
49
+ 1024: (4, 8),
50
+ 2048: (4, 8),
51
+ 256: (2, 8),
52
+ 4096: (4, 8),
53
+ 512: (4, 8),
54
+ 8192: (4, 8),
55
+ },
56
+ 'q_head-128_kv_head-2_head-128': {
57
+ 1024: (8, 8),
58
+ 2048: (16, 8),
59
+ 256: (2, 32),
60
+ 4096: (16, 8),
61
+ 512: (4, 16),
62
+ 8192: (16, 32),
63
+ },
64
+ 'q_head-128_kv_head-2_head-256': {
65
+ 1024: (8, 16),
66
+ 2048: (16, 8),
67
+ 256: (2, 8),
68
+ 4096: (16, 16),
69
+ 512: (4, 8),
70
+ 8192: (16, 16),
71
+ },
72
+ 'q_head-128_kv_head-4_head-128': {
73
+ 1024: (8, 8),
74
+ 2048: (16, 16),
75
+ 256: (2, 8),
76
+ 4096: (16, 16),
77
+ 512: (4, 8),
78
+ 8192: (16, 32),
79
+ },
80
+ 'q_head-128_kv_head-4_head-256': {
81
+ 1024: (8, 8),
82
+ 2048: (16, 16),
83
+ 256: (2, 8),
84
+ 4096: (16, 16),
85
+ 512: (4, 8),
86
+ 8192: (16, 16),
87
+ },
88
+ 'q_head-128_kv_head-8_head-128': {
89
+ 1024: (8, 8),
90
+ 2048: (16, 16),
91
+ 256: (2, 16),
92
+ 4096: (16, 32),
93
+ 512: (4, 32),
94
+ 8192: (16, 32),
95
+ },
96
+ 'q_head-128_kv_head-8_head-256': {
97
+ 1024: (8, 8),
98
+ 2048: (8, 16),
99
+ 256: (2, 16),
100
+ 4096: (8, 16),
101
+ 512: (4, 8),
102
+ 8192: (8, 16),
103
+ },
104
+ 'q_head-16_kv_head-1_head-128': {
105
+ 1024: (8, 32),
106
+ 2048: (16, 8),
107
+ 256: (2, 128),
108
+ 4096: (16, 32),
109
+ 512: (4, 256),
110
+ 8192: (16, 128),
111
+ },
112
+ 'q_head-16_kv_head-1_head-256': {
113
+ 1024: (8, 32),
114
+ 2048: (16, 64),
115
+ 256: (2, 128),
116
+ 4096: (16, 16),
117
+ 512: (4, 32),
118
+ 8192: (16, 16),
119
+ },
120
+ 'q_head-16_kv_head-2_head-128': {
121
+ 1024: (8, 128),
122
+ 2048: (16, 16),
123
+ 256: (2, 64),
124
+ 4096: (16, 64),
125
+ 512: (4, 16),
126
+ 8192: (16, 128),
127
+ },
128
+ 'q_head-16_kv_head-2_head-256': {
129
+ 1024: (8, 32),
130
+ 2048: (16, 32),
131
+ 256: (2, 32),
132
+ 4096: (16, 64),
133
+ 512: (4, 32),
134
+ 8192: (16, 32),
135
+ },
136
+ 'q_head-16_kv_head-4_head-128': {
137
+ 1024: (8, 128),
138
+ 2048: (8, 128),
139
+ 256: (2, 32),
140
+ 4096: (16, 128),
141
+ 512: (4, 32),
142
+ 8192: (16, 128),
143
+ },
144
+ 'q_head-16_kv_head-4_head-256': {
145
+ 1024: (8, 16),
146
+ 2048: (8, 16),
147
+ 256: (2, 32),
148
+ 4096: (16, 128),
149
+ 512: (4, 32),
150
+ 8192: (16, 64),
151
+ },
152
+ 'q_head-16_kv_head-8_head-128': {
153
+ 1024: (8, 32),
154
+ 2048: (16, 64),
155
+ 256: (2, 64),
156
+ 4096: (16, 128),
157
+ 512: (4, 64),
158
+ 8192: (16, 128),
159
+ },
160
+ 'q_head-16_kv_head-8_head-256': {
161
+ 1024: (8, 64),
162
+ 2048: (16, 64),
163
+ 256: (2, 64),
164
+ 4096: (16, 64),
165
+ 512: (4, 64),
166
+ 8192: (8, 128),
167
+ },
168
+ 'q_head-2_kv_head-1_head-128': {
169
+ 1024: (8, 64),
170
+ 2048: (16, 256),
171
+ 256: (2, 64),
172
+ 4096: (16, 64),
173
+ 512: (4, 32),
174
+ 8192: (16, 32),
175
+ },
176
+ 'q_head-2_kv_head-1_head-256': {
177
+ 1024: (8, 16),
178
+ 2048: (16, 64),
179
+ 256: (2, 16),
180
+ 4096: (16, 32),
181
+ 512: (4, 256),
182
+ 8192: (16, 32),
183
+ },
184
+ 'q_head-32_kv_head-1_head-128': {
185
+ 1024: (8, 32),
186
+ 2048: (16, 16),
187
+ 256: (2, 32),
188
+ 4096: (16, 8),
189
+ 512: (4, 16),
190
+ 8192: (16, 16),
191
+ },
192
+ 'q_head-32_kv_head-1_head-256': {
193
+ 1024: (4, 16),
194
+ 2048: (16, 8),
195
+ 256: (2, 16),
196
+ 4096: (16, 16),
197
+ 512: (4, 8),
198
+ 8192: (16, 32),
199
+ },
200
+ 'q_head-32_kv_head-16_head-128': {
201
+ 1024: (8, 32),
202
+ 2048: (8, 32),
203
+ 256: (2, 32),
204
+ 4096: (8, 64),
205
+ 512: (4, 32),
206
+ 8192: (8, 32),
207
+ },
208
+ 'q_head-32_kv_head-16_head-256': {
209
+ 1024: (4, 32),
210
+ 2048: (4, 32),
211
+ 256: (2, 32),
212
+ 4096: (4, 32),
213
+ 512: (4, 32),
214
+ 8192: (4, 32),
215
+ },
216
+ 'q_head-32_kv_head-2_head-128': {
217
+ 1024: (8, 8),
218
+ 2048: (16, 16),
219
+ 256: (2, 16),
220
+ 4096: (16, 64),
221
+ 512: (4, 16),
222
+ 8192: (16, 16),
223
+ },
224
+ 'q_head-32_kv_head-2_head-256': {
225
+ 1024: (8, 32),
226
+ 2048: (16, 32),
227
+ 256: (2, 32),
228
+ 4096: (16, 32),
229
+ 512: (2, 16),
230
+ 8192: (16, 16),
231
+ },
232
+ 'q_head-32_kv_head-4_head-128': {
233
+ 1024: (8, 16),
234
+ 2048: (16, 32),
235
+ 256: (2, 64),
236
+ 4096: (16, 32),
237
+ 512: (4, 64),
238
+ 8192: (16, 64),
239
+ },
240
+ 'q_head-32_kv_head-4_head-256': {
241
+ 1024: (8, 32),
242
+ 2048: (16, 64),
243
+ 256: (2, 32),
244
+ 4096: (16, 64),
245
+ 512: (4, 8),
246
+ 8192: (16, 32),
247
+ },
248
+ 'q_head-32_kv_head-8_head-128': {
249
+ 1024: (8, 32),
250
+ 2048: (16, 64),
251
+ 256: (2, 64),
252
+ 4096: (16, 32),
253
+ 512: (4, 32),
254
+ 8192: (16, 32),
255
+ },
256
+ 'q_head-32_kv_head-8_head-256': {
257
+ 1024: (8, 32),
258
+ 2048: (16, 32),
259
+ 256: (2, 16),
260
+ 4096: (16, 32),
261
+ 512: (4, 32),
262
+ 8192: (8, 64),
263
+ },
264
+ 'q_head-4_kv_head-1_head-128': {
265
+ 1024: (8, 8),
266
+ 2048: (16, 128),
267
+ 256: (2, 8),
268
+ 4096: (16, 256),
269
+ 512: (4, 8),
270
+ 8192: (16, 32),
271
+ },
272
+ 'q_head-4_kv_head-1_head-256': {
273
+ 1024: (8, 16),
274
+ 2048: (16, 64),
275
+ 256: (2, 64),
276
+ 4096: (16, 128),
277
+ 512: (4, 64),
278
+ 8192: (16, 32),
279
+ },
280
+ 'q_head-4_kv_head-2_head-128': {
281
+ 1024: (8, 64),
282
+ 2048: (16, 64),
283
+ 256: (2, 16),
284
+ 4096: (16, 64),
285
+ 512: (4, 32),
286
+ 8192: (16, 128),
287
+ },
288
+ 'q_head-4_kv_head-2_head-256': {
289
+ 1024: (8, 256),
290
+ 2048: (16, 64),
291
+ 256: (2, 64),
292
+ 4096: (16, 64),
293
+ 512: (4, 32),
294
+ 8192: (16, 128),
295
+ },
296
+ 'q_head-64_kv_head-1_head-128': {
297
+ 1024: (8, 16),
298
+ 2048: (16, 8),
299
+ 256: (2, 16),
300
+ 4096: (16, 32),
301
+ 512: (4, 8),
302
+ 8192: (16, 16),
303
+ },
304
+ 'q_head-64_kv_head-1_head-256': {
305
+ 1024: (8, 16),
306
+ 2048: (16, 16),
307
+ 256: (2, 16),
308
+ 4096: (16, 32),
309
+ 512: (4, 16),
310
+ 8192: (16, 32),
311
+ },
312
+ 'q_head-64_kv_head-16_head-128': {
313
+ 1024: (8, 32),
314
+ 2048: (8, 32),
315
+ 256: (2, 32),
316
+ 4096: (8, 32),
317
+ 512: (4, 16),
318
+ 8192: (8, 32),
319
+ },
320
+ 'q_head-64_kv_head-16_head-256': {
321
+ 1024: (4, 16),
322
+ 2048: (4, 16),
323
+ 256: (2, 16),
324
+ 4096: (4, 16),
325
+ 512: (4, 16),
326
+ 8192: (4, 16),
327
+ },
328
+ 'q_head-64_kv_head-2_head-128': {
329
+ 1024: (8, 16),
330
+ 2048: (16, 32),
331
+ 256: (2, 16),
332
+ 4096: (16, 32),
333
+ 512: (4, 16),
334
+ 8192: (16, 32),
335
+ },
336
+ 'q_head-64_kv_head-2_head-256': {
337
+ 1024: (8, 16),
338
+ 2048: (16, 16),
339
+ 256: (2, 8),
340
+ 4096: (16, 16),
341
+ 512: (4, 8),
342
+ 8192: (16, 32),
343
+ },
344
+ 'q_head-64_kv_head-4_head-128': {
345
+ 1024: (8, 32),
346
+ 2048: (16, 16),
347
+ 256: (2, 16),
348
+ 4096: (16, 16),
349
+ 512: (4, 8),
350
+ 8192: (16, 32),
351
+ },
352
+ 'q_head-64_kv_head-4_head-256': {
353
+ 1024: (8, 16),
354
+ 2048: (16, 32),
355
+ 256: (2, 8),
356
+ 4096: (16, 16),
357
+ 512: (4, 32),
358
+ 8192: (16, 32),
359
+ },
360
+ 'q_head-64_kv_head-8_head-128': {
361
+ 1024: (8, 16),
362
+ 2048: (16, 32),
363
+ 256: (2, 32),
364
+ 4096: (16, 32),
365
+ 512: (4, 32),
366
+ 8192: (16, 32),
367
+ },
368
+ 'q_head-64_kv_head-8_head-256': {
369
+ 1024: (8, 16),
370
+ 2048: (16, 16),
371
+ 256: (2, 8),
372
+ 4096: (8, 32),
373
+ 512: (4, 32),
374
+ 8192: (8, 32),
375
+ },
376
+ 'q_head-8_kv_head-1_head-128': {
377
+ 1024: (8, 128),
378
+ 2048: (16, 32),
379
+ 256: (2, 256),
380
+ 4096: (16, 128),
381
+ 512: (4, 16),
382
+ 8192: (16, 64),
383
+ },
384
+ 'q_head-8_kv_head-1_head-256': {
385
+ 1024: (8, 64),
386
+ 2048: (16, 32),
387
+ 256: (2, 32),
388
+ 4096: (16, 32),
389
+ 512: (4, 32),
390
+ 8192: (16, 16),
391
+ },
392
+ 'q_head-8_kv_head-2_head-128': {
393
+ 1024: (8, 128),
394
+ 2048: (16, 128),
395
+ 256: (2, 64),
396
+ 4096: (16, 32),
397
+ 512: (4, 32),
398
+ 8192: (16, 64),
399
+ },
400
+ 'q_head-8_kv_head-2_head-256': {
401
+ 1024: (8, 64),
402
+ 2048: (16, 128),
403
+ 256: (2, 16),
404
+ 4096: (16, 128),
405
+ 512: (4, 16),
406
+ 8192: (16, 64),
407
+ },
408
+ 'q_head-8_kv_head-4_head-128': {
409
+ 1024: (8, 8),
410
+ 2048: (16, 256),
411
+ 256: (2, 64),
412
+ 4096: (16, 64),
413
+ 512: (4, 32),
414
+ 8192: (16, 64),
415
+ },
416
+ 'q_head-8_kv_head-4_head-256': {
417
+ 1024: (8, 16),
418
+ 2048: (16, 64),
419
+ 256: (2, 8),
420
+ 4096: (16, 64),
421
+ 512: (4, 64),
422
+ 8192: (16, 64),
423
+ },
424
+ }
425
+ },
426
+ 256: {
427
+ 'q_bfloat16_kv_bfloat16': {
428
+ 'q_head-128_kv_head-1_head-128': {
429
+ 1024: (4, 8),
430
+ 2048: (8, 16),
431
+ 4096: (8, 8),
432
+ 512: (2, 8),
433
+ 8192: (8, 16),
434
+ },
435
+ 'q_head-128_kv_head-1_head-256': {
436
+ 1024: (4, 8),
437
+ 2048: (4, 8),
438
+ 4096: (8, 16),
439
+ 512: (2, 8),
440
+ 8192: (8, 16),
441
+ },
442
+ 'q_head-128_kv_head-16_head-128': {
443
+ 1024: (4, 16),
444
+ 2048: (4, 16),
445
+ 4096: (4, 16),
446
+ 512: (2, 16),
447
+ 8192: (4, 16),
448
+ },
449
+ 'q_head-128_kv_head-16_head-256': {
450
+ 1024: (2, 8),
451
+ 2048: (2, 8),
452
+ 4096: (2, 8),
453
+ 512: (2, 8),
454
+ 8192: (2, 8),
455
+ },
456
+ 'q_head-128_kv_head-2_head-128': {
457
+ 1024: (4, 8),
458
+ 2048: (8, 8),
459
+ 4096: (8, 8),
460
+ 512: (2, 8),
461
+ 8192: (8, 16),
462
+ },
463
+ 'q_head-128_kv_head-2_head-256': {
464
+ 1024: (4, 16),
465
+ 2048: (8, 8),
466
+ 4096: (8, 8),
467
+ 512: (2, 8),
468
+ 8192: (8, 8),
469
+ },
470
+ 'q_head-128_kv_head-4_head-128': {
471
+ 1024: (4, 16),
472
+ 2048: (8, 16),
473
+ 4096: (8, 16),
474
+ 512: (2, 32),
475
+ 8192: (8, 32),
476
+ },
477
+ 'q_head-128_kv_head-4_head-256': {
478
+ 1024: (4, 8),
479
+ 2048: (8, 8),
480
+ 4096: (8, 16),
481
+ 512: (2, 8),
482
+ 8192: (8, 16),
483
+ },
484
+ 'q_head-128_kv_head-8_head-128': {
485
+ 1024: (4, 16),
486
+ 2048: (8, 32),
487
+ 4096: (8, 32),
488
+ 512: (2, 32),
489
+ 8192: (8, 32),
490
+ },
491
+ 'q_head-128_kv_head-8_head-256': {
492
+ 1024: (4, 16),
493
+ 2048: (4, 16),
494
+ 4096: (4, 16),
495
+ 512: (2, 8),
496
+ 8192: (4, 16),
497
+ },
498
+ 'q_head-16_kv_head-1_head-128': {
499
+ 1024: (4, 32),
500
+ 2048: (8, 128),
501
+ 4096: (8, 128),
502
+ 512: (2, 32),
503
+ 8192: (8, 64),
504
+ },
505
+ 'q_head-16_kv_head-1_head-256': {
506
+ 1024: (4, 16),
507
+ 2048: (8, 32),
508
+ 4096: (8, 32),
509
+ 512: (2, 32),
510
+ 8192: (8, 32),
511
+ },
512
+ 'q_head-16_kv_head-2_head-128': {
513
+ 1024: (4, 8),
514
+ 2048: (8, 32),
515
+ 4096: (8, 16),
516
+ 512: (2, 32),
517
+ 8192: (8, 128),
518
+ },
519
+ 'q_head-16_kv_head-2_head-256': {
520
+ 1024: (4, 32),
521
+ 2048: (8, 32),
522
+ 4096: (8, 64),
523
+ 512: (2, 32),
524
+ 8192: (8, 16),
525
+ },
526
+ 'q_head-16_kv_head-4_head-128': {
527
+ 1024: (4, 32),
528
+ 2048: (8, 32),
529
+ 4096: (8, 64),
530
+ 512: (2, 16),
531
+ 8192: (8, 32),
532
+ },
533
+ 'q_head-16_kv_head-4_head-256': {
534
+ 1024: (4, 128),
535
+ 2048: (4, 32),
536
+ 4096: (8, 64),
537
+ 512: (2, 64),
538
+ 8192: (8, 64),
539
+ },
540
+ 'q_head-16_kv_head-8_head-128': {
541
+ 1024: (4, 64),
542
+ 2048: (8, 32),
543
+ 4096: (8, 64),
544
+ 512: (2, 64),
545
+ 8192: (8, 128),
546
+ },
547
+ 'q_head-16_kv_head-8_head-256': {
548
+ 1024: (4, 64),
549
+ 2048: (8, 64),
550
+ 4096: (8, 64),
551
+ 512: (2, 32),
552
+ 8192: (4, 128),
553
+ },
554
+ 'q_head-2_kv_head-1_head-128': {
555
+ 1024: (4, 64),
556
+ 2048: (8, 32),
557
+ 4096: (8, 32),
558
+ 512: (2, 64),
559
+ 8192: (8, 32),
560
+ },
561
+ 'q_head-2_kv_head-1_head-256': {
562
+ 1024: (4, 32),
563
+ 2048: (8, 16),
564
+ 4096: (8, 128),
565
+ 512: (2, 256),
566
+ 8192: (8, 256),
567
+ },
568
+ 'q_head-32_kv_head-1_head-128': {
569
+ 1024: (4, 32),
570
+ 2048: (8, 16),
571
+ 4096: (8, 32),
572
+ 512: (2, 64),
573
+ 8192: (8, 32),
574
+ },
575
+ 'q_head-32_kv_head-1_head-256': {
576
+ 1024: (4, 8),
577
+ 2048: (8, 32),
578
+ 4096: (8, 16),
579
+ 512: (2, 8),
580
+ 8192: (8, 32),
581
+ },
582
+ 'q_head-32_kv_head-16_head-128': {
583
+ 1024: (4, 32),
584
+ 2048: (4, 64),
585
+ 4096: (4, 64),
586
+ 512: (2, 64),
587
+ 8192: (4, 64),
588
+ },
589
+ 'q_head-32_kv_head-16_head-256': {
590
+ 1024: (2, 32),
591
+ 2048: (2, 32),
592
+ 4096: (2, 32),
593
+ 512: (2, 32),
594
+ 8192: (4, 32),
595
+ },
596
+ 'q_head-32_kv_head-2_head-128': {
597
+ 1024: (4, 64),
598
+ 2048: (8, 16),
599
+ 4096: (8, 16),
600
+ 512: (2, 64),
601
+ 8192: (8, 32),
602
+ },
603
+ 'q_head-32_kv_head-2_head-256': {
604
+ 1024: (4, 32),
605
+ 2048: (8, 16),
606
+ 4096: (8, 16),
607
+ 512: (2, 32),
608
+ 8192: (8, 32),
609
+ },
610
+ 'q_head-32_kv_head-4_head-128': {
611
+ 1024: (4, 64),
612
+ 2048: (8, 64),
613
+ 4096: (8, 64),
614
+ 512: (2, 16),
615
+ 8192: (8, 32),
616
+ },
617
+ 'q_head-32_kv_head-4_head-256': {
618
+ 1024: (4, 32),
619
+ 2048: (8, 64),
620
+ 4096: (8, 32),
621
+ 512: (2, 32),
622
+ 8192: (8, 32),
623
+ },
624
+ 'q_head-32_kv_head-8_head-128': {
625
+ 1024: (4, 64),
626
+ 2048: (8, 128),
627
+ 4096: (8, 32),
628
+ 512: (2, 64),
629
+ 8192: (8, 128),
630
+ },
631
+ 'q_head-32_kv_head-8_head-256': {
632
+ 1024: (4, 32),
633
+ 2048: (8, 32),
634
+ 4096: (8, 32),
635
+ 512: (2, 16),
636
+ 8192: (4, 64),
637
+ },
638
+ 'q_head-4_kv_head-1_head-128': {
639
+ 1024: (4, 8),
640
+ 2048: (8, 16),
641
+ 4096: (8, 128),
642
+ 512: (2, 8),
643
+ 8192: (8, 64),
644
+ },
645
+ 'q_head-4_kv_head-1_head-256': {
646
+ 1024: (4, 32),
647
+ 2048: (8, 128),
648
+ 4096: (8, 128),
649
+ 512: (2, 32),
650
+ 8192: (8, 64),
651
+ },
652
+ 'q_head-4_kv_head-2_head-128': {
653
+ 1024: (4, 64),
654
+ 2048: (8, 256),
655
+ 4096: (8, 256),
656
+ 512: (2, 16),
657
+ 8192: (8, 32),
658
+ },
659
+ 'q_head-4_kv_head-2_head-256': {
660
+ 1024: (2, 32),
661
+ 2048: (8, 64),
662
+ 4096: (8, 256),
663
+ 512: (2, 32),
664
+ 8192: (8, 32),
665
+ },
666
+ 'q_head-64_kv_head-1_head-128': {
667
+ 1024: (2, 16),
668
+ 2048: (8, 8),
669
+ 4096: (8, 8),
670
+ 512: (2, 16),
671
+ 8192: (8, 16),
672
+ },
673
+ 'q_head-64_kv_head-1_head-256': {
674
+ 1024: (4, 16),
675
+ 2048: (8, 16),
676
+ 4096: (8, 16),
677
+ 512: (2, 8),
678
+ 8192: (8, 32),
679
+ },
680
+ 'q_head-64_kv_head-16_head-128': {
681
+ 1024: (4, 32),
682
+ 2048: (4, 32),
683
+ 4096: (4, 32),
684
+ 512: (2, 32),
685
+ 8192: (4, 32),
686
+ },
687
+ 'q_head-64_kv_head-16_head-256': {
688
+ 1024: (2, 16),
689
+ 2048: (2, 16),
690
+ 4096: (2, 16),
691
+ 512: (2, 16),
692
+ 8192: (2, 16),
693
+ },
694
+ 'q_head-64_kv_head-2_head-128': {
695
+ 1024: (4, 8),
696
+ 2048: (8, 16),
697
+ 4096: (8, 16),
698
+ 512: (2, 8),
699
+ 8192: (8, 16),
700
+ },
701
+ 'q_head-64_kv_head-2_head-256': {
702
+ 1024: (4, 16),
703
+ 2048: (8, 32),
704
+ 4096: (8, 32),
705
+ 512: (2, 16),
706
+ 8192: (8, 32),
707
+ },
708
+ 'q_head-64_kv_head-4_head-128': {
709
+ 1024: (4, 32),
710
+ 2048: (8, 64),
711
+ 4096: (8, 32),
712
+ 512: (2, 32),
713
+ 8192: (8, 32),
714
+ },
715
+ 'q_head-64_kv_head-4_head-256': {
716
+ 1024: (4, 16),
717
+ 2048: (8, 32),
718
+ 4096: (8, 32),
719
+ 512: (2, 16),
720
+ 8192: (8, 32),
721
+ },
722
+ 'q_head-64_kv_head-8_head-128': {
723
+ 1024: (4, 16),
724
+ 2048: (8, 64),
725
+ 4096: (8, 64),
726
+ 512: (2, 16),
727
+ 8192: (8, 32),
728
+ },
729
+ 'q_head-64_kv_head-8_head-256': {
730
+ 1024: (4, 32),
731
+ 2048: (8, 16),
732
+ 4096: (8, 16),
733
+ 512: (2, 16),
734
+ 8192: (4, 32),
735
+ },
736
+ 'q_head-8_kv_head-1_head-128': {
737
+ 1024: (4, 16),
738
+ 2048: (8, 64),
739
+ 4096: (8, 64),
740
+ 512: (2, 32),
741
+ 8192: (8, 32),
742
+ },
743
+ 'q_head-8_kv_head-1_head-256': {
744
+ 1024: (4, 32),
745
+ 2048: (8, 128),
746
+ 4096: (8, 32),
747
+ 512: (2, 64),
748
+ 8192: (8, 32),
749
+ },
750
+ 'q_head-8_kv_head-2_head-128': {
751
+ 1024: (4, 32),
752
+ 2048: (8, 16),
753
+ 4096: (8, 32),
754
+ 512: (2, 32),
755
+ 8192: (8, 128),
756
+ },
757
+ 'q_head-8_kv_head-2_head-256': {
758
+ 1024: (4, 32),
759
+ 2048: (8, 128),
760
+ 4096: (8, 128),
761
+ 512: (2, 8),
762
+ 8192: (8, 128),
763
+ },
764
+ 'q_head-8_kv_head-4_head-128': {
765
+ 1024: (4, 16),
766
+ 2048: (8, 64),
767
+ 4096: (8, 32),
768
+ 512: (2, 8),
769
+ 8192: (8, 64),
770
+ },
771
+ 'q_head-8_kv_head-4_head-256': {
772
+ 1024: (4, 32),
773
+ 2048: (8, 64),
774
+ 4096: (8, 64),
775
+ 512: (2, 64),
776
+ 8192: (8, 256),
777
+ },
778
+ }
779
+ },
780
+ 64: {
781
+ 'q_bfloat16_kv_bfloat16': {
782
+ 'q_head-128_kv_head-1_head-128': {
783
+ 1024: (8, 16),
784
+ 128: (2, 16),
785
+ 2048: (32, 32),
786
+ 256: (4, 16),
787
+ 4096: (32, 16),
788
+ 512: (8, 8),
789
+ 8192: (32, 16),
790
+ },
791
+ 'q_head-128_kv_head-1_head-256': {
792
+ 1024: (16, 8),
793
+ 128: (2, 8),
794
+ 2048: (32, 8),
795
+ 256: (4, 8),
796
+ 4096: (32, 8),
797
+ 512: (8, 16),
798
+ 8192: (32, 8),
799
+ },
800
+ 'q_head-128_kv_head-16_head-128': {
801
+ 1024: (16, 16),
802
+ 128: (2, 16),
803
+ 2048: (16, 16),
804
+ 256: (4, 16),
805
+ 4096: (16, 16),
806
+ 512: (8, 16),
807
+ 8192: (16, 16),
808
+ },
809
+ 'q_head-128_kv_head-16_head-256': {
810
+ 1024: (8, 8),
811
+ 128: (2, 8),
812
+ 2048: (8, 8),
813
+ 256: (4, 8),
814
+ 4096: (8, 8),
815
+ 512: (8, 8),
816
+ 8192: (8, 8),
817
+ },
818
+ 'q_head-128_kv_head-2_head-128': {
819
+ 1024: (16, 8),
820
+ 128: (2, 8),
821
+ 2048: (32, 16),
822
+ 256: (4, 16),
823
+ 4096: (32, 16),
824
+ 512: (8, 32),
825
+ 8192: (32, 32),
826
+ },
827
+ 'q_head-128_kv_head-2_head-256': {
828
+ 1024: (16, 16),
829
+ 128: (2, 8),
830
+ 2048: (32, 8),
831
+ 256: (4, 8),
832
+ 4096: (32, 16),
833
+ 512: (8, 8),
834
+ 8192: (32, 16),
835
+ },
836
+ 'q_head-128_kv_head-4_head-128': {
837
+ 1024: (16, 8),
838
+ 128: (2, 16),
839
+ 2048: (32, 32),
840
+ 256: (4, 8),
841
+ 4096: (32, 16),
842
+ 512: (8, 16),
843
+ 8192: (32, 32),
844
+ },
845
+ 'q_head-128_kv_head-4_head-256': {
846
+ 1024: (16, 8),
847
+ 128: (2, 8),
848
+ 2048: (32, 16),
849
+ 256: (4, 8),
850
+ 4096: (32, 16),
851
+ 512: (8, 8),
852
+ 8192: (32, 16),
853
+ },
854
+ 'q_head-128_kv_head-8_head-128': {
855
+ 1024: (16, 16),
856
+ 128: (2, 16),
857
+ 2048: (32, 16),
858
+ 256: (4, 16),
859
+ 4096: (32, 32),
860
+ 512: (8, 8),
861
+ 8192: (32, 32),
862
+ },
863
+ 'q_head-128_kv_head-8_head-256': {
864
+ 1024: (16, 8),
865
+ 128: (2, 8),
866
+ 2048: (16, 16),
867
+ 256: (4, 8),
868
+ 4096: (16, 16),
869
+ 512: (8, 8),
870
+ 8192: (16, 16),
871
+ },
872
+ 'q_head-16_kv_head-1_head-128': {
873
+ 1024: (16, 64),
874
+ 128: (2, 16),
875
+ 2048: (32, 128),
876
+ 256: (4, 16),
877
+ 4096: (32, 32),
878
+ 512: (8, 16),
879
+ 8192: (32, 16),
880
+ },
881
+ 'q_head-16_kv_head-1_head-256': {
882
+ 1024: (16, 32),
883
+ 128: (2, 8),
884
+ 2048: (32, 16),
885
+ 256: (4, 8),
886
+ 4096: (32, 32),
887
+ 512: (8, 32),
888
+ 8192: (32, 64),
889
+ },
890
+ 'q_head-16_kv_head-2_head-128': {
891
+ 1024: (16, 16),
892
+ 128: (2, 32),
893
+ 2048: (32, 64),
894
+ 256: (4, 16),
895
+ 4096: (32, 32),
896
+ 512: (8, 32),
897
+ 8192: (32, 128),
898
+ },
899
+ 'q_head-16_kv_head-2_head-256': {
900
+ 1024: (16, 16),
901
+ 128: (2, 16),
902
+ 2048: (32, 16),
903
+ 256: (4, 128),
904
+ 4096: (16, 32),
905
+ 512: (8, 32),
906
+ 8192: (32, 64),
907
+ },
908
+ 'q_head-16_kv_head-4_head-128': {
909
+ 1024: (16, 16),
910
+ 128: (2, 32),
911
+ 2048: (32, 64),
912
+ 256: (4, 128),
913
+ 4096: (32, 32),
914
+ 512: (8, 32),
915
+ 8192: (32, 128),
916
+ },
917
+ 'q_head-16_kv_head-4_head-256': {
918
+ 1024: (16, 32),
919
+ 128: (2, 8),
920
+ 2048: (32, 64),
921
+ 256: (4, 32),
922
+ 4096: (32, 128),
923
+ 512: (8, 8),
924
+ 8192: (32, 64),
925
+ },
926
+ 'q_head-16_kv_head-8_head-128': {
927
+ 1024: (16, 64),
928
+ 128: (2, 32),
929
+ 2048: (32, 32),
930
+ 256: (4, 16),
931
+ 4096: (32, 64),
932
+ 512: (8, 32),
933
+ 8192: (32, 64),
934
+ },
935
+ 'q_head-16_kv_head-8_head-256': {
936
+ 1024: (16, 64),
937
+ 128: (2, 16),
938
+ 2048: (32, 32),
939
+ 256: (4, 64),
940
+ 4096: (32, 64),
941
+ 512: (8, 64),
942
+ 8192: (16, 128),
943
+ },
944
+ 'q_head-2_kv_head-1_head-128': {
945
+ 1024: (16, 32),
946
+ 128: (2, 256),
947
+ 2048: (32, 256),
948
+ 256: (4, 32),
949
+ 4096: (32, 256),
950
+ 512: (8, 256),
951
+ 8192: (32, 256),
952
+ },
953
+ 'q_head-2_kv_head-1_head-256': {
954
+ 1024: (16, 64),
955
+ 128: (2, 32),
956
+ 2048: (32, 128),
957
+ 256: (4, 32),
958
+ 4096: (32, 256),
959
+ 512: (8, 64),
960
+ 8192: (32, 64),
961
+ },
962
+ 'q_head-32_kv_head-1_head-128': {
963
+ 1024: (8, 64),
964
+ 128: (2, 32),
965
+ 2048: (32, 32),
966
+ 256: (4, 128),
967
+ 4096: (32, 16),
968
+ 512: (4, 16),
969
+ 8192: (32, 16),
970
+ },
971
+ 'q_head-32_kv_head-1_head-256': {
972
+ 1024: (16, 16),
973
+ 128: (2, 16),
974
+ 2048: (32, 8),
975
+ 256: (4, 8),
976
+ 4096: (16, 16),
977
+ 512: (8, 8),
978
+ 8192: (32, 32),
979
+ },
980
+ 'q_head-32_kv_head-16_head-128': {
981
+ 1024: (16, 64),
982
+ 128: (2, 64),
983
+ 2048: (16, 64),
984
+ 256: (4, 64),
985
+ 4096: (16, 64),
986
+ 512: (8, 64),
987
+ 8192: (16, 64),
988
+ },
989
+ 'q_head-32_kv_head-16_head-256': {
990
+ 1024: (8, 32),
991
+ 128: (2, 32),
992
+ 2048: (8, 32),
993
+ 256: (4, 32),
994
+ 4096: (8, 32),
995
+ 512: (8, 32),
996
+ 8192: (8, 32),
997
+ },
998
+ 'q_head-32_kv_head-2_head-128': {
999
+ 1024: (16, 64),
1000
+ 128: (2, 16),
1001
+ 2048: (32, 16),
1002
+ 256: (4, 32),
1003
+ 4096: (32, 32),
1004
+ 512: (8, 64),
1005
+ 8192: (32, 32),
1006
+ },
1007
+ 'q_head-32_kv_head-2_head-256': {
1008
+ 1024: (16, 32),
1009
+ 128: (2, 16),
1010
+ 2048: (32, 32),
1011
+ 256: (4, 16),
1012
+ 4096: (32, 16),
1013
+ 512: (8, 32),
1014
+ 8192: (32, 32),
1015
+ },
1016
+ 'q_head-32_kv_head-4_head-128': {
1017
+ 1024: (16, 32),
1018
+ 128: (2, 128),
1019
+ 2048: (32, 32),
1020
+ 256: (4, 16),
1021
+ 4096: (32, 32),
1022
+ 512: (8, 16),
1023
+ 8192: (32, 16),
1024
+ },
1025
+ 'q_head-32_kv_head-4_head-256': {
1026
+ 1024: (16, 32),
1027
+ 128: (2, 16),
1028
+ 2048: (32, 32),
1029
+ 256: (4, 16),
1030
+ 4096: (32, 32),
1031
+ 512: (8, 16),
1032
+ 8192: (32, 32),
1033
+ },
1034
+ 'q_head-32_kv_head-8_head-128': {
1035
+ 1024: (16, 32),
1036
+ 128: (2, 32),
1037
+ 2048: (32, 64),
1038
+ 256: (4, 8),
1039
+ 4096: (32, 32),
1040
+ 512: (8, 32),
1041
+ 8192: (32, 32),
1042
+ },
1043
+ 'q_head-32_kv_head-8_head-256': {
1044
+ 1024: (16, 32),
1045
+ 128: (2, 32),
1046
+ 2048: (32, 32),
1047
+ 256: (4, 32),
1048
+ 4096: (32, 32),
1049
+ 512: (8, 8),
1050
+ 8192: (16, 64),
1051
+ },
1052
+ 'q_head-4_kv_head-1_head-128': {
1053
+ 1024: (16, 32),
1054
+ 128: (2, 16),
1055
+ 2048: (32, 16),
1056
+ 256: (4, 256),
1057
+ 4096: (32, 128),
1058
+ 512: (4, 64),
1059
+ 8192: (32, 32),
1060
+ },
1061
+ 'q_head-4_kv_head-1_head-256': {
1062
+ 1024: (8, 128),
1063
+ 128: (2, 16),
1064
+ 2048: (16, 64),
1065
+ 256: (4, 256),
1066
+ 4096: (32, 64),
1067
+ 512: (8, 32),
1068
+ 8192: (32, 128),
1069
+ },
1070
+ 'q_head-4_kv_head-2_head-128': {
1071
+ 1024: (16, 64),
1072
+ 128: (2, 32),
1073
+ 2048: (32, 32),
1074
+ 256: (4, 256),
1075
+ 4096: (32, 64),
1076
+ 512: (8, 128),
1077
+ 8192: (32, 64),
1078
+ },
1079
+ 'q_head-4_kv_head-2_head-256': {
1080
+ 1024: (16, 64),
1081
+ 128: (2, 32),
1082
+ 2048: (32, 32),
1083
+ 256: (4, 64),
1084
+ 4096: (32, 32),
1085
+ 512: (8, 16),
1086
+ 8192: (32, 64),
1087
+ },
1088
+ 'q_head-64_kv_head-1_head-128': {
1089
+ 1024: (8, 32),
1090
+ 128: (2, 16),
1091
+ 2048: (32, 32),
1092
+ 256: (2, 16),
1093
+ 4096: (32, 32),
1094
+ 512: (8, 32),
1095
+ 8192: (32, 32),
1096
+ },
1097
+ 'q_head-64_kv_head-1_head-256': {
1098
+ 1024: (16, 32),
1099
+ 128: (2, 8),
1100
+ 2048: (32, 8),
1101
+ 256: (4, 8),
1102
+ 4096: (32, 32),
1103
+ 512: (8, 8),
1104
+ 8192: (32, 32),
1105
+ },
1106
+ 'q_head-64_kv_head-16_head-128': {
1107
+ 1024: (16, 32),
1108
+ 128: (2, 16),
1109
+ 2048: (16, 32),
1110
+ 256: (4, 32),
1111
+ 4096: (16, 32),
1112
+ 512: (8, 32),
1113
+ 8192: (16, 32),
1114
+ },
1115
+ 'q_head-64_kv_head-16_head-256': {
1116
+ 1024: (8, 16),
1117
+ 128: (2, 16),
1118
+ 2048: (8, 16),
1119
+ 256: (4, 16),
1120
+ 4096: (8, 16),
1121
+ 512: (8, 16),
1122
+ 8192: (8, 16),
1123
+ },
1124
+ 'q_head-64_kv_head-2_head-128': {
1125
+ 1024: (16, 32),
1126
+ 128: (2, 64),
1127
+ 2048: (32, 8),
1128
+ 256: (4, 16),
1129
+ 4096: (32, 32),
1130
+ 512: (8, 16),
1131
+ 8192: (32, 32),
1132
+ },
1133
+ 'q_head-64_kv_head-2_head-256': {
1134
+ 1024: (16, 16),
1135
+ 128: (2, 8),
1136
+ 2048: (32, 32),
1137
+ 256: (4, 8),
1138
+ 4096: (32, 32),
1139
+ 512: (8, 8),
1140
+ 8192: (32, 16),
1141
+ },
1142
+ 'q_head-64_kv_head-4_head-128': {
1143
+ 1024: (16, 16),
1144
+ 128: (2, 16),
1145
+ 2048: (32, 32),
1146
+ 256: (4, 16),
1147
+ 4096: (32, 32),
1148
+ 512: (8, 32),
1149
+ 8192: (32, 32),
1150
+ },
1151
+ 'q_head-64_kv_head-4_head-256': {
1152
+ 1024: (16, 8),
1153
+ 128: (2, 16),
1154
+ 2048: (32, 16),
1155
+ 256: (4, 32),
1156
+ 4096: (32, 32),
1157
+ 512: (8, 16),
1158
+ 8192: (32, 32),
1159
+ },
1160
+ 'q_head-64_kv_head-8_head-128': {
1161
+ 1024: (16, 32),
1162
+ 128: (2, 32),
1163
+ 2048: (32, 32),
1164
+ 256: (4, 32),
1165
+ 4096: (32, 64),
1166
+ 512: (8, 16),
1167
+ 8192: (32, 32),
1168
+ },
1169
+ 'q_head-64_kv_head-8_head-256': {
1170
+ 1024: (16, 32),
1171
+ 128: (2, 8),
1172
+ 2048: (32, 16),
1173
+ 256: (4, 16),
1174
+ 4096: (32, 16),
1175
+ 512: (8, 16),
1176
+ 8192: (16, 32),
1177
+ },
1178
+ 'q_head-8_kv_head-1_head-128': {
1179
+ 1024: (16, 16),
1180
+ 128: (2, 256),
1181
+ 2048: (16, 32),
1182
+ 256: (4, 128),
1183
+ 4096: (32, 32),
1184
+ 512: (8, 64),
1185
+ 8192: (32, 32),
1186
+ },
1187
+ 'q_head-8_kv_head-1_head-256': {
1188
+ 1024: (16, 8),
1189
+ 128: (2, 64),
1190
+ 2048: (32, 64),
1191
+ 256: (2, 32),
1192
+ 4096: (32, 64),
1193
+ 512: (4, 32),
1194
+ 8192: (32, 32),
1195
+ },
1196
+ 'q_head-8_kv_head-2_head-128': {
1197
+ 1024: (16, 16),
1198
+ 128: (2, 32),
1199
+ 2048: (32, 32),
1200
+ 256: (4, 128),
1201
+ 4096: (32, 128),
1202
+ 512: (8, 64),
1203
+ 8192: (32, 64),
1204
+ },
1205
+ 'q_head-8_kv_head-2_head-256': {
1206
+ 1024: (16, 64),
1207
+ 128: (2, 32),
1208
+ 2048: (32, 32),
1209
+ 256: (4, 16),
1210
+ 4096: (32, 128),
1211
+ 512: (8, 16),
1212
+ 8192: (32, 128),
1213
+ },
1214
+ 'q_head-8_kv_head-4_head-128': {
1215
+ 1024: (16, 64),
1216
+ 128: (2, 16),
1217
+ 2048: (32, 64),
1218
+ 256: (4, 64),
1219
+ 4096: (32, 64),
1220
+ 512: (8, 64),
1221
+ 8192: (32, 128),
1222
+ },
1223
+ 'q_head-8_kv_head-4_head-256': {
1224
+ 1024: (16, 64),
1225
+ 128: (2, 32),
1226
+ 2048: (32, 64),
1227
+ 256: (4, 32),
1228
+ 4096: (32, 64),
1229
+ 512: (8, 64),
1230
+ 8192: (32, 128),
1231
+ },
1232
+ }
1233
+ },
1234
+ 16: {
1235
+ 'q_bfloat16_kv_bfloat16': {
1236
+ 'q_head-8_kv_head-1_head-128': {
1237
+ 262144: (128, 256),
1238
+ }
1239
+ }
1240
+ },
1241
+ },
1242
+ 'TPU v5e': {
1243
+ 128: {
1244
+ 'q_bfloat16_kv_bfloat16': {
1245
+ 'q_head-128_kv_head-1_head-128': {
1246
+ 1024: (4, 32),
1247
+ 128: (1, 8),
1248
+ 2048: (16, 8),
1249
+ 256: (2, 8),
1250
+ 4096: (16, 16),
1251
+ 512: (4, 8),
1252
+ 8192: (16, 16),
1253
+ },
1254
+ 'q_head-128_kv_head-1_head-256': {
1255
+ 1024: (8, 16),
1256
+ 128: (1, 8),
1257
+ 2048: (16, 8),
1258
+ 256: (2, 8),
1259
+ 4096: (16, 8),
1260
+ 512: (2, 8),
1261
+ 8192: (16, 8),
1262
+ },
1263
+ 'q_head-128_kv_head-16_head-128': {
1264
+ 1024: (8, 16),
1265
+ 128: (1, 16),
1266
+ 2048: (8, 16),
1267
+ 256: (2, 8),
1268
+ 4096: (8, 16),
1269
+ 512: (2, 16),
1270
+ 8192: (8, 16),
1271
+ },
1272
+ 'q_head-128_kv_head-16_head-256': {
1273
+ 1024: (4, 8),
1274
+ 128: (1, 8),
1275
+ 2048: (4, 8),
1276
+ 256: (2, 8),
1277
+ 4096: (4, 8),
1278
+ 512: (4, 8),
1279
+ 8192: (4, 8),
1280
+ },
1281
+ 'q_head-128_kv_head-2_head-128': {
1282
+ 1024: (8, 8),
1283
+ 128: (1, 8),
1284
+ 2048: (16, 8),
1285
+ 256: (2, 16),
1286
+ 4096: (8, 16),
1287
+ 512: (4, 16),
1288
+ 8192: (16, 16),
1289
+ },
1290
+ 'q_head-128_kv_head-2_head-256': {
1291
+ 1024: (8, 8),
1292
+ 128: (1, 8),
1293
+ 2048: (16, 8),
1294
+ 256: (2, 8),
1295
+ 4096: (8, 16),
1296
+ 512: (4, 8),
1297
+ 8192: (8, 8),
1298
+ },
1299
+ 'q_head-128_kv_head-4_head-128': {
1300
+ 1024: (8, 8),
1301
+ 128: (1, 16),
1302
+ 2048: (8, 8),
1303
+ 256: (2, 8),
1304
+ 4096: (8, 32),
1305
+ 512: (4, 8),
1306
+ 8192: (8, 16),
1307
+ },
1308
+ 'q_head-128_kv_head-4_head-256': {
1309
+ 1024: (4, 8),
1310
+ 128: (1, 8),
1311
+ 2048: (8, 16),
1312
+ 256: (2, 8),
1313
+ 4096: (8, 16),
1314
+ 512: (4, 8),
1315
+ 8192: (8, 16),
1316
+ },
1317
+ 'q_head-128_kv_head-8_head-128': {
1318
+ 1024: (8, 32),
1319
+ 128: (1, 8),
1320
+ 2048: (8, 16),
1321
+ 256: (2, 16),
1322
+ 4096: (8, 16),
1323
+ 512: (4, 16),
1324
+ 8192: (8, 16),
1325
+ },
1326
+ 'q_head-128_kv_head-8_head-256': {
1327
+ 1024: (4, 16),
1328
+ 128: (1, 8),
1329
+ 2048: (8, 16),
1330
+ 256: (2, 8),
1331
+ 4096: (8, 16),
1332
+ 512: (4, 16),
1333
+ 8192: (4, 16),
1334
+ },
1335
+ 'q_head-16_kv_head-1_head-128': {
1336
+ 2048: (8, 64),
1337
+ 512: (4, 64)
1338
+ },
1339
+ 'q_head-16_kv_head-1_head-256': {
1340
+ 128: (1, 32),
1341
+ 256: (2, 8)
1342
+ },
1343
+ 'q_head-16_kv_head-2_head-128': {
1344
+ 128: (1, 128),
1345
+ 256: (2, 8),
1346
+ 512: (2, 32),
1347
+ 8192: (16, 32),
1348
+ },
1349
+ 'q_head-16_kv_head-2_head-256': {
1350
+ 128: (1, 32),
1351
+ 2048: (8, 32),
1352
+ 256: (2, 32),
1353
+ },
1354
+ 'q_head-16_kv_head-4_head-128': {
1355
+ 1024: (8, 32),
1356
+ 128: (1, 64),
1357
+ 256: (2, 16),
1358
+ 512: (4, 64),
1359
+ },
1360
+ 'q_head-16_kv_head-4_head-256': {
1361
+ 1024: (8, 128),
1362
+ 128: (1, 16),
1363
+ 2048: (8, 64),
1364
+ 256: (2, 32),
1365
+ 4096: (8, 32),
1366
+ 512: (4, 32),
1367
+ 8192: (16, 64),
1368
+ },
1369
+ 'q_head-16_kv_head-8_head-128': {
1370
+ 1024: (8, 256),
1371
+ 128: (1, 128),
1372
+ 2048: (8, 128),
1373
+ 256: (2, 16),
1374
+ 4096: (8, 64),
1375
+ 512: (4, 64),
1376
+ 8192: (4, 128),
1377
+ },
1378
+ 'q_head-16_kv_head-8_head-256': {
1379
+ 1024: (8, 128),
1380
+ 128: (1, 16),
1381
+ 2048: (8, 128),
1382
+ 256: (2, 64),
1383
+ 4096: (8, 128),
1384
+ 512: (2, 32),
1385
+ 8192: (8, 128),
1386
+ },
1387
+ 'q_head-2_kv_head-1_head-128': {
1388
+ 1024: (8, 128),
1389
+ 128: (1, 256),
1390
+ 2048: (8, 32),
1391
+ 256: (2, 8),
1392
+ 512: (4, 256),
1393
+ 8192: (16, 32),
1394
+ },
1395
+ 'q_head-2_kv_head-1_head-256': {
1396
+ 1024: (8, 128),
1397
+ 2048: (8, 64),
1398
+ 256: (2, 8),
1399
+ 4096: (8, 128),
1400
+ 512: (4, 32),
1401
+ 8192: (16, 64),
1402
+ },
1403
+ 'q_head-32_kv_head-1_head-128': {
1404
+ 1024: (8, 16),
1405
+ 128: (1, 128),
1406
+ 2048: (8, 32),
1407
+ 256: (2, 16),
1408
+ 4096: (16, 64),
1409
+ 512: (4, 64),
1410
+ 8192: (16, 16),
1411
+ },
1412
+ 'q_head-32_kv_head-1_head-256': {
1413
+ 1024: (8, 16),
1414
+ 128: (1, 16),
1415
+ 2048: (16, 32),
1416
+ 256: (2, 8),
1417
+ 4096: (16, 16),
1418
+ 512: (4, 16),
1419
+ 8192: (16, 16),
1420
+ },
1421
+ 'q_head-32_kv_head-16_head-128': {
1422
+ 1024: (8, 64),
1423
+ 128: (1, 8),
1424
+ 2048: (8, 64),
1425
+ 256: (2, 32),
1426
+ 4096: (8, 64),
1427
+ 512: (4, 64),
1428
+ 8192: (8, 64),
1429
+ },
1430
+ 'q_head-32_kv_head-16_head-256': {
1431
+ 1024: (4, 32),
1432
+ 128: (1, 8),
1433
+ 2048: (4, 32),
1434
+ 256: (2, 32),
1435
+ 4096: (4, 32),
1436
+ 512: (4, 32),
1437
+ 8192: (4, 32),
1438
+ },
1439
+ 'q_head-32_kv_head-2_head-128': {
1440
+ 1024: (4, 8),
1441
+ 128: (1, 32),
1442
+ 2048: (8, 64),
1443
+ 256: (2, 8),
1444
+ 4096: (16, 32),
1445
+ 512: (4, 32),
1446
+ 8192: (16, 16),
1447
+ },
1448
+ 'q_head-32_kv_head-2_head-256': {
1449
+ 1024: (8, 16),
1450
+ 128: (1, 16),
1451
+ 2048: (8, 32),
1452
+ 256: (2, 16),
1453
+ 4096: (8, 32),
1454
+ 512: (4, 8),
1455
+ 8192: (8, 32),
1456
+ },
1457
+ 'q_head-32_kv_head-4_head-128': {
1458
+ 1024: (8, 64),
1459
+ 128: (1, 32),
1460
+ 2048: (8, 64),
1461
+ 256: (2, 16),
1462
+ 4096: (8, 32),
1463
+ 512: (4, 16),
1464
+ 8192: (8, 32),
1465
+ },
1466
+ 'q_head-32_kv_head-4_head-256': {
1467
+ 1024: (8, 32),
1468
+ 128: (1, 16),
1469
+ 2048: (8, 32),
1470
+ 256: (2, 32),
1471
+ 4096: (8, 32),
1472
+ 512: (4, 16),
1473
+ 8192: (8, 32),
1474
+ },
1475
+ 'q_head-32_kv_head-8_head-128': {
1476
+ 1024: (8, 128),
1477
+ 128: (1, 16),
1478
+ 2048: (4, 32),
1479
+ 256: (1, 16),
1480
+ 4096: (16, 32),
1481
+ 512: (4, 64),
1482
+ 8192: (4, 64),
1483
+ },
1484
+ 'q_head-32_kv_head-8_head-256': {
1485
+ 1024: (8, 32),
1486
+ 128: (1, 8),
1487
+ 2048: (4, 64),
1488
+ 256: (2, 16),
1489
+ 4096: (8, 64),
1490
+ 512: (4, 32),
1491
+ 8192: (8, 64),
1492
+ },
1493
+ 'q_head-4_kv_head-1_head-128': {
1494
+ 1024: (8, 32),
1495
+ 2048: (8, 128),
1496
+ 256: (1, 256),
1497
+ 4096: (16, 128),
1498
+ 512: (4, 128),
1499
+ 8192: (16, 16),
1500
+ },
1501
+ 'q_head-4_kv_head-1_head-256': {
1502
+ 1024: (8, 16),
1503
+ 2048: (8, 32),
1504
+ 4096: (16, 32),
1505
+ 8192: (16, 32),
1506
+ },
1507
+ 'q_head-4_kv_head-2_head-128': {
1508
+ 1024: (8, 64),
1509
+ 128: (1, 64),
1510
+ 2048: (8, 128),
1511
+ 256: (1, 256),
1512
+ 4096: (16, 128),
1513
+ 8192: (8, 32),
1514
+ },
1515
+ 'q_head-4_kv_head-2_head-256': {
1516
+ 1024: (8, 32),
1517
+ 128: (1, 8),
1518
+ 4096: (8, 256),
1519
+ 8192: (8, 128),
1520
+ },
1521
+ 'q_head-64_kv_head-1_head-128': {
1522
+ 1024: (4, 32),
1523
+ 128: (1, 16),
1524
+ 2048: (16, 32),
1525
+ 256: (2, 32),
1526
+ 4096: (16, 32),
1527
+ 512: (4, 16),
1528
+ 8192: (16, 32),
1529
+ },
1530
+ 'q_head-64_kv_head-1_head-256': {
1531
+ 1024: (8, 16),
1532
+ 128: (1, 8),
1533
+ 2048: (16, 8),
1534
+ 256: (2, 16),
1535
+ 4096: (16, 16),
1536
+ 512: (4, 16),
1537
+ 8192: (16, 16),
1538
+ },
1539
+ 'q_head-64_kv_head-16_head-128': {
1540
+ 1024: (4, 32),
1541
+ 128: (1, 16),
1542
+ 2048: (8, 32),
1543
+ 256: (2, 32),
1544
+ 4096: (8, 32),
1545
+ 512: (2, 32),
1546
+ 8192: (8, 32),
1547
+ },
1548
+ 'q_head-64_kv_head-16_head-256': {
1549
+ 1024: (4, 16),
1550
+ 128: (1, 16),
1551
+ 2048: (4, 16),
1552
+ 256: (2, 16),
1553
+ 4096: (4, 16),
1554
+ 512: (4, 16),
1555
+ 8192: (4, 16),
1556
+ },
1557
+ 'q_head-64_kv_head-2_head-128': {
1558
+ 1024: (8, 8),
1559
+ 128: (1, 16),
1560
+ 2048: (8, 16),
1561
+ 256: (1, 16),
1562
+ 4096: (8, 16),
1563
+ 512: (4, 16),
1564
+ 8192: (8, 32),
1565
+ },
1566
+ 'q_head-64_kv_head-2_head-256': {
1567
+ 1024: (4, 8),
1568
+ 128: (1, 8),
1569
+ 2048: (16, 16),
1570
+ 256: (2, 8),
1571
+ 4096: (8, 16),
1572
+ 512: (4, 8),
1573
+ 8192: (8, 16),
1574
+ },
1575
+ 'q_head-64_kv_head-4_head-128': {
1576
+ 1024: (8, 32),
1577
+ 128: (1, 8),
1578
+ 2048: (16, 16),
1579
+ 256: (1, 32),
1580
+ 4096: (8, 32),
1581
+ 512: (4, 32),
1582
+ 8192: (16, 32),
1583
+ },
1584
+ 'q_head-64_kv_head-4_head-256': {
1585
+ 1024: (4, 16),
1586
+ 128: (1, 8),
1587
+ 2048: (8, 32),
1588
+ 256: (1, 8),
1589
+ 4096: (8, 32),
1590
+ 512: (4, 16),
1591
+ 8192: (8, 32),
1592
+ },
1593
+ 'q_head-64_kv_head-8_head-128': {
1594
+ 1024: (8, 16),
1595
+ 128: (1, 32),
1596
+ 2048: (4, 32),
1597
+ 256: (2, 64),
1598
+ 4096: (4, 32),
1599
+ 512: (4, 32),
1600
+ 8192: (16, 32),
1601
+ },
1602
+ 'q_head-64_kv_head-8_head-256': {
1603
+ 1024: (8, 32),
1604
+ 128: (1, 8),
1605
+ 2048: (8, 32),
1606
+ 256: (2, 16),
1607
+ 4096: (4, 32),
1608
+ 512: (4, 16),
1609
+ 8192: (8, 32),
1610
+ },
1611
+ 'q_head-8_kv_head-1_head-128': {
1612
+ 2048: (8, 32),
1613
+ 4096: (8, 16),
1614
+ 512: (4, 128),
1615
+ 8192: (16, 32),
1616
+ },
1617
+ 'q_head-8_kv_head-1_head-256': {
1618
+ 128: (1, 8),
1619
+ 2048: (8, 16),
1620
+ 8192: (8, 32),
1621
+ },
1622
+ 'q_head-8_kv_head-2_head-128': {
1623
+ 128: (1, 64),
1624
+ 256: (2, 64),
1625
+ 4096: (16, 32),
1626
+ 512: (4, 64),
1627
+ 8192: (16, 128),
1628
+ },
1629
+ 'q_head-8_kv_head-2_head-256': {
1630
+ 1024: (8, 128),
1631
+ 128: (1, 32),
1632
+ 8192: (8, 128),
1633
+ },
1634
+ 'q_head-8_kv_head-4_head-128': {
1635
+ 128: (1, 16),
1636
+ 256: (2, 32),
1637
+ 4096: (16, 32),
1638
+ 512: (4, 8),
1639
+ },
1640
+ 'q_head-8_kv_head-4_head-256': {
1641
+ 128: (1, 32),
1642
+ 2048: (8, 128),
1643
+ 256: (2, 32),
1644
+ 512: (4, 16),
1645
+ },
1646
+ }
1647
+ },
1648
+ 256: {
1649
+ 'q_bfloat16_kv_bfloat16': {
1650
+ 'q_head-128_kv_head-1_head-128': {
1651
+ 1024: (2, 16),
1652
+ 2048: (4, 8),
1653
+ 256: (1, 8),
1654
+ 4096: (8, 8),
1655
+ 512: (2, 8),
1656
+ 8192: (8, 16),
1657
+ },
1658
+ 'q_head-128_kv_head-1_head-256': {
1659
+ 1024: (4, 8),
1660
+ 2048: (4, 8),
1661
+ 256: (1, 8),
1662
+ 4096: (8, 8),
1663
+ 512: (2, 8),
1664
+ 8192: (8, 8),
1665
+ },
1666
+ 'q_head-128_kv_head-16_head-128': {
1667
+ 1024: (4, 16),
1668
+ 2048: (4, 16),
1669
+ 256: (1, 16),
1670
+ 4096: (4, 16),
1671
+ 512: (2, 16),
1672
+ 8192: (4, 16),
1673
+ },
1674
+ 'q_head-128_kv_head-16_head-256': {
1675
+ 1024: (2, 8),
1676
+ 2048: (2, 8),
1677
+ 256: (1, 8),
1678
+ 4096: (2, 8),
1679
+ 512: (2, 8),
1680
+ 8192: (2, 8),
1681
+ },
1682
+ 'q_head-128_kv_head-2_head-128': {
1683
+ 1024: (4, 8),
1684
+ 2048: (8, 8),
1685
+ 256: (1, 16),
1686
+ 4096: (8, 8),
1687
+ 512: (2, 8),
1688
+ 8192: (8, 16),
1689
+ },
1690
+ 'q_head-128_kv_head-2_head-256': {
1691
+ 1024: (4, 8),
1692
+ 2048: (4, 8),
1693
+ 256: (1, 8),
1694
+ 4096: (8, 8),
1695
+ 512: (1, 8),
1696
+ 8192: (8, 8),
1697
+ },
1698
+ 'q_head-128_kv_head-4_head-128': {
1699
+ 1024: (4, 16),
1700
+ 2048: (4, 16),
1701
+ 256: (1, 32),
1702
+ 4096: (8, 16),
1703
+ 512: (2, 32),
1704
+ 8192: (4, 16),
1705
+ },
1706
+ 'q_head-128_kv_head-4_head-256': {
1707
+ 1024: (2, 8),
1708
+ 2048: (4, 16),
1709
+ 256: (1, 8),
1710
+ 4096: (8, 8),
1711
+ 512: (2, 8),
1712
+ 8192: (4, 16),
1713
+ },
1714
+ 'q_head-128_kv_head-8_head-128': {
1715
+ 1024: (4, 16),
1716
+ 2048: (4, 32),
1717
+ 256: (1, 32),
1718
+ 4096: (4, 32),
1719
+ 512: (2, 16),
1720
+ 8192: (2, 32),
1721
+ },
1722
+ 'q_head-128_kv_head-8_head-256': {
1723
+ 1024: (4, 16),
1724
+ 2048: (2, 16),
1725
+ 256: (1, 8),
1726
+ 4096: (2, 16),
1727
+ 512: (2, 16),
1728
+ 8192: (2, 16),
1729
+ },
1730
+ 'q_head-16_kv_head-1_head-128': {
1731
+ 1024: (2, 32),
1732
+ 2048: (8, 16),
1733
+ 256: (1, 32),
1734
+ 4096: (8, 32),
1735
+ 512: (1, 64),
1736
+ 8192: (8, 32),
1737
+ },
1738
+ 'q_head-16_kv_head-1_head-256': {
1739
+ 1024: (4, 32),
1740
+ 2048: (4, 16),
1741
+ 256: (1, 32),
1742
+ 4096: (8, 16),
1743
+ 512: (2, 8),
1744
+ 8192: (8, 16),
1745
+ },
1746
+ 'q_head-16_kv_head-2_head-128': {
1747
+ 1024: (4, 16),
1748
+ 2048: (4, 32),
1749
+ 256: (1, 8),
1750
+ 4096: (4, 64),
1751
+ 512: (2, 16),
1752
+ 8192: (8, 128),
1753
+ },
1754
+ 'q_head-16_kv_head-2_head-256': {
1755
+ 1024: (4, 32),
1756
+ 2048: (4, 16),
1757
+ 256: (1, 64),
1758
+ 4096: (8, 32),
1759
+ 512: (2, 16),
1760
+ 8192: (4, 32),
1761
+ },
1762
+ 'q_head-16_kv_head-4_head-128': {
1763
+ 1024: (2, 64),
1764
+ 2048: (2, 64),
1765
+ 256: (1, 64),
1766
+ 4096: (4, 32),
1767
+ 512: (2, 128),
1768
+ 8192: (8, 32),
1769
+ },
1770
+ 'q_head-16_kv_head-4_head-256': {
1771
+ 1024: (2, 64),
1772
+ 2048: (8, 32),
1773
+ 256: (1, 32),
1774
+ 4096: (4, 128),
1775
+ 512: (2, 16),
1776
+ 8192: (4, 32),
1777
+ },
1778
+ 'q_head-16_kv_head-8_head-128': {
1779
+ 1024: (4, 64),
1780
+ 2048: (4, 32),
1781
+ 256: (1, 8),
1782
+ 4096: (2, 128),
1783
+ 512: (2, 64),
1784
+ 8192: (8, 128),
1785
+ },
1786
+ 'q_head-16_kv_head-8_head-256': {
1787
+ 1024: (4, 64),
1788
+ 2048: (4, 128),
1789
+ 256: (1, 16),
1790
+ 4096: (4, 128),
1791
+ 512: (1, 32),
1792
+ 8192: (4, 128),
1793
+ },
1794
+ 'q_head-2_kv_head-1_head-128': {
1795
+ 1024: (4, 64),
1796
+ 2048: (8, 128),
1797
+ 256: (1, 64),
1798
+ 4096: (8, 256),
1799
+ 512: (2, 64),
1800
+ 8192: (8, 256),
1801
+ },
1802
+ 'q_head-2_kv_head-1_head-256': {
1803
+ 1024: (4, 128),
1804
+ 2048: (8, 32),
1805
+ 256: (1, 32),
1806
+ 4096: (8, 256),
1807
+ 512: (2, 32),
1808
+ 8192: (4, 32),
1809
+ },
1810
+ 'q_head-32_kv_head-1_head-128': {
1811
+ 1024: (2, 32),
1812
+ 2048: (4, 16),
1813
+ 256: (1, 64),
1814
+ 4096: (8, 16),
1815
+ 512: (2, 32),
1816
+ 8192: (8, 64),
1817
+ },
1818
+ 'q_head-32_kv_head-1_head-256': {
1819
+ 1024: (4, 8),
1820
+ 2048: (8, 16),
1821
+ 256: (1, 16),
1822
+ 4096: (8, 16),
1823
+ 512: (2, 16),
1824
+ 8192: (8, 16),
1825
+ },
1826
+ 'q_head-32_kv_head-16_head-128': {
1827
+ 1024: (4, 64),
1828
+ 2048: (4, 64),
1829
+ 256: (1, 64),
1830
+ 4096: (4, 64),
1831
+ 512: (2, 32),
1832
+ 8192: (4, 64),
1833
+ },
1834
+ 'q_head-32_kv_head-16_head-256': {
1835
+ 1024: (2, 32),
1836
+ 2048: (2, 32),
1837
+ 256: (1, 32),
1838
+ 4096: (2, 32),
1839
+ 512: (2, 32),
1840
+ 8192: (2, 32),
1841
+ },
1842
+ 'q_head-32_kv_head-2_head-128': {
1843
+ 1024: (4, 16),
1844
+ 2048: (8, 16),
1845
+ 256: (1, 8),
1846
+ 4096: (4, 32),
1847
+ 512: (2, 16),
1848
+ 8192: (8, 32),
1849
+ },
1850
+ 'q_head-32_kv_head-2_head-256': {
1851
+ 1024: (2, 16),
1852
+ 2048: (8, 16),
1853
+ 256: (1, 32),
1854
+ 4096: (8, 16),
1855
+ 512: (2, 16),
1856
+ 8192: (8, 32),
1857
+ },
1858
+ 'q_head-32_kv_head-4_head-128': {
1859
+ 1024: (4, 64),
1860
+ 2048: (8, 32),
1861
+ 256: (1, 16),
1862
+ 4096: (4, 128),
1863
+ 512: (2, 16),
1864
+ 8192: (4, 128),
1865
+ },
1866
+ 'q_head-32_kv_head-4_head-256': {
1867
+ 1024: (4, 16),
1868
+ 2048: (2, 32),
1869
+ 256: (1, 32),
1870
+ 4096: (8, 32),
1871
+ 512: (2, 32),
1872
+ 8192: (4, 32),
1873
+ },
1874
+ 'q_head-32_kv_head-8_head-128': {
1875
+ 1024: (4, 128),
1876
+ 2048: (4, 128),
1877
+ 256: (1, 32),
1878
+ 4096: (4, 128),
1879
+ 512: (2, 16),
1880
+ 8192: (2, 64),
1881
+ },
1882
+ 'q_head-32_kv_head-8_head-256': {
1883
+ 1024: (2, 64),
1884
+ 2048: (2, 32),
1885
+ 256: (1, 16),
1886
+ 4096: (4, 64),
1887
+ 512: (1, 32),
1888
+ 8192: (4, 64),
1889
+ },
1890
+ 'q_head-4_kv_head-1_head-128': {
1891
+ 1024: (4, 16),
1892
+ 2048: (8, 16),
1893
+ 256: (1, 128),
1894
+ 4096: (4, 128),
1895
+ 512: (2, 128),
1896
+ 8192: (8, 32),
1897
+ },
1898
+ 'q_head-4_kv_head-1_head-256': {
1899
+ 1024: (4, 16),
1900
+ 2048: (4, 32),
1901
+ 256: (1, 64),
1902
+ 4096: (8, 64),
1903
+ 512: (2, 64),
1904
+ 8192: (4, 64),
1905
+ },
1906
+ 'q_head-4_kv_head-2_head-128': {
1907
+ 1024: (4, 256),
1908
+ 2048: (8, 128),
1909
+ 256: (1, 64),
1910
+ 4096: (8, 256),
1911
+ 512: (1, 64),
1912
+ 8192: (8, 128),
1913
+ },
1914
+ 'q_head-4_kv_head-2_head-256': {
1915
+ 1024: (4, 32),
1916
+ 2048: (4, 32),
1917
+ 256: (1, 8),
1918
+ 4096: (8, 64),
1919
+ 512: (2, 64),
1920
+ 8192: (4, 64),
1921
+ },
1922
+ 'q_head-64_kv_head-1_head-128': {
1923
+ 1024: (2, 8),
1924
+ 2048: (8, 16),
1925
+ 256: (1, 32),
1926
+ 4096: (8, 16),
1927
+ 512: (2, 16),
1928
+ 8192: (8, 8),
1929
+ },
1930
+ 'q_head-64_kv_head-1_head-256': {
1931
+ 1024: (4, 8),
1932
+ 2048: (8, 8),
1933
+ 256: (1, 8),
1934
+ 4096: (4, 8),
1935
+ 512: (1, 16),
1936
+ 8192: (8, 16),
1937
+ },
1938
+ 'q_head-64_kv_head-16_head-128': {
1939
+ 1024: (2, 32),
1940
+ 2048: (4, 32),
1941
+ 256: (1, 16),
1942
+ 4096: (2, 32),
1943
+ 512: (2, 32),
1944
+ 8192: (4, 32),
1945
+ },
1946
+ 'q_head-64_kv_head-16_head-256': {
1947
+ 1024: (2, 16),
1948
+ 2048: (2, 16),
1949
+ 256: (1, 16),
1950
+ 4096: (2, 16),
1951
+ 512: (2, 16),
1952
+ 8192: (2, 16),
1953
+ },
1954
+ 'q_head-64_kv_head-2_head-128': {
1955
+ 1024: (4, 16),
1956
+ 2048: (8, 16),
1957
+ 256: (1, 8),
1958
+ 4096: (8, 16),
1959
+ 512: (2, 32),
1960
+ 8192: (8, 16),
1961
+ },
1962
+ 'q_head-64_kv_head-2_head-256': {
1963
+ 1024: (2, 8),
1964
+ 2048: (4, 16),
1965
+ 256: (1, 16),
1966
+ 4096: (4, 16),
1967
+ 512: (2, 8),
1968
+ 8192: (4, 32),
1969
+ },
1970
+ 'q_head-64_kv_head-4_head-128': {
1971
+ 1024: (4, 16),
1972
+ 2048: (8, 32),
1973
+ 256: (1, 32),
1974
+ 4096: (8, 32),
1975
+ 512: (2, 64),
1976
+ 8192: (4, 32),
1977
+ },
1978
+ 'q_head-64_kv_head-4_head-256': {
1979
+ 1024: (4, 32),
1980
+ 2048: (8, 16),
1981
+ 256: (1, 16),
1982
+ 4096: (4, 16),
1983
+ 512: (2, 16),
1984
+ 8192: (4, 32),
1985
+ },
1986
+ 'q_head-64_kv_head-8_head-128': {
1987
+ 1024: (4, 16),
1988
+ 2048: (2, 32),
1989
+ 256: (1, 8),
1990
+ 4096: (8, 32),
1991
+ 512: (2, 64),
1992
+ 8192: (4, 32),
1993
+ },
1994
+ 'q_head-64_kv_head-8_head-256': {
1995
+ 1024: (4, 32),
1996
+ 2048: (4, 32),
1997
+ 256: (1, 8),
1998
+ 4096: (4, 32),
1999
+ 512: (2, 16),
2000
+ 8192: (4, 32),
2001
+ },
2002
+ 'q_head-8_kv_head-1_head-128': {
2003
+ 1024: (4, 8),
2004
+ 2048: (8, 64),
2005
+ 256: (1, 32),
2006
+ 4096: (8, 64),
2007
+ 512: (2, 32),
2008
+ 8192: (8, 32),
2009
+ },
2010
+ 'q_head-8_kv_head-1_head-256': {
2011
+ 1024: (2, 16),
2012
+ 2048: (8, 8),
2013
+ 256: (1, 64),
2014
+ 4096: (8, 64),
2015
+ 512: (2, 16),
2016
+ 8192: (8, 64),
2017
+ },
2018
+ 'q_head-8_kv_head-2_head-128': {
2019
+ 1024: (4, 64),
2020
+ 2048: (8, 16),
2021
+ 256: (1, 16),
2022
+ 4096: (8, 32),
2023
+ 512: (2, 128),
2024
+ 8192: (8, 32),
2025
+ },
2026
+ 'q_head-8_kv_head-2_head-256': {
2027
+ 1024: (2, 32),
2028
+ 2048: (2, 32),
2029
+ 256: (1, 32),
2030
+ 4096: (4, 64),
2031
+ 512: (2, 16),
2032
+ 8192: (4, 64),
2033
+ },
2034
+ 'q_head-8_kv_head-4_head-128': {
2035
+ 1024: (4, 256),
2036
+ 2048: (4, 32),
2037
+ 256: (1, 64),
2038
+ 4096: (8, 64),
2039
+ 512: (2, 64),
2040
+ 8192: (4, 64),
2041
+ },
2042
+ 'q_head-8_kv_head-4_head-256': {
2043
+ 1024: (4, 64),
2044
+ 2048: (4, 64),
2045
+ 256: (1, 64),
2046
+ 4096: (4, 128),
2047
+ 512: (2, 64),
2048
+ 8192: (4, 128),
2049
+ },
2050
+ }
2051
+ },
2052
+ 64: {
2053
+ 'q_bfloat16_kv_bfloat16': {
2054
+ 'q_head-128_kv_head-1_head-128': {
2055
+ 1024: (8, 16),
2056
+ 128: (2, 16),
2057
+ 2048: (16, 16),
2058
+ 256: (4, 8),
2059
+ 512: (4, 16),
2060
+ 64: (1, 8),
2061
+ },
2062
+ 'q_head-128_kv_head-1_head-256': {
2063
+ 1024: (16, 8),
2064
+ 2048: (32, 8),
2065
+ 256: (2, 8),
2066
+ 512: (8, 8),
2067
+ 64: (1, 8),
2068
+ 8192: (32, 8),
2069
+ },
2070
+ 'q_head-128_kv_head-16_head-128': {
2071
+ 1024: (16, 16),
2072
+ 128: (2, 16),
2073
+ 256: (2, 8),
2074
+ 512: (8, 16),
2075
+ 64: (1, 8),
2076
+ },
2077
+ 'q_head-128_kv_head-16_head-256': {
2078
+ 128: (2, 8),
2079
+ 256: (4, 8),
2080
+ 4096: (8, 8),
2081
+ 512: (8, 8),
2082
+ 64: (1, 8),
2083
+ },
2084
+ 'q_head-128_kv_head-2_head-128': {
2085
+ 1024: (16, 16),
2086
+ 2048: (16, 8),
2087
+ 256: (4, 8),
2088
+ 4096: (16, 16),
2089
+ 512: (8, 16),
2090
+ 64: (1, 8),
2091
+ 8192: (32, 16),
2092
+ },
2093
+ 'q_head-128_kv_head-2_head-256': {
2094
+ 1024: (16, 8),
2095
+ 2048: (16, 8),
2096
+ 256: (4, 8),
2097
+ 4096: (32, 8),
2098
+ },
2099
+ 'q_head-128_kv_head-4_head-128': {
2100
+ 1024: (16, 8),
2101
+ 128: (1, 8),
2102
+ 2048: (16, 8),
2103
+ 4096: (16, 16),
2104
+ 512: (8, 32),
2105
+ 64: (1, 32),
2106
+ 8192: (16, 32),
2107
+ },
2108
+ 'q_head-128_kv_head-4_head-256': {
2109
+ 1024: (8, 8),
2110
+ 128: (2, 8),
2111
+ 2048: (16, 8),
2112
+ 256: (4, 8),
2113
+ 4096: (32, 32),
2114
+ 64: (1, 8),
2115
+ 8192: (32, 32),
2116
+ },
2117
+ 'q_head-128_kv_head-8_head-128': {
2118
+ 1024: (8, 16),
2119
+ 4096: (8, 16),
2120
+ 64: (1, 8),
2121
+ 8192: (8, 32),
2122
+ },
2123
+ 'q_head-128_kv_head-8_head-256': {
2124
+ 128: (2, 8),
2125
+ 256: (4, 8),
2126
+ 4096: (16, 16),
2127
+ 64: (1, 8),
2128
+ 8192: (8, 16),
2129
+ },
2130
+ 'q_head-16_kv_head-1_head-128': {
2131
+ 1024: (16, 8),
2132
+ 128: (2, 16),
2133
+ 2048: (16, 64),
2134
+ 256: (4, 8),
2135
+ 4096: (32, 64),
2136
+ 512: (8, 16),
2137
+ 64: (1, 128),
2138
+ 8192: (32, 128),
2139
+ },
2140
+ 'q_head-16_kv_head-1_head-256': {
2141
+ 1024: (8, 16),
2142
+ 128: (2, 32),
2143
+ 2048: (32, 8),
2144
+ 256: (4, 64),
2145
+ 4096: (32, 16),
2146
+ 512: (8, 8),
2147
+ 64: (1, 16),
2148
+ 8192: (32, 16),
2149
+ },
2150
+ 'q_head-16_kv_head-2_head-128': {
2151
+ 1024: (16, 16),
2152
+ 128: (2, 64),
2153
+ 2048: (16, 16),
2154
+ 256: (4, 128),
2155
+ 4096: (32, 32),
2156
+ 512: (8, 64),
2157
+ 64: (1, 16),
2158
+ 8192: (32, 64),
2159
+ },
2160
+ 'q_head-16_kv_head-2_head-256': {
2161
+ 1024: (16, 16),
2162
+ 128: (2, 8),
2163
+ 2048: (16, 32),
2164
+ 256: (4, 8),
2165
+ 4096: (8, 32),
2166
+ 512: (8, 16),
2167
+ 64: (1, 8),
2168
+ 8192: (32, 32),
2169
+ },
2170
+ 'q_head-16_kv_head-4_head-128': {
2171
+ 1024: (8, 64),
2172
+ 128: (2, 32),
2173
+ 2048: (16, 32),
2174
+ 256: (4, 128),
2175
+ 4096: (16, 32),
2176
+ 512: (4, 128),
2177
+ 64: (1, 16),
2178
+ 8192: (16, 128),
2179
+ },
2180
+ 'q_head-16_kv_head-4_head-256': {
2181
+ 1024: (16, 32),
2182
+ 128: (2, 32),
2183
+ 2048: (16, 128),
2184
+ 256: (4, 32),
2185
+ 4096: (16, 128),
2186
+ 512: (4, 32),
2187
+ 64: (1, 8),
2188
+ 8192: (16, 32),
2189
+ },
2190
+ 'q_head-16_kv_head-8_head-128': {
2191
+ 1024: (8, 64),
2192
+ 128: (2, 32),
2193
+ 2048: (8, 64),
2194
+ 256: (4, 64),
2195
+ 4096: (32, 64),
2196
+ 512: (8, 8),
2197
+ 64: (1, 16),
2198
+ 8192: (8, 128),
2199
+ },
2200
+ 'q_head-16_kv_head-8_head-256': {
2201
+ 1024: (8, 128),
2202
+ 128: (2, 8),
2203
+ 2048: (8, 64),
2204
+ 256: (4, 32),
2205
+ 4096: (8, 128),
2206
+ 512: (8, 64),
2207
+ 64: (1, 8),
2208
+ 8192: (8, 128),
2209
+ },
2210
+ 'q_head-2_kv_head-1_head-128': {
2211
+ 1024: (16, 256),
2212
+ 128: (1, 8),
2213
+ 2048: (32, 32),
2214
+ 256: (4, 16),
2215
+ 4096: (32, 64),
2216
+ 512: (8, 256),
2217
+ 64: (1, 256),
2218
+ 8192: (32, 128),
2219
+ },
2220
+ 'q_head-2_kv_head-1_head-256': {
2221
+ 1024: (8, 64),
2222
+ 2048: (16, 64),
2223
+ 256: (2, 32),
2224
+ 4096: (32, 128),
2225
+ 512: (8, 32),
2226
+ 8192: (32, 64),
2227
+ },
2228
+ 'q_head-32_kv_head-1_head-128': {
2229
+ 1024: (16, 16),
2230
+ 128: (2, 16),
2231
+ 2048: (16, 16),
2232
+ 256: (4, 8),
2233
+ 4096: (32, 16),
2234
+ 512: (8, 16),
2235
+ 64: (1, 32),
2236
+ 8192: (32, 32),
2237
+ },
2238
+ 'q_head-32_kv_head-1_head-256': {
2239
+ 1024: (8, 16),
2240
+ 128: (2, 16),
2241
+ 2048: (16, 8),
2242
+ 256: (4, 16),
2243
+ 4096: (32, 32),
2244
+ 512: (8, 16),
2245
+ 64: (1, 16),
2246
+ 8192: (32, 16),
2247
+ },
2248
+ 'q_head-32_kv_head-16_head-128': {
2249
+ 1024: (16, 64),
2250
+ 128: (2, 64),
2251
+ 2048: (16, 64),
2252
+ 256: (2, 32),
2253
+ 4096: (16, 64),
2254
+ 512: (8, 32),
2255
+ 64: (1, 8),
2256
+ 8192: (16, 64),
2257
+ },
2258
+ 'q_head-32_kv_head-16_head-256': {
2259
+ 1024: (8, 32),
2260
+ 128: (2, 8),
2261
+ 2048: (8, 32),
2262
+ 256: (4, 8),
2263
+ 4096: (8, 32),
2264
+ 512: (8, 32),
2265
+ 64: (1, 16),
2266
+ 8192: (4, 32),
2267
+ },
2268
+ 'q_head-32_kv_head-2_head-128': {
2269
+ 1024: (16, 16),
2270
+ 128: (2, 32),
2271
+ 2048: (16, 16),
2272
+ 256: (4, 8),
2273
+ 4096: (32, 64),
2274
+ 512: (8, 32),
2275
+ 64: (1, 8),
2276
+ 8192: (32, 64),
2277
+ },
2278
+ 'q_head-32_kv_head-2_head-256': {
2279
+ 1024: (16, 32),
2280
+ 128: (2, 8),
2281
+ 2048: (32, 32),
2282
+ 256: (4, 8),
2283
+ 4096: (16, 32),
2284
+ 512: (8, 32),
2285
+ 64: (1, 8),
2286
+ 8192: (32, 32),
2287
+ },
2288
+ 'q_head-32_kv_head-4_head-128': {
2289
+ 1024: (8, 32),
2290
+ 128: (1, 64),
2291
+ 2048: (32, 16),
2292
+ 256: (4, 32),
2293
+ 4096: (16, 16),
2294
+ 512: (8, 16),
2295
+ 64: (1, 8),
2296
+ 8192: (16, 32),
2297
+ },
2298
+ 'q_head-32_kv_head-4_head-256': {
2299
+ 1024: (8, 32),
2300
+ 128: (2, 16),
2301
+ 2048: (16, 32),
2302
+ 256: (4, 16),
2303
+ 4096: (16, 32),
2304
+ 512: (4, 16),
2305
+ 64: (1, 16),
2306
+ 8192: (16, 32),
2307
+ },
2308
+ 'q_head-32_kv_head-8_head-128': {
2309
+ 1024: (16, 32),
2310
+ 128: (2, 16),
2311
+ 2048: (16, 32),
2312
+ 256: (2, 16),
2313
+ 4096: (32, 32),
2314
+ 512: (8, 32),
2315
+ 64: (1, 16),
2316
+ 8192: (32, 32),
2317
+ },
2318
+ 'q_head-32_kv_head-8_head-256': {
2319
+ 1024: (8, 32),
2320
+ 128: (2, 16),
2321
+ 2048: (8, 64),
2322
+ 256: (4, 16),
2323
+ 4096: (16, 64),
2324
+ 512: (8, 32),
2325
+ 64: (1, 16),
2326
+ 8192: (8, 64),
2327
+ },
2328
+ 'q_head-4_kv_head-1_head-128': {
2329
+ 1024: (16, 32),
2330
+ 128: (2, 16),
2331
+ 2048: (32, 128),
2332
+ 256: (4, 8),
2333
+ 4096: (32, 16),
2334
+ 512: (4, 32),
2335
+ 64: (1, 32),
2336
+ 8192: (32, 128),
2337
+ },
2338
+ 'q_head-4_kv_head-1_head-256': {
2339
+ 1024: (16, 128),
2340
+ 128: (1, 32),
2341
+ 2048: (32, 32),
2342
+ 256: (4, 32),
2343
+ 4096: (32, 64),
2344
+ 512: (8, 64),
2345
+ 64: (1, 128),
2346
+ 8192: (32, 64),
2347
+ },
2348
+ 'q_head-4_kv_head-2_head-128': {
2349
+ 1024: (16, 256),
2350
+ 128: (2, 256),
2351
+ 2048: (32, 32),
2352
+ 256: (4, 8),
2353
+ 4096: (32, 64),
2354
+ 512: (8, 32),
2355
+ 64: (1, 32),
2356
+ 8192: (32, 64),
2357
+ },
2358
+ 'q_head-4_kv_head-2_head-256': {
2359
+ 1024: (8, 64),
2360
+ 128: (2, 32),
2361
+ 2048: (32, 128),
2362
+ 256: (4, 8),
2363
+ 4096: (32, 128),
2364
+ 512: (8, 16),
2365
+ 64: (1, 16),
2366
+ 8192: (16, 128),
2367
+ },
2368
+ 'q_head-64_kv_head-1_head-128': {
2369
+ 1024: (16, 16),
2370
+ 128: (2, 16),
2371
+ 2048: (32, 16),
2372
+ 256: (4, 8),
2373
+ 4096: (32, 16),
2374
+ 512: (8, 8),
2375
+ 64: (1, 16),
2376
+ },
2377
+ 'q_head-64_kv_head-1_head-256': {
2378
+ 1024: (16, 16),
2379
+ 128: (2, 16),
2380
+ 2048: (32, 8),
2381
+ 256: (2, 8),
2382
+ 4096: (32, 8),
2383
+ 512: (8, 8),
2384
+ 64: (1, 8),
2385
+ },
2386
+ 'q_head-64_kv_head-16_head-128': {
2387
+ 1024: (16, 32),
2388
+ 128: (2, 16),
2389
+ 256: (4, 16),
2390
+ 4096: (8, 32),
2391
+ 512: (8, 16),
2392
+ 64: (1, 16),
2393
+ 8192: (16, 32),
2394
+ },
2395
+ 'q_head-64_kv_head-16_head-256': {
2396
+ 1024: (4, 16),
2397
+ 128: (2, 16),
2398
+ 2048: (8, 16),
2399
+ 256: (4, 16),
2400
+ 4096: (8, 16),
2401
+ 512: (8, 16),
2402
+ 64: (1, 16),
2403
+ 8192: (8, 16),
2404
+ },
2405
+ 'q_head-64_kv_head-2_head-128': {
2406
+ 1024: (16, 16),
2407
+ 128: (2, 32),
2408
+ 2048: (32, 32),
2409
+ 256: (4, 16),
2410
+ 4096: (32, 16),
2411
+ 512: (8, 64),
2412
+ 64: (1, 32),
2413
+ },
2414
+ 'q_head-64_kv_head-2_head-256': {
2415
+ 1024: (16, 16),
2416
+ 128: (2, 16),
2417
+ 2048: (32, 16),
2418
+ 256: (4, 8),
2419
+ 4096: (16, 16),
2420
+ 512: (8, 8),
2421
+ 64: (1, 8),
2422
+ 8192: (32, 16),
2423
+ },
2424
+ 'q_head-64_kv_head-4_head-128': {
2425
+ 1024: (8, 16),
2426
+ 128: (1, 8),
2427
+ 2048: (16, 32),
2428
+ 256: (4, 8),
2429
+ 4096: (16, 16),
2430
+ 512: (8, 64),
2431
+ 64: (1, 8),
2432
+ 8192: (16, 32),
2433
+ },
2434
+ 'q_head-64_kv_head-4_head-256': {
2435
+ 1024: (16, 16),
2436
+ 2048: (16, 32),
2437
+ 256: (4, 8),
2438
+ 4096: (16, 16),
2439
+ 64: (1, 8),
2440
+ 8192: (16, 32),
2441
+ },
2442
+ 'q_head-64_kv_head-8_head-128': {
2443
+ 1024: (16, 64),
2444
+ 128: (2, 16),
2445
+ 2048: (16, 32),
2446
+ 256: (4, 16),
2447
+ 4096: (16, 64),
2448
+ 64: (1, 32),
2449
+ 8192: (16, 32),
2450
+ },
2451
+ 'q_head-64_kv_head-8_head-256': {
2452
+ 1024: (8, 32),
2453
+ 128: (2, 8),
2454
+ 2048: (16, 32),
2455
+ 256: (4, 16),
2456
+ 4096: (16, 32),
2457
+ 512: (8, 32),
2458
+ 64: (1, 8),
2459
+ 8192: (16, 32),
2460
+ },
2461
+ 'q_head-8_kv_head-1_head-128': {
2462
+ 1024: (16, 64),
2463
+ 128: (2, 64),
2464
+ 2048: (32, 32),
2465
+ 256: (4, 128),
2466
+ 4096: (32, 32),
2467
+ 512: (8, 8),
2468
+ 64: (1, 128),
2469
+ 8192: (32, 32),
2470
+ },
2471
+ 'q_head-8_kv_head-1_head-256': {
2472
+ 1024: (16, 64),
2473
+ 128: (2, 32),
2474
+ 2048: (32, 32),
2475
+ 256: (4, 16),
2476
+ 4096: (32, 64),
2477
+ 512: (8, 8),
2478
+ 64: (1, 32),
2479
+ 8192: (32, 32),
2480
+ },
2481
+ 'q_head-8_kv_head-2_head-128': {
2482
+ 1024: (16, 64),
2483
+ 128: (2, 64),
2484
+ 2048: (32, 32),
2485
+ 256: (4, 128),
2486
+ 4096: (32, 32),
2487
+ 512: (8, 128),
2488
+ 64: (1, 16),
2489
+ 8192: (32, 32),
2490
+ },
2491
+ 'q_head-8_kv_head-2_head-256': {
2492
+ 1024: (16, 128),
2493
+ 128: (2, 64),
2494
+ 2048: (32, 32),
2495
+ 256: (4, 8),
2496
+ 4096: (16, 32),
2497
+ 512: (8, 64),
2498
+ 64: (1, 16),
2499
+ 8192: (32, 128),
2500
+ },
2501
+ 'q_head-8_kv_head-4_head-128': {
2502
+ 1024: (16, 32),
2503
+ 128: (2, 32),
2504
+ 2048: (32, 64),
2505
+ 256: (4, 32),
2506
+ 4096: (16, 64),
2507
+ 512: (8, 64),
2508
+ 64: (1, 16),
2509
+ 8192: (16, 64),
2510
+ },
2511
+ 'q_head-8_kv_head-4_head-256': {
2512
+ 1024: (8, 32),
2513
+ 128: (2, 32),
2514
+ 2048: (8, 128),
2515
+ 256: (4, 64),
2516
+ 4096: (8, 128),
2517
+ 512: (8, 128),
2518
+ 64: (1, 64),
2519
+ 8192: (8, 128),
2520
+ },
2521
+ }
2522
+ },
2523
+ },
2524
+ 'TPU v7': {
2525
+ 256: {
2526
+ 'q_bfloat16_kv_bfloat16': {
2527
+ 'q_head-8_kv_head-4_head-256': {
2528
+ 2048: (8, 64),
2529
+ 4096: (16, 64),
2530
+ 8192: (16, 64),
2531
+ 256: (1, 64),
2532
+ 512: (2, 64),
2533
+ 1024: (4, 32),
2534
+ },
2535
+ 'q_head-16_kv_head-4_head-128': {
2536
+ 256: (1, 8),
2537
+ 512: (2, 128),
2538
+ 1024: (4, 16),
2539
+ 2048: (8, 16),
2540
+ 4096: (16, 8),
2541
+ 8192: (16, 16),
2542
+ },
2543
+ 'q_head-32_kv_head-16_head-256': {
2544
+ 4096: (2, 16),
2545
+ 8192: (2, 16),
2546
+ 256: (1, 16),
2547
+ 512: (2, 8),
2548
+ 1024: (2, 16),
2549
+ 2048: (2, 16),
2550
+ },
2551
+ 'q_head-32_kv_head-2_head-256': {
2552
+ 1024: (4, 8),
2553
+ 2048: (8, 8),
2554
+ 4096: (16, 8),
2555
+ 8192: (16, 32),
2556
+ 256: (1, 8),
2557
+ 512: (2, 8),
2558
+ },
2559
+ 'q_head-64_kv_head-2_head-128': {
2560
+ 4096: (16, 16),
2561
+ 8192: (16, 16),
2562
+ 256: (1, 8),
2563
+ 512: (2, 16),
2564
+ 1024: (4, 16),
2565
+ 2048: (8, 16),
2566
+ },
2567
+ 'q_head-64_kv_head-16_head-128': {
2568
+ 256: (1, 8),
2569
+ 512: (2, 16),
2570
+ 1024: (4, 16),
2571
+ 2048: (2, 16),
2572
+ 4096: (4, 16),
2573
+ 8192: (4, 16),
2574
+ },
2575
+ 'q_head-128_kv_head-8_head-256': {
2576
+ 1024: (4, 8),
2577
+ 2048: (4, 8),
2578
+ 4096: (8, 8),
2579
+ 8192: (8, 8),
2580
+ 256: (1, 8),
2581
+ 512: (2, 8),
2582
+ },
2583
+ 'q_head-4_kv_head-2_head-128': {
2584
+ 2048: (8, 16),
2585
+ 4096: (16, 32),
2586
+ 8192: (16, 64),
2587
+ 256: (1, 32),
2588
+ 512: (2, 32),
2589
+ 1024: (4, 128),
2590
+ },
2591
+ 'q_head-4_kv_head-1_head-256': {
2592
+ 8192: (16, 64),
2593
+ 256: (1, 16),
2594
+ 512: (2, 128),
2595
+ 1024: (4, 16),
2596
+ 2048: (8, 8),
2597
+ 4096: (16, 16),
2598
+ },
2599
+ 'q_head-128_kv_head-2_head-128': {
2600
+ 256: (1, 8),
2601
+ 512: (2, 8),
2602
+ 1024: (4, 16),
2603
+ 2048: (8, 16),
2604
+ 4096: (8, 8),
2605
+ 8192: (8, 16),
2606
+ },
2607
+ 'q_head-64_kv_head-2_head-256': {
2608
+ 256: (1, 8),
2609
+ 512: (2, 8),
2610
+ 1024: (4, 16),
2611
+ 2048: (8, 8),
2612
+ 4096: (16, 16),
2613
+ 8192: (16, 16),
2614
+ },
2615
+ 'q_head-128_kv_head-16_head-128': {
2616
+ 256: (1, 8),
2617
+ 512: (2, 8),
2618
+ 1024: (4, 8),
2619
+ 2048: (4, 8),
2620
+ 4096: (4, 8),
2621
+ 8192: (4, 8),
2622
+ },
2623
+ 'q_head-4_kv_head-2_head-256': {
2624
+ 256: (1, 128),
2625
+ 512: (2, 128),
2626
+ 1024: (4, 64),
2627
+ 2048: (8, 32),
2628
+ 4096: (16, 32),
2629
+ 8192: (16, 16),
2630
+ },
2631
+ 'q_head-32_kv_head-4_head-128': {
2632
+ 256: (1, 16),
2633
+ 512: (2, 8),
2634
+ 1024: (4, 8),
2635
+ 2048: (8, 64),
2636
+ 4096: (8, 16),
2637
+ 8192: (16, 16),
2638
+ },
2639
+ 'q_head-8_kv_head-1_head-128': {
2640
+ 256: (1, 256),
2641
+ 512: (2, 128),
2642
+ 1024: (4, 128),
2643
+ 2048: (8, 32),
2644
+ 4096: (16, 32),
2645
+ 8192: (16, 128),
2646
+ },
2647
+ 'q_head-64_kv_head-16_head-256': {
2648
+ 256: (1, 8),
2649
+ 512: (2, 8),
2650
+ 1024: (2, 8),
2651
+ 2048: (2, 8),
2652
+ 4096: (2, 8),
2653
+ 8192: (2, 8),
2654
+ },
2655
+ 'q_head-16_kv_head-4_head-256': {
2656
+ 256: (1, 16),
2657
+ 512: (2, 16),
2658
+ 1024: (4, 64),
2659
+ 2048: (8, 16),
2660
+ 4096: (16, 32),
2661
+ 8192: (16, 64),
2662
+ },
2663
+ 'q_head-16_kv_head-2_head-256': {
2664
+ 256: (1, 8),
2665
+ 512: (2, 16),
2666
+ 1024: (4, 16),
2667
+ 2048: (8, 8),
2668
+ 4096: (16, 64),
2669
+ 8192: (16, 32),
2670
+ },
2671
+ 'q_head-32_kv_head-16_head-128': {
2672
+ 4096: (4, 32),
2673
+ 8192: (4, 32),
2674
+ 256: (1, 8),
2675
+ 512: (2, 16),
2676
+ 1024: (4, 32),
2677
+ 2048: (4, 32),
2678
+ },
2679
+ 'q_head-32_kv_head-2_head-128': {
2680
+ 1024: (4, 8),
2681
+ 2048: (8, 8),
2682
+ 256: (1, 64),
2683
+ 4096: (16, 32),
2684
+ 512: (2, 8),
2685
+ 8192: (16, 32),
2686
+ },
2687
+ 'q_head-64_kv_head-1_head-256': {
2688
+ 4096: (8, 16),
2689
+ 8192: (16, 8),
2690
+ 256: (1, 16),
2691
+ 512: (2, 8),
2692
+ 1024: (4, 8),
2693
+ 2048: (8, 8),
2694
+ },
2695
+ 'q_head-64_kv_head-8_head-256': {
2696
+ 256: (1, 16),
2697
+ 512: (2, 16),
2698
+ 1024: (4, 16),
2699
+ 2048: (8, 16),
2700
+ 4096: (8, 16),
2701
+ 8192: (8, 16),
2702
+ },
2703
+ 'q_head-128_kv_head-8_head-128': {
2704
+ 2048: (8, 16),
2705
+ 4096: (8, 16),
2706
+ 8192: (8, 16),
2707
+ 256: (1, 8),
2708
+ 512: (2, 16),
2709
+ 1024: (4, 8),
2710
+ },
2711
+ 'q_head-2_kv_head-1_head-256': {
2712
+ 2048: (8, 32),
2713
+ 4096: (16, 32),
2714
+ 8192: (16, 32),
2715
+ 256: (1, 128),
2716
+ 512: (2, 8),
2717
+ 1024: (4, 64),
2718
+ },
2719
+ 'q_head-4_kv_head-1_head-128': {
2720
+ 8192: (16, 16),
2721
+ 256: (1, 16),
2722
+ 512: (2, 16),
2723
+ 1024: (4, 128),
2724
+ 2048: (8, 8),
2725
+ 4096: (16, 16),
2726
+ },
2727
+ 'q_head-64_kv_head-32_head-128': {
2728
+ 256: (1, 8),
2729
+ 512: (2, 8),
2730
+ 1024: (2, 8),
2731
+ 2048: (2, 8),
2732
+ 4096: (2, 8),
2733
+ 8192: (4, 8),
2734
+ },
2735
+ 'q_head-128_kv_head-2_head-256': {
2736
+ 256: (1, 8),
2737
+ 512: (2, 8),
2738
+ 1024: (4, 8),
2739
+ 2048: (8, 8),
2740
+ 4096: (16, 8),
2741
+ 8192: (8, 16),
2742
+ },
2743
+ 'q_head-16_kv_head-8_head-128': {
2744
+ 256: (1, 32),
2745
+ 512: (2, 16),
2746
+ 1024: (4, 32),
2747
+ 2048: (8, 32),
2748
+ 4096: (16, 64),
2749
+ 8192: (16, 64),
2750
+ },
2751
+ 'q_head-64_kv_head-4_head-128': {
2752
+ 256: (1, 32),
2753
+ 512: (2, 16),
2754
+ 1024: (4, 8),
2755
+ 2048: (8, 16),
2756
+ 4096: (16, 8),
2757
+ 8192: (16, 16),
2758
+ },
2759
+ 'q_head-16_kv_head-1_head-128': {
2760
+ 256: (1, 16),
2761
+ 512: (2, 8),
2762
+ 1024: (4, 64),
2763
+ 2048: (8, 16),
2764
+ 4096: (16, 128),
2765
+ 8192: (16, 16),
2766
+ },
2767
+ 'q_head-32_kv_head-4_head-256': {
2768
+ 256: (1, 16),
2769
+ 512: (2, 16),
2770
+ 1024: (4, 64),
2771
+ 2048: (8, 32),
2772
+ 4096: (16, 16),
2773
+ 8192: (16, 32),
2774
+ },
2775
+ 'q_head-8_kv_head-1_head-256': {
2776
+ 256: (1, 128),
2777
+ 512: (2, 128),
2778
+ 1024: (4, 16),
2779
+ 2048: (8, 8),
2780
+ 4096: (8, 16),
2781
+ 8192: (16, 32),
2782
+ },
2783
+ 'q_head-128_kv_head-4_head-128': {
2784
+ 256: (1, 16),
2785
+ 512: (2, 16),
2786
+ 1024: (4, 8),
2787
+ 2048: (8, 16),
2788
+ 4096: (16, 8),
2789
+ 8192: (8, 16),
2790
+ },
2791
+ 'q_head-16_kv_head-8_head-256': {
2792
+ 256: (1, 16),
2793
+ 512: (2, 16),
2794
+ 1024: (4, 32),
2795
+ 2048: (8, 32),
2796
+ 4096: (8, 32),
2797
+ 8192: (8, 32),
2798
+ },
2799
+ 'q_head-8_kv_head-4_head-128': {
2800
+ 256: (1, 128),
2801
+ 512: (2, 32),
2802
+ 1024: (4, 16),
2803
+ 2048: (8, 64),
2804
+ 4096: (16, 256),
2805
+ 8192: (16, 64),
2806
+ },
2807
+ 'q_head-32_kv_head-1_head-128': {
2808
+ 256: (1, 16),
2809
+ 512: (2, 8),
2810
+ 1024: (4, 64),
2811
+ 2048: (8, 16),
2812
+ 4096: (16, 16),
2813
+ 8192: (16, 8),
2814
+ },
2815
+ 'q_head-64_kv_head-4_head-256': {
2816
+ 256: (1, 8),
2817
+ 512: (2, 32),
2818
+ 1024: (4, 16),
2819
+ 2048: (8, 16),
2820
+ 4096: (8, 32),
2821
+ 8192: (16, 16),
2822
+ },
2823
+ 'q_head-2_kv_head-1_head-128': {
2824
+ 256: (1, 256),
2825
+ 512: (2, 256),
2826
+ 1024: (4, 128),
2827
+ 2048: (8, 128),
2828
+ 4096: (16, 128),
2829
+ 8192: (16, 32),
2830
+ },
2831
+ 'q_head-16_kv_head-1_head-256': {
2832
+ 256: (1, 32),
2833
+ 512: (2, 8),
2834
+ 1024: (4, 64),
2835
+ 2048: (8, 32),
2836
+ 4096: (16, 16),
2837
+ 8192: (16, 16),
2838
+ },
2839
+ 'q_head-32_kv_head-8_head-128': {
2840
+ 256: (1, 64),
2841
+ 512: (2, 32),
2842
+ 1024: (4, 16),
2843
+ 2048: (8, 16),
2844
+ 4096: (16, 32),
2845
+ 8192: (16, 16),
2846
+ },
2847
+ 'q_head-8_kv_head-2_head-128': {
2848
+ 256: (1, 128),
2849
+ 512: (2, 32),
2850
+ 1024: (4, 128),
2851
+ 2048: (8, 16),
2852
+ 4096: (16, 32),
2853
+ 8192: (16, 64),
2854
+ },
2855
+ 'q_head-64_kv_head-1_head-128': {
2856
+ 256: (1, 16),
2857
+ 512: (2, 32),
2858
+ 1024: (4, 8),
2859
+ 2048: (8, 16),
2860
+ 4096: (8, 8),
2861
+ 8192: (16, 32),
2862
+ },
2863
+ 'q_head-128_kv_head-4_head-256': {
2864
+ 256: (1, 16),
2865
+ 512: (2, 8),
2866
+ 1024: (4, 16),
2867
+ 2048: (8, 16),
2868
+ 4096: (8, 8),
2869
+ 8192: (8, 16),
2870
+ },
2871
+ 'q_head-32_kv_head-1_head-256': {
2872
+ 256: (1, 32),
2873
+ 512: (2, 32),
2874
+ 1024: (4, 16),
2875
+ 2048: (4, 32),
2876
+ 4096: (16, 8),
2877
+ 8192: (16, 16),
2878
+ },
2879
+ 'q_head-64_kv_head-8_head-128': {
2880
+ 256: (1, 16),
2881
+ 512: (2, 8),
2882
+ 1024: (4, 16),
2883
+ 2048: (8, 16),
2884
+ 4096: (16, 16),
2885
+ 8192: (16, 16),
2886
+ },
2887
+ 'q_head-16_kv_head-2_head-128': {
2888
+ 256: (1, 128),
2889
+ 512: (2, 32),
2890
+ 1024: (4, 32),
2891
+ 2048: (8, 8),
2892
+ 4096: (16, 8),
2893
+ 8192: (16, 32),
2894
+ },
2895
+ 'q_head-32_kv_head-8_head-256': {
2896
+ 256: (1, 16),
2897
+ 512: (2, 8),
2898
+ 1024: (4, 16),
2899
+ 2048: (8, 16),
2900
+ 4096: (8, 32),
2901
+ 8192: (8, 32),
2902
+ },
2903
+ 'q_head-8_kv_head-2_head-256': {
2904
+ 256: (1, 8),
2905
+ 512: (2, 128),
2906
+ 1024: (4, 8),
2907
+ 2048: (8, 32),
2908
+ 4096: (16, 16),
2909
+ 8192: (16, 128),
2910
+ },
2911
+ 'q_head-128_kv_head-1_head-128': {
2912
+ 256: (1, 8),
2913
+ 512: (2, 16),
2914
+ 1024: (4, 8),
2915
+ 2048: (8, 8),
2916
+ 4096: (8, 8),
2917
+ 8192: (8, 16),
2918
+ },
2919
+ 'q_head-128_kv_head-1_head-256': {
2920
+ 256: (1, 8),
2921
+ 512: (2, 8),
2922
+ 1024: (4, 8),
2923
+ 2048: (4, 8),
2924
+ 4096: (16, 8),
2925
+ 8192: (16, 8),
2926
+ },
2927
+ },
2928
+ 'q_bfloat16_kv_float8_e4m3fn': {
2929
+ 'q_head-16_kv_head-4_head-128': {
2930
+ 2048: (8, 16),
2931
+ 4096: (16, 64),
2932
+ 8192: (16, 16),
2933
+ 256: (1, 16),
2934
+ 512: (2, 32),
2935
+ 1024: (4, 8),
2936
+ },
2937
+ 'q_head-32_kv_head-2_head-256': {
2938
+ 8192: (16, 32),
2939
+ 256: (1, 16),
2940
+ 512: (2, 8),
2941
+ 1024: (4, 8),
2942
+ 2048: (8, 16),
2943
+ 4096: (16, 16),
2944
+ },
2945
+ 'q_head-32_kv_head-16_head-256': {
2946
+ 2048: (8, 8),
2947
+ 4096: (8, 16),
2948
+ 8192: (8, 16),
2949
+ 512: (2, 16),
2950
+ 1024: (4, 16),
2951
+ 256: (1, 16),
2952
+ },
2953
+ 'q_head-64_kv_head-16_head-128': {
2954
+ 8192: (16, 8),
2955
+ 256: (1, 16),
2956
+ 512: (2, 16),
2957
+ 1024: (4, 8),
2958
+ 2048: (8, 8),
2959
+ 4096: (8, 16),
2960
+ },
2961
+ 'q_head-128_kv_head-2_head-128': {
2962
+ 1024: (4, 16),
2963
+ 2048: (8, 8),
2964
+ 4096: (8, 16),
2965
+ 8192: (16, 8),
2966
+ 256: (1, 16),
2967
+ 512: (2, 8),
2968
+ },
2969
+ 'q_head-64_kv_head-2_head-256': {
2970
+ 256: (1, 8),
2971
+ 512: (2, 32),
2972
+ 1024: (2, 16),
2973
+ 2048: (8, 16),
2974
+ 4096: (16, 16),
2975
+ 8192: (16, 8),
2976
+ },
2977
+ 'q_head-128_kv_head-16_head-128': {
2978
+ 256: (1, 8),
2979
+ 512: (2, 8),
2980
+ 1024: (4, 8),
2981
+ 2048: (8, 8),
2982
+ 4096: (8, 8),
2983
+ 8192: (8, 8),
2984
+ },
2985
+ 'q_head-32_kv_head-4_head-128': {
2986
+ 256: (1, 32),
2987
+ 512: (2, 16),
2988
+ 1024: (4, 64),
2989
+ 2048: (8, 8),
2990
+ 4096: (16, 16),
2991
+ 8192: (16, 32),
2992
+ },
2993
+ 'q_head-64_kv_head-16_head-256': {
2994
+ 256: (1, 8),
2995
+ 512: (2, 8),
2996
+ 1024: (4, 8),
2997
+ 2048: (8, 8),
2998
+ 4096: (8, 8),
2999
+ 8192: (8, 8),
3000
+ },
3001
+ 'q_head-16_kv_head-4_head-256': {
3002
+ 256: (1, 64),
3003
+ 512: (2, 32),
3004
+ 1024: (4, 8),
3005
+ 2048: (8, 8),
3006
+ 4096: (16, 64),
3007
+ 8192: (16, 32),
3008
+ },
3009
+ 'q_head-64_kv_head-32_head-128': {
3010
+ 256: (1, 8),
3011
+ 512: (2, 8),
3012
+ 1024: (4, 8),
3013
+ 2048: (4, 8),
3014
+ 4096: (4, 8),
3015
+ 8192: (8, 8),
3016
+ },
3017
+ 'q_head-128_kv_head-2_head-256': {
3018
+ 256: (1, 32),
3019
+ 512: (2, 8),
3020
+ 1024: (4, 8),
3021
+ 2048: (8, 16),
3022
+ 4096: (8, 16),
3023
+ 8192: (8, 16),
3024
+ },
3025
+ 'q_head-16_kv_head-2_head-256': {
3026
+ 2048: (8, 8),
3027
+ 256: (1, 64),
3028
+ 4096: (16, 32),
3029
+ 512: (2, 16),
3030
+ 1024: (4, 32),
3031
+ 8192: (16, 32),
3032
+ },
3033
+ 'q_head-32_kv_head-2_head-128': {
3034
+ 8192: (16, 16),
3035
+ 256: (1, 16),
3036
+ 512: (2, 16),
3037
+ 1024: (4, 8),
3038
+ 2048: (8, 8),
3039
+ 4096: (8, 16),
3040
+ },
3041
+ 'q_head-32_kv_head-16_head-128': {
3042
+ 2048: (8, 32),
3043
+ 4096: (16, 16),
3044
+ 256: (1, 32),
3045
+ 512: (2, 16),
3046
+ 1024: (4, 16),
3047
+ 8192: (16, 16),
3048
+ },
3049
+ 'q_head-64_kv_head-2_head-128': {
3050
+ 256: (1, 8),
3051
+ 512: (2, 8),
3052
+ 1024: (4, 8),
3053
+ 2048: (4, 8),
3054
+ 4096: (16, 8),
3055
+ 8192: (16, 16),
3056
+ },
3057
+ 'q_head-64_kv_head-8_head-256': {
3058
+ 8192: (16, 16),
3059
+ 256: (1, 16),
3060
+ 512: (2, 8),
3061
+ 1024: (4, 16),
3062
+ 2048: (8, 8),
3063
+ 4096: (16, 8),
3064
+ },
3065
+ 'q_head-16_kv_head-2_head-128': {
3066
+ 256: (1, 8),
3067
+ 512: (2, 16),
3068
+ 1024: (4, 32),
3069
+ 2048: (8, 16),
3070
+ 4096: (16, 8),
3071
+ 8192: (16, 32),
3072
+ },
3073
+ 'q_head-64_kv_head-4_head-128': {
3074
+ 256: (1, 8),
3075
+ 512: (1, 16),
3076
+ 1024: (4, 8),
3077
+ 2048: (8, 16),
3078
+ 4096: (16, 16),
3079
+ 8192: (16, 16),
3080
+ },
3081
+ 'q_head-32_kv_head-4_head-256': {
3082
+ 256: (1, 32),
3083
+ 512: (2, 32),
3084
+ 1024: (4, 16),
3085
+ 2048: (8, 8),
3086
+ 4096: (16, 32),
3087
+ 8192: (16, 16),
3088
+ },
3089
+ 'q_head-16_kv_head-8_head-128': {
3090
+ 256: (1, 64),
3091
+ 512: (2, 32),
3092
+ 1024: (4, 128),
3093
+ 2048: (8, 128),
3094
+ 4096: (16, 16),
3095
+ 8192: (16, 32),
3096
+ },
3097
+ 'q_head-128_kv_head-4_head-128': {
3098
+ 256: (1, 8),
3099
+ 512: (2, 8),
3100
+ 1024: (4, 8),
3101
+ 2048: (8, 16),
3102
+ 4096: (16, 8),
3103
+ 8192: (16, 8),
3104
+ },
3105
+ 'q_head-64_kv_head-4_head-256': {
3106
+ 256: (1, 16),
3107
+ 512: (2, 16),
3108
+ 1024: (4, 16),
3109
+ 2048: (8, 8),
3110
+ 4096: (16, 16),
3111
+ 8192: (16, 16),
3112
+ },
3113
+ 'q_head-32_kv_head-8_head-128': {
3114
+ 256: (1, 32),
3115
+ 512: (2, 8),
3116
+ 1024: (4, 16),
3117
+ 2048: (8, 32),
3118
+ 4096: (16, 16),
3119
+ 8192: (16, 32),
3120
+ },
3121
+ 'q_head-16_kv_head-8_head-256': {
3122
+ 256: (1, 16),
3123
+ 512: (2, 64),
3124
+ 1024: (4, 32),
3125
+ 2048: (8, 32),
3126
+ 4096: (16, 32),
3127
+ 8192: (16, 32),
3128
+ },
3129
+ 'q_head-128_kv_head-4_head-256': {
3130
+ 256: (1, 8),
3131
+ 512: (2, 8),
3132
+ 1024: (4, 16),
3133
+ 2048: (8, 8),
3134
+ 4096: (8, 16),
3135
+ 8192: (16, 8),
3136
+ },
3137
+ 'q_head-64_kv_head-8_head-128': {
3138
+ 256: (1, 8),
3139
+ 512: (2, 16),
3140
+ 1024: (4, 8),
3141
+ 2048: (8, 32),
3142
+ 4096: (16, 16),
3143
+ 8192: (16, 16),
3144
+ },
3145
+ 'q_head-32_kv_head-8_head-256': {
3146
+ 256: (1, 16),
3147
+ 512: (2, 16),
3148
+ 1024: (4, 8),
3149
+ 2048: (8, 32),
3150
+ 4096: (16, 32),
3151
+ 8192: (16, 32),
3152
+ },
3153
+ 'q_head-128_kv_head-8_head-256': {
3154
+ 256: (1, 8),
3155
+ 512: (2, 8),
3156
+ 1024: (4, 8),
3157
+ 2048: (4, 8),
3158
+ 4096: (16, 8),
3159
+ 8192: (16, 8),
3160
+ },
3161
+ 'q_head-128_kv_head-8_head-128': {
3162
+ 256: (1, 8),
3163
+ 512: (2, 8),
3164
+ 1024: (4, 16),
3165
+ 2048: (8, 16),
3166
+ 4096: (16, 8),
3167
+ 8192: (16, 16),
3168
+ },
3169
+ 'q_head-4_kv_head-2_head-256': {
3170
+ 8192: (16, 32),
3171
+ 256: (1, 8),
3172
+ 512: (2, 32),
3173
+ 1024: (4, 16),
3174
+ 2048: (8, 16),
3175
+ 4096: (16, 16),
3176
+ },
3177
+ 'q_head-8_kv_head-2_head-128': {
3178
+ 1024: (4, 64),
3179
+ 2048: (8, 32),
3180
+ 4096: (16, 8),
3181
+ 8192: (16, 32),
3182
+ 256: (1, 128),
3183
+ 512: (2, 128),
3184
+ },
3185
+ 'q_head-8_kv_head-2_head-256': {
3186
+ 256: (1, 8),
3187
+ 512: (2, 16),
3188
+ 1024: (4, 8),
3189
+ 2048: (8, 16),
3190
+ 4096: (16, 64),
3191
+ 8192: (16, 32),
3192
+ },
3193
+ 'q_head-4_kv_head-2_head-128': {
3194
+ 8192: (16, 32),
3195
+ 256: (1, 64),
3196
+ 512: (2, 128),
3197
+ 1024: (4, 16),
3198
+ 2048: (8, 64),
3199
+ 4096: (16, 64),
3200
+ },
3201
+ 'q_head-2_kv_head-2_head-128': {
3202
+ 256: (1, 64),
3203
+ 512: (2, 128),
3204
+ 1024: (4, 256),
3205
+ 2048: (8, 32),
3206
+ 4096: (16, 32),
3207
+ 8192: (16, 32),
3208
+ },
3209
+ 'q_head-8_kv_head-4_head-128': {
3210
+ 256: (1, 32),
3211
+ 512: (2, 16),
3212
+ 1024: (4, 8),
3213
+ 2048: (8, 16),
3214
+ 4096: (16, 32),
3215
+ 8192: (16, 32),
3216
+ },
3217
+ 'q_head-8_kv_head-4_head-256': {
3218
+ 256: (1, 8),
3219
+ 512: (2, 32),
3220
+ 1024: (4, 32),
3221
+ 2048: (8, 32),
3222
+ 4096: (16, 16),
3223
+ 8192: (16, 32),
3224
+ },
3225
+ 'q_head-2_kv_head-2_head-256': {
3226
+ 256: (1, 128),
3227
+ 512: (2, 32),
3228
+ 1024: (4, 64),
3229
+ 2048: (8, 32),
3230
+ 4096: (16, 32),
3231
+ 8192: (16, 32),
3232
+ },
3233
+ },
3234
+ },
3235
+ 128: {
3236
+ 'q_bfloat16_kv_bfloat16': {
3237
+ 'q_head-4_kv_head-2_head-128': {
3238
+ 128: (1, 32),
3239
+ 256: (2, 256),
3240
+ 512: (4, 64),
3241
+ 8192: (32, 64),
3242
+ 1024: (8, 32),
3243
+ 2048: (16, 16),
3244
+ 4096: (32, 32),
3245
+ },
3246
+ 'q_head-2_kv_head-1_head-128': {
3247
+ 512: (4, 64),
3248
+ 2048: (16, 128),
3249
+ 256: (2, 256),
3250
+ 1024: (4, 32),
3251
+ 4096: (16, 128),
3252
+ 128: (1, 32),
3253
+ 8192: (32, 128),
3254
+ },
3255
+ 'q_head-16_kv_head-8_head-128': {
3256
+ 256: (2, 64),
3257
+ 512: (4, 16),
3258
+ 1024: (8, 16),
3259
+ 2048: (16, 64),
3260
+ 4096: (32, 32),
3261
+ 8192: (32, 32),
3262
+ 128: (1, 32),
3263
+ },
3264
+ 'q_head-32_kv_head-4_head-256': {
3265
+ 1024: (8, 16),
3266
+ 2048: (16, 16),
3267
+ 4096: (16, 64),
3268
+ 8192: (16, 32),
3269
+ 128: (1, 16),
3270
+ 256: (2, 32),
3271
+ 512: (4, 16),
3272
+ },
3273
+ 'q_head-64_kv_head-4_head-128': {
3274
+ 4096: (32, 16),
3275
+ 8192: (32, 16),
3276
+ 128: (1, 8),
3277
+ 256: (2, 8),
3278
+ 512: (4, 32),
3279
+ 1024: (8, 16),
3280
+ 2048: (16, 16),
3281
+ },
3282
+ 'q_head-16_kv_head-8_head-256': {
3283
+ 128: (1, 32),
3284
+ 256: (2, 32),
3285
+ 512: (4, 32),
3286
+ 1024: (8, 32),
3287
+ 2048: (16, 32),
3288
+ 4096: (16, 32),
3289
+ 8192: (16, 64),
3290
+ },
3291
+ 'q_head-16_kv_head-1_head-128': {
3292
+ 4096: (32, 128),
3293
+ 8192: (32, 16),
3294
+ 128: (1, 8),
3295
+ 256: (2, 256),
3296
+ 512: (4, 64),
3297
+ 1024: (8, 32),
3298
+ 2048: (8, 32),
3299
+ },
3300
+ 'q_head-64_kv_head-32_head-128': {
3301
+ 1024: (4, 8),
3302
+ 2048: (4, 8),
3303
+ 4096: (4, 8),
3304
+ 8192: (8, 8),
3305
+ 128: (1, 8),
3306
+ 256: (2, 8),
3307
+ 512: (4, 8),
3308
+ },
3309
+ 'q_head-128_kv_head-4_head-128': {
3310
+ 128: (1, 8),
3311
+ 256: (2, 8),
3312
+ 512: (2, 8),
3313
+ 1024: (4, 16),
3314
+ 2048: (8, 16),
3315
+ 4096: (16, 16),
3316
+ 8192: (32, 8),
3317
+ },
3318
+ 'q_head-8_kv_head-1_head-256': {
3319
+ 1024: (8, 8),
3320
+ 2048: (16, 64),
3321
+ 4096: (16, 32),
3322
+ 8192: (16, 32),
3323
+ 128: (1, 32),
3324
+ 256: (2, 16),
3325
+ 512: (4, 8),
3326
+ },
3327
+ 'q_head-32_kv_head-1_head-128': {
3328
+ 128: (1, 64),
3329
+ 256: (2, 16),
3330
+ 512: (4, 32),
3331
+ 1024: (8, 16),
3332
+ 2048: (16, 8),
3333
+ 4096: (32, 8),
3334
+ 8192: (32, 64),
3335
+ },
3336
+ 'q_head-64_kv_head-4_head-256': {
3337
+ 128: (1, 16),
3338
+ 256: (2, 32),
3339
+ 512: (4, 16),
3340
+ 1024: (8, 8),
3341
+ 2048: (16, 16),
3342
+ 4096: (16, 16),
3343
+ 8192: (16, 32),
3344
+ },
3345
+ 'q_head-2_kv_head-1_head-256': {
3346
+ 512: (4, 64),
3347
+ 4096: (32, 32),
3348
+ 256: (2, 64),
3349
+ 1024: (8, 16),
3350
+ 8192: (32, 128),
3351
+ 128: (1, 128),
3352
+ 2048: (16, 16),
3353
+ },
3354
+ 'q_head-16_kv_head-1_head-256': {
3355
+ 128: (1, 16),
3356
+ 256: (1, 64),
3357
+ 512: (4, 32),
3358
+ 1024: (4, 16),
3359
+ 2048: (16, 16),
3360
+ 4096: (32, 16),
3361
+ 8192: (32, 16),
3362
+ },
3363
+ 'q_head-32_kv_head-8_head-128': {
3364
+ 128: (1, 16),
3365
+ 256: (2, 64),
3366
+ 512: (4, 32),
3367
+ 1024: (8, 32),
3368
+ 2048: (16, 64),
3369
+ 4096: (32, 32),
3370
+ 8192: (32, 32),
3371
+ },
3372
+ 'q_head-8_kv_head-2_head-128': {
3373
+ 128: (1, 8),
3374
+ 256: (2, 128),
3375
+ 512: (4, 128),
3376
+ 1024: (8, 16),
3377
+ 2048: (16, 16),
3378
+ 4096: (32, 128),
3379
+ 8192: (32, 32),
3380
+ },
3381
+ 'q_head-64_kv_head-1_head-128': {
3382
+ 128: (1, 8),
3383
+ 256: (2, 32),
3384
+ 512: (4, 16),
3385
+ 1024: (8, 16),
3386
+ 2048: (16, 8),
3387
+ 4096: (16, 8),
3388
+ 8192: (32, 8),
3389
+ },
3390
+ 'q_head-4_kv_head-2_head-256': {
3391
+ 128: (1, 128),
3392
+ 1024: (8, 32),
3393
+ 256: (2, 16),
3394
+ 512: (4, 16),
3395
+ 2048: (16, 32),
3396
+ 4096: (32, 128),
3397
+ 8192: (32, 32),
3398
+ },
3399
+ 'q_head-8_kv_head-1_head-128': {
3400
+ 512: (4, 32),
3401
+ 1024: (8, 128),
3402
+ 2048: (16, 8),
3403
+ 128: (1, 16),
3404
+ 4096: (32, 128),
3405
+ 256: (2, 128),
3406
+ 8192: (32, 32),
3407
+ },
3408
+ 'q_head-16_kv_head-4_head-256': {
3409
+ 256: (2, 16),
3410
+ 512: (4, 32),
3411
+ 1024: (8, 16),
3412
+ 2048: (16, 16),
3413
+ 128: (1, 8),
3414
+ 4096: (32, 32),
3415
+ 8192: (32, 32),
3416
+ },
3417
+ 'q_head-32_kv_head-4_head-128': {
3418
+ 1024: (8, 8),
3419
+ 2048: (16, 16),
3420
+ 4096: (32, 16),
3421
+ 128: (1, 16),
3422
+ 256: (2, 64),
3423
+ 8192: (32, 32),
3424
+ 512: (4, 32),
3425
+ },
3426
+ 'q_head-64_kv_head-2_head-256': {
3427
+ 4096: (32, 8),
3428
+ 8192: (32, 8),
3429
+ 128: (1, 8),
3430
+ 256: (2, 16),
3431
+ 512: (4, 8),
3432
+ 1024: (8, 16),
3433
+ 2048: (16, 16),
3434
+ },
3435
+ 'q_head-64_kv_head-16_head-256': {
3436
+ 512: (4, 8),
3437
+ 1024: (4, 8),
3438
+ 2048: (4, 8),
3439
+ 128: (1, 8),
3440
+ 4096: (4, 8),
3441
+ 256: (2, 8),
3442
+ 8192: (4, 8),
3443
+ },
3444
+ 'q_head-128_kv_head-2_head-256': {
3445
+ 128: (1, 16),
3446
+ 256: (2, 8),
3447
+ 512: (4, 16),
3448
+ 1024: (8, 8),
3449
+ 2048: (16, 16),
3450
+ 4096: (16, 8),
3451
+ 8192: (16, 16),
3452
+ },
3453
+ 'q_head-128_kv_head-16_head-128': {
3454
+ 2048: (8, 8),
3455
+ 4096: (8, 8),
3456
+ 8192: (8, 8),
3457
+ 128: (1, 8),
3458
+ 256: (2, 8),
3459
+ 512: (4, 8),
3460
+ 1024: (8, 8),
3461
+ },
3462
+ 'q_head-32_kv_head-1_head-256': {
3463
+ 128: (1, 16),
3464
+ 256: (2, 32),
3465
+ 512: (4, 8),
3466
+ 1024: (8, 8),
3467
+ 2048: (16, 8),
3468
+ 4096: (16, 16),
3469
+ 8192: (32, 16),
3470
+ },
3471
+ 'q_head-128_kv_head-4_head-256': {
3472
+ 128: (1, 8),
3473
+ 256: (2, 8),
3474
+ 512: (4, 8),
3475
+ 1024: (8, 16),
3476
+ 2048: (8, 8),
3477
+ 4096: (16, 8),
3478
+ 8192: (16, 16),
3479
+ },
3480
+ 'q_head-64_kv_head-8_head-128': {
3481
+ 128: (1, 8),
3482
+ 256: (2, 16),
3483
+ 512: (4, 8),
3484
+ 1024: (8, 16),
3485
+ 2048: (16, 32),
3486
+ 4096: (16, 32),
3487
+ 8192: (16, 32),
3488
+ },
3489
+ 'q_head-16_kv_head-2_head-128': {
3490
+ 128: (1, 128),
3491
+ 256: (2, 32),
3492
+ 512: (4, 8),
3493
+ 1024: (8, 16),
3494
+ 2048: (16, 16),
3495
+ 4096: (32, 8),
3496
+ 8192: (32, 32),
3497
+ },
3498
+ 'q_head-32_kv_head-8_head-256': {
3499
+ 128: (1, 32),
3500
+ 256: (2, 32),
3501
+ 512: (4, 8),
3502
+ 1024: (8, 16),
3503
+ 2048: (16, 32),
3504
+ 4096: (16, 32),
3505
+ 8192: (16, 32),
3506
+ },
3507
+ 'q_head-128_kv_head-1_head-128': {
3508
+ 128: (1, 16),
3509
+ 256: (2, 16),
3510
+ 512: (4, 16),
3511
+ 1024: (8, 8),
3512
+ 2048: (16, 8),
3513
+ 4096: (16, 8),
3514
+ 8192: (16, 8),
3515
+ },
3516
+ 'q_head-8_kv_head-2_head-256': {
3517
+ 128: (1, 16),
3518
+ 256: (2, 32),
3519
+ 512: (4, 16),
3520
+ 1024: (8, 16),
3521
+ 2048: (16, 32),
3522
+ 4096: (32, 32),
3523
+ 8192: (32, 64),
3524
+ },
3525
+ 'q_head-32_kv_head-16_head-128': {
3526
+ 128: (1, 32),
3527
+ 256: (2, 32),
3528
+ 512: (4, 32),
3529
+ 1024: (8, 32),
3530
+ 2048: (8, 32),
3531
+ 4096: (8, 32),
3532
+ 8192: (8, 32),
3533
+ },
3534
+ 'q_head-64_kv_head-1_head-256': {
3535
+ 128: (1, 8),
3536
+ 256: (2, 16),
3537
+ 512: (4, 8),
3538
+ 1024: (8, 16),
3539
+ 2048: (16, 8),
3540
+ 4096: (16, 8),
3541
+ 8192: (32, 8),
3542
+ },
3543
+ 'q_head-8_kv_head-4_head-128': {
3544
+ 128: (1, 64),
3545
+ 256: (2, 128),
3546
+ 512: (4, 32),
3547
+ 1024: (8, 64),
3548
+ 2048: (16, 32),
3549
+ 4096: (32, 128),
3550
+ 8192: (32, 32),
3551
+ },
3552
+ 'q_head-32_kv_head-2_head-128': {
3553
+ 128: (1, 8),
3554
+ 256: (2, 16),
3555
+ 512: (4, 8),
3556
+ 1024: (8, 8),
3557
+ 2048: (16, 16),
3558
+ 4096: (16, 16),
3559
+ 8192: (32, 16),
3560
+ },
3561
+ 'q_head-128_kv_head-8_head-128': {
3562
+ 128: (1, 8),
3563
+ 256: (2, 16),
3564
+ 512: (4, 16),
3565
+ 1024: (8, 8),
3566
+ 2048: (16, 16),
3567
+ 4096: (32, 8),
3568
+ 8192: (16, 16),
3569
+ },
3570
+ 'q_head-4_kv_head-1_head-256': {
3571
+ 128: (1, 32),
3572
+ 256: (2, 16),
3573
+ 512: (4, 64),
3574
+ 1024: (8, 32),
3575
+ 2048: (16, 64),
3576
+ 4096: (16, 16),
3577
+ 8192: (32, 32),
3578
+ },
3579
+ 'q_head-64_kv_head-8_head-256': {
3580
+ 128: (1, 16),
3581
+ 256: (2, 8),
3582
+ 512: (4, 8),
3583
+ 1024: (8, 16),
3584
+ 2048: (8, 16),
3585
+ 4096: (16, 16),
3586
+ 8192: (16, 16),
3587
+ },
3588
+ 'q_head-16_kv_head-2_head-256': {
3589
+ 128: (1, 8),
3590
+ 256: (2, 8),
3591
+ 512: (4, 8),
3592
+ 1024: (8, 16),
3593
+ 2048: (16, 8),
3594
+ 4096: (32, 16),
3595
+ 8192: (32, 32),
3596
+ },
3597
+ 'q_head-128_kv_head-1_head-256': {
3598
+ 128: (1, 8),
3599
+ 256: (2, 8),
3600
+ 512: (4, 8),
3601
+ 1024: (8, 8),
3602
+ 2048: (8, 8),
3603
+ 4096: (16, 8),
3604
+ 8192: (16, 16),
3605
+ },
3606
+ 'q_head-32_kv_head-16_head-256': {
3607
+ 128: (1, 16),
3608
+ 256: (2, 16),
3609
+ 512: (4, 16),
3610
+ 1024: (4, 16),
3611
+ 2048: (4, 16),
3612
+ 4096: (4, 16),
3613
+ 8192: (4, 16),
3614
+ },
3615
+ 'q_head-64_kv_head-2_head-128': {
3616
+ 128: (1, 16),
3617
+ 256: (2, 8),
3618
+ 512: (4, 8),
3619
+ 1024: (8, 16),
3620
+ 2048: (16, 8),
3621
+ 4096: (16, 32),
3622
+ 8192: (32, 32),
3623
+ },
3624
+ 'q_head-8_kv_head-4_head-256': {
3625
+ 128: (1, 64),
3626
+ 256: (2, 64),
3627
+ 512: (4, 128),
3628
+ 1024: (8, 32),
3629
+ 2048: (16, 32),
3630
+ 4096: (32, 32),
3631
+ 8192: (32, 32),
3632
+ },
3633
+ 'q_head-128_kv_head-8_head-256': {
3634
+ 128: (1, 8),
3635
+ 256: (2, 8),
3636
+ 512: (4, 8),
3637
+ 1024: (8, 8),
3638
+ 2048: (8, 8),
3639
+ 4096: (16, 8),
3640
+ 8192: (16, 8),
3641
+ },
3642
+ 'q_head-32_kv_head-2_head-256': {
3643
+ 128: (1, 8),
3644
+ 256: (2, 8),
3645
+ 512: (4, 8),
3646
+ 1024: (8, 8),
3647
+ 2048: (16, 32),
3648
+ 4096: (32, 32),
3649
+ 8192: (32, 32),
3650
+ },
3651
+ 'q_head-64_kv_head-16_head-128': {
3652
+ 128: (1, 16),
3653
+ 256: (2, 16),
3654
+ 512: (4, 16),
3655
+ 1024: (8, 16),
3656
+ 2048: (8, 16),
3657
+ 4096: (8, 16),
3658
+ 8192: (8, 16),
3659
+ },
3660
+ 'q_head-16_kv_head-4_head-128': {
3661
+ 128: (1, 128),
3662
+ 256: (2, 64),
3663
+ 512: (4, 16),
3664
+ 1024: (8, 16),
3665
+ 2048: (8, 32),
3666
+ 4096: (32, 16),
3667
+ 8192: (32, 64),
3668
+ },
3669
+ 'q_head-4_kv_head-1_head-128': {
3670
+ 128: (1, 32),
3671
+ 256: (2, 256),
3672
+ 512: (4, 64),
3673
+ 1024: (8, 8),
3674
+ 2048: (8, 32),
3675
+ 4096: (16, 32),
3676
+ 8192: (32, 128),
3677
+ },
3678
+ 'q_head-128_kv_head-2_head-128': {
3679
+ 128: (1, 8),
3680
+ 256: (2, 8),
3681
+ 512: (4, 8),
3682
+ 1024: (8, 16),
3683
+ 2048: (16, 16),
3684
+ 4096: (16, 8),
3685
+ 8192: (16, 16),
3686
+ },
3687
+ },
3688
+ 'q_bfloat16_kv_float8_e4m3fn': {
3689
+ 'q_head-32_kv_head-2_head-256': {
3690
+ 256: (2, 8),
3691
+ 512: (4, 8),
3692
+ 1024: (8, 32),
3693
+ 2048: (16, 8),
3694
+ 4096: (32, 16),
3695
+ 8192: (32, 32),
3696
+ 128: (1, 16),
3697
+ },
3698
+ 'q_head-32_kv_head-8_head-128': {
3699
+ 8192: (32, 32),
3700
+ 128: (1, 64),
3701
+ 256: (2, 16),
3702
+ 512: (4, 16),
3703
+ 1024: (8, 16),
3704
+ 2048: (16, 16),
3705
+ 4096: (32, 32),
3706
+ },
3707
+ 'q_head-64_kv_head-2_head-128': {
3708
+ 2048: (16, 16),
3709
+ 4096: (16, 32),
3710
+ 8192: (32, 8),
3711
+ 128: (1, 32),
3712
+ 256: (1, 16),
3713
+ 512: (4, 8),
3714
+ 1024: (8, 32),
3715
+ },
3716
+ 'q_head-64_kv_head-8_head-128': {
3717
+ 128: (1, 8),
3718
+ 256: (2, 16),
3719
+ 512: (4, 16),
3720
+ 1024: (8, 16),
3721
+ 2048: (16, 16),
3722
+ 4096: (32, 16),
3723
+ 8192: (32, 16),
3724
+ },
3725
+ 'q_head-128_kv_head-4_head-256': {
3726
+ 512: (4, 16),
3727
+ 1024: (4, 8),
3728
+ 2048: (16, 16),
3729
+ 4096: (16, 8),
3730
+ 8192: (16, 16),
3731
+ 128: (1, 8),
3732
+ 256: (2, 16),
3733
+ },
3734
+ 'q_head-32_kv_head-8_head-256': {
3735
+ 128: (1, 32),
3736
+ 256: (2, 8),
3737
+ 512: (4, 16),
3738
+ 1024: (8, 16),
3739
+ 2048: (16, 32),
3740
+ 4096: (32, 16),
3741
+ 8192: (32, 16),
3742
+ },
3743
+ 'q_head-128_kv_head-2_head-128': {
3744
+ 128: (1, 8),
3745
+ 256: (2, 16),
3746
+ 512: (4, 8),
3747
+ 1024: (8, 32),
3748
+ 2048: (16, 8),
3749
+ 4096: (16, 16),
3750
+ 8192: (16, 8),
3751
+ },
3752
+ 'q_head-16_kv_head-2_head-128': {
3753
+ 256: (2, 8),
3754
+ 1024: (8, 64),
3755
+ 8192: (32, 32),
3756
+ 512: (4, 16),
3757
+ 2048: (16, 8),
3758
+ 128: (1, 32),
3759
+ 4096: (16, 64),
3760
+ },
3761
+ 'q_head-16_kv_head-2_head-256': {
3762
+ 2048: (16, 32),
3763
+ 512: (4, 8),
3764
+ 128: (1, 32),
3765
+ 1024: (8, 8),
3766
+ 4096: (32, 16),
3767
+ 256: (2, 16),
3768
+ 8192: (32, 32),
3769
+ },
3770
+ 'q_head-16_kv_head-4_head-128': {
3771
+ 4096: (32, 32),
3772
+ 256: (2, 64),
3773
+ 8192: (32, 64),
3774
+ 2048: (16, 16),
3775
+ 512: (4, 64),
3776
+ 1024: (8, 32),
3777
+ 128: (1, 8),
3778
+ },
3779
+ 'q_head-16_kv_head-4_head-256': {
3780
+ 8192: (32, 64),
3781
+ 512: (4, 32),
3782
+ 4096: (32, 64),
3783
+ 1024: (8, 64),
3784
+ 2048: (8, 16),
3785
+ 128: (1, 32),
3786
+ 256: (2, 64),
3787
+ },
3788
+ 'q_head-32_kv_head-16_head-128': {
3789
+ 128: (1, 32),
3790
+ 4096: (32, 16),
3791
+ 256: (2, 32),
3792
+ 512: (4, 16),
3793
+ 1024: (8, 32),
3794
+ 2048: (16, 16),
3795
+ 8192: (32, 16),
3796
+ },
3797
+ 'q_head-32_kv_head-16_head-256': {
3798
+ 256: (2, 16),
3799
+ 512: (4, 16),
3800
+ 128: (1, 16),
3801
+ 1024: (8, 16),
3802
+ 2048: (16, 16),
3803
+ 4096: (16, 16),
3804
+ 8192: (16, 16),
3805
+ },
3806
+ 'q_head-64_kv_head-2_head-256': {
3807
+ 128: (1, 8),
3808
+ 256: (2, 8),
3809
+ 512: (2, 16),
3810
+ 1024: (8, 16),
3811
+ 2048: (16, 16),
3812
+ 4096: (32, 8),
3813
+ 8192: (32, 16),
3814
+ },
3815
+ 'q_head-128_kv_head-8_head-128': {
3816
+ 128: (1, 8),
3817
+ 256: (2, 16),
3818
+ 512: (4, 16),
3819
+ 1024: (8, 16),
3820
+ 2048: (16, 16),
3821
+ 4096: (32, 8),
3822
+ 8192: (32, 16),
3823
+ },
3824
+ 'q_head-32_kv_head-2_head-128': {
3825
+ 128: (1, 64),
3826
+ 256: (2, 32),
3827
+ 512: (4, 32),
3828
+ 1024: (8, 16),
3829
+ 2048: (16, 128),
3830
+ 4096: (32, 32),
3831
+ 8192: (32, 16),
3832
+ },
3833
+ 'q_head-64_kv_head-8_head-256': {
3834
+ 128: (1, 16),
3835
+ 256: (2, 16),
3836
+ 512: (4, 16),
3837
+ 1024: (8, 16),
3838
+ 2048: (16, 16),
3839
+ 4096: (16, 16),
3840
+ 8192: (32, 16),
3841
+ },
3842
+ 'q_head-16_kv_head-8_head-128': {
3843
+ 1024: (8, 32),
3844
+ 8192: (32, 32),
3845
+ 2048: (16, 32),
3846
+ 128: (1, 128),
3847
+ 4096: (32, 32),
3848
+ 256: (2, 32),
3849
+ 512: (4, 128),
3850
+ },
3851
+ 'q_head-16_kv_head-8_head-256': {
3852
+ 2048: (16, 32),
3853
+ 128: (1, 32),
3854
+ 4096: (32, 16),
3855
+ 256: (2, 8),
3856
+ 8192: (32, 64),
3857
+ 512: (4, 16),
3858
+ 1024: (8, 8),
3859
+ },
3860
+ 'q_head-32_kv_head-4_head-256': {
3861
+ 8192: (32, 32),
3862
+ 128: (1, 32),
3863
+ 256: (2, 32),
3864
+ 512: (4, 8),
3865
+ 1024: (8, 32),
3866
+ 2048: (16, 32),
3867
+ 4096: (32, 32),
3868
+ },
3869
+ 'q_head-64_kv_head-4_head-256': {
3870
+ 128: (1, 32),
3871
+ 256: (2, 8),
3872
+ 512: (4, 16),
3873
+ 1024: (8, 8),
3874
+ 2048: (16, 32),
3875
+ 4096: (32, 16),
3876
+ 8192: (32, 8),
3877
+ },
3878
+ 'q_head-64_kv_head-32_head-128': {
3879
+ 8192: (16, 8),
3880
+ 128: (1, 8),
3881
+ 256: (2, 8),
3882
+ 512: (4, 8),
3883
+ 1024: (8, 8),
3884
+ 2048: (8, 8),
3885
+ 4096: (8, 8),
3886
+ },
3887
+ 'q_head-128_kv_head-4_head-128': {
3888
+ 512: (4, 32),
3889
+ 1024: (8, 8),
3890
+ 2048: (16, 16),
3891
+ 128: (1, 8),
3892
+ 4096: (32, 8),
3893
+ 256: (2, 32),
3894
+ 8192: (32, 8),
3895
+ },
3896
+ 'q_head-128_kv_head-2_head-256': {
3897
+ 128: (1, 16),
3898
+ 256: (2, 8),
3899
+ 512: (4, 8),
3900
+ 1024: (8, 16),
3901
+ 2048: (8, 8),
3902
+ 4096: (32, 8),
3903
+ 8192: (32, 8),
3904
+ },
3905
+ 'q_head-128_kv_head-8_head-256': {
3906
+ 128: (1, 8),
3907
+ 256: (2, 8),
3908
+ 512: (4, 8),
3909
+ 1024: (8, 8),
3910
+ 2048: (16, 8),
3911
+ 4096: (32, 8),
3912
+ 8192: (32, 8),
3913
+ },
3914
+ 'q_head-64_kv_head-16_head-128': {
3915
+ 128: (1, 16),
3916
+ 256: (2, 16),
3917
+ 512: (4, 16),
3918
+ 1024: (8, 8),
3919
+ 2048: (16, 16),
3920
+ 4096: (32, 8),
3921
+ 8192: (32, 8),
3922
+ },
3923
+ 'q_head-128_kv_head-16_head-128': {
3924
+ 128: (1, 8),
3925
+ 256: (2, 8),
3926
+ 512: (4, 8),
3927
+ 1024: (8, 8),
3928
+ 2048: (16, 8),
3929
+ 4096: (16, 8),
3930
+ 8192: (16, 8),
3931
+ },
3932
+ 'q_head-32_kv_head-4_head-128': {
3933
+ 128: (1, 32),
3934
+ 256: (2, 8),
3935
+ 512: (4, 64),
3936
+ 1024: (8, 32),
3937
+ 2048: (16, 16),
3938
+ 4096: (32, 64),
3939
+ 8192: (32, 32),
3940
+ },
3941
+ 'q_head-64_kv_head-16_head-256': {
3942
+ 128: (1, 8),
3943
+ 256: (2, 8),
3944
+ 512: (4, 8),
3945
+ 1024: (8, 8),
3946
+ 2048: (16, 8),
3947
+ 4096: (16, 8),
3948
+ 8192: (16, 8),
3949
+ },
3950
+ 'q_head-64_kv_head-4_head-128': {
3951
+ 128: (1, 32),
3952
+ 256: (2, 16),
3953
+ 512: (4, 8),
3954
+ 1024: (8, 16),
3955
+ 2048: (16, 8),
3956
+ 4096: (32, 16),
3957
+ 8192: (32, 16),
3958
+ },
3959
+ 'q_head-2_kv_head-2_head-256': {
3960
+ 256: (2, 16),
3961
+ 512: (4, 8),
3962
+ 1024: (8, 128),
3963
+ 2048: (16, 128),
3964
+ 4096: (32, 64),
3965
+ 8192: (32, 256),
3966
+ 128: (1, 8),
3967
+ },
3968
+ 'q_head-4_kv_head-2_head-128': {
3969
+ 2048: (16, 128),
3970
+ 4096: (32, 64),
3971
+ 8192: (32, 128),
3972
+ 128: (1, 256),
3973
+ 256: (2, 128),
3974
+ 512: (4, 64),
3975
+ 1024: (8, 32),
3976
+ },
3977
+ 'q_head-8_kv_head-4_head-256': {
3978
+ 512: (4, 32),
3979
+ 1024: (8, 16),
3980
+ 2048: (16, 16),
3981
+ 4096: (32, 32),
3982
+ 8192: (32, 64),
3983
+ 128: (1, 128),
3984
+ 256: (2, 16),
3985
+ },
3986
+ 'q_head-8_kv_head-2_head-128': {
3987
+ 128: (1, 128),
3988
+ 256: (2, 256),
3989
+ 512: (4, 128),
3990
+ 1024: (8, 8),
3991
+ 2048: (16, 128),
3992
+ 4096: (32, 16),
3993
+ 8192: (32, 32),
3994
+ },
3995
+ 'q_head-2_kv_head-2_head-128': {
3996
+ 256: (2, 128),
3997
+ 128: (1, 64),
3998
+ 8192: (32, 16),
3999
+ 512: (4, 128),
4000
+ 1024: (8, 32),
4001
+ 2048: (16, 32),
4002
+ 4096: (32, 32),
4003
+ },
4004
+ 'q_head-4_kv_head-2_head-256': {
4005
+ 128: (1, 16),
4006
+ 256: (2, 32),
4007
+ 512: (4, 256),
4008
+ 1024: (8, 64),
4009
+ 2048: (16, 16),
4010
+ 4096: (32, 16),
4011
+ 8192: (32, 32),
4012
+ },
4013
+ 'q_head-8_kv_head-4_head-128': {
4014
+ 512: (4, 16),
4015
+ 1024: (8, 16),
4016
+ 2048: (16, 16),
4017
+ 128: (1, 256),
4018
+ 256: (2, 16),
4019
+ 4096: (32, 128),
4020
+ 8192: (32, 128),
4021
+ },
4022
+ 'q_head-8_kv_head-2_head-256': {
4023
+ 128: (1, 32),
4024
+ 256: (2, 64),
4025
+ 512: (4, 16),
4026
+ 1024: (8, 16),
4027
+ 2048: (16, 16),
4028
+ 4096: (32, 128),
4029
+ 8192: (32, 32),
4030
+ },
4031
+ },
4032
+ },
4033
+ },
4034
+ }
4035
+
4036
+
4037
+ def get_tuned_block_sizes(
4038
+ q_dtype,
4039
+ kv_dtype,
4040
+ actual_num_q_heads,
4041
+ actual_num_kv_heads,
4042
+ head_dim,
4043
+ page_size,
4044
+ max_num_tokens,
4045
+ pages_per_seq,
4046
+ ) -> tuple[int, int]:
4047
+ """Search tuned values for (num_kv_pages_per_blk, num_queries_per_blk)."""
4048
+
4049
+ # Set default block sizes for each tpu_version.
4050
+ tpu_version = get_tpu_version()
4051
+ if tpu_version < 4:
4052
+ raise NotImplementedError('TPU version must be 4 or higher.')
4053
+ match tpu_version:
4054
+ case 4:
4055
+ # TPUv4 has much smaller VMEM size so we pick fixed block sizes.
4056
+ bkv_p, bq = (512 // page_size, 32)
4057
+ case 7:
4058
+ bkv_p, bq = (4096 // page_size, 32)
4059
+ case _:
4060
+ bkv_p, bq = (2048 // page_size, 32)
4061
+
4062
+ keys = get_lookup_keys(
4063
+ page_size,
4064
+ q_dtype,
4065
+ kv_dtype,
4066
+ actual_num_q_heads,
4067
+ actual_num_kv_heads,
4068
+ head_dim,
4069
+ page_size * pages_per_seq,
4070
+ )
4071
+ device, page_size, dtypes, head_dims, max_model_len = keys
4072
+
4073
+ try:
4074
+ bkv_p, bq = TUNED_BLOCK_SIZES[device][page_size][dtypes][head_dims][
4075
+ max_model_len]
4076
+ except KeyError:
4077
+ logger.warning_once(
4078
+ 'Couldn`t find tuned sizes for the RPA v3 kernel with %s', keys)
4079
+
4080
+ return (min(pages_per_seq, bkv_p), min(max_num_tokens, bq))
4081
+
4082
+
4083
+ def get_lookup_keys(
4084
+ page_size,
4085
+ q_dtype,
4086
+ kv_dtype,
4087
+ num_q_heads,
4088
+ num_kv_heads,
4089
+ head_dim,
4090
+ max_model_len,
4091
+ ):
4092
+ """Get the lookup keys for tuned block sizes."""
4093
+ (
4094
+ page_size,
4095
+ q_dtype_name,
4096
+ kv_dtype_name,
4097
+ num_q_heads,
4098
+ num_kv_heads,
4099
+ head_dim,
4100
+ max_model_len,
4101
+ ) = get_simplified_raw_key(
4102
+ page_size,
4103
+ q_dtype,
4104
+ kv_dtype,
4105
+ num_q_heads,
4106
+ num_kv_heads,
4107
+ head_dim,
4108
+ max_model_len,
4109
+ )
4110
+
4111
+ return (
4112
+ get_device_name(),
4113
+ next_power_of_2(page_size),
4114
+ f'q_{q_dtype_name}_kv_{kv_dtype_name}',
4115
+ f'q_head-{num_q_heads}_kv_head-{num_kv_heads}_head-{head_dim}',
4116
+ next_power_of_2(max_model_len),
4117
+ )
4118
+
4119
+
4120
+ def get_simplified_raw_key(
4121
+ page_size,
4122
+ q_dtype,
4123
+ kv_dtype,
4124
+ actual_num_q_heads,
4125
+ actual_num_kv_heads,
4126
+ head_dim,
4127
+ max_model_len,
4128
+ ):
4129
+ """Get the simplified key."""
4130
+ assert actual_num_q_heads % actual_num_kv_heads == 0
4131
+ actual_num_q_heads_per_kv_head = actual_num_q_heads // actual_num_kv_heads
4132
+ q_packing = get_dtype_packing(q_dtype)
4133
+ kv_packing = get_dtype_packing(kv_dtype)
4134
+ num_kv_heads_x2 = align_to(actual_num_kv_heads * 2, kv_packing)
4135
+ num_q_heads_per_kv_head = align_to(actual_num_q_heads_per_kv_head,
4136
+ q_packing)
4137
+ assert num_kv_heads_x2 % 2 == 0
4138
+
4139
+ return (
4140
+ next_power_of_2(page_size),
4141
+ jnp.dtype(q_dtype).name,
4142
+ jnp.dtype(kv_dtype).name,
4143
+ next_power_of_2(num_q_heads_per_kv_head * actual_num_kv_heads),
4144
+ next_power_of_2(num_kv_heads_x2) // 2,
4145
+ align_to(head_dim, 128),
4146
+ next_power_of_2(max_model_len),
4147
+ )