rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,508 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Contains (a) global conditioning syntax and (b) transforms for pipeline
|
|
3
|
+
|
|
4
|
+
Conditioning pipeline:
|
|
5
|
+
inference --- create_atom_array_from_design_specification ---|
|
|
6
|
+
|---> CreateConditionedArray
|
|
7
|
+
training --- SampleConditioningFlags ---|
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import ast
|
|
11
|
+
import copy
|
|
12
|
+
import logging
|
|
13
|
+
|
|
14
|
+
import biotite.structure as struc
|
|
15
|
+
import hydra
|
|
16
|
+
import networkx as nx
|
|
17
|
+
import numpy as np
|
|
18
|
+
from atomworks.ml.transforms._checks import (
|
|
19
|
+
check_atom_array_annotation,
|
|
20
|
+
check_contains_keys,
|
|
21
|
+
check_is_instance,
|
|
22
|
+
)
|
|
23
|
+
from atomworks.ml.transforms.atom_array import (
|
|
24
|
+
add_global_token_id_annotation,
|
|
25
|
+
add_protein_termini_annotation,
|
|
26
|
+
)
|
|
27
|
+
from atomworks.ml.transforms.base import Transform
|
|
28
|
+
from atomworks.ml.utils.token import (
|
|
29
|
+
apply_and_spread_token_wise,
|
|
30
|
+
get_token_count,
|
|
31
|
+
get_token_starts,
|
|
32
|
+
)
|
|
33
|
+
from biotite.structure import AtomArray
|
|
34
|
+
from rfd3.constants import (
|
|
35
|
+
OPTIONAL_CONDITIONING_VALUES,
|
|
36
|
+
REQUIRED_CONDITIONING_ANNOTATIONS,
|
|
37
|
+
)
|
|
38
|
+
from rfd3.transforms.conditioning_utils import random_condition
|
|
39
|
+
from rfd3.transforms.util_transforms import (
|
|
40
|
+
add_representative_atom,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
from foundry.common import exists
|
|
44
|
+
|
|
45
|
+
nx.from_numpy_matrix = nx.from_numpy_array
|
|
46
|
+
logger = logging.getLogger(__name__)
|
|
47
|
+
NHEAVYPROT = 14
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
#################################################################################
|
|
51
|
+
# Base conditioning definititions
|
|
52
|
+
#################################################################################
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def get_motif_features(atom_array):
|
|
56
|
+
is_fixed = atom_array.is_motif_atom_with_fixed_coord.astype(bool)
|
|
57
|
+
is_sequence_fixed = atom_array.is_motif_atom_with_fixed_seq.astype(bool)
|
|
58
|
+
is_unindexed = atom_array.is_motif_atom_unindexed.astype(bool)
|
|
59
|
+
|
|
60
|
+
# Motif atom if has any conditioning
|
|
61
|
+
is_motif_atom = is_fixed | is_sequence_fixed | is_unindexed
|
|
62
|
+
is_motif_token = apply_and_spread_token_wise(
|
|
63
|
+
atom_array, is_motif_atom, function=lambda x: np.any(x)
|
|
64
|
+
) # Has any atoms with conditioning
|
|
65
|
+
|
|
66
|
+
return {"is_motif_atom": is_motif_atom, "is_motif_token": is_motif_token}
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def set_default_conditioning_annotations(
|
|
70
|
+
atom_array,
|
|
71
|
+
motif=False,
|
|
72
|
+
unindexed=False,
|
|
73
|
+
mask=None,
|
|
74
|
+
dtype=bool,
|
|
75
|
+
additional: set | list = None,
|
|
76
|
+
):
|
|
77
|
+
"""
|
|
78
|
+
Adds default annotations to the atom array
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
motif: True if default for a fully fixed motif, False if default for a fully diffused motif
|
|
82
|
+
unindexed: True if the tokens in the atom array should be motif
|
|
83
|
+
mask: boolean mask for array of which atoms to apply the assignments to.
|
|
84
|
+
NB: In both cases, the defaults for unindexed are False
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
# All annotations set to true for motif
|
|
88
|
+
fill = True if motif else False
|
|
89
|
+
if mask is not None:
|
|
90
|
+
# TODO: support defaulting to nulls
|
|
91
|
+
check_has_required_conditioning_annotations(atom_array)
|
|
92
|
+
trues = np.full(mask.sum(), True, dtype=dtype)
|
|
93
|
+
falses = np.full(mask.sum(), False, dtype=dtype)
|
|
94
|
+
|
|
95
|
+
atom_array.is_motif_atom_unindexed[mask] = trues if unindexed else falses
|
|
96
|
+
atom_array.is_motif_atom_unindexed_motif_breakpoint[mask] = falses
|
|
97
|
+
|
|
98
|
+
# Others:
|
|
99
|
+
for annotation in REQUIRED_CONDITIONING_ANNOTATIONS:
|
|
100
|
+
if annotation in [
|
|
101
|
+
"is_motif_atom_unindexed",
|
|
102
|
+
"is_motif_atom_unindexed_motif_breakpoint",
|
|
103
|
+
]:
|
|
104
|
+
continue
|
|
105
|
+
|
|
106
|
+
vals = copy.deepcopy(atom_array.get_annotation(annotation))
|
|
107
|
+
vals[mask] = trues if fill else falses
|
|
108
|
+
atom_array.set_annotation(annotation, vals)
|
|
109
|
+
else:
|
|
110
|
+
for annotation in REQUIRED_CONDITIONING_ANNOTATIONS:
|
|
111
|
+
if annotation in [
|
|
112
|
+
"is_motif_atom_unindexed",
|
|
113
|
+
]:
|
|
114
|
+
atom_array.set_annotation(
|
|
115
|
+
annotation,
|
|
116
|
+
np.full(atom_array.array_length(), unindexed, dtype=dtype),
|
|
117
|
+
)
|
|
118
|
+
elif annotation in [
|
|
119
|
+
"is_motif_atom_unindexed_motif_breakpoint",
|
|
120
|
+
]:
|
|
121
|
+
atom_array.set_annotation(
|
|
122
|
+
annotation, np.full(atom_array.array_length(), False, dtype=dtype)
|
|
123
|
+
)
|
|
124
|
+
else:
|
|
125
|
+
atom_array.set_annotation(
|
|
126
|
+
annotation, np.full(atom_array.array_length(), fill, dtype=dtype)
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
if additional is not None:
|
|
130
|
+
for annot, val in OPTIONAL_CONDITIONING_VALUES.items():
|
|
131
|
+
if (
|
|
132
|
+
annot in additional
|
|
133
|
+
and annot not in atom_array.get_annotation_categories()
|
|
134
|
+
):
|
|
135
|
+
atom_array.set_annotation(
|
|
136
|
+
annot, np.full(atom_array.array_length(), val)
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
return atom_array
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def check_has_required_conditioning_annotations(
|
|
143
|
+
atom_array, required=REQUIRED_CONDITIONING_ANNOTATIONS
|
|
144
|
+
):
|
|
145
|
+
"""
|
|
146
|
+
Checks if the atom array has the correct conditioning annotations
|
|
147
|
+
"""
|
|
148
|
+
received = atom_array.get_annotation_categories()
|
|
149
|
+
for required_annotation in required:
|
|
150
|
+
if required_annotation not in received:
|
|
151
|
+
raise InvalidSampledConditionException(
|
|
152
|
+
f"Missing annotation category in atom_array: {required_annotation}"
|
|
153
|
+
)
|
|
154
|
+
return True
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def convert_existing_annotations_to_bool(
|
|
158
|
+
atom_array, annotations=REQUIRED_CONDITIONING_ANNOTATIONS
|
|
159
|
+
):
|
|
160
|
+
# When loading from cif, annotations are loaded as strings when they should be boolean
|
|
161
|
+
for annotation in annotations:
|
|
162
|
+
if annotation not in atom_array.get_annotation_categories():
|
|
163
|
+
continue
|
|
164
|
+
tmp = atom_array.get_annotation(annotation).copy()
|
|
165
|
+
atom_array.get_annotation(annotation).dtype = bool
|
|
166
|
+
if isinstance(tmp[0], (str, np.str_, np.dtypes.StrDType)):
|
|
167
|
+
tmp = np.array([ast.literal_eval(x) for x in tmp], dtype=bool)
|
|
168
|
+
else:
|
|
169
|
+
tmp = np.asarray(tmp, dtype=bool)
|
|
170
|
+
atom_array.set_annotation(annotation, tmp)
|
|
171
|
+
return atom_array
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def convert_existing_annotations_to_int(
|
|
175
|
+
atom_array, annotations=REQUIRED_CONDITIONING_ANNOTATIONS
|
|
176
|
+
):
|
|
177
|
+
# When loading from cif, annotations are loaded as strings when they should be boolean
|
|
178
|
+
for annotation in annotations:
|
|
179
|
+
if annotation not in atom_array.get_annotation_categories():
|
|
180
|
+
continue
|
|
181
|
+
tmp = atom_array.get_annotation(annotation).copy()
|
|
182
|
+
if isinstance(tmp[0], (str, np.str_, np.bool_, bool, np.dtypes.BoolDType)):
|
|
183
|
+
tmp = np.array([int(x) for x in tmp], dtype=int)
|
|
184
|
+
atom_array.set_annotation(annotation, tmp)
|
|
185
|
+
return atom_array
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class StrtoBoolforIsXFeatures(Transform):
|
|
189
|
+
def check_input(self, *args, **kwargs):
|
|
190
|
+
pass
|
|
191
|
+
|
|
192
|
+
def __init__(self):
|
|
193
|
+
pass
|
|
194
|
+
|
|
195
|
+
def forward(self, data):
|
|
196
|
+
atom_array = data["atom_array"]
|
|
197
|
+
convert_existing_annotations_to_bool(atom_array)
|
|
198
|
+
data["atom_array"] = atom_array
|
|
199
|
+
return data
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
class InvalidSampledConditionException(Exception):
|
|
203
|
+
def __init__(self, message="Error during sampling of condition."):
|
|
204
|
+
self.message = message
|
|
205
|
+
super().__init__(self.message)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
#################################################################################
|
|
209
|
+
# Transform for pipeline (training & inference)
|
|
210
|
+
#################################################################################
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
class SampleConditioningType(Transform):
|
|
214
|
+
"""
|
|
215
|
+
Applies conditional assignments
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
train_conditions: List[RandomMask]
|
|
219
|
+
seed (int): random seed, for controling the masking results
|
|
220
|
+
|
|
221
|
+
Return:
|
|
222
|
+
atom_array with three more annotations:
|
|
223
|
+
- is_motif_token: tokens to be motif
|
|
224
|
+
- is_motif_atom: atoms to be motif
|
|
225
|
+
- is_motif_atom_with_fixed_seq: for which atom we know the true restype
|
|
226
|
+
"""
|
|
227
|
+
|
|
228
|
+
requires_previous_transforms = [
|
|
229
|
+
"AssignTypes",
|
|
230
|
+
]
|
|
231
|
+
|
|
232
|
+
def __init__(
|
|
233
|
+
self,
|
|
234
|
+
*,
|
|
235
|
+
train_conditions: dict,
|
|
236
|
+
meta_conditioning_probabilities: dict,
|
|
237
|
+
sequence_encoding,
|
|
238
|
+
):
|
|
239
|
+
if exists(train_conditions):
|
|
240
|
+
train_conditions = hydra.utils.instantiate(
|
|
241
|
+
train_conditions, _recursive_=True
|
|
242
|
+
)
|
|
243
|
+
self.meta_conditioning_probabilities = meta_conditioning_probabilities
|
|
244
|
+
self.train_conditions = train_conditions
|
|
245
|
+
self.sequence_encoding = sequence_encoding
|
|
246
|
+
|
|
247
|
+
def check_input(self, data: dict):
|
|
248
|
+
assert not data["is_inference"], "This transform is only used during training!"
|
|
249
|
+
check_contains_keys(data, ["atom_array"])
|
|
250
|
+
check_is_instance(data, "atom_array", AtomArray)
|
|
251
|
+
check_atom_array_annotation(data, ["pn_unit_id", "pn_unit_iid"])
|
|
252
|
+
existing = [
|
|
253
|
+
cat in REQUIRED_CONDITIONING_ANNOTATIONS
|
|
254
|
+
for cat in data["atom_array"].get_annotation_categories()
|
|
255
|
+
]
|
|
256
|
+
assert not any(
|
|
257
|
+
existing
|
|
258
|
+
), "Conditioning annotations already set! found {}".format(existing)
|
|
259
|
+
assert "conditions" in data, "Conditioning dict not initialized"
|
|
260
|
+
|
|
261
|
+
def forward(self, data):
|
|
262
|
+
valid_conditions = [
|
|
263
|
+
cond
|
|
264
|
+
for cond in self.train_conditions.values()
|
|
265
|
+
if cond.frequency > 0 and cond.is_valid_for_example(data)
|
|
266
|
+
]
|
|
267
|
+
|
|
268
|
+
if len(valid_conditions) == 0:
|
|
269
|
+
raise InvalidSampledConditionException("No valid condition was found.")
|
|
270
|
+
|
|
271
|
+
p_cond = np.array([cond.frequency for cond in valid_conditions])
|
|
272
|
+
if p_cond.sum() == 0:
|
|
273
|
+
raise InvalidSampledConditionException(
|
|
274
|
+
"No valid condition was found with non-zero frequency."
|
|
275
|
+
)
|
|
276
|
+
p_cond = p_cond.astype(np.float64)
|
|
277
|
+
p_cond /= p_cond.sum()
|
|
278
|
+
i_cond = np.random.choice(np.arange(len(p_cond)), p=p_cond)
|
|
279
|
+
cond = valid_conditions[i_cond]
|
|
280
|
+
|
|
281
|
+
data["sampled_condition"] = cond
|
|
282
|
+
data["sampled_condition_name"] = cond.name
|
|
283
|
+
data["sampled_condition_cls"] = cond.__class__
|
|
284
|
+
|
|
285
|
+
# Sample canonical conditioning flags for downstream processing
|
|
286
|
+
for k, p in self.meta_conditioning_probabilities.items():
|
|
287
|
+
data["conditions"][k] = random_condition(p)
|
|
288
|
+
|
|
289
|
+
return data
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
class SampleConditioningFlags(Transform):
|
|
293
|
+
requires_previous_transforms = [
|
|
294
|
+
"FlagAndReassignCovalentModifications",
|
|
295
|
+
"AssignTypes",
|
|
296
|
+
"SampleConditioningType",
|
|
297
|
+
] # We use is_protein in the PPI training condition
|
|
298
|
+
|
|
299
|
+
def check_input(self, data):
|
|
300
|
+
assert not data[
|
|
301
|
+
"is_inference"
|
|
302
|
+
], "This transform is only used during training! Validation using sampled conditions is not implemented yet"
|
|
303
|
+
assert "sampled_condition" in data
|
|
304
|
+
|
|
305
|
+
def forward(self, data: dict) -> dict:
|
|
306
|
+
cond = data["sampled_condition"]
|
|
307
|
+
|
|
308
|
+
# Sample canonical conditioning flags for atom array
|
|
309
|
+
atom_array = cond.sample(data)
|
|
310
|
+
data["atom_array"] = atom_array
|
|
311
|
+
|
|
312
|
+
return data
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
class UnindexFlaggedTokens(Transform):
|
|
316
|
+
"""
|
|
317
|
+
Serves as the merge point between training / infernece conditioning pipelines
|
|
318
|
+
"""
|
|
319
|
+
|
|
320
|
+
def __init__(self, central_atom):
|
|
321
|
+
"""
|
|
322
|
+
Args:
|
|
323
|
+
central_atom: The atom to use as the central atom for unindexed motifs.
|
|
324
|
+
"""
|
|
325
|
+
super().__init__()
|
|
326
|
+
self.central_atom = central_atom
|
|
327
|
+
|
|
328
|
+
def check_input(self, data: dict):
|
|
329
|
+
check_contains_keys(data, ["atom_array"])
|
|
330
|
+
check_is_instance(data, "atom_array", AtomArray)
|
|
331
|
+
|
|
332
|
+
def expand_unindexed_motifs(
|
|
333
|
+
self, atom_array: AtomArray, pop_orig_tokens: bool
|
|
334
|
+
) -> AtomArray:
|
|
335
|
+
"""
|
|
336
|
+
Takes atom array and motif indices and padds the atom array to include unindexed motif atoms.
|
|
337
|
+
|
|
338
|
+
is_motif_atom_unindexed - Whether an atom is flagged to be a guidepost
|
|
339
|
+
During training, the original coordinates are left behind for the model to learn to diffuse,
|
|
340
|
+
during inference, the original tokens are removed by default.
|
|
341
|
+
"""
|
|
342
|
+
# back up original residue id for training metrics
|
|
343
|
+
atom_array.set_annotation("orig_res_id", atom_array.res_id.copy())
|
|
344
|
+
is_motif_atom_unindexed = atom_array.is_motif_atom_unindexed.copy()
|
|
345
|
+
if not np.any(is_motif_atom_unindexed):
|
|
346
|
+
return atom_array
|
|
347
|
+
|
|
348
|
+
# ... A token is to be unindexed if any atoms in the token are unindexed
|
|
349
|
+
max_resid = np.max(atom_array.res_id)
|
|
350
|
+
starts = struc.get_residue_starts(atom_array, add_exclusive_stop=True)
|
|
351
|
+
token_to_unindex = struc.spread_residue_wise(
|
|
352
|
+
atom_array,
|
|
353
|
+
struc.apply_residue_wise(
|
|
354
|
+
atom_array,
|
|
355
|
+
is_motif_atom_unindexed,
|
|
356
|
+
function=lambda x: np.any(x),
|
|
357
|
+
),
|
|
358
|
+
)
|
|
359
|
+
assert token_to_unindex.sum() > 0, "No tokens to unindex!"
|
|
360
|
+
idxs = np.arange(atom_array.array_length())
|
|
361
|
+
unindexed_tokens = []
|
|
362
|
+
for i, (start, end) in enumerate(zip(starts[:-1], starts[1:])):
|
|
363
|
+
if not token_to_unindex[start]:
|
|
364
|
+
continue
|
|
365
|
+
subset_mask = np.isin(idxs, idxs[start:end])
|
|
366
|
+
token = copy.deepcopy(atom_array[subset_mask])
|
|
367
|
+
token = token[token.is_motif_atom_unindexed]
|
|
368
|
+
token.res_id = token.res_id + max_resid
|
|
369
|
+
token.is_C_terminus[:] = False
|
|
370
|
+
token.is_N_terminus[:] = False
|
|
371
|
+
assert token.is_protein.all(), f"Cannot unindex non-protein token: {token}"
|
|
372
|
+
token = add_representative_atom(token, central_atom=self.central_atom)
|
|
373
|
+
unindexed_tokens.append(token)
|
|
374
|
+
|
|
375
|
+
# ... Remove original tokens e.g. during inference
|
|
376
|
+
if pop_orig_tokens:
|
|
377
|
+
atom_array = atom_array[~token_to_unindex]
|
|
378
|
+
# Reassign Termini features
|
|
379
|
+
atom_array = add_protein_termini_annotation(atom_array)
|
|
380
|
+
else:
|
|
381
|
+
# Reset is_motif_atom and is_motif_atom_unindexed to contain no motif annotations where unindexed
|
|
382
|
+
# I.e model should view the original tokens the same as every other diffused token
|
|
383
|
+
atom_array.is_motif_atom[token_to_unindex] = False
|
|
384
|
+
atom_array.is_motif_atom_with_fixed_coord[token_to_unindex] = False
|
|
385
|
+
atom_array.is_motif_token[token_to_unindex] = False
|
|
386
|
+
atom_array.is_motif_atom_with_fixed_seq[token_to_unindex] = False
|
|
387
|
+
atom_array.is_motif_atom_unindexed[token_to_unindex] = False
|
|
388
|
+
atom_array.is_motif_atom_unindexed_motif_breakpoint[token_to_unindex] = (
|
|
389
|
+
False
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
# Concatenate unindexed parts to the end
|
|
393
|
+
atom_array_full = struc.concatenate([atom_array] + unindexed_tokens)
|
|
394
|
+
atom_array_to_concat = struc.concatenate(unindexed_tokens)
|
|
395
|
+
# Ensure tokens are recognised as seperate
|
|
396
|
+
n_unindexed_tokens = get_token_count(atom_array_to_concat)
|
|
397
|
+
assert n_unindexed_tokens == len(
|
|
398
|
+
unindexed_tokens
|
|
399
|
+
), f"Expected {len(unindexed_tokens)} but got {n_unindexed_tokens}"
|
|
400
|
+
assert (
|
|
401
|
+
get_token_count(atom_array_full)
|
|
402
|
+
== get_token_count(atom_array) + n_unindexed_tokens
|
|
403
|
+
), (
|
|
404
|
+
f"Failed to create uniquely recognised tokens after concatenation.\n"
|
|
405
|
+
f"Concatenated tokens: {get_token_count(atom_array_full)}, unindexed: {n_unindexed_tokens}"
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
return atom_array_full
|
|
409
|
+
|
|
410
|
+
def create_unindexed_masks(
|
|
411
|
+
self,
|
|
412
|
+
atom_array,
|
|
413
|
+
is_inference=False,
|
|
414
|
+
):
|
|
415
|
+
"""
|
|
416
|
+
Create L,L boolean matrix indicating the tokens which should absolutely
|
|
417
|
+
not know the relative positions of one another.
|
|
418
|
+
|
|
419
|
+
False when positional leakage is allowed
|
|
420
|
+
True when positional leakage is disallowed
|
|
421
|
+
|
|
422
|
+
Used as input to the models' relative position encoding.
|
|
423
|
+
|
|
424
|
+
breaks:
|
|
425
|
+
boolean atom-wise array indicating which token breaks the group ids up.
|
|
426
|
+
if all are false, all indices are leaked. If the first break of the unindexed tokens is
|
|
427
|
+
True, the cross-motif couplings are leaked but not the global index
|
|
428
|
+
|
|
429
|
+
atom_array: padded atom array
|
|
430
|
+
"""
|
|
431
|
+
token_starts = get_token_starts(atom_array)
|
|
432
|
+
token_level_array = atom_array[token_starts]
|
|
433
|
+
is_motif_token_unindexed = token_level_array.is_motif_atom_unindexed
|
|
434
|
+
|
|
435
|
+
# ... Grab breaks from the token level array
|
|
436
|
+
unindexed_token_level_array = token_level_array[is_motif_token_unindexed]
|
|
437
|
+
breaks = unindexed_token_level_array.is_motif_atom_unindexed_motif_breakpoint
|
|
438
|
+
|
|
439
|
+
leak_all = not np.any(breaks)
|
|
440
|
+
if leak_all:
|
|
441
|
+
if is_inference and np.any(is_motif_token_unindexed):
|
|
442
|
+
logger.info("Indexing all unindexed components")
|
|
443
|
+
L = len(token_starts)
|
|
444
|
+
return np.zeros((L, L), dtype=bool), is_motif_token_unindexed
|
|
445
|
+
|
|
446
|
+
# ... First component of mask is that no unindexed atoms should talk to indexed ones.
|
|
447
|
+
mask = (
|
|
448
|
+
is_motif_token_unindexed[:, None] == ~is_motif_token_unindexed[None, :]
|
|
449
|
+
) # [intra indexed + intra unindexed]
|
|
450
|
+
|
|
451
|
+
# ... Then, within unindexed tokens, seperate the islands based on where the token id breaks
|
|
452
|
+
unindexed_all_LL = (
|
|
453
|
+
is_motif_token_unindexed[:, None] & is_motif_token_unindexed[None, :]
|
|
454
|
+
) # [intra unindexed]
|
|
455
|
+
|
|
456
|
+
########################################################################################
|
|
457
|
+
# Determine intra-unindexed resid leakage
|
|
458
|
+
########################################################################################
|
|
459
|
+
# ... Mask out intra-unindexed off-diagonals as necessary
|
|
460
|
+
group_ids = np.cumsum(breaks)
|
|
461
|
+
mask_unindexed_MM = group_ids[:, None] != group_ids[None, :]
|
|
462
|
+
mask[unindexed_all_LL] = mask_unindexed_MM.flatten()
|
|
463
|
+
|
|
464
|
+
return mask, is_motif_token_unindexed
|
|
465
|
+
|
|
466
|
+
def forward(self, data: dict):
|
|
467
|
+
atom_array = data["atom_array"]
|
|
468
|
+
if "feats" not in data:
|
|
469
|
+
data["feats"] = {}
|
|
470
|
+
|
|
471
|
+
# ... Ensure conditioning flags are set correctly
|
|
472
|
+
# NOTE: Join point for inference and training conditioning pipelines
|
|
473
|
+
check_has_required_conditioning_annotations(atom_array)
|
|
474
|
+
|
|
475
|
+
is_unindexed_token = apply_and_spread_token_wise(
|
|
476
|
+
atom_array,
|
|
477
|
+
atom_array.is_motif_atom_unindexed.copy(),
|
|
478
|
+
function=lambda x: np.any(x),
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
# Expand unindexed motifs if necessary
|
|
482
|
+
atom_array_expanded = self.expand_unindexed_motifs(
|
|
483
|
+
atom_array,
|
|
484
|
+
pop_orig_tokens=data["is_inference"],
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
# Provide the atom-wise mask for the regions which should be diffused into the guideposts
|
|
488
|
+
# the original token was unindexed if any of the atoms where unindexed
|
|
489
|
+
n_expanded_atoms = (
|
|
490
|
+
atom_array_expanded.array_length() - atom_array.array_length()
|
|
491
|
+
)
|
|
492
|
+
mask = np.concatenate([is_unindexed_token, np.zeros(n_expanded_atoms)])
|
|
493
|
+
if "ground_truth" not in data:
|
|
494
|
+
data["ground_truth"] = {}
|
|
495
|
+
data["ground_truth"]["is_original_unindexed_token"] = mask.astype(bool)
|
|
496
|
+
|
|
497
|
+
# Reset global token IDs after possible padding
|
|
498
|
+
atom_array_expanded = add_global_token_id_annotation(atom_array_expanded)
|
|
499
|
+
|
|
500
|
+
# For unindexed scaffolding, we must provide an unindexing pair mask to ensure original positions aren't leaked to:
|
|
501
|
+
# (I) RPE of the token initializer and (II) the atom attention base sequence mask
|
|
502
|
+
mask_II, mask_I = self.create_unindexed_masks(
|
|
503
|
+
atom_array_expanded, is_inference=data["is_inference"]
|
|
504
|
+
)
|
|
505
|
+
data["feats"]["unindexing_pair_mask"] = mask_II
|
|
506
|
+
data["feats"]["is_motif_token_unindexed"] = mask_I
|
|
507
|
+
data["atom_array"] = atom_array_expanded
|
|
508
|
+
return data
|