tpu-inference 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (168) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_adapters.py +83 -0
  4. tests/core/test_core_tpu.py +523 -0
  5. tests/core/test_disagg_executor.py +60 -0
  6. tests/core/test_disagg_utils.py +53 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  10. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  11. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  12. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  13. tests/lora/__init__.py +0 -0
  14. tests/lora/test_lora.py +123 -0
  15. tests/test_base.py +201 -0
  16. tests/test_quantization.py +836 -0
  17. tests/test_tpu_info.py +120 -0
  18. tests/test_utils.py +218 -0
  19. tests/tpu_backend_test.py +59 -0
  20. tpu_inference/__init__.py +30 -0
  21. tpu_inference/adapters/__init__.py +0 -0
  22. tpu_inference/adapters/vllm_adapters.py +42 -0
  23. tpu_inference/adapters/vllm_config_adapters.py +134 -0
  24. tpu_inference/backend.py +69 -0
  25. tpu_inference/core/__init__.py +0 -0
  26. tpu_inference/core/adapters.py +153 -0
  27. tpu_inference/core/core_tpu.py +776 -0
  28. tpu_inference/core/disagg_executor.py +117 -0
  29. tpu_inference/core/disagg_utils.py +51 -0
  30. tpu_inference/di/__init__.py +0 -0
  31. tpu_inference/di/abstracts.py +28 -0
  32. tpu_inference/di/host.py +76 -0
  33. tpu_inference/di/interfaces.py +51 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/tpu_connector.py +699 -0
  36. tpu_inference/distributed/utils.py +59 -0
  37. tpu_inference/executors/__init__.py +0 -0
  38. tpu_inference/executors/ray_distributed_executor.py +346 -0
  39. tpu_inference/experimental/__init__.py +0 -0
  40. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  41. tpu_inference/interfaces/__init__.py +0 -0
  42. tpu_inference/interfaces/cache.py +31 -0
  43. tpu_inference/interfaces/config.py +47 -0
  44. tpu_inference/interfaces/config_parts.py +117 -0
  45. tpu_inference/interfaces/engine.py +51 -0
  46. tpu_inference/interfaces/outputs.py +22 -0
  47. tpu_inference/interfaces/params.py +21 -0
  48. tpu_inference/interfaces/platform.py +74 -0
  49. tpu_inference/interfaces/request.py +39 -0
  50. tpu_inference/interfaces/scheduler.py +31 -0
  51. tpu_inference/kernels/__init__.py +0 -0
  52. tpu_inference/kernels/collectives/__init__.py +0 -0
  53. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  54. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  55. tpu_inference/kernels/collectives/util.py +47 -0
  56. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  57. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  58. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  59. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  60. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  61. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  62. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  66. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
  71. tpu_inference/layers/__init__.py +0 -0
  72. tpu_inference/layers/common/__init__.py +0 -0
  73. tpu_inference/layers/common/attention_metadata.py +34 -0
  74. tpu_inference/layers/jax/__init__.py +0 -0
  75. tpu_inference/layers/jax/attention/__init__.py +0 -0
  76. tpu_inference/layers/jax/attention/attention.py +254 -0
  77. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  78. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  79. tpu_inference/layers/jax/attention_interface.py +356 -0
  80. tpu_inference/layers/jax/base.py +151 -0
  81. tpu_inference/layers/jax/binary_search.py +295 -0
  82. tpu_inference/layers/jax/constants.py +88 -0
  83. tpu_inference/layers/jax/layers.py +301 -0
  84. tpu_inference/layers/jax/misc.py +16 -0
  85. tpu_inference/layers/jax/moe/__init__.py +0 -0
  86. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  87. tpu_inference/layers/jax/moe/moe.py +209 -0
  88. tpu_inference/layers/jax/rope.py +172 -0
  89. tpu_inference/layers/jax/rope_interface.py +214 -0
  90. tpu_inference/layers/jax/sample/__init__.py +0 -0
  91. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  92. tpu_inference/layers/jax/sample/sampling.py +95 -0
  93. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  94. tpu_inference/layers/jax/sharding.py +406 -0
  95. tpu_inference/layers/jax/transformer_block.py +76 -0
  96. tpu_inference/layers/vllm/__init__.py +0 -0
  97. tpu_inference/layers/vllm/attention.py +184 -0
  98. tpu_inference/layers/vllm/fused_moe.py +399 -0
  99. tpu_inference/layers/vllm/linear_common.py +186 -0
  100. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  101. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  102. tpu_inference/layers/vllm/quantization/common.py +105 -0
  103. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  104. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  105. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  106. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  108. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  109. tpu_inference/layers/vllm/sharding.py +151 -0
  110. tpu_inference/logger.py +10 -0
  111. tpu_inference/lora/__init__.py +0 -0
  112. tpu_inference/lora/torch_lora_ops.py +103 -0
  113. tpu_inference/lora/torch_punica_tpu.py +308 -0
  114. tpu_inference/mock/__init__.py +0 -0
  115. tpu_inference/mock/vllm_config_utils.py +28 -0
  116. tpu_inference/mock/vllm_envs.py +1233 -0
  117. tpu_inference/mock/vllm_logger.py +212 -0
  118. tpu_inference/mock/vllm_logging_utils.py +15 -0
  119. tpu_inference/models/__init__.py +0 -0
  120. tpu_inference/models/common/__init__.py +0 -0
  121. tpu_inference/models/common/model_loader.py +433 -0
  122. tpu_inference/models/jax/__init__.py +0 -0
  123. tpu_inference/models/jax/deepseek_v3.py +868 -0
  124. tpu_inference/models/jax/llama3.py +366 -0
  125. tpu_inference/models/jax/llama4.py +473 -0
  126. tpu_inference/models/jax/llama_eagle3.py +333 -0
  127. tpu_inference/models/jax/phi3.py +376 -0
  128. tpu_inference/models/jax/qwen2.py +375 -0
  129. tpu_inference/models/jax/qwen2_5_vl.py +976 -0
  130. tpu_inference/models/jax/qwen3.py +302 -0
  131. tpu_inference/models/jax/utils/__init__.py +0 -0
  132. tpu_inference/models/jax/utils/file_utils.py +96 -0
  133. tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
  134. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
  136. tpu_inference/models/jax/utils/weight_utils.py +510 -0
  137. tpu_inference/models/vllm/__init__.py +0 -0
  138. tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
  139. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  140. tpu_inference/platforms/__init__.py +2 -0
  141. tpu_inference/platforms/tpu_jax.py +257 -0
  142. tpu_inference/runner/__init__.py +0 -0
  143. tpu_inference/runner/block_table_jax.py +122 -0
  144. tpu_inference/runner/compilation_manager.py +672 -0
  145. tpu_inference/runner/input_batch_jax.py +435 -0
  146. tpu_inference/runner/kv_cache.py +119 -0
  147. tpu_inference/runner/kv_cache_manager.py +460 -0
  148. tpu_inference/runner/lora_utils.py +92 -0
  149. tpu_inference/runner/multimodal_manager.py +208 -0
  150. tpu_inference/runner/persistent_batch_manager.py +244 -0
  151. tpu_inference/runner/speculative_decoding_manager.py +250 -0
  152. tpu_inference/runner/structured_decoding_manager.py +89 -0
  153. tpu_inference/runner/tpu_jax_runner.py +771 -0
  154. tpu_inference/runner/utils.py +426 -0
  155. tpu_inference/spec_decode/__init__.py +0 -0
  156. tpu_inference/spec_decode/jax/__init__.py +0 -0
  157. tpu_inference/spec_decode/jax/eagle3.py +334 -0
  158. tpu_inference/tpu_info.py +77 -0
  159. tpu_inference/utils.py +294 -0
  160. tpu_inference/worker/__init__.py +0 -0
  161. tpu_inference/worker/_temporary_vllm_compat.py +129 -0
  162. tpu_inference/worker/base.py +100 -0
  163. tpu_inference/worker/tpu_worker_jax.py +321 -0
  164. tpu_inference-0.11.1.dist-info/METADATA +101 -0
  165. tpu_inference-0.11.1.dist-info/RECORD +168 -0
  166. tpu_inference-0.11.1.dist-info/WHEEL +5 -0
  167. tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
  168. tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,3834 @@
