tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,772 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Flash Attention TPU kernel."""
3
+ from __future__ import annotations
4
+
5
+ import dataclasses
6
+ import functools
7
+ import math
8
+ from typing import Any, NamedTuple
9
+
10
+ import jax
11
+ import jax.numpy as jnp
12
+ from jax import lax
13
+ from jax.experimental import pallas as pl
14
+ from jax.experimental.pallas import tpu as pltpu
15
+
16
+ DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
17
+ NUM_LANES = 128
18
+ NUM_SUBLANES = 8
19
+
20
+
21
+ class SegmentIds(NamedTuple):
22
+ """SegmentIds for Q and KV sequences.
23
+
24
+ SegmentIds are used to generate segment mask, which prevents attention between
25
+ different segments in the input sequence. Each array is a list of ids
26
+ (integers).
27
+ Only the token with the same id can attend to each other.
28
+
29
+ Attributes:
30
+ q: segment ids along the Q sequence.
31
+ kv: segment ids along the KV sequence.
32
+ """
33
+
34
+ q: jax.Array # [batch_size, q_seq_len]
35
+ kv: jax.Array # [batch_size, kv_seq_len]
36
+
37
+
38
+ @dataclasses.dataclass(frozen=True)
39
+ class BlockSizes:
40
+ """Tile sizes parameterizing FlashAttention kernels.
41
+
42
+ Those parameters have negligible effect on numerics, but affect performance
43
+ greatly.
44
+ """
45
+ block_q: int
46
+ block_k_major: int
47
+ block_k: int
48
+ block_b: int
49
+
50
+ def __post_init__(self):
51
+
52
+ def verify_major_minor(prefix, suffix, major, minor):
53
+ if minor > major:
54
+ raise ValueError(
55
+ f"{prefix}{suffix}={minor} should be smaller than"
56
+ f" {prefix}_major{suffix}={major}")
57
+ if major % minor != 0:
58
+ raise ValueError(f"{prefix}{suffix}={minor} should divide"
59
+ f" {prefix}_major{suffix}={major}")
60
+
61
+ verify_major_minor("block_k", "", self.block_k_major, self.block_k)
62
+
63
+ @classmethod
64
+ def get_default(cls, batch_size, num_heads, q_seq_len, kv_len, d_model):
65
+ # TODO(apaszke,sharadmv): Select better parameters based on a heuristic.
66
+ del batch_size, num_heads, q_seq_len, kv_len, d_model # Unused.
67
+ return BlockSizes(
68
+ block_q=128,
69
+ block_k_major=128,
70
+ block_k=128,
71
+ block_b=1,
72
+ )
73
+
74
+
75
+ @functools.partial(
76
+ jax.jit,
77
+ static_argnames=[
78
+ "causal",
79
+ "sm_scale",
80
+ "block_sizes",
81
+ "vmem_limit_bytes",
82
+ "debug",
83
+ ],
84
+ )
85
+ def flash_attention(
86
+ q, # [batch_size, num_heads, q_seq_len, d_model]
87
+ k, # [batch_size, num_heads, kv_seq_len, d_model]
88
+ v, # [batch_size, num_heads, kv_seq_len, d_model]
89
+ ab=None, # [batch_size, num_heads, q_seq_len, kv_seq_len]
90
+ segment_ids=None, # q of [batch_size, q_seq_len] and kv of [batch_size, kv_seq_len]
91
+ *,
92
+ causal: bool = False,
93
+ sm_scale: float = 1.0,
94
+ block_sizes: BlockSizes | None = None,
95
+ vmem_limit_bytes: int,
96
+ debug: bool = False,
97
+ ):
98
+ batch_size, num_heads, q_seq_len, d_model = q.shape
99
+ batch_size_k, num_heads_k, kv_seq_len, d_model_k = k.shape
100
+ batch_size_v, num_heads_v, kv_seq_len_v, d_model_v = v.shape
101
+ if batch_size != batch_size_k or batch_size != batch_size_v:
102
+ raise ValueError(
103
+ f"Batch size mismatch: got {batch_size}, {batch_size_k} and"
104
+ f" {batch_size_v} (for q, k, v respectively)")
105
+ if num_heads != num_heads_k or num_heads != num_heads_v:
106
+ raise ValueError(
107
+ f"Head count mismatch: got {num_heads}, {num_heads_k},"
108
+ f" {num_heads_v} (for q, k, v respectively)")
109
+ if d_model != d_model_k:
110
+ raise ValueError(
111
+ f"Model dimension mismatch: got {d_model} and {d_model_k} (for q and k"
112
+ " respectively)")
113
+ if d_model != d_model_v:
114
+ raise NotImplementedError(
115
+ "V model dimension unequal to KV model dimension unsupported")
116
+ if kv_seq_len != kv_seq_len_v:
117
+ raise ValueError(
118
+ f"KV sequence length mismatch: got {kv_seq_len} and {kv_seq_len_v}"
119
+ )
120
+ if ab is not None:
121
+ if ab.shape != (batch_size, num_heads, q_seq_len, kv_seq_len):
122
+ raise ValueError(
123
+ f"Attention bias shape mismatch: expected ({batch_size=},"
124
+ f" {num_heads=}, {q_seq_len=}, {kv_seq_len=}), got {ab.shape}")
125
+ if segment_ids is not None:
126
+ if segment_ids.q.shape != (batch_size, q_seq_len):
127
+ raise ValueError(
128
+ f"Q segment ids shape mismatch: expected ({batch_size=},"
129
+ f" {q_seq_len=},), got {segment_ids.q.shape}")
130
+ if segment_ids.kv.shape != (batch_size, kv_seq_len):
131
+ raise ValueError(
132
+ f"KV segment ids shape mismatch: expected ({batch_size=},"
133
+ f" {kv_seq_len=},), got {segment_ids.kv.shape}")
134
+ if block_sizes is None:
135
+ block_sizes = BlockSizes.get_default(batch_size, num_heads, q_seq_len,
136
+ kv_seq_len, d_model)
137
+ # TODO (KWang1998 & hfan): tune the block sizes properly.
138
+ if kv_seq_len <= 92800:
139
+ # Override block_k/block_k_major to use `_flash_attention_kernel_single_batch_single_step`.
140
+ block_sizes = BlockSizes(block_q=block_sizes.block_q,
141
+ block_b=block_sizes.block_b,
142
+ block_k_major=kv_seq_len,
143
+ block_k=kv_seq_len)
144
+ return _flash_attention(q, k, v, ab, segment_ids, False, causal, sm_scale,
145
+ block_sizes, vmem_limit_bytes, debug)
146
+
147
+
148
+ def _flash_attention(
149
+ q,
150
+ k,
151
+ v,
152
+ ab,
153
+ segment_ids,
154
+ save_residuals,
155
+ causal,
156
+ sm_scale,
157
+ block_sizes,
158
+ vmem_limit_bytes,
159
+ debug,
160
+ ):
161
+ return _flash_attention_impl(
162
+ q,
163
+ k,
164
+ v,
165
+ ab,
166
+ segment_ids,
167
+ save_residuals,
168
+ causal,
169
+ sm_scale,
170
+ block_sizes.block_b,
171
+ block_sizes.block_q,
172
+ block_sizes.block_k_major,
173
+ block_sizes.block_k,
174
+ vmem_limit_bytes,
175
+ debug,
176
+ )
177
+
178
+
179
+ MIN_BLOCK_SIZE = 128
180
+ TRANS_B_DIM_NUMBERS = (((1, ), (1, )), ((), ()))
181
+
182
+
183
+ def below_or_on_diag(r, r_blk_size, c, c_blk_size):
184
+ # A block is considered below or on diagonal as long as the bottom left
185
+ # corner of the block is below or on diagonal.
186
+ return ((r + 1) * r_blk_size - 1) > (c * c_blk_size)
187
+
188
+
189
+ def _flash_attention_kernel(q_tile_ref, *args, **kwargs):
190
+ block_b = q_tile_ref.shape[0]
191
+ # If we're not going to tile the softmax, then we can avoid a bunch of VPU ops.
192
+ if kwargs["block_k"] == kwargs["kv_seq_len"]:
193
+ kernel = _flash_attention_kernel_single_batch_single_step
194
+ else:
195
+ kernel = _flash_attention_kernel_single_batch
196
+ for batch_idx in range(block_b):
197
+ kernel((batch_idx, 0), q_tile_ref, *args, **kwargs)
198
+
199
+
200
+ def _flash_attention_kernel_single_batch(
201
+ batch_idx: tuple[int, ...],
202
+ q_tile_ref,
203
+ k_tile_ref,
204
+ v_tile_ref,
205
+ ab_tile_ref,
206
+ q_segment_ids_tile_ref,
207
+ kv_segment_ids_tile_ref, # Input arrays
208
+ o_tile_ref, # Output arrays
209
+ l_ref,
210
+ m_ref,
211
+ m_scratch_ref,
212
+ l_scratch_ref,
213
+ acc_scratch_ref,
214
+ *,
215
+ causal,
216
+ sm_scale,
217
+ block_k,
218
+ kv_seq_len,
219
+ mask_value,
220
+ ):
221
+ block_k_major = k_tile_ref.shape[2]
222
+ block_q = q_tile_ref.shape[2]
223
+ head_dim = q_tile_ref.shape[-1]
224
+
225
+ kv_seq_idx = pl.program_id(3)
226
+
227
+ @pl.when(kv_seq_idx == 0)
228
+ def start_new_sequence():
229
+ m_scratch_ref[batch_idx] = jnp.full(m_scratch_ref.shape[2:], -jnp.inf,
230
+ jnp.float32)
231
+ l_scratch_ref[batch_idx] = jnp.zeros(l_scratch_ref.shape[2:],
232
+ jnp.float32)
233
+ acc_scratch_ref[batch_idx] = jnp.zeros(acc_scratch_ref.shape[2:],
234
+ jnp.float32)
235
+
236
+ q_seq_idx = pl.program_id(2)
237
+ if causal:
238
+ should_run = below_or_on_diag(q_seq_idx, block_q, kv_seq_idx,
239
+ block_k_major)
240
+ else:
241
+ should_run = True
242
+
243
+ @pl.when(should_run)
244
+ def run():
245
+
246
+ @pl.loop(0, block_k_major, step=block_k, unroll=True)
247
+ def _body(start_k):
248
+ m_prev = m_scratch_ref[batch_idx]
249
+ l_prev = l_scratch_ref[batch_idx]
250
+ q = q_tile_ref[batch_idx] # [block_q, head_dim]
251
+ k = k_tile_ref[(*batch_idx, pl.dslice(start_k, block_k),
252
+ slice(None))] # [block_k, head_dim]
253
+
254
+ s = jax.lax.dot_general(
255
+ q, k, TRANS_B_DIM_NUMBERS,
256
+ preferred_element_type=jnp.float32) # [block_q, block_k]
257
+
258
+ # Add attention bias if needed.
259
+ # TODO(tanburn) Should the attention bias be added before or after
260
+ # multiplication by sm_scale?
261
+ if ab_tile_ref is not None:
262
+ ab = ab_tile_ref[(*batch_idx, pl.dslice(None),
263
+ pl.dslice(start_k,
264
+ block_k))].astype(jnp.float32)
265
+ s += ab
266
+
267
+ if sm_scale != 1.0:
268
+ s *= sm_scale
269
+
270
+ mask = None
271
+ if q_segment_ids_tile_ref is not None:
272
+ repeats, rem = divmod(block_k, NUM_LANES)
273
+ if rem:
274
+ raise NotImplementedError(
275
+ f"kv block size must be a multiple of {NUM_LANES}")
276
+ q_segment_ids = pltpu.repeat(
277
+ q_segment_ids_tile_ref[batch_idx[0]], repeats,
278
+ axis=1) # [block_q, block_k].
279
+ kv_segment_ids = kv_segment_ids_tile_ref[
280
+ batch_idx[0], :1,
281
+ pl.dslice(start_k, block_k)] # [1, block_k].
282
+ mask = jnp.equal(q_segment_ids,
283
+ kv_segment_ids).astype(jnp.bool_)
284
+
285
+ if causal:
286
+ mask_shape = (block_q, block_k)
287
+ row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
288
+ row_ids += q_seq_idx * block_q
289
+ col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
290
+ col_ids += kv_seq_idx * block_k_major + start_k
291
+ causal_mask = col_ids <= row_ids
292
+ mask = (causal_mask if mask is None else jnp.logical_and(
293
+ mask, causal_mask))
294
+
295
+ s = s if mask is None else s + jnp.where(mask, 0.0, mask_value)
296
+
297
+ m_curr = jnp.max(s, axis=1)[:,
298
+ None] # Row max, shape [block_q, 1].
299
+ m_next = jnp.maximum(m_prev, m_curr) # Shape [block_q, 128].
300
+
301
+ block_k_repeats, rem = divmod(block_k, MIN_BLOCK_SIZE)
302
+ if rem:
303
+ raise NotImplementedError(
304
+ f"{block_k=} should be a multiple of {MIN_BLOCK_SIZE}")
305
+ p = jnp.exp(s - pltpu.repeat(m_next, block_k_repeats, 1))
306
+
307
+ alpha = jnp.exp(m_prev - m_next) # Shape [block_q, 128].
308
+
309
+ l_corr = alpha * l_prev
310
+
311
+ l_next = jnp.sum(p, axis=1)[:,
312
+ None] + l_corr # Shape [block_q, 128]
313
+
314
+ head_dim_repeats, rem = divmod(head_dim, MIN_BLOCK_SIZE)
315
+ l_broadcast = lambda l: pltpu.repeat(l, head_dim_repeats, 1)
316
+ if rem:
317
+ if head_dim_repeats == 0:
318
+ l_broadcast = lambda l: l[:, :head_dim]
319
+ else:
320
+ raise NotImplementedError(
321
+ f"{head_dim=} should be a multiple of {MIN_BLOCK_SIZE} if larger"
322
+ )
323
+ l_scratch_ref[batch_idx] = l_next
324
+ m_scratch_ref[batch_idx] = m_next
325
+
326
+ l_next_inv_safe = jnp.where(l_next == 0.0, 1.0, 1.0 / l_next)
327
+ acc_scratch_ref[batch_idx] *= l_broadcast(l_corr * l_next_inv_safe)
328
+ v = v_tile_ref[(*batch_idx, pl.dslice(start_k,
329
+ block_k), slice(None))]
330
+ o_curr = jax.lax.dot(p.astype(v.dtype),
331
+ v,
332
+ preferred_element_type=jnp.float32)
333
+ acc_scratch_ref[batch_idx] += o_curr * l_broadcast(l_next_inv_safe)
334
+
335
+ @pl.when(kv_seq_idx == (kv_seq_len // block_k_major) - 1)
336
+ def store_output():
337
+ o_tile_ref[batch_idx] = acc_scratch_ref[batch_idx].astype(
338
+ o_tile_ref.dtype)
339
+ if l_ref is not None:
340
+ l_ref[batch_idx] = l_scratch_ref[batch_idx].astype(l_ref.dtype)
341
+ if m_ref is not None:
342
+ m_ref[batch_idx] = m_scratch_ref[batch_idx].astype(m_ref.dtype)
343
+
344
+
345
+ # ruff: noqa #731
346
+ # ruff: noqa #741
347
+ def _flash_attention_kernel_single_batch_single_step(
348
+ batch_idx: tuple[int, ...],
349
+ q_tile_ref,
350
+ k_tile_ref,
351
+ v_tile_ref,
352
+ ab_tile_ref,
353
+ q_segment_ids_tile_ref,
354
+ kv_segment_ids_tile_ref, # Input arrays
355
+ o_tile_ref, # Output arrays
356
+ l_ref: Any | None = None,
357
+ m_ref: Any | None = None,
358
+ *,
359
+ causal,
360
+ sm_scale,
361
+ block_k,
362
+ kv_seq_len,
363
+ mask_value,
364
+ ):
365
+ block_k_major = k_tile_ref.shape[2]
366
+ block_q = q_tile_ref.shape[2]
367
+
368
+ assert kv_seq_len == block_k_major == block_k
369
+
370
+ q = q_tile_ref[batch_idx] # [block_q, head_dim]
371
+ k = k_tile_ref[batch_idx] # [block_k, head_dim]
372
+ s = jax.lax.dot_general(
373
+ q, k, TRANS_B_DIM_NUMBERS,
374
+ preferred_element_type=jnp.float32) # [block_q, block_k]
375
+
376
+ if ab_tile_ref is not None:
377
+ s += ab_tile_ref[batch_idx].astype(jnp.float32)
378
+ if sm_scale != 1.0:
379
+ s *= sm_scale
380
+
381
+ mask = None
382
+ if q_segment_ids_tile_ref is not None:
383
+ repeats, rem = divmod(block_k, NUM_LANES)
384
+ if rem:
385
+ raise NotImplementedError(
386
+ f"kv block size must be a multiple of {NUM_LANES}")
387
+ q_segment_ids = q_segment_ids_tile_ref[
388
+ batch_idx[0]] # [block_q, NUM_LANES].
389
+ q_segment_ids = pltpu.repeat(q_segment_ids, repeats,
390
+ axis=1) # [block_q, block_k].
391
+ kv_segment_ids = kv_segment_ids_tile_ref[batch_idx[0], :
392
+ 1] # [1, block_k].
393
+ mask = jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_)
394
+
395
+ if causal:
396
+ q_seq_idx = pl.program_id(2)
397
+ mask_shape = (block_q, block_k)
398
+ row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
399
+ row_ids += q_seq_idx * block_q
400
+ col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
401
+ causal_mask = col_ids <= row_ids
402
+ mask = causal_mask if mask is None else jnp.logical_and(
403
+ mask, causal_mask)
404
+ s = s if mask is None else s + jnp.where(mask, 0.0, mask_value)
405
+
406
+ m = jnp.max(s, axis=1)[:, None]
407
+ p = jnp.exp(s - m)
408
+ l = jnp.sum(p, axis=1)[:, None]
409
+ p /= l
410
+
411
+ if m_ref is not None:
412
+ m_ref[batch_idx] = lax.broadcast_in_dim(m, m_ref.shape[2:], range(2))
413
+ if l_ref is not None:
414
+ l_ref[batch_idx] = lax.broadcast_in_dim(l, l_ref.shape[2:], range(2))
415
+
416
+ v = v_tile_ref[batch_idx]
417
+ o_tile_ref[batch_idx] = jax.lax.dot(
418
+ p.astype(v.dtype), v,
419
+ preferred_element_type=jnp.float32).astype(o_tile_ref.dtype)
420
+
421
+
422
+ def _bytes(x: jax.Array | jax.ShapeDtypeStruct) -> int:
423
+ return math.prod(x.shape) * x.dtype.itemsize
424
+
425
+
426
+ def _fwd_cost_estimate(
427
+ q: jax.Array,
428
+ k: jax.Array,
429
+ v: jax.Array,
430
+ ab: jax.Array | None,
431
+ segment_ids: SegmentIds | None,
432
+ *,
433
+ causal: bool,
434
+ sm_scale: jax.Array | None,
435
+ kernel_inputs_specs,
436
+ kernel_outputs_specs,
437
+ ) -> pl.CostEstimate | None:
438
+ body_cost = pl.estimate_cost(mha_reference,
439
+ q,
440
+ k,
441
+ v,
442
+ ab,
443
+ segment_ids,
444
+ causal=causal,
445
+ sm_scale=sm_scale)
446
+ input_bytes = sum(_bytes(x) for x in jax.tree.leaves(kernel_inputs_specs))
447
+ output_bytes = sum(
448
+ _bytes(x) for x in jax.tree.leaves(kernel_outputs_specs))
449
+ return pl.CostEstimate(
450
+ flops=body_cost.flops,
451
+ transcendentals=body_cost.transcendentals,
452
+ bytes_accessed=input_bytes + output_bytes,
453
+ )
454
+
455
+
456
+ def _flash_attention_impl(
457
+ q,
458
+ k,
459
+ v,
460
+ ab,
461
+ segment_ids,
462
+ save_residuals,
463
+ causal,
464
+ sm_scale,
465
+ block_b,
466
+ block_q,
467
+ block_k_major,
468
+ block_k,
469
+ vmem_limit_bytes,
470
+ debug,
471
+ ):
472
+ batch_size, num_heads, q_seq_len, head_dim = q.shape
473
+ _, _, kv_seq_len, _ = k.shape
474
+ _verify_block("block_q",
475
+ "q_seq_len",
476
+ block_q,
477
+ q_seq_len,
478
+ should_divide=False)
479
+ _verify_block("block_k_major", "kv_seq_len", block_k_major, kv_seq_len)
480
+ _verify_block("block_k", "kv_seq_len", block_k, kv_seq_len)
481
+ _verify_block("block_b", "batch", block_b, batch_size, should_divide=False)
482
+
483
+ # TODO(apaszke): Tile over heads as well.
484
+ grid = (
485
+ pl.cdiv(batch_size, block_b),
486
+ num_heads,
487
+ pl.cdiv(q_seq_len, block_q),
488
+ kv_seq_len // block_k_major,
489
+ )
490
+
491
+ def q_index_map(batch_index, head_index, q_seq_index, _):
492
+ return (batch_index, head_index, q_seq_index, 0)
493
+
494
+ def kv_index_map(batch_index, head_index, q_seq_index, kv_seq_index):
495
+ if causal:
496
+ # If the kv block is skipped, prefetch the next valid kv block, i.e. the
497
+ # 0th one to be used for the next block_q rows.
498
+ next_kv_index = lax.select(
499
+ below_or_on_diag(q_seq_index, block_q, kv_seq_index,
500
+ block_k_major),
501
+ kv_seq_index,
502
+ 0,
503
+ )
504
+ else:
505
+ next_kv_index = kv_seq_index
506
+ return (batch_index, head_index, next_kv_index, 0)
507
+
508
+ def ab_index_map(batch_index, head_index, q_seq_index, kv_seq_index):
509
+ if causal:
510
+ should_run = below_or_on_diag(q_seq_index, block_q, kv_seq_index,
511
+ block_k_major)
512
+ # If the ab block is skipped, prefetch the next valid ab block, i.e. the
513
+ # 0th kv to be used for the next block_q rows.
514
+ next_q_index = lax.select(
515
+ should_run,
516
+ q_seq_index,
517
+ lax.select(q_seq_index == (q_seq_len // block_q) - 1, 0,
518
+ q_seq_index + 1),
519
+ )
520
+ next_kv_index = lax.select(should_run, kv_seq_index, 0)
521
+ else:
522
+ next_q_index = q_seq_index
523
+ next_kv_index = kv_seq_index
524
+
525
+ return (batch_index, head_index, next_q_index, next_kv_index)
526
+
527
+ def o_index_map(batch_index, head_index, q_seq_index, _):
528
+ return (batch_index, head_index, q_seq_index, 0)
529
+
530
+ def lm_index_map(batch_index, head_index, q_seq_index, _):
531
+ return (batch_index, head_index, q_seq_index, 0)
532
+
533
+ kernel = functools.partial(
534
+ _flash_attention_kernel,
535
+ causal=causal,
536
+ mask_value=DEFAULT_MASK_VALUE,
537
+ sm_scale=sm_scale,
538
+ block_k=block_k,
539
+ kv_seq_len=kv_seq_len,
540
+ )
541
+ out_shape = jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype)
542
+ out_shape = [out_shape]
543
+ out_specs = [pl.BlockSpec((block_b, 1, block_q, head_dim), o_index_map)]
544
+
545
+ if block_k != kv_seq_len:
546
+ m_scratch = pltpu.VMEM((block_b, 1, block_q, MIN_BLOCK_SIZE),
547
+ jnp.float32)
548
+ l_scratch = pltpu.VMEM((block_b, 1, block_q, MIN_BLOCK_SIZE),
549
+ jnp.float32)
550
+ acc_scratch = pltpu.VMEM((block_b, 1, block_q, head_dim), jnp.float32)
551
+ scratch_shapes = [m_scratch, l_scratch, acc_scratch]
552
+ else:
553
+ scratch_shapes = []
554
+
555
+ if save_residuals:
556
+ out_specs = [
557
+ *out_specs,
558
+ pl.BlockSpec((block_b, 1, block_q, MIN_BLOCK_SIZE), lm_index_map),
559
+ pl.BlockSpec((block_b, 1, block_q, MIN_BLOCK_SIZE), lm_index_map),
560
+ ]
561
+ l = jax.ShapeDtypeStruct(
562
+ (batch_size, num_heads, q_seq_len, MIN_BLOCK_SIZE),
563
+ dtype=jnp.float32)
564
+ m = jax.ShapeDtypeStruct(
565
+ (batch_size, num_heads, q_seq_len, MIN_BLOCK_SIZE),
566
+ dtype=jnp.float32)
567
+ out_shape = (*out_shape, l, m)
568
+ else:
569
+ out_specs = [*out_specs, None, None]
570
+ out_shape = (*out_shape, None, None)
571
+
572
+ ab_block_spec = (pl.BlockSpec(
573
+ (block_b, 1, block_q,
574
+ block_k_major), ab_index_map) if ab is not None else None)
575
+
576
+ q_segment_ids_spec = kv_segment_ids_spec = None
577
+ q_segment_ids = kv_segment_ids = None
578
+ if segment_ids is not None:
579
+
580
+ def q_segment_ids_index_map(batch_index, head_index, q_seq_index, _):
581
+ del head_index
582
+ return (batch_index, q_seq_index, 0)
583
+
584
+ def kv_segment_ids_index_map(batch_index, head_index, q_seq_index,
585
+ kv_seq_index):
586
+ del head_index
587
+ if causal:
588
+ next_kv_index = lax.select(
589
+ below_or_on_diag(q_seq_index, block_q, kv_seq_index,
590
+ block_k_major),
591
+ kv_seq_index,
592
+ 0,
593
+ )
594
+ else:
595
+ next_kv_index = kv_seq_index
596
+ return (batch_index, 0, next_kv_index)
597
+
598
+ q_segment_ids_spec = pl.BlockSpec((block_b, block_q, NUM_LANES),
599
+ q_segment_ids_index_map)
600
+ kv_segment_ids_spec = pl.BlockSpec(
601
+ (block_b, NUM_SUBLANES, block_k_major), kv_segment_ids_index_map)
602
+
603
+ q_segment_ids = jax.lax.broadcast_in_dim(
604
+ segment_ids.q,
605
+ (batch_size, q_seq_len, NUM_LANES),
606
+ (
607
+ 0,
608
+ 1,
609
+ ),
610
+ )
611
+ kv_segment_ids = jax.lax.broadcast_in_dim(
612
+ segment_ids.kv,
613
+ (batch_size, NUM_SUBLANES, kv_seq_len),
614
+ (
615
+ 0,
616
+ 2,
617
+ ),
618
+ )
619
+
620
+ in_specs = [
621
+ pl.BlockSpec((block_b, 1, block_q, head_dim), q_index_map),
622
+ pl.BlockSpec((block_b, 1, block_k_major, head_dim), kv_index_map),
623
+ pl.BlockSpec((block_b, 1, block_k_major, head_dim), kv_index_map),
624
+ ab_block_spec,
625
+ q_segment_ids_spec,
626
+ kv_segment_ids_spec,
627
+ ]
628
+
629
+ o, *aux = pl.pallas_call(
630
+ kernel,
631
+ grid_spec=pltpu.PrefetchScalarGridSpec(
632
+ num_scalar_prefetch=0,
633
+ grid=grid,
634
+ in_specs=in_specs,
635
+ out_specs=out_specs,
636
+ scratch_shapes=scratch_shapes,
637
+ ),
638
+ out_shape=out_shape,
639
+ debug=debug,
640
+ compiler_params=pltpu.CompilerParams(
641
+ dimension_semantics=(
642
+ "parallel",
643
+ "parallel",
644
+ "parallel",
645
+ "arbitrary",
646
+ ),
647
+ vmem_limit_bytes=vmem_limit_bytes,
648
+ ),
649
+ cost_estimate=_fwd_cost_estimate(
650
+ q,
651
+ k,
652
+ v,
653
+ ab,
654
+ segment_ids,
655
+ causal=causal,
656
+ sm_scale=sm_scale,
657
+ kernel_inputs_specs=(q, k, v, ab, q_segment_ids, kv_segment_ids),
658
+ kernel_outputs_specs=out_shape,
659
+ ),
660
+ )(q, k, v, ab, q_segment_ids, kv_segment_ids)
661
+ if save_residuals:
662
+ l, m = (v[..., 0] for v in aux[-2:])
663
+ return (o, l, m)
664
+ else:
665
+ return o
666
+
667
+
668
+ # For autograd testing.
669
+ def mha_reference_no_custom_vjp(
670
+ q,
671
+ k,
672
+ v,
673
+ ab: jax.Array | None = None,
674
+ segment_ids: SegmentIds | None = None,
675
+ *,
676
+ causal: bool = False,
677
+ mask_value: float = DEFAULT_MASK_VALUE,
678
+ sm_scale: float = 1.0,
679
+ save_residuals: bool = False,
680
+ ):
681
+ logits = jnp.einsum("bhqc,bhkc->bhqk", q, k)
682
+ if ab is not None:
683
+ logits += ab
684
+ if sm_scale != 1.0:
685
+ logits *= sm_scale
686
+
687
+ mask = None
688
+ if segment_ids is not None:
689
+ mask = segment_ids.q[:, :, None] == segment_ids.kv[:, None, :]
690
+ mask = mask[:, None, :, :]
691
+
692
+ if causal:
693
+ _, _, q_seq_len, _ = q.shape
694
+ _, _, kv_seq_len, _ = k.shape
695
+ mask_shape = (q_seq_len, kv_seq_len)
696
+ row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
697
+ col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
698
+ causal_mask = (col_ids <= row_ids)[None, None, :, :]
699
+ mask = causal_mask if mask is None else jnp.logical_and(
700
+ mask, causal_mask)
701
+
702
+ logits = logits if mask is None else logits + jnp.where(
703
+ mask, 0.0, mask_value)
704
+
705
+ m = logits.max(axis=-1)
706
+ unnormalized = jnp.exp(logits - m[..., None])
707
+ l = unnormalized.sum(axis=-1)
708
+ weights = unnormalized / l[..., None]
709
+ out = jnp.einsum("bhqk,bhkc->bhqc", weights, v)
710
+ if save_residuals:
711
+ return out, l, m
712
+ return out
713
+
714
+
715
+ @functools.partial(jax.jit,
716
+ static_argnames=["causal", "mask_value", "sm_scale"])
717
+ @jax.default_matmul_precision("bfloat16")
718
+ def mha_reference(
719
+ q,
720
+ k,
721
+ v,
722
+ ab,
723
+ segment_ids: SegmentIds | None = None,
724
+ causal: bool = False,
725
+ mask_value: float = DEFAULT_MASK_VALUE,
726
+ sm_scale=1.0,
727
+ ):
728
+ return _mha_reference(
729
+ q,
730
+ k,
731
+ v,
732
+ ab,
733
+ segment_ids,
734
+ causal=causal,
735
+ mask_value=mask_value,
736
+ sm_scale=sm_scale,
737
+ save_residuals=False,
738
+ )
739
+
740
+
741
+ def _mha_reference(
742
+ q,
743
+ k,
744
+ v,
745
+ ab,
746
+ segment_ids: SegmentIds | None,
747
+ causal: bool,
748
+ mask_value: float,
749
+ sm_scale: float,
750
+ save_residuals: bool,
751
+ ):
752
+ return mha_reference_no_custom_vjp(
753
+ q,
754
+ k,
755
+ v,
756
+ ab,
757
+ segment_ids,
758
+ causal=causal,
759
+ mask_value=mask_value,
760
+ sm_scale=sm_scale,
761
+ save_residuals=save_residuals,
762
+ )
763
+
764
+
765
+ def _verify_block(block_name, dim_name, block, dim, should_divide=True):
766
+ if block > dim:
767
+ raise ValueError(
768
+ f"{block_name}={block} should be smaller or equal to {dim_name}={dim}"
769
+ )
770
+ if should_divide and dim % block != 0:
771
+ raise ValueError(
772
+ f"{dim_name}={dim} should be divisible by {block_name}={block}")