rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
from itertools import combinations
|
|
2
|
+
from typing import Union
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
from jaxtyping import Bool, Float
|
|
7
|
+
from numpy.typing import NDArray
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def find_bin_midpoints(
|
|
11
|
+
max_distance: float, num_bins: int, device: Union[str, torch.device] = "cpu"
|
|
12
|
+
) -> Float[torch.Tensor, "num_bins"]:
|
|
13
|
+
"""
|
|
14
|
+
Find the bin midpoints for a given binning scheme. Used to find expectation of values when converting binned
|
|
15
|
+
predictions to unbinned predictions. Assumes the minimum of the schema is 0.
|
|
16
|
+
Args:
|
|
17
|
+
max_distance: float, maximum distance
|
|
18
|
+
num_bins: int, number of bins
|
|
19
|
+
device: device to run on
|
|
20
|
+
Returns:
|
|
21
|
+
pae_midpoints: [num_bins], bin midpoints
|
|
22
|
+
"""
|
|
23
|
+
bin_size = max_distance / num_bins
|
|
24
|
+
bins = torch.linspace(
|
|
25
|
+
bin_size, max_distance - bin_size, num_bins - 1, device=device
|
|
26
|
+
)
|
|
27
|
+
midpoints = (bins[1:] + bins[:-1]) / 2
|
|
28
|
+
midpoints = torch.cat(
|
|
29
|
+
[(bins[0] - bin_size / 2)[None], midpoints, bins[-1:] + bin_size / 2]
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
return midpoints
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def unbin_logits(
|
|
36
|
+
logits: Float[torch.Tensor, "B num_bins L X"], max_distance: float, num_bins: int
|
|
37
|
+
) -> Float[torch.Tensor, "B L L"]:
|
|
38
|
+
"""
|
|
39
|
+
Unbin the logits to get the matrix
|
|
40
|
+
Args:
|
|
41
|
+
logits: [B, num_bins, L, X], binned logits where X is 23 for plddt and L for pae and pde
|
|
42
|
+
max_distance: float, maximum distance
|
|
43
|
+
num_bins: int, number of bins
|
|
44
|
+
Returns:
|
|
45
|
+
unbinned: [B, L, L], unbinned matrix
|
|
46
|
+
"""
|
|
47
|
+
midpoints = find_bin_midpoints(max_distance, num_bins, device=logits.device)
|
|
48
|
+
probabilities = torch.nn.Softmax(dim=1)(logits).detach().float()
|
|
49
|
+
unbinned = (probabilities * midpoints[None, :, None, None]).sum(dim=1)
|
|
50
|
+
return unbinned
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def create_chainwise_masks_1d(
|
|
54
|
+
ch_label: NDArray[np.str_], device: Union[str, torch.device] = "cpu"
|
|
55
|
+
) -> dict[str, Bool[torch.Tensor, "L"]]:
|
|
56
|
+
"""
|
|
57
|
+
Create 1D chainwise masks for a set of chain labels
|
|
58
|
+
Args:
|
|
59
|
+
ch_label: np.ndarray [L], chain labels
|
|
60
|
+
device: torch.device, device to run on
|
|
61
|
+
Returns:
|
|
62
|
+
ch_masks: dict, chain maps chain letter to which elements to score for that chain
|
|
63
|
+
"""
|
|
64
|
+
unique_chains = np.unique(ch_label)
|
|
65
|
+
ch_masks = {}
|
|
66
|
+
for chain in unique_chains:
|
|
67
|
+
indices = torch.from_numpy((ch_label == chain)).to(
|
|
68
|
+
dtype=torch.bool, device=device
|
|
69
|
+
)
|
|
70
|
+
ch_masks[chain] = indices
|
|
71
|
+
return ch_masks
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def create_chainwise_masks_2d(
|
|
75
|
+
ch_label: NDArray[np.str_], device: Union[str, torch.device] = "cpu"
|
|
76
|
+
) -> dict[str, Bool[torch.Tensor, "L L"]]:
|
|
77
|
+
"""
|
|
78
|
+
Create 2D chainwise masks for a set of chain labels
|
|
79
|
+
Args:
|
|
80
|
+
ch_label: np.ndarray [L], chain labels
|
|
81
|
+
device: torch.device, device to run on
|
|
82
|
+
Returns:
|
|
83
|
+
ch_masks: dict, chain maps chain letter to which elements to score for that chain
|
|
84
|
+
"""
|
|
85
|
+
unique_chains = np.unique(ch_label)
|
|
86
|
+
ch_masks = {}
|
|
87
|
+
for chain in unique_chains:
|
|
88
|
+
indices = torch.from_numpy((ch_label == chain))
|
|
89
|
+
mask = torch.outer(indices, indices).to(dtype=torch.bool, device=device)
|
|
90
|
+
ch_masks[chain] = mask
|
|
91
|
+
return ch_masks
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def create_interface_masks_2d(
|
|
95
|
+
ch_label: NDArray[np.str_], device: Union[str, torch.device] = "cpu"
|
|
96
|
+
) -> dict[tuple[str, str], Bool[torch.Tensor, "L L"]]:
|
|
97
|
+
"""
|
|
98
|
+
Create interface masks for a set of chain labels
|
|
99
|
+
Args:
|
|
100
|
+
ch_label: np.ndarray [L], chain labels
|
|
101
|
+
device: torch.device, device to run on
|
|
102
|
+
Returns:
|
|
103
|
+
pairs_to_score: dict mapping chain pairs to boolean masks
|
|
104
|
+
"""
|
|
105
|
+
unique_chains = np.unique(ch_label)
|
|
106
|
+
pairs_to_score = {}
|
|
107
|
+
for chain_i, chain_j in combinations(unique_chains, 2):
|
|
108
|
+
chain_i_indices = torch.from_numpy((ch_label == chain_i))
|
|
109
|
+
chain_j_indices = torch.from_numpy((ch_label == chain_j))
|
|
110
|
+
to_be_scored = torch.outer(chain_i_indices, chain_j_indices).to(
|
|
111
|
+
dtype=torch.bool, device=device
|
|
112
|
+
) + torch.outer(chain_j_indices, chain_i_indices).to(
|
|
113
|
+
dtype=torch.bool, device=device
|
|
114
|
+
)
|
|
115
|
+
pairs_to_score[(chain_i, chain_j)] = to_be_scored
|
|
116
|
+
return pairs_to_score
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def compute_mean_over_subsampled_pairs(
|
|
120
|
+
matrix_to_mean: Float[torch.Tensor, "B L M"],
|
|
121
|
+
pairs_to_score: Bool[torch.Tensor, "L M"],
|
|
122
|
+
eps: float = 1e-6,
|
|
123
|
+
) -> Float[torch.Tensor, "B"]:
|
|
124
|
+
"""
|
|
125
|
+
Compute the mean over a subsample of pairs in a 2d matrix. Returns a tensor with an element for each batch
|
|
126
|
+
Args:
|
|
127
|
+
matrix_to_mean: tensor of shape (batch, L, L)
|
|
128
|
+
pairs_to_score: 2d tensor of shape (L, L) with 1s where pairs should be scored and 0s elsewhere
|
|
129
|
+
eps: small epsilon value to avoid division by zero
|
|
130
|
+
Returns:
|
|
131
|
+
1d tensor of shape (batch,) with the mean over the subsampled pairs for each batch
|
|
132
|
+
"""
|
|
133
|
+
B, L, M = matrix_to_mean.shape
|
|
134
|
+
assert matrix_to_mean.shape == (
|
|
135
|
+
B,
|
|
136
|
+
L,
|
|
137
|
+
M,
|
|
138
|
+
), "Matrix to mean should be of shape (batch, L, M)"
|
|
139
|
+
assert pairs_to_score.shape == (L, M), "Pairs to score should be of shape (L, M)"
|
|
140
|
+
batch = (matrix_to_mean * pairs_to_score).sum(dim=(-1, -2)) / (
|
|
141
|
+
pairs_to_score.sum() + eps
|
|
142
|
+
)
|
|
143
|
+
assert batch.shape == (B,), "Batch should be of shape (batch,)"
|
|
144
|
+
return batch
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def compute_min_over_subsampled_pairs(
|
|
148
|
+
matrix_to_min: Float[torch.Tensor, "B L M"],
|
|
149
|
+
pairs_to_score: Bool[torch.Tensor, "L M"],
|
|
150
|
+
) -> Float[torch.Tensor, "B"]:
|
|
151
|
+
"""
|
|
152
|
+
Compute the min over a subsample of pairs in a 2d matrix. Returns a tensor with an element for each batch
|
|
153
|
+
Args:
|
|
154
|
+
matrix_to_min: tensor of shape (batch, L, L)
|
|
155
|
+
pairs_to_score: 2d tensor of shape (L, L) with 1s where pairs should be scored and 0s elsewhere
|
|
156
|
+
Returns:
|
|
157
|
+
1d tensor of shape (batch,) with the min over the subsampled pairs for each batch
|
|
158
|
+
"""
|
|
159
|
+
B, L, M = matrix_to_min.shape
|
|
160
|
+
assert matrix_to_min.shape == (
|
|
161
|
+
B,
|
|
162
|
+
L,
|
|
163
|
+
M,
|
|
164
|
+
), "Matrix to min should be of shape (batch, L, M)"
|
|
165
|
+
assert pairs_to_score.shape == (L, M), "Pairs to score should be of shape (L, M)"
|
|
166
|
+
# Use torch.where to efficiently mask without cloning the entire matrix
|
|
167
|
+
# This broadcasts pairs_to_score across the batch dimension
|
|
168
|
+
masked_matrix = torch.where(
|
|
169
|
+
pairs_to_score.bool(), # condition (L, M) -> broadcasts to (B, L, M)
|
|
170
|
+
matrix_to_min, # if True: use original values (B, L, M)
|
|
171
|
+
torch.tensor(
|
|
172
|
+
float("inf"), device=matrix_to_min.device, dtype=matrix_to_min.dtype
|
|
173
|
+
), # if False: use inf
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# Flatten the last two dimensions and compute min across them
|
|
177
|
+
batch = masked_matrix.view(B, -1).min(dim=-1)[0]
|
|
178
|
+
|
|
179
|
+
assert batch.shape == (B,), "Batch should be of shape (batch,)"
|
|
180
|
+
return batch
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def spread_batch_into_dictionary(batch: Float[torch.Tensor, "B"]) -> dict[int, float]:
|
|
184
|
+
"""
|
|
185
|
+
Given a batch of data, create a dictionary with keys as the batch index and value as the corresponding data
|
|
186
|
+
Args:
|
|
187
|
+
batch: 1D tensor of shape (B,)
|
|
188
|
+
Returns:
|
|
189
|
+
Dictionary mapping batch indices to float values
|
|
190
|
+
"""
|
|
191
|
+
assert len(batch.shape) == 1, f"Batch should be a 1d tensor, {batch}"
|
|
192
|
+
return {i: data.item() for i, data in enumerate(batch)}
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from rf3.metrics.metric_utils import find_bin_midpoints
|
|
5
|
+
|
|
6
|
+
from foundry.metrics.metric import Metric
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def compute_ptm(
|
|
10
|
+
pae: torch.Tensor,
|
|
11
|
+
to_calculate: torch.Tensor | None,
|
|
12
|
+
max_distance: float = 32,
|
|
13
|
+
bin_count: int = 64,
|
|
14
|
+
):
|
|
15
|
+
"""Compute the predicted TM-score (PTM) from the predicted aligned error (PAE).
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
pae: Predicted aligned error tensor.
|
|
19
|
+
to_calculate: Tensor indicating which residues to calculate PTM for.
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
ptm: Computed predicted TM-score.
|
|
23
|
+
"""
|
|
24
|
+
D, I = pae.shape[0], pae.shape[1]
|
|
25
|
+
if to_calculate is None:
|
|
26
|
+
to_calculate = torch.ones((I, I), dtype=torch.bool, device=pae.device)
|
|
27
|
+
|
|
28
|
+
bin_centers = find_bin_midpoints(
|
|
29
|
+
max_distance, bin_count, device=pae.device
|
|
30
|
+
) # TODO: get this from config
|
|
31
|
+
pae = torch.nn.Softmax(dim=-1)(pae).detach().float()
|
|
32
|
+
normalization_factor = 1.24 * (max(I, 19) - 15.0) ** (1 / 3) - 1.8
|
|
33
|
+
denominator = 1 / (1 + (bin_centers / (normalization_factor)) ** 2)
|
|
34
|
+
pae = pae * denominator[None, None, None, :] # Broadcast to match pae shape
|
|
35
|
+
|
|
36
|
+
pae = pae.sum(dim=-1) # Sum over the last dimension
|
|
37
|
+
pae = (pae * to_calculate[None]).sum(dim=-1) / (to_calculate.sum(dim=-1) + 1e-6)
|
|
38
|
+
ptm = pae.max(dim=-1).values
|
|
39
|
+
assert ptm.shape == (D,)
|
|
40
|
+
return ptm
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class ComputePTM(Metric):
|
|
44
|
+
@property
|
|
45
|
+
def kwargs_to_compute_args(self) -> dict[str, Any]:
|
|
46
|
+
return {
|
|
47
|
+
"pae": ("network_output", "pae"),
|
|
48
|
+
"asym_id": ("network_input", "f", "asym_id"),
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
def compute(
|
|
52
|
+
self,
|
|
53
|
+
pae: torch.Tensor,
|
|
54
|
+
asym_id: torch.Tensor,
|
|
55
|
+
) -> dict[str, float]:
|
|
56
|
+
"""Compute the predicted TM-score (PTM) from the predicted aligned error (PAE).
|
|
57
|
+
Args:
|
|
58
|
+
pae: Predicted aligned error tensor.
|
|
59
|
+
asym_id: AtomArrayStack containing the predicted structure.
|
|
60
|
+
Returns:
|
|
61
|
+
ptm: Computed predicted TM-score.
|
|
62
|
+
"""
|
|
63
|
+
ptm = compute_ptm(pae, None)
|
|
64
|
+
# split the batch dimension into separate keys in the output dictionary
|
|
65
|
+
ptm = ptm.cpu().numpy()
|
|
66
|
+
ptm = {f"ptm_{i}": ptm[i] for i in range(len(ptm))}
|
|
67
|
+
return ptm
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class ComputeIPTM(Metric):
|
|
71
|
+
@property
|
|
72
|
+
def kwargs_to_compute_args(self) -> dict[str, Any]:
|
|
73
|
+
return {
|
|
74
|
+
"pae": ("network_output", "pae"),
|
|
75
|
+
"asym_id": ("network_input", "f", "asym_id"),
|
|
76
|
+
"is_ligand": ("network_input", "f", "is_ligand"),
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
def compute(
|
|
80
|
+
self,
|
|
81
|
+
pae: torch.Tensor,
|
|
82
|
+
asym_id: torch.Tensor,
|
|
83
|
+
is_ligand: torch.Tensor,
|
|
84
|
+
) -> dict[str, float]:
|
|
85
|
+
"""Compute the predicted interface TM-score (iPTM) from the predicted aligned error (PAE).
|
|
86
|
+
Args:
|
|
87
|
+
pae: Predicted aligned error tensor.
|
|
88
|
+
predicted_atom_array_stack: AtomArrayStack containing the predicted structure.
|
|
89
|
+
Returns:
|
|
90
|
+
iptm: Computed interface TM-score.
|
|
91
|
+
"""
|
|
92
|
+
unique, counts = torch.unique(asym_id, return_counts=True)
|
|
93
|
+
to_calculate = asym_id[None, :] != asym_id[:, None]
|
|
94
|
+
iptm = compute_ptm(pae, to_calculate)
|
|
95
|
+
|
|
96
|
+
# make a protein - ligand mask
|
|
97
|
+
protein_mask = is_ligand == 0
|
|
98
|
+
ligand_mask = is_ligand == 1
|
|
99
|
+
# calculate iptm for protein-protein, protein-ligand, and ligand-ligand interfaces
|
|
100
|
+
protein_protein_mask = (
|
|
101
|
+
protein_mask[None, :] & protein_mask[:, None] * to_calculate
|
|
102
|
+
)
|
|
103
|
+
protein_ligand_mask = (
|
|
104
|
+
(protein_mask[None, :] & ligand_mask[:, None])
|
|
105
|
+
| (ligand_mask[None, :] & protein_mask[:, None])
|
|
106
|
+
) * to_calculate
|
|
107
|
+
ligand_ligand_mask = ligand_mask[None, :] & ligand_mask[:, None] * to_calculate
|
|
108
|
+
# calculate iptm for each interface type
|
|
109
|
+
iptm_protein_protein = compute_ptm(pae, protein_protein_mask)
|
|
110
|
+
iptm_protein_ligand = compute_ptm(pae, protein_ligand_mask)
|
|
111
|
+
iptm_ligand_ligand = compute_ptm(pae, ligand_ligand_mask)
|
|
112
|
+
|
|
113
|
+
# split the batch dimension into separate keys in the output dictionary
|
|
114
|
+
iptm = iptm.cpu().numpy()
|
|
115
|
+
iptm = {f"iptm_{i}": iptm[i] for i in range(len(iptm))}
|
|
116
|
+
iptm_protein_protein = iptm_protein_protein.cpu().numpy()
|
|
117
|
+
iptm_protein_protein = {
|
|
118
|
+
f"iptm_protein_protein_{i}": iptm_protein_protein[i]
|
|
119
|
+
for i in range(len(iptm_protein_protein))
|
|
120
|
+
}
|
|
121
|
+
iptm_protein_ligand = iptm_protein_ligand.cpu().numpy()
|
|
122
|
+
iptm_protein_ligand = {
|
|
123
|
+
f"iptm_protein_ligand_{i}": iptm_protein_ligand[i]
|
|
124
|
+
for i in range(len(iptm_protein_ligand))
|
|
125
|
+
}
|
|
126
|
+
iptm_ligand_ligand = iptm_ligand_ligand.cpu().numpy()
|
|
127
|
+
iptm_ligand_ligand = {
|
|
128
|
+
f"iptm_ligand_ligand_{i}": iptm_ligand_ligand[i]
|
|
129
|
+
for i in range(len(iptm_ligand_ligand))
|
|
130
|
+
}
|
|
131
|
+
iptm.update(iptm_protein_protein)
|
|
132
|
+
iptm.update(iptm_protein_ligand)
|
|
133
|
+
iptm.update(iptm_ligand_ligand)
|
|
134
|
+
return iptm
|
rf3/metrics/rasa.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from atomworks.ml.transforms.sasa import calculate_atomwise_rasa
|
|
3
|
+
from beartype.typing import Any
|
|
4
|
+
from biotite.structure import AtomArrayStack
|
|
5
|
+
|
|
6
|
+
from foundry.metrics.metric import Metric
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class UnresolvedRegionRASA(Metric):
|
|
10
|
+
"""
|
|
11
|
+
This metric computes the RASA score for unresolved regions in a protein structure.
|
|
12
|
+
The RASA score is defined as the ratio of the solvent-accessible surface area (SASA)
|
|
13
|
+
of a residue in a protein structure to the SASA of the same residue in an extended conformation.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
probe_radius: float = 1.4,
|
|
19
|
+
atom_radii: str | np.ndarray = "ProtOr",
|
|
20
|
+
point_number: int = 100,
|
|
21
|
+
include_resolved: bool = False,
|
|
22
|
+
**kwargs,
|
|
23
|
+
):
|
|
24
|
+
super().__init__(**kwargs)
|
|
25
|
+
self.probe_radius = probe_radius
|
|
26
|
+
self.atom_radii = atom_radii
|
|
27
|
+
self.point_number = point_number
|
|
28
|
+
self.include_resolved = include_resolved
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def kwargs_to_compute_args(self) -> dict[str, Any]:
|
|
32
|
+
return {
|
|
33
|
+
"predicted_atom_array_stack": ("predicted_atom_array_stack",),
|
|
34
|
+
"ground_truth_atom_array_stack": ("ground_truth_atom_array_stack",),
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
def compute(
|
|
38
|
+
self,
|
|
39
|
+
predicted_atom_array_stack: AtomArrayStack,
|
|
40
|
+
ground_truth_atom_array_stack: AtomArrayStack,
|
|
41
|
+
) -> dict[str, Any]:
|
|
42
|
+
"""Compute the RASA score for unresolved regions in a protein structure.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
predicted_atom_array (AtomArray): The input atom array representing the predicted protein structure.
|
|
46
|
+
ground_truth_atom_array (AtomArray): The input atom array representing the ground truth protein structure.
|
|
47
|
+
probe_radius (float, optional): Van-der-Waals radius of the probe in Angstrom. Defaults to 1.4 (for water).
|
|
48
|
+
atom_radii (str | np.ndarray, optional): Atom radii set to use for calculation. Defaults to "ProtOr".
|
|
49
|
+
point_number (int, optional): Number of points in the Shrake-Rupley algorithm to sample for calculating SASA. Defaults to 100.
|
|
50
|
+
include_resolved (bool, optional): Whether to include resolved regions in the RASA score. Defaults to False.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
dict: A dictionary containing the RASA score and other relevant information.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
# find unresolved regions
|
|
57
|
+
# (polymer atoms with occupancy 0.0)
|
|
58
|
+
atoms_to_score_unresolved = ground_truth_atom_array_stack.is_polymer & (
|
|
59
|
+
ground_truth_atom_array_stack.occupancy == 0.0
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# find resolved regions (polymer atoms with occupancy > 0.0)
|
|
63
|
+
atoms_to_score_resolved = ground_truth_atom_array_stack.is_polymer & (
|
|
64
|
+
ground_truth_atom_array_stack.occupancy > 0.0
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
unresolved_rasas = []
|
|
68
|
+
resolved_rasas = []
|
|
69
|
+
|
|
70
|
+
# Calculate RASA
|
|
71
|
+
for atom_array in predicted_atom_array_stack:
|
|
72
|
+
try:
|
|
73
|
+
rasa = calculate_atomwise_rasa(
|
|
74
|
+
atom_array=atom_array,
|
|
75
|
+
probe_radius=self.probe_radius,
|
|
76
|
+
atom_radii=self.atom_radii,
|
|
77
|
+
point_number=self.point_number,
|
|
78
|
+
)
|
|
79
|
+
unresolved_rasas.append(rasa[atoms_to_score_unresolved].mean())
|
|
80
|
+
if self.include_resolved:
|
|
81
|
+
resolved_rasas.append(rasa[atoms_to_score_resolved].mean())
|
|
82
|
+
except KeyError:
|
|
83
|
+
unresolved_rasas.append(np.nan)
|
|
84
|
+
if self.include_resolved:
|
|
85
|
+
resolved_rasas.append(np.nan)
|
|
86
|
+
|
|
87
|
+
# Calculate the mean RASA scores
|
|
88
|
+
# Pattern-match other metrics by appending "_i" to the metric name to represent multiple batches
|
|
89
|
+
# (e.g., "unresolved_polymer_rasa_0", "unresolved_polymer_rasa_1", etc.)
|
|
90
|
+
unresolved_rasa = np.nanmean(unresolved_rasas)
|
|
91
|
+
output_dictionary = {
|
|
92
|
+
f"unresolved_polymer_rasa_{i}": rasa
|
|
93
|
+
for i, rasa in enumerate(unresolved_rasas)
|
|
94
|
+
}
|
|
95
|
+
output_dictionary["mean_unresolved_polymer_rasa"] = unresolved_rasa
|
|
96
|
+
|
|
97
|
+
# ... and add resolved region RASA scores if flag is enabled
|
|
98
|
+
if self.include_resolved:
|
|
99
|
+
resolved_rasa = np.nanmean(resolved_rasas)
|
|
100
|
+
output_dictionary.update(
|
|
101
|
+
{
|
|
102
|
+
f"resolved_polymer_rasa_{i}": rasa
|
|
103
|
+
for i, rasa in enumerate(resolved_rasas)
|
|
104
|
+
}
|
|
105
|
+
)
|
|
106
|
+
output_dictionary["mean_resolved_polymer_rasa"] = resolved_rasa
|
|
107
|
+
|
|
108
|
+
return output_dictionary
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from atomworks.ml.utils import nested_dict
|
|
3
|
+
from atomworks.ml.utils.selection import (
|
|
4
|
+
get_mask_from_atom_selection,
|
|
5
|
+
parse_selection_string,
|
|
6
|
+
)
|
|
7
|
+
from beartype.typing import Any
|
|
8
|
+
from biotite.structure import AtomArrayStack
|
|
9
|
+
|
|
10
|
+
from foundry.metrics.metric import Metric
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SelectedAtomByAtomDistances(Metric):
|
|
14
|
+
"""Computes all-by-all 2D distances given a list of selection strings"""
|
|
15
|
+
|
|
16
|
+
def compute_from_kwargs(self, **kwargs: Any) -> dict[str, Any]:
|
|
17
|
+
"""Override parent class to handle optional selection_strings parameter"""
|
|
18
|
+
compute_inputs = {
|
|
19
|
+
"atom_array_stack": nested_dict.getitem(
|
|
20
|
+
kwargs, key="predicted_atom_array_stack"
|
|
21
|
+
)
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
# Add selection_strings only if it exists
|
|
25
|
+
try:
|
|
26
|
+
compute_inputs["selection_strings"] = nested_dict.getitem(
|
|
27
|
+
kwargs, key=("extra_info", "selection_strings")
|
|
28
|
+
)
|
|
29
|
+
except (KeyError, IndexError, TypeError):
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
return self.compute(**compute_inputs)
|
|
33
|
+
|
|
34
|
+
def compute(
|
|
35
|
+
self,
|
|
36
|
+
atom_array_stack: AtomArrayStack,
|
|
37
|
+
selection_strings: list[str] | None = None,
|
|
38
|
+
) -> dict[str, Any]:
|
|
39
|
+
# Short-circuit if no selection strings are provided
|
|
40
|
+
if not selection_strings:
|
|
41
|
+
return {}
|
|
42
|
+
|
|
43
|
+
# ... select the specified atoms
|
|
44
|
+
mask = np.zeros(atom_array_stack.array_length(), dtype=bool)
|
|
45
|
+
atom_selections = [parse_selection_string(s) for s in selection_strings]
|
|
46
|
+
for atom_selection in atom_selections:
|
|
47
|
+
mask |= get_mask_from_atom_selection(atom_array_stack, atom_selection)
|
|
48
|
+
selected_atom_array_stack = atom_array_stack[:, mask]
|
|
49
|
+
|
|
50
|
+
# Create views with added dimensions for broadcasting
|
|
51
|
+
# coord is (D, L, 3), we want pairwise distances for each D
|
|
52
|
+
coord_i = selected_atom_array_stack.coord[:, :, np.newaxis, :] # (D, L, 1, 3)
|
|
53
|
+
coord_j = selected_atom_array_stack.coord[:, np.newaxis, :, :] # (D, 1, L, 3)
|
|
54
|
+
|
|
55
|
+
# Calculate pairwise differences and distances
|
|
56
|
+
differences = coord_i - coord_j # broadcasts to (D, L, L, 3)
|
|
57
|
+
distances = np.linalg.norm(differences, axis=-1) # (D, L, L)
|
|
58
|
+
|
|
59
|
+
# Compute the mean and standard deviation across the D dimension
|
|
60
|
+
mean_distances = np.mean(distances, axis=0) # Shape: (L, L)
|
|
61
|
+
std_distances = np.std(distances, axis=0) # Shape: (L, L)
|
|
62
|
+
|
|
63
|
+
# Name the features with the chain_id, res_name, res_id, atom_name
|
|
64
|
+
def _format_atom_id(chain_id, res_name, res_id, atom_name):
|
|
65
|
+
return f"{chain_id}/{res_name}/{res_id}/{atom_name}"
|
|
66
|
+
|
|
67
|
+
vectorized_format = np.vectorize(_format_atom_id)
|
|
68
|
+
id = vectorized_format(
|
|
69
|
+
selected_atom_array_stack.chain_id,
|
|
70
|
+
selected_atom_array_stack.res_name,
|
|
71
|
+
selected_atom_array_stack.res_id,
|
|
72
|
+
selected_atom_array_stack.atom_name,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# Create a 2x2 numpy arrays of names, where we concatenate the id ...
|
|
76
|
+
id_i = np.char.add(id, "-")
|
|
77
|
+
id_II = np.char.add(id_i[:, np.newaxis], id[np.newaxis, :])
|
|
78
|
+
|
|
79
|
+
# ... and store the results in a dictionary, naming the columns with the concatenated id
|
|
80
|
+
results = {}
|
|
81
|
+
for i in range(len(id)):
|
|
82
|
+
for j in range(
|
|
83
|
+
i + 1, len(id)
|
|
84
|
+
): # Only consider j > i to avoid symmetric duplicates
|
|
85
|
+
col_id = id_II[i, j]
|
|
86
|
+
mean = mean_distances[i, j]
|
|
87
|
+
std = std_distances[i, j]
|
|
88
|
+
results[f"{col_id}_mean"] = mean
|
|
89
|
+
results[f"{col_id}_std"] = std
|
|
90
|
+
|
|
91
|
+
return results
|