rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
from atomworks.ml.transforms.base import Transform
|
|
4
|
+
from rfd3.inference.symmetry.frames import (
|
|
5
|
+
framecoords_to_RTs,
|
|
6
|
+
unpack_vector,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class AddSymmetryFeats(Transform):
|
|
11
|
+
"""
|
|
12
|
+
Add atom_array symmetry features to the data features.
|
|
13
|
+
Arguments:
|
|
14
|
+
symmetry_features: The atom_array symmetry features to add to the data features.
|
|
15
|
+
Returns:
|
|
16
|
+
data: The data with the atom_array symmetry features added to the data features.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
symmetry_features=[
|
|
22
|
+
"sym_transform_id",
|
|
23
|
+
"sym_entity_id",
|
|
24
|
+
"is_sym_asu",
|
|
25
|
+
],
|
|
26
|
+
):
|
|
27
|
+
self.symmetry_feats = symmetry_features
|
|
28
|
+
|
|
29
|
+
def forward(self, data):
|
|
30
|
+
atom_array = data["atom_array"]
|
|
31
|
+
# Get frames from atom_array
|
|
32
|
+
transforms_dict = self.make_transforms_dict(atom_array)
|
|
33
|
+
data["feats"]["sym_transform"] = transforms_dict # {str(id): tuple (R,T)}
|
|
34
|
+
# Else, add symmetry features atomwise
|
|
35
|
+
for feature_name in self.symmetry_feats:
|
|
36
|
+
feature_array = atom_array.get_annotation(feature_name)
|
|
37
|
+
data["feats"][feature_name] = feature_array
|
|
38
|
+
return data
|
|
39
|
+
|
|
40
|
+
def make_transforms_dict(self, atom_array):
|
|
41
|
+
transforms_dict = {}
|
|
42
|
+
# get decomposed frames from atom array (unpacking the vectorized frames)
|
|
43
|
+
Oris = torch.tensor(
|
|
44
|
+
[
|
|
45
|
+
np.asarray(unpack_vector(Ori)).tolist()
|
|
46
|
+
for Ori in atom_array.get_annotation("sym_transform_Ori")
|
|
47
|
+
]
|
|
48
|
+
)
|
|
49
|
+
Xs = torch.tensor(
|
|
50
|
+
[
|
|
51
|
+
np.asarray(unpack_vector(X)).tolist()
|
|
52
|
+
for X in atom_array.get_annotation("sym_transform_X")
|
|
53
|
+
]
|
|
54
|
+
)
|
|
55
|
+
Ys = torch.tensor(
|
|
56
|
+
[
|
|
57
|
+
np.asarray(unpack_vector(Y)).tolist()
|
|
58
|
+
for Y in atom_array.get_annotation("sym_transform_Y")
|
|
59
|
+
]
|
|
60
|
+
)
|
|
61
|
+
TIDs = torch.from_numpy(atom_array.get_annotation("sym_transform_id"))
|
|
62
|
+
|
|
63
|
+
Oris = torch.unique_consecutive(Oris, dim=0)
|
|
64
|
+
Xs = torch.unique_consecutive(Xs, dim=0)
|
|
65
|
+
Ys = torch.unique_consecutive(Ys, dim=0)
|
|
66
|
+
TIDs = torch.unique_consecutive(TIDs, dim=0)
|
|
67
|
+
# the case in which there is only rotation (no translation), Ori = [0,0,0]
|
|
68
|
+
if len(Oris) == 1 and (Oris == 0).all():
|
|
69
|
+
Oris = Oris.repeat(len(Xs), 1)
|
|
70
|
+
Rs, Ts = framecoords_to_RTs(Oris, Xs, Ys)
|
|
71
|
+
|
|
72
|
+
for R, T, transform_id in zip(Rs, Ts, TIDs):
|
|
73
|
+
if transform_id.item() == -1:
|
|
74
|
+
continue
|
|
75
|
+
transforms_dict[str(transform_id.item())] = (R, T)
|
|
76
|
+
return transforms_dict
|
|
@@ -0,0 +1,552 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Class-based motif masking system
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
|
|
8
|
+
import networkx as nx
|
|
9
|
+
import numpy as np
|
|
10
|
+
from atomworks.ml.utils.token import (
|
|
11
|
+
apply_token_wise,
|
|
12
|
+
get_token_starts,
|
|
13
|
+
spread_token_wise,
|
|
14
|
+
)
|
|
15
|
+
from biotite.structure import AtomArray, get_residue_starts
|
|
16
|
+
from rfd3.transforms.conditioning_utils import (
|
|
17
|
+
random_condition,
|
|
18
|
+
sample_island_tokens,
|
|
19
|
+
sample_subgraph_atoms,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
nx.from_numpy_matrix = nx.from_numpy_array
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
#################################################################################
|
|
27
|
+
# Transform for creating training conditions
|
|
28
|
+
#################################################################################
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class TrainingCondition(ABC):
|
|
32
|
+
"""
|
|
33
|
+
Base class for applying conditioning during training
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
name = None
|
|
37
|
+
|
|
38
|
+
def __init__(self, frequency):
|
|
39
|
+
self.frequency = frequency
|
|
40
|
+
|
|
41
|
+
@abstractmethod
|
|
42
|
+
def is_valid_for_example(self, data) -> bool:
|
|
43
|
+
"""
|
|
44
|
+
Returns true whether this mask can be applied to the data instance
|
|
45
|
+
|
|
46
|
+
E.g. only use this transform if data metadata contains key or if data contains type
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def sample(self, data) -> AtomArray:
|
|
51
|
+
"""
|
|
52
|
+
Set which atoms should be made into tokens
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class IslandCondition(TrainingCondition):
|
|
57
|
+
"""
|
|
58
|
+
Select islands as motif and assign conditioning strategies.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
*,
|
|
64
|
+
name,
|
|
65
|
+
frequency,
|
|
66
|
+
island_sampling_kwargs,
|
|
67
|
+
p_diffuse_motif_sidechains,
|
|
68
|
+
p_diffuse_subgraph_atoms,
|
|
69
|
+
subgraph_sampling_kwargs,
|
|
70
|
+
p_fix_motif_coordinates,
|
|
71
|
+
p_fix_motif_sequence,
|
|
72
|
+
p_unindex_motif_tokens,
|
|
73
|
+
):
|
|
74
|
+
self.name = name
|
|
75
|
+
self.frequency = frequency
|
|
76
|
+
|
|
77
|
+
# Token selection
|
|
78
|
+
self.island_sampling_kwargs = island_sampling_kwargs
|
|
79
|
+
|
|
80
|
+
# Atom selection
|
|
81
|
+
self.p_diffuse_motif_sidechains = p_diffuse_motif_sidechains
|
|
82
|
+
self.p_include_oxygen_in_backbone_mask = 0.95
|
|
83
|
+
self.p_diffuse_subgraph_atoms = p_diffuse_subgraph_atoms
|
|
84
|
+
self.subgraph_sampling_kwargs = subgraph_sampling_kwargs
|
|
85
|
+
|
|
86
|
+
# Additional conditioning selection
|
|
87
|
+
self.p_fix_motif_coordinates = p_fix_motif_coordinates
|
|
88
|
+
self.p_fix_motif_sequence = p_fix_motif_sequence
|
|
89
|
+
self.p_unindex_motif_tokens = p_unindex_motif_tokens
|
|
90
|
+
|
|
91
|
+
def is_valid_for_example(self, data) -> bool:
|
|
92
|
+
is_protein = data["atom_array"].is_protein
|
|
93
|
+
if not np.any(is_protein):
|
|
94
|
+
return False
|
|
95
|
+
return True
|
|
96
|
+
|
|
97
|
+
def sample_motif_tokens(self, atom_array):
|
|
98
|
+
"""
|
|
99
|
+
Samples what tokens should be considered motif.
|
|
100
|
+
"""
|
|
101
|
+
token_level_array = atom_array[get_token_starts(atom_array)]
|
|
102
|
+
|
|
103
|
+
# initialize motif tokens as all non-protein tokens
|
|
104
|
+
is_motif_token = np.asarray(~token_level_array.is_protein, dtype=bool).copy()
|
|
105
|
+
n_protein_tokens = np.sum(token_level_array.is_protein)
|
|
106
|
+
islands_mask = sample_island_tokens(
|
|
107
|
+
n_protein_tokens,
|
|
108
|
+
**self.island_sampling_kwargs,
|
|
109
|
+
)
|
|
110
|
+
is_motif_token[token_level_array.is_protein] = islands_mask
|
|
111
|
+
|
|
112
|
+
# TODO: Atoms with covalent bonds should be motif, needs FlagAndReassignCovalentModifications transform prior to this
|
|
113
|
+
# atom_with_coval_bond = token_level_array.covale # (n_atoms, )
|
|
114
|
+
# is_motif_token[atom_with_coval_bond] = True
|
|
115
|
+
|
|
116
|
+
return spread_token_wise(atom_array, is_motif_token)
|
|
117
|
+
|
|
118
|
+
def sample_motif_atoms(self, atom_array):
|
|
119
|
+
"""
|
|
120
|
+
Samples which atoms in motif tokens should be masked.
|
|
121
|
+
This handles the case where you want the sidechain of a residue to not be motif.
|
|
122
|
+
|
|
123
|
+
Argument attrs:
|
|
124
|
+
- is_motif_token
|
|
125
|
+
- is_motif_atom_with_fixed_seq
|
|
126
|
+
"""
|
|
127
|
+
is_motif_atom = np.asarray(atom_array.is_motif_token, dtype=bool).copy()
|
|
128
|
+
|
|
129
|
+
if random_condition(self.p_diffuse_motif_sidechains):
|
|
130
|
+
backbone_atoms = ["N", "C", "CA"]
|
|
131
|
+
if random_condition(self.p_include_oxygen_in_backbone_mask):
|
|
132
|
+
backbone_atoms.append("O")
|
|
133
|
+
is_motif_atom = is_motif_atom & np.isin(
|
|
134
|
+
atom_array.atom_name, backbone_atoms
|
|
135
|
+
)
|
|
136
|
+
elif random_condition(self.p_diffuse_subgraph_atoms):
|
|
137
|
+
is_motif_atom = sample_motif_subgraphs(
|
|
138
|
+
atom_array=atom_array,
|
|
139
|
+
**self.subgraph_sampling_kwargs,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
# We also only want resolved atoms to be motif
|
|
143
|
+
is_motif_atom = (is_motif_atom) & (atom_array.occupancy > 0.0)
|
|
144
|
+
|
|
145
|
+
return is_motif_atom
|
|
146
|
+
|
|
147
|
+
def sample(self, data):
|
|
148
|
+
atom_array = data["atom_array"]
|
|
149
|
+
|
|
150
|
+
atom_array.set_annotation(
|
|
151
|
+
"is_motif_token", self.sample_motif_tokens(atom_array)
|
|
152
|
+
)
|
|
153
|
+
atom_array.set_annotation("is_motif_atom", self.sample_motif_atoms(atom_array))
|
|
154
|
+
|
|
155
|
+
# After selecting the motif, we need to decide what conditioning strategy to apply
|
|
156
|
+
atom_array = sample_conditioning_strategy(
|
|
157
|
+
atom_array,
|
|
158
|
+
p_fix_motif_sequence=self.p_fix_motif_sequence,
|
|
159
|
+
p_fix_motif_coordinates=self.p_fix_motif_coordinates,
|
|
160
|
+
p_unindex_motif_tokens=self.p_unindex_motif_tokens,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
atom_array.set_annotation(
|
|
164
|
+
"is_motif_atom_unindexed_motif_breakpoint",
|
|
165
|
+
sample_unindexed_breaks(
|
|
166
|
+
atom_array,
|
|
167
|
+
remove_random_break=data["conditions"]["unindex_remove_random_break"],
|
|
168
|
+
insert_random_break=data["conditions"]["unindex_insert_random_break"],
|
|
169
|
+
leak_global_index=data["conditions"]["unindex_leak_global_index"],
|
|
170
|
+
),
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
return atom_array
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class PPICondition(TrainingCondition):
|
|
177
|
+
"""Get condition indicating what is motif and what is to be diffused for protein-protein interaction training."""
|
|
178
|
+
|
|
179
|
+
name = "ppi"
|
|
180
|
+
|
|
181
|
+
def is_valid_for_example(self, data):
|
|
182
|
+
# Extract relevant data
|
|
183
|
+
atom_array = data["atom_array"]
|
|
184
|
+
self.query_pn_unit_iids = data.get("query_pn_unit_iids")
|
|
185
|
+
|
|
186
|
+
# Compute protein pn_unit_iids
|
|
187
|
+
protein_pn_unit_iids = []
|
|
188
|
+
for pn_unit_iid in np.unique(atom_array.pn_unit_iid):
|
|
189
|
+
pn_unit_atom_array = atom_array[atom_array.pn_unit_iid == pn_unit_iid]
|
|
190
|
+
pn_unit_is_protein = np.unique(pn_unit_atom_array.is_protein)
|
|
191
|
+
|
|
192
|
+
if all(pn_unit_is_protein): # Exclude cases of chimeric ligands
|
|
193
|
+
protein_pn_unit_iids.append(pn_unit_iid)
|
|
194
|
+
|
|
195
|
+
# This mask is intended to operate on binary protein-protein interfaces
|
|
196
|
+
if (
|
|
197
|
+
self.query_pn_unit_iids is None
|
|
198
|
+
or len(self.query_pn_unit_iids) != 2
|
|
199
|
+
or len(np.unique(self.query_pn_unit_iids)) != 2
|
|
200
|
+
):
|
|
201
|
+
return False
|
|
202
|
+
|
|
203
|
+
elif not all(
|
|
204
|
+
[pn_unit in protein_pn_unit_iids for pn_unit in self.query_pn_unit_iids]
|
|
205
|
+
):
|
|
206
|
+
return False
|
|
207
|
+
|
|
208
|
+
else:
|
|
209
|
+
# Randomly select one of the two query pn_unit_iids to be the binder
|
|
210
|
+
# NOTE: Could also do this based on if only one will work uncropped, but since that
|
|
211
|
+
# strategy will not always be applied, enforcing it here would bias the training data.
|
|
212
|
+
binder_pn_unit = np.random.choice(self.query_pn_unit_iids)
|
|
213
|
+
data["binder_pn_unit"] = binder_pn_unit
|
|
214
|
+
atom_array.set_annotation(
|
|
215
|
+
"is_binder_pn_unit", atom_array.pn_unit_iid == binder_pn_unit
|
|
216
|
+
)
|
|
217
|
+
return True
|
|
218
|
+
|
|
219
|
+
# TODO: If I want to have multiple possible strategies for motif assignment (e.g. motif scaffolding for the binder)
|
|
220
|
+
# should probably just have this function sample between them with a set of probabilities specified in the config.
|
|
221
|
+
# Anything that makes it this far will have to be a valid PPI example with an assigned binder chain.
|
|
222
|
+
def sample(self, data):
|
|
223
|
+
atom_array = data["atom_array"]
|
|
224
|
+
|
|
225
|
+
# Set `is_motif_token`
|
|
226
|
+
# NOTE: In the future, we may want to diffuse part of the target or fix part of the binder
|
|
227
|
+
is_motif_token = atom_array.pn_unit_iid != data["binder_pn_unit"]
|
|
228
|
+
atom_array.set_annotation("is_motif_token", is_motif_token)
|
|
229
|
+
|
|
230
|
+
# Set `is_motif_atom_with_fixed_seq`
|
|
231
|
+
is_motif_atom_with_fixed_seq = (
|
|
232
|
+
is_motif_token.copy()
|
|
233
|
+
) # We fix the target sequence in binder design
|
|
234
|
+
atom_array.set_annotation(
|
|
235
|
+
"is_motif_atom_with_fixed_seq", is_motif_atom_with_fixed_seq
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
# Set `is_motif_atom`
|
|
239
|
+
is_motif_atom = (
|
|
240
|
+
is_motif_token.copy()
|
|
241
|
+
) # The PPI mask should apply to all or no atoms of a token
|
|
242
|
+
atom_array.set_annotation("is_motif_atom", is_motif_atom)
|
|
243
|
+
|
|
244
|
+
# Set `is_motif_atom_with_fixed_pos`
|
|
245
|
+
is_motif_atom_with_fixed_coord = (
|
|
246
|
+
is_motif_token.copy()
|
|
247
|
+
) # We fully fix the target atom positions (at least for now)
|
|
248
|
+
atom_array.set_annotation(
|
|
249
|
+
"is_motif_atom_with_fixed_coord", is_motif_atom_with_fixed_coord
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# Set `is_motif_atom_without_index`
|
|
253
|
+
is_motif_atom_unindexed = np.zeros_like(
|
|
254
|
+
is_motif_token
|
|
255
|
+
) # We want fixed indices for the target
|
|
256
|
+
atom_array.set_annotation("is_motif_atom_unindexed", is_motif_atom_unindexed)
|
|
257
|
+
|
|
258
|
+
# Set `is_motif_atom_unindexed_motif_breakpoint`
|
|
259
|
+
is_motif_atom_unindexed_motif_breakpoint = np.zeros_like(is_motif_token)
|
|
260
|
+
atom_array.set_annotation(
|
|
261
|
+
"is_motif_atom_unindexed_motif_breakpoint",
|
|
262
|
+
is_motif_atom_unindexed_motif_breakpoint,
|
|
263
|
+
)
|
|
264
|
+
return atom_array
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
##############################################################################################
|
|
268
|
+
# Additional conditioning classes
|
|
269
|
+
##############################################################################################
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
class SubtypeCondition(TrainingCondition):
|
|
273
|
+
"""
|
|
274
|
+
Selects specific subtypes of atoms as motif and assigns conditioning strategies.
|
|
275
|
+
"""
|
|
276
|
+
|
|
277
|
+
name = "subtype"
|
|
278
|
+
|
|
279
|
+
def __init__(self, frequency: float, subtype: list[str], fix_pos: bool = False):
|
|
280
|
+
self.frequency = frequency
|
|
281
|
+
self.subtype = subtype
|
|
282
|
+
self.fix_pos = fix_pos
|
|
283
|
+
|
|
284
|
+
def is_valid_for_example(self, data):
|
|
285
|
+
"""
|
|
286
|
+
For subtype conditioning, the example must contain the specified subtype
|
|
287
|
+
"""
|
|
288
|
+
is_subtypes = [
|
|
289
|
+
data["atom_array"].get_annotation(subtype).sum() for subtype in self.subtype
|
|
290
|
+
]
|
|
291
|
+
if not np.any(is_subtypes):
|
|
292
|
+
return False
|
|
293
|
+
return True
|
|
294
|
+
|
|
295
|
+
def sample(self, data):
|
|
296
|
+
atom_array = data["atom_array"]
|
|
297
|
+
|
|
298
|
+
is_motif = generate_subtype_mask(atom_array, self.subtype)
|
|
299
|
+
is_motif = prune_unresolved_motif(atom_array, is_motif)
|
|
300
|
+
atom_array.set_annotation("is_motif_token", is_motif)
|
|
301
|
+
atom_array.set_annotation("is_motif_atom", is_motif)
|
|
302
|
+
atom_array.set_annotation("is_motif_atom_with_fixed_seq", is_motif)
|
|
303
|
+
|
|
304
|
+
if self.fix_pos:
|
|
305
|
+
atom_array.set_annotation("is_motif_atom_with_fixed_coord", is_motif)
|
|
306
|
+
else:
|
|
307
|
+
atom_array.set_annotation(
|
|
308
|
+
"is_motif_atom_with_fixed_coord", np.zeros(len(atom_array), dtype=bool)
|
|
309
|
+
)
|
|
310
|
+
atom_array.set_annotation(
|
|
311
|
+
"is_motif_atom_unindexed", np.zeros(len(atom_array), dtype=bool)
|
|
312
|
+
)
|
|
313
|
+
atom_array.set_annotation(
|
|
314
|
+
"is_motif_atom_unindexed_motif_breakpoint",
|
|
315
|
+
np.zeros(len(atom_array), dtype=bool),
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
return atom_array
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
################# need mask -> condition refactor
|
|
322
|
+
def prune_unresolved_motif(atom_array, mask):
|
|
323
|
+
"""
|
|
324
|
+
Prune the mask to only include resolved atoms.
|
|
325
|
+
and for any residue that have unresolved atoms, set the whole residue to be False.
|
|
326
|
+
"""
|
|
327
|
+
# Get the indices of the atoms that are resolved
|
|
328
|
+
resolved_indices = np.where(atom_array.occupancy > 0.0)[0]
|
|
329
|
+
|
|
330
|
+
# Create a mask for the resolved atoms
|
|
331
|
+
resolved_mask = np.zeros_like(mask, dtype=bool)
|
|
332
|
+
resolved_mask[resolved_indices] = True
|
|
333
|
+
|
|
334
|
+
# Combine the original mask with the resolved mask
|
|
335
|
+
combined_mask = mask & resolved_mask
|
|
336
|
+
|
|
337
|
+
# Set the whole residue to be False if any atom in the residue is unresolved
|
|
338
|
+
token_ids = np.unique(atom_array.token_id)
|
|
339
|
+
for token_id in token_ids:
|
|
340
|
+
if np.any(~combined_mask[atom_array.token_id == token_id]):
|
|
341
|
+
combined_mask[atom_array.token_id == token_id] = False
|
|
342
|
+
return combined_mask
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
def generate_subtype_mask(atom_array, subtypes):
|
|
346
|
+
"""
|
|
347
|
+
Generate a mask for a specific subtype list of atoms.
|
|
348
|
+
E.g. is_protein, is_ligand, is_dna etc.
|
|
349
|
+
"""
|
|
350
|
+
all_masks = []
|
|
351
|
+
for subtype in subtypes:
|
|
352
|
+
if subtype not in atom_array.get_annotation_categories():
|
|
353
|
+
raise ValueError(f"Subtype {subtype} not found in atom array annotations.")
|
|
354
|
+
mask = atom_array.get_annotation(subtype)
|
|
355
|
+
all_masks.append(mask)
|
|
356
|
+
# Combine all masks using logical OR
|
|
357
|
+
combined_mask = np.logical_or.reduce(all_masks)
|
|
358
|
+
return combined_mask
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
##############################################################################################
|
|
362
|
+
# Shared assignment functions
|
|
363
|
+
##############################################################################################
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def sample_motif_subgraphs(
|
|
367
|
+
atom_array,
|
|
368
|
+
residue_p_seed_furthest_from_o,
|
|
369
|
+
residue_n_bond_expectation,
|
|
370
|
+
hetatom_n_bond_expectation,
|
|
371
|
+
residue_p_fix_all,
|
|
372
|
+
hetatom_p_fix_all,
|
|
373
|
+
):
|
|
374
|
+
"""
|
|
375
|
+
Returns a boolean mask over atoms, indicating which atoms are part of the sampled motif.
|
|
376
|
+
Sampling is performed per residue, with sidechains optionally excluded based on bond-based neighborhood sampling.
|
|
377
|
+
|
|
378
|
+
Handles both protein residues and heteroatoms (e.g., ligands).
|
|
379
|
+
|
|
380
|
+
Args:
|
|
381
|
+
atom_array: AtomArray with annotations is_motif_token, is_protein, occupancy, res_id.
|
|
382
|
+
|
|
383
|
+
Returns:
|
|
384
|
+
is_motif_atom: np.ndarray of shape (n_atoms,) with True for sampled motif atoms.
|
|
385
|
+
"""
|
|
386
|
+
is_motif_token = atom_array.is_motif_token.copy()
|
|
387
|
+
is_motif_atom = is_motif_token.copy()
|
|
388
|
+
idxs = np.arange(atom_array.array_length(), dtype=int)
|
|
389
|
+
starts = get_residue_starts(atom_array, add_exclusive_stop=True)
|
|
390
|
+
|
|
391
|
+
for i, (start, end) in enumerate(zip(starts[:-1], starts[1:])):
|
|
392
|
+
if not is_motif_token[start]:
|
|
393
|
+
continue
|
|
394
|
+
|
|
395
|
+
# Get the atoms of the current residue
|
|
396
|
+
subset_mask = np.isin(idxs, idxs[start:end])
|
|
397
|
+
atom_array_subset = atom_array[subset_mask]
|
|
398
|
+
assert atom_array_subset.array_length() > 0
|
|
399
|
+
|
|
400
|
+
args = {
|
|
401
|
+
"p_seed_furthest_from_o": residue_p_seed_furthest_from_o,
|
|
402
|
+
"n_bond_expectation": residue_n_bond_expectation,
|
|
403
|
+
"p_fix_all": residue_p_fix_all,
|
|
404
|
+
}
|
|
405
|
+
if not atom_array_subset.is_protein.all():
|
|
406
|
+
args.update(
|
|
407
|
+
{
|
|
408
|
+
"p_seed_furthest_from_o": 0.0,
|
|
409
|
+
"n_bond_expectation": hetatom_n_bond_expectation,
|
|
410
|
+
"p_fix_all": hetatom_p_fix_all,
|
|
411
|
+
}
|
|
412
|
+
)
|
|
413
|
+
try:
|
|
414
|
+
mask = sample_subgraph_atoms(atom_array_subset, **args)
|
|
415
|
+
except Exception as e:
|
|
416
|
+
logger.warning(
|
|
417
|
+
f"Failed to sample subgraph motif atoms for {atom_array_subset.res_name[0]}. Error: {e}"
|
|
418
|
+
)
|
|
419
|
+
mask = np.ones(atom_array_subset.array_length(), dtype=bool)
|
|
420
|
+
|
|
421
|
+
is_motif_atom[subset_mask] = mask
|
|
422
|
+
|
|
423
|
+
# We also only want resolved atoms to be motif
|
|
424
|
+
is_motif_atom = (is_motif_atom) & (atom_array.occupancy > 0.0)
|
|
425
|
+
|
|
426
|
+
return is_motif_atom
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
def sample_conditioning_strategy(
|
|
430
|
+
atom_array,
|
|
431
|
+
p_fix_motif_sequence,
|
|
432
|
+
p_fix_motif_coordinates,
|
|
433
|
+
p_unindex_motif_tokens,
|
|
434
|
+
):
|
|
435
|
+
atom_array.set_annotation(
|
|
436
|
+
"is_motif_atom_with_fixed_seq",
|
|
437
|
+
sample_is_motif_atom_with_fixed_seq(
|
|
438
|
+
atom_array, p_fix_motif_sequence=p_fix_motif_sequence
|
|
439
|
+
),
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
atom_array.set_annotation(
|
|
443
|
+
"is_motif_atom_with_fixed_coord",
|
|
444
|
+
sample_fix_motif_coordinates(
|
|
445
|
+
atom_array, p_fix_motif_coordinates=p_fix_motif_coordinates
|
|
446
|
+
),
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
atom_array.set_annotation(
|
|
450
|
+
"is_motif_atom_unindexed",
|
|
451
|
+
sample_unindexed_atoms(
|
|
452
|
+
atom_array, p_unindex_motif_tokens=p_unindex_motif_tokens
|
|
453
|
+
),
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
return atom_array
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def sample_is_motif_atom_with_fixed_seq(atom_array, p_fix_motif_sequence):
|
|
460
|
+
"""
|
|
461
|
+
Samples what kind of conditioning to apply to motif tokens.
|
|
462
|
+
|
|
463
|
+
Argument attrs:
|
|
464
|
+
- is_motif_token
|
|
465
|
+
"""
|
|
466
|
+
if random_condition(p_fix_motif_sequence):
|
|
467
|
+
is_motif_atom_with_fixed_seq = atom_array.is_motif_token.copy()
|
|
468
|
+
else:
|
|
469
|
+
is_motif_atom_with_fixed_seq = np.zeros(atom_array.array_length(), dtype=bool)
|
|
470
|
+
|
|
471
|
+
# By default reveal sequence for non-protein
|
|
472
|
+
is_motif_atom_with_fixed_seq = is_motif_atom_with_fixed_seq | ~atom_array.is_protein
|
|
473
|
+
return is_motif_atom_with_fixed_seq
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
def sample_fix_motif_coordinates(atom_array, p_fix_motif_coordinates):
|
|
477
|
+
"""
|
|
478
|
+
Universal function to decide if atoms' coords are fixed in the point cloud for conditioning.
|
|
479
|
+
|
|
480
|
+
Argument attrs:
|
|
481
|
+
- is_motif_atom_with_fixed_coord
|
|
482
|
+
"""
|
|
483
|
+
if random_condition(p_fix_motif_coordinates):
|
|
484
|
+
is_motif_atom_with_fixed_coord = atom_array.is_motif_atom.copy()
|
|
485
|
+
else:
|
|
486
|
+
is_motif_atom_with_fixed_coord = np.zeros(atom_array.array_length(), dtype=bool)
|
|
487
|
+
return is_motif_atom_with_fixed_coord
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
def sample_unindexed_atoms(atom_array, p_unindex_motif_tokens):
|
|
491
|
+
"""
|
|
492
|
+
Samples which atoms in motif tokens should be flagged for unindexing.
|
|
493
|
+
|
|
494
|
+
Argument attrs:
|
|
495
|
+
- is_motif_atom_unindexed
|
|
496
|
+
"""
|
|
497
|
+
if random_condition(p_unindex_motif_tokens):
|
|
498
|
+
is_motif_atom_unindexed = atom_array.is_motif_atom.copy()
|
|
499
|
+
else:
|
|
500
|
+
is_motif_atom_unindexed = np.zeros(atom_array.array_length(), dtype=bool)
|
|
501
|
+
|
|
502
|
+
# ensure non-residue atoms are not already flagged
|
|
503
|
+
is_motif_atom_unindexed = np.logical_and(
|
|
504
|
+
is_motif_atom_unindexed, atom_array.is_residue
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
return is_motif_atom_unindexed
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
def sample_unindexed_breaks(
|
|
511
|
+
atom_array,
|
|
512
|
+
remove_random_break=False,
|
|
513
|
+
insert_random_break=False,
|
|
514
|
+
leak_global_index=False,
|
|
515
|
+
):
|
|
516
|
+
is_unindexed_token = apply_token_wise(
|
|
517
|
+
atom_array,
|
|
518
|
+
atom_array.is_motif_atom_unindexed.copy(),
|
|
519
|
+
function=lambda x: np.any(x),
|
|
520
|
+
)
|
|
521
|
+
starts = get_token_starts(atom_array)
|
|
522
|
+
token_idxs = np.arange(len(starts))
|
|
523
|
+
breaks_all = np.zeros(len(starts), dtype=bool)
|
|
524
|
+
|
|
525
|
+
if is_unindexed_token.sum() == 1:
|
|
526
|
+
breaks_all = is_unindexed_token
|
|
527
|
+
elif np.any(is_unindexed_token):
|
|
528
|
+
# ... Subset to unindexed tokens
|
|
529
|
+
unindexed_token_starts = starts[is_unindexed_token]
|
|
530
|
+
unindexed_token_resid = atom_array[unindexed_token_starts].res_id
|
|
531
|
+
breaks = np.diff(unindexed_token_resid) != 1 # (M-1,)
|
|
532
|
+
|
|
533
|
+
# ... Connect discontiguous regions
|
|
534
|
+
if remove_random_break and np.any(breaks):
|
|
535
|
+
break_idx = np.random.choice(np.flatnonzero(breaks), size=1, replace=False)
|
|
536
|
+
breaks[break_idx] = False
|
|
537
|
+
|
|
538
|
+
# ... Disconnect contiguous regions
|
|
539
|
+
if insert_random_break:
|
|
540
|
+
break_idx = np.random.choice(np.arange(len(breaks)), size=1, replace=False)
|
|
541
|
+
breaks[break_idx] = True
|
|
542
|
+
|
|
543
|
+
breaks[0] = True
|
|
544
|
+
breaks = np.concatenate([np.array([False], dtype=bool), breaks])
|
|
545
|
+
|
|
546
|
+
# ... Remove all breaks to leak global indices:
|
|
547
|
+
if leak_global_index:
|
|
548
|
+
breaks = False
|
|
549
|
+
|
|
550
|
+
breaks_all[token_idxs[is_unindexed_token]] = breaks
|
|
551
|
+
|
|
552
|
+
return spread_token_wise(atom_array, breaks_all)
|