1
+ """Auto-tuned block sizes for ragged paged attention."""
2
+
3
+ import jax.numpy as jnp
4
+
5
+ from tpu_inference.kernels.ragged_paged_attention.v3.util import (
6
+ align_to, get_dtype_packing, get_tpu_version, next_power_of_2)
7
+ from tpu_inference.logger import init_logger
8
+ from tpu_inference.utils import get_device_name
9
+
10
+ logger = init_logger(__name__)
11
+
12
+ # key
13
+ # - device_name
14
+ # - page_size
15
+ # - q_{q_dtype_name}_kv_{kv_dtype_name}
16
+ # - q_head-{num_q_heads}_kv_head-{num_kv_heads}-_head-{head_dim}
17
+ # - max_model_len
18
+ # value:
19
+ # - (num_kv_pages_per_block, num_queries_per_block)
20
+ TUNED_BLOCK_SIZES = {
21
+ 'TPU v6e': {
22
+ 128: {
23
+ 'q_bfloat16_kv_bfloat16': {
24
+ 'q_head-128_kv_head-1_head-128': {
25
+ 1024: (8, 16),
26
+ 2048: (8, 16),
27
+ 256: (2, 8),
28
+ 4096: (16, 8),
29
+ 512: (2, 32),
30
+ 8192: (16, 16),
31
+ },
32
+ 'q_head-128_kv_head-1_head-256': {
33
+ 1024: (8, 16),
34
+ 2048: (16, 16),
35
+ 256: (2, 8),
36
+ 4096: (16, 8),
37
+ 512: (4, 8),
38
+ 8192: (16, 16),
39
+ },
40
+ 'q_head-128_kv_head-16_head-128': {
41
+ 1024: (8, 16),
42
+ 2048: (8, 16),
43
+ 256: (2, 16),
44
+ 4096: (8, 16),
45
+ 512: (4, 16),
46
+ 8192: (8, 16),
47
+ },
48
+ 'q_head-128_kv_head-16_head-256': {
49
+ 1024: (4, 8),
50
+ 2048: (4, 8),
51
+ 256: (2, 8),
52
+ 4096: (4, 8),
53
+ 512: (4, 8),
54
+ 8192: (4, 8),
55
+ },
56
+ 'q_head-128_kv_head-2_head-128': {
57
+ 1024: (8, 8),
58
+ 2048: (16, 8),
59
+ 256: (2, 32),
60
+ 4096: (16, 8),
61
+ 512: (4, 16),
62
+ 8192: (16, 32),
63
+ },
64
+ 'q_head-128_kv_head-2_head-256': {
65
+ 1024: (8, 16),
66
+ 2048: (16, 8),
67
+ 256: (2, 8),
68
+ 4096: (16, 16),
69
+ 512: (4, 8),
70
+ 8192: (16, 16),
71
+ },
72
+ 'q_head-128_kv_head-4_head-128': {
73
+ 1024: (8, 8),
74
+ 2048: (16, 16),
75
+ 256: (2, 8),
76
+ 4096: (16, 16),
77
+ 512: (4, 8),
78
+ 8192: (16, 32),
79
+ },
80
+ 'q_head-128_kv_head-4_head-256': {
81
+ 1024: (8, 8),
82
+ 2048: (16, 16),
83
+ 256: (2, 8),
84
+ 4096: (16, 16),
85
+ 512: (4, 8),
86
+ 8192: (16, 16),
87
+ },
88
+ 'q_head-128_kv_head-8_head-128': {
89
+ 1024: (8, 8),
90
+ 2048: (16, 16),
91
+ 256: (2, 16),
92
+ 4096: (16, 32),
93
+ 512: (4, 32),
94
+ 8192: (16, 32),
95
+ },
96
+ 'q_head-128_kv_head-8_head-256': {
97
+ 1024: (8, 8),
98
+ 2048: (8, 16),
99
+ 256: (2, 16),
100
+ 4096: (8, 16),
101
+ 512: (4, 8),
102
+ 8192: (8, 16),
103
+ },
104
+ 'q_head-16_kv_head-1_head-128': {
105
+ 1024: (8, 32),
106
+ 2048: (16, 8),
107
+ 256: (2, 128),
108
+ 4096: (16, 32),
109
+ 512: (4, 256),
110
+ 8192: (16, 128),
111
+ },
112
+ 'q_head-16_kv_head-1_head-256': {
113
+ 1024: (8, 32),
114
+ 2048: (16, 64),
115
+ 256: (2, 128),
116
+ 4096: (16, 16),
117
+ 512: (4, 32),
118
+ 8192: (16, 16),
119
+ },
120
+ 'q_head-16_kv_head-2_head-128': {
121
+ 1024: (8, 128),
122
+ 2048: (16, 16),
123
+ 256: (2, 64),
124
+ 4096: (16, 64),
125
+ 512: (4, 16),
126
+ 8192: (16, 128),
127
+ },
128
+ 'q_head-16_kv_head-2_head-256': {
129
+ 1024: (8, 32),
130
+ 2048: (16, 32),
131
+ 256: (2, 32),
132
+ 4096: (16, 64),
133
+ 512: (4, 32),
134
+ 8192: (16, 32),
135
+ },
136
+ 'q_head-16_kv_head-4_head-128': {
137
+ 1024: (8, 128),
138
+ 2048: (8, 128),
139
+ 256: (2, 32),
140
+ 4096: (16, 128),
141
+ 512: (4, 32),
142
+ 8192: (16, 128),
143
+ },
144
+ 'q_head-16_kv_head-4_head-256': {
145
+ 1024: (8, 16),
146
+ 2048: (8, 16),
147
+ 256: (2, 32),
148
+ 4096: (16, 128),
149
+ 512: (4, 32),
150
+ 8192: (16, 64),
151
+ },
152
+ 'q_head-16_kv_head-8_head-128': {
153
+ 1024: (8, 32),
154
+ 2048: (16, 64),
155
+ 256: (2, 64),
156
+ 4096: (16, 128),
157
+ 512: (4, 64),
158
+ 8192: (16, 128),
159
+ },
160
+ 'q_head-16_kv_head-8_head-256': {
161
+ 1024: (8, 64),
162
+ 2048: (16, 64),
163
+ 256: (2, 64),
164
+ 4096: (16, 64),
165
+ 512: (4, 64),
166
+ 8192: (8, 128),
167
+ },
168
+ 'q_head-2_kv_head-1_head-128': {
169
+ 1024: (8, 64),
170
+ 2048: (16, 256),
171
+ 256: (2, 64),
172
+ 4096: (16, 64),
173
+ 512: (4, 32),
174
+ 8192: (16, 32),
175
+ },
176
+ 'q_head-2_kv_head-1_head-256': {
177
+ 1024: (8, 16),
178
+ 2048: (16, 64),
179
+ 256: (2, 16),
180
+ 4096: (16, 32),
181
+ 512: (4, 256),
182
+ 8192: (16, 32),
183
+ },
184
+ 'q_head-32_kv_head-1_head-128': {
185
+ 1024: (8, 32),
186
+ 2048: (16, 16),
187
+ 256: (2, 32),
188
+ 4096: (16, 8),
189
+ 512: (4, 16),
190
+ 8192: (16, 16),
191
+ },
192
+ 'q_head-32_kv_head-1_head-256': {
193
+ 1024: (4, 16),
194
+ 2048: (16, 8),
195
+ 256: (2, 16),
196
+ 4096: (16, 16),
197
+ 512: (4, 8),
198
+ 8192: (16, 32),
199
+ },
200
+ 'q_head-32_kv_head-16_head-128': {
201
+ 1024: (8, 32),
202
+ 2048: (8, 32),
203
+ 256: (2, 32),
204
+ 4096: (8, 64),
205
+ 512: (4, 32),
206
+ 8192: (8, 32),
207
+ },
208
+ 'q_head-32_kv_head-16_head-256': {
209
+ 1024: (4, 32),
210
+ 2048: (4, 32),
211
+ 256: (2, 32),
212
+ 4096: (4, 32),
213
+ 512: (4, 32),
214
+ 8192: (4, 32),
215
+ },
216
+ 'q_head-32_kv_head-2_head-128': {
217
+ 1024: (8, 8),
218
+ 2048: (16, 16),
219
+ 256: (2, 16),
220
+ 4096: (16, 64),
221
+ 512: (4, 16),
222
+ 8192: (16, 16),
223
+ },
224
+ 'q_head-32_kv_head-2_head-256': {
225
+ 1024: (8, 32),
226
+ 2048: (16, 32),
227
+ 256: (2, 32),
228
+ 4096: (16, 32),
229
+ 512: (2, 16),
230
+ 8192: (16, 16),
231
+ },
232
+ 'q_head-32_kv_head-4_head-128': {
233
+ 1024: (8, 16),
234
+ 2048: (16, 32),
235
+ 256: (2, 64),
236
+ 4096: (16, 32),
237
+ 512: (4, 64),
238
+ 8192: (16, 64),
239
+ },
240
+ 'q_head-32_kv_head-4_head-256': {
241
+ 1024: (8, 32),
242
+ 2048: (16, 64),
243
+ 256: (2, 32),
244
+ 4096: (16, 64),
245
+ 512: (4, 8),
246
+ 8192: (16, 32),
247
+ },
248
+ 'q_head-32_kv_head-8_head-128': {
249
+ 1024: (8, 32),
250
+ 2048: (16, 64),
251
+ 256: (2, 64),
252
+ 4096: (16, 32),
253
+ 512: (4, 32),
254
+ 8192: (16, 32),
255
+ },
256
+ 'q_head-32_kv_head-8_head-256': {
257
+ 1024: (8, 32),
258
+ 2048: (16, 32),
259
+ 256: (2, 16),
260
+ 4096: (16, 32),
261
+ 512: (4, 32),
262
+ 8192: (8, 64),
263
+ },
264
+ 'q_head-4_kv_head-1_head-128': {
265
+ 1024: (8, 8),
266
+ 2048: (16, 128),
267
+ 256: (2, 8),
268
+ 4096: (16, 256),
269
+ 512: (4, 8),
270
+ 8192: (16, 32),
271
+ },
272
+ 'q_head-4_kv_head-1_head-256': {
273
+ 1024: (8, 16),
274
+ 2048: (16, 64),
275
+ 256: (2, 64),
276
+ 4096: (16, 128),
277
+ 512: (4, 64),
278
+ 8192: (16, 32),
279
+ },
280
+ 'q_head-4_kv_head-2_head-128': {
281
+ 1024: (8, 64),
282
+ 2048: (16, 64),
283
+ 256: (2, 16),
284
+ 4096: (16, 64),
285
+ 512: (4, 32),
286
+ 8192: (16, 128),
287
+ },
288
+ 'q_head-4_kv_head-2_head-256': {
289
+ 1024: (8, 256),
290
+ 2048: (16, 64),
291
+ 256: (2, 64),
292
+ 4096: (16, 64),
293
+ 512: (4, 32),
294
+ 8192: (16, 128),
295
+ },
296
+ 'q_head-64_kv_head-1_head-128': {
297
+ 1024: (8, 16),
298
+ 2048: (16, 8),
299
+ 256: (2, 16),
300
+ 4096: (16, 32),
301
+ 512: (4, 8),
302
+ 8192: (16, 16),
303
+ },
304
+ 'q_head-64_kv_head-1_head-256': {
305
+ 1024: (8, 16),
306
+ 2048: (16, 16),
307
+ 256: (2, 16),
308
+ 4096: (16, 32),
309
+ 512: (4, 16),
310
+ 8192: (16, 32),
311
+ },
312
+ 'q_head-64_kv_head-16_head-128': {
313
+ 1024: (8, 32),
314
+ 2048: (8, 32),
315
+ 256: (2, 32),
316
+ 4096: (8, 32),
317
+ 512: (4, 16),
318
+ 8192: (8, 32),
319
+ },
320
+ 'q_head-64_kv_head-16_head-256': {
321
+ 1024: (4, 16),
322
+ 2048: (4, 16),
323
+ 256: (2, 16),
324
+ 4096: (4, 16),
325
+ 512: (4, 16),
326
+ 8192: (4, 16),
327
+ },
328
+ 'q_head-64_kv_head-2_head-128': {
329
+ 1024: (8, 16),
330
+ 2048: (16, 32),
331
+ 256: (2, 16),
332
+ 4096: (16, 32),
333
+ 512: (4, 16),
334
+ 8192: (16, 32),
335
+ },
336
+ 'q_head-64_kv_head-2_head-256': {
337
+ 1024: (8, 16),
338
+ 2048: (16, 16),
339
+ 256: (2, 8),
340
+ 4096: (16, 16),
341
+ 512: (4, 8),
342
+ 8192: (16, 32),
343
+ },
344
+ 'q_head-64_kv_head-4_head-128': {
345
+ 1024: (8, 32),
346
+ 2048: (16, 16),
347
+ 256: (2, 16),
348
+ 4096: (16, 16),
349
+ 512: (4, 8),
350
+ 8192: (16, 32),
351
+ },
352
+ 'q_head-64_kv_head-4_head-256': {
353
+ 1024: (8, 16),
354
+ 2048: (16, 32),
355
+ 256: (2, 8),
356
+ 4096: (16, 16),
357
+ 512: (4, 32),
358
+ 8192: (16, 32),
359
+ },
360
+ 'q_head-64_kv_head-8_head-128': {
361
+ 1024: (8, 16),
362
+ 2048: (16, 32),
363
+ 256: (2, 32),
364
+ 4096: (16, 32),
365
+ 512: (4, 32),
366
+ 8192: (16, 32),
367
+ },
368
+ 'q_head-64_kv_head-8_head-256': {
369
+ 1024: (8, 16),
370
+ 2048: (16, 16),
371
+ 256: (2, 8),
372
+ 4096: (8, 32),
373
+ 512: (4, 32),
374
+ 8192: (8, 32),
375
+ },
376
+ 'q_head-8_kv_head-1_head-128': {
377
+ 1024: (8, 128),
378
+ 2048: (16, 32),
379
+ 256: (2, 256),
380
+ 4096: (16, 128),
381
+ 512: (4, 16),
382
+ 8192: (16, 64),
383
+ },
384
+ 'q_head-8_kv_head-1_head-256': {
385
+ 1024: (8, 64),
386
+ 2048: (16, 32),
387
+ 256: (2, 32),
388
+ 4096: (16, 32),
389
+ 512: (4, 32),
390
+ 8192: (16, 16),
391
+ },
392
+ 'q_head-8_kv_head-2_head-128': {
393
+ 1024: (8, 128),
394
+ 2048: (16, 128),
395
+ 256: (2, 64),
396
+ 4096: (16, 32),
397
+ 512: (4, 32),
398
+ 8192: (16, 64),
399
+ },
400
+ 'q_head-8_kv_head-2_head-256': {
401
+ 1024: (8, 64),
402
+ 2048: (16, 128),
403
+ 256: (2, 16),
404
+ 4096: (16, 128),
405
+ 512: (4, 16),
406
+ 8192: (16, 64),
407
+ },
408
+ 'q_head-8_kv_head-4_head-128': {
409
+ 1024: (8, 8),
410
+ 2048: (16, 256),
411
+ 256: (2, 64),
412
+ 4096: (16, 64),
413
+ 512: (4, 32),
414
+ 8192: (16, 64),
415
+ },
416
+ 'q_head-8_kv_head-4_head-256': {
417
+ 1024: (8, 16),
418
+ 2048: (16, 64),
419
+ 256: (2, 8),
420
+ 4096: (16, 64),
421
+ 512: (4, 64),
422
+ 8192: (16, 64),
423
+ },
424
+ }
425
+ },
426
+ 256: {
427
+ 'q_bfloat16_kv_bfloat16': {
428
+ 'q_head-128_kv_head-1_head-128': {
429
+ 1024: (4, 8),
430
+ 2048: (8, 16),
431
+ 4096: (8, 8),
432
+ 512: (2, 8),
433
+ 8192: (8, 16),
434
+ },
435
+ 'q_head-128_kv_head-1_head-256': {
436
+ 1024: (4, 8),
437
+ 2048: (4, 8),
438
+ 4096: (8, 16),
439
+ 512: (2, 8),
440
+ 8192: (8, 16),
441
+ },
442
+ 'q_head-128_kv_head-16_head-128': {
443
+ 1024: (4, 16),
444
+ 2048: (4, 16),
445
+ 4096: (4, 16),
446
+ 512: (2, 16),
447
+ 8192: (4, 16),
448
+ },
449
+ 'q_head-128_kv_head-16_head-256': {
450
+ 1024: (2, 8),
451
+ 2048: (2, 8),
452
+ 4096: (2, 8),
453
+ 512: (2, 8),
454
+ 8192: (2, 8),
455
+ },
456
+ 'q_head-128_kv_head-2_head-128': {
457
+ 1024: (4, 8),
458
+ 2048: (8, 8),
459
+ 4096: (8, 8),
460
+ 512: (2, 8),
461
+ 8192: (8, 16),
462
+ },
463
+ 'q_head-128_kv_head-2_head-256': {
464
+ 1024: (4, 16),
465
+ 2048: (8, 8),
466
+ 4096: (8, 8),
467
+ 512: (2, 8),
468
+ 8192: (8, 8),
469
+ },
470
+ 'q_head-128_kv_head-4_head-128': {
471
+ 1024: (4, 16),
472
+ 2048: (8, 16),
473
+ 4096: (8, 16),
474
+ 512: (2, 32),
475
+ 8192: (8, 32),
476
+ },
477
+ 'q_head-128_kv_head-4_head-256': {
478
+ 1024: (4, 8),
479
+ 2048: (8, 8),
480
+ 4096: (8, 16),
481
+ 512: (2, 8),
482
+ 8192: (8, 16),
483
+ },
484
+ 'q_head-128_kv_head-8_head-128': {
485
+ 1024: (4, 16),
486
+ 2048: (8, 32),
487
+ 4096: (8, 32),
488
+ 512: (2, 32),
489
+ 8192: (8, 32),
490
+ },
491
+ 'q_head-128_kv_head-8_head-256': {
492
+ 1024: (4, 16),
493
+ 2048: (4, 16),
494
+ 4096: (4, 16),
495
+ 512: (2, 8),
496
+ 8192: (4, 16),
497
+ },
498
+ 'q_head-16_kv_head-1_head-128': {
499
+ 1024: (4, 32),
500
+ 2048: (8, 128),
501
+ 4096: (8, 128),
502
+ 512: (2, 32),
503
+ 8192: (8, 64),
504
+ },
505
+ 'q_head-16_kv_head-1_head-256': {
506
+ 1024: (4, 16),
507
+ 2048: (8, 32),
508
+ 4096: (8, 32),
509
+ 512: (2, 32),
510
+ 8192: (8, 32),
511
+ },
512
+ 'q_head-16_kv_head-2_head-128': {
513
+ 1024: (4, 8),
514
+ 2048: (8, 32),
515
+ 4096: (8, 16),
516
+ 512: (2, 32),
517
+ 8192: (8, 128),
518
+ },
519
+ 'q_head-16_kv_head-2_head-256': {
520
+ 1024: (4, 32),
521
+ 2048: (8, 32),
522
+ 4096: (8, 64),
523
+ 512: (2, 32),
524
+ 8192: (8, 16),
525
+ },
526
+ 'q_head-16_kv_head-4_head-128': {
527
+ 1024: (4, 32),
528
+ 2048: (8, 32),
529
+ 4096: (8, 64),
530
+ 512: (2, 16),
531
+ 8192: (8, 32),
532
+ },
533
+ 'q_head-16_kv_head-4_head-256': {
534
+ 1024: (4, 128),
535
+ 2048: (4, 32),
536
+ 4096: (8, 64),
537
+ 512: (2, 64),
538
+ 8192: (8, 64),
539
+ },
540
+ 'q_head-16_kv_head-8_head-128': {
541
+ 1024: (4, 64),
542
+ 2048: (8, 32),
543
+ 4096: (8, 64),
544
+ 512: (2, 64),
545
+ 8192: (8, 128),
546
+ },
547
+ 'q_head-16_kv_head-8_head-256': {
548
+ 1024: (4, 64),
549
+ 2048: (8, 64),
550
+ 4096: (8, 64),
551
+ 512: (2, 32),
552
+ 8192: (4, 128),
553
+ },
554
+ 'q_head-2_kv_head-1_head-128': {
555
+ 1024: (4, 64),
556
+ 2048: (8, 32),
557
+ 4096: (8, 32),
558
+ 512: (2, 64),
559
+ 8192: (8, 32),
560
+ },
561
+ 'q_head-2_kv_head-1_head-256': {
562
+ 1024: (4, 32),
563
+ 2048: (8, 16),
564
+ 4096: (8, 128),
565
+ 512: (2, 256),
566
+ 8192: (8, 256),
567
+ },
568
+ 'q_head-32_kv_head-1_head-128': {
569
+ 1024: (4, 32),
570
+ 2048: (8, 16),
571
+ 4096: (8, 32),
572
+ 512: (2, 64),
573
+ 8192: (8, 32),
574
+ },
575
+ 'q_head-32_kv_head-1_head-256': {
576
+ 1024: (4, 8),
577
+ 2048: (8, 32),
578
+ 4096: (8, 16),
579
+ 512: (2, 8),
580
+ 8192: (8, 32),
581
+ },
582
+ 'q_head-32_kv_head-16_head-128': {
583
+ 1024: (4, 32),
584
+ 2048: (4, 64),
585
+ 4096: (4, 64),
586
+ 512: (2, 64),
587
+ 8192: (4, 64),
588
+ },
589
+ 'q_head-32_kv_head-16_head-256': {
590
+ 1024: (2, 32),
591
+ 2048: (2, 32),
592
+ 4096: (2, 32),
593
+ 512: (2, 32),
594
+ 8192: (4, 32),
595
+ },
596
+ 'q_head-32_kv_head-2_head-128': {
597
+ 1024: (4, 64),
598
+ 2048: (8, 16),
599
+ 4096: (8, 16),
600
+ 512: (2, 64),
601
+ 8192: (8, 32),
602
+ },
603
+ 'q_head-32_kv_head-2_head-256': {
604
+ 1024: (4, 32),
605
+ 2048: (8, 16),
606
+ 4096: (8, 16),
607
+ 512: (2, 32),
608
+ 8192: (8, 32),
609
+ },
610
+ 'q_head-32_kv_head-4_head-128': {
611
+ 1024: (4, 64),
612
+ 2048: (8, 64),
613
+ 4096: (8, 64),
614
+ 512: (2, 16),
615
+ 8192: (8, 32),
616
+ },
617
+ 'q_head-32_kv_head-4_head-256': {
618
+ 1024: (4, 32),
619
+ 2048: (8, 64),
620
+ 4096: (8, 32),
621
+ 512: (2, 32),
622
+ 8192: (8, 32),
623
+ },
624
+ 'q_head-32_kv_head-8_head-128': {
625
+ 1024: (4, 64),
626
+ 2048: (8, 128),
627
+ 4096: (8, 32),
628
+ 512: (2, 64),
629
+ 8192: (8, 128),
630
+ },
631
+ 'q_head-32_kv_head-8_head-256': {
632
+ 1024: (4, 32),
633
+ 2048: (8, 32),
634
+ 4096: (8, 32),
635
+ 512: (2, 16),
636
+ 8192: (4, 64),
637
+ },
638
+ 'q_head-4_kv_head-1_head-128': {
639
+ 1024: (4, 8),
640
+ 2048: (8, 16),
641
+ 4096: (8, 128),
642
+ 512: (2, 8),
643
+ 8192: (8, 64),
644
+ },
645
+ 'q_head-4_kv_head-1_head-256': {
646
+ 1024: (4, 32),
647
+ 2048: (8, 128),
648
+ 4096: (8, 128),
649
+ 512: (2, 32),
650
+ 8192: (8, 64),
651
+ },
652
+ 'q_head-4_kv_head-2_head-128': {
653
+ 1024: (4, 64),
654
+ 2048: (8, 256),
655
+ 4096: (8, 256),
656
+ 512: (2, 16),
657
+ 8192: (8, 32),
658
+ },
659
+ 'q_head-4_kv_head-2_head-256': {
660
+ 1024: (2, 32),
661
+ 2048: (8, 64),
662
+ 4096: (8, 256),
663
+ 512: (2, 32),
664
+ 8192: (8, 32),
665
+ },
666
+ 'q_head-64_kv_head-1_head-128': {
667
+ 1024: (2, 16),
668
+ 2048: (8, 8),
669
+ 4096: (8, 8),
670
+ 512: (2, 16),
671
+ 8192: (8, 16),
672
+ },
673
+ 'q_head-64_kv_head-1_head-256': {
674
+ 1024: (4, 16),
675
+ 2048: (8, 16),
676
+ 4096: (8, 16),
677
+ 512: (2, 8),
678
+ 8192: (8, 32),
679
+ },
680
+ 'q_head-64_kv_head-16_head-128': {
681
+ 1024: (4, 32),
682
+ 2048: (4, 32),
683
+ 4096: (4, 32),
684
+ 512: (2, 32),
685
+ 8192: (4, 32),
686
+ },
687
+ 'q_head-64_kv_head-16_head-256': {
688
+ 1024: (2, 16),
689
+ 2048: (2, 16),
690
+ 4096: (2, 16),
691
+ 512: (2, 16),
692
+ 8192: (2, 16),
693
+ },
694
+ 'q_head-64_kv_head-2_head-128': {
695
+ 1024: (4, 8),
696
+ 2048: (8, 16),
697
+ 4096: (8, 16),
698
+ 512: (2, 8),
699
+ 8192: (8, 16),
700
+ },
701
+ 'q_head-64_kv_head-2_head-256': {
702
+ 1024: (4, 16),
703
+ 2048: (8, 32),
704
+ 4096: (8, 32),
705
+ 512: (2, 16),
706
+ 8192: (8, 32),
707
+ },
708
+ 'q_head-64_kv_head-4_head-128': {
709
+ 1024: (4, 32),
710
+ 2048: (8, 64),
711
+ 4096: (8, 32),
712
+ 512: (2, 32),
713
+ 8192: (8, 32),
714
+ },
715
+ 'q_head-64_kv_head-4_head-256': {
716
+ 1024: (4, 16),
717
+ 2048: (8, 32),
718
+ 4096: (8, 32),
719
+ 512: (2, 16),
720
+ 8192: (8, 32),
721
+ },
722
+ 'q_head-64_kv_head-8_head-128': {
723
+ 1024: (4, 16),
724
+ 2048: (8, 64),
725
+ 4096: (8, 64),
726
+ 512: (2, 16),
727
+ 8192: (8, 32),
728
+ },
729
+ 'q_head-64_kv_head-8_head-256': {
730
+ 1024: (4, 32),
731
+ 2048: (8, 16),
732
+ 4096: (8, 16),
733
+ 512: (2, 16),
734
+ 8192: (4, 32),
735
+ },
736
+ 'q_head-8_kv_head-1_head-128': {
737
+ 1024: (4, 16),
738
+ 2048: (8, 64),
739
+ 4096: (8, 64),
740
+ 512: (2, 32),
741
+ 8192: (8, 32),
742
+ },
743
+ 'q_head-8_kv_head-1_head-256': {
744
+ 1024: (4, 32),
745
+ 2048: (8, 128),
746
+ 4096: (8, 32),
747
+ 512: (2, 64),
748
+ 8192: (8, 32),
749
+ },
750
+ 'q_head-8_kv_head-2_head-128': {
751
+ 1024: (4, 32),
752
+ 2048: (8, 16),
753
+ 4096: (8, 32),
754
+ 512: (2, 32),
755
+ 8192: (8, 128),
756
+ },
757
+ 'q_head-8_kv_head-2_head-256': {
758
+ 1024: (4, 32),
759
+ 2048: (8, 128),
760
+ 4096: (8, 128),
761
+ 512: (2, 8),
762
+ 8192: (8, 128),
763
+ },
764
+ 'q_head-8_kv_head-4_head-128': {
765
+ 1024: (4, 16),
766
+ 2048: (8, 64),
767
+ 4096: (8, 32),
768
+ 512: (2, 8),
769
+ 8192: (8, 64),
770
+ },
771
+ 'q_head-8_kv_head-4_head-256': {
772
+ 1024: (4, 32),
773
+ 2048: (8, 64),
774
+ 4096: (8, 64),
775
+ 512: (2, 64),
776
+ 8192: (8, 256),
777
+ },
778
+ }
779
+ },
780
+ 64: {
781
+ 'q_bfloat16_kv_bfloat16': {
782
+ 'q_head-128_kv_head-1_head-128': {
783
+ 1024: (8, 16),
784
+ 128: (2, 16),
785
+ 2048: (32, 32),
786
+ 256: (4, 16),
787
+ 4096: (32, 16),
788
+ 512: (8, 8),
789
+ 8192: (32, 16),
790
+ },
791
+ 'q_head-128_kv_head-1_head-256': {
792
+ 1024: (16, 8),
793
+ 128: (2, 8),
794
+ 2048: (32, 8),
795
+ 256: (4, 8),
796
+ 4096: (32, 8),
797
+ 512: (8, 16),
798
+ 8192: (32, 8),
799
+ },
800
+ 'q_head-128_kv_head-16_head-128': {
801
+ 1024: (16, 16),
802
+ 128: (2, 16),
803
+ 2048: (16, 16),
804
+ 256: (4, 16),
805
+ 4096: (16, 16),
806
+ 512: (8, 16),
807
+ 8192: (16, 16),
808
+ },
809
+ 'q_head-128_kv_head-16_head-256': {
810
+ 1024: (8, 8),
811
+ 128: (2, 8),
812
+ 2048: (8, 8),
813
+ 256: (4, 8),
814
+ 4096: (8, 8),
815
+ 512: (8, 8),
816
+ 8192: (8, 8),
817
+ },
818
+ 'q_head-128_kv_head-2_head-128': {
819
+ 1024: (16, 8),
820
+ 128: (2, 8),
821
+ 2048: (32, 16),
822
+ 256: (4, 16),
823
+ 4096: (32, 16),
824
+ 512: (8, 32),
825
+ 8192: (32, 32),
826
+ },
827
+ 'q_head-128_kv_head-2_head-256': {
828
+ 1024: (16, 16),
829
+ 128: (2, 8),
830
+ 2048: (32, 8),
831
+ 256: (4, 8),
832
+ 4096: (32, 16),
833
+ 512: (8, 8),
834
+ 8192: (32, 16),
835
+ },
836
+ 'q_head-128_kv_head-4_head-128': {
837
+ 1024: (16, 8),
838
+ 128: (2, 16),
839
+ 2048: (32, 32),
840
+ 256: (4, 8),
841
+ 4096: (32, 16),
842
+ 512: (8, 16),
843
+ 8192: (32, 32),
844
+ },
845
+ 'q_head-128_kv_head-4_head-256': {
846
+ 1024: (16, 8),
847
+ 128: (2, 8),
848
+ 2048: (32, 16),
849
+ 256: (4, 8),
850
+ 4096: (32, 16),
851
+ 512: (8, 8),
852
+ 8192: (32, 16),
853
+ },
854
+ 'q_head-128_kv_head-8_head-128': {
855
+ 1024: (16, 16),
856
+ 128: (2, 16),
857
+ 2048: (32, 16),
858
+ 256: (4, 16),
859
+ 4096: (32, 32),
860
+ 512: (8, 8),
861
+ 8192: (32, 32),
862
+ },
863
+ 'q_head-128_kv_head-8_head-256': {
864
+ 1024: (16, 8),
865
+ 128: (2, 8),
866
+ 2048: (16, 16),
867
+ 256: (4, 8),
868
+ 4096: (16, 16),
869
+ 512: (8, 8),
870
+ 8192: (16, 16),
871
+ },
872
+ 'q_head-16_kv_head-1_head-128': {
873
+ 1024: (16, 64),
874
+ 128: (2, 16),
875
+ 2048: (32, 128),
876
+ 256: (4, 16),
877
+ 4096: (32, 32),
878
+ 512: (8, 16),
879
+ 8192: (32, 16),
880
+ },
881
+ 'q_head-16_kv_head-1_head-256': {
882
+ 1024: (16, 32),
883
+ 128: (2, 8),
884
+ 2048: (32, 16),
885
+ 256: (4, 8),
886
+ 4096: (32, 32),
887
+ 512: (8, 32),
888
+ 8192: (32, 64),
889
+ },
890
+ 'q_head-16_kv_head-2_head-128': {
891
+ 1024: (16, 16),
892
+ 128: (2, 32),
893
+ 2048: (32, 64),
894
+ 256: (4, 16),
895
+ 4096: (32, 32),
896
+ 512: (8, 32),
897
+ 8192: (32, 128),
898
+ },
899
+ 'q_head-16_kv_head-2_head-256': {
900
+ 1024: (16, 16),
901
+ 128: (2, 16),
902
+ 2048: (32, 16),
903
+ 256: (4, 128),
904
+ 4096: (16, 32),
905
+ 512: (8, 32),
906
+ 8192: (32, 64),
907
+ },
908
+ 'q_head-16_kv_head-4_head-128': {
909
+ 1024: (16, 16),
910
+ 128: (2, 32),
911
+ 2048: (32, 64),
912
+ 256: (4, 128),
913
+ 4096: (32, 32),
914
+ 512: (8, 32),
915
+ 8192: (32, 128),
916
+ },
917
+ 'q_head-16_kv_head-4_head-256': {
918
+ 1024: (16, 32),
919
+ 128: (2, 8),
920
+ 2048: (32, 64),
921
+ 256: (4, 32),
922
+ 4096: (32, 128),
923
+ 512: (8, 8),
924
+ 8192: (32, 64),
925
+ },
926
+ 'q_head-16_kv_head-8_head-128': {
927
+ 1024: (16, 64),
928
+ 128: (2, 32),
929
+ 2048: (32, 32),
930
+ 256: (4, 16),
931
+ 4096: (32, 64),
932
+ 512: (8, 32),
933
+ 8192: (32, 64),
934
+ },
935
+ 'q_head-16_kv_head-8_head-256': {
936
+ 1024: (16, 64),
937
+ 128: (2, 16),
938
+ 2048: (32, 32),
939
+ 256: (4, 64),
940
+ 4096: (32, 64),
941
+ 512: (8, 64),
942
+ 8192: (16, 128),
943
+ },
944
+ 'q_head-2_kv_head-1_head-128': {
945
+ 1024: (16, 32),
946
+ 128: (2, 256),
947
+ 2048: (32, 256),
948
+ 256: (4, 32),
949
+ 4096: (32, 256),
950
+ 512: (8, 256),
951
+ 8192: (32, 256),
952
+ },
953
+ 'q_head-2_kv_head-1_head-256': {
954
+ 1024: (16, 64),
955
+ 128: (2, 32),
956
+ 2048: (32, 128),
957
+ 256: (4, 32),
958
+ 4096: (32, 256),
959
+ 512: (8, 64),
960
+ 8192: (32, 64),
961
+ },
962
+ 'q_head-32_kv_head-1_head-128': {
963
+ 1024: (8, 64),
964
+ 128: (2, 32),
965
+ 2048: (32, 32),
966
+ 256: (4, 128),
967
+ 4096: (32, 16),
968
+ 512: (4, 16),
969
+ 8192: (32, 16),
970
+ },
971
+ 'q_head-32_kv_head-1_head-256': {
972
+ 1024: (16, 16),
973
+ 128: (2, 16),
974
+ 2048: (32, 8),
975
+ 256: (4, 8),
976
+ 4096: (16, 16),
977
+ 512: (8, 8),
978
+ 8192: (32, 32),
979
+ },
980
+ 'q_head-32_kv_head-16_head-128': {
981
+ 1024: (16, 64),
982
+ 128: (2, 64),
983
+ 2048: (16, 64),
984
+ 256: (4, 64),
985
+ 4096: (16, 64),
986
+ 512: (8, 64),
987
+ 8192: (16, 64),
988
+ },
989
+ 'q_head-32_kv_head-16_head-256': {
990
+ 1024: (8, 32),
991
+ 128: (2, 32),
992
+ 2048: (8, 32),
993
+ 256: (4, 32),
994
+ 4096: (8, 32),
995
+ 512: (8, 32),
996
+ 8192: (8, 32),
997
+ },
998
+ 'q_head-32_kv_head-2_head-128': {
999
+ 1024: (16, 64),
1000
+ 128: (2, 16),
1001
+ 2048: (32, 16),
1002
+ 256: (4, 32),
1003
+ 4096: (32, 32),
1004
+ 512: (8, 64),
1005
+ 8192: (32, 32),
1006
+ },
1007
+ 'q_head-32_kv_head-2_head-256': {
1008
+ 1024: (16, 32),
1009
+ 128: (2, 16),
1010
+ 2048: (32, 32),
1011
+ 256: (4, 16),
1012
+ 4096: (32, 16),
1013
+ 512: (8, 32),
1014
+ 8192: (32, 32),
1015
+ },
1016
+ 'q_head-32_kv_head-4_head-128': {
1017
+ 1024: (16, 32),
1018
+ 128: (2, 128),
1019
+ 2048: (32, 32),
1020
+ 256: (4, 16),
1021
+ 4096: (32, 32),
1022
+ 512: (8, 16),
1023
+ 8192: (32, 16),
1024
+ },
1025
+ 'q_head-32_kv_head-4_head-256': {
1026
+ 1024: (16, 32),
1027
+ 128: (2, 16),
1028
+ 2048: (32, 32),
1029
+ 256: (4, 16),
1030
+ 4096: (32, 32),
1031
+ 512: (8, 16),
1032
+ 8192: (32, 32),
1033
+ },
1034
+ 'q_head-32_kv_head-8_head-128': {
1035
+ 1024: (16, 32),
1036
+ 128: (2, 32),
1037
+ 2048: (32, 64),
1038
+ 256: (4, 8),
1039
+ 4096: (32, 32),
1040
+ 512: (8, 32),
1041
+ 8192: (32, 32),
1042
+ },
1043
+ 'q_head-32_kv_head-8_head-256': {
1044
+ 1024: (16, 32),
1045
+ 128: (2, 32),
1046
+ 2048: (32, 32),
1047
+ 256: (4, 32),
1048
+ 4096: (32, 32),
1049
+ 512: (8, 8),
1050
+ 8192: (16, 64),
1051
+ },
1052
+ 'q_head-4_kv_head-1_head-128': {
1053
+ 1024: (16, 32),
1054
+ 128: (2, 16),
1055
+ 2048: (32, 16),
1056
+ 256: (4, 256),
1057
+ 4096: (32, 128),
1058
+ 512: (4, 64),
1059
+ 8192: (32, 32),
1060
+ },
1061
+ 'q_head-4_kv_head-1_head-256': {
1062
+ 1024: (8, 128),
1063
+ 128: (2, 16),
1064
+ 2048: (16, 64),
1065
+ 256: (4, 256),
1066
+ 4096: (32, 64),
1067
+ 512: (8, 32),
1068
+ 8192: (32, 128),
1069
+ },
1070
+ 'q_head-4_kv_head-2_head-128': {
1071
+ 1024: (16, 64),
1072
+ 128: (2, 32),
1073
+ 2048: (32, 32),
1074
+ 256: (4, 256),
1075
+ 4096: (32, 64),
1076
+ 512: (8, 128),
1077
+ 8192: (32, 64),
1078
+ },
1079
+ 'q_head-4_kv_head-2_head-256': {
1080
+ 1024: (16, 64),
1081
+ 128: (2, 32),
1082
+ 2048: (32, 32),
1083
+ 256: (4, 64),
1084
+ 4096: (32, 32),
1085
+ 512: (8, 16),
1086
+ 8192: (32, 64),
1087
+ },
1088
+ 'q_head-64_kv_head-1_head-128': {
1089
+ 1024: (8, 32),
1090
+ 128: (2, 16),
1091
+ 2048: (32, 32),
1092
+ 256: (2, 16),
1093
+ 4096: (32, 32),
1094
+ 512: (8, 32),
1095
+ 8192: (32, 32),
1096
+ },
1097
+ 'q_head-64_kv_head-1_head-256': {
1098
+ 1024: (16, 32),
1099
+ 128: (2, 8),
1100
+ 2048: (32, 8),
1101
+ 256: (4, 8),
1102
+ 4096: (32, 32),
1103
+ 512: (8, 8),
1104
+ 8192: (32, 32),
1105
+ },
1106
+ 'q_head-64_kv_head-16_head-128': {
1107
+ 1024: (16, 32),
1108
+ 128: (2, 16),
1109
+ 2048: (16, 32),
1110
+ 256: (4, 32),
1111
+ 4096: (16, 32),
1112
+ 512: (8, 32),
1113
+ 8192: (16, 32),
1114
+ },
1115
+ 'q_head-64_kv_head-16_head-256': {
1116
+ 1024: (8, 16),
1117
+ 128: (2, 16),
1118
+ 2048: (8, 16),
1119
+ 256: (4, 16),
1120
+ 4096: (8, 16),
1121
+ 512: (8, 16),
1122
+ 8192: (8, 16),
1123
+ },
1124
+ 'q_head-64_kv_head-2_head-128': {
1125
+ 1024: (16, 32),
1126
+ 128: (2, 64),
1127
+ 2048: (32, 8),
1128
+ 256: (4, 16),
1129
+ 4096: (32, 32),
1130
+ 512: (8, 16),
1131
+ 8192: (32, 32),
1132
+ },
1133
+ 'q_head-64_kv_head-2_head-256': {
1134
+ 1024: (16, 16),
1135
+ 128: (2, 8),
1136
+ 2048: (32, 32),
1137
+ 256: (4, 8),
1138
+ 4096: (32, 32),
1139
+ 512: (8, 8),
1140
+ 8192: (32, 16),
1141
+ },
1142
+ 'q_head-64_kv_head-4_head-128': {
1143
+ 1024: (16, 16),
1144
+ 128: (2, 16),
1145
+ 2048: (32, 32),
1146
+ 256: (4, 16),
1147
+ 4096: (32, 32),
1148
+ 512: (8, 32),
1149
+ 8192: (32, 32),
1150
+ },
1151
+ 'q_head-64_kv_head-4_head-256': {
1152
+ 1024: (16, 8),
1153
+ 128: (2, 16),
1154
+ 2048: (32, 16),
1155
+ 256: (4, 32),
1156
+ 4096: (32, 32),
1157
+ 512: (8, 16),
1158
+ 8192: (32, 32),
1159
+ },
1160
+ 'q_head-64_kv_head-8_head-128': {
1161
+ 1024: (16, 32),
1162
+ 128: (2, 32),
1163
+ 2048: (32, 32),
1164
+ 256: (4, 32),
1165
+ 4096: (32, 64),
1166
+ 512: (8, 16),
1167
+ 8192: (32, 32),
1168
+ },
1169
+ 'q_head-64_kv_head-8_head-256': {
1170
+ 1024: (16, 32),
1171
+ 128: (2, 8),
1172
+ 2048: (32, 16),
1173
+ 256: (4, 16),
1174
+ 4096: (32, 16),
1175
+ 512: (8, 16),
1176
+ 8192: (16, 32),
1177
+ },
1178
+ 'q_head-8_kv_head-1_head-128': {
1179
+ 1024: (16, 16),
1180
+ 128: (2, 256),
1181
+ 2048: (16, 32),
1182
+ 256: (4, 128),
1183
+ 4096: (32, 32),
1184
+ 512: (8, 64),
1185
+ 8192: (32, 32),
1186
+ },
1187
+ 'q_head-8_kv_head-1_head-256': {
1188
+ 1024: (16, 8),
1189
+ 128: (2, 64),
1190
+ 2048: (32, 64),
1191
+ 256: (2, 32),
1192
+ 4096: (32, 64),
1193
+ 512: (4, 32),
1194
+ 8192: (32, 32),
1195
+ },
1196
+ 'q_head-8_kv_head-2_head-128': {
1197
+ 1024: (16, 16),
1198
+ 128: (2, 32),
1199
+ 2048: (32, 32),
1200
+ 256: (4, 128),
1201
+ 4096: (32, 128),
1202
+ 512: (8, 64),
1203
+ 8192: (32, 64),
1204
+ },
1205
+ 'q_head-8_kv_head-2_head-256': {
1206
+ 1024: (16, 64),
1207
+ 128: (2, 32),
1208
+ 2048: (32, 32),
1209
+ 256: (4, 16),
1210
+ 4096: (32, 128),
1211
+ 512: (8, 16),
1212
+ 8192: (32, 128),
1213
+ },
1214
+ 'q_head-8_kv_head-4_head-128': {
1215
+ 1024: (16, 64),
1216
+ 128: (2, 16),
1217
+ 2048: (32, 64),
1218
+ 256: (4, 64),
1219
+ 4096: (32, 64),
1220
+ 512: (8, 64),
1221
+ 8192: (32, 128),
1222
+ },
1223
+ 'q_head-8_kv_head-4_head-256': {
1224
+ 1024: (16, 64),
1225
+ 128: (2, 32),
1226
+ 2048: (32, 64),
1227
+ 256: (4, 32),
1228
+ 4096: (32, 64),
1229
+ 512: (8, 64),
1230
+ 8192: (32, 128),
1231
+ },
1232
+ }
1233
+ },
1234
+ },
1235
+ 'TPU v5e': {
1236
+ 128: {
1237
+ 'q_bfloat16_kv_bfloat16': {
1238
+ 'q_head-128_kv_head-1_head-128': {
1239
+ 1024: (4, 32),
1240
+ 128: (1, 8),
1241
+ 2048: (16, 8),
1242
+ 256: (2, 8),
1243
+ 4096: (16, 16),
1244
+ 512: (4, 8),
1245
+ 8192: (16, 16),
1246
+ },
1247
+ 'q_head-128_kv_head-1_head-256': {
1248
+ 1024: (8, 16),
1249
+ 128: (1, 8),
1250
+ 2048: (16, 8),
1251
+ 256: (2, 8),
1252
+ 4096: (16, 8),
1253
+ 512: (2, 8),
1254
+ 8192: (16, 8),
1255
+ },
1256
+ 'q_head-128_kv_head-16_head-128': {
1257
+ 1024: (8, 16),
1258
+ 128: (1, 16),
1259
+ 2048: (8, 16),
1260
+ 256: (2, 8),
1261
+ 4096: (8, 16),
1262
+ 512: (2, 16),
1263
+ 8192: (8, 16),
1264
+ },
1265
+ 'q_head-128_kv_head-16_head-256': {
1266
+ 1024: (4, 8),
1267
+ 128: (1, 8),
1268
+ 2048: (4, 8),
1269
+ 256: (2, 8),
1270
+ 4096: (4, 8),
1271
+ 512: (4, 8),
1272
+ 8192: (4, 8),
1273
+ },
1274
+ 'q_head-128_kv_head-2_head-128': {
1275
+ 1024: (8, 8),
1276
+ 128: (1, 8),
1277
+ 2048: (16, 8),
1278
+ 256: (2, 16),
1279
+ 4096: (8, 16),
1280
+ 512: (4, 16),
1281
+ 8192: (16, 16),
1282
+ },
1283
+ 'q_head-128_kv_head-2_head-256': {
1284
+ 1024: (8, 8),
1285
+ 128: (1, 8),
1286
+ 2048: (16, 8),
1287
+ 256: (2, 8),
1288
+ 4096: (8, 16),
1289
+ 512: (4, 8),
1290
+ 8192: (8, 8),
1291
+ },
1292
+ 'q_head-128_kv_head-4_head-128': {
1293
+ 1024: (8, 8),
1294
+ 128: (1, 16),
1295
+ 2048: (8, 8),
1296
+ 256: (2, 8),
1297
+ 4096: (8, 32),
1298
+ 512: (4, 8),
1299
+ 8192: (8, 16),
1300
+ },
1301
+ 'q_head-128_kv_head-4_head-256': {
1302
+ 1024: (4, 8),
1303
+ 128: (1, 8),
1304
+ 2048: (8, 16),
1305
+ 256: (2, 8),
1306
+ 4096: (8, 16),
1307
+ 512: (4, 8),
1308
+ 8192: (8, 16),
1309
+ },
1310
+ 'q_head-128_kv_head-8_head-128': {
1311
+ 1024: (8, 32),
1312
+ 128: (1, 8),
1313
+ 2048: (8, 16),
1314
+ 256: (2, 16),
1315
+ 4096: (8, 16),
1316
+ 512: (4, 16),
1317
+ 8192: (8, 16),
1318
+ },
1319
+ 'q_head-128_kv_head-8_head-256': {
1320
+ 1024: (4, 16),
1321
+ 128: (1, 8),
1322
+ 2048: (8, 16),
1323
+ 256: (2, 8),
1324
+ 4096: (8, 16),
1325
+ 512: (4, 16),
1326
+ 8192: (4, 16),
1327
+ },
1328
+ 'q_head-16_kv_head-1_head-128': {
1329
+ 2048: (8, 64),
1330
+ 512: (4, 64)
1331
+ },
1332
+ 'q_head-16_kv_head-1_head-256': {
1333
+ 128: (1, 32),
1334
+ 256: (2, 8)
1335
+ },
1336
+ 'q_head-16_kv_head-2_head-128': {
1337
+ 128: (1, 128),
1338
+ 256: (2, 8),
1339
+ 512: (2, 32),
1340
+ 8192: (16, 32),
1341
+ },
1342
+ 'q_head-16_kv_head-2_head-256': {
1343
+ 128: (1, 32),
1344
+ 2048: (8, 32),
1345
+ 256: (2, 32),
1346
+ },
1347
+ 'q_head-16_kv_head-4_head-128': {
1348
+ 1024: (8, 32),
1349
+ 128: (1, 64),
1350
+ 256: (2, 16),
1351
+ 512: (4, 64),
1352
+ },
1353
+ 'q_head-16_kv_head-4_head-256': {
1354
+ 1024: (8, 128),
1355
+ 128: (1, 16),
1356
+ 2048: (8, 64),
1357
+ 256: (2, 32),
1358
+ 4096: (8, 32),
1359
+ 512: (4, 32),
1360
+ 8192: (16, 64),
1361
+ },
1362
+ 'q_head-16_kv_head-8_head-128': {
1363
+ 1024: (8, 256),
1364
+ 128: (1, 128),
1365
+ 2048: (8, 128),
1366
+ 256: (2, 16),
1367
+ 4096: (8, 64),
1368
+ 512: (4, 64),
1369
+ 8192: (4, 128),
1370
+ },
1371
+ 'q_head-16_kv_head-8_head-256': {
1372
+ 1024: (8, 128),
1373
+ 128: (1, 16),
1374
+ 2048: (8, 128),
1375
+ 256: (2, 64),
1376
+ 4096: (8, 128),
1377
+ 512: (2, 32),
1378
+ 8192: (8, 128),
1379
+ },
1380
+ 'q_head-2_kv_head-1_head-128': {
1381
+ 1024: (8, 128),
1382
+ 128: (1, 256),
1383
+ 2048: (8, 32),
1384
+ 256: (2, 8),
1385
+ 512: (4, 256),
1386
+ 8192: (16, 32),
1387
+ },
1388
+ 'q_head-2_kv_head-1_head-256': {
1389
+ 1024: (8, 128),
1390
+ 2048: (8, 64),
1391
+ 256: (2, 8),
1392
+ 4096: (8, 128),
1393
+ 512: (4, 32),
1394
+ 8192: (16, 64),
1395
+ },
1396
+ 'q_head-32_kv_head-1_head-128': {
1397
+ 1024: (8, 16),
1398
+ 128: (1, 128),
1399
+ 2048: (8, 32),
1400
+ 256: (2, 16),
1401
+ 4096: (16, 64),
1402
+ 512: (4, 64),
1403
+ 8192: (16, 16),
1404
+ },
1405
+ 'q_head-32_kv_head-1_head-256': {
1406
+ 1024: (8, 16),
1407
+ 128: (1, 16),
1408
+ 2048: (16, 32),
1409
+ 256: (2, 8),
1410
+ 4096: (16, 16),
1411
+ 512: (4, 16),
1412
+ 8192: (16, 16),
1413
+ },
1414
+ 'q_head-32_kv_head-16_head-128': {
1415
+ 1024: (8, 64),
1416
+ 128: (1, 8),
1417
+ 2048: (8, 64),
1418
+ 256: (2, 32),
1419
+ 4096: (8, 64),
1420
+ 512: (4, 64),
1421
+ 8192: (8, 64),
1422
+ },
1423
+ 'q_head-32_kv_head-16_head-256': {
1424
+ 1024: (4, 32),
1425
+ 128: (1, 8),
1426
+ 2048: (4, 32),
1427
+ 256: (2, 32),
1428
+ 4096: (4, 32),
1429
+ 512: (4, 32),
1430
+ 8192: (4, 32),
1431
+ },
1432
+ 'q_head-32_kv_head-2_head-128': {
1433
+ 1024: (4, 8),
1434
+ 128: (1, 32),
1435
+ 2048: (8, 64),
1436
+ 256: (2, 8),
1437
+ 4096: (16, 32),
1438
+ 512: (4, 32),
1439
+ 8192: (16, 16),
1440
+ },
1441
+ 'q_head-32_kv_head-2_head-256': {
1442
+ 1024: (8, 16),
1443
+ 128: (1, 16),
1444
+ 2048: (8, 32),
1445
+ 256: (2, 16),
1446
+ 4096: (8, 32),
1447
+ 512: (4, 8),
1448
+ 8192: (8, 32),
1449
+ },
1450
+ 'q_head-32_kv_head-4_head-128': {
1451
+ 1024: (8, 64),
1452
+ 128: (1, 32),
1453
+ 2048: (8, 64),
1454
+ 256: (2, 16),
1455
+ 4096: (8, 32),
1456
+ 512: (4, 16),
1457
+ 8192: (8, 32),
1458
+ },
1459
+ 'q_head-32_kv_head-4_head-256': {
1460
+ 1024: (8, 32),
1461
+ 128: (1, 16),
1462
+ 2048: (8, 32),
1463
+ 256: (2, 32),
1464
+ 4096: (8, 32),
1465
+ 512: (4, 16),
1466
+ 8192: (8, 32),
1467
+ },
1468
+ 'q_head-32_kv_head-8_head-128': {
1469
+ 1024: (8, 128),
1470
+ 128: (1, 16),
1471
+ 2048: (4, 32),
1472
+ 256: (1, 16),
1473
+ 4096: (16, 32),
1474
+ 512: (4, 64),
1475
+ 8192: (4, 64),
1476
+ },
1477
+ 'q_head-32_kv_head-8_head-256': {
1478
+ 1024: (8, 32),
1479
+ 128: (1, 8),
1480
+ 2048: (4, 64),
1481
+ 256: (2, 16),
1482
+ 4096: (8, 64),
1483
+ 512: (4, 32),
1484
+ 8192: (8, 64),
1485
+ },
1486
+ 'q_head-4_kv_head-1_head-128': {
1487
+ 1024: (8, 32),
1488
+ 2048: (8, 128),
1489
+ 256: (1, 256),
1490
+ 4096: (16, 128),
1491
+ 512: (4, 128),
1492
+ 8192: (16, 16),
1493
+ },
1494
+ 'q_head-4_kv_head-1_head-256': {
1495
+ 1024: (8, 16),
1496
+ 2048: (8, 32),
1497
+ 4096: (16, 32),
1498
+ 8192: (16, 32),
1499
+ },
1500
+ 'q_head-4_kv_head-2_head-128': {
1501
+ 1024: (8, 64),
1502
+ 128: (1, 64),
1503
+ 2048: (8, 128),
1504
+ 256: (1, 256),
1505
+ 4096: (16, 128),
1506
+ 8192: (8, 32),
1507
+ },
1508
+ 'q_head-4_kv_head-2_head-256': {
1509
+ 1024: (8, 32),
1510
+ 128: (1, 8),
1511
+ 4096: (8, 256),
1512
+ 8192: (8, 128),
1513
+ },
1514
+ 'q_head-64_kv_head-1_head-128': {
1515
+ 1024: (4, 32),
1516
+ 128: (1, 16),
1517
+ 2048: (16, 32),
1518
+ 256: (2, 32),
1519
+ 4096: (16, 32),
1520
+ 512: (4, 16),
1521
+ 8192: (16, 32),
1522
+ },
1523
+ 'q_head-64_kv_head-1_head-256': {
1524
+ 1024: (8, 16),
1525
+ 128: (1, 8),
1526
+ 2048: (16, 8),
1527
+ 256: (2, 16),
1528
+ 4096: (16, 16),
1529
+ 512: (4, 16),
1530
+ 8192: (16, 16),
1531
+ },
1532
+ 'q_head-64_kv_head-16_head-128': {
1533
+ 1024: (4, 32),
1534
+ 128: (1, 16),
1535
+ 2048: (8, 32),
1536
+ 256: (2, 32),
1537
+ 4096: (8, 32),
1538
+ 512: (2, 32),
1539
+ 8192: (8, 32),
1540
+ },
1541
+ 'q_head-64_kv_head-16_head-256': {
1542
+ 1024: (4, 16),
1543
+ 128: (1, 16),
1544
+ 2048: (4, 16),
1545
+ 256: (2, 16),
1546
+ 4096: (4, 16),
1547
+ 512: (4, 16),
1548
+ 8192: (4, 16),
1549
+ },
1550
+ 'q_head-64_kv_head-2_head-128': {
1551
+ 1024: (8, 8),
1552
+ 128: (1, 16),
1553
+ 2048: (8, 16),
1554
+ 256: (1, 16),
1555
+ 4096: (8, 16),
1556
+ 512: (4, 16),
1557
+ 8192: (8, 32),
1558
+ },
1559
+ 'q_head-64_kv_head-2_head-256': {
1560
+ 1024: (4, 8),
1561
+ 128: (1, 8),
1562
+ 2048: (16, 16),
1563
+ 256: (2, 8),
1564
+ 4096: (8, 16),
1565
+ 512: (4, 8),
1566
+ 8192: (8, 16),
1567
+ },
1568
+ 'q_head-64_kv_head-4_head-128': {
1569
+ 1024: (8, 32),
1570
+ 128: (1, 8),
1571
+ 2048: (16, 16),
1572
+ 256: (1, 32),
1573
+ 4096: (8, 32),
1574
+ 512: (4, 32),
1575
+ 8192: (16, 32),
1576
+ },
1577
+ 'q_head-64_kv_head-4_head-256': {
1578
+ 1024: (4, 16),
1579
+ 128: (1, 8),
1580
+ 2048: (8, 32),
1581
+ 256: (1, 8),
1582
+ 4096: (8, 32),
1583
+ 512: (4, 16),
1584
+ 8192: (8, 32),
1585
+ },
1586
+ 'q_head-64_kv_head-8_head-128': {
1587
+ 1024: (8, 16),
1588
+ 128: (1, 32),
1589
+ 2048: (4, 32),
1590
+ 256: (2, 64),
1591
+ 4096: (4, 32),
1592
+ 512: (4, 32),
1593
+ 8192: (16, 32),
1594
+ },
1595
+ 'q_head-64_kv_head-8_head-256': {
1596
+ 1024: (8, 32),
1597
+ 128: (1, 8),
1598
+ 2048: (8, 32),
1599
+ 256: (2, 16),
1600
+ 4096: (4, 32),
1601
+ 512: (4, 16),
1602
+ 8192: (8, 32),
1603
+ },
1604
+ 'q_head-8_kv_head-1_head-128': {
1605
+ 2048: (8, 32),
1606
+ 4096: (8, 16),
1607
+ 512: (4, 128),
1608
+ 8192: (16, 32),
1609
+ },
1610
+ 'q_head-8_kv_head-1_head-256': {
1611
+ 128: (1, 8),
1612
+ 2048: (8, 16),
1613
+ 8192: (8, 32),
1614
+ },
1615
+ 'q_head-8_kv_head-2_head-128': {
1616
+ 128: (1, 64),
1617
+ 256: (2, 64),
1618
+ 4096: (16, 32),
1619
+ 512: (4, 64),
1620
+ 8192: (16, 128),
1621
+ },
1622
+ 'q_head-8_kv_head-2_head-256': {
1623
+ 1024: (8, 128),
1624
+ 128: (1, 32),
1625
+ 8192: (8, 128),
1626
+ },
1627
+ 'q_head-8_kv_head-4_head-128': {
1628
+ 128: (1, 16),
1629
+ 256: (2, 32),
1630
+ 4096: (16, 32),
1631
+ 512: (4, 8),
1632
+ },
1633
+ 'q_head-8_kv_head-4_head-256': {
1634
+ 128: (1, 32),
1635
+ 2048: (8, 128),
1636
+ 256: (2, 32),
1637
+ 512: (4, 16),
1638
+ },
1639
+ }
1640
+ },
1641
+ 256: {
1642
+ 'q_bfloat16_kv_bfloat16': {
1643
+ 'q_head-128_kv_head-1_head-128': {
1644
+ 1024: (2, 16),
1645
+ 2048: (4, 8),
1646
+ 256: (1, 8),
1647
+ 4096: (8, 8),
1648
+ 512: (2, 8),
1649
+ 8192: (8, 16),
1650
+ },
1651
+ 'q_head-128_kv_head-1_head-256': {
1652
+ 1024: (4, 8),
1653
+ 2048: (4, 8),
1654
+ 256: (1, 8),
1655
+ 4096: (8, 8),
1656
+ 512: (2, 8),
1657
+ 8192: (8, 8),
1658
+ },
1659
+ 'q_head-128_kv_head-16_head-128': {
1660
+ 1024: (4, 16),
1661
+ 2048: (4, 16),
1662
+ 256: (1, 16),
1663
+ 4096: (4, 16),
1664
+ 512: (2, 16),
1665
+ 8192: (4, 16),
1666
+ },
1667
+ 'q_head-128_kv_head-16_head-256': {
1668
+ 1024: (2, 8),
1669
+ 2048: (2, 8),
1670
+ 256: (1, 8),
1671
+ 4096: (2, 8),
1672
+ 512: (2, 8),
1673
+ 8192: (2, 8),
1674
+ },
1675
+ 'q_head-128_kv_head-2_head-128': {
1676
+ 1024: (4, 8),
1677
+ 2048: (8, 8),
1678
+ 256: (1, 16),
1679
+ 4096: (8, 8),
1680
+ 512: (2, 8),
1681
+ 8192: (8, 16),
1682
+ },
1683
+ 'q_head-128_kv_head-2_head-256': {
1684
+ 1024: (4, 8),
1685
+ 2048: (4, 8),
1686
+ 256: (1, 8),
1687
+ 4096: (8, 8),
1688
+ 512: (1, 8),
1689
+ 8192: (8, 8),
1690
+ },
1691
+ 'q_head-128_kv_head-4_head-128': {
1692
+ 1024: (4, 16),
1693
+ 2048: (4, 16),
1694
+ 256: (1, 32),
1695
+ 4096: (8, 16),
1696
+ 512: (2, 32),
1697
+ 8192: (4, 16),
1698
+ },
1699
+ 'q_head-128_kv_head-4_head-256': {
1700
+ 1024: (2, 8),
1701
+ 2048: (4, 16),
1702
+ 256: (1, 8),
1703
+ 4096: (8, 8),
1704
+ 512: (2, 8),
1705
+ 8192: (4, 16),
1706
+ },
1707
+ 'q_head-128_kv_head-8_head-128': {
1708
+ 1024: (4, 16),
1709
+ 2048: (4, 32),
1710
+ 256: (1, 32),
1711
+ 4096: (4, 32),
1712
+ 512: (2, 16),
1713
+ 8192: (2, 32),
1714
+ },
1715
+ 'q_head-128_kv_head-8_head-256': {
1716
+ 1024: (4, 16),
1717
+ 2048: (2, 16),
1718
+ 256: (1, 8),
1719
+ 4096: (2, 16),
1720
+ 512: (2, 16),
1721
+ 8192: (2, 16),
1722
+ },
1723
+ 'q_head-16_kv_head-1_head-128': {
1724
+ 1024: (2, 32),
1725
+ 2048: (8, 16),
1726
+ 256: (1, 32),
1727
+ 4096: (8, 32),
1728
+ 512: (1, 64),
1729
+ 8192: (8, 32),
1730
+ },
1731
+ 'q_head-16_kv_head-1_head-256': {
1732
+ 1024: (4, 32),
1733
+ 2048: (4, 16),
1734
+ 256: (1, 32),
1735
+ 4096: (8, 16),
1736
+ 512: (2, 8),
1737
+ 8192: (8, 16),
1738
+ },
1739
+ 'q_head-16_kv_head-2_head-128': {
1740
+ 1024: (4, 16),
1741
+ 2048: (4, 32),
1742
+ 256: (1, 8),
1743
+ 4096: (4, 64),
1744
+ 512: (2, 16),
1745
+ 8192: (8, 128),
1746
+ },
1747
+ 'q_head-16_kv_head-2_head-256': {
1748
+ 1024: (4, 32),
1749
+ 2048: (4, 16),
1750
+ 256: (1, 64),
1751
+ 4096: (8, 32),
1752
+ 512: (2, 16),
1753
+ 8192: (4, 32),
1754
+ },
1755
+ 'q_head-16_kv_head-4_head-128': {
1756
+ 1024: (2, 64),
1757
+ 2048: (2, 64),
1758
+ 256: (1, 64),
1759
+ 4096: (4, 32),
1760
+ 512: (2, 128),
1761
+ 8192: (8, 32),
1762
+ },
1763
+ 'q_head-16_kv_head-4_head-256': {
1764
+ 1024: (2, 64),
1765
+ 2048: (8, 32),
1766
+ 256: (1, 32),
1767
+ 4096: (4, 128),
1768
+ 512: (2, 16),
1769
+ 8192: (4, 32),
1770
+ },
1771
+ 'q_head-16_kv_head-8_head-128': {
1772
+ 1024: (4, 64),
1773
+ 2048: (4, 32),
1774
+ 256: (1, 8),
1775
+ 4096: (2, 128),
1776
+ 512: (2, 64),
1777
+ 8192: (8, 128),
1778
+ },
1779
+ 'q_head-16_kv_head-8_head-256': {
1780
+ 1024: (4, 64),
1781
+ 2048: (4, 128),
1782
+ 256: (1, 16),
1783
+ 4096: (4, 128),
1784
+ 512: (1, 32),
1785
+ 8192: (4, 128),
1786
+ },
1787
+ 'q_head-2_kv_head-1_head-128': {
1788
+ 1024: (4, 64),
1789
+ 2048: (8, 128),
1790
+ 256: (1, 64),
1791
+ 4096: (8, 256),
1792
+ 512: (2, 64),
1793
+ 8192: (8, 256),
1794
+ },
1795
+ 'q_head-2_kv_head-1_head-256': {
1796
+ 1024: (4, 128),
1797
+ 2048: (8, 32),
1798
+ 256: (1, 32),
1799
+ 4096: (8, 256),
1800
+ 512: (2, 32),
1801
+ 8192: (4, 32),
1802
+ },
1803
+ 'q_head-32_kv_head-1_head-128': {
1804
+ 1024: (2, 32),
1805
+ 2048: (4, 16),
1806
+ 256: (1, 64),
1807
+ 4096: (8, 16),
1808
+ 512: (2, 32),
1809
+ 8192: (8, 64),
1810
+ },
1811
+ 'q_head-32_kv_head-1_head-256': {
1812
+ 1024: (4, 8),
1813
+ 2048: (8, 16),
1814
+ 256: (1, 16),
1815
+ 4096: (8, 16),
1816
+ 512: (2, 16),
1817
+ 8192: (8, 16),
1818
+ },
1819
+ 'q_head-32_kv_head-16_head-128': {
1820
+ 1024: (4, 64),
1821
+ 2048: (4, 64),
1822
+ 256: (1, 64),
1823
+ 4096: (4, 64),
1824
+ 512: (2, 32),
1825
+ 8192: (4, 64),
1826
+ },
1827
+ 'q_head-32_kv_head-16_head-256': {
1828
+ 1024: (2, 32),
1829
+ 2048: (2, 32),
1830
+ 256: (1, 32),
1831
+ 4096: (2, 32),
1832
+ 512: (2, 32),
1833
+ 8192: (2, 32),
1834
+ },
1835
+ 'q_head-32_kv_head-2_head-128': {
1836
+ 1024: (4, 16),
1837
+ 2048: (8, 16),
1838
+ 256: (1, 8),
1839
+ 4096: (4, 32),
1840
+ 512: (2, 16),
1841
+ 8192: (8, 32),
1842
+ },
1843
+ 'q_head-32_kv_head-2_head-256': {
1844
+ 1024: (2, 16),
1845
+ 2048: (8, 16),
1846
+ 256: (1, 32),
1847
+ 4096: (8, 16),
1848
+ 512: (2, 16),
1849
+ 8192: (8, 32),
1850
+ },
1851
+ 'q_head-32_kv_head-4_head-128': {
1852
+ 1024: (4, 64),
1853
+ 2048: (8, 32),
1854
+ 256: (1, 16),
1855
+ 4096: (4, 128),
1856
+ 512: (2, 16),
1857
+ 8192: (4, 128),
1858
+ },
1859
+ 'q_head-32_kv_head-4_head-256': {
1860
+ 1024: (4, 16),
1861
+ 2048: (2, 32),
1862
+ 256: (1, 32),
1863
+ 4096: (8, 32),
1864
+ 512: (2, 32),
1865
+ 8192: (4, 32),
1866
+ },
1867
+ 'q_head-32_kv_head-8_head-128': {
1868
+ 1024: (4, 128),
1869
+ 2048: (4, 128),
1870
+ 256: (1, 32),
1871
+ 4096: (4, 128),
1872
+ 512: (2, 16),
1873
+ 8192: (2, 64),
1874
+ },
1875
+ 'q_head-32_kv_head-8_head-256': {
1876
+ 1024: (2, 64),
1877
+ 2048: (2, 32),
1878
+ 256: (1, 16),
1879
+ 4096: (4, 64),
1880
+ 512: (1, 32),
1881
+ 8192: (4, 64),
1882
+ },
1883
+ 'q_head-4_kv_head-1_head-128': {
1884
+ 1024: (4, 16),
1885
+ 2048: (8, 16),
1886
+ 256: (1, 128),
1887
+ 4096: (4, 128),
1888
+ 512: (2, 128),
1889
+ 8192: (8, 32),
1890
+ },
1891
+ 'q_head-4_kv_head-1_head-256': {
1892
+ 1024: (4, 16),
1893
+ 2048: (4, 32),
1894
+ 256: (1, 64),
1895
+ 4096: (8, 64),
1896
+ 512: (2, 64),
1897
+ 8192: (4, 64),
1898
+ },
1899
+ 'q_head-4_kv_head-2_head-128': {
1900
+ 1024: (4, 256),
1901
+ 2048: (8, 128),
1902
+ 256: (1, 64),
1903
+ 4096: (8, 256),
1904
+ 512: (1, 64),
1905
+ 8192: (8, 128),
1906
+ },
1907
+ 'q_head-4_kv_head-2_head-256': {
1908
+ 1024: (4, 32),
1909
+ 2048: (4, 32),
1910
+ 256: (1, 8),
1911
+ 4096: (8, 64),
1912
+ 512: (2, 64),
1913
+ 8192: (4, 64),
1914
+ },
1915
+ 'q_head-64_kv_head-1_head-128': {
1916
+ 1024: (2, 8),
1917
+ 2048: (8, 16),
1918
+ 256: (1, 32),
1919
+ 4096: (8, 16),
1920
+ 512: (2, 16),
1921
+ 8192: (8, 8),
1922
+ },
1923
+ 'q_head-64_kv_head-1_head-256': {
1924
+ 1024: (4, 8),
1925
+ 2048: (8, 8),
1926
+ 256: (1, 8),
1927
+ 4096: (4, 8),
1928
+ 512: (1, 16),
1929
+ 8192: (8, 16),
1930
+ },
1931
+ 'q_head-64_kv_head-16_head-128': {
1932
+ 1024: (2, 32),
1933
+ 2048: (4, 32),
1934
+ 256: (1, 16),
1935
+ 4096: (2, 32),
1936
+ 512: (2, 32),
1937
+ 8192: (4, 32),
1938
+ },
1939
+ 'q_head-64_kv_head-16_head-256': {
1940
+ 1024: (2, 16),
1941
+ 2048: (2, 16),
1942
+ 256: (1, 16),
1943
+ 4096: (2, 16),
1944
+ 512: (2, 16),
1945
+ 8192: (2, 16),
1946
+ },
1947
+ 'q_head-64_kv_head-2_head-128': {
1948
+ 1024: (4, 16),
1949
+ 2048: (8, 16),
1950
+ 256: (1, 8),
1951
+ 4096: (8, 16),
1952
+ 512: (2, 32),
1953
+ 8192: (8, 16),
1954
+ },
1955
+ 'q_head-64_kv_head-2_head-256': {
1956
+ 1024: (2, 8),
1957
+ 2048: (4, 16),
1958
+ 256: (1, 16),
1959
+ 4096: (4, 16),
1960
+ 512: (2, 8),
1961
+ 8192: (4, 32),
1962
+ },
1963
+ 'q_head-64_kv_head-4_head-128': {
1964
+ 1024: (4, 16),
1965
+ 2048: (8, 32),
1966
+ 256: (1, 32),
1967
+ 4096: (8, 32),
1968
+ 512: (2, 64),
1969
+ 8192: (4, 32),
1970
+ },
1971
+ 'q_head-64_kv_head-4_head-256': {
1972
+ 1024: (4, 32),
1973
+ 2048: (8, 16),
1974
+ 256: (1, 16),
1975
+ 4096: (4, 16),
1976
+ 512: (2, 16),
1977
+ 8192: (4, 32),
1978
+ },
1979
+ 'q_head-64_kv_head-8_head-128': {
1980
+ 1024: (4, 16),
1981
+ 2048: (2, 32),
1982
+ 256: (1, 8),
1983
+ 4096: (8, 32),
1984
+ 512: (2, 64),
1985
+ 8192: (4, 32),
1986
+ },
1987
+ 'q_head-64_kv_head-8_head-256': {
1988
+ 1024: (4, 32),
1989
+ 2048: (4, 32),
1990
+ 256: (1, 8),
1991
+ 4096: (4, 32),
1992
+ 512: (2, 16),
1993
+ 8192: (4, 32),
1994
+ },
1995
+ 'q_head-8_kv_head-1_head-128': {
1996
+ 1024: (4, 8),
1997
+ 2048: (8, 64),
1998
+ 256: (1, 32),
1999
+ 4096: (8, 64),
2000
+ 512: (2, 32),
2001
+ 8192: (8, 32),
2002
+ },
2003
+ 'q_head-8_kv_head-1_head-256': {
2004
+ 1024: (2, 16),
2005
+ 2048: (8, 8),
2006
+ 256: (1, 64),
2007
+ 4096: (8, 64),
2008
+ 512: (2, 16),
2009
+ 8192: (8, 64),
2010
+ },
2011
+ 'q_head-8_kv_head-2_head-128': {
2012
+ 1024: (4, 64),
2013
+ 2048: (8, 16),
2014
+ 256: (1, 16),
2015
+ 4096: (8, 32),
2016
+ 512: (2, 128),
2017
+ 8192: (8, 32),
2018
+ },
2019
+ 'q_head-8_kv_head-2_head-256': {
2020
+ 1024: (2, 32),
2021
+ 2048: (2, 32),
2022
+ 256: (1, 32),
2023
+ 4096: (4, 64),
2024
+ 512: (2, 16),
2025
+ 8192: (4, 64),
2026
+ },
2027
+ 'q_head-8_kv_head-4_head-128': {
2028
+ 1024: (4, 256),
2029
+ 2048: (4, 32),
2030
+ 256: (1, 64),
2031
+ 4096: (8, 64),
2032
+ 512: (2, 64),
2033
+ 8192: (4, 64),
2034
+ },
2035
+ 'q_head-8_kv_head-4_head-256': {
2036
+ 1024: (4, 64),
2037
+ 2048: (4, 64),
2038
+ 256: (1, 64),
2039
+ 4096: (4, 128),
2040
+ 512: (2, 64),
2041
+ 8192: (4, 128),
2042
+ },
2043
+ }
2044
+ },
2045
+ 64: {
2046
+ 'q_bfloat16_kv_bfloat16': {
2047
+ 'q_head-128_kv_head-1_head-128': {
2048
+ 1024: (8, 16),
2049
+ 128: (2, 16),
2050
+ 2048: (16, 16),
2051
+ 256: (4, 8),
2052
+ 512: (4, 16),
2053
+ 64: (1, 8),
2054
+ },
2055
+ 'q_head-128_kv_head-1_head-256': {
2056
+ 1024: (16, 8),
2057
+ 2048: (32, 8),
2058
+ 256: (2, 8),
2059
+ 512: (8, 8),
2060
+ 64: (1, 8),
2061
+ 8192: (32, 8),
2062
+ },
2063
+ 'q_head-128_kv_head-16_head-128': {
2064
+ 1024: (16, 16),
2065
+ 128: (2, 16),
2066
+ 256: (2, 8),
2067
+ 512: (8, 16),
2068
+ 64: (1, 8),
2069
+ },
2070
+ 'q_head-128_kv_head-16_head-256': {
2071
+ 128: (2, 8),
2072
+ 256: (4, 8),
2073
+ 4096: (8, 8),
2074
+ 512: (8, 8),
2075
+ 64: (1, 8),
2076
+ },
2077
+ 'q_head-128_kv_head-2_head-128': {
2078
+ 1024: (16, 16),
2079
+ 2048: (16, 8),
2080
+ 256: (4, 8),
2081
+ 4096: (16, 16),
2082
+ 512: (8, 16),
2083
+ 64: (1, 8),
2084
+ 8192: (32, 16),
2085
+ },
2086
+ 'q_head-128_kv_head-2_head-256': {
2087
+ 1024: (16, 8),
2088
+ 2048: (16, 8),
2089
+ 256: (4, 8),
2090
+ 4096: (32, 8),
2091
+ },
2092
+ 'q_head-128_kv_head-4_head-128': {
2093
+ 1024: (16, 8),
2094
+ 128: (1, 8),
2095
+ 2048: (16, 8),
2096
+ 4096: (16, 16),
2097
+ 512: (8, 32),
2098
+ 64: (1, 32),
2099
+ 8192: (16, 32),
2100
+ },
2101
+ 'q_head-128_kv_head-4_head-256': {
2102
+ 1024: (8, 8),
2103
+ 128: (2, 8),
2104
+ 2048: (16, 8),
2105
+ 256: (4, 8),
2106
+ 4096: (32, 32),
2107
+ 64: (1, 8),
2108
+ 8192: (32, 32),
2109
+ },
2110
+ 'q_head-128_kv_head-8_head-128': {
2111
+ 1024: (8, 16),
2112
+ 4096: (8, 16),
2113
+ 64: (1, 8),
2114
+ 8192: (8, 32),
2115
+ },
2116
+ 'q_head-128_kv_head-8_head-256': {
2117
+ 128: (2, 8),
2118
+ 256: (4, 8),
2119
+ 4096: (16, 16),
2120
+ 64: (1, 8),
2121
+ 8192: (8, 16),
2122
+ },
2123
+ 'q_head-16_kv_head-1_head-128': {
2124
+ 1024: (16, 8),
2125
+ 128: (2, 16),
2126
+ 2048: (16, 64),
2127
+ 256: (4, 8),
2128
+ 4096: (32, 64),
2129
+ 512: (8, 16),
2130
+ 64: (1, 128),
2131
+ 8192: (32, 128),
2132
+ },
2133
+ 'q_head-16_kv_head-1_head-256': {
2134
+ 1024: (8, 16),
2135
+ 128: (2, 32),
2136
+ 2048: (32, 8),
2137
+ 256: (4, 64),
2138
+ 4096: (32, 16),
2139
+ 512: (8, 8),
2140
+ 64: (1, 16),
2141
+ 8192: (32, 16),
2142
+ },
2143
+ 'q_head-16_kv_head-2_head-128': {
2144
+ 1024: (16, 16),
2145
+ 128: (2, 64),
2146
+ 2048: (16, 16),
2147
+ 256: (4, 128),
2148
+ 4096: (32, 32),
2149
+ 512: (8, 64),
2150
+ 64: (1, 16),
2151
+ 8192: (32, 64),
2152
+ },
2153
+ 'q_head-16_kv_head-2_head-256': {
2154
+ 1024: (16, 16),
2155
+ 128: (2, 8),
2156
+ 2048: (16, 32),
2157
+ 256: (4, 8),
2158
+ 4096: (8, 32),
2159
+ 512: (8, 16),
2160
+ 64: (1, 8),
2161
+ 8192: (32, 32),
2162
+ },
2163
+ 'q_head-16_kv_head-4_head-128': {
2164
+ 1024: (8, 64),
2165
+ 128: (2, 32),
2166
+ 2048: (16, 32),
2167
+ 256: (4, 128),
2168
+ 4096: (16, 32),
2169
+ 512: (4, 128),
2170
+ 64: (1, 16),
2171
+ 8192: (16, 128),
2172
+ },
2173
+ 'q_head-16_kv_head-4_head-256': {
2174
+ 1024: (16, 32),
2175
+ 128: (2, 32),
2176
+ 2048: (16, 128),
2177
+ 256: (4, 32),
2178
+ 4096: (16, 128),
2179
+ 512: (4, 32),
2180
+ 64: (1, 8),
2181
+ 8192: (16, 32),
2182
+ },
2183
+ 'q_head-16_kv_head-8_head-128': {
2184
+ 1024: (8, 64),
2185
+ 128: (2, 32),
2186
+ 2048: (8, 64),
2187
+ 256: (4, 64),
2188
+ 4096: (32, 64),
2189
+ 512: (8, 8),
2190
+ 64: (1, 16),
2191
+ 8192: (8, 128),
2192
+ },
2193
+ 'q_head-16_kv_head-8_head-256': {
2194
+ 1024: (8, 128),
2195
+ 128: (2, 8),
2196
+ 2048: (8, 64),
2197
+ 256: (4, 32),
2198
+ 4096: (8, 128),
2199
+ 512: (8, 64),
2200
+ 64: (1, 8),
2201
+ 8192: (8, 128),
2202
+ },
2203
+ 'q_head-2_kv_head-1_head-128': {
2204
+ 1024: (16, 256),
2205
+ 128: (1, 8),
2206
+ 2048: (32, 32),
2207
+ 256: (4, 16),
2208
+ 4096: (32, 64),
2209
+ 512: (8, 256),
2210
+ 64: (1, 256),
2211
+ 8192: (32, 128),
2212
+ },
2213
+ 'q_head-2_kv_head-1_head-256': {
2214
+ 1024: (8, 64),
2215
+ 2048: (16, 64),
2216
+ 256: (2, 32),
2217
+ 4096: (32, 128),
2218
+ 512: (8, 32),
2219
+ 8192: (32, 64),
2220
+ },
2221
+ 'q_head-32_kv_head-1_head-128': {
2222
+ 1024: (16, 16),
2223
+ 128: (2, 16),
2224
+ 2048: (16, 16),
2225
+ 256: (4, 8),
2226
+ 4096: (32, 16),
2227
+ 512: (8, 16),
2228
+ 64: (1, 32),
2229
+ 8192: (32, 32),
2230
+ },
2231
+ 'q_head-32_kv_head-1_head-256': {
2232
+ 1024: (8, 16),
2233
+ 128: (2, 16),
2234
+ 2048: (16, 8),
2235
+ 256: (4, 16),
2236
+ 4096: (32, 32),
2237
+ 512: (8, 16),
2238
+ 64: (1, 16),
2239
+ 8192: (32, 16),
2240
+ },
2241
+ 'q_head-32_kv_head-16_head-128': {
2242
+ 1024: (16, 64),
2243
+ 128: (2, 64),
2244
+ 2048: (16, 64),
2245
+ 256: (2, 32),
2246
+ 4096: (16, 64),
2247
+ 512: (8, 32),
2248
+ 64: (1, 8),
2249
+ 8192: (16, 64),
2250
+ },
2251
+ 'q_head-32_kv_head-16_head-256': {
2252
+ 1024: (8, 32),
2253
+ 128: (2, 8),
2254
+ 2048: (8, 32),
2255
+ 256: (4, 8),
2256
+ 4096: (8, 32),
2257
+ 512: (8, 32),
2258
+ 64: (1, 16),
2259
+ 8192: (4, 32),
2260
+ },
2261
+ 'q_head-32_kv_head-2_head-128': {
2262
+ 1024: (16, 16),
2263
+ 128: (2, 32),
2264
+ 2048: (16, 16),
2265
+ 256: (4, 8),
2266
+ 4096: (32, 64),
2267
+ 512: (8, 32),
2268
+ 64: (1, 8),
2269
+ 8192: (32, 64),
2270
+ },
2271
+ 'q_head-32_kv_head-2_head-256': {
2272
+ 1024: (16, 32),
2273
+ 128: (2, 8),
2274
+ 2048: (32, 32),
2275
+ 256: (4, 8),
2276
+ 4096: (16, 32),
2277
+ 512: (8, 32),
2278
+ 64: (1, 8),
2279
+ 8192: (32, 32),
2280
+ },
2281
+ 'q_head-32_kv_head-4_head-128': {
2282
+ 1024: (8, 32),
2283
+ 128: (1, 64),
2284
+ 2048: (32, 16),
2285
+ 256: (4, 32),
2286
+ 4096: (16, 16),
2287
+ 512: (8, 16),
2288
+ 64: (1, 8),
2289
+ 8192: (16, 32),
2290
+ },
2291
+ 'q_head-32_kv_head-4_head-256': {
2292
+ 1024: (8, 32),
2293
+ 128: (2, 16),
2294
+ 2048: (16, 32),
2295
+ 256: (4, 16),
2296
+ 4096: (16, 32),
2297
+ 512: (4, 16),
2298
+ 64: (1, 16),
2299
+ 8192: (16, 32),
2300
+ },
2301
+ 'q_head-32_kv_head-8_head-128': {
2302
+ 1024: (16, 32),
2303
+ 128: (2, 16),
2304
+ 2048: (16, 32),
2305
+ 256: (2, 16),
2306
+ 4096: (32, 32),
2307
+ 512: (8, 32),
2308
+ 64: (1, 16),
2309
+ 8192: (32, 32),
2310
+ },
2311
+ 'q_head-32_kv_head-8_head-256': {
2312
+ 1024: (8, 32),
2313
+ 128: (2, 16),
2314
+ 2048: (8, 64),
2315
+ 256: (4, 16),
2316
+ 4096: (16, 64),
2317
+ 512: (8, 32),
2318
+ 64: (1, 16),
2319
+ 8192: (8, 64),
2320
+ },
2321
+ 'q_head-4_kv_head-1_head-128': {
2322
+ 1024: (16, 32),
2323
+ 128: (2, 16),
2324
+ 2048: (32, 128),
2325
+ 256: (4, 8),
2326
+ 4096: (32, 16),
2327
+ 512: (4, 32),
2328
+ 64: (1, 32),
2329
+ 8192: (32, 128),
2330
+ },
2331
+ 'q_head-4_kv_head-1_head-256': {
2332
+ 1024: (16, 128),
2333
+ 128: (1, 32),
2334
+ 2048: (32, 32),
2335
+ 256: (4, 32),
2336
+ 4096: (32, 64),
2337
+ 512: (8, 64),
2338
+ 64: (1, 128),
2339
+ 8192: (32, 64),
2340
+ },
2341
+ 'q_head-4_kv_head-2_head-128': {
2342
+ 1024: (16, 256),
2343
+ 128: (2, 256),
2344
+ 2048: (32, 32),
2345
+ 256: (4, 8),
2346
+ 4096: (32, 64),
2347
+ 512: (8, 32),
2348
+ 64: (1, 32),
2349
+ 8192: (32, 64),
2350
+ },
2351
+ 'q_head-4_kv_head-2_head-256': {
2352
+ 1024: (8, 64),
2353
+ 128: (2, 32),
2354
+ 2048: (32, 128),
2355
+ 256: (4, 8),
2356
+ 4096: (32, 128),
2357
+ 512: (8, 16),
2358
+ 64: (1, 16),
2359
+ 8192: (16, 128),
2360
+ },
2361
+ 'q_head-64_kv_head-1_head-128': {
2362
+ 1024: (16, 16),
2363
+ 128: (2, 16),
2364
+ 2048: (32, 16),
2365
+ 256: (4, 8),
2366
+ 4096: (32, 16),
2367
+ 512: (8, 8),
2368
+ 64: (1, 16),
2369
+ },
2370
+ 'q_head-64_kv_head-1_head-256': {
2371
+ 1024: (16, 16),
2372
+ 128: (2, 16),
2373
+ 2048: (32, 8),
2374
+ 256: (2, 8),
2375
+ 4096: (32, 8),
2376
+ 512: (8, 8),
2377
+ 64: (1, 8),
2378
+ },
2379
+ 'q_head-64_kv_head-16_head-128': {
2380
+ 1024: (16, 32),
2381
+ 128: (2, 16),
2382
+ 256: (4, 16),
2383
+ 4096: (8, 32),
2384
+ 512: (8, 16),
2385
+ 64: (1, 16),
2386
+ 8192: (16, 32),
2387
+ },
2388
+ 'q_head-64_kv_head-16_head-256': {
2389
+ 1024: (4, 16),
2390
+ 128: (2, 16),
2391
+ 2048: (8, 16),
2392
+ 256: (4, 16),
2393
+ 4096: (8, 16),
2394
+ 512: (8, 16),
2395
+ 64: (1, 16),
2396
+ 8192: (8, 16),
2397
+ },
2398
+ 'q_head-64_kv_head-2_head-128': {
2399
+ 1024: (16, 16),
2400
+ 128: (2, 32),
2401
+ 2048: (32, 32),
2402
+ 256: (4, 16),
2403
+ 4096: (32, 16),
2404
+ 512: (8, 64),
2405
+ 64: (1, 32),
2406
+ },
2407
+ 'q_head-64_kv_head-2_head-256': {
2408
+ 1024: (16, 16),
2409
+ 128: (2, 16),
2410
+ 2048: (32, 16),
2411
+ 256: (4, 8),
2412
+ 4096: (16, 16),
2413
+ 512: (8, 8),
2414
+ 64: (1, 8),
2415
+ 8192: (32, 16),
2416
+ },
2417
+ 'q_head-64_kv_head-4_head-128': {
2418
+ 1024: (8, 16),
2419
+ 128: (1, 8),
2420
+ 2048: (16, 32),
2421
+ 256: (4, 8),
2422
+ 4096: (16, 16),
2423
+ 512: (8, 64),
2424
+ 64: (1, 8),
2425
+ 8192: (16, 32),
2426
+ },
2427
+ 'q_head-64_kv_head-4_head-256': {
2428
+ 1024: (16, 16),
2429
+ 2048: (16, 32),
2430
+ 256: (4, 8),
2431
+ 4096: (16, 16),
2432
+ 64: (1, 8),
2433
+ 8192: (16, 32),
2434
+ },
2435
+ 'q_head-64_kv_head-8_head-128': {
2436
+ 1024: (16, 64),
2437
+ 128: (2, 16),
2438
+ 2048: (16, 32),
2439
+ 256: (4, 16),
2440
+ 4096: (16, 64),
2441
+ 64: (1, 32),
2442
+ 8192: (16, 32),
2443
+ },
2444
+ 'q_head-64_kv_head-8_head-256': {
2445
+ 1024: (8, 32),
2446
+ 128: (2, 8),
2447
+ 2048: (16, 32),
2448
+ 256: (4, 16),
2449
+ 4096: (16, 32),
2450
+ 512: (8, 32),
2451
+ 64: (1, 8),
2452
+ 8192: (16, 32),
2453
+ },
2454
+ 'q_head-8_kv_head-1_head-128': {
2455
+ 1024: (16, 64),
2456
+ 128: (2, 64),
2457
+ 2048: (32, 32),
2458
+ 256: (4, 128),
2459
+ 4096: (32, 32),
2460
+ 512: (8, 8),
2461
+ 64: (1, 128),
2462
+ 8192: (32, 32),
2463
+ },
2464
+ 'q_head-8_kv_head-1_head-256': {
2465
+ 1024: (16, 64),
2466
+ 128: (2, 32),
2467
+ 2048: (32, 32),
2468
+ 256: (4, 16),
2469
+ 4096: (32, 64),
2470
+ 512: (8, 8),
2471
+ 64: (1, 32),
2472
+ 8192: (32, 32),
2473
+ },
2474
+ 'q_head-8_kv_head-2_head-128': {
2475
+ 1024: (16, 64),
2476
+ 128: (2, 64),
2477
+ 2048: (32, 32),
2478
+ 256: (4, 128),
2479
+ 4096: (32, 32),
2480
+ 512: (8, 128),
2481
+ 64: (1, 16),
2482
+ 8192: (32, 32),
2483
+ },
2484
+ 'q_head-8_kv_head-2_head-256': {
2485
+ 1024: (16, 128),
2486
+ 128: (2, 64),
2487
+ 2048: (32, 32),
2488
+ 256: (4, 8),
2489
+ 4096: (16, 32),
2490
+ 512: (8, 64),
2491
+ 64: (1, 16),
2492
+ 8192: (32, 128),
2493
+ },
2494
+ 'q_head-8_kv_head-4_head-128': {
2495
+ 1024: (16, 32),
2496
+ 128: (2, 32),
2497
+ 2048: (32, 64),
2498
+ 256: (4, 32),
2499
+ 4096: (16, 64),
2500
+ 512: (8, 64),
2501
+ 64: (1, 16),
2502
+ 8192: (16, 64),
2503
+ },
2504
+ 'q_head-8_kv_head-4_head-256': {
2505
+ 1024: (8, 32),
2506
+ 128: (2, 32),
2507
+ 2048: (8, 128),
2508
+ 256: (4, 64),
2509
+ 4096: (8, 128),
2510
+ 512: (8, 128),
2511
+ 64: (1, 64),
2512
+ 8192: (8, 128),
2513
+ },
2514
+ }
2515
+ },
2516
+ },
2517
+ 'TPU v7': {
2518
+ 256: {
2519
+ 'q_bfloat16_kv_float8_e4m3fn': {
2520
+ 'q_head-4_kv_head-2_head-128': {
2521
+ 8192: (8, 128),
2522
+ 256: (1, 64),
2523
+ 512: (2, 128),
2524
+ 1024: (4, 256),
2525
+ 2048: (8, 64),
2526
+ 4096: (16, 256),
2527
+ },
2528
+ 'q_head-8_kv_head-2_head-128': {
2529
+ 8192: (16, 32),
2530
+ 256: (1, 64),
2531
+ 512: (2, 32),
2532
+ 1024: (4, 64),
2533
+ 2048: (8, 64),
2534
+ 4096: (16, 128),
2535
+ },
2536
+ 'q_head-32_kv_head-16_head-128': {
2537
+ 8192: (8, 32),
2538
+ 256: (1, 32),
2539
+ 512: (2, 32),
2540
+ 1024: (4, 16),
2541
+ 2048: (8, 32),
2542
+ 4096: (16, 16),
2543
+ },
2544
+ 'q_head-32_kv_head-8_head-128': {
2545
+ 8192: (8, 32),
2546
+ 256: (1, 16),
2547
+ 512: (2, 16),
2548
+ 1024: (4, 32),
2549
+ 2048: (8, 64),
2550
+ 4096: (8, 32),
2551
+ },
2552
+ 'q_head-64_kv_head-4_head-128': {
2553
+ 8192: (8, 32),
2554
+ 256: (1, 32),
2555
+ 512: (2, 16),
2556
+ 1024: (4, 64),
2557
+ 2048: (8, 8),
2558
+ 4096: (8, 32),
2559
+ },
2560
+ 'q_head-128_kv_head-2_head-128': {
2561
+ 8192: (4, 16),
2562
+ 256: (1, 16),
2563
+ 512: (2, 16),
2564
+ 1024: (4, 8),
2565
+ 2048: (4, 8),
2566
+ 4096: (8, 16),
2567
+ },
2568
+ 'q_head-64_kv_head-8_head-128': {
2569
+ 256: (1, 16),
2570
+ 512: (2, 32),
2571
+ 1024: (4, 32),
2572
+ 2048: (8, 16),
2573
+ 4096: (8, 32),
2574
+ 8192: (8, 32),
2575
+ },
2576
+ 'q_head-8_kv_head-4_head-128': {
2577
+ 256: (1, 128),
2578
+ 512: (2, 128),
2579
+ 1024: (4, 16),
2580
+ 2048: (8, 16),
2581
+ 4096: (16, 32),
2582
+ 8192: (16, 128),
2583
+ },
2584
+ 'q_head-16_kv_head-2_head-128': {
2585
+ 256: (1, 128),
2586
+ 512: (2, 16),
2587
+ 1024: (4, 8),
2588
+ 2048: (8, 16),
2589
+ 4096: (16, 32),
2590
+ 8192: (16, 32),
2591
+ },
2592
+ 'q_head-128_kv_head-4_head-128': {
2593
+ 256: (1, 16),
2594
+ 512: (2, 8),
2595
+ 1024: (4, 16),
2596
+ 2048: (8, 32),
2597
+ 4096: (8, 16),
2598
+ 8192: (8, 16),
2599
+ },
2600
+ 'q_head-64_kv_head-32_head-128': {
2601
+ 256: (1, 8),
2602
+ 512: (2, 8),
2603
+ 1024: (4, 8),
2604
+ 2048: (4, 8),
2605
+ 4096: (8, 8),
2606
+ 8192: (8, 8),
2607
+ },
2608
+ 'q_head-64_kv_head-16_head-128': {
2609
+ 256: (1, 16),
2610
+ 512: (2, 16),
2611
+ 1024: (4, 16),
2612
+ 2048: (8, 16),
2613
+ 4096: (8, 16),
2614
+ 8192: (8, 16),
2615
+ },
2616
+ 'q_head-128_kv_head-8_head-128': {
2617
+ 256: (1, 8),
2618
+ 512: (2, 8),
2619
+ 1024: (4, 16),
2620
+ 2048: (8, 16),
2621
+ 4096: (8, 16),
2622
+ 8192: (8, 16),
2623
+ },
2624
+ 'q_head-2_kv_head-2_head-128': {
2625
+ 256: (1, 256),
2626
+ 512: (2, 128),
2627
+ 1024: (4, 16),
2628
+ 2048: (8, 128),
2629
+ 4096: (16, 32),
2630
+ 8192: (16, 32),
2631
+ },
2632
+ 'q_head-32_kv_head-2_head-128': {
2633
+ 256: (1, 32),
2634
+ 512: (2, 8),
2635
+ 1024: (4, 8),
2636
+ 2048: (8, 32),
2637
+ 4096: (16, 16),
2638
+ 8192: (16, 16),
2639
+ },
2640
+ 'q_head-16_kv_head-4_head-128': {
2641
+ 256: (1, 64),
2642
+ 512: (2, 128),
2643
+ 1024: (4, 32),
2644
+ 2048: (8, 16),
2645
+ 4096: (8, 32),
2646
+ 8192: (16, 32),
2647
+ },
2648
+ 'q_head-64_kv_head-2_head-128': {
2649
+ 256: (1, 16),
2650
+ 512: (2, 16),
2651
+ 1024: (4, 64),
2652
+ 2048: (8, 8),
2653
+ 4096: (16, 16),
2654
+ 8192: (8, 16),
2655
+ },
2656
+ 'q_head-128_kv_head-16_head-128': {
2657
+ 256: (1, 8),
2658
+ 512: (2, 8),
2659
+ 1024: (4, 8),
2660
+ 2048: (8, 8),
2661
+ 4096: (8, 8),
2662
+ 8192: (8, 8),
2663
+ },
2664
+ 'q_head-16_kv_head-8_head-128': {
2665
+ 256: (1, 32),
2666
+ 512: (2, 64),
2667
+ 1024: (4, 64),
2668
+ 2048: (8, 64),
2669
+ 4096: (16, 64),
2670
+ 8192: (8, 128),
2671
+ },
2672
+ 'q_head-32_kv_head-4_head-128': {
2673
+ 256: (1, 16),
2674
+ 512: (2, 16),
2675
+ 1024: (4, 32),
2676
+ 2048: (8, 32),
2677
+ 4096: (8, 128),
2678
+ 8192: (16, 16),
2679
+ },
2680
+ },
2681
+ 'q_bfloat16_kv_bfloat16': {
2682
+ 'q_head-8_kv_head-4_head-256': {
2683
+ 2048: (8, 64),
2684
+ 4096: (16, 32),
2685
+ 8192: (16, 32),
2686
+ 256: (1, 8),
2687
+ 512: (2, 64),
2688
+ 1024: (4, 16),
2689
+ },
2690
+ 'q_head-16_kv_head-4_head-128': {
2691
+ 256: (1, 32),
2692
+ 512: (2, 16),
2693
+ 1024: (4, 64),
2694
+ 2048: (8, 16),
2695
+ 4096: (16, 64),
2696
+ 8192: (16, 32),
2697
+ },
2698
+ 'q_head-32_kv_head-16_head-256': {
2699
+ 4096: (2, 16),
2700
+ 8192: (4, 16),
2701
+ 256: (1, 16),
2702
+ 512: (2, 16),
2703
+ 1024: (2, 16),
2704
+ 2048: (2, 16),
2705
+ },
2706
+ 'q_head-32_kv_head-2_head-256': {
2707
+ 1024: (4, 16),
2708
+ 2048: (8, 8),
2709
+ 4096: (16, 32),
2710
+ 8192: (16, 16),
2711
+ 256: (1, 8),
2712
+ 512: (2, 16),
2713
+ },
2714
+ 'q_head-64_kv_head-2_head-128': {
2715
+ 4096: (16, 16),
2716
+ 8192: (16, 16),
2717
+ 256: (1, 16),
2718
+ 512: (2, 8),
2719
+ 1024: (4, 32),
2720
+ 2048: (8, 8),
2721
+ },
2722
+ 'q_head-64_kv_head-16_head-128': {
2723
+ 256: (1, 16),
2724
+ 512: (2, 16),
2725
+ 1024: (4, 16),
2726
+ 2048: (4, 16),
2727
+ 4096: (4, 16),
2728
+ 8192: (8, 16),
2729
+ },
2730
+ 'q_head-128_kv_head-8_head-256': {
2731
+ 1024: (2, 8),
2732
+ 2048: (8, 8),
2733
+ 4096: (8, 8),
2734
+ 8192: (8, 8),
2735
+ 256: (1, 8),
2736
+ 512: (2, 8),
2737
+ },
2738
+ 'q_head-4_kv_head-2_head-128': {
2739
+ 2048: (8, 16),
2740
+ 4096: (16, 8),
2741
+ 8192: (8, 64),
2742
+ 256: (1, 64),
2743
+ 512: (2, 256),
2744
+ 1024: (4, 64),
2745
+ },
2746
+ 'q_head-4_kv_head-1_head-256': {
2747
+ 8192: (16, 16),
2748
+ 256: (1, 128),
2749
+ 512: (2, 64),
2750
+ 1024: (4, 8),
2751
+ 2048: (8, 32),
2752
+ 4096: (16, 32),
2753
+ },
2754
+ 'q_head-128_kv_head-2_head-128': {
2755
+ 256: (1, 8),
2756
+ 512: (2, 8),
2757
+ 1024: (4, 32),
2758
+ 2048: (8, 8),
2759
+ 4096: (8, 16),
2760
+ 8192: (8, 16),
2761
+ },
2762
+ 'q_head-64_kv_head-2_head-256': {
2763
+ 256: (1, 8),
2764
+ 512: (2, 16),
2765
+ 1024: (4, 16),
2766
+ 2048: (8, 16),
2767
+ 4096: (16, 16),
2768
+ 8192: (16, 16),
2769
+ },
2770
+ 'q_head-128_kv_head-16_head-128': {
2771
+ 256: (1, 8),
2772
+ 512: (2, 8),
2773
+ 1024: (4, 8),
2774
+ 2048: (4, 8),
2775
+ 4096: (4, 8),
2776
+ 8192: (8, 8),
2777
+ },
2778
+ 'q_head-4_kv_head-2_head-256': {
2779
+ 256: (1, 64),
2780
+ 512: (2, 32),
2781
+ 1024: (4, 16),
2782
+ 2048: (8, 32),
2783
+ 4096: (16, 32),
2784
+ 8192: (16, 128),
2785
+ },
2786
+ 'q_head-32_kv_head-4_head-128': {
2787
+ 256: (1, 32),
2788
+ 512: (2, 8),
2789
+ 1024: (4, 64),
2790
+ 2048: (8, 64),
2791
+ 4096: (16, 32),
2792
+ 8192: (16, 32),
2793
+ },
2794
+ 'q_head-8_kv_head-1_head-128': {
2795
+ 256: (1, 64),
2796
+ 512: (2, 16),
2797
+ 1024: (4, 32),
2798
+ 2048: (8, 32),
2799
+ 4096: (8, 32),
2800
+ 8192: (16, 32),
2801
+ },
2802
+ 'q_head-64_kv_head-16_head-256': {
2803
+ 256: (1, 8),
2804
+ 512: (2, 8),
2805
+ 1024: (2, 8),
2806
+ 2048: (2, 8),
2807
+ 4096: (4, 8),
2808
+ 8192: (4, 8),
2809
+ },
2810
+ 'q_head-16_kv_head-4_head-256': {
2811
+ 256: (1, 8),
2812
+ 512: (2, 8),
2813
+ 1024: (4, 32),
2814
+ 2048: (8, 32),
2815
+ 4096: (16, 32),
2816
+ 8192: (16, 32),
2817
+ },
2818
+ 'q_head-16_kv_head-2_head-256': {
2819
+ 256: (1, 8),
2820
+ 512: (2, 16),
2821
+ 1024: (4, 8),
2822
+ 2048: (4, 32),
2823
+ 4096: (16, 32),
2824
+ 8192: (16, 32),
2825
+ },
2826
+ 'q_head-32_kv_head-16_head-128': {
2827
+ 4096: (4, 32),
2828
+ 8192: (8, 32),
2829
+ 256: (1, 32),
2830
+ 512: (2, 16),
2831
+ 1024: (4, 32),
2832
+ 2048: (4, 32),
2833
+ },
2834
+ 'q_head-32_kv_head-2_head-128': {
2835
+ 1024: (4, 16),
2836
+ 2048: (8, 128),
2837
+ 256: (1, 16),
2838
+ 4096: (16, 32),
2839
+ 512: (2, 8),
2840
+ 8192: (16, 32),
2841
+ },
2842
+ 'q_head-64_kv_head-1_head-256': {
2843
+ 4096: (16, 8),
2844
+ 8192: (16, 16),
2845
+ 256: (1, 8),
2846
+ 512: (2, 8),
2847
+ 1024: (4, 16),
2848
+ 2048: (8, 8),
2849
+ },
2850
+ 'q_head-64_kv_head-8_head-256': {
2851
+ 256: (1, 8),
2852
+ 512: (2, 8),
2853
+ 1024: (4, 16),
2854
+ 2048: (8, 8),
2855
+ 4096: (8, 16),
2856
+ 8192: (8, 16),
2857
+ },
2858
+ 'q_head-128_kv_head-8_head-128': {
2859
+ 2048: (8, 16),
2860
+ 4096: (8, 16),
2861
+ 8192: (8, 16),
2862
+ 256: (1, 8),
2863
+ 512: (2, 16),
2864
+ 1024: (4, 16),
2865
+ },
2866
+ 'q_head-2_kv_head-1_head-256': {
2867
+ 2048: (8, 64),
2868
+ 4096: (8, 64),
2869
+ 8192: (16, 16),
2870
+ 256: (1, 64),
2871
+ 512: (2, 128),
2872
+ 1024: (4, 256),
2873
+ },
2874
+ 'q_head-4_kv_head-1_head-128': {
2875
+ 8192: (16, 32),
2876
+ 256: (1, 128),
2877
+ 512: (2, 32),
2878
+ 1024: (4, 128),
2879
+ 2048: (8, 64),
2880
+ 4096: (16, 32),
2881
+ },
2882
+ 'q_head-64_kv_head-32_head-128': {
2883
+ 256: (1, 8),
2884
+ 512: (2, 8),
2885
+ 1024: (2, 8),
2886
+ 2048: (2, 8),
2887
+ 4096: (2, 8),
2888
+ 8192: (2, 8),
2889
+ },
2890
+ 'q_head-128_kv_head-2_head-256': {
2891
+ 256: (1, 8),
2892
+ 512: (2, 8),
2893
+ 1024: (4, 16),
2894
+ 2048: (8, 16),
2895
+ 4096: (8, 8),
2896
+ 8192: (8, 16),
2897
+ },
2898
+ 'q_head-16_kv_head-8_head-128': {
2899
+ 256: (1, 64),
2900
+ 512: (2, 32),
2901
+ 1024: (4, 64),
2902
+ 2048: (8, 128),
2903
+ 4096: (16, 32),
2904
+ 8192: (16, 64),
2905
+ },
2906
+ 'q_head-64_kv_head-4_head-128': {
2907
+ 256: (1, 16),
2908
+ 512: (2, 8),
2909
+ 1024: (4, 16),
2910
+ 2048: (8, 16),
2911
+ 4096: (16, 16),
2912
+ 8192: (16, 16),
2913
+ },
2914
+ 'q_head-16_kv_head-1_head-128': {
2915
+ 256: (1, 32),
2916
+ 512: (2, 32),
2917
+ 1024: (4, 32),
2918
+ 2048: (8, 16),
2919
+ 4096: (16, 32),
2920
+ 8192: (16, 32),
2921
+ },
2922
+ 'q_head-32_kv_head-4_head-256': {
2923
+ 256: (1, 8),
2924
+ 512: (2, 32),
2925
+ 1024: (4, 16),
2926
+ 2048: (8, 32),
2927
+ 4096: (16, 16),
2928
+ 8192: (16, 32),
2929
+ },
2930
+ 'q_head-8_kv_head-1_head-256': {
2931
+ 256: (1, 32),
2932
+ 512: (2, 64),
2933
+ 1024: (4, 16),
2934
+ 2048: (8, 8),
2935
+ 4096: (16, 64),
2936
+ 8192: (16, 16),
2937
+ },
2938
+ 'q_head-128_kv_head-4_head-128': {
2939
+ 256: (1, 8),
2940
+ 512: (2, 16),
2941
+ 1024: (4, 16),
2942
+ 2048: (8, 16),
2943
+ 4096: (8, 16),
2944
+ 8192: (8, 8),
2945
+ },
2946
+ 'q_head-16_kv_head-8_head-256': {
2947
+ 256: (1, 16),
2948
+ 512: (2, 64),
2949
+ 1024: (4, 64),
2950
+ 2048: (8, 64),
2951
+ 4096: (8, 32),
2952
+ 8192: (8, 64),
2953
+ },
2954
+ 'q_head-8_kv_head-4_head-128': {
2955
+ 256: (1, 128),
2956
+ 512: (2, 8),
2957
+ 1024: (4, 8),
2958
+ 2048: (8, 128),
2959
+ 4096: (16, 32),
2960
+ 8192: (16, 32),
2961
+ },
2962
+ 'q_head-32_kv_head-1_head-128': {
2963
+ 256: (1, 32),
2964
+ 512: (2, 32),
2965
+ 1024: (4, 16),
2966
+ 2048: (8, 8),
2967
+ 4096: (16, 32),
2968
+ 8192: (8, 16),
2969
+ },
2970
+ 'q_head-64_kv_head-4_head-256': {
2971
+ 256: (1, 8),
2972
+ 512: (2, 8),
2973
+ 1024: (4, 16),
2974
+ 2048: (8, 16),
2975
+ 4096: (8, 32),
2976
+ 8192: (8, 32),
2977
+ },
2978
+ 'q_head-2_kv_head-1_head-128': {
2979
+ 256: (1, 32),
2980
+ 512: (2, 16),
2981
+ 1024: (4, 128),
2982
+ 2048: (8, 64),
2983
+ 4096: (16, 64),
2984
+ 8192: (16, 64),
2985
+ },
2986
+ 'q_head-16_kv_head-1_head-256': {
2987
+ 256: (1, 8),
2988
+ 512: (2, 32),
2989
+ 1024: (4, 8),
2990
+ 2048: (4, 8),
2991
+ 4096: (16, 8),
2992
+ 8192: (16, 32),
2993
+ },
2994
+ 'q_head-32_kv_head-8_head-128': {
2995
+ 256: (1, 32),
2996
+ 512: (2, 16),
2997
+ 1024: (4, 64),
2998
+ 2048: (8, 64),
2999
+ 4096: (16, 32),
3000
+ 8192: (16, 32),
3001
+ },
3002
+ 'q_head-8_kv_head-2_head-128': {
3003
+ 256: (1, 16),
3004
+ 512: (2, 16),
3005
+ 1024: (4, 32),
3006
+ 2048: (8, 32),
3007
+ 4096: (16, 64),
3008
+ 8192: (16, 32),
3009
+ },
3010
+ 'q_head-64_kv_head-1_head-128': {
3011
+ 256: (1, 8),
3012
+ 512: (2, 16),
3013
+ 1024: (4, 8),
3014
+ 2048: (8, 16),
3015
+ 4096: (8, 8),
3016
+ 8192: (16, 8),
3017
+ },
3018
+ 'q_head-128_kv_head-4_head-256': {
3019
+ 256: (1, 8),
3020
+ 512: (2, 16),
3021
+ 1024: (4, 8),
3022
+ 2048: (8, 16),
3023
+ 4096: (8, 8),
3024
+ 8192: (8, 16),
3025
+ },
3026
+ 'q_head-32_kv_head-1_head-256': {
3027
+ 256: (1, 8),
3028
+ 512: (2, 16),
3029
+ 1024: (4, 8),
3030
+ 2048: (8, 16),
3031
+ 4096: (16, 16),
3032
+ 8192: (8, 16),
3033
+ },
3034
+ 'q_head-64_kv_head-8_head-128': {
3035
+ 256: (1, 16),
3036
+ 512: (2, 16),
3037
+ 1024: (4, 32),
3038
+ 2048: (8, 32),
3039
+ 4096: (8, 32),
3040
+ 8192: (16, 16),
3041
+ },
3042
+ 'q_head-16_kv_head-2_head-128': {
3043
+ 256: (1, 16),
3044
+ 512: (2, 16),
3045
+ 1024: (4, 8),
3046
+ 2048: (8, 32),
3047
+ 4096: (16, 32),
3048
+ 8192: (16, 16),
3049
+ },
3050
+ 'q_head-32_kv_head-8_head-256': {
3051
+ 256: (1, 16),
3052
+ 512: (2, 32),
3053
+ 1024: (4, 32),
3054
+ 2048: (8, 32),
3055
+ 4096: (8, 32),
3056
+ 8192: (8, 32),
3057
+ },
3058
+ 'q_head-8_kv_head-2_head-256': {
3059
+ 256: (1, 8),
3060
+ 512: (2, 64),
3061
+ 1024: (4, 64),
3062
+ 2048: (8, 32),
3063
+ 4096: (16, 64),
3064
+ 8192: (16, 128),
3065
+ },
3066
+ 'q_head-128_kv_head-1_head-128': {
3067
+ 256: (1, 16),
3068
+ 512: (2, 8),
3069
+ 1024: (4, 8),
3070
+ 2048: (4, 8),
3071
+ 4096: (16, 8),
3072
+ 8192: (16, 8),
3073
+ },
3074
+ 'q_head-128_kv_head-1_head-256': {
3075
+ 256: (1, 8),
3076
+ 512: (2, 8),
3077
+ 1024: (2, 8),
3078
+ 2048: (4, 8),
3079
+ 4096: (8, 8),
3080
+ 8192: (16, 8),
3081
+ },
3082
+ },
3083
+ },
3084
+ 128: {
3085
+ 'q_bfloat16_kv_float8_e4m3fn': {
3086
+ 'q_head-16_kv_head-8_head-128': {
3087
+ 8192: (16, 64),
3088
+ 128: (1, 32),
3089
+ 256: (2, 64),
3090
+ 512: (4, 128),
3091
+ 1024: (8, 128),
3092
+ 2048: (16, 128),
3093
+ 4096: (32, 32),
3094
+ },
3095
+ 'q_head-16_kv_head-4_head-128': {
3096
+ 8192: (32, 64),
3097
+ 128: (1, 8),
3098
+ 256: (2, 16),
3099
+ 512: (4, 32),
3100
+ 1024: (8, 16),
3101
+ 2048: (16, 128),
3102
+ 4096: (32, 64),
3103
+ },
3104
+ 'q_head-32_kv_head-2_head-128': {
3105
+ 8192: (16, 64),
3106
+ 128: (1, 8),
3107
+ 256: (2, 32),
3108
+ 512: (4, 16),
3109
+ 1024: (8, 16),
3110
+ 2048: (16, 64),
3111
+ 4096: (32, 32),
3112
+ },
3113
+ 'q_head-64_kv_head-2_head-128': {
3114
+ 8192: (32, 32),
3115
+ 128: (1, 8),
3116
+ 256: (2, 8),
3117
+ 512: (4, 32),
3118
+ 1024: (8, 16),
3119
+ 2048: (16, 32),
3120
+ 4096: (16, 32),
3121
+ },
3122
+ 'q_head-64_kv_head-32_head-128': {
3123
+ 8192: (8, 8),
3124
+ 128: (1, 8),
3125
+ 256: (2, 8),
3126
+ 512: (4, 8),
3127
+ 1024: (8, 8),
3128
+ 2048: (8, 8),
3129
+ 4096: (16, 8),
3130
+ },
3131
+ 'q_head-128_kv_head-16_head-128': {
3132
+ 8192: (16, 8),
3133
+ 128: (1, 8),
3134
+ 256: (2, 8),
3135
+ 512: (4, 8),
3136
+ 1024: (8, 8),
3137
+ 2048: (16, 8),
3138
+ 4096: (16, 8),
3139
+ },
3140
+ 'q_head-32_kv_head-4_head-128': {
3141
+ 128: (1, 32),
3142
+ 256: (2, 32),
3143
+ 512: (4, 32),
3144
+ 1024: (8, 64),
3145
+ 2048: (16, 64),
3146
+ 4096: (16, 64),
3147
+ 8192: (32, 32),
3148
+ },
3149
+ 'q_head-128_kv_head-2_head-128': {
3150
+ 128: (1, 16),
3151
+ 256: (2, 8),
3152
+ 512: (4, 16),
3153
+ 1024: (8, 8),
3154
+ 2048: (8, 16),
3155
+ 4096: (16, 16),
3156
+ 8192: (16, 16),
3157
+ },
3158
+ 'q_head-4_kv_head-2_head-128': {
3159
+ 128: (1, 64),
3160
+ 256: (2, 128),
3161
+ 512: (4, 128),
3162
+ 1024: (8, 32),
3163
+ 2048: (16, 64),
3164
+ 4096: (32, 64),
3165
+ 8192: (32, 32),
3166
+ },
3167
+ 'q_head-2_kv_head-2_head-128': {
3168
+ 256: (2, 128),
3169
+ 8192: (32, 64),
3170
+ 512: (4, 256),
3171
+ 128: (1, 16),
3172
+ 1024: (8, 64),
3173
+ 2048: (16, 256),
3174
+ 4096: (32, 64),
3175
+ },
3176
+ 'q_head-16_kv_head-2_head-128': {
3177
+ 8192: (16, 256),
3178
+ 128: (1, 16),
3179
+ 256: (2, 16),
3180
+ 512: (4, 32),
3181
+ 1024: (8, 16),
3182
+ 2048: (16, 256),
3183
+ 4096: (32, 64),
3184
+ },
3185
+ 'q_head-128_kv_head-8_head-128': {
3186
+ 8192: (8, 16),
3187
+ 128: (1, 16),
3188
+ 256: (2, 8),
3189
+ 512: (4, 16),
3190
+ 1024: (8, 16),
3191
+ 2048: (16, 16),
3192
+ 4096: (16, 16),
3193
+ },
3194
+ 'q_head-32_kv_head-8_head-128': {
3195
+ 128: (1, 64),
3196
+ 256: (2, 64),
3197
+ 512: (4, 16),
3198
+ 1024: (8, 32),
3199
+ 2048: (16, 64),
3200
+ 4096: (32, 32),
3201
+ 8192: (16, 64),
3202
+ },
3203
+ 'q_head-8_kv_head-2_head-128': {
3204
+ 128: (1, 32),
3205
+ 256: (2, 128),
3206
+ 512: (4, 256),
3207
+ 1024: (4, 64),
3208
+ 2048: (16, 64),
3209
+ 4096: (16, 32),
3210
+ 8192: (32, 64),
3211
+ },
3212
+ 'q_head-32_kv_head-16_head-128': {
3213
+ 128: (1, 32),
3214
+ 256: (2, 32),
3215
+ 512: (4, 32),
3216
+ 1024: (8, 32),
3217
+ 2048: (16, 32),
3218
+ 4096: (16, 32),
3219
+ 8192: (16, 32),
3220
+ },
3221
+ 'q_head-64_kv_head-4_head-128': {
3222
+ 128: (1, 32),
3223
+ 256: (2, 16),
3224
+ 512: (4, 16),
3225
+ 1024: (8, 32),
3226
+ 2048: (16, 32),
3227
+ 4096: (32, 16),
3228
+ 8192: (16, 32),
3229
+ },
3230
+ 'q_head-128_kv_head-4_head-128': {
3231
+ 128: (1, 8),
3232
+ 256: (2, 16),
3233
+ 512: (4, 16),
3234
+ 1024: (8, 16),
3235
+ 2048: (16, 16),
3236
+ 4096: (16, 32),
3237
+ 8192: (16, 16),
3238
+ },
3239
+ 'q_head-8_kv_head-4_head-128': {
3240
+ 128: (1, 64),
3241
+ 256: (2, 128),
3242
+ 512: (4, 128),
3243
+ 1024: (8, 64),
3244
+ 2048: (16, 32),
3245
+ 4096: (32, 64),
3246
+ 8192: (32, 128),
3247
+ },
3248
+ 'q_head-64_kv_head-8_head-128': {
3249
+ 128: (1, 16),
3250
+ 256: (2, 32),
3251
+ 512: (4, 32),
3252
+ 1024: (8, 16),
3253
+ 2048: (16, 32),
3254
+ 4096: (16, 32),
3255
+ 8192: (16, 32),
3256
+ },
3257
+ 'q_head-64_kv_head-16_head-128': {
3258
+ 128: (1, 16),
3259
+ 256: (2, 16),
3260
+ 512: (4, 16),
3261
+ 1024: (8, 8),
3262
+ 2048: (16, 16),
3263
+ 4096: (16, 16),
3264
+ 8192: (16, 16),
3265
+ },
3266
+ },
3267
+ 'q_bfloat16_kv_bfloat16': {
3268
+ 'q_head-4_kv_head-2_head-128': {
3269
+ 128: (1, 32),
3270
+ 256: (2, 16),
3271
+ 512: (4, 16),
3272
+ 8192: (32, 32),
3273
+ 1024: (8, 128),
3274
+ 2048: (16, 128),
3275
+ 4096: (32, 16),
3276
+ },
3277
+ 'q_head-2_kv_head-1_head-128': {
3278
+ 512: (4, 128),
3279
+ 2048: (8, 64),
3280
+ 256: (2, 8),
3281
+ 1024: (8, 32),
3282
+ 4096: (32, 16),
3283
+ 128: (1, 16),
3284
+ 8192: (32, 32),
3285
+ },
3286
+ 'q_head-16_kv_head-8_head-128': {
3287
+ 256: (2, 32),
3288
+ 512: (4, 16),
3289
+ 1024: (8, 64),
3290
+ 2048: (16, 64),
3291
+ 4096: (32, 64),
3292
+ 8192: (32, 32),
3293
+ 128: (1, 64),
3294
+ },
3295
+ 'q_head-32_kv_head-4_head-256': {
3296
+ 1024: (8, 32),
3297
+ 2048: (16, 16),
3298
+ 4096: (32, 16),
3299
+ 8192: (32, 32),
3300
+ 128: (1, 16),
3301
+ 256: (2, 16),
3302
+ 512: (4, 8),
3303
+ },
3304
+ 'q_head-64_kv_head-4_head-128': {
3305
+ 4096: (32, 16),
3306
+ 8192: (16, 32),
3307
+ 128: (1, 32),
3308
+ 256: (2, 16),
3309
+ 512: (4, 8),
3310
+ 1024: (8, 8),
3311
+ 2048: (16, 32),
3312
+ },
3313
+ 'q_head-16_kv_head-8_head-256': {
3314
+ 128: (1, 32),
3315
+ 256: (2, 16),
3316
+ 512: (4, 32),
3317
+ 1024: (8, 64),
3318
+ 2048: (16, 64),
3319
+ 4096: (16, 32),
3320
+ 8192: (16, 32),
3321
+ },
3322
+ 'q_head-16_kv_head-1_head-128': {
3323
+ 4096: (32, 32),
3324
+ 8192: (32, 16),
3325
+ 128: (1, 32),
3326
+ 256: (2, 32),
3327
+ 512: (4, 64),
3328
+ 1024: (8, 8),
3329
+ 2048: (8, 32),
3330
+ },
3331
+ 'q_head-64_kv_head-32_head-128': {
3332
+ 1024: (4, 8),
3333
+ 2048: (4, 8),
3334
+ 4096: (4, 8),
3335
+ 8192: (8, 8),
3336
+ 128: (1, 8),
3337
+ 256: (2, 8),
3338
+ 512: (4, 8),
3339
+ },
3340
+ 'q_head-128_kv_head-4_head-128': {
3341
+ 128: (1, 8),
3342
+ 256: (2, 8),
3343
+ 512: (4, 8),
3344
+ 1024: (8, 8),
3345
+ 2048: (16, 8),
3346
+ 4096: (16, 16),
3347
+ 8192: (16, 16),
3348
+ },
3349
+ 'q_head-8_kv_head-1_head-256': {
3350
+ 1024: (8, 32),
3351
+ 2048: (8, 32),
3352
+ 4096: (16, 32),
3353
+ 8192: (32, 16),
3354
+ 128: (1, 64),
3355
+ 256: (2, 32),
3356
+ 512: (4, 32),
3357
+ },
3358
+ 'q_head-32_kv_head-1_head-128': {
3359
+ 128: (1, 32),
3360
+ 256: (2, 32),
3361
+ 512: (4, 32),
3362
+ 1024: (8, 16),
3363
+ 2048: (8, 16),
3364
+ 4096: (16, 16),
3365
+ 8192: (32, 8),
3366
+ },
3367
+ 'q_head-64_kv_head-4_head-256': {
3368
+ 128: (1, 8),
3369
+ 256: (2, 16),
3370
+ 512: (4, 8),
3371
+ 1024: (8, 8),
3372
+ 2048: (16, 32),
3373
+ 4096: (16, 16),
3374
+ 8192: (32, 16),
3375
+ },
3376
+ 'q_head-2_kv_head-1_head-256': {
3377
+ 512: (4, 8),
3378
+ 4096: (32, 128),
3379
+ 256: (2, 64),
3380
+ 1024: (8, 8),
3381
+ 8192: (32, 64),
3382
+ 128: (1, 128),
3383
+ 2048: (16, 32),
3384
+ },
3385
+ 'q_head-16_kv_head-1_head-256': {
3386
+ 128: (1, 32),
3387
+ 256: (2, 32),
3388
+ 512: (4, 16),
3389
+ 1024: (8, 8),
3390
+ 2048: (16, 8),
3391
+ 4096: (16, 16),
3392
+ 8192: (32, 16),
3393
+ },
3394
+ 'q_head-32_kv_head-8_head-128': {
3395
+ 128: (1, 16),
3396
+ 256: (2, 16),
3397
+ 512: (4, 32),
3398
+ 1024: (8, 16),
3399
+ 2048: (16, 64),
3400
+ 4096: (32, 32),
3401
+ 8192: (32, 32),
3402
+ },
3403
+ 'q_head-8_kv_head-2_head-128': {
3404
+ 128: (1, 64),
3405
+ 256: (2, 32),
3406
+ 512: (4, 64),
3407
+ 1024: (8, 16),
3408
+ 2048: (16, 16),
3409
+ 4096: (16, 32),
3410
+ 8192: (32, 16),
3411
+ },
3412
+ 'q_head-64_kv_head-1_head-128': {
3413
+ 128: (1, 16),
3414
+ 256: (2, 8),
3415
+ 512: (4, 8),
3416
+ 1024: (8, 8),
3417
+ 2048: (8, 8),
3418
+ 4096: (16, 8),
3419
+ 8192: (16, 8),
3420
+ },
3421
+ 'q_head-4_kv_head-2_head-256': {
3422
+ 128: (1, 64),
3423
+ 1024: (8, 16),
3424
+ 256: (2, 16),
3425
+ 512: (4, 16),
3426
+ 2048: (16, 32),
3427
+ 4096: (16, 16),
3428
+ 8192: (32, 32),
3429
+ },
3430
+ 'q_head-8_kv_head-1_head-128': {
3431
+ 512: (4, 32),
3432
+ 1024: (8, 16),
3433
+ 2048: (16, 64),
3434
+ 128: (1, 16),
3435
+ 4096: (16, 16),
3436
+ 256: (2, 128),
3437
+ 8192: (32, 32),
3438
+ },
3439
+ 'q_head-16_kv_head-4_head-256': {
3440
+ 256: (2, 16),
3441
+ 512: (4, 8),
3442
+ 1024: (8, 16),
3443
+ 2048: (16, 32),
3444
+ 128: (1, 64),
3445
+ 4096: (32, 64),
3446
+ 8192: (32, 64),
3447
+ },
3448
+ 'q_head-32_kv_head-4_head-128': {
3449
+ 1024: (8, 8),
3450
+ 2048: (16, 16),
3451
+ 4096: (32, 16),
3452
+ 128: (1, 32),
3453
+ 256: (2, 8),
3454
+ 8192: (32, 32),
3455
+ 512: (4, 8),
3456
+ },
3457
+ 'q_head-64_kv_head-2_head-256': {
3458
+ 4096: (32, 8),
3459
+ 8192: (32, 16),
3460
+ 128: (1, 8),
3461
+ 256: (2, 8),
3462
+ 512: (4, 32),
3463
+ 1024: (8, 16),
3464
+ 2048: (16, 16),
3465
+ },
3466
+ 'q_head-64_kv_head-16_head-256': {
3467
+ 512: (4, 8),
3468
+ 1024: (4, 8),
3469
+ 2048: (4, 8),
3470
+ 128: (1, 8),
3471
+ 4096: (8, 8),
3472
+ 256: (2, 8),
3473
+ 8192: (8, 8),
3474
+ },
3475
+ 'q_head-128_kv_head-2_head-256': {
3476
+ 128: (1, 8),
3477
+ 256: (2, 16),
3478
+ 512: (4, 16),
3479
+ 1024: (8, 8),
3480
+ 2048: (16, 8),
3481
+ 4096: (32, 8),
3482
+ 8192: (16, 16),
3483
+ },
3484
+ 'q_head-128_kv_head-16_head-128': {
3485
+ 2048: (8, 8),
3486
+ 4096: (8, 8),
3487
+ 8192: (16, 8),
3488
+ 128: (1, 8),
3489
+ 256: (2, 8),
3490
+ 512: (4, 8),
3491
+ 1024: (8, 8),
3492
+ },
3493
+ 'q_head-32_kv_head-1_head-256': {
3494
+ 128: (1, 8),
3495
+ 256: (2, 8),
3496
+ 512: (4, 16),
3497
+ 1024: (8, 16),
3498
+ 2048: (16, 8),
3499
+ 4096: (16, 16),
3500
+ 8192: (32, 32),
3501
+ },
3502
+ 'q_head-128_kv_head-4_head-256': {
3503
+ 128: (1, 16),
3504
+ 256: (2, 8),
3505
+ 512: (4, 8),
3506
+ 1024: (8, 8),
3507
+ 2048: (16, 16),
3508
+ 4096: (16, 16),
3509
+ 8192: (16, 16),
3510
+ },
3511
+ 'q_head-64_kv_head-8_head-128': {
3512
+ 128: (1, 32),
3513
+ 256: (2, 16),
3514
+ 512: (4, 8),
3515
+ 1024: (8, 16),
3516
+ 2048: (16, 32),
3517
+ 4096: (32, 16),
3518
+ 8192: (16, 32),
3519
+ },
3520
+ 'q_head-16_kv_head-2_head-128': {
3521
+ 128: (1, 16),
3522
+ 256: (2, 32),
3523
+ 512: (4, 16),
3524
+ 1024: (8, 16),
3525
+ 2048: (16, 16),
3526
+ 4096: (32, 32),
3527
+ 8192: (32, 32),
3528
+ },
3529
+ 'q_head-32_kv_head-8_head-256': {
3530
+ 128: (1, 16),
3531
+ 256: (2, 16),
3532
+ 512: (4, 8),
3533
+ 1024: (8, 32),
3534
+ 2048: (16, 16),
3535
+ 4096: (16, 32),
3536
+ 8192: (16, 32),
3537
+ },
3538
+ 'q_head-128_kv_head-1_head-128': {
3539
+ 128: (1, 32),
3540
+ 256: (2, 8),
3541
+ 512: (4, 8),
3542
+ 1024: (8, 8),
3543
+ 2048: (8, 8),
3544
+ 4096: (16, 8),
3545
+ 8192: (16, 16),
3546
+ },
3547
+ 'q_head-8_kv_head-2_head-256': {
3548
+ 128: (1, 64),
3549
+ 256: (2, 16),
3550
+ 512: (4, 64),
3551
+ 1024: (8, 16),
3552
+ 2048: (8, 32),
3553
+ 4096: (16, 16),
3554
+ 8192: (32, 64),
3555
+ },
3556
+ 'q_head-32_kv_head-16_head-128': {
3557
+ 128: (1, 32),
3558
+ 256: (2, 32),
3559
+ 512: (4, 16),
3560
+ 1024: (8, 16),
3561
+ 2048: (8, 32),
3562
+ 4096: (8, 32),
3563
+ 8192: (8, 32),
3564
+ },
3565
+ 'q_head-64_kv_head-1_head-256': {
3566
+ 128: (1, 8),
3567
+ 256: (2, 8),
3568
+ 512: (4, 8),
3569
+ 1024: (8, 8),
3570
+ 2048: (8, 8),
3571
+ 4096: (16, 8),
3572
+ 8192: (32, 16),
3573
+ },
3574
+ 'q_head-8_kv_head-4_head-128': {
3575
+ 128: (1, 64),
3576
+ 256: (2, 128),
3577
+ 512: (4, 32),
3578
+ 1024: (8, 16),
3579
+ 2048: (16, 64),
3580
+ 4096: (32, 128),
3581
+ 8192: (32, 32),
3582
+ },
3583
+ 'q_head-32_kv_head-2_head-128': {
3584
+ 128: (1, 8),
3585
+ 256: (2, 16),
3586
+ 512: (4, 32),
3587
+ 1024: (8, 8),
3588
+ 2048: (16, 8),
3589
+ 4096: (32, 32),
3590
+ 8192: (32, 16),
3591
+ },
3592
+ 'q_head-128_kv_head-8_head-128': {
3593
+ 128: (1, 8),
3594
+ 256: (2, 8),
3595
+ 512: (4, 8),
3596
+ 1024: (8, 16),
3597
+ 2048: (16, 16),
3598
+ 4096: (16, 16),
3599
+ 8192: (16, 16),
3600
+ },
3601
+ 'q_head-4_kv_head-1_head-256': {
3602
+ 128: (1, 64),
3603
+ 256: (2, 32),
3604
+ 512: (4, 16),
3605
+ 1024: (8, 64),
3606
+ 2048: (16, 16),
3607
+ 4096: (32, 32),
3608
+ 8192: (32, 32),
3609
+ },
3610
+ 'q_head-64_kv_head-8_head-256': {
3611
+ 128: (1, 8),
3612
+ 256: (2, 8),
3613
+ 512: (4, 8),
3614
+ 1024: (8, 8),
3615
+ 2048: (16, 8),
3616
+ 4096: (16, 16),
3617
+ 8192: (16, 16),
3618
+ },
3619
+ 'q_head-16_kv_head-2_head-256': {
3620
+ 128: (1, 8),
3621
+ 256: (2, 32),
3622
+ 512: (4, 32),
3623
+ 1024: (8, 16),
3624
+ 2048: (16, 16),
3625
+ 4096: (32, 16),
3626
+ 8192: (32, 16),
3627
+ },
3628
+ 'q_head-128_kv_head-1_head-256': {
3629
+ 128: (1, 8),
3630
+ 256: (2, 8),
3631
+ 512: (4, 8),
3632
+ 1024: (8, 8),
3633
+ 2048: (8, 8),
3634
+ 4096: (32, 8),
3635
+ 8192: (32, 8),
3636
+ },
3637
+ 'q_head-32_kv_head-16_head-256': {
3638
+ 128: (1, 16),
3639
+ 256: (2, 16),
3640
+ 512: (4, 16),
3641
+ 1024: (4, 16),
3642
+ 2048: (4, 16),
3643
+ 4096: (4, 16),
3644
+ 8192: (8, 16),
3645
+ },
3646
+ 'q_head-64_kv_head-2_head-128': {
3647
+ 128: (1, 16),
3648
+ 256: (2, 8),
3649
+ 512: (4, 8),
3650
+ 1024: (8, 16),
3651
+ 2048: (16, 32),
3652
+ 4096: (16, 32),
3653
+ 8192: (32, 16),
3654
+ },
3655
+ 'q_head-8_kv_head-4_head-256': {
3656
+ 128: (1, 32),
3657
+ 256: (2, 32),
3658
+ 512: (4, 16),
3659
+ 1024: (8, 16),
3660
+ 2048: (8, 32),
3661
+ 4096: (32, 32),
3662
+ 8192: (32, 32),
3663
+ },
3664
+ 'q_head-128_kv_head-8_head-256': {
3665
+ 128: (1, 8),
3666
+ 256: (2, 8),
3667
+ 512: (4, 8),
3668
+ 1024: (8, 8),
3669
+ 2048: (8, 8),
3670
+ 4096: (16, 8),
3671
+ 8192: (16, 8),
3672
+ },
3673
+ 'q_head-32_kv_head-2_head-256': {
3674
+ 128: (1, 8),
3675
+ 256: (2, 8),
3676
+ 512: (4, 8),
3677
+ 1024: (8, 16),
3678
+ 2048: (16, 16),
3679
+ 4096: (32, 32),
3680
+ 8192: (32, 32),
3681
+ },
3682
+ 'q_head-64_kv_head-16_head-128': {
3683
+ 128: (1, 16),
3684
+ 256: (2, 16),
3685
+ 512: (4, 8),
3686
+ 1024: (4, 16),
3687
+ 2048: (8, 16),
3688
+ 4096: (8, 16),
3689
+ 8192: (8, 16),
3690
+ },
3691
+ 'q_head-16_kv_head-4_head-128': {
3692
+ 128: (1, 32),
3693
+ 256: (2, 32),
3694
+ 512: (4, 128),
3695
+ 1024: (8, 32),
3696
+ 2048: (8, 16),
3697
+ 4096: (16, 32),
3698
+ 8192: (32, 64),
3699
+ },
3700
+ 'q_head-4_kv_head-1_head-128': {
3701
+ 128: (1, 32),
3702
+ 256: (2, 8),
3703
+ 512: (4, 256),
3704
+ 1024: (8, 8),
3705
+ 2048: (16, 16),
3706
+ 4096: (32, 32),
3707
+ 8192: (32, 64),
3708
+ },
3709
+ 'q_head-128_kv_head-2_head-128': {
3710
+ 128: (1, 16),
3711
+ 256: (2, 8),
3712
+ 512: (4, 8),
3713
+ 1024: (8, 16),
3714
+ 2048: (16, 8),
3715
+ 4096: (32, 8),
3716
+ 8192: (16, 16),
3717
+ },
3718
+ },
3719
+ },
3720
+ },
3721
+ }
3722
+
3723
+
3724
+ def get_tuned_block_sizes(
3725
+ q_dtype,
3726
+ kv_dtype,
3727
+ actual_num_q_heads,
3728
+ actual_num_kv_heads,
3729
+ head_dim,
3730
+ page_size,
3731
+ max_num_tokens,
3732
+ pages_per_seq,
3733
+ ) -> tuple[int, int]:
3734
+ """Search tuned values for (num_kv_pages_per_blk, num_queries_per_blk)."""
3735
+
3736
+ # Set default block sizes for each tpu_version.
3737
+ tpu_version = get_tpu_version()
3738
+ if tpu_version < 4:
3739
+ raise NotImplementedError('TPU version must be 4 or higher.')
3740
+ match tpu_version:
3741
+ case 4:
3742
+ # TPUv4 has much smaller VMEM size so we pick fixed block sizes.
3743
+ bkv_p, bq = (512 // page_size, 32)
3744
+ case 7:
3745
+ bkv_p, bq = (4096 // page_size, 32)
3746
+ case _:
3747
+ bkv_p, bq = (2048 // page_size, 32)
3748
+
3749
+ keys = get_lookup_keys(
3750
+ page_size,
3751
+ q_dtype,
3752
+ kv_dtype,
3753
+ actual_num_q_heads,
3754
+ actual_num_kv_heads,
3755
+ head_dim,
3756
+ page_size * pages_per_seq,
3757
+ )
3758
+ device, page_size, dtypes, head_dims, max_model_len = keys
3759
+
3760
+ try:
3761
+ bkv_p, bq = TUNED_BLOCK_SIZES[device][page_size][dtypes][head_dims][
3762
+ max_model_len]
3763
+ except KeyError:
3764
+ logger.warning_once(
3765
+ 'Couldn`t find tuned sizes for the RPA v3 kernel with %s', keys)
3766
+
3767
+ return (min(pages_per_seq, bkv_p), min(max_num_tokens, bq))
3768
+
3769
+
3770
+ def get_lookup_keys(
3771
+ page_size,
3772
+ q_dtype,
3773
+ kv_dtype,
3774
+ num_q_heads,
3775
+ num_kv_heads,
3776
+ head_dim,
3777
+ max_model_len,
3778
+ ):
3779
+ """Get the lookup keys for tuned block sizes."""
3780
+ (
3781
+ page_size,
3782
+ q_dtype_name,
3783
+ kv_dtype_name,
3784
+ num_q_heads,
3785
+ num_kv_heads,
3786
+ head_dim,
3787
+ max_model_len,
3788
+ ) = get_simplified_raw_key(
3789
+ page_size,
3790
+ q_dtype,
3791
+ kv_dtype,
3792
+ num_q_heads,
3793
+ num_kv_heads,
3794
+ head_dim,
3795
+ max_model_len,
3796
+ )
3797
+
3798
+ return (
3799
+ get_device_name(),
3800
+ next_power_of_2(page_size),
3801
+ f'q_{q_dtype_name}_kv_{kv_dtype_name}',
3802
+ f'q_head-{num_q_heads}_kv_head-{num_kv_heads}_head-{head_dim}',
3803
+ next_power_of_2(max_model_len),
3804
+ )
3805
+
3806
+
3807
+ def get_simplified_raw_key(
3808
+ page_size,
3809
+ q_dtype,
3810
+ kv_dtype,
3811
+ actual_num_q_heads,
3812
+ actual_num_kv_heads,
3813
+ head_dim,
3814
+ max_model_len,
3815
+ ):
3816
+ """Get the simplified key."""
3817
+ assert actual_num_q_heads % actual_num_kv_heads == 0
3818
+ actual_num_q_heads_per_kv_head = actual_num_q_heads // actual_num_kv_heads
3819
+ q_packing = get_dtype_packing(q_dtype)
3820
+ kv_packing = get_dtype_packing(kv_dtype)
3821
+ num_kv_heads_x2 = align_to(actual_num_kv_heads * 2, kv_packing)
3822
+ num_q_heads_per_kv_head = align_to(actual_num_q_heads_per_kv_head,
3823
+ q_packing)
3824
+ assert num_kv_heads_x2 % 2 == 0
3825
+
3826
+ return (
3827
+ next_power_of_2(page_size),
3828
+ jnp.dtype(q_dtype).name,
3829
+ jnp.dtype(kv_dtype).name,
3830
+ next_power_of_2(num_q_heads_per_kv_head * actual_num_kv_heads),
3831
+ next_power_of_2(num_kv_heads_x2) // 2,
3832
+ align_to(head_dim, 128),
3833
+ next_power_of_2(max_model_len),
3834
+ )