rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
mpnn/trainers/mpnn.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from beartype.typing import Any
|
|
3
|
+
from lightning_utilities import apply_to_collection
|
|
4
|
+
from mpnn.loss.nll_loss import LabelSmoothedNLLLoss
|
|
5
|
+
from mpnn.metrics.nll import NLL, InterfaceNLL
|
|
6
|
+
from mpnn.metrics.sequence_recovery import (
|
|
7
|
+
InterfaceSequenceRecovery,
|
|
8
|
+
SequenceRecovery,
|
|
9
|
+
)
|
|
10
|
+
from mpnn.model.mpnn import LigandMPNN, ProteinMPNN
|
|
11
|
+
from omegaconf import DictConfig
|
|
12
|
+
|
|
13
|
+
from foundry.metrics.metric import MetricManager
|
|
14
|
+
from foundry.trainers.fabric import FabricTrainer
|
|
15
|
+
from foundry.utils.ddp import RankedLogger
|
|
16
|
+
from foundry.utils.torch import assert_no_nans
|
|
17
|
+
|
|
18
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class MPNNTrainer(FabricTrainer):
|
|
22
|
+
"""Standard Trainer for MPNN-style models"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
*,
|
|
27
|
+
model_type: str,
|
|
28
|
+
loss: DictConfig | dict | None = None,
|
|
29
|
+
metrics: DictConfig | dict | None = None,
|
|
30
|
+
**kwargs,
|
|
31
|
+
):
|
|
32
|
+
"""
|
|
33
|
+
See `FabricTrainer` for the additional initialization arguments.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
model_type (str): Type of model to use ("protein_mpnn" or
|
|
37
|
+
"ligand_mpnn").
|
|
38
|
+
loss (DictConfig | dict | None): Configuration for the loss
|
|
39
|
+
function. If None, default parameters will be used.
|
|
40
|
+
metrics (DictConfig | dict | None): Configuration for the metrics.
|
|
41
|
+
Ignored - metrics are hard-coded.
|
|
42
|
+
"""
|
|
43
|
+
super().__init__(**kwargs)
|
|
44
|
+
|
|
45
|
+
self.model_type = model_type
|
|
46
|
+
|
|
47
|
+
# Metrics
|
|
48
|
+
metrics = {
|
|
49
|
+
"nll": NLL(),
|
|
50
|
+
"sequence_recovery": SequenceRecovery(),
|
|
51
|
+
}
|
|
52
|
+
if self.model_type == "ligand_mpnn":
|
|
53
|
+
metrics["interface_nll"] = InterfaceNLL()
|
|
54
|
+
metrics["interface_sequence_recovery"] = InterfaceSequenceRecovery()
|
|
55
|
+
self.metrics = MetricManager(metrics)
|
|
56
|
+
|
|
57
|
+
# Loss
|
|
58
|
+
loss_params = loss if loss else {}
|
|
59
|
+
self.loss = LabelSmoothedNLLLoss(**loss_params)
|
|
60
|
+
|
|
61
|
+
def construct_model(self):
|
|
62
|
+
"""Construct the model with hard-coded parameters."""
|
|
63
|
+
with self.fabric.init_module():
|
|
64
|
+
ranked_logger.info(f"Instantiating {self.model_type} model...")
|
|
65
|
+
|
|
66
|
+
# Hard-coded model selection
|
|
67
|
+
if self.model_type == "protein_mpnn":
|
|
68
|
+
model = ProteinMPNN()
|
|
69
|
+
elif self.model_type == "ligand_mpnn":
|
|
70
|
+
model = LigandMPNN()
|
|
71
|
+
else:
|
|
72
|
+
raise ValueError(f"Invalid model type: {self.model_type}")
|
|
73
|
+
|
|
74
|
+
# Initialize model weights
|
|
75
|
+
model.apply(model.init_weights)
|
|
76
|
+
|
|
77
|
+
self.initialize_or_update_trainer_state({"model": model})
|
|
78
|
+
|
|
79
|
+
def training_step(
|
|
80
|
+
self,
|
|
81
|
+
batch: Any,
|
|
82
|
+
batch_idx: int,
|
|
83
|
+
is_accumulating: bool,
|
|
84
|
+
) -> None:
|
|
85
|
+
"""
|
|
86
|
+
Training step, running forward and backward passes.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
batch (Any): The current batch; can be of any form.
|
|
90
|
+
batch_idx (int): The index of the current batch.
|
|
91
|
+
is_accumulating (bool): Whether we are accumulating gradients
|
|
92
|
+
(i.e., not yet calling optimizer.step()). If this is the case,
|
|
93
|
+
we should skip the synchronization during the backward pass.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
None; we call `loss.backward()` directly, and store the outputs in
|
|
97
|
+
`self._current_train_return`.
|
|
98
|
+
"""
|
|
99
|
+
model = self.state["model"]
|
|
100
|
+
assert model.training, "Model must be training!"
|
|
101
|
+
|
|
102
|
+
network_input = batch
|
|
103
|
+
|
|
104
|
+
with self.fabric.no_backward_sync(model, enabled=is_accumulating):
|
|
105
|
+
# Forward pass
|
|
106
|
+
network_output = model.forward(network_input)
|
|
107
|
+
assert_no_nans(
|
|
108
|
+
network_output["decoder_features"],
|
|
109
|
+
msg="network_output['decoder_features'] "
|
|
110
|
+
+ f"for batch_idx: {batch_idx}",
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
total_loss, loss_dict = self.loss(
|
|
114
|
+
network_input=batch,
|
|
115
|
+
network_output=network_output,
|
|
116
|
+
loss_input={},
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Backward pass
|
|
120
|
+
self.fabric.backward(total_loss)
|
|
121
|
+
|
|
122
|
+
# Optionally compute training metrics
|
|
123
|
+
train_return = {"total_loss": total_loss, "loss_dict": loss_dict}
|
|
124
|
+
|
|
125
|
+
# Store the outputs without gradients for use in logging,
|
|
126
|
+
# callbacks, learning rate schedulers, etc.
|
|
127
|
+
self._current_train_return = apply_to_collection(
|
|
128
|
+
train_return,
|
|
129
|
+
dtype=torch.Tensor,
|
|
130
|
+
function=lambda x: x.detach(),
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
def validation_step(
|
|
134
|
+
self,
|
|
135
|
+
batch: Any,
|
|
136
|
+
batch_idx: int,
|
|
137
|
+
compute_metrics: bool = True,
|
|
138
|
+
) -> dict:
|
|
139
|
+
"""
|
|
140
|
+
Validation step, running forward pass and computing validation
|
|
141
|
+
metrics.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
batch (Any): The current batch; can be of any form.
|
|
145
|
+
batch_idx (int): The index of the current batch.
|
|
146
|
+
compute_metrics (bool): Whether to compute metrics. If False, we
|
|
147
|
+
will not compute metrics, and the output will be None. Set to
|
|
148
|
+
False during the inference pipeline, where we need the network
|
|
149
|
+
output but cannot compute metrics (since we do not have the
|
|
150
|
+
ground truth).
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
dict: Output dictionary containing the validation metrics and
|
|
154
|
+
network output.
|
|
155
|
+
"""
|
|
156
|
+
model = self.state["model"]
|
|
157
|
+
assert not model.training, "Model must be in evaluation mode during validation!"
|
|
158
|
+
|
|
159
|
+
network_input = batch
|
|
160
|
+
|
|
161
|
+
# Forward pass
|
|
162
|
+
network_output = model.forward(network_input)
|
|
163
|
+
|
|
164
|
+
assert_no_nans(
|
|
165
|
+
network_output["decoder_features"],
|
|
166
|
+
msg="network_output['decoder_features'] " + f"for batch_idx: {batch_idx}",
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
metrics_output = {}
|
|
170
|
+
if compute_metrics:
|
|
171
|
+
# Compute all metrics using MetricManager
|
|
172
|
+
metrics_output = self.metrics(
|
|
173
|
+
network_input=batch,
|
|
174
|
+
network_output=network_output,
|
|
175
|
+
extra_info={},
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
# Avoid gradients in stored values to prevent memory leaks
|
|
179
|
+
if metrics_output:
|
|
180
|
+
metrics_output = apply_to_collection(
|
|
181
|
+
metrics_output, torch.Tensor, lambda x: x.detach()
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
network_output = apply_to_collection(
|
|
185
|
+
network_output, torch.Tensor, lambda x: x.detach()
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
validation_return = {
|
|
189
|
+
"metrics_output": metrics_output,
|
|
190
|
+
"network_output": network_output,
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
return validation_return
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from atomworks.common import KeyToIntMapper
|
|
5
|
+
from atomworks.ml.transforms._checks import (
|
|
6
|
+
check_atom_array_annotation,
|
|
7
|
+
)
|
|
8
|
+
from atomworks.ml.transforms.base import Transform
|
|
9
|
+
from atomworks.ml.transforms.encoding import atom_array_to_encoding
|
|
10
|
+
from atomworks.ml.utils.token import get_token_starts
|
|
11
|
+
from mpnn.transforms.feature_aggregation.token_encodings import MPNN_TOKEN_ENCODING
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class EncodeMPNNNonAtomizedTokens(Transform):
|
|
15
|
+
"""Encode non-atomized tokens for MPNN with X, X_m, and S features.
|
|
16
|
+
|
|
17
|
+
Creates:
|
|
18
|
+
- X: (L, 37, 3) coordinates for non-atomized tokens
|
|
19
|
+
- X_m: (L, 37) mask for atom existence and occupancy > occupancy_threshold
|
|
20
|
+
- S: (L) sequence encoding
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
occupancy_threshold (float): Minimum occupancy to consider atom as present. Defaults to 0.5.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, occupancy_threshold: float = 0.5):
|
|
27
|
+
self.occupancy_threshold = occupancy_threshold
|
|
28
|
+
self.encoding = MPNN_TOKEN_ENCODING
|
|
29
|
+
|
|
30
|
+
def check_input(self, data: dict[str, Any]) -> None:
|
|
31
|
+
check_atom_array_annotation(data, ["atomize", "res_name", "occupancy"])
|
|
32
|
+
|
|
33
|
+
def forward(self, data: dict[str, Any]) -> dict[str, Any]:
|
|
34
|
+
atom_array = data["atom_array"]
|
|
35
|
+
|
|
36
|
+
# Check that atom_array is not empty.
|
|
37
|
+
assert len(atom_array) > 0, "atom_array cannot be empty"
|
|
38
|
+
|
|
39
|
+
# Get non-atomized tokens only
|
|
40
|
+
non_atomized_mask = ~atom_array.atomize
|
|
41
|
+
non_atomized_array = atom_array[non_atomized_mask]
|
|
42
|
+
|
|
43
|
+
assert len(non_atomized_array) > 0, "No non-atomized atoms found"
|
|
44
|
+
|
|
45
|
+
if len(non_atomized_array) == 0:
|
|
46
|
+
# No non-atomized tokens, create empty arrays
|
|
47
|
+
data["input_features"].update(
|
|
48
|
+
{
|
|
49
|
+
"X": np.zeros((0, 37, 3), dtype=np.float32),
|
|
50
|
+
"X_m": np.zeros((0, 37), dtype=np.bool_),
|
|
51
|
+
"S": np.zeros((0,), dtype=np.int64),
|
|
52
|
+
}
|
|
53
|
+
)
|
|
54
|
+
return data
|
|
55
|
+
|
|
56
|
+
# Encode using the MPNN token encoding
|
|
57
|
+
encoded = atom_array_to_encoding(
|
|
58
|
+
non_atomized_array,
|
|
59
|
+
encoding=self.encoding,
|
|
60
|
+
default_coord=0.0, # Use 0.0 instead of NaN for MPNN
|
|
61
|
+
occupancy_threshold=self.occupancy_threshold,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Create X: coordinates (L, 37, 3)
|
|
65
|
+
X = encoded["xyz"].astype(np.float32)
|
|
66
|
+
|
|
67
|
+
# Create X_m: mask for existence and occupancy > threshold (L, 37)
|
|
68
|
+
# encoded["mask"] already considers occupancy, we just need to check if atoms exist in encoding
|
|
69
|
+
X_m = encoded["mask"].astype(np.bool_)
|
|
70
|
+
|
|
71
|
+
# Create S: sequence encoding (L,)
|
|
72
|
+
S = encoded["seq"].astype(np.int64)
|
|
73
|
+
|
|
74
|
+
data["input_features"].update(
|
|
75
|
+
{
|
|
76
|
+
"X": X,
|
|
77
|
+
"X_m": X_m,
|
|
78
|
+
"S": S,
|
|
79
|
+
}
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# Check that we have at least one non-atomized token.
|
|
83
|
+
L = X.shape[0]
|
|
84
|
+
assert L > 0, "At least one non-atomized token should be present"
|
|
85
|
+
|
|
86
|
+
return data
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class FeaturizeNonAtomizedTokens(Transform):
|
|
90
|
+
"""Add additional features for non-atomized tokens: R_idx, chain_labels, residue_mask."""
|
|
91
|
+
|
|
92
|
+
def check_input(self, data: dict[str, Any]) -> None:
|
|
93
|
+
check_atom_array_annotation(
|
|
94
|
+
data, ["atomize", "within_chain_res_idx", "chain_iid"]
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
def forward(self, data: dict[str, Any]) -> dict[str, Any]:
|
|
98
|
+
atom_array = data["atom_array"]
|
|
99
|
+
|
|
100
|
+
# Subset to non-atomized.
|
|
101
|
+
non_atomized_array = atom_array[~atom_array.atomize]
|
|
102
|
+
|
|
103
|
+
# Get token starts for non-atomized tokens
|
|
104
|
+
non_atomized_token_starts = get_token_starts(non_atomized_array)
|
|
105
|
+
non_atomized_token_level = non_atomized_array[non_atomized_token_starts]
|
|
106
|
+
|
|
107
|
+
if len(non_atomized_token_level) == 0:
|
|
108
|
+
# No non-atomized tokens
|
|
109
|
+
data["input_features"].update(
|
|
110
|
+
{
|
|
111
|
+
"R_idx": np.zeros((0,), dtype=np.int32),
|
|
112
|
+
"chain_labels": np.zeros((0,), dtype=np.int64),
|
|
113
|
+
"residue_mask": np.zeros((0,), dtype=np.bool_),
|
|
114
|
+
}
|
|
115
|
+
)
|
|
116
|
+
return data
|
|
117
|
+
|
|
118
|
+
# R_idx: residue indices within chains (0-indexed)
|
|
119
|
+
R_idx = non_atomized_token_level.within_chain_res_idx.astype(np.int32)
|
|
120
|
+
|
|
121
|
+
# chain_labels: convert chain_iid to unique integers
|
|
122
|
+
chain_mapper = KeyToIntMapper()
|
|
123
|
+
chain_labels = np.array(
|
|
124
|
+
[
|
|
125
|
+
chain_mapper(chain_iid)
|
|
126
|
+
for chain_iid in non_atomized_token_level.chain_iid
|
|
127
|
+
],
|
|
128
|
+
dtype=np.int64,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
# residue_mask: all 1's for non-atomized tokens
|
|
132
|
+
residue_mask = np.ones(len(non_atomized_token_level), dtype=np.bool_)
|
|
133
|
+
|
|
134
|
+
data["input_features"].update(
|
|
135
|
+
{
|
|
136
|
+
"R_idx": R_idx,
|
|
137
|
+
"chain_labels": chain_labels,
|
|
138
|
+
"residue_mask": residue_mask,
|
|
139
|
+
}
|
|
140
|
+
)
|
|
141
|
+
return data
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class FeaturizeAtomizedTokens(Transform):
|
|
145
|
+
"""Add features for atomized tokens: Y, Y_t, Y_m."""
|
|
146
|
+
|
|
147
|
+
def check_input(self, data: dict[str, Any]) -> None:
|
|
148
|
+
check_atom_array_annotation(data, ["atomize", "atomic_number"])
|
|
149
|
+
|
|
150
|
+
def forward(self, data: dict[str, Any]) -> dict[str, Any]:
|
|
151
|
+
atom_array = data["atom_array"]
|
|
152
|
+
|
|
153
|
+
# Get atomized tokens only
|
|
154
|
+
atomized_mask = atom_array.atomize
|
|
155
|
+
atomized_array = atom_array[atomized_mask]
|
|
156
|
+
|
|
157
|
+
if len(atomized_array) == 0:
|
|
158
|
+
# No atomized tokens
|
|
159
|
+
data["input_features"].update(
|
|
160
|
+
{
|
|
161
|
+
"Y": np.zeros((0, 3), dtype=np.float32),
|
|
162
|
+
"Y_t": np.zeros((0,), dtype=np.int32),
|
|
163
|
+
"Y_m": np.zeros((0,), dtype=np.bool_),
|
|
164
|
+
}
|
|
165
|
+
)
|
|
166
|
+
return data
|
|
167
|
+
|
|
168
|
+
# Y: coordinates of atomized tokens (n_atomized, 3)
|
|
169
|
+
Y = atomized_array.coord.astype(np.float32)
|
|
170
|
+
|
|
171
|
+
# Y_t: atomic numbers of atomized tokens (n_atomized,)
|
|
172
|
+
Y_t = atomized_array.atomic_number.astype(np.int32)
|
|
173
|
+
|
|
174
|
+
# Y_m: mask for atomized tokens (all 1's since they exist) (n_atomized,)
|
|
175
|
+
Y_m = np.ones(len(atomized_array), dtype=np.bool_)
|
|
176
|
+
|
|
177
|
+
data["input_features"].update(
|
|
178
|
+
{
|
|
179
|
+
"Y": Y,
|
|
180
|
+
"Y_t": Y_t,
|
|
181
|
+
"Y_m": Y_m,
|
|
182
|
+
}
|
|
183
|
+
)
|
|
184
|
+
return data
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Feature aggregation for polymer-ligand interface masks.
|
|
3
|
+
|
|
4
|
+
This module provides transforms to compute interface masks for polymer residues
|
|
5
|
+
that are at the interface with ligand molecules.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
from atomworks.ml.transforms._checks import check_atom_array_annotation
|
|
12
|
+
from atomworks.ml.transforms.base import Transform
|
|
13
|
+
from atomworks.ml.utils.token import get_token_starts
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class FeaturizePolymerLigandInterfaceMask(Transform):
|
|
17
|
+
"""
|
|
18
|
+
Compute a polymer mask indicating which residues are at the polymer-ligand
|
|
19
|
+
interface.
|
|
20
|
+
|
|
21
|
+
This transform processes an atom array to identify polymer residues that
|
|
22
|
+
have any atoms within the specified distance threshold of ligand atoms.
|
|
23
|
+
It expects that the atom array already has the
|
|
24
|
+
'at_polymer_ligand_interface' annotation computed by the
|
|
25
|
+
ComputePolymerLigandInterface transform.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def check_input(self, data: dict[str, Any]) -> None:
|
|
29
|
+
"""Check that required annotations are present."""
|
|
30
|
+
check_atom_array_annotation(
|
|
31
|
+
{"atom_array": data["atom_array"]},
|
|
32
|
+
required=["element", "atomize", "at_polymer_ligand_interface"],
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
def forward(self, data: dict[str, Any]) -> dict[str, Any]:
|
|
36
|
+
"""Compute polymer-ligand interface mask and add to input_features."""
|
|
37
|
+
atom_array = data["atom_array"]
|
|
38
|
+
|
|
39
|
+
# Get interface annotation that should already be computed
|
|
40
|
+
interface_atoms = atom_array.at_polymer_ligand_interface
|
|
41
|
+
|
|
42
|
+
# Get token starts to map atoms to residues
|
|
43
|
+
token_starts = get_token_starts(atom_array)
|
|
44
|
+
|
|
45
|
+
# Create residue-level interface mask for all tokens
|
|
46
|
+
all_residue_interface_mask = np.zeros(len(token_starts), dtype=bool)
|
|
47
|
+
|
|
48
|
+
# For each token (residue), check if any of its atoms are at the
|
|
49
|
+
# interface.
|
|
50
|
+
for i, start_idx in enumerate(token_starts):
|
|
51
|
+
if i < len(token_starts) - 1:
|
|
52
|
+
end_idx = token_starts[i + 1]
|
|
53
|
+
else:
|
|
54
|
+
end_idx = len(atom_array)
|
|
55
|
+
|
|
56
|
+
# Check if any atom in this residue is at the interface
|
|
57
|
+
residue_atoms = interface_atoms[start_idx:end_idx]
|
|
58
|
+
all_residue_interface_mask[i] = np.any(residue_atoms)
|
|
59
|
+
|
|
60
|
+
# Get token-level atomize annotation
|
|
61
|
+
token_level_array = atom_array[token_starts]
|
|
62
|
+
non_atomized_mask = ~token_level_array.atomize
|
|
63
|
+
|
|
64
|
+
# Get interface mask for non-atomized residues only
|
|
65
|
+
polymer_interface_mask = all_residue_interface_mask[non_atomized_mask]
|
|
66
|
+
|
|
67
|
+
# Initialize input_features if it doesn't exist.
|
|
68
|
+
if "input_features" not in data:
|
|
69
|
+
data["input_features"] = {}
|
|
70
|
+
|
|
71
|
+
# Add the interface mask to input_features
|
|
72
|
+
data["input_features"]["polymer_ligand_interface_mask"] = (
|
|
73
|
+
polymer_interface_mask.astype(np.bool_)
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
return data
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
from atomworks.constants import AA_LIKE_CHEM_TYPES, STANDARD_AA, UNKNOWN_AA
|
|
2
|
+
from atomworks.ml.encoding_definitions import TokenEncoding
|
|
3
|
+
|
|
4
|
+
# Token ordering for MPNN.
|
|
5
|
+
token_order = STANDARD_AA + (UNKNOWN_AA,)
|
|
6
|
+
|
|
7
|
+
# Token ordering for old versions of MPNN.
|
|
8
|
+
legacy_token_order = (
|
|
9
|
+
"ALA",
|
|
10
|
+
"CYS",
|
|
11
|
+
"ASP",
|
|
12
|
+
"GLU",
|
|
13
|
+
"PHE",
|
|
14
|
+
"GLY",
|
|
15
|
+
"HIS",
|
|
16
|
+
"ILE",
|
|
17
|
+
"LYS",
|
|
18
|
+
"LEU",
|
|
19
|
+
"MET",
|
|
20
|
+
"ASN",
|
|
21
|
+
"PRO",
|
|
22
|
+
"GLN",
|
|
23
|
+
"ARG",
|
|
24
|
+
"SER",
|
|
25
|
+
"THR",
|
|
26
|
+
"VAL",
|
|
27
|
+
"TRP",
|
|
28
|
+
"TYR",
|
|
29
|
+
"UNK",
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
# Atom ordering for new versions of MPNN.
|
|
33
|
+
atom_order = (
|
|
34
|
+
"N",
|
|
35
|
+
"CA",
|
|
36
|
+
"C",
|
|
37
|
+
"O",
|
|
38
|
+
"CB",
|
|
39
|
+
"CG",
|
|
40
|
+
"CG1",
|
|
41
|
+
"CG2",
|
|
42
|
+
"OG",
|
|
43
|
+
"OG1",
|
|
44
|
+
"SG",
|
|
45
|
+
"CD",
|
|
46
|
+
"CD1",
|
|
47
|
+
"CD2",
|
|
48
|
+
"ND1",
|
|
49
|
+
"ND2",
|
|
50
|
+
"OD1",
|
|
51
|
+
"OD2",
|
|
52
|
+
"SD",
|
|
53
|
+
"CE",
|
|
54
|
+
"CE1",
|
|
55
|
+
"CE2",
|
|
56
|
+
"CE3",
|
|
57
|
+
"NE",
|
|
58
|
+
"NE1",
|
|
59
|
+
"NE2",
|
|
60
|
+
"OE1",
|
|
61
|
+
"OE2",
|
|
62
|
+
"CH2",
|
|
63
|
+
"NH1",
|
|
64
|
+
"NH2",
|
|
65
|
+
"OH",
|
|
66
|
+
"CZ",
|
|
67
|
+
"CZ2",
|
|
68
|
+
"CZ3",
|
|
69
|
+
"NZ",
|
|
70
|
+
"OXT",
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# Atom ordering for old versions of MPNN.
|
|
74
|
+
legacy_atom_order = (
|
|
75
|
+
"N",
|
|
76
|
+
"CA",
|
|
77
|
+
"C",
|
|
78
|
+
"CB",
|
|
79
|
+
"O",
|
|
80
|
+
"CG",
|
|
81
|
+
"CG1",
|
|
82
|
+
"CG2",
|
|
83
|
+
"OG",
|
|
84
|
+
"OG1",
|
|
85
|
+
"SG",
|
|
86
|
+
"CD",
|
|
87
|
+
"CD1",
|
|
88
|
+
"CD2",
|
|
89
|
+
"ND1",
|
|
90
|
+
"ND2",
|
|
91
|
+
"OD1",
|
|
92
|
+
"OD2",
|
|
93
|
+
"SD",
|
|
94
|
+
"CE",
|
|
95
|
+
"CE1",
|
|
96
|
+
"CE2",
|
|
97
|
+
"CE3",
|
|
98
|
+
"NE",
|
|
99
|
+
"NE1",
|
|
100
|
+
"NE2",
|
|
101
|
+
"OE1",
|
|
102
|
+
"OE2",
|
|
103
|
+
"CH2",
|
|
104
|
+
"NH1",
|
|
105
|
+
"NH2",
|
|
106
|
+
"OH",
|
|
107
|
+
"CZ",
|
|
108
|
+
"CZ2",
|
|
109
|
+
"CZ3",
|
|
110
|
+
"NZ",
|
|
111
|
+
"OXT",
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# Token encoding for MPNN.
|
|
115
|
+
MPNN_TOKEN_ENCODING = TokenEncoding(
|
|
116
|
+
token_atoms={token: atom_order for token in token_order},
|
|
117
|
+
chemcomp_type_to_unknown={chem_type: "UNK" for chem_type in AA_LIKE_CHEM_TYPES},
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# Token encoding for versions of MPNN using the legacy token order and
|
|
121
|
+
# new atom order.
|
|
122
|
+
MPNN_LEGACY_TOKEN_ENCODING = TokenEncoding(
|
|
123
|
+
token_atoms={token: atom_order for token in legacy_token_order},
|
|
124
|
+
chemcomp_type_to_unknown={chem_type: "UNK" for chem_type in AA_LIKE_CHEM_TYPES},
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Token encoding for versions of MPNN using the legacy token order and
|
|
128
|
+
# legacy atom order.
|
|
129
|
+
MPNN_LEGACY_TOKEN_LEGACY_ATOM_ENCODING = TokenEncoding(
|
|
130
|
+
token_atoms={token: legacy_atom_order for token in legacy_token_order},
|
|
131
|
+
chemcomp_type_to_unknown={chem_type: "UNK" for chem_type in AA_LIKE_CHEM_TYPES},
|
|
132
|
+
)
|