mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2186 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # pyre-unsafe
7
+ """
8
+ This file contains biases that can be used as the `attn_bias` argument in
9
+ :attr:`xformers.ops.memory_efficient_attention`.
10
+ Essentially, a bias is a Tensor which will be added to the ``Q @ K.t`` before
11
+ computing the ``softmax``.
12
+
13
+
14
+ The goal of having custom made classes (instead of dense tensors) is that
15
+ we want to avoid having to load the biases from memory in the kernel, for
16
+ performance reasons. We also want to be able to know before-hand which
17
+ parts of the attention matrix we will need to compute (eg causal masks).
18
+
19
+
20
+ Some very common biases are LowerTriangularMask and BlockDiagonalMask.
21
+ """
22
+
23
+ import math
24
+ from dataclasses import dataclass
25
+ from typing import (
26
+ Any,
27
+ cast,
28
+ ClassVar,
29
+ Iterable,
30
+ List,
31
+ Optional,
32
+ Sequence,
33
+ Tuple,
34
+ Type,
35
+ Union,
36
+ )
37
+
38
+ import torch
39
+
40
+
41
+ def _to_device(t: torch.Tensor, device: torch.device) -> torch.Tensor:
42
+ if t.device == device:
43
+ return t
44
+ if device == torch.device("cpu"):
45
+ return t.to(device)
46
+
47
+ return t.to(device, non_blocking=True)
48
+
49
+
50
+ def _to_device_tensor(seq: Sequence[int], dtype: torch.dtype, device: torch.device):
51
+ if device == torch.device("cpu"):
52
+ return torch.tensor(seq, dtype=dtype)
53
+
54
+ return torch.tensor(seq, dtype=dtype, pin_memory=True).to(device, non_blocking=True)
55
+
56
+
57
+ class AttentionBias:
58
+ """Base class for a custom bias that can be applied \
59
+ as the attn_bias argument in
60
+ :attr:`xformers.ops.memory_efficient_attention`.
61
+
62
+ That function has the ability to add a tensor, the
63
+ attention bias, to the QK^T matrix before it is used
64
+ in the softmax part of the attention calculation.
65
+ The attention bias tensor with shape
66
+ (B or 1, n_queries, number of keys)
67
+ can be given as the attn_bias input.
68
+ The most common use case is for an attention bias is
69
+ to contain only zeros and negative infinities, which forms
70
+ a mask so that some queries only attend to some keys.
71
+
72
+ Children of this class define alternative things which can
73
+ be used as the attn_bias input to define an attention bias which
74
+ forms such a mask, for some common cases.
75
+
76
+ When using an :attr:`xformers.ops.AttentionBias`
77
+ instead of a :attr:`torch.Tensor`, the mask matrix does
78
+ not need to be materialized, and can be
79
+ hardcoded into some kernels for better performance.
80
+
81
+ See:
82
+
83
+ - :attr:`xformers.ops.fmha.attn_bias.LowerTriangularMask`
84
+ - :attr:`xformers.ops.fmha.attn_bias.LowerTriangularFromBottomRightMask`
85
+ - :attr:`xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias`
86
+ - :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`
87
+ - :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`
88
+
89
+ """
90
+
91
+ def materialize(
92
+ self,
93
+ shape: Tuple[int, ...],
94
+ dtype: torch.dtype = torch.float32,
95
+ device: Union[str, torch.device] = "cpu",
96
+ ) -> torch.Tensor:
97
+ """
98
+ Materializes the bias as a `torch.Tensor`. This is very slow
99
+ and we don't attempt to make it fast. Only use for debugging/testing.
100
+
101
+ Shape should be like `[*, q_seqlen, k_seqlen]`
102
+ """
103
+ raise NotImplementedError()
104
+
105
+
106
+ def _get_default_bias_device(device: Optional[torch.device] = None) -> torch.device:
107
+ if device is None:
108
+ if torch.cuda.is_available():
109
+ return torch.device("cuda")
110
+ if torch.mtia.is_available():
111
+ return torch.device("mtia")
112
+ return torch.device("cpu")
113
+ return device
114
+
115
+
116
+ def _materialize_causal_mask(
117
+ shape: Tuple[int, ...],
118
+ dtype: torch.dtype = torch.float32,
119
+ device: Union[str, torch.device] = "cpu",
120
+ *,
121
+ window_size: Optional[int] = None,
122
+ from_bottomright: bool = False,
123
+ ) -> torch.Tensor:
124
+ create_as = dtype if dtype is not torch.bfloat16 else torch.float32
125
+ tensor = torch.full( # type: ignore
126
+ shape,
127
+ dtype=create_as,
128
+ fill_value=1,
129
+ device=device,
130
+ )
131
+
132
+ num_queries, num_keys = shape[-2:]
133
+ shift = 0
134
+ if from_bottomright:
135
+ shift = num_keys - num_queries
136
+
137
+ mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
138
+ if window_size is not None:
139
+ mask = torch.triu(mask, diagonal=shift - window_size + 1)
140
+ mask = torch.log(mask)
141
+ return mask.to(dtype)
142
+
143
+
144
+ class LowerTriangularMask(AttentionBias):
145
+ """
146
+ A lower-triangular (aka causal) mask
147
+
148
+ A query Q cannot attend to a key which is farther from the
149
+ initial key than Q is from the initial query.
150
+
151
+ See also :attr:`LowerTriangularFromBottomRightMask` if the number
152
+ of queries is not equal to the number of keys/values.
153
+ """
154
+
155
+ def to(self, device: torch.device) -> "LowerTriangularMask":
156
+ assert type(self) is LowerTriangularMask, "Please implement in subclass"
157
+ return self
158
+
159
+ def materialize(
160
+ self,
161
+ shape: Tuple[int, ...],
162
+ dtype: torch.dtype = torch.float32,
163
+ device: Union[str, torch.device] = "cpu",
164
+ ) -> torch.Tensor:
165
+ return _materialize_causal_mask(shape, dtype=dtype, device=device)
166
+
167
+ def add_bias(self, bias: torch.Tensor) -> "LowerTriangularMaskWithTensorBias":
168
+ """
169
+ Creates a new causal mask with an arbitrary ``torch.Tensor`` bias
170
+ """
171
+ return LowerTriangularMaskWithTensorBias(bias)
172
+
173
+
174
+ @dataclass
175
+ class LocalAttentionFromBottomRightMask(AttentionBias):
176
+ """
177
+ A local attention mask
178
+
179
+ The query at position :math:`q` can attend the key at position :math:`k` if
180
+ :math:`q - window\\_left <= k + s <= q + window\\_right`
181
+
182
+ With :math:`s = num\\_queries - num\\_keys`
183
+
184
+ :Example:
185
+
186
+ .. code-block:: python
187
+
188
+ import torch
189
+ from xformers.ops import fmha
190
+
191
+ bias = fmha.attn_bias.LocalAttentionFromBottomRightMask(window_left=1, window_right=2)
192
+ print(bias.materialize(shape=(4, 4)).exp())
193
+ print(bias.materialize(shape=(4, 5)).exp())
194
+
195
+ .. code-block:: text
196
+
197
+ # 4x4
198
+ tensor([[1., 1., 1., 0.],
199
+ [1., 1., 1., 1.],
200
+ [0., 1., 1., 1.],
201
+ [0., 0., 1., 1.]])
202
+
203
+ # 4x5
204
+ tensor([[1., 1., 1., 1., 0.],
205
+ [0., 1., 1., 1., 1.],
206
+ [0., 0., 1., 1., 1.],
207
+ [0., 0., 0., 1., 1.]])
208
+
209
+ :Illustration:
210
+
211
+ .. figure:: /_static/local_attn.png
212
+ :width: 240px
213
+
214
+ The total window size is :math:`window\\_left + 1 + window\\_right`
215
+ """
216
+
217
+ window_left: int
218
+ window_right: int
219
+
220
+ def to(self, device) -> "LocalAttentionFromBottomRightMask":
221
+ return self
222
+
223
+ def __post_init__(self) -> None:
224
+ if self.window_left < 0:
225
+ raise ValueError(
226
+ "Invalid window value passed to "
227
+ "`LocalAttentionFromBottomRightMask`: expected"
228
+ f"`window_left > 0` but got window_left={self.window_left}"
229
+ )
230
+ if self.window_right < 0:
231
+ raise ValueError(
232
+ "Invalid window value passed to "
233
+ "`LocalAttentionFromBottomRightMask`: expected"
234
+ f"`window_right > 0` but got window_right={self.window_right}"
235
+ )
236
+
237
+ def materialize(
238
+ self,
239
+ shape: Tuple[int, ...],
240
+ dtype: torch.dtype = torch.float32,
241
+ device: Union[str, torch.device] = "cpu",
242
+ ) -> torch.Tensor:
243
+ create_as = dtype if dtype is not torch.bfloat16 else torch.float32
244
+ mask = torch.full( # type: ignore
245
+ shape,
246
+ dtype=create_as,
247
+ fill_value=1,
248
+ device=device,
249
+ )
250
+
251
+ mask = _apply_locality_on_mask(mask, self.window_left, self.window_right)
252
+ mask = torch.log(mask)
253
+ return mask.to(dtype)
254
+
255
+
256
+ class LowerTriangularFromBottomRightMask(AttentionBias):
257
+ """
258
+ A causal masking.
259
+
260
+ This mask is exactly the same as :attr:`LowerTriangularMask` when there is
261
+ the same number of queries and keys.
262
+ When the number of queries is different from the number of keys,
263
+ it is a triangular mask shifted so that the last query can attend to
264
+ the last key.
265
+ In other words, a query Q cannot attend to a key which is nearer the
266
+ final key than Q is to the final query.
267
+
268
+
269
+ .. figure:: /_static/causal_bottom_right.png
270
+
271
+ The difference between :attr:`LowerTriangularMask` (left) and
272
+ :attr:`LowerTriangularFromBottomRightMask` (right). They become
273
+ equivalent if the number of queries equals the number of keys.
274
+ """
275
+
276
+ def to(self, device: torch.device) -> "LowerTriangularFromBottomRightMask":
277
+ assert type(self) is LowerTriangularFromBottomRightMask, (
278
+ "Please implement in subclass"
279
+ )
280
+ return self
281
+
282
+ def materialize(
283
+ self,
284
+ shape: Tuple[int, ...],
285
+ dtype: torch.dtype = torch.float32,
286
+ device: Union[str, torch.device] = "cpu",
287
+ ) -> torch.Tensor:
288
+ return _materialize_causal_mask(
289
+ shape, dtype=dtype, device=device, from_bottomright=True
290
+ )
291
+
292
+ def make_local_attention(
293
+ self, window_size: int
294
+ ) -> "LowerTriangularFromBottomRightLocalAttentionMask":
295
+ """
296
+ Create a new bias which combines local + causal attention.
297
+
298
+ See :attr:`LowerTriangularFromBottomRightLocalAttentionMask`
299
+ """
300
+ return LowerTriangularFromBottomRightLocalAttentionMask(window_size)
301
+
302
+
303
+ @dataclass
304
+ class LowerTriangularFromBottomRightLocalAttentionMask(
305
+ LowerTriangularFromBottomRightMask
306
+ ):
307
+ """
308
+ A mask that combines both :attr:`LowerTriangularFromBottomRightMask` and
309
+ local attention.
310
+
311
+ A query whose distance from the final query is X cannot attend to a key
312
+ whose distance to the final key is either of:
313
+
314
+ * less than X (i.e. "causal attention", same as :attr:`LowerTriangularFromBottomRightMask`)
315
+ * greater than or equal to X + window_size (i.e. "local attention")
316
+
317
+
318
+ .. figure:: /_static/causal_bottom_right_local.png
319
+
320
+ The mask from :attr:`LowerTriangularFromBottomRightLocalAttentionMask`.
321
+ The green area is calculated, and the grey area is masked out.
322
+ """
323
+
324
+ _window_size: int
325
+
326
+ def to(
327
+ self, device: torch.device
328
+ ) -> "LowerTriangularFromBottomRightLocalAttentionMask":
329
+ assert type(self) is LowerTriangularFromBottomRightLocalAttentionMask, (
330
+ "Please implement in subclass"
331
+ )
332
+ return self
333
+
334
+ def __post_init__(self) -> None:
335
+ if self._window_size <= 0:
336
+ raise ValueError(
337
+ f"Expected `window_size > 0`, but window_size={self._window_size}"
338
+ )
339
+
340
+ def materialize(
341
+ self,
342
+ shape: Tuple[int, ...],
343
+ dtype: torch.dtype = torch.float32,
344
+ device: Union[str, torch.device] = "cpu",
345
+ ) -> torch.Tensor:
346
+ return _materialize_causal_mask(
347
+ shape,
348
+ dtype=dtype,
349
+ device=device,
350
+ window_size=self._window_size,
351
+ from_bottomright=True,
352
+ )
353
+
354
+
355
+ class LowerTriangularMaskWithTensorBias(LowerTriangularMask):
356
+ """A lower-triangular (aka causal) mask with an additive bias"""
357
+
358
+ def __init__(self, bias: torch.Tensor) -> None:
359
+ self._bias = bias
360
+
361
+ def to(self, device: torch.device) -> "LowerTriangularMaskWithTensorBias":
362
+ assert type(self) is LowerTriangularMaskWithTensorBias, (
363
+ "Please implement in subclass"
364
+ )
365
+ return LowerTriangularMaskWithTensorBias(_to_device(self._bias, device))
366
+
367
+ def materialize(
368
+ self,
369
+ shape: Tuple[int, ...],
370
+ dtype: torch.dtype = torch.float32,
371
+ device: Union[str, torch.device] = "cpu",
372
+ ) -> torch.Tensor:
373
+ return super().materialize(shape, dtype=dtype, device=device) + self._bias
374
+
375
+
376
+ @dataclass
377
+ class _SeqLenInfo:
378
+ """
379
+ (Internal) Represents the division of a dimension into blocks.
380
+
381
+ For example, to represents a dimension of length 7 divided into
382
+ three blocks of lengths 2, 3 and 2, use `from_seqlength([2, 3, 2])`.
383
+ The members will be:
384
+ max_seqlen: 3
385
+ min_seqlen: 2
386
+ seqstart_py: [0, 2, 5, 7]
387
+ seqstart: torch.IntTensor([0, 2, 5, 7])
388
+ """
389
+
390
+ seqstart: torch.Tensor
391
+ max_seqlen: int
392
+ min_seqlen: int
393
+ seqstart_py: List[int]
394
+
395
+ def to(self, device: torch.device) -> "_SeqLenInfo":
396
+ assert type(self) is _SeqLenInfo, "Please implement in subclass"
397
+ if self.seqstart.device == device:
398
+ return self
399
+ return _SeqLenInfo(
400
+ seqstart=_to_device(self.seqstart, device),
401
+ max_seqlen=self.max_seqlen,
402
+ min_seqlen=self.min_seqlen,
403
+ seqstart_py=self.seqstart_py,
404
+ )
405
+
406
+ def intervals(self) -> Iterable[Tuple[int, int]]:
407
+ yield from zip(self.seqstart_py, self.seqstart_py[1:])
408
+
409
+ @classmethod
410
+ def _get_seqstart(
411
+ cls, seqlens: Iterable[int], *, device: torch.device
412
+ ) -> Tuple[int, int, List[int], torch.Tensor]:
413
+ """
414
+ Given sequence lengths, returns the min/max value and the sequence start
415
+ positions (offsets), with first element being 0 (returned in list and Tensor).
416
+ """
417
+
418
+ assert not isinstance(seqlens, torch.Tensor)
419
+ seqstart_py = [0]
420
+ max_seqlen = -1
421
+ min_seqlen = -1
422
+ for seqlen in seqlens:
423
+ min_seqlen = min(min_seqlen, seqlen) if min_seqlen != -1 else seqlen
424
+ max_seqlen = max(max_seqlen, seqlen)
425
+ seqstart_py.append(seqstart_py[len(seqstart_py) - 1] + seqlen)
426
+ seqstart = _to_device_tensor(seqstart_py, dtype=torch.int32, device=device)
427
+
428
+ return (min_seqlen, max_seqlen, seqstart_py, seqstart)
429
+
430
+ @classmethod
431
+ def from_seqlens(
432
+ cls, seqlens: Iterable[int], *, device: Optional[torch.device] = None
433
+ ) -> "_SeqLenInfo":
434
+ """
435
+ Input tensors are assumed to be in shape [B, M, *]
436
+ """
437
+ device = _get_default_bias_device(device)
438
+ min_seqlen, max_seqlen, seqstart_py, seqstart = cls._get_seqstart(
439
+ seqlens, device=device
440
+ )
441
+
442
+ return cls(
443
+ max_seqlen=max_seqlen,
444
+ min_seqlen=min_seqlen,
445
+ seqstart=seqstart,
446
+ seqstart_py=seqstart_py,
447
+ )
448
+
449
+ def from_seqlens_inplace(self, seqlens: Iterable[int]) -> None:
450
+ """
451
+ Perform in-place update. You can only update with the same shape.
452
+ Can be useful with CUDA graphs.
453
+ """
454
+ min_seqlen, max_seqlen, seqstart_py, seqstart = self._get_seqstart(
455
+ seqlens, device=self.seqstart.device
456
+ )
457
+
458
+ assert len(seqstart_py) == len(self.seqstart_py), (
459
+ f"Old / New len {len(self.seqstart_py)} / {len(seqstart)}, "
460
+ f"Contents {self.seqstart_py} / {seqstart}"
461
+ )
462
+ assert self.max_seqlen >= max_seqlen, (
463
+ f"For inplace update, new max_seqlen {max_seqlen} "
464
+ f"cannot exceed the previous max_seqlen {self.max_seqlen}"
465
+ )
466
+ for i in range(len(seqstart_py)):
467
+ self.seqstart_py[i] = seqstart_py[i]
468
+ self.seqstart.copy_(seqstart, non_blocking=True)
469
+
470
+ def split(
471
+ self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None
472
+ ) -> List[torch.Tensor]:
473
+ if self.seqstart_py[-1] != x.shape[1] or x.shape[0] != 1:
474
+ raise ValueError(
475
+ f"Invalid `torch.Tensor` of shape {x.shape}, expected format "
476
+ f"(B, M, *) with B=1 and M={self.seqstart_py[-1]}\n"
477
+ f" seqstart: {self.seqstart_py}"
478
+ )
479
+ if batch_sizes is None:
480
+ batch_sizes = [1] * (len(self.seqstart_py) - 1)
481
+ split_chunks = []
482
+ it = 0
483
+ for batch_size in batch_sizes:
484
+ split_chunks.append(
485
+ self.seqstart_py[it + batch_size] - self.seqstart_py[it]
486
+ )
487
+ it += batch_size
488
+ return [
489
+ tensor.reshape([bs, -1, *tensor.shape[2:]])
490
+ for bs, tensor in zip(batch_sizes, x.split(split_chunks, dim=1))
491
+ ]
492
+
493
+
494
+ @dataclass
495
+ class _PaddedSeqLenInfo(_SeqLenInfo):
496
+ """
497
+ (Internal) Represents the division of a dimension into blocks which are
498
+ padded out to the same total length.
499
+
500
+ For example, to represent a dimension of length 12 with space for
501
+ three blocks of length 4, but where the occupied lengths are
502
+ 2, 3 and 2, use `from_seqlens_padded([2, 3, 2], 4)`.
503
+
504
+ The layout along the dimension is
505
+
506
+ 0 ─► block 0
507
+ block 0
508
+ <space>
509
+ <space>
510
+ 4 ─► block 1
511
+ block 1
512
+ block 1
513
+ <space>
514
+ 8 ─► block 2
515
+ block 2
516
+ <space>
517
+ <space>
518
+ 12 ─►
519
+
520
+ The members will be:
521
+ max_seqlen: 3
522
+ min_seqlen: 2
523
+ seqstart_py: [0, 4, 8, 12]
524
+ seqstart: torch.IntTensor([0, 4, 8, 12])
525
+ seqlen_py: [2, 3, 2]
526
+ seqlen: torch.IntTensor([2, 3, 2])
527
+ padding: 4
528
+ """
529
+
530
+ seqlen: torch.Tensor
531
+ seqlen_py: List[int]
532
+ padding: int
533
+ # From parent: seqstart[i] contains the start position
534
+ # of the i-th sequence
535
+ # seqstart: torch.Tensor
536
+
537
+ def __post_init__(self) -> None:
538
+ assert len(self.seqstart_py) == len(self.seqlen_py) + 1
539
+
540
+ def to(self, device: torch.device) -> "_PaddedSeqLenInfo":
541
+ assert type(self) is _PaddedSeqLenInfo, "Please implement in subclass"
542
+ if self.seqlen.device == device:
543
+ return self
544
+ return _PaddedSeqLenInfo(
545
+ # _SeqLenInfo
546
+ seqstart=_to_device(self.seqstart, device),
547
+ max_seqlen=self.max_seqlen,
548
+ min_seqlen=self.min_seqlen,
549
+ seqstart_py=self.seqstart_py,
550
+ # _PaddedSeqLenInfo
551
+ seqlen=_to_device(self.seqlen, device),
552
+ seqlen_py=self.seqlen_py,
553
+ padding=self.padding,
554
+ )
555
+
556
+ def intervals(self) -> Iterable[Tuple[int, int]]:
557
+ for (start, _), length in zip(super().intervals(), self.seqlen_py):
558
+ yield start, start + length
559
+
560
+ @classmethod
561
+ def from_seqlens(
562
+ cls, seqlens: Iterable[int], *, device: Optional[torch.device] = None
563
+ ) -> "_SeqLenInfo":
564
+ raise RuntimeError(
565
+ "Use either `_SeqLenInfo.from_seqlens` or `_PaddedSeqLenInfo.from_seqlens_padded`"
566
+ )
567
+
568
+ @classmethod
569
+ def from_seqlens_padded(
570
+ cls,
571
+ seqlens: Sequence[int],
572
+ padding: int,
573
+ *,
574
+ device: Optional[torch.device] = None,
575
+ ) -> "_PaddedSeqLenInfo":
576
+ """
577
+ Input tensors are assumed to be in shape [B, M, *]
578
+ seqstart = padding * torch.arange(batch_size)
579
+ """
580
+ assert not isinstance(seqlens, torch.Tensor)
581
+ assert all(seqlen <= padding for seqlen in seqlens), (
582
+ f"Seqlens {seqlens} Padding {padding}"
583
+ )
584
+ device = _get_default_bias_device(device)
585
+ seqstart_py = list(range(0, len(seqlens) * padding + 1, padding))
586
+ seqlen = _to_device_tensor(seqlens, dtype=torch.int32, device=device)
587
+ return cls(
588
+ seqlen=seqlen,
589
+ seqlen_py=seqlens if isinstance(seqlens, list) else list(seqlens),
590
+ max_seqlen=max(seqlens),
591
+ min_seqlen=min(seqlens),
592
+ seqstart=_to_device_tensor(seqstart_py, dtype=torch.int32, device=device),
593
+ seqstart_py=seqstart_py,
594
+ padding=padding,
595
+ )
596
+
597
+ def from_seqlens_padded_inplace(self, seqlens: Sequence[int]) -> None:
598
+ """
599
+ Perform in-place update. You can only update with the same shape.
600
+ Can be useful with CUDA graphs.
601
+ Note: we don't update padding because they would have been already baked
602
+ into CUDA graphs during the generation.
603
+ """
604
+ assert not isinstance(seqlens, torch.Tensor)
605
+ assert all(seqlen <= self.padding for seqlen in seqlens), (
606
+ f"Seqlens {seqlens} Padding {self.padding}"
607
+ )
608
+ seqlen_tensor = torch.tensor(seqlens, dtype=torch.int32)
609
+
610
+ assert len(self.seqlen_py) == len(seqlens), (
611
+ f"Old/New len {len(self.seqlen_py)} / {len(seqlens)}, "
612
+ f"Contents {self.seqlen_py} / {seqlens}"
613
+ )
614
+ assert self.max_seqlen >= max(seqlens), (
615
+ f"For inplace update, new max_seqlen {max(seqlens)} "
616
+ f"cannot exceed the previous max_seqlen {self.max_seqlen}"
617
+ )
618
+
619
+ for i in range(len(self.seqlen_py)):
620
+ self.seqlen_py[i] = seqlens[i]
621
+
622
+ self.seqlen.copy_(seqlen_tensor, non_blocking=True)
623
+
624
+ def split(
625
+ self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None
626
+ ) -> List[torch.Tensor]:
627
+ raise NotImplementedError("_PaddedSeqLenInfo.split")
628
+
629
+
630
+ @dataclass
631
+ class _GappySeqInfo(_SeqLenInfo):
632
+ """
633
+ (Internal) Flexible equivalent of _PaddedSeqLenInfo. There are two
634
+ distinct semantics.
635
+
636
+ (1) For non-paged masks:
637
+ Represents the division of a dimension into blocks which are
638
+ anywhere. Each just has a start and a length. The final start is the total
639
+ length of the dimension.
640
+
641
+ For example, to represent a dimension of length 14 like follows with
642
+ three occupied lengths of
643
+ 6, 3 and 1, use `from_seqlens_padded([0, 7, 12, 14], [6, 3, 1])`.
644
+
645
+ The layout along the dimension is
646
+
647
+ 0 ─► block 0
648
+ block 0
649
+ block 0
650
+ block 0
651
+ 4 ─► block 0
652
+ block 0
653
+ <space>
654
+ block 1
655
+ 8 ─► block 1
656
+ block 1
657
+ <space>
658
+ <space>
659
+ 12 ─► block 2
660
+ <space>
661
+
662
+ The members will be:
663
+ max_seqlen: 6
664
+ min_seqlen: 1
665
+ seqstart_py: [0, 7, 12, 14]
666
+ seqstart: torch.IntTensor([0, 7, 12, 14])
667
+ seqlen_py: [6, 3, 1]
668
+ seqlen: torch.IntTensor([6, 3, 1])
669
+
670
+ (2) For paged masks:
671
+ The notional space is divided into batch-size-many blocks.
672
+ seqstart and seqstart_py is an offset in the block, not in
673
+ the whole space, and the extra last element is not important.
674
+ And seqlen is the index of the last key in the block.
675
+ Otherwise as above.
676
+ """
677
+
678
+ seqlen: torch.Tensor
679
+ seqlen_py: Sequence[int]
680
+ # From parent: seqstart[i] contains the start position
681
+ # of the i-th sequence
682
+ # seqstart: torch.Tensor
683
+
684
+ def to(self, device: torch.device) -> "_GappySeqInfo":
685
+ assert type(self) is _GappySeqInfo, "Please implement in subclass"
686
+ if self.seqlen.device == device:
687
+ return self
688
+ return _GappySeqInfo(
689
+ # _SeqLenInfo
690
+ seqstart=_to_device(self.seqstart, device),
691
+ max_seqlen=self.max_seqlen,
692
+ min_seqlen=self.min_seqlen,
693
+ seqstart_py=self.seqstart_py,
694
+ # _GappySeqInfo
695
+ seqlen=_to_device(self.seqlen, device),
696
+ seqlen_py=self.seqlen_py,
697
+ )
698
+
699
+ def intervals(self) -> Iterable[Tuple[int, int]]:
700
+ for (start, _), length in zip(super().intervals(), self.seqlen_py):
701
+ yield start, start + length
702
+
703
+ @classmethod
704
+ def from_seqlens(
705
+ cls, seqlens: Iterable[int], *, device: Optional[torch.device] = None
706
+ ) -> "_SeqLenInfo":
707
+ raise NotImplementedError()
708
+
709
+ @classmethod
710
+ def from_seqlens_gappy(
711
+ cls,
712
+ seqstarts: Sequence[int],
713
+ seqlens: Sequence[int],
714
+ paged: bool,
715
+ *,
716
+ device: torch.device,
717
+ ) -> "_GappySeqInfo":
718
+ assert not isinstance(seqlens, torch.Tensor)
719
+ seqstart_py = list(seqstarts)
720
+ if len(seqlens) == 0:
721
+ raise ValueError("No elements")
722
+ if len(seqstarts) - len(seqlens) != 1:
723
+ raise ValueError(
724
+ f"len(seqstarts)={seqstarts} should be len(seqlens)={seqlens}"
725
+ )
726
+ max_seqlen = max(seqlens)
727
+ min_seqlen = min(seqlens)
728
+ if paged:
729
+ seqstart_py.append(-1)
730
+ seqlens = [i + j for i, j in zip(seqstart_py, seqlens)]
731
+ seqlen = _to_device_tensor(seqlens, dtype=torch.int32, device=device)
732
+ return cls(
733
+ seqlen=seqlen,
734
+ seqlen_py=seqlens,
735
+ max_seqlen=max_seqlen,
736
+ min_seqlen=min_seqlen,
737
+ seqstart=_to_device_tensor(seqstart_py, dtype=torch.int32, device=device),
738
+ seqstart_py=seqstart_py,
739
+ )
740
+
741
+ def split(
742
+ self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None
743
+ ) -> List[torch.Tensor]:
744
+ raise NotImplementedError("_GappySeqInfo.split")
745
+
746
+
747
+ @dataclass
748
+ class BlockDiagonalMask(AttentionBias):
749
+ """
750
+ A block-diagonal mask that can be passed as ``attn_bias``
751
+ argument to :attr:`xformers.ops.memory_efficient_attention`.
752
+
753
+ Queries and Keys are each divided into the same number of blocks.
754
+ Queries in block i only attend to keys in block i.
755
+
756
+ .. figure:: /_static/block_diag_bias.png
757
+
758
+ This bias can be used to handle a batch of sequences of
759
+ different lengths, via :attr:`BlockDiagonalMask.from_tensor_list`
760
+
761
+ :Example:
762
+
763
+ .. code-block:: python
764
+
765
+ import torch
766
+ from xformers.ops import fmha
767
+
768
+ K = 16
769
+ dtype = torch.float16
770
+ device = "cuda"
771
+ list_x = [
772
+ torch.randn([1, 3, 1, K], dtype=dtype, device=device),
773
+ torch.randn([1, 6, 1, K], dtype=dtype, device=device),
774
+ torch.randn([1, 2, 1, K], dtype=dtype, device=device),
775
+ ]
776
+ attn_bias, x = fmha.BlockDiagonalMask.from_tensor_list(list_x)
777
+ linear = torch.nn.Linear(K, K * 3).to(device=device, dtype=dtype)
778
+
779
+ q, k, v = linear(x).reshape([1, -1, 1, 3, K]).unbind(-2)
780
+ out = fmha.memory_efficient_attention(q, k, v, attn_bias=attn_bias)
781
+ list_out = attn_bias.split(out)
782
+ print(list_out[0].shape) # [1, 3, 1, K]
783
+ assert tuple(list_out[0].shape) == (1, 3, 1, K)
784
+
785
+ """
786
+
787
+ q_seqinfo: _SeqLenInfo
788
+ k_seqinfo: _SeqLenInfo
789
+ _batch_sizes: Optional[Sequence[int]] = None
790
+
791
+ def to(self, device) -> "BlockDiagonalMask":
792
+ assert type(self) is BlockDiagonalMask, "Please implement in subclass"
793
+ return BlockDiagonalMask(
794
+ q_seqinfo=self.q_seqinfo.to(device),
795
+ k_seqinfo=self.k_seqinfo.to(device),
796
+ _batch_sizes=self._batch_sizes,
797
+ )
798
+
799
+ def _create_block_mask(
800
+ self,
801
+ shape: Tuple[int, ...],
802
+ dtype: torch.dtype = torch.float32,
803
+ device: Union[str, torch.device] = "cpu",
804
+ ) -> torch.Tensor:
805
+ return torch.zeros(
806
+ shape,
807
+ dtype=dtype,
808
+ device=device,
809
+ )
810
+
811
+ def materialize(
812
+ self,
813
+ shape: Tuple[int, ...],
814
+ dtype: torch.dtype = torch.float32,
815
+ device: Union[str, torch.device] = "cpu",
816
+ ) -> torch.Tensor:
817
+ """Materialize the attention bias - for debugging & testing"""
818
+ assert shape[-1] == self.k_seqinfo.seqstart_py[-1], (
819
+ shape[-1],
820
+ self.k_seqinfo.seqstart_py[-1],
821
+ )
822
+ assert shape[-2] == self.q_seqinfo.seqstart_py[-1], (
823
+ shape[-2],
824
+ self.q_seqinfo.seqstart_py[-1],
825
+ )
826
+ mask = torch.empty(shape[-2:], dtype=dtype, device=device)
827
+ mask.fill_(-math.inf)
828
+ for (q_start, q_end), (k_start, k_end) in zip(
829
+ self.q_seqinfo.intervals(),
830
+ self.k_seqinfo.intervals(),
831
+ ):
832
+ mask[q_start:q_end, k_start:k_end] = self._create_block_mask(
833
+ (q_end - q_start, k_end - k_start),
834
+ dtype=dtype,
835
+ device=device,
836
+ )
837
+ for _ in range(len(shape) - 2):
838
+ mask = mask.unsqueeze(0)
839
+ return mask.expand(shape)
840
+
841
+ @classmethod
842
+ def from_seqlens(
843
+ cls,
844
+ q_seqlen: Sequence[int],
845
+ kv_seqlen: Optional[Sequence[int]] = None,
846
+ *,
847
+ device: Optional[torch.device] = None,
848
+ ) -> "BlockDiagonalMask":
849
+ """Creates a :attr:`BlockDiagonalMask` from a list of tensors lengths for query and key/value.
850
+
851
+ Args:
852
+ q_seqlen (Union[Sequence[int], torch.Tensor]): List or tensor of sequence lengths for query tensors
853
+ kv_seqlen (Union[Sequence[int], torch.Tensor], optional): List or tensor of sequence lengths for key/value.
854
+ (Defaults to ``q_seqlen``.)
855
+ Returns:
856
+ BlockDiagonalMask
857
+ """
858
+ device = _get_default_bias_device(device)
859
+ assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen)
860
+ q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen, device=device)
861
+ if kv_seqlen is None or q_seqlen == kv_seqlen:
862
+ k_seqinfo = q_seqinfo
863
+ else:
864
+ k_seqinfo = _SeqLenInfo.from_seqlens(kv_seqlen, device=device)
865
+ return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo)
866
+
867
+ @classmethod
868
+ def from_tensor_list(
869
+ cls,
870
+ tensors: Sequence[torch.Tensor],
871
+ ) -> Tuple["BlockDiagonalMask", torch.Tensor]:
872
+ """Creates a :attr:`BlockDiagonalMask` from a list of tensors, and returns the tensors
873
+ concatenated on the sequence length dimension
874
+
875
+ .. figure:: /_static/block_diag_cat_split.png
876
+
877
+ See also :attr:`BlockDiagonalMask.split` to split the returned
878
+ :attr:`torch.Tensor` back to a list of tensors of varying sequence length
879
+
880
+ Args:
881
+ tensors (Sequence[torch.Tensor]): A list of tensors of shape ``[B, M_i, *]``.
882
+ All tensors should have the same dimension and the same batch size ``B``, but
883
+ they can have different sequence length ``M``.
884
+
885
+ Returns:
886
+ Tuple[BlockDiagonalMask, torch.Tensor]: The corresponding bias for the attention
887
+ along with `tensors` concatenated on the sequence length dimension, with shape ``[1, sum_i{M_i}, *]``
888
+ """
889
+ batch_sizes = [tensor.shape[0] for tensor in tensors]
890
+ seqlens = []
891
+ for x in tensors:
892
+ for _ in range(x.shape[0]):
893
+ seqlens.append(x.shape[1])
894
+ block_diag = cls.from_seqlens(seqlens)
895
+ block_diag._batch_sizes = batch_sizes
896
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in tensors)
897
+ concat_tensors = torch.cat(tensors_bs1, dim=1)
898
+ return block_diag, concat_tensors
899
+
900
+ @classmethod
901
+ def from_tensor_lists_qkv(
902
+ cls,
903
+ tensors_q: Sequence[torch.Tensor],
904
+ tensors_k: Sequence[torch.Tensor],
905
+ tensors_v: Optional[Sequence[torch.Tensor]] = None,
906
+ ) -> Tuple["BlockDiagonalMask", torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
907
+ assert len(tensors_q) == len(tensors_k)
908
+ assert tensors_v is None or len(tensors_v) == len(tensors_q)
909
+ batch_sizes = [tensor.shape[0] for tensor in tensors_q]
910
+ q_seqlens, kv_seqlens = [], []
911
+ for i, (q, k) in enumerate(zip(tensors_q, tensors_k)):
912
+ assert q.shape[0] == k.shape[0]
913
+ q_seqlens += [q.shape[1]] * q.shape[0]
914
+ kv_seqlens += [k.shape[1]] * k.shape[0]
915
+ assert tensors_v is None or tensors_v[i].shape[:2] == k.shape[:2]
916
+ block_diag = cls.from_seqlens(q_seqlens, kv_seqlens)
917
+ block_diag._batch_sizes = batch_sizes
918
+ return (
919
+ block_diag,
920
+ torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_q], dim=1),
921
+ torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_k], dim=1),
922
+ torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_v], dim=1)
923
+ if tensors_v is not None
924
+ else None,
925
+ )
926
+
927
+ def split_queries(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]:
928
+ return self.q_seqinfo.split(tensor, self._batch_sizes)
929
+
930
+ def split_kv(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]:
931
+ return self.k_seqinfo.split(tensor, self._batch_sizes)
932
+
933
+ def split(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]:
934
+ """The inverse operation of :attr:`BlockDiagonalCausalMask.from_tensor_list`
935
+
936
+ Args:
937
+ tensor (torch.Tensor): Tensor of tokens of shape ``[1, sum_i{M_i}, *]``
938
+
939
+ Returns:
940
+ Sequence[torch.Tensor]: A list of tokens with possibly different sequence lengths
941
+ """
942
+ assert self.q_seqinfo is self.k_seqinfo
943
+ return self.q_seqinfo.split(tensor, self._batch_sizes)
944
+
945
+ def make_causal(self) -> "BlockDiagonalCausalMask":
946
+ """Makes each block causal"""
947
+ return BlockDiagonalCausalMask(
948
+ q_seqinfo=self.q_seqinfo,
949
+ k_seqinfo=self.k_seqinfo,
950
+ _batch_sizes=self._batch_sizes,
951
+ )
952
+
953
+ def make_causal_from_bottomright(self) -> "BlockDiagonalCausalFromBottomRightMask":
954
+ """Makes each block causal with a possible non-causal prefix"""
955
+ return BlockDiagonalCausalFromBottomRightMask(
956
+ q_seqinfo=self.q_seqinfo,
957
+ k_seqinfo=self.k_seqinfo,
958
+ _batch_sizes=self._batch_sizes,
959
+ )
960
+
961
+ def make_local_attention(
962
+ self, window_size: int
963
+ ) -> "BlockDiagonalCausalLocalAttentionMask":
964
+ """Experimental: Makes each block causal with local attention"""
965
+ return BlockDiagonalCausalLocalAttentionMask(
966
+ q_seqinfo=self.q_seqinfo,
967
+ k_seqinfo=self.k_seqinfo,
968
+ _batch_sizes=self._batch_sizes,
969
+ _window_size=window_size,
970
+ )
971
+
972
+ def make_local_attention_from_bottomright(
973
+ self, window_size: int
974
+ ) -> "BlockDiagonalCausalLocalAttentionFromBottomRightMask":
975
+ """Experimental: Makes each block causal with local attention, start from bottom right"""
976
+ return BlockDiagonalCausalLocalAttentionFromBottomRightMask(
977
+ q_seqinfo=self.q_seqinfo,
978
+ k_seqinfo=self.k_seqinfo,
979
+ _batch_sizes=self._batch_sizes,
980
+ _window_size=window_size,
981
+ )
982
+
983
+
984
+ @dataclass
985
+ class BlockDiagonalCausalMask(BlockDiagonalMask):
986
+ """
987
+ Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`, except that each block is causal.
988
+
989
+ Queries and Keys are each divided into the same number of blocks.
990
+ A query Q in block i cannot attend to a key which is not in block i,
991
+ nor one which is farther from the initial key in block i than Q
992
+ is from the initial query in block i.
993
+ """
994
+
995
+ def to(self, device) -> "BlockDiagonalCausalMask":
996
+ assert type(self) is BlockDiagonalCausalMask, "Please implement in subclass"
997
+ return BlockDiagonalCausalMask(
998
+ q_seqinfo=self.q_seqinfo.to(device),
999
+ k_seqinfo=self.k_seqinfo.to(device),
1000
+ _batch_sizes=self._batch_sizes,
1001
+ )
1002
+
1003
+ def _create_block_mask(
1004
+ self,
1005
+ shape: Tuple[int, ...],
1006
+ dtype: torch.dtype = torch.float32,
1007
+ device: Union[str, torch.device] = "cpu",
1008
+ ) -> torch.Tensor:
1009
+ return LowerTriangularMask().materialize(
1010
+ shape,
1011
+ dtype=dtype,
1012
+ device=device,
1013
+ )
1014
+
1015
+
1016
+ @dataclass
1017
+ class BlockDiagonalCausalFromBottomRightMask(BlockDiagonalMask):
1018
+ """
1019
+ Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`, except that each block is causal.
1020
+ This mask allows for a non-causal prefix
1021
+ NOTE: Each block should have `num_keys >= num_queries` otherwise the forward pass is not
1022
+ defined (softmax of vector of `-inf` in the attention)
1023
+
1024
+ Queries and keys are each divided into the same number of blocks.
1025
+ A query Q in block i cannot attend to a key which is not in block i,
1026
+ nor one which nearer the final key in block i than Q is to the
1027
+ final query in block i.
1028
+ """
1029
+
1030
+ def to(self, device) -> "BlockDiagonalCausalFromBottomRightMask":
1031
+ assert type(self) is BlockDiagonalCausalFromBottomRightMask, (
1032
+ "Please implement in subclass"
1033
+ )
1034
+ return BlockDiagonalCausalFromBottomRightMask(
1035
+ q_seqinfo=self.q_seqinfo.to(device),
1036
+ k_seqinfo=self.k_seqinfo.to(device),
1037
+ _batch_sizes=self._batch_sizes,
1038
+ )
1039
+
1040
+ def __post_init__(self) -> None:
1041
+ for i, ((q_start, q_end), (k_start, k_end)) in enumerate(
1042
+ zip(
1043
+ self.q_seqinfo.intervals(),
1044
+ self.k_seqinfo.intervals(),
1045
+ )
1046
+ ):
1047
+ num_queries = q_end - q_start
1048
+ num_keys = k_end - k_start
1049
+ if num_keys < num_queries:
1050
+ raise ValueError(
1051
+ f"Block #{i} has num_keys={num_keys} and num_queries={num_queries}."
1052
+ " Expected `num_keys >= num_queries`"
1053
+ )
1054
+
1055
+ def _create_block_mask(
1056
+ self,
1057
+ shape: Tuple[int, ...],
1058
+ dtype: torch.dtype = torch.float32,
1059
+ device: Union[str, torch.device] = "cpu",
1060
+ ) -> torch.Tensor:
1061
+ return LowerTriangularFromBottomRightMask().materialize(
1062
+ shape=shape, dtype=dtype, device=device
1063
+ )
1064
+
1065
+
1066
+ @dataclass
1067
+ class BlockDiagonalPaddedKeysMask(AttentionBias):
1068
+ """
1069
+ Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`,
1070
+ except we support padding for k/v
1071
+
1072
+ The keys and values are divided into blocks which are padded out to
1073
+ the same total length.
1074
+ For example, if there is space for 12 keys, for three blocks of
1075
+ max length 4, but we only want to use the first 2, 3 and 2
1076
+ of each block, use `kv_padding=4` and `kv_seqlens=[2, 3, 2]`.
1077
+ The queries are divided into blocks, without padding, of lengths given by
1078
+ q_seqlen.
1079
+
1080
+ A query Q in block i cannot attend to a key which is not in block i,
1081
+ nor one which is not in use (i.e. in the padded area).
1082
+ """
1083
+
1084
+ q_seqinfo: _SeqLenInfo
1085
+ k_seqinfo: _PaddedSeqLenInfo
1086
+
1087
+ def to(self, device) -> "BlockDiagonalPaddedKeysMask":
1088
+ assert type(self) is BlockDiagonalPaddedKeysMask, "Please implement in subclass"
1089
+ return BlockDiagonalPaddedKeysMask(
1090
+ q_seqinfo=self.q_seqinfo.to(device),
1091
+ k_seqinfo=self.k_seqinfo.to(device),
1092
+ )
1093
+
1094
+ def _create_block_mask(
1095
+ self,
1096
+ shape: Tuple[int, ...],
1097
+ dtype: torch.dtype = torch.float32,
1098
+ device: Union[str, torch.device] = "cpu",
1099
+ ) -> torch.Tensor:
1100
+ return torch.zeros([1], device=device, dtype=dtype)
1101
+
1102
+ def materialize(
1103
+ self,
1104
+ shape: Tuple[int, ...],
1105
+ dtype: torch.dtype = torch.float32,
1106
+ device: Union[str, torch.device] = "cpu",
1107
+ ) -> torch.Tensor:
1108
+ """Materialize the attention bias - for debugging & testing"""
1109
+ if shape[-1] != self.k_seqinfo.seqstart_py[-1]:
1110
+ raise ValueError("k shapes wrong")
1111
+ if shape[-2] != self.q_seqinfo.seqstart_py[-1]:
1112
+ raise ValueError("q shapes wrong")
1113
+ mask = torch.empty(shape[-2:], dtype=dtype, device=device)
1114
+ mask.fill_(-math.inf)
1115
+ for (q_start, q_end), (k_start, k_end) in zip(
1116
+ self.q_seqinfo.intervals(),
1117
+ self.k_seqinfo.intervals(),
1118
+ ):
1119
+ mask[q_start:q_end, k_start:k_end] = self._create_block_mask(
1120
+ (q_end - q_start, k_end - k_start),
1121
+ dtype=dtype,
1122
+ device=device,
1123
+ )
1124
+ for _ in range(len(shape) - 2):
1125
+ mask = mask.unsqueeze(0)
1126
+ return mask.expand(shape)
1127
+
1128
+ @classmethod
1129
+ def from_seqlens(
1130
+ cls,
1131
+ q_seqlen: Sequence[int],
1132
+ kv_padding: int,
1133
+ kv_seqlen: Sequence[int],
1134
+ causal_diagonal: Any = None,
1135
+ *,
1136
+ device: Optional[torch.device] = None,
1137
+ ) -> "BlockDiagonalPaddedKeysMask":
1138
+ """Creates a :attr:`BlockDiagonalPaddedKeysMask` from a list of tensor
1139
+ lengths for query and key/value.
1140
+
1141
+ Args:
1142
+ q_seqlen (Sequence[int]): List or tensor of sequence lengths for query tensors
1143
+ kv_padding (int): Padding for k/v - also an upperbound on each individual key length
1144
+ kv_seqlen (Sequence[int]): List or tensor of sequence lengths for key/value.
1145
+ causal_diagonal: unused, for BC only
1146
+ Returns:
1147
+ BlockDiagonalPaddedKeysMask
1148
+ """
1149
+ device = _get_default_bias_device(device)
1150
+ assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen), (
1151
+ q_seqlen,
1152
+ kv_seqlen,
1153
+ )
1154
+ q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen, device=device)
1155
+ k_seqinfo = _PaddedSeqLenInfo.from_seqlens_padded(
1156
+ kv_seqlen, kv_padding, device=device
1157
+ )
1158
+ return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo)
1159
+
1160
+ def make_paged(
1161
+ self,
1162
+ block_tables: torch.Tensor,
1163
+ page_size: int,
1164
+ paged_type: Type["PagedBlockDiagonalPaddedKeysMask"],
1165
+ ) -> "PagedBlockDiagonalPaddedKeysMask":
1166
+ paged_bias = paged_type(
1167
+ q_seqinfo=self.q_seqinfo,
1168
+ k_seqinfo=_PaddedSeqLenInfo(
1169
+ seqstart=self.k_seqinfo.seqstart,
1170
+ seqstart_py=self.k_seqinfo.seqstart_py,
1171
+ seqlen=self.k_seqinfo.seqlen,
1172
+ seqlen_py=self.k_seqinfo.seqlen_py,
1173
+ padding=block_tables.shape[1] * page_size,
1174
+ max_seqlen=self.k_seqinfo.max_seqlen,
1175
+ min_seqlen=self.k_seqinfo.min_seqlen,
1176
+ ),
1177
+ block_tables=block_tables,
1178
+ page_size=page_size,
1179
+ )
1180
+ return paged_bias
1181
+
1182
+ def make_local_attention(
1183
+ self, window_left: int, window_right: int
1184
+ ) -> "BlockDiagonalLocalAttentionPaddedKeysMask":
1185
+ return BlockDiagonalLocalAttentionPaddedKeysMask(
1186
+ q_seqinfo=self.q_seqinfo,
1187
+ k_seqinfo=self.k_seqinfo,
1188
+ window_left=window_left,
1189
+ window_right=window_right,
1190
+ )
1191
+
1192
+
1193
+ @dataclass
1194
+ class BlockDiagonalCausalWithOffsetPaddedKeysMask(BlockDiagonalPaddedKeysMask):
1195
+ """
1196
+ Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`,
1197
+ except an offset on causality is allowed for each block and we support padding for k/v
1198
+
1199
+ The keys and values are divided into blocks which are padded out to
1200
+ the same total length.
1201
+ For example, if there is space for 12 keys, for three blocks of
1202
+ max length 4, but we only want to use the first 2, 3 and 2
1203
+ of each block, use `kv_padding=4` and `kv_seqlens=[2, 3, 2]`.
1204
+ The queries are divided into blocks, without padding, of lengths given by
1205
+ q_seqlen.
1206
+
1207
+ A query Q in block i cannot attend to a key which is not in block i,
1208
+ nor one which is not in use (i.e. in the padded area),
1209
+ nor one which is nearer to the final key in block i
1210
+ than Q is to the final query in block i.
1211
+ """
1212
+
1213
+ causal_diagonal: Any = None # unused. Exists for BC only.
1214
+
1215
+ def to(self, device) -> "BlockDiagonalCausalWithOffsetPaddedKeysMask":
1216
+ assert type(self) is BlockDiagonalCausalWithOffsetPaddedKeysMask, (
1217
+ "Please implement in subclass"
1218
+ )
1219
+ return BlockDiagonalCausalWithOffsetPaddedKeysMask(
1220
+ q_seqinfo=self.q_seqinfo.to(device),
1221
+ k_seqinfo=self.k_seqinfo.to(device),
1222
+ )
1223
+
1224
+ def _create_block_mask(
1225
+ self,
1226
+ shape: Tuple[int, ...],
1227
+ dtype: torch.dtype = torch.float32,
1228
+ device: Union[str, torch.device] = "cpu",
1229
+ ) -> torch.Tensor:
1230
+ return LowerTriangularFromBottomRightMask().materialize(
1231
+ shape=shape, dtype=dtype, device=device
1232
+ )
1233
+
1234
+ @classmethod
1235
+ def from_seqlens(
1236
+ cls,
1237
+ q_seqlen: Sequence[int],
1238
+ kv_padding: int,
1239
+ kv_seqlen: Sequence[int],
1240
+ causal_diagonal: Any = None,
1241
+ *,
1242
+ device: Optional[torch.device] = None,
1243
+ ) -> "BlockDiagonalCausalWithOffsetPaddedKeysMask":
1244
+ """Creates a :attr:`BlockDiagonalCausalWithOffsetPaddedKeysMask` from a list of tensor
1245
+ lengths for query and key/value.
1246
+
1247
+ Args:
1248
+ q_seqlen (Sequence[int]): List or tensor of sequence lengths for query tensors
1249
+ kv_padding (int): Padding for k/v - also an upperbound on each individual key length
1250
+ kv_seqlen (Sequence[int]): List or tensor of sequence lengths for key/value.
1251
+ causal_diagonal: unused, for BC only
1252
+ Returns:
1253
+ BlockDiagonalCausalWithOffsetPaddedKeysMask
1254
+ """
1255
+ assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen), (
1256
+ q_seqlen,
1257
+ kv_seqlen,
1258
+ )
1259
+ device = _get_default_bias_device(device)
1260
+ q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen, device=device)
1261
+ k_seqinfo = _PaddedSeqLenInfo.from_seqlens_padded(
1262
+ kv_seqlen, kv_padding, device=device
1263
+ )
1264
+ return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo)
1265
+
1266
+
1267
+ @dataclass
1268
+ class BlockDiagonalLocalAttentionPaddedKeysMask(BlockDiagonalPaddedKeysMask):
1269
+ """
1270
+ Like :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalLocalAttentionPaddedKeysMask`,
1271
+ except that this is non-causal.
1272
+
1273
+ A query Q in block i cannot attend to a key which is not in block i,
1274
+ nor one which is not in use (i.e. in the padded area),
1275
+ nor one whose distance to the final key in block i
1276
+ is more than window_left further or window_right nearer
1277
+ than Q is to the final query in block i.
1278
+
1279
+ A query attends to at most window_left + window_right - 1 keys.
1280
+
1281
+ NOTE that if window_right is 0, then this is like a
1282
+ BlockDiagonalCausalLocalAttentionPaddedKeysMask whose window_size is equal to
1283
+ window_left - 1.
1284
+ """
1285
+
1286
+ window_left: int
1287
+ window_right: int
1288
+
1289
+ def to(self, device) -> "BlockDiagonalLocalAttentionPaddedKeysMask":
1290
+ assert type(self) is BlockDiagonalLocalAttentionPaddedKeysMask, (
1291
+ "Please implement in subclass"
1292
+ )
1293
+ return BlockDiagonalLocalAttentionPaddedKeysMask(
1294
+ q_seqinfo=self.q_seqinfo.to(device),
1295
+ k_seqinfo=self.k_seqinfo.to(device),
1296
+ window_left=self.window_left,
1297
+ window_right=self.window_right,
1298
+ )
1299
+
1300
+ def _create_block_mask(
1301
+ self,
1302
+ shape: Tuple[int, ...],
1303
+ dtype: torch.dtype = torch.float32,
1304
+ device: Union[str, torch.device] = "cpu",
1305
+ ) -> torch.Tensor:
1306
+ return LocalAttentionFromBottomRightMask(
1307
+ window_left=self.window_left, window_right=self.window_right
1308
+ ).materialize(shape=shape, dtype=dtype, device=device)
1309
+
1310
+ @classmethod
1311
+ def from_seqlens_local(
1312
+ cls,
1313
+ q_seqlen: Sequence[int],
1314
+ kv_padding: int,
1315
+ kv_seqlen: Sequence[int],
1316
+ window_left: int,
1317
+ window_right: int,
1318
+ ) -> "BlockDiagonalLocalAttentionPaddedKeysMask":
1319
+ assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen), (
1320
+ q_seqlen,
1321
+ kv_seqlen,
1322
+ )
1323
+ q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen)
1324
+ k_seqinfo = _PaddedSeqLenInfo.from_seqlens_padded(kv_seqlen, kv_padding)
1325
+ return cls(
1326
+ q_seqinfo=q_seqinfo,
1327
+ k_seqinfo=k_seqinfo,
1328
+ window_left=window_left,
1329
+ window_right=window_right,
1330
+ )
1331
+
1332
+
1333
+ @dataclass
1334
+ class BlockDiagonalCausalLocalAttentionPaddedKeysMask(BlockDiagonalPaddedKeysMask):
1335
+ """
1336
+ Like :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask`,
1337
+ except with a window size.
1338
+
1339
+ A query Q in block i cannot attend to a key which is not in block i,
1340
+ nor one which is not in use (i.e. in the padded area),
1341
+ nor one which is nearer to the final key in block i
1342
+ than Q is to the final query in block i, nor one that is at least
1343
+ window_size further from the final key in block i than Q is
1344
+ to the final query in block i.
1345
+ """
1346
+
1347
+ _window_size: int
1348
+
1349
+ def to(self, device) -> "BlockDiagonalCausalLocalAttentionPaddedKeysMask":
1350
+ assert type(self) is BlockDiagonalCausalLocalAttentionPaddedKeysMask, (
1351
+ "Please implement in subclass"
1352
+ )
1353
+ return BlockDiagonalCausalLocalAttentionPaddedKeysMask(
1354
+ q_seqinfo=self.q_seqinfo.to(device),
1355
+ k_seqinfo=self.k_seqinfo.to(device),
1356
+ _window_size=self._window_size,
1357
+ )
1358
+
1359
+ def _create_block_mask(
1360
+ self,
1361
+ shape: Tuple[int, ...],
1362
+ dtype: torch.dtype = torch.float32,
1363
+ device: Union[str, torch.device] = "cpu",
1364
+ ) -> torch.Tensor:
1365
+ return _materialize_causal_mask(
1366
+ shape=shape,
1367
+ dtype=dtype,
1368
+ device=device,
1369
+ window_size=self._window_size,
1370
+ from_bottomright=True,
1371
+ )
1372
+
1373
+ @classmethod
1374
+ def from_seqlens_local(
1375
+ cls,
1376
+ q_seqlen: Sequence[int],
1377
+ kv_padding: int,
1378
+ kv_seqlen: Sequence[int],
1379
+ window_size: int,
1380
+ ) -> "BlockDiagonalCausalLocalAttentionPaddedKeysMask":
1381
+ assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen), (
1382
+ q_seqlen,
1383
+ kv_seqlen,
1384
+ )
1385
+ q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen)
1386
+ k_seqinfo = _PaddedSeqLenInfo.from_seqlens_padded(kv_seqlen, kv_padding)
1387
+ return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo, _window_size=window_size)
1388
+
1389
+ # pyre-ignore[14]
1390
+ def make_paged(
1391
+ self,
1392
+ block_tables: torch.Tensor,
1393
+ page_size: int,
1394
+ paged_type: Type["PagedBlockDiagonalCausalLocalPaddedKeysMask"],
1395
+ ) -> "PagedBlockDiagonalCausalLocalPaddedKeysMask":
1396
+ paged_bias = paged_type(
1397
+ q_seqinfo=self.q_seqinfo,
1398
+ k_seqinfo=_PaddedSeqLenInfo(
1399
+ seqstart=self.k_seqinfo.seqstart,
1400
+ seqstart_py=self.k_seqinfo.seqstart_py,
1401
+ seqlen=self.k_seqinfo.seqlen,
1402
+ seqlen_py=self.k_seqinfo.seqlen_py,
1403
+ padding=block_tables.shape[1] * page_size,
1404
+ max_seqlen=self.k_seqinfo.max_seqlen,
1405
+ min_seqlen=self.k_seqinfo.min_seqlen,
1406
+ ),
1407
+ block_tables=block_tables,
1408
+ page_size=page_size,
1409
+ _window_size=self._window_size,
1410
+ )
1411
+ return paged_bias
1412
+
1413
+
1414
+ @dataclass
1415
+ class PagedBlockDiagonalPaddedKeysMask(AttentionBias):
1416
+ """
1417
+ Same as BlockDiagonalPaddedKeysMask, but for paged attention.
1418
+ block_tables has shape [batch_size, max_num_pages] and K/V have shape
1419
+ [1, max_num_pages * page_size, num_heads, head_dim]
1420
+ or [1, max_num_pages * page_size, num_groups, num_heads, head_dim]
1421
+ """
1422
+
1423
+ q_seqinfo: _SeqLenInfo
1424
+ k_seqinfo: _PaddedSeqLenInfo
1425
+ block_tables: torch.Tensor
1426
+ page_size: int
1427
+
1428
+ _UNPAGED_TYPE: ClassVar[Type[BlockDiagonalPaddedKeysMask]] = (
1429
+ BlockDiagonalPaddedKeysMask
1430
+ )
1431
+
1432
+ def to(self, device: torch.device) -> "PagedBlockDiagonalPaddedKeysMask":
1433
+ assert type(self) is PagedBlockDiagonalPaddedKeysMask, (
1434
+ "Please implement in subclass"
1435
+ )
1436
+ return PagedBlockDiagonalPaddedKeysMask(
1437
+ q_seqinfo=self.q_seqinfo.to(device),
1438
+ k_seqinfo=self.k_seqinfo.to(device),
1439
+ block_tables=_to_device(self.block_tables, device),
1440
+ page_size=self.page_size,
1441
+ )
1442
+
1443
+ def materialize(
1444
+ self,
1445
+ shape: Tuple[int, ...],
1446
+ dtype: torch.dtype = torch.float32,
1447
+ device: Union[str, torch.device] = "cpu",
1448
+ ) -> torch.Tensor:
1449
+ """Materialize the attention bias - for debugging & testing"""
1450
+ # First create a non-paged mask, then cut individual pages and
1451
+ # copy them to their places in the physical mask, using block tables
1452
+
1453
+ max_row_len = self.block_tables.shape[1] * self.page_size
1454
+ bias_nonpaged = self._UNPAGED_TYPE(
1455
+ q_seqinfo=self.q_seqinfo,
1456
+ k_seqinfo=_PaddedSeqLenInfo.from_seqlens_padded(
1457
+ self.k_seqinfo.seqlen_py, max_row_len
1458
+ ),
1459
+ )
1460
+ mask_nonpaged = bias_nonpaged.materialize(shape, dtype, device)
1461
+
1462
+ n_used_blocks = cast(int, self.block_tables.max().item() + 1)
1463
+ max_physical_len = n_used_blocks * self.page_size
1464
+ mask_paged = torch.empty(
1465
+ mask_nonpaged.shape[:-1] + (max_physical_len,), dtype=dtype, device=device
1466
+ )
1467
+ mask_paged.fill_(-math.inf)
1468
+ for b, (q_start, q_end) in enumerate(self.q_seqinfo.intervals()):
1469
+ for logical_page_idx in range(self.block_tables.shape[1]):
1470
+ physical_page_idx = cast(
1471
+ int, self.block_tables[b][logical_page_idx].item()
1472
+ )
1473
+ k_logical_start = b * max_row_len + logical_page_idx * self.page_size
1474
+ k_logical_end = k_logical_start + self.page_size
1475
+ k_physical_start = physical_page_idx * self.page_size
1476
+ k_physical_end = k_physical_start + self.page_size
1477
+ mask_paged[..., q_start:q_end, k_physical_start:k_physical_end] = (
1478
+ mask_nonpaged[..., q_start:q_end, k_logical_start:k_logical_end]
1479
+ )
1480
+ return mask_paged
1481
+
1482
+ @classmethod
1483
+ def from_seqlens(
1484
+ cls,
1485
+ q_seqlen: Sequence[int],
1486
+ kv_seqlen: Sequence[int],
1487
+ block_tables: torch.Tensor,
1488
+ page_size: int,
1489
+ *,
1490
+ device: Optional[torch.device] = None,
1491
+ ) -> "PagedBlockDiagonalPaddedKeysMask":
1492
+ """Creates a :attr:`PagedBlockDiagonalPaddedKeysMask` from a list of tensor
1493
+ lengths for query and key/value.
1494
+
1495
+ Args:
1496
+ q_seqlen (Sequence[int]): List or tensor of sequence lengths for query tensors
1497
+ kv_padding (int): Padding for k/v - also an upperbound on each individual key length
1498
+ kv_seqlen (Sequence[int]): List or tensor of sequence lengths for key/value.
1499
+ causal_diagonal: unused, for BC only
1500
+ Returns:
1501
+ PagedBlockDiagonalPaddedKeysMask
1502
+ """
1503
+ assert len(q_seqlen) == len(kv_seqlen), (
1504
+ q_seqlen,
1505
+ kv_seqlen,
1506
+ )
1507
+ device = _get_default_bias_device(device)
1508
+ q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen, device=device)
1509
+ k_seqinfo = _PaddedSeqLenInfo.from_seqlens_padded(
1510
+ kv_seqlen, padding=block_tables.shape[1] * page_size, device=device
1511
+ )
1512
+ return cls(
1513
+ q_seqinfo=q_seqinfo,
1514
+ k_seqinfo=k_seqinfo,
1515
+ block_tables=block_tables,
1516
+ page_size=page_size,
1517
+ )
1518
+
1519
+
1520
+ @dataclass
1521
+ class PagedBlockDiagonalCausalWithOffsetPaddedKeysMask(
1522
+ PagedBlockDiagonalPaddedKeysMask
1523
+ ):
1524
+ """
1525
+ Same as BlockDiagonalCausalWithOffsetPaddedKeysMask, but for paged attention.
1526
+ block_tables has shape [batch_size, max_num_pages] and K/V have shape
1527
+ [1, max_num_pages * page_size, num_heads, head_dim]
1528
+ or [1, max_num_pages * page_size, num_groups, num_heads, head_dim]
1529
+ """
1530
+
1531
+ _UNPAGED_TYPE = BlockDiagonalCausalWithOffsetPaddedKeysMask
1532
+
1533
+ def to(
1534
+ self, device: torch.device
1535
+ ) -> "PagedBlockDiagonalCausalWithOffsetPaddedKeysMask":
1536
+ assert type(self) is PagedBlockDiagonalCausalWithOffsetPaddedKeysMask, (
1537
+ "Please implement in subclass"
1538
+ )
1539
+ return PagedBlockDiagonalCausalWithOffsetPaddedKeysMask(
1540
+ q_seqinfo=self.q_seqinfo.to(device),
1541
+ k_seqinfo=self.k_seqinfo.to(device),
1542
+ block_tables=_to_device(self.block_tables, device),
1543
+ page_size=self.page_size,
1544
+ )
1545
+
1546
+
1547
+ @dataclass
1548
+ class PagedBlockDiagonalCausalLocalPaddedKeysMask(PagedBlockDiagonalPaddedKeysMask):
1549
+ """
1550
+ Same as BlockDiagonalCausalLocalAttentionPaddedKeysMask, but for paged attention.
1551
+ block_tables has shape [batch_size, max_num_pages] and K/V have shape
1552
+ [1, max_num_pages * page_size, num_heads, head_dim]
1553
+ or [1, max_num_pages * page_size, num_groups, num_heads, head_dim]
1554
+ """
1555
+
1556
+ _window_size: int
1557
+
1558
+ _UNPAGED_TYPE: ClassVar[Type[BlockDiagonalCausalLocalAttentionPaddedKeysMask]] = (
1559
+ BlockDiagonalCausalLocalAttentionPaddedKeysMask
1560
+ )
1561
+
1562
+ def to(self, device: torch.device) -> "PagedBlockDiagonalCausalLocalPaddedKeysMask":
1563
+ assert type(self) is PagedBlockDiagonalCausalLocalPaddedKeysMask, (
1564
+ "Please implement in subclass"
1565
+ )
1566
+ return PagedBlockDiagonalCausalLocalPaddedKeysMask(
1567
+ q_seqinfo=self.q_seqinfo.to(device),
1568
+ k_seqinfo=self.k_seqinfo.to(device),
1569
+ block_tables=_to_device(self.block_tables, device),
1570
+ page_size=self.page_size,
1571
+ _window_size=self._window_size,
1572
+ )
1573
+
1574
+ def materialize(
1575
+ self,
1576
+ shape: Tuple[int, ...],
1577
+ dtype: torch.dtype = torch.float32,
1578
+ device: Union[str, torch.device] = "cpu",
1579
+ ) -> torch.Tensor:
1580
+ """Materialize the attention bias - for debugging & testing"""
1581
+ # First create a non-paged mask, then cut individual pages and
1582
+ # copy them to their places in the physical mask, using block tables
1583
+
1584
+ max_row_len = self.block_tables.shape[1] * self.page_size
1585
+ bias_nonpaged = self._UNPAGED_TYPE(
1586
+ q_seqinfo=self.q_seqinfo,
1587
+ k_seqinfo=_PaddedSeqLenInfo.from_seqlens_padded(
1588
+ self.k_seqinfo.seqlen_py, max_row_len
1589
+ ),
1590
+ _window_size=self._window_size,
1591
+ )
1592
+ mask_nonpaged = bias_nonpaged.materialize(shape, dtype, device)
1593
+
1594
+ n_used_blocks = cast(int, self.block_tables.max().item() + 1)
1595
+ max_physical_len = n_used_blocks * self.page_size
1596
+ mask_paged = torch.empty(
1597
+ mask_nonpaged.shape[:-1] + (max_physical_len,), dtype=dtype, device=device
1598
+ )
1599
+ mask_paged.fill_(-math.inf)
1600
+ for b, (q_start, q_end) in enumerate(self.q_seqinfo.intervals()):
1601
+ for logical_page_idx in range(self.block_tables.shape[1]):
1602
+ physical_page_idx = cast(
1603
+ int, self.block_tables[b][logical_page_idx].item()
1604
+ )
1605
+ k_logical_start = b * max_row_len + logical_page_idx * self.page_size
1606
+ k_logical_end = k_logical_start + self.page_size
1607
+ k_physical_start = physical_page_idx * self.page_size
1608
+ k_physical_end = k_physical_start + self.page_size
1609
+ mask_paged[..., q_start:q_end, k_physical_start:k_physical_end] = (
1610
+ mask_nonpaged[..., q_start:q_end, k_logical_start:k_logical_end]
1611
+ )
1612
+ return mask_paged
1613
+
1614
+ @classmethod
1615
+ def from_seqlens_local(
1616
+ cls,
1617
+ q_seqlen: Sequence[int],
1618
+ kv_seqlen: Sequence[int],
1619
+ block_tables: torch.Tensor,
1620
+ page_size: int,
1621
+ window_size: int,
1622
+ *,
1623
+ device: Optional[torch.device] = None,
1624
+ ) -> "PagedBlockDiagonalCausalLocalPaddedKeysMask":
1625
+ """Creates a :attr:`PagedBlockDiagonalCausalLocalPaddedKeysMask` from a list of tensor
1626
+ lengths for query and key/value.
1627
+
1628
+ Args:
1629
+ q_seqlen (Sequence[int]): List or tensor of sequence lengths for query tensors
1630
+ kv_padding (int): Padding for k/v - also an upperbound on each individual key length
1631
+ kv_seqlen (Sequence[int]): List or tensor of sequence lengths for key/value.
1632
+ block_tables: table mapping logical pages to physical for paged attention.
1633
+ page_size: size of each page for paged attention.
1634
+ window_size: size of the window for sliding window attention.
1635
+ Returns:
1636
+ PagedBlockDiagonalCausalLocalPaddedKeysMask
1637
+ """
1638
+ assert len(q_seqlen) == len(kv_seqlen), (
1639
+ q_seqlen,
1640
+ kv_seqlen,
1641
+ )
1642
+ device = _get_default_bias_device(device)
1643
+ q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen, device=device)
1644
+ k_seqinfo = _PaddedSeqLenInfo.from_seqlens_padded(
1645
+ kv_seqlen, padding=block_tables.shape[1] * page_size, device=device
1646
+ )
1647
+ return cls(
1648
+ q_seqinfo=q_seqinfo,
1649
+ k_seqinfo=k_seqinfo,
1650
+ block_tables=block_tables,
1651
+ page_size=page_size,
1652
+ _window_size=window_size,
1653
+ )
1654
+
1655
+
1656
+ @dataclass
1657
+ class BlockDiagonalGappyKeysMask(AttentionBias):
1658
+ """
1659
+ Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`,
1660
+ except k/v is gappy.
1661
+
1662
+ A query Q in block i only attends to a key which is in block i.
1663
+ """
1664
+
1665
+ q_seqinfo: _SeqLenInfo
1666
+ k_seqinfo: _GappySeqInfo
1667
+
1668
+ def to(self, device: torch.device) -> "BlockDiagonalGappyKeysMask":
1669
+ assert type(self) is BlockDiagonalGappyKeysMask, "Please implement in subclass"
1670
+ return BlockDiagonalGappyKeysMask(
1671
+ q_seqinfo=self.q_seqinfo.to(device),
1672
+ k_seqinfo=self.k_seqinfo.to(device),
1673
+ )
1674
+
1675
+ def materialize(
1676
+ self,
1677
+ shape: Tuple[int, ...],
1678
+ dtype: torch.dtype = torch.float32,
1679
+ device: Union[str, torch.device] = "cpu",
1680
+ ) -> torch.Tensor:
1681
+ """Materialize the attention bias - for debugging & testing"""
1682
+ if shape[-1] != self.k_seqinfo.seqstart_py[-1]:
1683
+ raise ValueError("k shapes wrong", (shape, self.k_seqinfo))
1684
+ if shape[-2] != self.q_seqinfo.seqstart_py[-1]:
1685
+ raise ValueError("q shapes wrong", (shape, self.q_seqinfo))
1686
+ mask = torch.empty(shape[-2:], dtype=dtype, device=device)
1687
+ mask.fill_(-math.inf)
1688
+ for (q_start, q_end), (k_start, k_end) in zip(
1689
+ self.q_seqinfo.intervals(),
1690
+ self.k_seqinfo.intervals(),
1691
+ ):
1692
+ mask[q_start:q_end, k_start:k_end] = 0
1693
+ for _ in range(len(shape) - 2):
1694
+ mask = mask.unsqueeze(0)
1695
+ return mask.expand(shape)
1696
+
1697
+ @classmethod
1698
+ def from_seqlens(
1699
+ cls,
1700
+ q_seqlen: Sequence[int],
1701
+ kv_seqstarts: Sequence[int],
1702
+ kv_seqlen: Sequence[int],
1703
+ *,
1704
+ device: Optional[torch.device] = None,
1705
+ ) -> "BlockDiagonalGappyKeysMask":
1706
+ """Creates a :attr:`BlockDiagonalGappyKeysMask` from a list of tensor
1707
+ lengths for query and key/value.
1708
+ """
1709
+ assert len(q_seqlen) == len(kv_seqlen), (
1710
+ q_seqlen,
1711
+ kv_seqlen,
1712
+ )
1713
+ device = _get_default_bias_device(device)
1714
+ q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen, device=device)
1715
+ k_seqinfo = _GappySeqInfo.from_seqlens_gappy(
1716
+ kv_seqstarts, kv_seqlen, False, device=device
1717
+ )
1718
+ return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo)
1719
+
1720
+ def make_paged(
1721
+ self,
1722
+ block_tables: torch.Tensor,
1723
+ page_size: int,
1724
+ notional_padding: int,
1725
+ paged_type: Type["PagedBlockDiagonalGappyKeysMask"],
1726
+ ) -> AttentionBias:
1727
+ """
1728
+ Assuming our keys actually live in separate blocks of length
1729
+ notional_padding, convert to a Paged version, avoiding GPU syncs.
1730
+ """
1731
+ if notional_padding % page_size:
1732
+ raise ValueError(
1733
+ "Notional padding should be divisible by the page size,"
1734
+ f" but got {notional_padding=}, {page_size=}."
1735
+ )
1736
+ max_row_len = block_tables.shape[1] * page_size
1737
+ new_seqstarts_py = [
1738
+ start - i * notional_padding
1739
+ for i, start in enumerate(self.k_seqinfo.seqstart_py[:-1])
1740
+ ]
1741
+ new_seqstarts_py.append(-1)
1742
+ assert all(0 <= i < max_row_len for i in new_seqstarts_py[:-1]), (
1743
+ f"{max_row_len=} {new_seqstarts_py=}"
1744
+ )
1745
+
1746
+ # Sequence info is duplicated on CPU and GPU,
1747
+ # but we process them independently to avoid GPU sync.
1748
+ batch_size = len(self.k_seqinfo.seqlen_py)
1749
+ notional_starts = notional_padding * torch.arange(
1750
+ batch_size + 1,
1751
+ device=block_tables.device,
1752
+ dtype=torch.int32,
1753
+ )
1754
+ new_seqstarts = self.k_seqinfo.seqstart - notional_starts
1755
+
1756
+ new_seqlens_py = [
1757
+ i + j for i, j in zip(new_seqstarts_py, self.k_seqinfo.seqlen_py)
1758
+ ]
1759
+ new_seqlens = self.k_seqinfo.seqlen + new_seqstarts[:-1]
1760
+
1761
+ k_seqinfo = _GappySeqInfo(
1762
+ seqlen=new_seqlens,
1763
+ seqlen_py=new_seqlens_py,
1764
+ max_seqlen=self.k_seqinfo.max_seqlen,
1765
+ min_seqlen=self.k_seqinfo.min_seqlen,
1766
+ seqstart=new_seqstarts,
1767
+ seqstart_py=new_seqstarts_py,
1768
+ )
1769
+ assert self.k_seqinfo.max_seqlen <= max_row_len
1770
+ paged_bias = paged_type(
1771
+ q_seqinfo=self.q_seqinfo,
1772
+ k_seqinfo=k_seqinfo,
1773
+ block_tables=block_tables,
1774
+ page_size=page_size,
1775
+ )
1776
+ return paged_bias
1777
+
1778
+
1779
+ @dataclass
1780
+ class BlockDiagonalLocalAttentionFromBottomRightGappyKeysMask(
1781
+ BlockDiagonalGappyKeysMask
1782
+ ):
1783
+ """
1784
+ Like :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalGappyKeysMask`,
1785
+ except that this has local attention.
1786
+
1787
+ A query Q in block i cannot attend to a key which is not in block i,
1788
+ nor one which is not in use (i.e. in the padded area),
1789
+ nor one whose distance to the final key in block i
1790
+ is more than window_left further or window_right nearer
1791
+ than Q is to the final query in block i.
1792
+
1793
+ A query attends to at most window_left + window_right - 1 keys.
1794
+ """
1795
+
1796
+ window_left: int
1797
+ window_right: int
1798
+
1799
+ def to(self, device) -> "BlockDiagonalLocalAttentionFromBottomRightGappyKeysMask":
1800
+ assert type(self) is BlockDiagonalLocalAttentionFromBottomRightGappyKeysMask, (
1801
+ "Please implement in subclass"
1802
+ )
1803
+ return BlockDiagonalLocalAttentionFromBottomRightGappyKeysMask(
1804
+ q_seqinfo=self.q_seqinfo.to(device),
1805
+ k_seqinfo=self.k_seqinfo.to(device),
1806
+ window_left=self.window_left,
1807
+ window_right=self.window_right,
1808
+ )
1809
+
1810
+ def materialize(
1811
+ self,
1812
+ shape: Tuple[int, ...],
1813
+ dtype: torch.dtype = torch.float32,
1814
+ device: Union[str, torch.device] = "cpu",
1815
+ ) -> torch.Tensor:
1816
+ """Materialize the attention bias - for debugging & testing"""
1817
+ if shape[-1] != self.k_seqinfo.seqstart_py[-1]:
1818
+ raise ValueError("k shapes wrong", (shape, self.k_seqinfo))
1819
+ if shape[-2] != self.q_seqinfo.seqstart_py[-1]:
1820
+ raise ValueError("q shapes wrong", (shape, self.q_seqinfo))
1821
+ mask = torch.full(shape[-2:], fill_value=0, dtype=dtype, device=device)
1822
+ for (q_start, q_end), (k_start, k_end) in zip(
1823
+ self.q_seqinfo.intervals(),
1824
+ self.k_seqinfo.intervals(),
1825
+ ):
1826
+ mask[q_start:q_end, k_start:k_end] = 1
1827
+ # TODO insert locality condition
1828
+ mask[q_start:q_end, k_start:k_end] = _apply_locality_on_mask(
1829
+ mask[q_start:q_end, k_start:k_end], self.window_left, self.window_right
1830
+ )
1831
+ for _ in range(len(shape) - 2):
1832
+ mask = mask.unsqueeze(0)
1833
+ mask = mask.log()
1834
+ return mask.expand(shape)
1835
+
1836
+ @classmethod
1837
+ def from_seqlens_local_gappy(
1838
+ cls,
1839
+ q_seqlen: Sequence[int],
1840
+ kv_seqstarts: Sequence[int],
1841
+ kv_seqlen: Sequence[int],
1842
+ window_left: int,
1843
+ window_right: int,
1844
+ device: torch.device,
1845
+ ) -> "BlockDiagonalLocalAttentionFromBottomRightGappyKeysMask":
1846
+ assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen), (
1847
+ q_seqlen,
1848
+ kv_seqlen,
1849
+ )
1850
+ q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen)
1851
+ k_seqinfo = _GappySeqInfo.from_seqlens_gappy(
1852
+ kv_seqstarts, kv_seqlen, paged=False, device=device
1853
+ )
1854
+ return cls(
1855
+ q_seqinfo=q_seqinfo,
1856
+ k_seqinfo=k_seqinfo,
1857
+ window_left=window_left,
1858
+ window_right=window_right,
1859
+ )
1860
+
1861
+
1862
+ @dataclass
1863
+ class BlockDiagonalCausalWithOffsetGappyKeysMask(BlockDiagonalGappyKeysMask):
1864
+ """
1865
+ Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`,
1866
+ except k/v is gappy.
1867
+
1868
+ A query Q in block i cannot attend to a key which is not in block i,
1869
+ nor one which is nearer to the final key in block i
1870
+ than Q is to the final query in block i.
1871
+ """
1872
+
1873
+ def to(self, device: torch.device) -> "BlockDiagonalCausalWithOffsetGappyKeysMask":
1874
+ assert type(self) is BlockDiagonalCausalWithOffsetGappyKeysMask, (
1875
+ "Please implement in subclass"
1876
+ )
1877
+ return BlockDiagonalCausalWithOffsetGappyKeysMask(
1878
+ q_seqinfo=self.q_seqinfo.to(device),
1879
+ k_seqinfo=self.k_seqinfo.to(device),
1880
+ )
1881
+
1882
+ def materialize(
1883
+ self,
1884
+ shape: Tuple[int, ...],
1885
+ dtype: torch.dtype = torch.float32,
1886
+ device: Union[str, torch.device] = "cpu",
1887
+ ) -> torch.Tensor:
1888
+ """Materialize the attention bias - for debugging & testing"""
1889
+ if shape[-1] != self.k_seqinfo.seqstart_py[-1]:
1890
+ raise ValueError("k shapes wrong")
1891
+ if shape[-2] != self.q_seqinfo.seqstart_py[-1]:
1892
+ raise ValueError("q shapes wrong")
1893
+ mask = torch.empty(shape[-2:], dtype=dtype, device=device)
1894
+ mask.fill_(-math.inf)
1895
+ for (q_start, q_end), (k_start, k_end) in zip(
1896
+ self.q_seqinfo.intervals(),
1897
+ self.k_seqinfo.intervals(),
1898
+ ):
1899
+ mask[q_start:q_end, k_start:k_end] = (
1900
+ LowerTriangularFromBottomRightMask().materialize(
1901
+ shape=(q_end - q_start, k_end - k_start), dtype=dtype, device=device
1902
+ )
1903
+ )
1904
+
1905
+ for _ in range(len(shape) - 2):
1906
+ mask = mask.unsqueeze(0)
1907
+ return mask.expand(shape)
1908
+
1909
+
1910
+ @dataclass
1911
+ class PagedBlockDiagonalGappyKeysMask(AttentionBias):
1912
+ """
1913
+ Equivalent BlockDiagonalGappyKeysMask, but for paged attention.
1914
+ block_tables has shape [batch_size, max_num_pages] and K/V have shape
1915
+ [1, max_num_pages * page_size, num_heads, head_dim]
1916
+ or [1, max_num_pages * page_size, num_groups, num_heads, head_dim]
1917
+ """
1918
+
1919
+ q_seqinfo: _SeqLenInfo
1920
+ k_seqinfo: _GappySeqInfo
1921
+ block_tables: torch.Tensor
1922
+ page_size: int
1923
+
1924
+ _UNPAGED_TYPE: ClassVar[Type[BlockDiagonalGappyKeysMask]] = (
1925
+ BlockDiagonalGappyKeysMask
1926
+ )
1927
+
1928
+ def to(self, device: torch.device) -> "PagedBlockDiagonalGappyKeysMask":
1929
+ assert type(self) is PagedBlockDiagonalGappyKeysMask, (
1930
+ "Please implement in subclass"
1931
+ )
1932
+ return PagedBlockDiagonalGappyKeysMask(
1933
+ q_seqinfo=self.q_seqinfo.to(device),
1934
+ k_seqinfo=self.k_seqinfo.to(device),
1935
+ block_tables=_to_device(self.block_tables, device),
1936
+ page_size=self.page_size,
1937
+ )
1938
+
1939
+ def materialize(
1940
+ self,
1941
+ shape: Tuple[int, ...],
1942
+ dtype: torch.dtype = torch.float32,
1943
+ device: Union[str, torch.device] = "cpu",
1944
+ ) -> torch.Tensor:
1945
+ """Materialize the attention bias - for debugging & testing"""
1946
+ # First create a non-paged mask, then cut individual pages and
1947
+ # copy them to their places in the physical mask, using block tables
1948
+
1949
+ max_row_len = self.block_tables.shape[1] * self.page_size
1950
+ new_seqstarts = [
1951
+ start + i * max_row_len
1952
+ for i, start in enumerate(self.k_seqinfo.seqstart_py[:-1])
1953
+ ] + [shape[-1]]
1954
+ new_seqlens = [
1955
+ end - start
1956
+ for start, end in zip(self.k_seqinfo.seqstart_py, self.k_seqinfo.seqlen_py)
1957
+ ]
1958
+ bias_nonpaged = self._UNPAGED_TYPE(
1959
+ q_seqinfo=self.q_seqinfo,
1960
+ k_seqinfo=_GappySeqInfo.from_seqlens_gappy(
1961
+ new_seqstarts,
1962
+ new_seqlens,
1963
+ False,
1964
+ device=torch.device(device),
1965
+ ),
1966
+ )
1967
+ mask_nonpaged = bias_nonpaged.materialize(shape, dtype, device)
1968
+
1969
+ n_used_blocks = cast(int, self.block_tables.max().item() + 1)
1970
+ max_physical_len = n_used_blocks * self.page_size
1971
+ mask_paged = torch.empty(
1972
+ mask_nonpaged.shape[:-1] + (max_physical_len,), dtype=dtype, device=device
1973
+ )
1974
+ mask_paged.fill_(-math.inf)
1975
+ for b, (q_start, q_end) in enumerate(self.q_seqinfo.intervals()):
1976
+ for logical_page_idx in range(self.block_tables.shape[1]):
1977
+ physical_page_idx = cast(
1978
+ int, self.block_tables[b][logical_page_idx].item()
1979
+ )
1980
+ k_logical_start = b * max_row_len + logical_page_idx * self.page_size
1981
+ k_logical_end = k_logical_start + self.page_size
1982
+ k_physical_start = physical_page_idx * self.page_size
1983
+ k_physical_end = k_physical_start + self.page_size
1984
+ mask_paged[..., q_start:q_end, k_physical_start:k_physical_end] = (
1985
+ mask_nonpaged[..., q_start:q_end, k_logical_start:k_logical_end]
1986
+ )
1987
+ return mask_paged
1988
+
1989
+ @classmethod
1990
+ def from_seqlens(
1991
+ cls,
1992
+ q_seqlen: Sequence[int],
1993
+ kv_seqstarts: Sequence[int],
1994
+ kv_seqlen: Sequence[int],
1995
+ block_tables: torch.Tensor,
1996
+ page_size: int,
1997
+ *,
1998
+ device: Optional[torch.device] = None,
1999
+ ) -> "PagedBlockDiagonalGappyKeysMask":
2000
+ """Creates a :attr:`PagedBlockDiagonalGappyKeysMask` from a list of tensor
2001
+ lengths for query and key/value.
2002
+
2003
+ Note that unlike :attr:`BlockDiagonalGappyKeysMask`, kv_seqstarts is
2004
+ addressing in a different space for each batch element. For example
2005
+ if you were doing a BlockDiagonalPaddedKeysMask with two batch
2006
+ elements and padding=100, but wanted to change it so that the first
2007
+ key is ignored, then you would use BlockDiagonalGappyKeysMask with kv_seqstarts
2008
+ [1, 101, 200]. But if you were using PagedBlockDiagonalPaddedKeysMask
2009
+ but wanted to ignore the first key, you would provide this function with
2010
+ kv_seqstarts = [1, 1].
2011
+ """
2012
+ assert len(q_seqlen) == len(kv_seqlen) == len(kv_seqstarts), (
2013
+ q_seqlen,
2014
+ kv_seqlen,
2015
+ kv_seqstarts,
2016
+ )
2017
+ device = block_tables.device if device is None else device
2018
+ q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen, device=device)
2019
+ k_seqinfo = _GappySeqInfo.from_seqlens_gappy(
2020
+ kv_seqstarts, kv_seqlen, True, device=device
2021
+ )
2022
+ return cls(
2023
+ q_seqinfo=q_seqinfo,
2024
+ k_seqinfo=k_seqinfo,
2025
+ block_tables=block_tables,
2026
+ page_size=page_size,
2027
+ )
2028
+
2029
+
2030
+ @dataclass
2031
+ class PagedBlockDiagonalCausalWithOffsetGappyKeysMask(PagedBlockDiagonalGappyKeysMask):
2032
+ """
2033
+ Same as BlockDiagonalCausalWithOffsetGappyKeysMask, but for paged attention.
2034
+ block_tables has shape [batch_size, max_num_pages] and K/V have shape
2035
+ [1, max_num_pages * page_size, num_heads, head_dim] or
2036
+ [1, max_num_pages * page_size, num_groups, num_heads, head_dim]
2037
+ """
2038
+
2039
+ _UNPAGED_TYPE = BlockDiagonalCausalWithOffsetGappyKeysMask
2040
+
2041
+ def to(
2042
+ self, device: torch.device
2043
+ ) -> "PagedBlockDiagonalCausalWithOffsetGappyKeysMask":
2044
+ assert type(self) is PagedBlockDiagonalCausalWithOffsetGappyKeysMask, (
2045
+ "Please implement in subclass"
2046
+ )
2047
+ return PagedBlockDiagonalCausalWithOffsetGappyKeysMask(
2048
+ q_seqinfo=self.q_seqinfo.to(device),
2049
+ k_seqinfo=self.k_seqinfo.to(device),
2050
+ block_tables=_to_device(self.block_tables, device),
2051
+ page_size=self.page_size,
2052
+ )
2053
+
2054
+
2055
+ @dataclass
2056
+ class BlockDiagonalCausalLocalAttentionMask(BlockDiagonalCausalMask):
2057
+ """
2058
+ (Experimental feature)
2059
+ Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`.
2060
+ This makes the mask "local" and the attention pattern banded.
2061
+
2062
+ The ith query in a block only attends to keys in its block with index
2063
+ greater than i - window_size and less than or equal to i.
2064
+ """
2065
+
2066
+ _window_size: int = 0 # forced due to inheritance and default arguments
2067
+
2068
+ def to(self, device) -> "BlockDiagonalCausalLocalAttentionMask":
2069
+ assert type(self) is BlockDiagonalCausalLocalAttentionMask, (
2070
+ "Please implement in subclass"
2071
+ )
2072
+ return BlockDiagonalCausalLocalAttentionMask(
2073
+ q_seqinfo=self.q_seqinfo.to(device),
2074
+ k_seqinfo=self.k_seqinfo.to(device),
2075
+ _batch_sizes=self._batch_sizes,
2076
+ _window_size=self._window_size,
2077
+ )
2078
+
2079
+ def __post_init__(self):
2080
+ if self._window_size <= 0:
2081
+ raise ValueError(
2082
+ f"Expected `window_size > 0`, but window_size={self._window_size}"
2083
+ )
2084
+ q_seqlen = [
2085
+ y - x
2086
+ for x, y in zip(
2087
+ self.q_seqinfo.seqstart_py[:-1], self.q_seqinfo.seqstart_py[1:]
2088
+ )
2089
+ ]
2090
+ kv_seqlen = [
2091
+ y - x
2092
+ for x, y in zip(
2093
+ self.k_seqinfo.seqstart_py[:-1], self.k_seqinfo.seqstart_py[1:]
2094
+ )
2095
+ ]
2096
+ for q, k in zip(q_seqlen, kv_seqlen):
2097
+ if q - self._window_size >= k:
2098
+ # Each query only attends to keys no further than window_size back.
2099
+ # When q > k + window_size, there will be a query for which the window doesn't reach any key.
2100
+ raise RuntimeError(
2101
+ f"No keys are attended in q_seqlen {q} k_seqlen {k} with sliding window {self._window_size}"
2102
+ )
2103
+
2104
+ def _create_block_mask(
2105
+ self,
2106
+ shape: Tuple[int, ...],
2107
+ dtype: torch.dtype = torch.float32,
2108
+ device: Union[str, torch.device] = "cpu",
2109
+ ) -> torch.Tensor:
2110
+ return _materialize_causal_mask(
2111
+ shape,
2112
+ dtype=dtype,
2113
+ device=device,
2114
+ window_size=self._window_size,
2115
+ )
2116
+
2117
+
2118
+ @dataclass
2119
+ class BlockDiagonalCausalLocalAttentionFromBottomRightMask(
2120
+ BlockDiagonalCausalFromBottomRightMask
2121
+ ):
2122
+ """
2123
+ (Experimental feature)
2124
+ Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`.
2125
+ This makes the mask "local" and the attention pattern banded.
2126
+
2127
+ A query with distance j from the last query in its block only attends to
2128
+ keys in the same block, and only those whose distance to the last key
2129
+ in the block is greater than or equal to j and less than window_size + j.
2130
+ """
2131
+
2132
+ _window_size: int = 0 # forced due to inheritance and default arguments
2133
+
2134
+ def to(self, device) -> "BlockDiagonalCausalLocalAttentionFromBottomRightMask":
2135
+ assert type(self) is BlockDiagonalCausalLocalAttentionFromBottomRightMask, (
2136
+ "Please implement in subclass"
2137
+ )
2138
+ return BlockDiagonalCausalLocalAttentionFromBottomRightMask(
2139
+ q_seqinfo=self.q_seqinfo.to(device),
2140
+ k_seqinfo=self.k_seqinfo.to(device),
2141
+ _batch_sizes=self._batch_sizes,
2142
+ _window_size=self._window_size,
2143
+ )
2144
+
2145
+ def __post_init__(self):
2146
+ super().__post_init__()
2147
+ if self._window_size <= 0:
2148
+ raise ValueError(
2149
+ f"Expected `window_size > 0`, but window_size={self._window_size}"
2150
+ )
2151
+
2152
+ def _create_block_mask(
2153
+ self,
2154
+ shape: Tuple[int, ...],
2155
+ dtype: torch.dtype = torch.float32,
2156
+ device: Union[str, torch.device] = "cpu",
2157
+ ) -> torch.Tensor:
2158
+ return _materialize_causal_mask(
2159
+ shape,
2160
+ dtype=dtype,
2161
+ device=device,
2162
+ window_size=self._window_size,
2163
+ from_bottomright=True,
2164
+ )
2165
+
2166
+
2167
+ torch._dynamo.allow_in_graph(LowerTriangularMask)
2168
+ torch._dynamo.allow_in_graph(LowerTriangularMaskWithTensorBias)
2169
+
2170
+ VARLEN_BIASES = (
2171
+ BlockDiagonalMask,
2172
+ BlockDiagonalGappyKeysMask,
2173
+ BlockDiagonalPaddedKeysMask,
2174
+ PagedBlockDiagonalPaddedKeysMask,
2175
+ PagedBlockDiagonalGappyKeysMask,
2176
+ )
2177
+
2178
+
2179
+ def _apply_locality_on_mask(mask: torch.Tensor, window_left: int, window_right: int):
2180
+ """Assumes 0s and 1s. If you need -inf and 1, take a log."""
2181
+ num_queries, num_keys = mask.shape[-2:]
2182
+ shift = num_keys - num_queries
2183
+
2184
+ mask = torch.triu(mask, diagonal=shift - window_left)
2185
+ mask = torch.tril(mask, diagonal=shift + window_right)
2186
+ return mask