boltz-vsynthes 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/__init__.py +7 -0
- boltz/data/__init__.py +0 -0
- boltz/data/const.py +1184 -0
- boltz/data/crop/__init__.py +0 -0
- boltz/data/crop/affinity.py +164 -0
- boltz/data/crop/boltz.py +296 -0
- boltz/data/crop/cropper.py +45 -0
- boltz/data/feature/__init__.py +0 -0
- boltz/data/feature/featurizer.py +1230 -0
- boltz/data/feature/featurizerv2.py +2208 -0
- boltz/data/feature/symmetry.py +602 -0
- boltz/data/filter/__init__.py +0 -0
- boltz/data/filter/dynamic/__init__.py +0 -0
- boltz/data/filter/dynamic/date.py +76 -0
- boltz/data/filter/dynamic/filter.py +24 -0
- boltz/data/filter/dynamic/max_residues.py +37 -0
- boltz/data/filter/dynamic/resolution.py +34 -0
- boltz/data/filter/dynamic/size.py +38 -0
- boltz/data/filter/dynamic/subset.py +42 -0
- boltz/data/filter/static/__init__.py +0 -0
- boltz/data/filter/static/filter.py +26 -0
- boltz/data/filter/static/ligand.py +37 -0
- boltz/data/filter/static/polymer.py +299 -0
- boltz/data/module/__init__.py +0 -0
- boltz/data/module/inference.py +307 -0
- boltz/data/module/inferencev2.py +429 -0
- boltz/data/module/training.py +684 -0
- boltz/data/module/trainingv2.py +660 -0
- boltz/data/mol.py +900 -0
- boltz/data/msa/__init__.py +0 -0
- boltz/data/msa/mmseqs2.py +235 -0
- boltz/data/pad.py +84 -0
- boltz/data/parse/__init__.py +0 -0
- boltz/data/parse/a3m.py +134 -0
- boltz/data/parse/csv.py +100 -0
- boltz/data/parse/fasta.py +138 -0
- boltz/data/parse/mmcif.py +1239 -0
- boltz/data/parse/mmcif_with_constraints.py +1607 -0
- boltz/data/parse/schema.py +1851 -0
- boltz/data/parse/yaml.py +68 -0
- boltz/data/sample/__init__.py +0 -0
- boltz/data/sample/cluster.py +283 -0
- boltz/data/sample/distillation.py +57 -0
- boltz/data/sample/random.py +39 -0
- boltz/data/sample/sampler.py +49 -0
- boltz/data/tokenize/__init__.py +0 -0
- boltz/data/tokenize/boltz.py +195 -0
- boltz/data/tokenize/boltz2.py +396 -0
- boltz/data/tokenize/tokenizer.py +24 -0
- boltz/data/types.py +777 -0
- boltz/data/write/__init__.py +0 -0
- boltz/data/write/mmcif.py +305 -0
- boltz/data/write/pdb.py +171 -0
- boltz/data/write/utils.py +23 -0
- boltz/data/write/writer.py +330 -0
- boltz/main.py +1292 -0
- boltz/model/__init__.py +0 -0
- boltz/model/layers/__init__.py +0 -0
- boltz/model/layers/attention.py +132 -0
- boltz/model/layers/attentionv2.py +111 -0
- boltz/model/layers/confidence_utils.py +231 -0
- boltz/model/layers/dropout.py +34 -0
- boltz/model/layers/initialize.py +100 -0
- boltz/model/layers/outer_product_mean.py +98 -0
- boltz/model/layers/pair_averaging.py +135 -0
- boltz/model/layers/pairformer.py +337 -0
- boltz/model/layers/relative.py +58 -0
- boltz/model/layers/transition.py +78 -0
- boltz/model/layers/triangular_attention/__init__.py +0 -0
- boltz/model/layers/triangular_attention/attention.py +189 -0
- boltz/model/layers/triangular_attention/primitives.py +409 -0
- boltz/model/layers/triangular_attention/utils.py +380 -0
- boltz/model/layers/triangular_mult.py +212 -0
- boltz/model/loss/__init__.py +0 -0
- boltz/model/loss/bfactor.py +49 -0
- boltz/model/loss/confidence.py +590 -0
- boltz/model/loss/confidencev2.py +621 -0
- boltz/model/loss/diffusion.py +171 -0
- boltz/model/loss/diffusionv2.py +134 -0
- boltz/model/loss/distogram.py +48 -0
- boltz/model/loss/distogramv2.py +105 -0
- boltz/model/loss/validation.py +1025 -0
- boltz/model/models/__init__.py +0 -0
- boltz/model/models/boltz1.py +1286 -0
- boltz/model/models/boltz2.py +1249 -0
- boltz/model/modules/__init__.py +0 -0
- boltz/model/modules/affinity.py +223 -0
- boltz/model/modules/confidence.py +481 -0
- boltz/model/modules/confidence_utils.py +181 -0
- boltz/model/modules/confidencev2.py +495 -0
- boltz/model/modules/diffusion.py +844 -0
- boltz/model/modules/diffusion_conditioning.py +116 -0
- boltz/model/modules/diffusionv2.py +677 -0
- boltz/model/modules/encoders.py +639 -0
- boltz/model/modules/encodersv2.py +565 -0
- boltz/model/modules/transformers.py +322 -0
- boltz/model/modules/transformersv2.py +261 -0
- boltz/model/modules/trunk.py +688 -0
- boltz/model/modules/trunkv2.py +828 -0
- boltz/model/modules/utils.py +303 -0
- boltz/model/optim/__init__.py +0 -0
- boltz/model/optim/ema.py +389 -0
- boltz/model/optim/scheduler.py +99 -0
- boltz/model/potentials/__init__.py +0 -0
- boltz/model/potentials/potentials.py +497 -0
- boltz/model/potentials/schedules.py +32 -0
- boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
- boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
- boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
- boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
- boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
- boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,189 @@
|
|
1
|
+
# Copyright 2021 AlQuraishi Laboratory
|
2
|
+
# Copyright 2021 DeepMind Technologies Limited
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
from functools import partial, partialmethod
|
17
|
+
from typing import Optional
|
18
|
+
|
19
|
+
import torch
|
20
|
+
import torch.nn as nn
|
21
|
+
|
22
|
+
from boltz.model.layers.triangular_attention.primitives import (
|
23
|
+
Attention,
|
24
|
+
LayerNorm,
|
25
|
+
Linear,
|
26
|
+
)
|
27
|
+
from boltz.model.layers.triangular_attention.utils import (
|
28
|
+
chunk_layer,
|
29
|
+
permute_final_dims,
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
class TriangleAttention(nn.Module):
|
34
|
+
"""Implement Algorithm 12."""
|
35
|
+
|
36
|
+
def __init__(
|
37
|
+
self,
|
38
|
+
c_in: int,
|
39
|
+
c_hidden: int,
|
40
|
+
no_heads: int,
|
41
|
+
starting: bool = True,
|
42
|
+
inf: float = 1e9,
|
43
|
+
) -> None:
|
44
|
+
super().__init__()
|
45
|
+
|
46
|
+
self.c_in = c_in
|
47
|
+
self.c_hidden = c_hidden
|
48
|
+
self.no_heads = no_heads
|
49
|
+
self.starting = starting
|
50
|
+
self.inf = inf
|
51
|
+
|
52
|
+
self.layer_norm = LayerNorm(self.c_in)
|
53
|
+
|
54
|
+
self.linear = Linear(c_in, self.no_heads, bias=False, init="normal")
|
55
|
+
|
56
|
+
self.mha = Attention(
|
57
|
+
self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
|
58
|
+
)
|
59
|
+
|
60
|
+
@torch.jit.ignore
|
61
|
+
def _chunk(
|
62
|
+
self,
|
63
|
+
x: torch.Tensor,
|
64
|
+
tri_bias: torch.Tensor,
|
65
|
+
mask_bias: torch.Tensor,
|
66
|
+
mask: torch.Tensor,
|
67
|
+
chunk_size: int,
|
68
|
+
use_kernels: bool = False,
|
69
|
+
) -> torch.Tensor:
|
70
|
+
"""Compute triangle attention.
|
71
|
+
|
72
|
+
Parameters
|
73
|
+
----------
|
74
|
+
x : torch.Tensor
|
75
|
+
Input tensor of shape [*, I, J, C_in]
|
76
|
+
biases : list[torch.Tensor]
|
77
|
+
List of bias tensors of shape [*, H, I, J]
|
78
|
+
chunk_size : int
|
79
|
+
Size of chunks for memory efficient computation
|
80
|
+
use_kernels : bool, default=False
|
81
|
+
Whether to use optimized CUDA kernels
|
82
|
+
|
83
|
+
Returns
|
84
|
+
-------
|
85
|
+
torch.Tensor
|
86
|
+
Output tensor of shape [*, I, J, C_in]
|
87
|
+
|
88
|
+
"""
|
89
|
+
mha_inputs = {
|
90
|
+
"q_x": x,
|
91
|
+
"kv_x": x,
|
92
|
+
"tri_bias": tri_bias,
|
93
|
+
"mask_bias": mask_bias,
|
94
|
+
"mask": mask,
|
95
|
+
}
|
96
|
+
|
97
|
+
return chunk_layer(
|
98
|
+
partial(
|
99
|
+
self.mha,
|
100
|
+
use_kernels=use_kernels,
|
101
|
+
),
|
102
|
+
mha_inputs,
|
103
|
+
chunk_size=chunk_size,
|
104
|
+
no_batch_dims=len(x.shape[:-2]),
|
105
|
+
_out=None,
|
106
|
+
)
|
107
|
+
|
108
|
+
def forward(
|
109
|
+
self,
|
110
|
+
x: torch.Tensor,
|
111
|
+
mask: Optional[torch.Tensor] = None,
|
112
|
+
chunk_size: Optional[int] = None,
|
113
|
+
use_kernels: bool = False,
|
114
|
+
) -> torch.Tensor:
|
115
|
+
"""Compute triangle attention.
|
116
|
+
|
117
|
+
Parameters
|
118
|
+
----------
|
119
|
+
x : torch.Tensor
|
120
|
+
Input tensor of shape [*, I, J, C_in]
|
121
|
+
mask : torch.Tensor, optional
|
122
|
+
Attention mask of shape [*, I, J]
|
123
|
+
chunk_size : int, optional
|
124
|
+
Size of chunks for memory efficient computation
|
125
|
+
use_kernels : bool, default=False
|
126
|
+
Whether to use optimized CUDA kernels
|
127
|
+
|
128
|
+
Returns
|
129
|
+
-------
|
130
|
+
torch.Tensor
|
131
|
+
Output tensor of shape [*, I, J, C_in]
|
132
|
+
|
133
|
+
"""
|
134
|
+
if mask is None:
|
135
|
+
# [*, I, J]
|
136
|
+
mask = x.new_ones(
|
137
|
+
x.shape[:-1],
|
138
|
+
)
|
139
|
+
|
140
|
+
if not self.starting:
|
141
|
+
x = x.transpose(-2, -3)
|
142
|
+
mask = mask.transpose(-1, -2)
|
143
|
+
|
144
|
+
# [*, I, J, C_in]
|
145
|
+
x = self.layer_norm(x)
|
146
|
+
|
147
|
+
# [*, I, 1, 1, J]
|
148
|
+
mask = mask[..., :, None, None, :]
|
149
|
+
mask_bias = self.inf * (mask - 1)
|
150
|
+
|
151
|
+
# [*, H, I, J]
|
152
|
+
triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
|
153
|
+
|
154
|
+
# [*, 1, H, I, J]
|
155
|
+
triangle_bias = triangle_bias.unsqueeze(-4)
|
156
|
+
|
157
|
+
if chunk_size is not None and not use_kernels:
|
158
|
+
x = self._chunk(
|
159
|
+
x,
|
160
|
+
triangle_bias,
|
161
|
+
mask_bias,
|
162
|
+
mask,
|
163
|
+
chunk_size,
|
164
|
+
use_kernels=use_kernels,
|
165
|
+
)
|
166
|
+
else:
|
167
|
+
x = self.mha(
|
168
|
+
x,
|
169
|
+
x,
|
170
|
+
triangle_bias,
|
171
|
+
mask_bias,
|
172
|
+
mask,
|
173
|
+
use_kernels=use_kernels,
|
174
|
+
)
|
175
|
+
|
176
|
+
if not self.starting:
|
177
|
+
x = x.transpose(-2, -3)
|
178
|
+
|
179
|
+
return x
|
180
|
+
|
181
|
+
|
182
|
+
# Implements Algorithm 13
|
183
|
+
TriangleAttentionStartingNode = TriangleAttention
|
184
|
+
|
185
|
+
|
186
|
+
class TriangleAttentionEndingNode(TriangleAttention):
|
187
|
+
"""Implement Algorithm 14."""
|
188
|
+
|
189
|
+
__init__ = partialmethod(TriangleAttention.__init__, starting=False)
|
@@ -0,0 +1,409 @@
|
|
1
|
+
# Copyright 2021 AlQuraishi Laboratory
|
2
|
+
# Copyright 2021 DeepMind Technologies Limited
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import math
|
17
|
+
from typing import Callable, List, Optional, Tuple
|
18
|
+
|
19
|
+
import torch
|
20
|
+
from cuequivariance_torch.primitives.triangle import triangle_attention
|
21
|
+
from einops import rearrange
|
22
|
+
from torch import nn
|
23
|
+
|
24
|
+
from boltz.model.layers import initialize
|
25
|
+
from boltz.model.layers.triangular_attention.utils import (
|
26
|
+
flatten_final_dims,
|
27
|
+
permute_final_dims,
|
28
|
+
)
|
29
|
+
|
30
|
+
|
31
|
+
class Linear(nn.Linear):
|
32
|
+
"""
|
33
|
+
A Linear layer with built-in nonstandard initializations. Called just
|
34
|
+
like torch.nn.Linear.
|
35
|
+
|
36
|
+
Implements the initializers in 1.11.4, plus some additional ones found
|
37
|
+
in the code.
|
38
|
+
"""
|
39
|
+
|
40
|
+
def __init__(
|
41
|
+
self,
|
42
|
+
in_dim: int,
|
43
|
+
out_dim: int,
|
44
|
+
bias: bool = True,
|
45
|
+
init: str = "default",
|
46
|
+
init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
|
47
|
+
precision=None,
|
48
|
+
):
|
49
|
+
"""Initialize the linear layer.
|
50
|
+
|
51
|
+
Parameters
|
52
|
+
----------
|
53
|
+
in_dim : int
|
54
|
+
The final dimension of inputs to the layer
|
55
|
+
out_dim : int
|
56
|
+
The final dimension of layer outputs
|
57
|
+
bias : bool, default=True
|
58
|
+
Whether to learn an additive bias
|
59
|
+
init : str, default='default'
|
60
|
+
The initializer to use. Choose from:
|
61
|
+
|
62
|
+
- "default": LeCun fan-in truncated normal initialization
|
63
|
+
- "relu": He initialization w/ truncated normal distribution
|
64
|
+
- "glorot": Fan-average Glorot uniform initialization
|
65
|
+
- "gating": Weights=0, Bias=1
|
66
|
+
- "normal": Normal initialization with std=1/sqrt(fan_in)
|
67
|
+
- "final": Weights=0, Bias=0
|
68
|
+
|
69
|
+
Overridden by init_fn if the latter is not None.
|
70
|
+
init_fn : callable, optional
|
71
|
+
A custom initializer taking weight and bias as inputs.
|
72
|
+
Overrides init if not None.
|
73
|
+
|
74
|
+
"""
|
75
|
+
super().__init__(in_dim, out_dim, bias=bias)
|
76
|
+
|
77
|
+
if bias:
|
78
|
+
with torch.no_grad():
|
79
|
+
self.bias.fill_(0)
|
80
|
+
|
81
|
+
with torch.no_grad():
|
82
|
+
if init_fn is not None:
|
83
|
+
init_fn(self.weight, self.bias)
|
84
|
+
else:
|
85
|
+
if init == "default":
|
86
|
+
initialize.lecun_normal_init_(self.weight)
|
87
|
+
elif init == "relu":
|
88
|
+
initialize.he_normal_init_(self.weight)
|
89
|
+
elif init == "glorot":
|
90
|
+
initialize.glorot_uniform_init_(self.weight)
|
91
|
+
elif init == "gating":
|
92
|
+
initialize.gating_init_(self.weight)
|
93
|
+
if bias:
|
94
|
+
self.bias.fill_(1.0)
|
95
|
+
elif init == "normal":
|
96
|
+
initialize.normal_init_(self.weight)
|
97
|
+
elif init == "final":
|
98
|
+
initialize.final_init_(self.weight)
|
99
|
+
else:
|
100
|
+
raise ValueError("Invalid init string.")
|
101
|
+
|
102
|
+
self.precision = precision
|
103
|
+
|
104
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
105
|
+
d = input.dtype
|
106
|
+
if self.precision is not None:
|
107
|
+
with torch.autocast("cuda", enabled=False):
|
108
|
+
bias = (
|
109
|
+
self.bias.to(dtype=self.precision)
|
110
|
+
if self.bias is not None
|
111
|
+
else None
|
112
|
+
)
|
113
|
+
return nn.functional.linear(
|
114
|
+
input.to(dtype=self.precision),
|
115
|
+
self.weight.to(dtype=self.precision),
|
116
|
+
bias,
|
117
|
+
).to(dtype=d)
|
118
|
+
|
119
|
+
if d is torch.bfloat16:
|
120
|
+
with torch.autocast("cuda", enabled=False):
|
121
|
+
bias = self.bias.to(dtype=d) if self.bias is not None else None
|
122
|
+
return nn.functional.linear(input, self.weight.to(dtype=d), bias)
|
123
|
+
|
124
|
+
return nn.functional.linear(input, self.weight, self.bias)
|
125
|
+
|
126
|
+
|
127
|
+
class LayerNorm(nn.Module):
|
128
|
+
def __init__(self, c_in, eps=1e-5):
|
129
|
+
super(LayerNorm, self).__init__()
|
130
|
+
|
131
|
+
self.c_in = (c_in,)
|
132
|
+
self.eps = eps
|
133
|
+
|
134
|
+
self.weight = nn.Parameter(torch.ones(c_in))
|
135
|
+
self.bias = nn.Parameter(torch.zeros(c_in))
|
136
|
+
|
137
|
+
def forward(self, x):
|
138
|
+
d = x.dtype
|
139
|
+
if d is torch.bfloat16:
|
140
|
+
with torch.autocast("cuda", enabled=False):
|
141
|
+
out = nn.functional.layer_norm(
|
142
|
+
x,
|
143
|
+
self.c_in,
|
144
|
+
self.weight.to(dtype=d),
|
145
|
+
self.bias.to(dtype=d),
|
146
|
+
self.eps,
|
147
|
+
)
|
148
|
+
else:
|
149
|
+
out = nn.functional.layer_norm(
|
150
|
+
x,
|
151
|
+
self.c_in,
|
152
|
+
self.weight,
|
153
|
+
self.bias,
|
154
|
+
self.eps,
|
155
|
+
)
|
156
|
+
|
157
|
+
return out
|
158
|
+
|
159
|
+
|
160
|
+
@torch.jit.ignore
|
161
|
+
def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
162
|
+
"""
|
163
|
+
Softmax, but without automatic casting to fp32 when the input is of
|
164
|
+
type bfloat16
|
165
|
+
"""
|
166
|
+
d = t.dtype
|
167
|
+
if d is torch.bfloat16:
|
168
|
+
with torch.autocast("cuda", enabled=False):
|
169
|
+
s = torch.nn.functional.softmax(t, dim=dim)
|
170
|
+
else:
|
171
|
+
s = torch.nn.functional.softmax(t, dim=dim)
|
172
|
+
|
173
|
+
return s
|
174
|
+
|
175
|
+
|
176
|
+
# @torch.jit.script
|
177
|
+
def _attention(
|
178
|
+
query: torch.Tensor,
|
179
|
+
key: torch.Tensor,
|
180
|
+
value: torch.Tensor,
|
181
|
+
biases: List[torch.Tensor],
|
182
|
+
) -> torch.Tensor:
|
183
|
+
# [*, H, C_hidden, K]
|
184
|
+
key = permute_final_dims(key, (1, 0))
|
185
|
+
|
186
|
+
# [*, H, Q, K]
|
187
|
+
a = torch.matmul(query, key)
|
188
|
+
|
189
|
+
for b in biases:
|
190
|
+
a += b
|
191
|
+
|
192
|
+
a = softmax_no_cast(a, -1)
|
193
|
+
|
194
|
+
# [*, H, Q, C_hidden]
|
195
|
+
a = torch.matmul(a, value)
|
196
|
+
|
197
|
+
return a
|
198
|
+
|
199
|
+
|
200
|
+
@torch.compiler.disable
|
201
|
+
def kernel_triangular_attn(q, k, v, tri_bias, mask, scale):
|
202
|
+
return triangle_attention(q, k, v, tri_bias, mask=mask, scale=scale)
|
203
|
+
|
204
|
+
|
205
|
+
class Attention(nn.Module):
|
206
|
+
"""
|
207
|
+
Standard multi-head attention using AlphaFold's default layer
|
208
|
+
initialization. Allows multiple bias vectors.
|
209
|
+
"""
|
210
|
+
|
211
|
+
def __init__(
|
212
|
+
self,
|
213
|
+
c_q: int,
|
214
|
+
c_k: int,
|
215
|
+
c_v: int,
|
216
|
+
c_hidden: int,
|
217
|
+
no_heads: int,
|
218
|
+
gating: bool = True,
|
219
|
+
):
|
220
|
+
"""Initialize the attention layer.
|
221
|
+
|
222
|
+
Parameters
|
223
|
+
----------
|
224
|
+
c_q : int
|
225
|
+
Input dimension of query data
|
226
|
+
c_k : int
|
227
|
+
Input dimension of key data
|
228
|
+
c_v : int
|
229
|
+
Input dimension of value data
|
230
|
+
c_hidden : int
|
231
|
+
Per-head hidden dimension
|
232
|
+
no_heads : int
|
233
|
+
Number of attention heads
|
234
|
+
gating : bool, default=True
|
235
|
+
Whether the output should be gated using query data
|
236
|
+
|
237
|
+
"""
|
238
|
+
super().__init__()
|
239
|
+
|
240
|
+
self.c_q = c_q
|
241
|
+
self.c_k = c_k
|
242
|
+
self.c_v = c_v
|
243
|
+
self.c_hidden = c_hidden
|
244
|
+
self.no_heads = no_heads
|
245
|
+
self.gating = gating
|
246
|
+
|
247
|
+
# DISCREPANCY: c_hidden is not the per-head channel dimension, as
|
248
|
+
# stated in the supplement, but the overall channel dimension.
|
249
|
+
|
250
|
+
self.linear_q = Linear(
|
251
|
+
self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot"
|
252
|
+
)
|
253
|
+
self.linear_k = Linear(
|
254
|
+
self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot"
|
255
|
+
)
|
256
|
+
self.linear_v = Linear(
|
257
|
+
self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot"
|
258
|
+
)
|
259
|
+
self.linear_o = Linear(
|
260
|
+
self.c_hidden * self.no_heads, self.c_q, bias=False, init="final"
|
261
|
+
)
|
262
|
+
|
263
|
+
self.linear_g = None
|
264
|
+
if self.gating:
|
265
|
+
self.linear_g = Linear(
|
266
|
+
self.c_q, self.c_hidden * self.no_heads, bias=False, init="gating"
|
267
|
+
)
|
268
|
+
|
269
|
+
self.sigmoid = nn.Sigmoid()
|
270
|
+
|
271
|
+
def _prep_qkv(
|
272
|
+
self, q_x: torch.Tensor, kv_x: torch.Tensor, apply_scale: bool = True
|
273
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
274
|
+
# [*, Q/K/V, H * C_hidden]
|
275
|
+
q = self.linear_q(q_x)
|
276
|
+
k = self.linear_k(kv_x)
|
277
|
+
v = self.linear_v(kv_x)
|
278
|
+
|
279
|
+
# [*, Q/K, H, C_hidden]
|
280
|
+
q = q.view(q.shape[:-1] + (self.no_heads, -1))
|
281
|
+
k = k.view(k.shape[:-1] + (self.no_heads, -1))
|
282
|
+
v = v.view(v.shape[:-1] + (self.no_heads, -1))
|
283
|
+
|
284
|
+
# [*, H, Q/K, C_hidden]
|
285
|
+
q = q.transpose(-2, -3)
|
286
|
+
k = k.transpose(-2, -3)
|
287
|
+
v = v.transpose(-2, -3)
|
288
|
+
|
289
|
+
if apply_scale:
|
290
|
+
q /= math.sqrt(self.c_hidden)
|
291
|
+
|
292
|
+
return q, k, v
|
293
|
+
|
294
|
+
def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor:
|
295
|
+
if self.linear_g is not None:
|
296
|
+
g = self.sigmoid(self.linear_g(q_x))
|
297
|
+
|
298
|
+
# [*, Q, H, C_hidden]
|
299
|
+
g = g.view(g.shape[:-1] + (self.no_heads, -1))
|
300
|
+
o = o * g
|
301
|
+
|
302
|
+
# [*, Q, H * C_hidden]
|
303
|
+
o = flatten_final_dims(o, 2)
|
304
|
+
|
305
|
+
# [*, Q, C_q]
|
306
|
+
o = self.linear_o(o)
|
307
|
+
|
308
|
+
return o
|
309
|
+
|
310
|
+
def forward(
|
311
|
+
self,
|
312
|
+
q_x: torch.Tensor,
|
313
|
+
kv_x: torch.Tensor,
|
314
|
+
tri_bias: torch.Tensor,
|
315
|
+
mask_bias: torch.Tensor,
|
316
|
+
mask: torch.Tensor,
|
317
|
+
use_kernels: bool = False,
|
318
|
+
) -> torch.Tensor:
|
319
|
+
"""Compute attention.
|
320
|
+
|
321
|
+
Parameters
|
322
|
+
----------
|
323
|
+
q_x : torch.Tensor
|
324
|
+
[*, Q, C_q] query data
|
325
|
+
kv_x : torch.Tensor
|
326
|
+
[*, K, C_k] key data
|
327
|
+
tri_bias : torch.Tensor
|
328
|
+
[*, H, Q, K] triangular bias
|
329
|
+
mask_bias : torch.Tensor
|
330
|
+
[*, H, Q, K] mask bias
|
331
|
+
mask : torch.Tensor
|
332
|
+
[*, Q, K] mask
|
333
|
+
use_kernels : bool, default=False
|
334
|
+
Whether to use optimized CUDA kernels
|
335
|
+
|
336
|
+
Returns
|
337
|
+
-------
|
338
|
+
[*, Q, C_q] attention update
|
339
|
+
|
340
|
+
"""
|
341
|
+
# Attention kernel applies scaling internally
|
342
|
+
q, k, v = self._prep_qkv(
|
343
|
+
q_x,
|
344
|
+
kv_x,
|
345
|
+
apply_scale=not use_kernels,
|
346
|
+
)
|
347
|
+
|
348
|
+
if use_kernels:
|
349
|
+
scale = 1.0 / math.sqrt(self.c_hidden)
|
350
|
+
o = kernel_triangular_attn(
|
351
|
+
q,
|
352
|
+
k,
|
353
|
+
v,
|
354
|
+
tri_bias=tri_bias,
|
355
|
+
mask=mask.bool(),
|
356
|
+
scale=scale,
|
357
|
+
)
|
358
|
+
o = o.transpose(-2, -3)
|
359
|
+
else:
|
360
|
+
biases = [mask_bias, tri_bias]
|
361
|
+
o = _attention(q, k, v, biases)
|
362
|
+
o = o.transpose(-2, -3)
|
363
|
+
|
364
|
+
o = self._wrap_up(o, q_x)
|
365
|
+
|
366
|
+
return o
|
367
|
+
|
368
|
+
|
369
|
+
def _trifast_attn(q, k, v, biases):
|
370
|
+
orig_n_dims = len(q.shape)
|
371
|
+
|
372
|
+
if len(biases) != 2:
|
373
|
+
raise ValueError(f"Trifast expects two bias terms, found {len(biases)}")
|
374
|
+
|
375
|
+
mask, b = biases
|
376
|
+
|
377
|
+
if len(b.shape) == 5:
|
378
|
+
# Sometimes there is an extra batch dim -- why?
|
379
|
+
b = b.squeeze(1)
|
380
|
+
|
381
|
+
if orig_n_dims == 4:
|
382
|
+
# add fake batch dim
|
383
|
+
q = q.unsqueeze(0)
|
384
|
+
k = k.unsqueeze(0)
|
385
|
+
v = v.unsqueeze(0)
|
386
|
+
# b = b.unsqueeze(0) not sure why this and only this has a batch dim?
|
387
|
+
mask = mask.unsqueeze(0)
|
388
|
+
|
389
|
+
if len(q.shape) != 5:
|
390
|
+
raise ValueError(f"Trifast expects q/k/v to be 5D, found {len(q.shape)}")
|
391
|
+
|
392
|
+
# Reorder q/k/v
|
393
|
+
q = rearrange(q, "b i h j d -> b h i j d")
|
394
|
+
k = rearrange(k, "b i h j d -> b h i j d")
|
395
|
+
v = rearrange(v, "b i h j d -> b h i j d")
|
396
|
+
|
397
|
+
# Make mask the right shape.
|
398
|
+
mask = rearrange(mask, "b i () () j -> b i j").bool()
|
399
|
+
|
400
|
+
# Delay import to here to avoid initializing cuda too early
|
401
|
+
from trifast import triangle_attention
|
402
|
+
|
403
|
+
o = triangle_attention(q, k, v, b, mask)
|
404
|
+
o = rearrange(o, "b h i j d -> b i j h d")
|
405
|
+
|
406
|
+
# Remove the batch dim if we added it.
|
407
|
+
if orig_n_dims == 4:
|
408
|
+
o = o.squeeze(0)
|
409
|
+
return o
|