rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,417 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
from rfd3.model.layers.block_utils import (
|
|
7
|
+
bucketize_scaled_distogram,
|
|
8
|
+
pairwise_mean_pool,
|
|
9
|
+
)
|
|
10
|
+
from rfd3.model.layers.blocks import (
|
|
11
|
+
Downcast,
|
|
12
|
+
LocalAtomTransformer,
|
|
13
|
+
OneDFeatureEmbedder,
|
|
14
|
+
PositionPairDistEmbedder,
|
|
15
|
+
RelativePositionEncodingWithIndexRemoval,
|
|
16
|
+
SinusoidalDistEmbed,
|
|
17
|
+
)
|
|
18
|
+
from rfd3.model.layers.chunked_pairwise import (
|
|
19
|
+
ChunkedPairwiseEmbedder,
|
|
20
|
+
ChunkedPositionPairDistEmbedder,
|
|
21
|
+
ChunkedSinusoidalDistEmbed,
|
|
22
|
+
)
|
|
23
|
+
from rfd3.model.layers.layer_utils import (
|
|
24
|
+
RMSNorm,
|
|
25
|
+
Transition,
|
|
26
|
+
linearNoBias,
|
|
27
|
+
)
|
|
28
|
+
from rfd3.model.layers.pairformer_layers import PairformerBlock
|
|
29
|
+
|
|
30
|
+
from foundry.common import exists
|
|
31
|
+
from foundry.training.checkpoint import activation_checkpointing
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class TokenInitializer(nn.Module):
|
|
37
|
+
"""
|
|
38
|
+
Token embedding module for RFD3
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
c_s,
|
|
44
|
+
c_z,
|
|
45
|
+
c_atom,
|
|
46
|
+
c_atompair,
|
|
47
|
+
relative_position_encoding,
|
|
48
|
+
n_pairformer_blocks,
|
|
49
|
+
pairformer_block,
|
|
50
|
+
downcast,
|
|
51
|
+
token_1d_features,
|
|
52
|
+
atom_1d_features,
|
|
53
|
+
atom_transformer,
|
|
54
|
+
use_chunked_pll=False, # New parameter for memory optimization
|
|
55
|
+
):
|
|
56
|
+
super().__init__()
|
|
57
|
+
|
|
58
|
+
# Store chunked mode flag
|
|
59
|
+
self.use_chunked_pll = use_chunked_pll
|
|
60
|
+
|
|
61
|
+
# Features
|
|
62
|
+
self.atom_1d_embedder_1 = OneDFeatureEmbedder(atom_1d_features, c_s)
|
|
63
|
+
self.atom_1d_embedder_2 = OneDFeatureEmbedder(atom_1d_features, c_atom)
|
|
64
|
+
self.token_1d_embedder = OneDFeatureEmbedder(token_1d_features, c_s)
|
|
65
|
+
|
|
66
|
+
self.downcast_atom = Downcast(c_atom=c_s, c_token=c_s, c_s=None, **downcast)
|
|
67
|
+
self.transition_post_token = Transition(c=c_s, n=2)
|
|
68
|
+
self.transition_post_atom = Transition(c=c_s, n=2)
|
|
69
|
+
self.process_s_init = nn.Sequential(
|
|
70
|
+
RMSNorm(c_s),
|
|
71
|
+
linearNoBias(c_s, c_s),
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# Operations to mix into Z_II and S_I
|
|
75
|
+
self.to_z_init_i = linearNoBias(c_s, c_z)
|
|
76
|
+
self.to_z_init_j = linearNoBias(c_s, c_z)
|
|
77
|
+
self.relative_position_encoding = RelativePositionEncodingWithIndexRemoval(
|
|
78
|
+
c_z=c_z, **relative_position_encoding
|
|
79
|
+
)
|
|
80
|
+
self.relative_position_encoding2 = RelativePositionEncodingWithIndexRemoval(
|
|
81
|
+
c_z=c_z, **relative_position_encoding
|
|
82
|
+
)
|
|
83
|
+
self.process_token_bonds = linearNoBias(1, c_z)
|
|
84
|
+
|
|
85
|
+
# Processing of Z_init
|
|
86
|
+
self.process_z_init = nn.Sequential(
|
|
87
|
+
RMSNorm(c_z * 2),
|
|
88
|
+
linearNoBias(c_z * 2, c_z),
|
|
89
|
+
)
|
|
90
|
+
self.transition_1 = nn.ModuleList(
|
|
91
|
+
[
|
|
92
|
+
Transition(c=c_z, n=2),
|
|
93
|
+
Transition(c=c_z, n=2),
|
|
94
|
+
]
|
|
95
|
+
)
|
|
96
|
+
self.ref_pos_embedder_tok = PositionPairDistEmbedder(c_z, embed_frame=False)
|
|
97
|
+
|
|
98
|
+
# Pairformer without triangle updates
|
|
99
|
+
self.transformer_stack = nn.ModuleList(
|
|
100
|
+
[
|
|
101
|
+
PairformerBlock(c_s=c_s, c_z=c_z, **pairformer_block)
|
|
102
|
+
for _ in range(n_pairformer_blocks)
|
|
103
|
+
]
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
#############################################################################
|
|
107
|
+
# Token track processing
|
|
108
|
+
self.process_s_trunk = nn.Sequential(RMSNorm(c_s), linearNoBias(c_s, c_atom))
|
|
109
|
+
self.process_single_l = nn.Sequential(
|
|
110
|
+
nn.ReLU(), linearNoBias(c_atom, c_atompair)
|
|
111
|
+
)
|
|
112
|
+
self.process_single_m = nn.Sequential(
|
|
113
|
+
nn.ReLU(), linearNoBias(c_atom, c_atompair)
|
|
114
|
+
)
|
|
115
|
+
self.process_z = nn.Sequential(RMSNorm(c_z), linearNoBias(c_z, c_atompair))
|
|
116
|
+
|
|
117
|
+
# ALWAYS create these MLPs - they will be shared between chunked and standard modes
|
|
118
|
+
self.motif_pos_embedder = SinusoidalDistEmbed(c_atompair=c_atompair)
|
|
119
|
+
self.ref_pos_embedder = PositionPairDistEmbedder(c_atompair, embed_frame=False)
|
|
120
|
+
self.pair_mlp = nn.Sequential(
|
|
121
|
+
nn.ReLU(),
|
|
122
|
+
linearNoBias(c_atompair, c_atompair),
|
|
123
|
+
nn.ReLU(),
|
|
124
|
+
linearNoBias(c_atompair, c_atompair),
|
|
125
|
+
nn.ReLU(),
|
|
126
|
+
linearNoBias(c_atompair, c_atompair),
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
# Atom pair feature processing
|
|
130
|
+
if self.use_chunked_pll:
|
|
131
|
+
# Initialize chunked embedders and share the trained MLPs!
|
|
132
|
+
self.chunked_pairwise_embedder = ChunkedPairwiseEmbedder(
|
|
133
|
+
c_atompair=c_atompair,
|
|
134
|
+
motif_pos_embedder=ChunkedSinusoidalDistEmbed(c_atompair=c_atompair),
|
|
135
|
+
ref_pos_embedder=ChunkedPositionPairDistEmbedder(
|
|
136
|
+
c_atompair, embed_frame=False
|
|
137
|
+
),
|
|
138
|
+
process_single_l=self.process_single_l, # Share trained parameters!
|
|
139
|
+
process_single_m=self.process_single_m, # Share trained parameters!
|
|
140
|
+
process_z=self.process_z, # Share trained parameters!
|
|
141
|
+
pair_mlp=self.pair_mlp, # Share trained parameters!
|
|
142
|
+
)
|
|
143
|
+
self.process_pll = linearNoBias(c_atompair, c_atompair)
|
|
144
|
+
self.project_pll = linearNoBias(c_atompair, c_z)
|
|
145
|
+
|
|
146
|
+
if atom_transformer["n_blocks"] > 0:
|
|
147
|
+
self.atom_transformer = LocalAtomTransformer(
|
|
148
|
+
c_atom=c_atom, c_s=None, c_atompair=c_atompair, **atom_transformer
|
|
149
|
+
)
|
|
150
|
+
else:
|
|
151
|
+
self.atom_transformer = None
|
|
152
|
+
|
|
153
|
+
# Post-processing
|
|
154
|
+
# self.process_s_post = nn.Sequential(
|
|
155
|
+
# RMSNorm(c_s),
|
|
156
|
+
# linearNoBias(c_s, c_s),
|
|
157
|
+
# )
|
|
158
|
+
# self.process_z_post = nn.Sequential(
|
|
159
|
+
# RMSNorm(c_z),
|
|
160
|
+
# linearNoBias(c_z, c_z),
|
|
161
|
+
# )
|
|
162
|
+
|
|
163
|
+
def forward(self, f):
|
|
164
|
+
"""
|
|
165
|
+
Provides initial representation for atom and token representations
|
|
166
|
+
"""
|
|
167
|
+
tok_idx = f["atom_to_token_map"]
|
|
168
|
+
L = len(tok_idx)
|
|
169
|
+
f["ref_atom_name_chars"] = f["ref_atom_name_chars"].reshape(L, -1)
|
|
170
|
+
I = len(f["restype"])
|
|
171
|
+
|
|
172
|
+
def init_tokens():
|
|
173
|
+
# Embed token features
|
|
174
|
+
S_I = self.token_1d_embedder(f, I)
|
|
175
|
+
S_I = S_I + self.transition_post_token(S_I)
|
|
176
|
+
|
|
177
|
+
# Embed atom features and downcast to token features
|
|
178
|
+
S_I = self.downcast_atom(
|
|
179
|
+
Q_L=self.atom_1d_embedder_1(f, L), A_I=S_I, tok_idx=tok_idx
|
|
180
|
+
)
|
|
181
|
+
S_I = S_I + self.transition_post_atom(S_I)
|
|
182
|
+
S_I = self.process_s_init(S_I)
|
|
183
|
+
|
|
184
|
+
# Embed Z_II
|
|
185
|
+
Z_init_II = self.to_z_init_i(S_I).unsqueeze(-3) + self.to_z_init_j(
|
|
186
|
+
S_I
|
|
187
|
+
).unsqueeze(-2)
|
|
188
|
+
Z_init_II = Z_init_II + self.relative_position_encoding(f)
|
|
189
|
+
Z_init_II = Z_init_II + self.process_token_bonds(
|
|
190
|
+
f["token_bonds"].unsqueeze(-1).float()
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# Embed reference coordinates of ligands
|
|
194
|
+
token_id = f["ref_space_uid"][f["is_ca"]]
|
|
195
|
+
valid_mask = (token_id.unsqueeze(-1) == token_id.unsqueeze(-2)).unsqueeze(
|
|
196
|
+
-1
|
|
197
|
+
)
|
|
198
|
+
Z_init_II = Z_init_II + self.ref_pos_embedder_tok(
|
|
199
|
+
f["ref_pos"][f["is_ca"]], valid_mask
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Run a small transformer to provide position encodings to single.
|
|
203
|
+
for block in self.transformer_stack:
|
|
204
|
+
S_I, Z_init_II = block(S_I, Z_init_II)
|
|
205
|
+
|
|
206
|
+
# Also cat the relative position encoding and mix
|
|
207
|
+
Z_init_II = torch.cat(
|
|
208
|
+
[
|
|
209
|
+
Z_init_II,
|
|
210
|
+
self.relative_position_encoding2(f),
|
|
211
|
+
],
|
|
212
|
+
dim=-1,
|
|
213
|
+
)
|
|
214
|
+
Z_init_II = self.process_z_init(Z_init_II)
|
|
215
|
+
for b in range(2):
|
|
216
|
+
Z_init_II = Z_init_II + self.transition_1[b](Z_init_II)
|
|
217
|
+
|
|
218
|
+
return {"S_init_I": S_I, "Z_init_II": Z_init_II}
|
|
219
|
+
|
|
220
|
+
@activation_checkpointing
|
|
221
|
+
def init_atoms(S_init_I, Z_init_II):
|
|
222
|
+
Q_L_init = self.atom_1d_embedder_2(f, L)
|
|
223
|
+
C_L = Q_L_init + self.process_s_trunk(S_init_I)[..., tok_idx, :]
|
|
224
|
+
|
|
225
|
+
if self.use_chunked_pll:
|
|
226
|
+
# Chunked mode: return embedder for later sparse computation
|
|
227
|
+
return {
|
|
228
|
+
"Q_L_init": Q_L_init,
|
|
229
|
+
"C_L": C_L,
|
|
230
|
+
"chunked_pairwise_embedder": self.chunked_pairwise_embedder,
|
|
231
|
+
"S_I": S_init_I,
|
|
232
|
+
"Z_II": Z_init_II,
|
|
233
|
+
}
|
|
234
|
+
else:
|
|
235
|
+
# Original full P_LL computation
|
|
236
|
+
##################################################################################
|
|
237
|
+
# Embed motif coordinates
|
|
238
|
+
valid_mask = (
|
|
239
|
+
f["is_motif_atom_with_fixed_coord"].unsqueeze(-1)
|
|
240
|
+
& f["is_motif_atom_with_fixed_coord"].unsqueeze(-2)
|
|
241
|
+
).unsqueeze(-1)
|
|
242
|
+
P_LL = self.motif_pos_embedder(
|
|
243
|
+
f["motif_pos"], valid_mask
|
|
244
|
+
) # (L, L, c_atompair)
|
|
245
|
+
|
|
246
|
+
# Embed ref pos
|
|
247
|
+
atoms_in_same_token = (
|
|
248
|
+
f["ref_space_uid"].unsqueeze(-1) == f["ref_space_uid"].unsqueeze(-2)
|
|
249
|
+
).unsqueeze(-1)
|
|
250
|
+
# Only consider ref_pos for atoms given seq (otherwise ref_pos is 0, doesn't make sense to compute)
|
|
251
|
+
atoms_has_seq = (
|
|
252
|
+
f["is_motif_atom_with_fixed_seq"].unsqueeze(-1)
|
|
253
|
+
& f["is_motif_atom_with_fixed_seq"].unsqueeze(-2)
|
|
254
|
+
).unsqueeze(-1)
|
|
255
|
+
valid_mask = atoms_in_same_token & atoms_has_seq
|
|
256
|
+
P_LL = P_LL + self.ref_pos_embedder(f["ref_pos"], valid_mask)
|
|
257
|
+
|
|
258
|
+
##################################################################################
|
|
259
|
+
|
|
260
|
+
P_LL = P_LL + (
|
|
261
|
+
self.process_single_l(C_L).unsqueeze(-2)
|
|
262
|
+
+ self.process_single_m(C_L).unsqueeze(-3)
|
|
263
|
+
)
|
|
264
|
+
P_LL = (
|
|
265
|
+
P_LL
|
|
266
|
+
+ self.process_z(Z_init_II)[..., tok_idx, :, :][..., tok_idx, :]
|
|
267
|
+
)
|
|
268
|
+
P_LL = P_LL + self.pair_mlp(P_LL)
|
|
269
|
+
P_LL = P_LL.contiguous()
|
|
270
|
+
|
|
271
|
+
# Pool P_LL to token level to provide atom-level resolution for token track
|
|
272
|
+
pooled_atom_level_features = pairwise_mean_pool(
|
|
273
|
+
pairwise_atom_features=self.process_pll(P_LL).unsqueeze(0),
|
|
274
|
+
atom_to_token_map=tok_idx,
|
|
275
|
+
I=int(tok_idx.max().item()) + 1,
|
|
276
|
+
dtype=P_LL.dtype,
|
|
277
|
+
).squeeze(0)
|
|
278
|
+
Z_init_II = Z_init_II + self.project_pll(pooled_atom_level_features)
|
|
279
|
+
|
|
280
|
+
# Mix atom conditioning features via sequence-local attention
|
|
281
|
+
if exists(self.atom_transformer):
|
|
282
|
+
C_L = self.atom_transformer(
|
|
283
|
+
C_L.unsqueeze(0), None, P_LL, indices=None, f=f, X_L=None
|
|
284
|
+
).squeeze(0)
|
|
285
|
+
|
|
286
|
+
return {
|
|
287
|
+
"Q_L_init": Q_L_init,
|
|
288
|
+
"C_L": C_L,
|
|
289
|
+
"P_LL": P_LL,
|
|
290
|
+
"S_I": S_init_I,
|
|
291
|
+
"Z_II": Z_init_II,
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
tokens = init_tokens()
|
|
295
|
+
return init_atoms(**tokens)
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
class DiffusionTokenEncoder(nn.Module):
|
|
299
|
+
def __init__(
|
|
300
|
+
self,
|
|
301
|
+
c_s,
|
|
302
|
+
c_z,
|
|
303
|
+
c_token,
|
|
304
|
+
c_atompair,
|
|
305
|
+
sigma_data,
|
|
306
|
+
n_pairformer_blocks,
|
|
307
|
+
pairformer_block,
|
|
308
|
+
use_distogram,
|
|
309
|
+
use_self,
|
|
310
|
+
use_sinusoidal_distogram_embedder=True,
|
|
311
|
+
**_,
|
|
312
|
+
):
|
|
313
|
+
super().__init__()
|
|
314
|
+
|
|
315
|
+
# Sequence processing
|
|
316
|
+
self.transition_1 = nn.ModuleList(
|
|
317
|
+
[
|
|
318
|
+
Transition(c=c_s, n=2),
|
|
319
|
+
Transition(c=c_s, n=2),
|
|
320
|
+
]
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
# Post-processing of z
|
|
324
|
+
self.n_bins_distogram = 65 # n bins for both self distogram and distogram
|
|
325
|
+
n_bins_noise = self.n_bins_distogram
|
|
326
|
+
self.use_self = use_self
|
|
327
|
+
self.use_distogram = use_distogram
|
|
328
|
+
self.use_sinusoidal_distogram_embedder = use_sinusoidal_distogram_embedder
|
|
329
|
+
if self.use_distogram:
|
|
330
|
+
if self.use_sinusoidal_distogram_embedder:
|
|
331
|
+
self.dist_embedder = SinusoidalDistEmbed(c_atompair=c_z)
|
|
332
|
+
n_bins_noise = c_z
|
|
333
|
+
else:
|
|
334
|
+
self.bucketize_fn = functools.partial(
|
|
335
|
+
bucketize_scaled_distogram,
|
|
336
|
+
min_dist=1,
|
|
337
|
+
max_dist=30,
|
|
338
|
+
sigma_data=sigma_data,
|
|
339
|
+
n_bins=self.n_bins_distogram,
|
|
340
|
+
)
|
|
341
|
+
cat_c_z = (
|
|
342
|
+
c_z
|
|
343
|
+
+ int(self.use_distogram) * n_bins_noise
|
|
344
|
+
+ int(self.use_self) * self.n_bins_distogram
|
|
345
|
+
)
|
|
346
|
+
self.process_z = nn.Sequential(
|
|
347
|
+
RMSNorm(cat_c_z),
|
|
348
|
+
linearNoBias(cat_c_z, c_z),
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
self.transition_2 = nn.ModuleList(
|
|
352
|
+
[
|
|
353
|
+
Transition(c=c_z, n=2),
|
|
354
|
+
Transition(c=c_z, n=2),
|
|
355
|
+
]
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
# Pairformer without triangle updates
|
|
359
|
+
self.pairformer_stack = nn.ModuleList(
|
|
360
|
+
[
|
|
361
|
+
PairformerBlock(c_s=c_s, c_z=c_z, **pairformer_block)
|
|
362
|
+
for _ in range(n_pairformer_blocks)
|
|
363
|
+
]
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
def forward(self, f, R_L, S_init_I, Z_init_II, C_L, P_LL, **kwargs):
|
|
367
|
+
B = R_L.shape[0]
|
|
368
|
+
"""
|
|
369
|
+
Pools atom-level features to token-level features and encodes them into Z_II, S_I and prepares A_I.
|
|
370
|
+
"""
|
|
371
|
+
|
|
372
|
+
@activation_checkpointing
|
|
373
|
+
def token_embed(S_init_I, Z_init_II):
|
|
374
|
+
S_I = S_init_I
|
|
375
|
+
for b in range(2):
|
|
376
|
+
S_I = S_I + self.transition_1[b](S_I)
|
|
377
|
+
|
|
378
|
+
Z_II = Z_init_II.unsqueeze(0).expand(B, -1, -1, -1) # B, I, I, c_z
|
|
379
|
+
|
|
380
|
+
Z_II_list = [Z_II]
|
|
381
|
+
if self.use_distogram:
|
|
382
|
+
# Noise / self conditioning pair
|
|
383
|
+
if self.use_sinusoidal_distogram_embedder:
|
|
384
|
+
mask = f["is_motif_atom_with_fixed_coord"][f["is_ca"]]
|
|
385
|
+
mask = (mask[None, :] != mask[:, None]).unsqueeze(
|
|
386
|
+
-1
|
|
387
|
+
) # remove off-diagonals where distances don't make sense across time
|
|
388
|
+
D_LL = self.dist_embedder(R_L[..., f["is_ca"], :], ~mask)
|
|
389
|
+
else:
|
|
390
|
+
D_LL = self.bucketize_fn(
|
|
391
|
+
R_L[..., f["is_ca"], :]
|
|
392
|
+
) # [B, L, I, n_bins]
|
|
393
|
+
Z_II_list.append(D_LL)
|
|
394
|
+
if self.use_self:
|
|
395
|
+
D_II_self = kwargs.get("D_II_self")
|
|
396
|
+
if D_II_self is None:
|
|
397
|
+
D_II_self = torch.zeros(
|
|
398
|
+
Z_II.shape[:-1] + (self.n_bins_distogram,),
|
|
399
|
+
device=Z_II.device,
|
|
400
|
+
dtype=Z_II.dtype,
|
|
401
|
+
)
|
|
402
|
+
Z_II_list.append(D_II_self)
|
|
403
|
+
Z_II = torch.cat(Z_II_list, dim=-1)
|
|
404
|
+
|
|
405
|
+
# Flatten concatenated dims
|
|
406
|
+
Z_II = self.process_z(Z_II)
|
|
407
|
+
|
|
408
|
+
for b in range(2):
|
|
409
|
+
Z_II = Z_II + self.transition_2[b](Z_II)
|
|
410
|
+
|
|
411
|
+
# Pairformer to mix
|
|
412
|
+
for block in self.pairformer_stack:
|
|
413
|
+
S_I, Z_II = block(S_I, Z_II)
|
|
414
|
+
|
|
415
|
+
return S_I, Z_II
|
|
416
|
+
|
|
417
|
+
return token_embed(S_init_I, Z_init_II)
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
from torch.nn.functional import silu
|
|
8
|
+
|
|
9
|
+
from foundry.training.checkpoint import activation_checkpointing
|
|
10
|
+
from foundry.utils.ddp import RankedLogger
|
|
11
|
+
|
|
12
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
13
|
+
try:
|
|
14
|
+
from apex.normalization.fused_layer_norm import FusedRMSNorm
|
|
15
|
+
|
|
16
|
+
ranked_logger.info("Fused RMSNorm enabled!")
|
|
17
|
+
RMSNorm_ = FusedRMSNorm
|
|
18
|
+
except (ImportError, ModuleNotFoundError):
|
|
19
|
+
ranked_logger.warning(
|
|
20
|
+
"Using nn.RMSNorm instead of apex.normalization.fused_layer_norm.FusedRMSNorm."
|
|
21
|
+
"Ensure you're using the correct apptainer"
|
|
22
|
+
)
|
|
23
|
+
RMSNorm_ = nn.RMSNorm
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# Allow bias=False to be passed for RMSNorm
|
|
27
|
+
def RMSNorm(*args, **kwargs):
|
|
28
|
+
if "bias" in kwargs:
|
|
29
|
+
kwargs.pop("bias")
|
|
30
|
+
return RMSNorm_(*args, **kwargs)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
SWAP_LAYER_NORM_FOR_RMS_NORM = True
|
|
34
|
+
RMSNorm = RMSNorm if SWAP_LAYER_NORM_FOR_RMS_NORM else nn.LayerNorm
|
|
35
|
+
linearNoBias = partial(torch.nn.Linear, bias=False)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class EmbeddingLayer(nn.Linear):
|
|
39
|
+
"""
|
|
40
|
+
Specialized linear layer for correct weight initialization for embedding layers.
|
|
41
|
+
|
|
42
|
+
Embedding layers are functionally a multiplication of an N channel input by an NxC weight matrix to produce an
|
|
43
|
+
embedding of length C. However, we compute the components separately with a ModuleDict, then sum at the end, for
|
|
44
|
+
embedding reusability and interoperability purposes.
|
|
45
|
+
|
|
46
|
+
This layer uses Xavier initialization as described in [1]_.
|
|
47
|
+
|
|
48
|
+
References
|
|
49
|
+
----------
|
|
50
|
+
.. [1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty
|
|
51
|
+
of training deep feedforward neural networks." (2010)
|
|
52
|
+
http://proceedings.mlr.press/v9/glorot10a.html
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
this_in_features,
|
|
58
|
+
total_embedding_features,
|
|
59
|
+
out_features,
|
|
60
|
+
device=None,
|
|
61
|
+
dtype=None,
|
|
62
|
+
):
|
|
63
|
+
self.total_embedding_features = total_embedding_features
|
|
64
|
+
self.out_features = out_features
|
|
65
|
+
super().__init__(
|
|
66
|
+
this_in_features, out_features, bias=False, device=device, dtype=dtype
|
|
67
|
+
)
|
|
68
|
+
self.reset_parameters()
|
|
69
|
+
|
|
70
|
+
def reset_parameters(self, **kwargs):
|
|
71
|
+
super().reset_parameters()
|
|
72
|
+
a = math.sqrt(6.0 / float(self.total_embedding_features + self.out_features))
|
|
73
|
+
nn.init._no_grad_uniform_(self.weight, -a, a)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def collapse(x, L):
|
|
77
|
+
return x.reshape((L, x.numel() // L))
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class MultiDimLinear(nn.Linear):
|
|
81
|
+
def __init__(self, in_features, out_shape, norm=False, **kwargs):
|
|
82
|
+
self.out_shape = out_shape
|
|
83
|
+
out_features = np.prod(out_shape)
|
|
84
|
+
super().__init__(in_features, out_features, **kwargs)
|
|
85
|
+
if norm:
|
|
86
|
+
self.ln = RMSNorm((out_features,))
|
|
87
|
+
self.use_ln = True
|
|
88
|
+
else:
|
|
89
|
+
self.use_ln = False
|
|
90
|
+
self.reset_parameters()
|
|
91
|
+
|
|
92
|
+
def reset_parameters(self, **kwargs) -> None:
|
|
93
|
+
super().reset_parameters()
|
|
94
|
+
nn.init.xavier_uniform_(self.weight)
|
|
95
|
+
|
|
96
|
+
def forward(self, x):
|
|
97
|
+
out = super().forward(x)
|
|
98
|
+
if self.use_ln:
|
|
99
|
+
out = self.ln(out)
|
|
100
|
+
return out.reshape(x.shape[:-1] + self.out_shape)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class LinearBiasInit(nn.Linear):
|
|
104
|
+
def __init__(self, *args, biasinit, **kwargs):
|
|
105
|
+
assert biasinit == -2.0 # Sanity check
|
|
106
|
+
self.biasinit = biasinit
|
|
107
|
+
super().__init__(*args, **kwargs)
|
|
108
|
+
|
|
109
|
+
def reset_parameters(self) -> None:
|
|
110
|
+
super().reset_parameters()
|
|
111
|
+
self.bias.data.fill_(self.biasinit)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class Transition(nn.Module):
|
|
115
|
+
def __init__(self, n, c):
|
|
116
|
+
super().__init__()
|
|
117
|
+
self.layer_norm_1 = RMSNorm(c)
|
|
118
|
+
self.linear_1 = linearNoBias(c, n * c)
|
|
119
|
+
self.linear_2 = linearNoBias(c, n * c)
|
|
120
|
+
self.linear_3 = linearNoBias(n * c, c)
|
|
121
|
+
|
|
122
|
+
@activation_checkpointing
|
|
123
|
+
def forward(
|
|
124
|
+
self,
|
|
125
|
+
X,
|
|
126
|
+
):
|
|
127
|
+
X = self.layer_norm_1(X)
|
|
128
|
+
A = self.linear_1(X)
|
|
129
|
+
B = self.linear_2(X)
|
|
130
|
+
X = self.linear_3(silu(A) * B)
|
|
131
|
+
return X
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class AdaLN(nn.Module):
|
|
135
|
+
def __init__(self, c_a, c_s, n=2):
|
|
136
|
+
super().__init__()
|
|
137
|
+
self.ln_a = RMSNorm(normalized_shape=(c_a,), elementwise_affine=False)
|
|
138
|
+
self.ln_s = RMSNorm(normalized_shape=(c_s,), bias=False)
|
|
139
|
+
self.to_gain = nn.Sequential(
|
|
140
|
+
nn.Linear(c_s, c_a),
|
|
141
|
+
nn.Sigmoid(),
|
|
142
|
+
)
|
|
143
|
+
self.to_bias = linearNoBias(c_s, c_a)
|
|
144
|
+
|
|
145
|
+
def forward(
|
|
146
|
+
self,
|
|
147
|
+
Ai, # [B, I, C_a]
|
|
148
|
+
Si, # [B, I, C_s]
|
|
149
|
+
):
|
|
150
|
+
"""
|
|
151
|
+
Output:
|
|
152
|
+
[B, I, C_a]
|
|
153
|
+
"""
|
|
154
|
+
Ai = self.ln_a(Ai)
|
|
155
|
+
Si = self.ln_s(Si)
|
|
156
|
+
return self.to_gain(Si) * Ai + self.to_bias(Si)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def create_batch_dimension_if_not_present(batched_n_dim):
|
|
160
|
+
"""
|
|
161
|
+
Decorator for adapting a function which expects batched arguments with ndim `batched_n_dim` also
|
|
162
|
+
accept unbatched arguments.
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
def wrap(f):
|
|
166
|
+
def _wrap(arg):
|
|
167
|
+
inserted_batch_dim = False
|
|
168
|
+
if arg.ndim == batched_n_dim - 1:
|
|
169
|
+
arg = arg[None]
|
|
170
|
+
inserted_batch_dim = True
|
|
171
|
+
elif arg.ndim == batched_n_dim:
|
|
172
|
+
pass
|
|
173
|
+
else:
|
|
174
|
+
raise Exception(
|
|
175
|
+
f"arg must have {batched_n_dim - 1} or {batched_n_dim} dimensions, got shape {arg.shape=}"
|
|
176
|
+
)
|
|
177
|
+
o = f(arg)
|
|
178
|
+
|
|
179
|
+
if inserted_batch_dim:
|
|
180
|
+
assert o.shape[0] == 1, f"{o.shape=}[0] != 1"
|
|
181
|
+
return o[0]
|
|
182
|
+
return o
|
|
183
|
+
|
|
184
|
+
return _wrap
|
|
185
|
+
|
|
186
|
+
return wrap
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def unpack_args_for_checkpointing(arg_names):
|
|
190
|
+
def wrap(f):
|
|
191
|
+
def _wrap(*args):
|
|
192
|
+
f = args[0]
|
|
193
|
+
return f(**dict(zip(arg_names, args)))
|
|
194
|
+
|
|
195
|
+
return _wrap
|
|
196
|
+
|
|
197
|
+
return wrap
|