boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,189 @@
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import partial, partialmethod
17
+ from typing import Optional
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from boltz.model.layers.triangular_attention.primitives import (
23
+ Attention,
24
+ LayerNorm,
25
+ Linear,
26
+ )
27
+ from boltz.model.layers.triangular_attention.utils import (
28
+ chunk_layer,
29
+ permute_final_dims,
30
+ )
31
+
32
+
33
+ class TriangleAttention(nn.Module):
34
+ """Implement Algorithm 12."""
35
+
36
+ def __init__(
37
+ self,
38
+ c_in: int,
39
+ c_hidden: int,
40
+ no_heads: int,
41
+ starting: bool = True,
42
+ inf: float = 1e9,
43
+ ) -> None:
44
+ super().__init__()
45
+
46
+ self.c_in = c_in
47
+ self.c_hidden = c_hidden
48
+ self.no_heads = no_heads
49
+ self.starting = starting
50
+ self.inf = inf
51
+
52
+ self.layer_norm = LayerNorm(self.c_in)
53
+
54
+ self.linear = Linear(c_in, self.no_heads, bias=False, init="normal")
55
+
56
+ self.mha = Attention(
57
+ self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
58
+ )
59
+
60
+ @torch.jit.ignore
61
+ def _chunk(
62
+ self,
63
+ x: torch.Tensor,
64
+ tri_bias: torch.Tensor,
65
+ mask_bias: torch.Tensor,
66
+ mask: torch.Tensor,
67
+ chunk_size: int,
68
+ use_kernels: bool = False,
69
+ ) -> torch.Tensor:
70
+ """Compute triangle attention.
71
+
72
+ Parameters
73
+ ----------
74
+ x : torch.Tensor
75
+ Input tensor of shape [*, I, J, C_in]
76
+ biases : list[torch.Tensor]
77
+ List of bias tensors of shape [*, H, I, J]
78
+ chunk_size : int
79
+ Size of chunks for memory efficient computation
80
+ use_kernels : bool, default=False
81
+ Whether to use optimized CUDA kernels
82
+
83
+ Returns
84
+ -------
85
+ torch.Tensor
86
+ Output tensor of shape [*, I, J, C_in]
87
+
88
+ """
89
+ mha_inputs = {
90
+ "q_x": x,
91
+ "kv_x": x,
92
+ "tri_bias": tri_bias,
93
+ "mask_bias": mask_bias,
94
+ "mask": mask,
95
+ }
96
+
97
+ return chunk_layer(
98
+ partial(
99
+ self.mha,
100
+ use_kernels=use_kernels,
101
+ ),
102
+ mha_inputs,
103
+ chunk_size=chunk_size,
104
+ no_batch_dims=len(x.shape[:-2]),
105
+ _out=None,
106
+ )
107
+
108
+ def forward(
109
+ self,
110
+ x: torch.Tensor,
111
+ mask: Optional[torch.Tensor] = None,
112
+ chunk_size: Optional[int] = None,
113
+ use_kernels: bool = False,
114
+ ) -> torch.Tensor:
115
+ """Compute triangle attention.
116
+
117
+ Parameters
118
+ ----------
119
+ x : torch.Tensor
120
+ Input tensor of shape [*, I, J, C_in]
121
+ mask : torch.Tensor, optional
122
+ Attention mask of shape [*, I, J]
123
+ chunk_size : int, optional
124
+ Size of chunks for memory efficient computation
125
+ use_kernels : bool, default=False
126
+ Whether to use optimized CUDA kernels
127
+
128
+ Returns
129
+ -------
130
+ torch.Tensor
131
+ Output tensor of shape [*, I, J, C_in]
132
+
133
+ """
134
+ if mask is None:
135
+ # [*, I, J]
136
+ mask = x.new_ones(
137
+ x.shape[:-1],
138
+ )
139
+
140
+ if not self.starting:
141
+ x = x.transpose(-2, -3)
142
+ mask = mask.transpose(-1, -2)
143
+
144
+ # [*, I, J, C_in]
145
+ x = self.layer_norm(x)
146
+
147
+ # [*, I, 1, 1, J]
148
+ mask = mask[..., :, None, None, :]
149
+ mask_bias = self.inf * (mask - 1)
150
+
151
+ # [*, H, I, J]
152
+ triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
153
+
154
+ # [*, 1, H, I, J]
155
+ triangle_bias = triangle_bias.unsqueeze(-4)
156
+
157
+ if chunk_size is not None and not use_kernels:
158
+ x = self._chunk(
159
+ x,
160
+ triangle_bias,
161
+ mask_bias,
162
+ mask,
163
+ chunk_size,
164
+ use_kernels=use_kernels,
165
+ )
166
+ else:
167
+ x = self.mha(
168
+ x,
169
+ x,
170
+ triangle_bias,
171
+ mask_bias,
172
+ mask,
173
+ use_kernels=use_kernels,
174
+ )
175
+
176
+ if not self.starting:
177
+ x = x.transpose(-2, -3)
178
+
179
+ return x
180
+
181
+
182
+ # Implements Algorithm 13
183
+ TriangleAttentionStartingNode = TriangleAttention
184
+
185
+
186
+ class TriangleAttentionEndingNode(TriangleAttention):
187
+ """Implement Algorithm 14."""
188
+
189
+ __init__ = partialmethod(TriangleAttention.__init__, starting=False)
@@ -0,0 +1,409 @@
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from typing import Callable, List, Optional, Tuple
18
+
19
+ import torch
20
+ from cuequivariance_torch.primitives.triangle import triangle_attention
21
+ from einops import rearrange
22
+ from torch import nn
23
+
24
+ from boltz.model.layers import initialize
25
+ from boltz.model.layers.triangular_attention.utils import (
26
+ flatten_final_dims,
27
+ permute_final_dims,
28
+ )
29
+
30
+
31
+ class Linear(nn.Linear):
32
+ """
33
+ A Linear layer with built-in nonstandard initializations. Called just
34
+ like torch.nn.Linear.
35
+
36
+ Implements the initializers in 1.11.4, plus some additional ones found
37
+ in the code.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ in_dim: int,
43
+ out_dim: int,
44
+ bias: bool = True,
45
+ init: str = "default",
46
+ init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
47
+ precision=None,
48
+ ):
49
+ """Initialize the linear layer.
50
+
51
+ Parameters
52
+ ----------
53
+ in_dim : int
54
+ The final dimension of inputs to the layer
55
+ out_dim : int
56
+ The final dimension of layer outputs
57
+ bias : bool, default=True
58
+ Whether to learn an additive bias
59
+ init : str, default='default'
60
+ The initializer to use. Choose from:
61
+
62
+ - "default": LeCun fan-in truncated normal initialization
63
+ - "relu": He initialization w/ truncated normal distribution
64
+ - "glorot": Fan-average Glorot uniform initialization
65
+ - "gating": Weights=0, Bias=1
66
+ - "normal": Normal initialization with std=1/sqrt(fan_in)
67
+ - "final": Weights=0, Bias=0
68
+
69
+ Overridden by init_fn if the latter is not None.
70
+ init_fn : callable, optional
71
+ A custom initializer taking weight and bias as inputs.
72
+ Overrides init if not None.
73
+
74
+ """
75
+ super().__init__(in_dim, out_dim, bias=bias)
76
+
77
+ if bias:
78
+ with torch.no_grad():
79
+ self.bias.fill_(0)
80
+
81
+ with torch.no_grad():
82
+ if init_fn is not None:
83
+ init_fn(self.weight, self.bias)
84
+ else:
85
+ if init == "default":
86
+ initialize.lecun_normal_init_(self.weight)
87
+ elif init == "relu":
88
+ initialize.he_normal_init_(self.weight)
89
+ elif init == "glorot":
90
+ initialize.glorot_uniform_init_(self.weight)
91
+ elif init == "gating":
92
+ initialize.gating_init_(self.weight)
93
+ if bias:
94
+ self.bias.fill_(1.0)
95
+ elif init == "normal":
96
+ initialize.normal_init_(self.weight)
97
+ elif init == "final":
98
+ initialize.final_init_(self.weight)
99
+ else:
100
+ raise ValueError("Invalid init string.")
101
+
102
+ self.precision = precision
103
+
104
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
105
+ d = input.dtype
106
+ if self.precision is not None:
107
+ with torch.autocast("cuda", enabled=False):
108
+ bias = (
109
+ self.bias.to(dtype=self.precision)
110
+ if self.bias is not None
111
+ else None
112
+ )
113
+ return nn.functional.linear(
114
+ input.to(dtype=self.precision),
115
+ self.weight.to(dtype=self.precision),
116
+ bias,
117
+ ).to(dtype=d)
118
+
119
+ if d is torch.bfloat16:
120
+ with torch.autocast("cuda", enabled=False):
121
+ bias = self.bias.to(dtype=d) if self.bias is not None else None
122
+ return nn.functional.linear(input, self.weight.to(dtype=d), bias)
123
+
124
+ return nn.functional.linear(input, self.weight, self.bias)
125
+
126
+
127
+ class LayerNorm(nn.Module):
128
+ def __init__(self, c_in, eps=1e-5):
129
+ super(LayerNorm, self).__init__()
130
+
131
+ self.c_in = (c_in,)
132
+ self.eps = eps
133
+
134
+ self.weight = nn.Parameter(torch.ones(c_in))
135
+ self.bias = nn.Parameter(torch.zeros(c_in))
136
+
137
+ def forward(self, x):
138
+ d = x.dtype
139
+ if d is torch.bfloat16:
140
+ with torch.autocast("cuda", enabled=False):
141
+ out = nn.functional.layer_norm(
142
+ x,
143
+ self.c_in,
144
+ self.weight.to(dtype=d),
145
+ self.bias.to(dtype=d),
146
+ self.eps,
147
+ )
148
+ else:
149
+ out = nn.functional.layer_norm(
150
+ x,
151
+ self.c_in,
152
+ self.weight,
153
+ self.bias,
154
+ self.eps,
155
+ )
156
+
157
+ return out
158
+
159
+
160
+ @torch.jit.ignore
161
+ def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
162
+ """
163
+ Softmax, but without automatic casting to fp32 when the input is of
164
+ type bfloat16
165
+ """
166
+ d = t.dtype
167
+ if d is torch.bfloat16:
168
+ with torch.autocast("cuda", enabled=False):
169
+ s = torch.nn.functional.softmax(t, dim=dim)
170
+ else:
171
+ s = torch.nn.functional.softmax(t, dim=dim)
172
+
173
+ return s
174
+
175
+
176
+ # @torch.jit.script
177
+ def _attention(
178
+ query: torch.Tensor,
179
+ key: torch.Tensor,
180
+ value: torch.Tensor,
181
+ biases: List[torch.Tensor],
182
+ ) -> torch.Tensor:
183
+ # [*, H, C_hidden, K]
184
+ key = permute_final_dims(key, (1, 0))
185
+
186
+ # [*, H, Q, K]
187
+ a = torch.matmul(query, key)
188
+
189
+ for b in biases:
190
+ a += b
191
+
192
+ a = softmax_no_cast(a, -1)
193
+
194
+ # [*, H, Q, C_hidden]
195
+ a = torch.matmul(a, value)
196
+
197
+ return a
198
+
199
+
200
+ @torch.compiler.disable
201
+ def kernel_triangular_attn(q, k, v, tri_bias, mask, scale):
202
+ return triangle_attention(q, k, v, tri_bias, mask=mask, scale=scale)
203
+
204
+
205
+ class Attention(nn.Module):
206
+ """
207
+ Standard multi-head attention using AlphaFold's default layer
208
+ initialization. Allows multiple bias vectors.
209
+ """
210
+
211
+ def __init__(
212
+ self,
213
+ c_q: int,
214
+ c_k: int,
215
+ c_v: int,
216
+ c_hidden: int,
217
+ no_heads: int,
218
+ gating: bool = True,
219
+ ):
220
+ """Initialize the attention layer.
221
+
222
+ Parameters
223
+ ----------
224
+ c_q : int
225
+ Input dimension of query data
226
+ c_k : int
227
+ Input dimension of key data
228
+ c_v : int
229
+ Input dimension of value data
230
+ c_hidden : int
231
+ Per-head hidden dimension
232
+ no_heads : int
233
+ Number of attention heads
234
+ gating : bool, default=True
235
+ Whether the output should be gated using query data
236
+
237
+ """
238
+ super().__init__()
239
+
240
+ self.c_q = c_q
241
+ self.c_k = c_k
242
+ self.c_v = c_v
243
+ self.c_hidden = c_hidden
244
+ self.no_heads = no_heads
245
+ self.gating = gating
246
+
247
+ # DISCREPANCY: c_hidden is not the per-head channel dimension, as
248
+ # stated in the supplement, but the overall channel dimension.
249
+
250
+ self.linear_q = Linear(
251
+ self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot"
252
+ )
253
+ self.linear_k = Linear(
254
+ self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot"
255
+ )
256
+ self.linear_v = Linear(
257
+ self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot"
258
+ )
259
+ self.linear_o = Linear(
260
+ self.c_hidden * self.no_heads, self.c_q, bias=False, init="final"
261
+ )
262
+
263
+ self.linear_g = None
264
+ if self.gating:
265
+ self.linear_g = Linear(
266
+ self.c_q, self.c_hidden * self.no_heads, bias=False, init="gating"
267
+ )
268
+
269
+ self.sigmoid = nn.Sigmoid()
270
+
271
+ def _prep_qkv(
272
+ self, q_x: torch.Tensor, kv_x: torch.Tensor, apply_scale: bool = True
273
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
274
+ # [*, Q/K/V, H * C_hidden]
275
+ q = self.linear_q(q_x)
276
+ k = self.linear_k(kv_x)
277
+ v = self.linear_v(kv_x)
278
+
279
+ # [*, Q/K, H, C_hidden]
280
+ q = q.view(q.shape[:-1] + (self.no_heads, -1))
281
+ k = k.view(k.shape[:-1] + (self.no_heads, -1))
282
+ v = v.view(v.shape[:-1] + (self.no_heads, -1))
283
+
284
+ # [*, H, Q/K, C_hidden]
285
+ q = q.transpose(-2, -3)
286
+ k = k.transpose(-2, -3)
287
+ v = v.transpose(-2, -3)
288
+
289
+ if apply_scale:
290
+ q /= math.sqrt(self.c_hidden)
291
+
292
+ return q, k, v
293
+
294
+ def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor:
295
+ if self.linear_g is not None:
296
+ g = self.sigmoid(self.linear_g(q_x))
297
+
298
+ # [*, Q, H, C_hidden]
299
+ g = g.view(g.shape[:-1] + (self.no_heads, -1))
300
+ o = o * g
301
+
302
+ # [*, Q, H * C_hidden]
303
+ o = flatten_final_dims(o, 2)
304
+
305
+ # [*, Q, C_q]
306
+ o = self.linear_o(o)
307
+
308
+ return o
309
+
310
+ def forward(
311
+ self,
312
+ q_x: torch.Tensor,
313
+ kv_x: torch.Tensor,
314
+ tri_bias: torch.Tensor,
315
+ mask_bias: torch.Tensor,
316
+ mask: torch.Tensor,
317
+ use_kernels: bool = False,
318
+ ) -> torch.Tensor:
319
+ """Compute attention.
320
+
321
+ Parameters
322
+ ----------
323
+ q_x : torch.Tensor
324
+ [*, Q, C_q] query data
325
+ kv_x : torch.Tensor
326
+ [*, K, C_k] key data
327
+ tri_bias : torch.Tensor
328
+ [*, H, Q, K] triangular bias
329
+ mask_bias : torch.Tensor
330
+ [*, H, Q, K] mask bias
331
+ mask : torch.Tensor
332
+ [*, Q, K] mask
333
+ use_kernels : bool, default=False
334
+ Whether to use optimized CUDA kernels
335
+
336
+ Returns
337
+ -------
338
+ [*, Q, C_q] attention update
339
+
340
+ """
341
+ # Attention kernel applies scaling internally
342
+ q, k, v = self._prep_qkv(
343
+ q_x,
344
+ kv_x,
345
+ apply_scale=not use_kernels,
346
+ )
347
+
348
+ if use_kernels:
349
+ scale = 1.0 / math.sqrt(self.c_hidden)
350
+ o = kernel_triangular_attn(
351
+ q,
352
+ k,
353
+ v,
354
+ tri_bias=tri_bias,
355
+ mask=mask.bool(),
356
+ scale=scale,
357
+ )
358
+ o = o.transpose(-2, -3)
359
+ else:
360
+ biases = [mask_bias, tri_bias]
361
+ o = _attention(q, k, v, biases)
362
+ o = o.transpose(-2, -3)
363
+
364
+ o = self._wrap_up(o, q_x)
365
+
366
+ return o
367
+
368
+
369
+ def _trifast_attn(q, k, v, biases):
370
+ orig_n_dims = len(q.shape)
371
+
372
+ if len(biases) != 2:
373
+ raise ValueError(f"Trifast expects two bias terms, found {len(biases)}")
374
+
375
+ mask, b = biases
376
+
377
+ if len(b.shape) == 5:
378
+ # Sometimes there is an extra batch dim -- why?
379
+ b = b.squeeze(1)
380
+
381
+ if orig_n_dims == 4:
382
+ # add fake batch dim
383
+ q = q.unsqueeze(0)
384
+ k = k.unsqueeze(0)
385
+ v = v.unsqueeze(0)
386
+ # b = b.unsqueeze(0) not sure why this and only this has a batch dim?
387
+ mask = mask.unsqueeze(0)
388
+
389
+ if len(q.shape) != 5:
390
+ raise ValueError(f"Trifast expects q/k/v to be 5D, found {len(q.shape)}")
391
+
392
+ # Reorder q/k/v
393
+ q = rearrange(q, "b i h j d -> b h i j d")
394
+ k = rearrange(k, "b i h j d -> b h i j d")
395
+ v = rearrange(v, "b i h j d -> b h i j d")
396
+
397
+ # Make mask the right shape.
398
+ mask = rearrange(mask, "b i () () j -> b i j").bool()
399
+
400
+ # Delay import to here to avoid initializing cuda too early
401
+ from trifast import triangle_attention
402
+
403
+ o = triangle_attention(q, k, v, b, mask)
404
+ o = rearrange(o, "b h i j d -> b i j h d")
405
+
406
+ # Remove the batch dim if we added it.
407
+ if orig_n_dims == 4:
408
+ o = o.squeeze(0)
409
+ return o