boltz-vsynthes 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/__init__.py +7 -0
- boltz/data/__init__.py +0 -0
- boltz/data/const.py +1184 -0
- boltz/data/crop/__init__.py +0 -0
- boltz/data/crop/affinity.py +164 -0
- boltz/data/crop/boltz.py +296 -0
- boltz/data/crop/cropper.py +45 -0
- boltz/data/feature/__init__.py +0 -0
- boltz/data/feature/featurizer.py +1230 -0
- boltz/data/feature/featurizerv2.py +2208 -0
- boltz/data/feature/symmetry.py +602 -0
- boltz/data/filter/__init__.py +0 -0
- boltz/data/filter/dynamic/__init__.py +0 -0
- boltz/data/filter/dynamic/date.py +76 -0
- boltz/data/filter/dynamic/filter.py +24 -0
- boltz/data/filter/dynamic/max_residues.py +37 -0
- boltz/data/filter/dynamic/resolution.py +34 -0
- boltz/data/filter/dynamic/size.py +38 -0
- boltz/data/filter/dynamic/subset.py +42 -0
- boltz/data/filter/static/__init__.py +0 -0
- boltz/data/filter/static/filter.py +26 -0
- boltz/data/filter/static/ligand.py +37 -0
- boltz/data/filter/static/polymer.py +299 -0
- boltz/data/module/__init__.py +0 -0
- boltz/data/module/inference.py +307 -0
- boltz/data/module/inferencev2.py +429 -0
- boltz/data/module/training.py +684 -0
- boltz/data/module/trainingv2.py +660 -0
- boltz/data/mol.py +900 -0
- boltz/data/msa/__init__.py +0 -0
- boltz/data/msa/mmseqs2.py +235 -0
- boltz/data/pad.py +84 -0
- boltz/data/parse/__init__.py +0 -0
- boltz/data/parse/a3m.py +134 -0
- boltz/data/parse/csv.py +100 -0
- boltz/data/parse/fasta.py +138 -0
- boltz/data/parse/mmcif.py +1239 -0
- boltz/data/parse/mmcif_with_constraints.py +1607 -0
- boltz/data/parse/schema.py +1851 -0
- boltz/data/parse/yaml.py +68 -0
- boltz/data/sample/__init__.py +0 -0
- boltz/data/sample/cluster.py +283 -0
- boltz/data/sample/distillation.py +57 -0
- boltz/data/sample/random.py +39 -0
- boltz/data/sample/sampler.py +49 -0
- boltz/data/tokenize/__init__.py +0 -0
- boltz/data/tokenize/boltz.py +195 -0
- boltz/data/tokenize/boltz2.py +396 -0
- boltz/data/tokenize/tokenizer.py +24 -0
- boltz/data/types.py +777 -0
- boltz/data/write/__init__.py +0 -0
- boltz/data/write/mmcif.py +305 -0
- boltz/data/write/pdb.py +171 -0
- boltz/data/write/utils.py +23 -0
- boltz/data/write/writer.py +330 -0
- boltz/main.py +1292 -0
- boltz/model/__init__.py +0 -0
- boltz/model/layers/__init__.py +0 -0
- boltz/model/layers/attention.py +132 -0
- boltz/model/layers/attentionv2.py +111 -0
- boltz/model/layers/confidence_utils.py +231 -0
- boltz/model/layers/dropout.py +34 -0
- boltz/model/layers/initialize.py +100 -0
- boltz/model/layers/outer_product_mean.py +98 -0
- boltz/model/layers/pair_averaging.py +135 -0
- boltz/model/layers/pairformer.py +337 -0
- boltz/model/layers/relative.py +58 -0
- boltz/model/layers/transition.py +78 -0
- boltz/model/layers/triangular_attention/__init__.py +0 -0
- boltz/model/layers/triangular_attention/attention.py +189 -0
- boltz/model/layers/triangular_attention/primitives.py +409 -0
- boltz/model/layers/triangular_attention/utils.py +380 -0
- boltz/model/layers/triangular_mult.py +212 -0
- boltz/model/loss/__init__.py +0 -0
- boltz/model/loss/bfactor.py +49 -0
- boltz/model/loss/confidence.py +590 -0
- boltz/model/loss/confidencev2.py +621 -0
- boltz/model/loss/diffusion.py +171 -0
- boltz/model/loss/diffusionv2.py +134 -0
- boltz/model/loss/distogram.py +48 -0
- boltz/model/loss/distogramv2.py +105 -0
- boltz/model/loss/validation.py +1025 -0
- boltz/model/models/__init__.py +0 -0
- boltz/model/models/boltz1.py +1286 -0
- boltz/model/models/boltz2.py +1249 -0
- boltz/model/modules/__init__.py +0 -0
- boltz/model/modules/affinity.py +223 -0
- boltz/model/modules/confidence.py +481 -0
- boltz/model/modules/confidence_utils.py +181 -0
- boltz/model/modules/confidencev2.py +495 -0
- boltz/model/modules/diffusion.py +844 -0
- boltz/model/modules/diffusion_conditioning.py +116 -0
- boltz/model/modules/diffusionv2.py +677 -0
- boltz/model/modules/encoders.py +639 -0
- boltz/model/modules/encodersv2.py +565 -0
- boltz/model/modules/transformers.py +322 -0
- boltz/model/modules/transformersv2.py +261 -0
- boltz/model/modules/trunk.py +688 -0
- boltz/model/modules/trunkv2.py +828 -0
- boltz/model/modules/utils.py +303 -0
- boltz/model/optim/__init__.py +0 -0
- boltz/model/optim/ema.py +389 -0
- boltz/model/optim/scheduler.py +99 -0
- boltz/model/potentials/__init__.py +0 -0
- boltz/model/potentials/potentials.py +497 -0
- boltz/model/potentials/schedules.py +32 -0
- boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
- boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
- boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
- boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
- boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
- boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,322 @@
|
|
1
|
+
# started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang
|
2
|
+
|
3
|
+
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
4
|
+
from torch import nn, sigmoid
|
5
|
+
from torch.nn import (
|
6
|
+
LayerNorm,
|
7
|
+
Linear,
|
8
|
+
Module,
|
9
|
+
ModuleList,
|
10
|
+
Sequential,
|
11
|
+
)
|
12
|
+
|
13
|
+
from boltz.model.layers.attention import AttentionPairBias
|
14
|
+
from boltz.model.modules.utils import LinearNoBias, SwiGLU, default
|
15
|
+
|
16
|
+
|
17
|
+
class AdaLN(Module):
|
18
|
+
"""Adaptive Layer Normalization"""
|
19
|
+
|
20
|
+
def __init__(self, dim, dim_single_cond):
|
21
|
+
"""Initialize the adaptive layer normalization.
|
22
|
+
|
23
|
+
Parameters
|
24
|
+
----------
|
25
|
+
dim : int
|
26
|
+
The input dimension.
|
27
|
+
dim_single_cond : int
|
28
|
+
The single condition dimension.
|
29
|
+
|
30
|
+
"""
|
31
|
+
super().__init__()
|
32
|
+
self.a_norm = LayerNorm(dim, elementwise_affine=False, bias=False)
|
33
|
+
self.s_norm = LayerNorm(dim_single_cond, bias=False)
|
34
|
+
self.s_scale = Linear(dim_single_cond, dim)
|
35
|
+
self.s_bias = LinearNoBias(dim_single_cond, dim)
|
36
|
+
|
37
|
+
def forward(self, a, s):
|
38
|
+
a = self.a_norm(a)
|
39
|
+
s = self.s_norm(s)
|
40
|
+
a = sigmoid(self.s_scale(s)) * a + self.s_bias(s)
|
41
|
+
return a
|
42
|
+
|
43
|
+
|
44
|
+
class ConditionedTransitionBlock(Module):
|
45
|
+
"""Conditioned Transition Block"""
|
46
|
+
|
47
|
+
def __init__(self, dim_single, dim_single_cond, expansion_factor=2):
|
48
|
+
"""Initialize the conditioned transition block.
|
49
|
+
|
50
|
+
Parameters
|
51
|
+
----------
|
52
|
+
dim_single : int
|
53
|
+
The single dimension.
|
54
|
+
dim_single_cond : int
|
55
|
+
The single condition dimension.
|
56
|
+
expansion_factor : int, optional
|
57
|
+
The expansion factor, by default 2
|
58
|
+
|
59
|
+
"""
|
60
|
+
super().__init__()
|
61
|
+
|
62
|
+
self.adaln = AdaLN(dim_single, dim_single_cond)
|
63
|
+
|
64
|
+
dim_inner = int(dim_single * expansion_factor)
|
65
|
+
self.swish_gate = Sequential(
|
66
|
+
LinearNoBias(dim_single, dim_inner * 2),
|
67
|
+
SwiGLU(),
|
68
|
+
)
|
69
|
+
self.a_to_b = LinearNoBias(dim_single, dim_inner)
|
70
|
+
self.b_to_a = LinearNoBias(dim_inner, dim_single)
|
71
|
+
|
72
|
+
output_projection_linear = Linear(dim_single_cond, dim_single)
|
73
|
+
nn.init.zeros_(output_projection_linear.weight)
|
74
|
+
nn.init.constant_(output_projection_linear.bias, -2.0)
|
75
|
+
|
76
|
+
self.output_projection = nn.Sequential(output_projection_linear, nn.Sigmoid())
|
77
|
+
|
78
|
+
def forward(
|
79
|
+
self,
|
80
|
+
a,
|
81
|
+
s,
|
82
|
+
):
|
83
|
+
a = self.adaln(a, s)
|
84
|
+
b = self.swish_gate(a) * self.a_to_b(a)
|
85
|
+
a = self.output_projection(s) * self.b_to_a(b)
|
86
|
+
|
87
|
+
return a
|
88
|
+
|
89
|
+
|
90
|
+
class DiffusionTransformer(Module):
|
91
|
+
"""Diffusion Transformer"""
|
92
|
+
|
93
|
+
def __init__(
|
94
|
+
self,
|
95
|
+
depth,
|
96
|
+
heads,
|
97
|
+
dim=384,
|
98
|
+
dim_single_cond=None,
|
99
|
+
dim_pairwise=128,
|
100
|
+
activation_checkpointing=False,
|
101
|
+
offload_to_cpu=False,
|
102
|
+
):
|
103
|
+
"""Initialize the diffusion transformer.
|
104
|
+
|
105
|
+
Parameters
|
106
|
+
----------
|
107
|
+
depth : int
|
108
|
+
The depth.
|
109
|
+
heads : int
|
110
|
+
The number of heads.
|
111
|
+
dim : int, optional
|
112
|
+
The dimension, by default 384
|
113
|
+
dim_single_cond : int, optional
|
114
|
+
The single condition dimension, by default None
|
115
|
+
dim_pairwise : int, optional
|
116
|
+
The pairwise dimension, by default 128
|
117
|
+
activation_checkpointing : bool, optional
|
118
|
+
Whether to use activation checkpointing, by default False
|
119
|
+
offload_to_cpu : bool, optional
|
120
|
+
Whether to offload to CPU, by default False
|
121
|
+
|
122
|
+
"""
|
123
|
+
super().__init__()
|
124
|
+
self.activation_checkpointing = activation_checkpointing
|
125
|
+
dim_single_cond = default(dim_single_cond, dim)
|
126
|
+
|
127
|
+
self.layers = ModuleList()
|
128
|
+
for _ in range(depth):
|
129
|
+
if activation_checkpointing:
|
130
|
+
self.layers.append(
|
131
|
+
checkpoint_wrapper(
|
132
|
+
DiffusionTransformerLayer(
|
133
|
+
heads,
|
134
|
+
dim,
|
135
|
+
dim_single_cond,
|
136
|
+
dim_pairwise,
|
137
|
+
),
|
138
|
+
offload_to_cpu=offload_to_cpu,
|
139
|
+
)
|
140
|
+
)
|
141
|
+
else:
|
142
|
+
self.layers.append(
|
143
|
+
DiffusionTransformerLayer(
|
144
|
+
heads,
|
145
|
+
dim,
|
146
|
+
dim_single_cond,
|
147
|
+
dim_pairwise,
|
148
|
+
)
|
149
|
+
)
|
150
|
+
|
151
|
+
def forward(
|
152
|
+
self,
|
153
|
+
a,
|
154
|
+
s,
|
155
|
+
z,
|
156
|
+
mask=None,
|
157
|
+
to_keys=None,
|
158
|
+
multiplicity=1,
|
159
|
+
model_cache=None,
|
160
|
+
):
|
161
|
+
for i, layer in enumerate(self.layers):
|
162
|
+
layer_cache = None
|
163
|
+
if model_cache is not None:
|
164
|
+
prefix_cache = "layer_" + str(i)
|
165
|
+
if prefix_cache not in model_cache:
|
166
|
+
model_cache[prefix_cache] = {}
|
167
|
+
layer_cache = model_cache[prefix_cache]
|
168
|
+
a = layer(
|
169
|
+
a,
|
170
|
+
s,
|
171
|
+
z,
|
172
|
+
mask=mask,
|
173
|
+
to_keys=to_keys,
|
174
|
+
multiplicity=multiplicity,
|
175
|
+
layer_cache=layer_cache,
|
176
|
+
)
|
177
|
+
return a
|
178
|
+
|
179
|
+
|
180
|
+
class DiffusionTransformerLayer(Module):
|
181
|
+
"""Diffusion Transformer Layer"""
|
182
|
+
|
183
|
+
def __init__(
|
184
|
+
self,
|
185
|
+
heads,
|
186
|
+
dim=384,
|
187
|
+
dim_single_cond=None,
|
188
|
+
dim_pairwise=128,
|
189
|
+
):
|
190
|
+
"""Initialize the diffusion transformer layer.
|
191
|
+
|
192
|
+
Parameters
|
193
|
+
----------
|
194
|
+
heads : int
|
195
|
+
The number of heads.
|
196
|
+
dim : int, optional
|
197
|
+
The dimension, by default 384
|
198
|
+
dim_single_cond : int, optional
|
199
|
+
The single condition dimension, by default None
|
200
|
+
dim_pairwise : int, optional
|
201
|
+
The pairwise dimension, by default 128
|
202
|
+
|
203
|
+
"""
|
204
|
+
super().__init__()
|
205
|
+
|
206
|
+
dim_single_cond = default(dim_single_cond, dim)
|
207
|
+
|
208
|
+
self.adaln = AdaLN(dim, dim_single_cond)
|
209
|
+
|
210
|
+
self.pair_bias_attn = AttentionPairBias(
|
211
|
+
c_s=dim, c_z=dim_pairwise, num_heads=heads, initial_norm=False
|
212
|
+
)
|
213
|
+
|
214
|
+
self.output_projection_linear = Linear(dim_single_cond, dim)
|
215
|
+
nn.init.zeros_(self.output_projection_linear.weight)
|
216
|
+
nn.init.constant_(self.output_projection_linear.bias, -2.0)
|
217
|
+
|
218
|
+
self.output_projection = nn.Sequential(
|
219
|
+
self.output_projection_linear, nn.Sigmoid()
|
220
|
+
)
|
221
|
+
self.transition = ConditionedTransitionBlock(
|
222
|
+
dim_single=dim, dim_single_cond=dim_single_cond
|
223
|
+
)
|
224
|
+
|
225
|
+
def forward(
|
226
|
+
self,
|
227
|
+
a,
|
228
|
+
s,
|
229
|
+
z,
|
230
|
+
mask=None,
|
231
|
+
to_keys=None,
|
232
|
+
multiplicity=1,
|
233
|
+
layer_cache=None,
|
234
|
+
):
|
235
|
+
b = self.adaln(a, s)
|
236
|
+
b = self.pair_bias_attn(
|
237
|
+
s=b,
|
238
|
+
z=z,
|
239
|
+
mask=mask,
|
240
|
+
multiplicity=multiplicity,
|
241
|
+
to_keys=to_keys,
|
242
|
+
model_cache=layer_cache,
|
243
|
+
)
|
244
|
+
b = self.output_projection(s) * b
|
245
|
+
|
246
|
+
# NOTE: Added residual connection!
|
247
|
+
a = a + b
|
248
|
+
a = a + self.transition(a, s)
|
249
|
+
return a
|
250
|
+
|
251
|
+
|
252
|
+
class AtomTransformer(Module):
|
253
|
+
"""Atom Transformer"""
|
254
|
+
|
255
|
+
def __init__(
|
256
|
+
self,
|
257
|
+
attn_window_queries=None,
|
258
|
+
attn_window_keys=None,
|
259
|
+
**diffusion_transformer_kwargs,
|
260
|
+
):
|
261
|
+
"""Initialize the atom transformer.
|
262
|
+
|
263
|
+
Parameters
|
264
|
+
----------
|
265
|
+
attn_window_queries : int, optional
|
266
|
+
The attention window queries, by default None
|
267
|
+
attn_window_keys : int, optional
|
268
|
+
The attention window keys, by default None
|
269
|
+
diffusion_transformer_kwargs : dict
|
270
|
+
The diffusion transformer keyword arguments
|
271
|
+
|
272
|
+
"""
|
273
|
+
super().__init__()
|
274
|
+
self.attn_window_queries = attn_window_queries
|
275
|
+
self.attn_window_keys = attn_window_keys
|
276
|
+
self.diffusion_transformer = DiffusionTransformer(
|
277
|
+
**diffusion_transformer_kwargs
|
278
|
+
)
|
279
|
+
|
280
|
+
def forward(
|
281
|
+
self,
|
282
|
+
q,
|
283
|
+
c,
|
284
|
+
p,
|
285
|
+
to_keys=None,
|
286
|
+
mask=None,
|
287
|
+
multiplicity=1,
|
288
|
+
model_cache=None,
|
289
|
+
):
|
290
|
+
W = self.attn_window_queries
|
291
|
+
H = self.attn_window_keys
|
292
|
+
|
293
|
+
if W is not None:
|
294
|
+
B, N, D = q.shape
|
295
|
+
NW = N // W
|
296
|
+
|
297
|
+
# reshape tokens
|
298
|
+
q = q.view((B * NW, W, -1))
|
299
|
+
c = c.view((B * NW, W, -1))
|
300
|
+
if mask is not None:
|
301
|
+
mask = mask.view(B * NW, W)
|
302
|
+
p = p.view((p.shape[0] * NW, W, H, -1))
|
303
|
+
|
304
|
+
to_keys_new = lambda x: to_keys(x.view(B, NW * W, -1)).view(B * NW, H, -1)
|
305
|
+
else:
|
306
|
+
to_keys_new = None
|
307
|
+
|
308
|
+
# main transformer
|
309
|
+
q = self.diffusion_transformer(
|
310
|
+
a=q,
|
311
|
+
s=c,
|
312
|
+
z=p,
|
313
|
+
mask=mask.float(),
|
314
|
+
multiplicity=multiplicity,
|
315
|
+
to_keys=to_keys_new,
|
316
|
+
model_cache=model_cache,
|
317
|
+
)
|
318
|
+
|
319
|
+
if W is not None:
|
320
|
+
q = q.view((B, NW * W, D))
|
321
|
+
|
322
|
+
return q
|
@@ -0,0 +1,261 @@
|
|
1
|
+
# started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from torch import nn, sigmoid
|
5
|
+
from torch.nn import (
|
6
|
+
LayerNorm,
|
7
|
+
Linear,
|
8
|
+
Module,
|
9
|
+
ModuleList,
|
10
|
+
Sequential,
|
11
|
+
)
|
12
|
+
|
13
|
+
from boltz.model.layers.attentionv2 import AttentionPairBias
|
14
|
+
from boltz.model.modules.utils import LinearNoBias, SwiGLU, default
|
15
|
+
|
16
|
+
|
17
|
+
class AdaLN(Module):
|
18
|
+
"""Algorithm 26"""
|
19
|
+
|
20
|
+
def __init__(self, dim, dim_single_cond):
|
21
|
+
super().__init__()
|
22
|
+
self.a_norm = LayerNorm(dim, elementwise_affine=False, bias=False)
|
23
|
+
self.s_norm = LayerNorm(dim_single_cond, bias=False)
|
24
|
+
self.s_scale = Linear(dim_single_cond, dim)
|
25
|
+
self.s_bias = LinearNoBias(dim_single_cond, dim)
|
26
|
+
|
27
|
+
def forward(self, a, s):
|
28
|
+
a = self.a_norm(a)
|
29
|
+
s = self.s_norm(s)
|
30
|
+
a = sigmoid(self.s_scale(s)) * a + self.s_bias(s)
|
31
|
+
return a
|
32
|
+
|
33
|
+
|
34
|
+
class ConditionedTransitionBlock(Module):
|
35
|
+
"""Algorithm 25"""
|
36
|
+
|
37
|
+
def __init__(self, dim_single, dim_single_cond, expansion_factor=2):
|
38
|
+
super().__init__()
|
39
|
+
|
40
|
+
self.adaln = AdaLN(dim_single, dim_single_cond)
|
41
|
+
|
42
|
+
dim_inner = int(dim_single * expansion_factor)
|
43
|
+
self.swish_gate = Sequential(
|
44
|
+
LinearNoBias(dim_single, dim_inner * 2),
|
45
|
+
SwiGLU(),
|
46
|
+
)
|
47
|
+
self.a_to_b = LinearNoBias(dim_single, dim_inner)
|
48
|
+
self.b_to_a = LinearNoBias(dim_inner, dim_single)
|
49
|
+
|
50
|
+
output_projection_linear = Linear(dim_single_cond, dim_single)
|
51
|
+
nn.init.zeros_(output_projection_linear.weight)
|
52
|
+
nn.init.constant_(output_projection_linear.bias, -2.0)
|
53
|
+
|
54
|
+
self.output_projection = nn.Sequential(output_projection_linear, nn.Sigmoid())
|
55
|
+
|
56
|
+
def forward(
|
57
|
+
self,
|
58
|
+
a, # Float['... d']
|
59
|
+
s,
|
60
|
+
): # -> Float['... d']:
|
61
|
+
a = self.adaln(a, s)
|
62
|
+
b = self.swish_gate(a) * self.a_to_b(a)
|
63
|
+
a = self.output_projection(s) * self.b_to_a(b)
|
64
|
+
|
65
|
+
return a
|
66
|
+
|
67
|
+
|
68
|
+
class DiffusionTransformer(Module):
|
69
|
+
"""Algorithm 23"""
|
70
|
+
|
71
|
+
def __init__(
|
72
|
+
self,
|
73
|
+
depth,
|
74
|
+
heads,
|
75
|
+
dim=384,
|
76
|
+
dim_single_cond=None,
|
77
|
+
pair_bias_attn=True,
|
78
|
+
activation_checkpointing=False,
|
79
|
+
post_layer_norm=False,
|
80
|
+
):
|
81
|
+
super().__init__()
|
82
|
+
self.activation_checkpointing = activation_checkpointing
|
83
|
+
dim_single_cond = default(dim_single_cond, dim)
|
84
|
+
self.pair_bias_attn = pair_bias_attn
|
85
|
+
|
86
|
+
self.layers = ModuleList()
|
87
|
+
for _ in range(depth):
|
88
|
+
self.layers.append(
|
89
|
+
DiffusionTransformerLayer(
|
90
|
+
heads,
|
91
|
+
dim,
|
92
|
+
dim_single_cond,
|
93
|
+
post_layer_norm,
|
94
|
+
)
|
95
|
+
)
|
96
|
+
|
97
|
+
def forward(
|
98
|
+
self,
|
99
|
+
a, # Float['bm n d'],
|
100
|
+
s, # Float['bm n ds'],
|
101
|
+
bias=None, # Float['b n n dp']
|
102
|
+
mask=None, # Bool['b n'] | None = None
|
103
|
+
to_keys=None,
|
104
|
+
multiplicity=1,
|
105
|
+
):
|
106
|
+
if self.pair_bias_attn:
|
107
|
+
B, N, M, D = bias.shape
|
108
|
+
L = len(self.layers)
|
109
|
+
bias = bias.view(B, N, M, L, D // L)
|
110
|
+
|
111
|
+
for i, layer in enumerate(self.layers):
|
112
|
+
if self.pair_bias_attn:
|
113
|
+
bias_l = bias[:, :, :, i]
|
114
|
+
else:
|
115
|
+
bias_l = None
|
116
|
+
|
117
|
+
if self.activation_checkpointing and self.training:
|
118
|
+
a = torch.utils.checkpoint.checkpoint(
|
119
|
+
layer,
|
120
|
+
a,
|
121
|
+
s,
|
122
|
+
bias_l,
|
123
|
+
mask,
|
124
|
+
to_keys,
|
125
|
+
multiplicity,
|
126
|
+
)
|
127
|
+
|
128
|
+
else:
|
129
|
+
a = layer(
|
130
|
+
a, # Float['bm n d'],
|
131
|
+
s, # Float['bm n ds'],
|
132
|
+
bias_l, # Float['b n n dp']
|
133
|
+
mask, # Bool['b n'] | None = None
|
134
|
+
to_keys,
|
135
|
+
multiplicity,
|
136
|
+
)
|
137
|
+
return a
|
138
|
+
|
139
|
+
|
140
|
+
class DiffusionTransformerLayer(Module):
|
141
|
+
"""Algorithm 23"""
|
142
|
+
|
143
|
+
def __init__(
|
144
|
+
self,
|
145
|
+
heads,
|
146
|
+
dim=384,
|
147
|
+
dim_single_cond=None,
|
148
|
+
post_layer_norm=False,
|
149
|
+
):
|
150
|
+
super().__init__()
|
151
|
+
|
152
|
+
dim_single_cond = default(dim_single_cond, dim)
|
153
|
+
|
154
|
+
self.adaln = AdaLN(dim, dim_single_cond)
|
155
|
+
self.pair_bias_attn = AttentionPairBias(
|
156
|
+
c_s=dim, num_heads=heads, compute_pair_bias=False
|
157
|
+
)
|
158
|
+
|
159
|
+
self.output_projection_linear = Linear(dim_single_cond, dim)
|
160
|
+
nn.init.zeros_(self.output_projection_linear.weight)
|
161
|
+
nn.init.constant_(self.output_projection_linear.bias, -2.0)
|
162
|
+
|
163
|
+
self.output_projection = nn.Sequential(
|
164
|
+
self.output_projection_linear, nn.Sigmoid()
|
165
|
+
)
|
166
|
+
self.transition = ConditionedTransitionBlock(
|
167
|
+
dim_single=dim, dim_single_cond=dim_single_cond
|
168
|
+
)
|
169
|
+
|
170
|
+
if post_layer_norm:
|
171
|
+
self.post_lnorm = nn.LayerNorm(dim)
|
172
|
+
else:
|
173
|
+
self.post_lnorm = nn.Identity()
|
174
|
+
|
175
|
+
def forward(
|
176
|
+
self,
|
177
|
+
a, # Float['bm n d'],
|
178
|
+
s, # Float['bm n ds'],
|
179
|
+
bias=None, # Float['b n n dp']
|
180
|
+
mask=None, # Bool['b n'] | None = None
|
181
|
+
to_keys=None,
|
182
|
+
multiplicity=1,
|
183
|
+
):
|
184
|
+
b = self.adaln(a, s)
|
185
|
+
|
186
|
+
k_in = b
|
187
|
+
if to_keys is not None:
|
188
|
+
k_in = to_keys(b)
|
189
|
+
mask = to_keys(mask.unsqueeze(-1)).squeeze(-1)
|
190
|
+
|
191
|
+
if self.pair_bias_attn:
|
192
|
+
b = self.pair_bias_attn(
|
193
|
+
s=b,
|
194
|
+
z=bias,
|
195
|
+
mask=mask,
|
196
|
+
multiplicity=multiplicity,
|
197
|
+
k_in=k_in,
|
198
|
+
)
|
199
|
+
else:
|
200
|
+
b = self.no_pair_bias_attn(s=b, mask=mask, k_in=k_in)
|
201
|
+
|
202
|
+
b = self.output_projection(s) * b
|
203
|
+
|
204
|
+
a = a + b
|
205
|
+
a = a + self.transition(a, s)
|
206
|
+
|
207
|
+
a = self.post_lnorm(a)
|
208
|
+
return a
|
209
|
+
|
210
|
+
|
211
|
+
class AtomTransformer(Module):
|
212
|
+
"""Algorithm 7"""
|
213
|
+
|
214
|
+
def __init__(
|
215
|
+
self,
|
216
|
+
attn_window_queries,
|
217
|
+
attn_window_keys,
|
218
|
+
**diffusion_transformer_kwargs,
|
219
|
+
):
|
220
|
+
super().__init__()
|
221
|
+
self.attn_window_queries = attn_window_queries
|
222
|
+
self.attn_window_keys = attn_window_keys
|
223
|
+
self.diffusion_transformer = DiffusionTransformer(
|
224
|
+
**diffusion_transformer_kwargs
|
225
|
+
)
|
226
|
+
|
227
|
+
def forward(
|
228
|
+
self,
|
229
|
+
q, # Float['b m d'],
|
230
|
+
c, # Float['b m ds'],
|
231
|
+
bias, # Float['b m m dp']
|
232
|
+
to_keys,
|
233
|
+
mask, # Bool['b m'] | None = None
|
234
|
+
multiplicity=1,
|
235
|
+
):
|
236
|
+
W = self.attn_window_queries
|
237
|
+
H = self.attn_window_keys
|
238
|
+
|
239
|
+
B, N, D = q.shape
|
240
|
+
NW = N // W
|
241
|
+
|
242
|
+
# reshape tokens
|
243
|
+
q = q.view((B * NW, W, -1))
|
244
|
+
c = c.view((B * NW, W, -1))
|
245
|
+
mask = mask.view(B * NW, W)
|
246
|
+
bias = bias.view((bias.shape[0] * NW, W, H, -1))
|
247
|
+
|
248
|
+
to_keys_new = lambda x: to_keys(x.view(B, NW * W, -1)).view(B * NW, H, -1)
|
249
|
+
|
250
|
+
# main transformer
|
251
|
+
q = self.diffusion_transformer(
|
252
|
+
a=q,
|
253
|
+
s=c,
|
254
|
+
bias=bias,
|
255
|
+
mask=mask.float(),
|
256
|
+
multiplicity=multiplicity,
|
257
|
+
to_keys=to_keys_new,
|
258
|
+
)
|
259
|
+
|
260
|
+
q = q.view((B, NW * W, D))
|
261
|
+
return q
|