rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,347 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains the FeaturizeUserSettings transform that sets
|
|
3
|
+
mode-specific and common user features required by MPNN models.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from atomworks.io.utils.atom_array_plus import AtomArrayPlus
|
|
10
|
+
from atomworks.io.utils.selection import get_annotation
|
|
11
|
+
from atomworks.ml.transforms._checks import (
|
|
12
|
+
check_atom_array_annotation,
|
|
13
|
+
)
|
|
14
|
+
from atomworks.ml.transforms.base import Transform
|
|
15
|
+
from atomworks.ml.utils.token import (
|
|
16
|
+
get_token_starts,
|
|
17
|
+
spread_token_wise,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class FeaturizeUserSettings(Transform):
|
|
22
|
+
"""
|
|
23
|
+
Transform for featurizing user settings to MPNN model inputs.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
is_inference: bool = False,
|
|
29
|
+
minimal_return: bool = False,
|
|
30
|
+
train_structure_noise_default: float = 0.1,
|
|
31
|
+
):
|
|
32
|
+
"""
|
|
33
|
+
Initialize the FeaturizeUserSettings transform.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
is_inference (bool): Whether this is inference mode. Defaults to
|
|
37
|
+
False (training mode).
|
|
38
|
+
minimal_return (bool): Whether to return minimal intermediate
|
|
39
|
+
features. Defaults to False.
|
|
40
|
+
train_structure_noise_default (float): Default standard
|
|
41
|
+
deviation of Gaussian noise to add to atomic coordinates during
|
|
42
|
+
training for data augmentation. Defaults to 0.1.
|
|
43
|
+
"""
|
|
44
|
+
self.is_inference = is_inference
|
|
45
|
+
self.minimal_return = minimal_return
|
|
46
|
+
self.train_structure_noise_default = train_structure_noise_default
|
|
47
|
+
|
|
48
|
+
def check_input(self, data: dict[str, Any]) -> None:
|
|
49
|
+
"""Check that atomize annotation exists in the data."""
|
|
50
|
+
check_atom_array_annotation(data, ["atomize"])
|
|
51
|
+
|
|
52
|
+
# Check that the scalar user settings have the correct types.
|
|
53
|
+
if data.get("structure_noise", None) is not None:
|
|
54
|
+
if not isinstance(data["structure_noise"], (float, int)):
|
|
55
|
+
raise TypeError("structure_noise must be a float or int")
|
|
56
|
+
|
|
57
|
+
if data.get("decode_type", None) is not None:
|
|
58
|
+
if not isinstance(data["decode_type"], str):
|
|
59
|
+
raise TypeError("decode_type must be a string")
|
|
60
|
+
|
|
61
|
+
if data.get("causality_pattern", None) is not None:
|
|
62
|
+
if not isinstance(data["causality_pattern"], str):
|
|
63
|
+
raise TypeError("causality_pattern must be a string")
|
|
64
|
+
|
|
65
|
+
if (
|
|
66
|
+
data.get("initialize_sequence_embedding_with_ground_truth", None)
|
|
67
|
+
is not None
|
|
68
|
+
):
|
|
69
|
+
if not isinstance(
|
|
70
|
+
data["initialize_sequence_embedding_with_ground_truth"], bool
|
|
71
|
+
):
|
|
72
|
+
raise TypeError(
|
|
73
|
+
"initialize_sequence_embedding_with_ground_truth must be a bool"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
if data.get("atomize_side_chains", None) is not None:
|
|
77
|
+
if not isinstance(data["atomize_side_chains"], bool):
|
|
78
|
+
raise TypeError("atomize_side_chains must be a bool")
|
|
79
|
+
|
|
80
|
+
if data.get("repeat_sample_num", None) is not None:
|
|
81
|
+
if not isinstance(data["repeat_sample_num"], int):
|
|
82
|
+
raise TypeError("repeat_sample_num must be an int")
|
|
83
|
+
|
|
84
|
+
if data.get("features_to_return", None) is not None:
|
|
85
|
+
if not isinstance(data["features_to_return"], dict):
|
|
86
|
+
raise TypeError("features_to_return must be a dict")
|
|
87
|
+
for key, value in data["features_to_return"].items():
|
|
88
|
+
if not isinstance(key, str):
|
|
89
|
+
raise TypeError("features_to_return keys must be strings")
|
|
90
|
+
if not isinstance(value, list):
|
|
91
|
+
raise TypeError("features_to_return values must be lists")
|
|
92
|
+
|
|
93
|
+
# Check that the array-wide user settings are consistent across all
|
|
94
|
+
# atoms in each token.
|
|
95
|
+
atom_array = data["atom_array"]
|
|
96
|
+
token_starts = get_token_starts(atom_array)
|
|
97
|
+
token_level_array = atom_array[token_starts]
|
|
98
|
+
keys_to_check = [
|
|
99
|
+
"mpnn_designed_residue_mask",
|
|
100
|
+
"mpnn_temperature",
|
|
101
|
+
"mpnn_symmetry_equivalence_group",
|
|
102
|
+
"mpnn_symmetry_weight",
|
|
103
|
+
"mpnn_bias",
|
|
104
|
+
]
|
|
105
|
+
for key in keys_to_check:
|
|
106
|
+
atom_values = get_annotation(atom_array, key)
|
|
107
|
+
if atom_values is not None:
|
|
108
|
+
token_values = get_annotation(token_level_array, key)
|
|
109
|
+
if not np.all(
|
|
110
|
+
atom_values == spread_token_wise(atom_array, token_values)
|
|
111
|
+
):
|
|
112
|
+
raise ValueError(
|
|
113
|
+
f"All atoms in each token must have the same value for {key}"
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Check pair keys such that token-level pairs are unique.
|
|
117
|
+
pair_keys_to_check = [
|
|
118
|
+
"mpnn_pair_bias",
|
|
119
|
+
]
|
|
120
|
+
token_idx = spread_token_wise(atom_array, np.arange(len(token_level_array)))
|
|
121
|
+
# Only validate 2D annotations if atom_array supports them
|
|
122
|
+
if isinstance(atom_array, AtomArrayPlus):
|
|
123
|
+
annotations_2d = atom_array.get_annotation_2d_categories()
|
|
124
|
+
for key in pair_keys_to_check:
|
|
125
|
+
if key in annotations_2d:
|
|
126
|
+
annotation = atom_array.get_annotation_2d(key)
|
|
127
|
+
pairs = annotation.pairs
|
|
128
|
+
seen_token_pairs = set()
|
|
129
|
+
for i, j in pairs:
|
|
130
|
+
token_pair = (token_idx[i], token_idx[j])
|
|
131
|
+
if token_pair in seen_token_pairs:
|
|
132
|
+
raise ValueError(
|
|
133
|
+
f"Token-level pairs must be unique for {key}"
|
|
134
|
+
" i.e. token pairs should be represented using "
|
|
135
|
+
"only one atom pair across the tokens."
|
|
136
|
+
)
|
|
137
|
+
seen_token_pairs.add(token_pair)
|
|
138
|
+
|
|
139
|
+
def forward(self, data: dict[str, Any]) -> dict[str, Any]:
|
|
140
|
+
"""Apply user settings to the input features."""
|
|
141
|
+
# +-------- Scalar User Settings --------- +
|
|
142
|
+
# structure_noise (float): the standard deviation of the Gaussian
|
|
143
|
+
# noise to add to the input coordinates, in Angstroms.
|
|
144
|
+
structure_noise = data.get("structure_noise", None)
|
|
145
|
+
if structure_noise is None:
|
|
146
|
+
structure_noise = (
|
|
147
|
+
0.0 if self.is_inference else self.train_structure_noise_default
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
# decode_type (str): the type of decoding to use.
|
|
151
|
+
# - "teacher_forcing": Use teacher forcing for the
|
|
152
|
+
# decoder, where the decoder attends to the ground
|
|
153
|
+
# truth sequence S for all previously decoded
|
|
154
|
+
# residues.
|
|
155
|
+
# - "auto_regressive": Use auto-regressive decoding,
|
|
156
|
+
# where the decoder attends to the sequence and
|
|
157
|
+
# decoder representation of residues that have
|
|
158
|
+
# already been decoded (using the predicted sequence).
|
|
159
|
+
decode_type = data.get("decode_type", None)
|
|
160
|
+
if decode_type is None:
|
|
161
|
+
decode_type = "auto_regressive" if self.is_inference else "teacher_forcing"
|
|
162
|
+
|
|
163
|
+
# causality_pattern (str): The pattern of causality to use for the
|
|
164
|
+
# decoder. For all causality patterns, the decoding order is randomized.
|
|
165
|
+
# - "auto_regressive": Use an auto-regressive causality
|
|
166
|
+
# pattern, where residues can attend to the sequence
|
|
167
|
+
# and decoder representation of residues that have
|
|
168
|
+
# already been decoded (NOTE: as mentioned above,
|
|
169
|
+
# this will be randomized).
|
|
170
|
+
# - "unconditional": Residues cannot attend to the
|
|
171
|
+
# sequence or decoder representation of any other
|
|
172
|
+
# residues.
|
|
173
|
+
# - "conditional": Residues can attend to the sequence
|
|
174
|
+
# and decoder representation of all other residues.
|
|
175
|
+
# - "conditional_minus_self": Residues can attend to the
|
|
176
|
+
# sequence and decoder representation of all other
|
|
177
|
+
# residues, except for themselves (as destination
|
|
178
|
+
# nodes).
|
|
179
|
+
causality_pattern = data.get("causality_pattern", None)
|
|
180
|
+
if causality_pattern is None:
|
|
181
|
+
causality_pattern = "auto_regressive"
|
|
182
|
+
|
|
183
|
+
# initialize_sequence_embedding_with_ground_truth (bool):
|
|
184
|
+
# - True: Initialize the sequence embedding with the ground truth
|
|
185
|
+
# sequence S.
|
|
186
|
+
# - If doing auto-regressive decoding, also
|
|
187
|
+
# initialize S_sampled with the ground truth
|
|
188
|
+
# sequence S, which should only affect the
|
|
189
|
+
# application of pair bias.
|
|
190
|
+
# - False: Initialize the sequence embedding with zeros.
|
|
191
|
+
# - If doing auto-regressive decoding, initialize
|
|
192
|
+
# S_sampled with unknown residues.
|
|
193
|
+
initialize_sequence_embedding_with_ground_truth = data.get(
|
|
194
|
+
"initialize_sequence_embedding_with_ground_truth", None
|
|
195
|
+
)
|
|
196
|
+
if initialize_sequence_embedding_with_ground_truth is None:
|
|
197
|
+
initialize_sequence_embedding_with_ground_truth = (
|
|
198
|
+
False if self.is_inference else True
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
# atomize_side_chains (bool): Whether to atomize side chains of fixed
|
|
202
|
+
# residues.
|
|
203
|
+
atomize_side_chains = data.get("atomize_side_chains", None)
|
|
204
|
+
if atomize_side_chains is None:
|
|
205
|
+
if data["model_type"] == "ligand_mpnn":
|
|
206
|
+
atomize_side_chains = False if self.is_inference else True
|
|
207
|
+
else:
|
|
208
|
+
atomize_side_chains = False
|
|
209
|
+
|
|
210
|
+
# repeat_sample_num (int, optional): Number of times to
|
|
211
|
+
# repeat the samples along the batch dimension. If None,
|
|
212
|
+
# no repetition is performed. If greater than 1, the
|
|
213
|
+
# samples are repeated along the batch dimension. If
|
|
214
|
+
# greater than 1, B must be 1, since repeating samples
|
|
215
|
+
# along the batch dimension is not supported when more
|
|
216
|
+
# than one sample is provided in the batch.
|
|
217
|
+
# NOTE: default is None, so no conditional needed.
|
|
218
|
+
repeat_sample_num = data.get("repeat_sample_num", None)
|
|
219
|
+
|
|
220
|
+
# features_to_return (dict, optional): dictionary
|
|
221
|
+
# determining which features to return from the model. If
|
|
222
|
+
# None, return all features (including modified input
|
|
223
|
+
# features, graph features, encoder features, and decoder
|
|
224
|
+
# features). Otherwise, expects a dictionary with the
|
|
225
|
+
# following key, value pairs:
|
|
226
|
+
# - "input_features": list - the input features to return.
|
|
227
|
+
# - "graph_features": list - the graph features to return.
|
|
228
|
+
# - "encoder_features": list - the encoder features to
|
|
229
|
+
# return.
|
|
230
|
+
# - "decoder_features": list - the decoder features to
|
|
231
|
+
# return.
|
|
232
|
+
features_to_return = data.get("features_to_return", None)
|
|
233
|
+
if features_to_return is None:
|
|
234
|
+
if self.minimal_return:
|
|
235
|
+
features_to_return = {
|
|
236
|
+
"input_features": [
|
|
237
|
+
"mask_for_loss",
|
|
238
|
+
],
|
|
239
|
+
"decoder_features": ["log_probs", "S_sampled", "S_argmax"],
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
# Save the scalar settings.
|
|
243
|
+
data["input_features"].update(
|
|
244
|
+
{
|
|
245
|
+
"structure_noise": structure_noise,
|
|
246
|
+
"decode_type": decode_type,
|
|
247
|
+
"causality_pattern": causality_pattern,
|
|
248
|
+
"initialize_sequence_embedding_with_ground_truth": initialize_sequence_embedding_with_ground_truth,
|
|
249
|
+
"atomize_side_chains": atomize_side_chains,
|
|
250
|
+
"repeat_sample_num": repeat_sample_num,
|
|
251
|
+
"features_to_return": features_to_return,
|
|
252
|
+
}
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
# +-------- Array-Wide User Settings --------- +
|
|
256
|
+
# Extract atom array.
|
|
257
|
+
atom_array = data["atom_array"]
|
|
258
|
+
|
|
259
|
+
# Subset to non-atomized.
|
|
260
|
+
non_atomized_array = atom_array[~atom_array.atomize]
|
|
261
|
+
|
|
262
|
+
# Get token starts for non-atomized tokens.
|
|
263
|
+
non_atomized_token_starts = get_token_starts(non_atomized_array)
|
|
264
|
+
non_atomized_token_level = non_atomized_array[non_atomized_token_starts]
|
|
265
|
+
|
|
266
|
+
# Project token indices for non-atomized tokens.
|
|
267
|
+
non_atomized_token_idx = spread_token_wise(
|
|
268
|
+
non_atomized_array, np.arange(len(non_atomized_token_level))
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
if get_annotation(non_atomized_array, "mpnn_designed_residue_mask") is not None:
|
|
272
|
+
designed_residue_mask = (
|
|
273
|
+
non_atomized_token_level.mpnn_designed_residue_mask.astype(np.bool_)
|
|
274
|
+
)
|
|
275
|
+
else:
|
|
276
|
+
designed_residue_mask = None
|
|
277
|
+
|
|
278
|
+
if get_annotation(non_atomized_array, "mpnn_temperature") is not None:
|
|
279
|
+
temperature = non_atomized_token_level.mpnn_temperature.astype(np.float32)
|
|
280
|
+
else:
|
|
281
|
+
if self.is_inference:
|
|
282
|
+
temperature = 0.1 * np.ones(
|
|
283
|
+
len(non_atomized_token_level), dtype=np.float32
|
|
284
|
+
)
|
|
285
|
+
else:
|
|
286
|
+
temperature = None
|
|
287
|
+
|
|
288
|
+
if (
|
|
289
|
+
get_annotation(non_atomized_array, "mpnn_symmetry_equivalence_group")
|
|
290
|
+
is not None
|
|
291
|
+
):
|
|
292
|
+
symmetry_equivalence_group = (
|
|
293
|
+
non_atomized_token_level.mpnn_symmetry_equivalence_group.astype(
|
|
294
|
+
np.int32
|
|
295
|
+
)
|
|
296
|
+
)
|
|
297
|
+
else:
|
|
298
|
+
symmetry_equivalence_group = None
|
|
299
|
+
|
|
300
|
+
if get_annotation(non_atomized_array, "mpnn_symmetry_weight") is not None:
|
|
301
|
+
symmetry_weight = non_atomized_token_level.mpnn_symmetry_weight.astype(
|
|
302
|
+
np.float32
|
|
303
|
+
)
|
|
304
|
+
else:
|
|
305
|
+
symmetry_weight = None
|
|
306
|
+
|
|
307
|
+
if get_annotation(non_atomized_array, "mpnn_bias") is not None:
|
|
308
|
+
bias = non_atomized_token_level.mpnn_bias.astype(np.float32)
|
|
309
|
+
else:
|
|
310
|
+
bias = None
|
|
311
|
+
|
|
312
|
+
if (
|
|
313
|
+
isinstance(non_atomized_array, AtomArrayPlus)
|
|
314
|
+
and "mpnn_pair_bias" in non_atomized_array.get_annotation_2d_categories()
|
|
315
|
+
):
|
|
316
|
+
pair_bias_sparse = non_atomized_array.get_annotation_2d("mpnn_pair_bias")
|
|
317
|
+
pair_bias = np.zeros(
|
|
318
|
+
(
|
|
319
|
+
len(non_atomized_token_level),
|
|
320
|
+
pair_bias_sparse.values.shape[1],
|
|
321
|
+
len(non_atomized_token_level),
|
|
322
|
+
pair_bias_sparse.values.shape[2],
|
|
323
|
+
),
|
|
324
|
+
dtype=np.float32,
|
|
325
|
+
)
|
|
326
|
+
for idx in range(pair_bias_sparse.values.shape[0]):
|
|
327
|
+
i, j, pair_bias_ij = pair_bias_sparse[idx]
|
|
328
|
+
token_idx_i = non_atomized_token_idx[i]
|
|
329
|
+
token_idx_j = non_atomized_token_idx[j]
|
|
330
|
+
pair_bias[token_idx_i, :, token_idx_j, :] = pair_bias_ij
|
|
331
|
+
|
|
332
|
+
else:
|
|
333
|
+
pair_bias = None
|
|
334
|
+
|
|
335
|
+
# Save the array-wide settings.
|
|
336
|
+
data["input_features"].update(
|
|
337
|
+
{
|
|
338
|
+
"designed_residue_mask": designed_residue_mask,
|
|
339
|
+
"temperature": temperature,
|
|
340
|
+
"symmetry_equivalence_group": symmetry_equivalence_group,
|
|
341
|
+
"symmetry_weight": symmetry_weight,
|
|
342
|
+
"bias": bias,
|
|
343
|
+
"pair_bias": pair_bias,
|
|
344
|
+
}
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
return data
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utilities for computing polymer-ligand interface atoms.
|
|
3
|
+
|
|
4
|
+
This module provides a transform to identify and annotate polymer atoms that
|
|
5
|
+
are at the interface with ligand molecules, defined as atoms within a specified
|
|
6
|
+
distance threshold.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
from atomworks.ml.transforms._checks import check_atom_array_annotation
|
|
13
|
+
from atomworks.ml.transforms.base import Transform
|
|
14
|
+
from biotite.structure import AtomArray, CellList
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ComputePolymerLigandInterface(Transform):
|
|
18
|
+
"""
|
|
19
|
+
Compute polymer and ligand atoms at the polymer-ligand interface and
|
|
20
|
+
annotate the atom array with interface labels.
|
|
21
|
+
|
|
22
|
+
An interface atom is defined as any polymer atom that is within the
|
|
23
|
+
distance_threshold of any ligand atom, or vice versa.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
distance_threshold (float): Maximum distance in Angstroms for
|
|
27
|
+
considering atoms to be at the interface.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, distance_threshold: float):
|
|
31
|
+
self.distance_threshold = distance_threshold
|
|
32
|
+
|
|
33
|
+
def check_input(self, data: dict[str, Any]) -> None:
|
|
34
|
+
"""Check that required annotations are present."""
|
|
35
|
+
check_atom_array_annotation(
|
|
36
|
+
{"atom_array": data["atom_array"]}, required=["element", "atomize"]
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
def forward(self, data: dict[str, Any]) -> dict[str, Any]:
|
|
40
|
+
"""Compute polymer-ligand interface and update atom array."""
|
|
41
|
+
atom_array = data["atom_array"]
|
|
42
|
+
|
|
43
|
+
# Create a copy to avoid modifying the original.
|
|
44
|
+
result_array = atom_array.copy()
|
|
45
|
+
|
|
46
|
+
# Identify polymer and ligand atoms
|
|
47
|
+
polymer_mask, ligand_mask = self._identify_polymer_and_ligand_atoms(
|
|
48
|
+
result_array
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
# If no valid atoms, return empty annotations.
|
|
52
|
+
if not np.any(polymer_mask) or not np.any(ligand_mask):
|
|
53
|
+
# If no polymer or ligand atoms found, return empty annotations.
|
|
54
|
+
result_array.set_annotation(
|
|
55
|
+
"at_polymer_ligand_interface",
|
|
56
|
+
np.zeros(result_array.array_length(), dtype=bool),
|
|
57
|
+
)
|
|
58
|
+
else:
|
|
59
|
+
# Extract coordinates for interface calculation
|
|
60
|
+
polymer_atoms = result_array[polymer_mask]
|
|
61
|
+
ligand_atoms = result_array[ligand_mask]
|
|
62
|
+
|
|
63
|
+
# Compute interface atoms using efficient spatial search.
|
|
64
|
+
(polymer_interface_indices, ligand_interface_indices) = (
|
|
65
|
+
self._compute_interface_atoms(
|
|
66
|
+
polymer_atoms,
|
|
67
|
+
ligand_atoms,
|
|
68
|
+
polymer_mask,
|
|
69
|
+
ligand_mask,
|
|
70
|
+
self.distance_threshold,
|
|
71
|
+
)
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# Annotate the atom array with interface information
|
|
75
|
+
result_array = self._annotate_interface_results(
|
|
76
|
+
result_array, polymer_interface_indices, ligand_interface_indices
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
data["atom_array"] = result_array
|
|
80
|
+
return data
|
|
81
|
+
|
|
82
|
+
def _identify_polymer_and_ligand_atoms(
|
|
83
|
+
self, atom_array: AtomArray
|
|
84
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
85
|
+
"""Identify polymer and ligand atoms in the atom array."""
|
|
86
|
+
# Exclude atoms with invalid coordinates
|
|
87
|
+
has_valid_coords = (~np.isnan(atom_array.coord)).any(axis=1)
|
|
88
|
+
|
|
89
|
+
ligand_mask = atom_array.atomize & has_valid_coords
|
|
90
|
+
polymer_mask = ~atom_array.atomize & has_valid_coords
|
|
91
|
+
|
|
92
|
+
return polymer_mask, ligand_mask
|
|
93
|
+
|
|
94
|
+
def _compute_interface_atoms(
|
|
95
|
+
self,
|
|
96
|
+
polymer_atoms: AtomArray,
|
|
97
|
+
ligand_atoms: AtomArray,
|
|
98
|
+
polymer_mask: np.ndarray,
|
|
99
|
+
ligand_mask: np.ndarray,
|
|
100
|
+
distance_threshold: float,
|
|
101
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
102
|
+
"""
|
|
103
|
+
Compute interface atoms using spatial data structures.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
Tuple containing:
|
|
107
|
+
- polymer_indices: Global indices of polymer atoms at interface
|
|
108
|
+
- ligand_indices: Global indices of ligand atoms at interface
|
|
109
|
+
"""
|
|
110
|
+
# Build CellList for ligand atoms
|
|
111
|
+
ligand_cell_list = CellList(ligand_atoms, cell_size=distance_threshold)
|
|
112
|
+
|
|
113
|
+
# Find polymer atoms within threshold of any ligand.
|
|
114
|
+
polymer_at_interface_mask = ligand_cell_list.get_atoms(
|
|
115
|
+
polymer_atoms.coord, distance_threshold, as_mask=True
|
|
116
|
+
)
|
|
117
|
+
polymer_interface_local_indices = np.where(
|
|
118
|
+
np.any(polymer_at_interface_mask, axis=1)
|
|
119
|
+
)[0]
|
|
120
|
+
|
|
121
|
+
# Convert local indices to global indices.
|
|
122
|
+
global_polymer_indices = np.where(polymer_mask)[0]
|
|
123
|
+
polymer_interface_indices = global_polymer_indices[
|
|
124
|
+
polymer_interface_local_indices
|
|
125
|
+
]
|
|
126
|
+
|
|
127
|
+
# Build CellList for polymer atoms.
|
|
128
|
+
polymer_cell_list = CellList(polymer_atoms, cell_size=distance_threshold)
|
|
129
|
+
|
|
130
|
+
# Find ligand atoms within threshold of any polymer.
|
|
131
|
+
ligand_at_interface_mask = polymer_cell_list.get_atoms(
|
|
132
|
+
ligand_atoms.coord, distance_threshold, as_mask=True
|
|
133
|
+
)
|
|
134
|
+
ligand_interface_local_indices = np.where(
|
|
135
|
+
np.any(ligand_at_interface_mask, axis=1)
|
|
136
|
+
)[0]
|
|
137
|
+
|
|
138
|
+
# Convert local indices to global indices.
|
|
139
|
+
global_ligand_indices = np.where(ligand_mask)[0]
|
|
140
|
+
ligand_interface_indices = global_ligand_indices[ligand_interface_local_indices]
|
|
141
|
+
|
|
142
|
+
return (polymer_interface_indices, ligand_interface_indices)
|
|
143
|
+
|
|
144
|
+
def _annotate_interface_results(
|
|
145
|
+
self,
|
|
146
|
+
atom_array: AtomArray,
|
|
147
|
+
polymer_interface_indices: np.ndarray,
|
|
148
|
+
ligand_interface_indices: np.ndarray,
|
|
149
|
+
) -> AtomArray:
|
|
150
|
+
"""Annotate the atom array with interface calculation results."""
|
|
151
|
+
n_atoms = atom_array.array_length()
|
|
152
|
+
|
|
153
|
+
# Initialize interface annotations.
|
|
154
|
+
at_polymer_ligand_interface = np.zeros(n_atoms, dtype=bool)
|
|
155
|
+
|
|
156
|
+
# Mark interface atoms.
|
|
157
|
+
at_polymer_ligand_interface[polymer_interface_indices] = True
|
|
158
|
+
at_polymer_ligand_interface[ligand_interface_indices] = True
|
|
159
|
+
|
|
160
|
+
# Add annotation to atom array.
|
|
161
|
+
atom_array.set_annotation(
|
|
162
|
+
"at_polymer_ligand_interface", at_polymer_ligand_interface
|
|
163
|
+
)
|
|
164
|
+
return atom_array
|