rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,807 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Design Transforms for the Atom14 pipeline
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Any, Dict
|
|
6
|
+
|
|
7
|
+
import biotite.structure as struc
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn.functional as F
|
|
11
|
+
from atomworks.constants import (
|
|
12
|
+
ELEMENT_NAME_TO_ATOMIC_NUMBER,
|
|
13
|
+
)
|
|
14
|
+
from atomworks.enums import GroundTruthConformerPolicy
|
|
15
|
+
from atomworks.io.utils.selection import get_residue_starts
|
|
16
|
+
from atomworks.ml.transforms._checks import (
|
|
17
|
+
check_contains_keys,
|
|
18
|
+
check_is_instance,
|
|
19
|
+
)
|
|
20
|
+
from atomworks.ml.transforms.af3_reference_molecule import (
|
|
21
|
+
_encode_atom_names_like_af3,
|
|
22
|
+
get_af3_reference_molecule_features,
|
|
23
|
+
)
|
|
24
|
+
from atomworks.ml.transforms.base import (
|
|
25
|
+
Transform,
|
|
26
|
+
)
|
|
27
|
+
from atomworks.ml.utils.geometry import (
|
|
28
|
+
masked_center,
|
|
29
|
+
random_rigid_augmentation,
|
|
30
|
+
)
|
|
31
|
+
from atomworks.ml.utils.token import (
|
|
32
|
+
apply_token_wise,
|
|
33
|
+
get_token_starts,
|
|
34
|
+
)
|
|
35
|
+
from biotite.structure import AtomArray
|
|
36
|
+
from rfd3.constants import VIRTUAL_ATOM_ELEMENT_NAME
|
|
37
|
+
from rfd3.transforms.conditioning_base import (
|
|
38
|
+
UnindexFlaggedTokens,
|
|
39
|
+
get_motif_features,
|
|
40
|
+
)
|
|
41
|
+
from rfd3.transforms.rasa import discretize_rasa
|
|
42
|
+
from rfd3.transforms.util_transforms import (
|
|
43
|
+
AssignTypes,
|
|
44
|
+
add_backbone_and_sidechain_annotations,
|
|
45
|
+
get_af3_token_representative_masks,
|
|
46
|
+
)
|
|
47
|
+
from rfd3.transforms.virtual_atoms import PadTokensWithVirtualAtoms
|
|
48
|
+
|
|
49
|
+
from foundry.utils.ddp import RankedLogger # noqa
|
|
50
|
+
|
|
51
|
+
#####################################################################################################
|
|
52
|
+
# Other design transforms
|
|
53
|
+
#####################################################################################################
|
|
54
|
+
|
|
55
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class SubsampleToTypes(Transform):
|
|
59
|
+
"""
|
|
60
|
+
Remove all types not specified as allowed_types
|
|
61
|
+
Possible allowed types:
|
|
62
|
+
- is_protein
|
|
63
|
+
- is_ligand
|
|
64
|
+
- is_dna
|
|
65
|
+
- is_rna
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
requires_previous_transforms = [AssignTypes]
|
|
69
|
+
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
allowed_types: list | str = ["is_protein"],
|
|
73
|
+
):
|
|
74
|
+
self.allowed_types = allowed_types
|
|
75
|
+
if not self.allowed_types == "ALL":
|
|
76
|
+
for k in allowed_types:
|
|
77
|
+
if not k.startswith("is_"):
|
|
78
|
+
raise ValueError(f"Allowed types must start with 'is_', got {k}")
|
|
79
|
+
|
|
80
|
+
def check_input(self, data: dict):
|
|
81
|
+
check_contains_keys(data, ["atom_array"])
|
|
82
|
+
|
|
83
|
+
def forward(self, data):
|
|
84
|
+
atom_array = data["atom_array"]
|
|
85
|
+
|
|
86
|
+
# ... Subsampling
|
|
87
|
+
if not self.allowed_types == "ALL":
|
|
88
|
+
is_allowed = np.zeros_like(atom_array.is_protein, dtype=bool)
|
|
89
|
+
for allowed_type in self.allowed_types:
|
|
90
|
+
is_allowed = np.logical_or(
|
|
91
|
+
np.asarray(is_allowed, dtype=bool),
|
|
92
|
+
np.asarray(
|
|
93
|
+
atom_array.get_annotation(allowed_type), dtype=bool
|
|
94
|
+
).copy(),
|
|
95
|
+
)
|
|
96
|
+
atom_array = atom_array[is_allowed]
|
|
97
|
+
|
|
98
|
+
# ... Assert any protein remains
|
|
99
|
+
if atom_array.array_length() == 0:
|
|
100
|
+
raise ValueError(
|
|
101
|
+
"No tokens found in the atom array! Example ID: {}".format(
|
|
102
|
+
data.get("example_id", "unknown")
|
|
103
|
+
)
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
if atom_array.is_protein.sum() == 0:
|
|
107
|
+
raise ValueError(
|
|
108
|
+
"No protein atoms found in the atom array. Example ID: {}".format(
|
|
109
|
+
data.get("example_id", "unknown")
|
|
110
|
+
)
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
data["atom_array"] = atom_array
|
|
114
|
+
return data
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class CreateDesignReferenceFeatures(Transform):
|
|
118
|
+
"""
|
|
119
|
+
Traditional AF3 will create a bunch of reference features based on the sequence and molecular identity.
|
|
120
|
+
For our design, we do not have access to sequence so these features are useless
|
|
121
|
+
|
|
122
|
+
However, this is a great place to add atom-level features as explicit conditioning or implicit
|
|
123
|
+
classifier free guidance.
|
|
124
|
+
|
|
125
|
+
Reduces time to process from ~0.5 to ~0.1 s on avg.
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
requires_previous_transforms = [UnindexFlaggedTokens, AssignTypes]
|
|
129
|
+
|
|
130
|
+
def __init__(
|
|
131
|
+
self,
|
|
132
|
+
generate_conformers,
|
|
133
|
+
generate_conformers_for_non_protein_only,
|
|
134
|
+
provide_reference_conformer_when_unmasked,
|
|
135
|
+
ground_truth_conformer_policy,
|
|
136
|
+
provide_elements_for_unindexed_components,
|
|
137
|
+
use_element_for_atom_names_of_atomized_tokens,
|
|
138
|
+
**kwargs,
|
|
139
|
+
):
|
|
140
|
+
self.generate_conformers = generate_conformers
|
|
141
|
+
self.generate_conformers_for_non_protein_only = (
|
|
142
|
+
generate_conformers_for_non_protein_only
|
|
143
|
+
)
|
|
144
|
+
self.provide_reference_conformer_when_unmasked = (
|
|
145
|
+
provide_reference_conformer_when_unmasked
|
|
146
|
+
)
|
|
147
|
+
if provide_reference_conformer_when_unmasked:
|
|
148
|
+
self.ground_truth_conformer_policy = GroundTruthConformerPolicy[
|
|
149
|
+
ground_truth_conformer_policy
|
|
150
|
+
]
|
|
151
|
+
else:
|
|
152
|
+
self.ground_truth_conformer_policy = GroundTruthConformerPolicy.IGNORE
|
|
153
|
+
|
|
154
|
+
self.provide_elements_for_unindexed_components = (
|
|
155
|
+
provide_elements_for_unindexed_components
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
self.conformer_generation_kwargs = {
|
|
159
|
+
"conformer_generation_timeout": 2.0,
|
|
160
|
+
"use_element_for_atom_names_of_atomized_tokens": use_element_for_atom_names_of_atomized_tokens,
|
|
161
|
+
} | kwargs
|
|
162
|
+
|
|
163
|
+
def check_input(self, data: dict):
|
|
164
|
+
check_contains_keys(data, ["atom_array"])
|
|
165
|
+
|
|
166
|
+
def forward(self, data: dict) -> dict:
|
|
167
|
+
atom_array = data["atom_array"]
|
|
168
|
+
I = atom_array.array_length()
|
|
169
|
+
token_starts = get_token_starts(atom_array)
|
|
170
|
+
token_level_array = atom_array[token_starts]
|
|
171
|
+
L = token_level_array.array_length()
|
|
172
|
+
|
|
173
|
+
# ... Set up default reference features
|
|
174
|
+
ref_pos = np.zeros_like(atom_array.coord, dtype=np.float32)
|
|
175
|
+
ref_pos[~atom_array.is_motif_atom_with_fixed_coord, :] = 0
|
|
176
|
+
ref_mask = np.zeros((I,), dtype=bool)
|
|
177
|
+
ref_charge = np.zeros((I,), dtype=np.int8)
|
|
178
|
+
ref_pos_is_ground_truth = np.zeros((I,), dtype=bool)
|
|
179
|
+
|
|
180
|
+
# ... For elements provide only the elements for unindexed components
|
|
181
|
+
ref_element = np.zeros((I,), dtype=np.int64)
|
|
182
|
+
# if self.provide_elements_for_unindexed_components:
|
|
183
|
+
# ref_element[atom_array.is_motif_atom_unindexed] = atom_array.atomic_number[
|
|
184
|
+
# atom_array.is_motif_atom_unindexed
|
|
185
|
+
# ]
|
|
186
|
+
|
|
187
|
+
# ... For atom names, provide all (spoofed) names in the atom array
|
|
188
|
+
ref_atom_name_chars = _encode_atom_names_like_af3(atom_array.atom_name)
|
|
189
|
+
_res_start_ends = get_residue_starts(atom_array, add_exclusive_stop=True)
|
|
190
|
+
_res_starts, _res_ends = _res_start_ends[:-1], _res_start_ends[1:]
|
|
191
|
+
ref_space_uid = struc.segments.spread_segment_wise(
|
|
192
|
+
_res_start_ends, np.arange(len(_res_starts), dtype=np.int64)
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
o = get_motif_features(atom_array)
|
|
196
|
+
is_motif_atom, is_motif_token = o["is_motif_atom"], o["is_motif_token"]
|
|
197
|
+
is_motif_token = is_motif_token[token_starts]
|
|
198
|
+
|
|
199
|
+
# ... Add a flag for atoms with zero occupancy
|
|
200
|
+
has_zero_occupancy = atom_array.occupancy == 0.0
|
|
201
|
+
if data["is_inference"] and has_zero_occupancy.any():
|
|
202
|
+
ranked_logger.warning(
|
|
203
|
+
"Found non-zero occupancy in input, setting occupancy to 1"
|
|
204
|
+
)
|
|
205
|
+
has_zero_occupancy = np.full_like(has_zero_occupancy, False)
|
|
206
|
+
|
|
207
|
+
# ... Token features for token type;
|
|
208
|
+
# [1, 0, 0]: non-motif
|
|
209
|
+
# [0, 1, 0]: indexed motif
|
|
210
|
+
# [0, 0, 1]: unindexed motif
|
|
211
|
+
is_motif_token_unindexed = atom_array.is_motif_atom_unindexed[token_starts]
|
|
212
|
+
motif_token_class = np.zeros((L,), dtype=np.int8)
|
|
213
|
+
motif_token_class[is_motif_token] = 1
|
|
214
|
+
motif_token_class[is_motif_token_unindexed] = 2
|
|
215
|
+
motif_token_type = np.eye(3, dtype=np.int8)[
|
|
216
|
+
motif_token_class
|
|
217
|
+
] # one-hot, (L, 3)
|
|
218
|
+
|
|
219
|
+
# ... Provide GT as reference coordinates even when unfixed
|
|
220
|
+
motif_pos = np.nan_to_num(atom_array.coord.copy())
|
|
221
|
+
motif_pos = motif_pos * (is_motif_atom[..., None])
|
|
222
|
+
|
|
223
|
+
# ... Create reference features for unmasked subset (where we are allowed to use gt)
|
|
224
|
+
has_sequence = (
|
|
225
|
+
atom_array.is_motif_atom_with_fixed_seq
|
|
226
|
+
& ~atom_array.is_motif_atom_unindexed
|
|
227
|
+
) # (n_atoms,)
|
|
228
|
+
|
|
229
|
+
if self.generate_conformers_for_non_protein_only:
|
|
230
|
+
has_sequence = has_sequence & ~atom_array.is_protein
|
|
231
|
+
|
|
232
|
+
if np.any(has_sequence):
|
|
233
|
+
# Subset atom level
|
|
234
|
+
atom_array_unmasked = atom_array[has_sequence]
|
|
235
|
+
|
|
236
|
+
# We always want to generate conformers if there are ligand atoms that are diffused
|
|
237
|
+
if (
|
|
238
|
+
self.generate_conformers
|
|
239
|
+
or np.logical_and(
|
|
240
|
+
atom_array_unmasked.is_ligand, ~atom_array_unmasked.is_motif_atom
|
|
241
|
+
).any()
|
|
242
|
+
):
|
|
243
|
+
atom_array_unmasked.set_annotation(
|
|
244
|
+
"ground_truth_conformer_policy",
|
|
245
|
+
np.full(
|
|
246
|
+
atom_array_unmasked.array_length(),
|
|
247
|
+
self.ground_truth_conformer_policy.value,
|
|
248
|
+
),
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
# Compute the reference features
|
|
252
|
+
# ... Create a copy of atom_array_unmasked and replace the atom_names with gt_atom_names for reference conformer generation
|
|
253
|
+
atom_array_unmasked_with_gt_atom_name = atom_array_unmasked.copy()
|
|
254
|
+
atom_array_unmasked_with_gt_atom_name.atom_name = (
|
|
255
|
+
atom_array_unmasked_with_gt_atom_name.gt_atom_name
|
|
256
|
+
)
|
|
257
|
+
reference_features_unmasked = get_af3_reference_molecule_features(
|
|
258
|
+
atom_array_unmasked_with_gt_atom_name,
|
|
259
|
+
cached_residue_level_data=data["cached_residue_level_data"]
|
|
260
|
+
if "cached_residue_level_data" in data
|
|
261
|
+
else None,
|
|
262
|
+
**self.conformer_generation_kwargs,
|
|
263
|
+
)[0] ## returns tuple, need to index 0
|
|
264
|
+
|
|
265
|
+
ref_atom_name_chars[has_sequence] = reference_features_unmasked[
|
|
266
|
+
"ref_atom_name_chars"
|
|
267
|
+
]
|
|
268
|
+
ref_mask[has_sequence] = reference_features_unmasked["ref_mask"]
|
|
269
|
+
ref_element[has_sequence] = reference_features_unmasked["ref_element"]
|
|
270
|
+
ref_charge[has_sequence] = reference_features_unmasked["ref_charge"]
|
|
271
|
+
ref_pos_is_ground_truth[has_sequence] = reference_features_unmasked[
|
|
272
|
+
"ref_pos_is_ground_truth"
|
|
273
|
+
]
|
|
274
|
+
|
|
275
|
+
# If requested, include the reference conformers for unmasked atoms
|
|
276
|
+
if self.provide_reference_conformer_when_unmasked:
|
|
277
|
+
ref_pos[has_sequence] = reference_features_unmasked["ref_pos"]
|
|
278
|
+
else:
|
|
279
|
+
# Generate simple features
|
|
280
|
+
ref_charge[has_sequence] = atom_array_unmasked.charge
|
|
281
|
+
ref_element[has_sequence] = (
|
|
282
|
+
atom_array_unmasked.atomic_number
|
|
283
|
+
if "atomic_number" in atom_array.get_annotation_categories()
|
|
284
|
+
else np.vectorize(ELEMENT_NAME_TO_ATOMIC_NUMBER.get)(
|
|
285
|
+
atom_array.element
|
|
286
|
+
)
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
reference_features = {
|
|
290
|
+
"ref_atom_name_chars": ref_atom_name_chars, # (n_atoms, 4)
|
|
291
|
+
"ref_pos": ref_pos, # (n_atoms, 3)
|
|
292
|
+
"ref_mask": ref_mask, # (n_atoms)
|
|
293
|
+
"ref_element": ref_element, # (n_atoms)
|
|
294
|
+
"ref_charge": ref_charge, # (n_atoms)
|
|
295
|
+
"ref_space_uid": ref_space_uid, # (n_atoms)
|
|
296
|
+
"ref_pos_is_ground_truth": ref_pos_is_ground_truth, # (n_atoms)
|
|
297
|
+
"has_zero_occupancy": has_zero_occupancy, # (n_atoms)
|
|
298
|
+
# Conditional masks
|
|
299
|
+
# "ref_is_motif_atom": is_motif_atom, # (n_atoms, 2)
|
|
300
|
+
# "ref_is_motif_atom_mask": atom_array.is_motif_atom.copy(), # (n_atoms)
|
|
301
|
+
# "ref_is_motif_token": is_motif_token, # (n_tokens, 2)
|
|
302
|
+
# "ref_motif_atom_type": motif_atom_type, # (n_atoms, 3) # 3 types of atom conditions
|
|
303
|
+
"ref_is_motif_atom_with_fixed_coord": atom_array.is_motif_atom_with_fixed_coord.copy(), # (n_atoms)
|
|
304
|
+
"ref_is_motif_atom_unindexed": atom_array.is_motif_atom_unindexed.copy(),
|
|
305
|
+
"ref_motif_token_type": motif_token_type, # (n_tokens, 3) # 3 types of token
|
|
306
|
+
"motif_pos": motif_pos, # (n_atoms, 3) # GT pos for motif atoms
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
# TEMPORARY HACK TO CREATE MOTIF FEATURES AGAIN
|
|
310
|
+
# f = get_motif_features(atom_array)
|
|
311
|
+
# token_starts = get_token_starts(atom_array)
|
|
312
|
+
# # Annots
|
|
313
|
+
# atom_array = data["atom_array"]
|
|
314
|
+
# atom_array.set_annotation("is_motif_atom", f["is_motif_atom"])
|
|
315
|
+
# atom_array.set_annotation("is_motif_token", f["is_motif_token"])
|
|
316
|
+
# data["atom_array"] = atom_array
|
|
317
|
+
# # Ref feats
|
|
318
|
+
# motif_atom_class = np.zeros((I,), dtype=np.int8)
|
|
319
|
+
# motif_atom_class[atom_array.is_motif_atom] = 1
|
|
320
|
+
# motif_atom_class[atom_array.is_motif_atom_unindexed] = 2
|
|
321
|
+
# motif_atom_type = np.eye(3, dtype=np.int8)[motif_atom_class] # one-hot, (I, 3)
|
|
322
|
+
# is_motif_atom = torch.nn.functional.one_hot(
|
|
323
|
+
# torch.from_numpy(atom_array.is_motif_atom).long(), num_classes=2
|
|
324
|
+
# ).numpy()
|
|
325
|
+
# is_motif_token = torch.nn.functional.one_hot(
|
|
326
|
+
# torch.from_numpy(atom_array.is_motif_token[token_starts]).long(),
|
|
327
|
+
# num_classes=2,
|
|
328
|
+
# ).numpy()
|
|
329
|
+
# reference_features["ref_motif_atom_type"] = motif_atom_type
|
|
330
|
+
# reference_features["ref_is_motif_atom"] = is_motif_atom
|
|
331
|
+
# reference_features["ref_is_motif_atom_mask"] = atom_array.is_motif_atom.copy()
|
|
332
|
+
# reference_features["ref_is_motif_token"] = is_motif_token
|
|
333
|
+
# reference_features["is_motif_atom"] = atom_array.is_motif_atom.astype(
|
|
334
|
+
# bool
|
|
335
|
+
# ).copy()
|
|
336
|
+
# reference_features["is_motif_token"] = f["is_motif_token"][token_starts]
|
|
337
|
+
# # END TEMPORARY HACK
|
|
338
|
+
|
|
339
|
+
if "feats" not in data:
|
|
340
|
+
data["feats"] = {}
|
|
341
|
+
data["feats"].update(reference_features)
|
|
342
|
+
|
|
343
|
+
return data
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
class FeaturizeAtoms(Transform):
|
|
347
|
+
def __init__(self, n_bins=4):
|
|
348
|
+
self.n_bins = n_bins
|
|
349
|
+
|
|
350
|
+
def forward(self, data):
|
|
351
|
+
atom_array = data["atom_array"]
|
|
352
|
+
|
|
353
|
+
if "feats" not in data:
|
|
354
|
+
data["feats"] = {}
|
|
355
|
+
|
|
356
|
+
if (
|
|
357
|
+
data["is_inference"]
|
|
358
|
+
and "rasa_bin" in atom_array.get_annotation_categories()
|
|
359
|
+
):
|
|
360
|
+
rasa_binned = atom_array.get_annotation("rasa_bin").copy()
|
|
361
|
+
elif "rasa" in atom_array.get_annotation_categories():
|
|
362
|
+
rasa_binned = discretize_rasa(
|
|
363
|
+
atom_array,
|
|
364
|
+
n_bins=self.n_bins - 1,
|
|
365
|
+
keep_protein_motif=data["conditions"]["keep_protein_motif_rasa"],
|
|
366
|
+
)
|
|
367
|
+
else:
|
|
368
|
+
rasa_binned = np.full(
|
|
369
|
+
atom_array.array_length(), self.n_bins - 1, dtype=np.int64
|
|
370
|
+
)
|
|
371
|
+
rasa_oh = F.one_hot(
|
|
372
|
+
torch.from_numpy(rasa_binned).long(), num_classes=self.n_bins
|
|
373
|
+
).numpy()
|
|
374
|
+
data["feats"]["ref_atomwise_rasa"] = rasa_oh[
|
|
375
|
+
..., :-1
|
|
376
|
+
] # exclude last bin from being fed to the model
|
|
377
|
+
|
|
378
|
+
if "active_donor" in atom_array.get_annotation_categories():
|
|
379
|
+
data["feats"]["active_donor"] = torch.tensor(
|
|
380
|
+
np.float64(atom_array.active_donor)
|
|
381
|
+
).long()
|
|
382
|
+
else:
|
|
383
|
+
data["feats"]["active_donor"] = torch.tensor(
|
|
384
|
+
np.zeros(len(atom_array))
|
|
385
|
+
).long()
|
|
386
|
+
|
|
387
|
+
if "active_acceptor" in atom_array.get_annotation_categories():
|
|
388
|
+
data["feats"]["active_acceptor"] = torch.tensor(
|
|
389
|
+
np.float64(atom_array.active_acceptor)
|
|
390
|
+
).long()
|
|
391
|
+
else:
|
|
392
|
+
data["feats"]["active_acceptor"] = torch.tensor(
|
|
393
|
+
np.zeros(len(atom_array))
|
|
394
|
+
).long()
|
|
395
|
+
|
|
396
|
+
return data
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
class AddIsXFeats(Transform):
|
|
400
|
+
"""
|
|
401
|
+
Adds boolean masks to the atom array based on the sequence type
|
|
402
|
+
|
|
403
|
+
Assigned types to atom array (X):
|
|
404
|
+
- is_backbone
|
|
405
|
+
- is_sidechain
|
|
406
|
+
- is_virtual
|
|
407
|
+
- is_central
|
|
408
|
+
- is_ca
|
|
409
|
+
Xs only returned as features (requires previous assignment):
|
|
410
|
+
- is_motif_atom_with_fixed_coord
|
|
411
|
+
- is_motif_atom_unindexed
|
|
412
|
+
- is_motif_atom_with_fixed_seq
|
|
413
|
+
"""
|
|
414
|
+
|
|
415
|
+
requires_previous_transforms = [
|
|
416
|
+
AssignTypes,
|
|
417
|
+
PadTokensWithVirtualAtoms,
|
|
418
|
+
"UnindexFlaggedTokens",
|
|
419
|
+
]
|
|
420
|
+
|
|
421
|
+
def __init__(
|
|
422
|
+
self,
|
|
423
|
+
X,
|
|
424
|
+
central_atom,
|
|
425
|
+
extra_atom_level_feats: list[str] = [],
|
|
426
|
+
extra_token_level_feats: list[str] = [],
|
|
427
|
+
):
|
|
428
|
+
self.X = X
|
|
429
|
+
self.central_atom = central_atom
|
|
430
|
+
self.update_atom_array = False
|
|
431
|
+
self.extra_atom_level_feats = extra_atom_level_feats
|
|
432
|
+
self.extra_token_level_feats = extra_token_level_feats
|
|
433
|
+
|
|
434
|
+
def check_input(self, data):
|
|
435
|
+
check_contains_keys(data, ["atom_array", "feats"])
|
|
436
|
+
|
|
437
|
+
def forward(self, data: dict) -> dict:
|
|
438
|
+
atom_array = data["atom_array"]
|
|
439
|
+
atom_array = add_backbone_and_sidechain_annotations(atom_array)
|
|
440
|
+
token_level_array = atom_array[get_token_starts(atom_array)]
|
|
441
|
+
_token_rep_mask = get_af3_token_representative_masks(
|
|
442
|
+
atom_array, central_atom=self.central_atom
|
|
443
|
+
)
|
|
444
|
+
_token_rep_idxs = np.where(_token_rep_mask)[0]
|
|
445
|
+
|
|
446
|
+
# ... Basic features
|
|
447
|
+
if "is_backbone" in self.X:
|
|
448
|
+
is_backbone = data["atom_array"].get_annotation("is_backbone")
|
|
449
|
+
data["feats"]["is_backbone"] = torch.from_numpy(is_backbone).to(
|
|
450
|
+
dtype=torch.bool
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
if "is_sidechain" in self.X:
|
|
454
|
+
is_sidechain = data["atom_array"].get_annotation("is_sidechain")
|
|
455
|
+
data["feats"]["is_sidechain"] = torch.from_numpy(is_sidechain).to(
|
|
456
|
+
dtype=torch.bool
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
# Virtual atom feats
|
|
460
|
+
if "is_virtual" in self.X:
|
|
461
|
+
data["feats"]["is_virtual"] = (
|
|
462
|
+
atom_array.element == VIRTUAL_ATOM_ELEMENT_NAME
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
for x in [
|
|
466
|
+
"is_motif_atom_with_fixed_coord",
|
|
467
|
+
"is_motif_atom_with_fixed_seq",
|
|
468
|
+
"is_motif_atom_unindexed",
|
|
469
|
+
]:
|
|
470
|
+
if x not in self.X:
|
|
471
|
+
continue
|
|
472
|
+
if "atom" in x:
|
|
473
|
+
mask = atom_array.get_annotation(x).copy().astype(bool)
|
|
474
|
+
else:
|
|
475
|
+
mask = token_level_array.get_annotation(x).copy().astype(bool)
|
|
476
|
+
data["feats"][x] = mask
|
|
477
|
+
|
|
478
|
+
if "is_motif_token_with_fully_fixed_coord" in self.X:
|
|
479
|
+
mask = apply_token_wise(
|
|
480
|
+
atom_array,
|
|
481
|
+
atom_array.is_motif_atom_with_fixed_coord.astype(bool),
|
|
482
|
+
function=lambda x: np.all(x, axis=-1),
|
|
483
|
+
)
|
|
484
|
+
data["feats"]["is_motif_token_with_fully_fixed_coord"] = mask
|
|
485
|
+
|
|
486
|
+
# ... Central and CA
|
|
487
|
+
if "is_central" in self.X:
|
|
488
|
+
data["feats"]["is_central"] = _token_rep_mask
|
|
489
|
+
|
|
490
|
+
if "is_ca" in self.X:
|
|
491
|
+
# Split into components to handle separately
|
|
492
|
+
atom_array_indexed = atom_array[~atom_array.is_motif_atom_unindexed]
|
|
493
|
+
_token_rep_mask_indexed = get_af3_token_representative_masks(
|
|
494
|
+
atom_array_indexed, central_atom="CA"
|
|
495
|
+
)
|
|
496
|
+
if atom_array.is_motif_atom_unindexed.any():
|
|
497
|
+
atom_array_unindexed = atom_array[atom_array.is_motif_atom_unindexed]
|
|
498
|
+
|
|
499
|
+
# Ensure is_ca represents one and the first atom only for unindexed tokens
|
|
500
|
+
def first_nonzero(n):
|
|
501
|
+
assert n > 0
|
|
502
|
+
x = np.zeros(n, dtype=bool)
|
|
503
|
+
x[0] = 1
|
|
504
|
+
return x
|
|
505
|
+
|
|
506
|
+
starts = get_token_starts(atom_array_unindexed, add_exclusive_stop=True)
|
|
507
|
+
_token_rep_mask_unindexed = np.concatenate(
|
|
508
|
+
[
|
|
509
|
+
first_nonzero(end - start)
|
|
510
|
+
for start, end in zip(starts[:-1], starts[1:])
|
|
511
|
+
]
|
|
512
|
+
)
|
|
513
|
+
_token_rep_mask = np.concatenate(
|
|
514
|
+
[
|
|
515
|
+
_token_rep_mask_indexed,
|
|
516
|
+
_token_rep_mask_unindexed,
|
|
517
|
+
],
|
|
518
|
+
axis=0,
|
|
519
|
+
)
|
|
520
|
+
else:
|
|
521
|
+
_token_rep_mask = _token_rep_mask_indexed
|
|
522
|
+
data["feats"]["is_ca"] = _token_rep_mask
|
|
523
|
+
|
|
524
|
+
return data
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
PPI_PERTURB_SCALE = 2.0
|
|
528
|
+
PPI_PERTURB_COM_SCALE = 1.5
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
class MotifCenterRandomAugmentation(Transform):
|
|
532
|
+
requires_previous_transforms = ["BatchStructuresForDiffusionNoising"]
|
|
533
|
+
|
|
534
|
+
def __init__(
|
|
535
|
+
self,
|
|
536
|
+
batch_size,
|
|
537
|
+
sigma_perturb,
|
|
538
|
+
center_option,
|
|
539
|
+
):
|
|
540
|
+
"""
|
|
541
|
+
Randomly augments the coordinates of the motif center for diffusion training.
|
|
542
|
+
During inference, this behaviour is handled by the sampler at every step
|
|
543
|
+
"""
|
|
544
|
+
|
|
545
|
+
self.batch_size = batch_size
|
|
546
|
+
self.sigma_perturb = sigma_perturb
|
|
547
|
+
self.center_option = center_option
|
|
548
|
+
|
|
549
|
+
def check_input(self, data: dict):
|
|
550
|
+
pass
|
|
551
|
+
|
|
552
|
+
def forward(self, data):
|
|
553
|
+
"""
|
|
554
|
+
Applies CenterRandomAugmentation
|
|
555
|
+
|
|
556
|
+
And supplies the same rotated ground-truth coordinates as the input feature
|
|
557
|
+
"""
|
|
558
|
+
if data["is_inference"]:
|
|
559
|
+
return data # ori token behaviour set when creating atom array & in sampler
|
|
560
|
+
|
|
561
|
+
xyz = data["coord_atom_lvl_to_be_noised"] # (batch_size, n_atoms, 3)
|
|
562
|
+
mask_atom_lvl = data["ground_truth"]["mask_atom_lvl"]
|
|
563
|
+
mask_atom_lvl = (
|
|
564
|
+
mask_atom_lvl & ~data["feats"]["is_motif_atom_unindexed"]
|
|
565
|
+
) # Avoid double weighting
|
|
566
|
+
|
|
567
|
+
# Handle the diffferent centering options
|
|
568
|
+
is_motif_atom_with_fixed_coord = torch.tensor(
|
|
569
|
+
data["atom_array"].is_motif_atom_with_fixed_coord, dtype=torch.bool
|
|
570
|
+
)
|
|
571
|
+
if torch.any(is_motif_atom_with_fixed_coord):
|
|
572
|
+
if self.center_option == "motif":
|
|
573
|
+
center_mask = is_motif_atom_with_fixed_coord.clone()
|
|
574
|
+
elif self.center_option == "diffuse":
|
|
575
|
+
center_mask = (~is_motif_atom_with_fixed_coord).clone()
|
|
576
|
+
else:
|
|
577
|
+
center_mask = torch.ones(mask_atom_lvl.shape, dtype=torch.bool)
|
|
578
|
+
else:
|
|
579
|
+
center_mask = torch.ones(mask_atom_lvl.shape, dtype=torch.bool)
|
|
580
|
+
|
|
581
|
+
mask_atom_lvl = mask_atom_lvl & center_mask
|
|
582
|
+
mask_atom_lvl_expanded = mask_atom_lvl.expand(xyz.shape[0], -1)
|
|
583
|
+
|
|
584
|
+
# Masked center during training (nb not motif mask - just non-zero occupancy)
|
|
585
|
+
xyz = masked_center(xyz, mask_atom_lvl_expanded)
|
|
586
|
+
|
|
587
|
+
# Random offset
|
|
588
|
+
sigma_perturb = self.sigma_perturb
|
|
589
|
+
if data["sampled_condition_name"] == "ppi":
|
|
590
|
+
sigma_perturb = sigma_perturb * PPI_PERTURB_SCALE
|
|
591
|
+
|
|
592
|
+
xyz = (
|
|
593
|
+
xyz
|
|
594
|
+
+ torch.randn(
|
|
595
|
+
(
|
|
596
|
+
self.batch_size,
|
|
597
|
+
3,
|
|
598
|
+
),
|
|
599
|
+
device=xyz.device,
|
|
600
|
+
)[:, None, :]
|
|
601
|
+
* self.sigma_perturb
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
# Apply random spin
|
|
605
|
+
xyz = random_rigid_augmentation(xyz, batch_size=self.batch_size, s=0)
|
|
606
|
+
data["coord_atom_lvl_to_be_noised"] = xyz
|
|
607
|
+
|
|
608
|
+
return data
|
|
609
|
+
|
|
610
|
+
|
|
611
|
+
class AugmentNoise(Transform):
|
|
612
|
+
requires_previous_transforms = ["SampleEDMNoise", "AddIsXFeats"]
|
|
613
|
+
|
|
614
|
+
def __init__(
|
|
615
|
+
self,
|
|
616
|
+
sigma_perturb_com,
|
|
617
|
+
batch_size,
|
|
618
|
+
):
|
|
619
|
+
"""
|
|
620
|
+
Scaled perturbation to the offset between motif and diffused region based on time
|
|
621
|
+
"""
|
|
622
|
+
self.sigma_perturb_com = sigma_perturb_com
|
|
623
|
+
self.batch_size = batch_size
|
|
624
|
+
|
|
625
|
+
def check_input(self, data: dict):
|
|
626
|
+
check_contains_keys(data, ["noise", "coord_atom_lvl_to_be_noised"])
|
|
627
|
+
check_contains_keys(data, ["feats"])
|
|
628
|
+
|
|
629
|
+
def forward(self, data: dict) -> dict:
|
|
630
|
+
is_motif_atom_with_fixed_coord = data["feats"]["is_motif_atom_with_fixed_coord"]
|
|
631
|
+
device = data["coord_atom_lvl_to_be_noised"].device
|
|
632
|
+
data["noise"][..., is_motif_atom_with_fixed_coord, :] = 0.0
|
|
633
|
+
|
|
634
|
+
# Add perturbation to the centre-of-mass
|
|
635
|
+
if not data["is_inference"] or not is_motif_atom_with_fixed_coord.any():
|
|
636
|
+
sigma_perturb_com = self.sigma_perturb_com
|
|
637
|
+
if data["sampled_condition_name"] == "ppi":
|
|
638
|
+
sigma_perturb_com = sigma_perturb_com * PPI_PERTURB_COM_SCALE
|
|
639
|
+
eps = torch.randn(self.batch_size, 3, device=device) * sigma_perturb_com
|
|
640
|
+
maxt = 38
|
|
641
|
+
eps = eps * torch.clip((data["t"][:, None] / maxt) ** 3, min=0, max=1)
|
|
642
|
+
data["noise"][..., ~is_motif_atom_with_fixed_coord, :] += eps[:, None, :]
|
|
643
|
+
|
|
644
|
+
# ... Zero out noise going into motif
|
|
645
|
+
data["noise"][..., is_motif_atom_with_fixed_coord, :] = 0.0
|
|
646
|
+
|
|
647
|
+
assert data["coord_atom_lvl_to_be_noised"].shape == data["noise"].shape
|
|
648
|
+
return data
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
class AddGroundTruthSequence(Transform):
|
|
652
|
+
"""
|
|
653
|
+
Adds token level sequence to the ground truth.
|
|
654
|
+
|
|
655
|
+
Adds:
|
|
656
|
+
['ground_truth']['seq_token_lvl'] (torch.Tensor): The ground truth token level sequence [L,]
|
|
657
|
+
"""
|
|
658
|
+
|
|
659
|
+
def __init__(self, sequence_encoding):
|
|
660
|
+
self.sequence_encoding = sequence_encoding
|
|
661
|
+
|
|
662
|
+
def check_input(self, data):
|
|
663
|
+
check_contains_keys(data, ["atom_array"])
|
|
664
|
+
|
|
665
|
+
def forward(self, data: dict) -> dict:
|
|
666
|
+
atom_array = data["atom_array"]
|
|
667
|
+
token_starts = get_token_starts(atom_array)
|
|
668
|
+
res_names = atom_array.res_name[token_starts]
|
|
669
|
+
restype = self.sequence_encoding.encode(res_names)
|
|
670
|
+
|
|
671
|
+
if "ground_truth" not in data:
|
|
672
|
+
data["ground_truth"] = {}
|
|
673
|
+
|
|
674
|
+
data["ground_truth"]["sequence_gt_I"] = torch.from_numpy(restype)
|
|
675
|
+
data["ground_truth"]["sequence_valid_mask"] = torch.from_numpy(
|
|
676
|
+
~np.isin(res_names, ["UNK", "X", "DX", "<G>"])
|
|
677
|
+
)
|
|
678
|
+
|
|
679
|
+
return data
|
|
680
|
+
|
|
681
|
+
|
|
682
|
+
class AddAdditional1dFeaturesToFeats(Transform):
|
|
683
|
+
"""
|
|
684
|
+
Adds any net.token_initializer.token_1d_features and net.diffusion_module.diffusion_atom_encoder.atom_1d_features present in the atomarray but not in data['feats'] to data['feats']
|
|
685
|
+
Args:
|
|
686
|
+
- autofill_zeros_if_not_present_in_atomarray: self explanatory
|
|
687
|
+
- token_1d_features: List of single-item dictionaries, corresponding to feature_name: n_feature_dims. Should be hydra interpolated from
|
|
688
|
+
net.token_initializer.token_1d_features
|
|
689
|
+
- atom_1d_features: List of single-item dictionaries, corresponding to feature_name: n_feature_dims. Should be hydra interpolated from
|
|
690
|
+
net.diffusion_module.diffusion_atom_encoder.atom_1d_features
|
|
691
|
+
"""
|
|
692
|
+
|
|
693
|
+
incompatible_previous_transforms = ["AddAdditional1dFeaturesToFeats"]
|
|
694
|
+
|
|
695
|
+
def __init__(
|
|
696
|
+
self,
|
|
697
|
+
token_1d_features,
|
|
698
|
+
atom_1d_features,
|
|
699
|
+
autofill_zeros_if_not_present_in_atomarray=False,
|
|
700
|
+
):
|
|
701
|
+
self.autofill = autofill_zeros_if_not_present_in_atomarray
|
|
702
|
+
self.token_1d_features = token_1d_features
|
|
703
|
+
self.atom_1d_features = atom_1d_features
|
|
704
|
+
|
|
705
|
+
def check_input(self, data) -> None:
|
|
706
|
+
check_contains_keys(data, ["atom_array"])
|
|
707
|
+
check_is_instance(data, "atom_array", AtomArray)
|
|
708
|
+
|
|
709
|
+
def generate_feature(self, feature_name, n_dims, data, feature_type):
|
|
710
|
+
if feature_name in data["feats"].keys():
|
|
711
|
+
return data
|
|
712
|
+
elif feature_name in data["atom_array"].get_annotation_categories():
|
|
713
|
+
feature_array = torch.tensor(
|
|
714
|
+
data["atom_array"].get_annotation(feature_name)
|
|
715
|
+
).float()
|
|
716
|
+
|
|
717
|
+
# ensure that feature_array is a 2d array with second dim n_dims
|
|
718
|
+
if len(feature_array.shape) == 1 and n_dims == 1:
|
|
719
|
+
feature_array = feature_array.unsqueeze(1)
|
|
720
|
+
elif len(feature_array.shape) != 2:
|
|
721
|
+
raise ValueError(
|
|
722
|
+
f"{feature_type} 1d_feature `{feature_name}` must be a 2d array, got {len(feature_array.shape)}d."
|
|
723
|
+
)
|
|
724
|
+
if feature_array.shape[1] != n_dims:
|
|
725
|
+
raise ValueError(
|
|
726
|
+
f"{feature_type} 1d_feature `{feature_name}` dimensions in atomarray ({feature_array.shape[-1]}) does not match dimension declared in config, ({n_dims})"
|
|
727
|
+
)
|
|
728
|
+
|
|
729
|
+
elif self.autofill:
|
|
730
|
+
feature_array = torch.zeros((len(data["atom_array"]), n_dims)).float()
|
|
731
|
+
|
|
732
|
+
# not in data['feats'], not in atomarray, and autofill is off
|
|
733
|
+
else:
|
|
734
|
+
raise ValueError(
|
|
735
|
+
f"{feature_type} 1d_feature `{feature_name}` is not present in atomarray, and autofill is False"
|
|
736
|
+
)
|
|
737
|
+
|
|
738
|
+
if feature_type == "token":
|
|
739
|
+
feature_array = torch.tensor(
|
|
740
|
+
apply_token_wise(
|
|
741
|
+
data["atom_array"], feature_array.numpy(), np.mean, axis=0
|
|
742
|
+
)
|
|
743
|
+
).float()
|
|
744
|
+
|
|
745
|
+
data["feats"][feature_name] = feature_array
|
|
746
|
+
return data
|
|
747
|
+
|
|
748
|
+
def forward(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
749
|
+
"""
|
|
750
|
+
Checks if the 1d_features are present in data['feats']. If not present, adds them from the atomarray.
|
|
751
|
+
If annotation is not present in atomarray, either autofills the feature with 0s or throws an error
|
|
752
|
+
"""
|
|
753
|
+
if "feats" not in data.keys():
|
|
754
|
+
data["feats"] = {}
|
|
755
|
+
|
|
756
|
+
for feature_name, n_dims in self.token_1d_features.items():
|
|
757
|
+
data = self.generate_feature(feature_name, n_dims, data, "token")
|
|
758
|
+
|
|
759
|
+
for feature_name, n_dims in self.atom_1d_features.items():
|
|
760
|
+
data = self.generate_feature(feature_name, n_dims, data, "atom")
|
|
761
|
+
|
|
762
|
+
return data
|
|
763
|
+
|
|
764
|
+
|
|
765
|
+
class FeaturizepLDDT(Transform):
|
|
766
|
+
"""
|
|
767
|
+
Provides:
|
|
768
|
+
0 for unknown pLDDT
|
|
769
|
+
+1 for high pLDDT
|
|
770
|
+
-1 for low pLDDT
|
|
771
|
+
"""
|
|
772
|
+
|
|
773
|
+
def __init__(
|
|
774
|
+
self,
|
|
775
|
+
skip,
|
|
776
|
+
):
|
|
777
|
+
self.skip = skip
|
|
778
|
+
self.bsplit = 80 # Threshold for splitting pLDDT into high and low
|
|
779
|
+
|
|
780
|
+
def forward(self, data: dict) -> dict:
|
|
781
|
+
atom_array = data["atom_array"]
|
|
782
|
+
token_starts = get_token_starts(atom_array)
|
|
783
|
+
I = len(token_starts)
|
|
784
|
+
zeros = np.zeros(I, dtype=int)
|
|
785
|
+
if data["is_inference"]:
|
|
786
|
+
if "ref_plddt" not in atom_array.get_annotation_categories():
|
|
787
|
+
ref_plddt = zeros
|
|
788
|
+
else:
|
|
789
|
+
ref_plddt = atom_array.get_annotation("ref_plddt")[token_starts]
|
|
790
|
+
elif (
|
|
791
|
+
self.skip
|
|
792
|
+
or "b_factor" not in atom_array.get_annotation_categories()
|
|
793
|
+
or not data["conditions"]["featurize_plddt"]
|
|
794
|
+
):
|
|
795
|
+
ref_plddt = zeros
|
|
796
|
+
else:
|
|
797
|
+
plddt = atom_array.get_annotation("b_factor")
|
|
798
|
+
mean_plddt = np.mean(plddt)
|
|
799
|
+
ref_plddt = zeros + (1 if mean_plddt >= self.bsplit else -1)
|
|
800
|
+
|
|
801
|
+
# Provide only non-zero values for diffused tokens
|
|
802
|
+
ref_plddt = (
|
|
803
|
+
~(get_motif_features(atom_array)["is_motif_token"][token_starts])
|
|
804
|
+
* ref_plddt
|
|
805
|
+
)
|
|
806
|
+
data["feats"]["ref_plddt"] = ref_plddt
|
|
807
|
+
return data
|