rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
rf3/loss/loss.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
logger = logging.getLogger(__name__)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def calc_ddihedralmse_dxyz(a, b, c, d, true_dih, eps=1e-6):
|
|
9
|
+
"""
|
|
10
|
+
Calculates the gradient of the dihedral angle with respect to the xyz coordinates using the closed form derivative.
|
|
11
|
+
a, b, c, and d are atoms participating in the chiral center. true_dih is the true dihedral angle.
|
|
12
|
+
|
|
13
|
+
Unlike the original implementation, this does NOT use autograd.
|
|
14
|
+
"""
|
|
15
|
+
# I need to reshape this from n_symm, batch, n, 3 to n_symm * batch * n, 3)
|
|
16
|
+
og_shape = a.shape
|
|
17
|
+
# Expand the dihedral by the batch dimension to match n_atoms*batchs
|
|
18
|
+
true_dih = true_dih.unsqueeze(0).repeat(a.shape[0], 1)
|
|
19
|
+
a = a.view(-1, 3)
|
|
20
|
+
b = b.view(-1, 3)
|
|
21
|
+
c = c.view(-1, 3)
|
|
22
|
+
d = d.view(-1, 3)
|
|
23
|
+
true_dih = true_dih.view(-1)
|
|
24
|
+
|
|
25
|
+
batch_size = a.shape[0] # Support for batch size
|
|
26
|
+
I = (
|
|
27
|
+
torch.eye(3).unsqueeze(0).repeat(batch_size, 1, 1).to(a.device)
|
|
28
|
+
) # Make batch-aware identity matrix
|
|
29
|
+
|
|
30
|
+
# Compute b0, b1, b2
|
|
31
|
+
b0 = a - b
|
|
32
|
+
b1 = c - b
|
|
33
|
+
b2 = d - c
|
|
34
|
+
|
|
35
|
+
# Normalize b1
|
|
36
|
+
b1_norm = torch.norm(b1, dim=-1, keepdim=True)
|
|
37
|
+
b1n = b1 / (b1_norm + eps)
|
|
38
|
+
|
|
39
|
+
# Compute orthogonal components v and w
|
|
40
|
+
v = b0 - torch.sum(b0 * b1n, dim=-1, keepdim=True) * b1n
|
|
41
|
+
w = b2 - torch.sum(b2 * b1n, dim=-1, keepdim=True) * b1n
|
|
42
|
+
|
|
43
|
+
# Dihedral components x and y
|
|
44
|
+
x = torch.sum(v * w, dim=-1)
|
|
45
|
+
y = torch.sum(torch.cross(b1n, v, dim=-1) * w, dim=-1)
|
|
46
|
+
|
|
47
|
+
# Dihedral angle
|
|
48
|
+
dih = torch.atan2(y + eps, x + eps)
|
|
49
|
+
|
|
50
|
+
# Compute MSE loss and manual gradients
|
|
51
|
+
# mse_loss = torch.mean(torch.square(dih - true_dih))
|
|
52
|
+
# mse_loss = torch.sum(torch.square(dih - true_dih))
|
|
53
|
+
|
|
54
|
+
# Define matrices and gradients, adapted for batch
|
|
55
|
+
db0_db = -I
|
|
56
|
+
db1_db = -I
|
|
57
|
+
db1_dc = I
|
|
58
|
+
db2_dc = -I
|
|
59
|
+
db0_da = I
|
|
60
|
+
db2_dd = I
|
|
61
|
+
# dmse_ddih = 2 * (dih - true_dih) / batch_size
|
|
62
|
+
dmse_ddih = 2 * (dih - true_dih)
|
|
63
|
+
ddih_dx = -y / (x**2 + y**2 + eps)
|
|
64
|
+
ddih_dy = x / (x**2 + y**2 + eps)
|
|
65
|
+
dy_dv = -torch.cross(b1n, w, dim=-1)
|
|
66
|
+
dy_dw = torch.cross(b1n, v, dim=-1)
|
|
67
|
+
dx_dv = w
|
|
68
|
+
dx_dw = v
|
|
69
|
+
|
|
70
|
+
dw_db1n = -torch.sum(b2 * b1n, dim=-1, keepdim=True).unsqueeze(-1) * I - torch.bmm(
|
|
71
|
+
b2.unsqueeze(-1), b1n.unsqueeze(1)
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
db1n_db1 = (b1_norm + eps).unsqueeze(-1) * I / (b1_norm**2 + eps).unsqueeze(
|
|
75
|
+
-1
|
|
76
|
+
) - torch.bmm(b1.unsqueeze(-1), b1.unsqueeze(1)) / (b1_norm**2 + eps).unsqueeze(-1)
|
|
77
|
+
|
|
78
|
+
dv_db1n = -torch.sum(b0 * b1n, dim=-1, keepdim=True).unsqueeze(-1) * I - torch.bmm(
|
|
79
|
+
b0.unsqueeze(-1), b1n.unsqueeze(1)
|
|
80
|
+
)
|
|
81
|
+
dv_db0 = I - torch.bmm(b1n.unsqueeze(-1), b1n.unsqueeze(1))
|
|
82
|
+
dw_db2 = I - torch.bmm(b1n.unsqueeze(-1), b1n.unsqueeze(1))
|
|
83
|
+
|
|
84
|
+
# Adjust sizes now for efficiency
|
|
85
|
+
ddih_dx = ddih_dx.view(-1, 1, 1)
|
|
86
|
+
ddih_dy = ddih_dy.view(-1, 1, 1)
|
|
87
|
+
dmse_ddih = dmse_ddih.view(-1, 1, 1)
|
|
88
|
+
dx_dv = dx_dv.unsqueeze(1)
|
|
89
|
+
dx_dw = dx_dw.unsqueeze(1)
|
|
90
|
+
dy_dv = dy_dv.unsqueeze(1)
|
|
91
|
+
dy_dw = dy_dw.unsqueeze(1)
|
|
92
|
+
|
|
93
|
+
# Gradient computations
|
|
94
|
+
# wrt a
|
|
95
|
+
dv_da = torch.matmul(dv_db0, db0_da)
|
|
96
|
+
ddih_da = torch.bmm((ddih_dx * dx_dv), dv_da) + torch.bmm((ddih_dy * dy_dv), dv_da)
|
|
97
|
+
dmse_da = torch.bmm(dmse_ddih, ddih_da)
|
|
98
|
+
|
|
99
|
+
# wrt b
|
|
100
|
+
db1n_db = torch.matmul(db1n_db1, db1_db)
|
|
101
|
+
dv_db = torch.matmul(dv_db0, db0_db) + torch.matmul(
|
|
102
|
+
dv_db1n.transpose(-1, -2), db1n_db
|
|
103
|
+
)
|
|
104
|
+
dw_db = torch.matmul(dw_db1n.transpose(-1, -2), db1n_db)
|
|
105
|
+
dx_db = torch.bmm(dx_dv, dv_db) + torch.bmm(dx_dw, dw_db)
|
|
106
|
+
dy_db = torch.bmm(dy_dv, dv_db) + torch.bmm(dy_dw, dw_db)
|
|
107
|
+
ddih_db = torch.bmm(ddih_dx, dx_db) + torch.bmm(ddih_dy, dy_db)
|
|
108
|
+
dmse_db = torch.bmm(dmse_ddih, ddih_db)
|
|
109
|
+
|
|
110
|
+
# wrt c
|
|
111
|
+
db1n_dc = torch.matmul(db1n_db1, db1_dc)
|
|
112
|
+
dv_dc = torch.matmul(dv_db1n.transpose(-1, -2), db1n_dc)
|
|
113
|
+
dw_dc = torch.matmul(dw_db2, db2_dc) + torch.matmul(
|
|
114
|
+
dw_db1n.transpose(-1, -2), db1n_dc
|
|
115
|
+
)
|
|
116
|
+
dx_dc = torch.bmm(dx_dv, dv_dc) + torch.bmm(dx_dw, dw_dc)
|
|
117
|
+
dy_dc = torch.bmm(dy_dv, dv_dc) + torch.bmm(dy_dw, dw_dc)
|
|
118
|
+
ddih_dc = torch.bmm(ddih_dx, dx_dc) + torch.bmm(ddih_dy, dy_dc)
|
|
119
|
+
dmse_dc = torch.bmm(dmse_ddih, ddih_dc)
|
|
120
|
+
|
|
121
|
+
# wrt d
|
|
122
|
+
dw_dd = torch.matmul(dw_db2, db2_dd)
|
|
123
|
+
ddih_dd = torch.bmm((ddih_dx * dx_dw), dw_dd) + torch.bmm((ddih_dy * dy_dw), dw_dd)
|
|
124
|
+
dmse_dd = torch.bmm(dmse_ddih, ddih_dd)
|
|
125
|
+
|
|
126
|
+
# Reshape gradients back to original shape and prep for cat
|
|
127
|
+
dmse_da = dmse_da.view(og_shape).unsqueeze(-2)
|
|
128
|
+
dmse_db = dmse_db.view(og_shape).unsqueeze(-2)
|
|
129
|
+
dmse_dc = dmse_dc.view(og_shape).unsqueeze(-2)
|
|
130
|
+
dmse_dd = dmse_dd.view(og_shape).unsqueeze(-2)
|
|
131
|
+
|
|
132
|
+
grads = torch.cat([dmse_da, dmse_db, dmse_dc, dmse_dd], dim=-2)
|
|
133
|
+
return grads
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def calc_chiral_grads_flat_impl(
|
|
137
|
+
xyz, chiral_centers, chiral_center_dihedral_angles, no_grad_on_chiral_center
|
|
138
|
+
):
|
|
139
|
+
"""
|
|
140
|
+
Calculates the gradient of the chiral centers with respect to the xyz coordinates using the closed form derivative.
|
|
141
|
+
Args:
|
|
142
|
+
xyz: torch.Tensor, shape (batch, n_atoms, 3)
|
|
143
|
+
chiral_centers: torch.Tensor, shape (long) (n_centers, 4)
|
|
144
|
+
chiral_center_dihedral_angles: torch.Tensor, shape (float) (n_centers, 1)
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
grads: torch.Tensor, shape (batch, n_atoms, 3)
|
|
148
|
+
"""
|
|
149
|
+
# (We want to track the gradient of the dihedral angle loss with respect to the xyz coordinates)
|
|
150
|
+
xyz.requires_grad_(True)
|
|
151
|
+
|
|
152
|
+
# Edge case: No chiral centers, return zero gradients
|
|
153
|
+
if chiral_centers.shape[0] == 0:
|
|
154
|
+
return torch.zeros(xyz.shape, device=xyz.device)
|
|
155
|
+
|
|
156
|
+
# Get the coordinates of the four atoms that make up the chiral center
|
|
157
|
+
chiral_dih = xyz[:, chiral_centers, :]
|
|
158
|
+
|
|
159
|
+
# Calculate the gradient of the dihedral angle loss with respect to the xyz coordinates
|
|
160
|
+
grads = torch.zeros_like(xyz).to(xyz.device)
|
|
161
|
+
chiral_grads = calc_ddihedralmse_dxyz(
|
|
162
|
+
chiral_dih[..., 0, :],
|
|
163
|
+
chiral_dih[..., 1, :],
|
|
164
|
+
chiral_dih[..., 2, :],
|
|
165
|
+
chiral_dih[..., 3, :],
|
|
166
|
+
chiral_center_dihedral_angles,
|
|
167
|
+
) # n_center, 4, 3
|
|
168
|
+
|
|
169
|
+
if no_grad_on_chiral_center:
|
|
170
|
+
chiral_grads[:, :, 0] = 0.0 # no gradient on chiral center
|
|
171
|
+
|
|
172
|
+
# back to atom
|
|
173
|
+
grads.index_add_(
|
|
174
|
+
1,
|
|
175
|
+
chiral_centers.flatten(),
|
|
176
|
+
chiral_grads.flatten(start_dim=1, end_dim=2),
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
return grads
|
rf3/metrics/chiral.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from atomworks.io.transforms.atom_array import ensure_atom_array_stack
|
|
3
|
+
from atomworks.ml.transforms.af3_reference_molecule import (
|
|
4
|
+
get_af3_reference_molecule_features,
|
|
5
|
+
)
|
|
6
|
+
from atomworks.ml.transforms.chirals import add_af3_chiral_features
|
|
7
|
+
from atomworks.ml.transforms.rdkit_utils import get_rdkit_chiral_centers
|
|
8
|
+
from beartype.typing import Any
|
|
9
|
+
from biotite.structure import AtomArray, AtomArrayStack
|
|
10
|
+
from jaxtyping import Bool, Float
|
|
11
|
+
from rf3.kinematics import get_dih
|
|
12
|
+
|
|
13
|
+
from foundry.metrics.metric import Metric
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def calc_chiral_metrics_masked(
|
|
17
|
+
pred: Float[torch.Tensor, "B L ... 3"],
|
|
18
|
+
chirals: Float[torch.Tensor, "n_chiral 5"],
|
|
19
|
+
mask: Bool[torch.Tensor, "I"],
|
|
20
|
+
):
|
|
21
|
+
"""Calculate metrics for chiral centers, including:
|
|
22
|
+
- n_chiral_centers (B): number of chiral centers in the structure
|
|
23
|
+
- chiral_loss_mean (B): mean of the squared errors of chiral angles
|
|
24
|
+
- percent_correct_chirality (B): percentage of correctly predicted chiral centers
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
pred: predicted coords (B, L, :, 3)
|
|
28
|
+
chirals: True coords (nchiral, 5); skip if 0 chiral sites. 5 dimension are indices for 4 atoms that make dihedral and the ideal angle they should form
|
|
29
|
+
mask: Boolean mask of shape (I) indicating valid positions (e.g., non-NaN coordinates, desired residue type)
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
chiral_loss_sum: sum of squared errors of chiral angles (B)
|
|
33
|
+
n_chiral_centers: number of chiral centers in the structure
|
|
34
|
+
percent_correct_chirality: percentage of correctly predicted chiral centers (B)
|
|
35
|
+
"""
|
|
36
|
+
if not chirals.numel() or not mask.sum():
|
|
37
|
+
# ... no chiral centers; exit
|
|
38
|
+
return {}
|
|
39
|
+
|
|
40
|
+
# ... get the coordinates of all four atoms involved in each chiral center
|
|
41
|
+
chiral_dih = pred[
|
|
42
|
+
:, chirals[..., :-1].long()
|
|
43
|
+
] # (n_chiral 5) -> (B, n_chiral, 4, 3)
|
|
44
|
+
|
|
45
|
+
# ... for each chiral center, compute the dihedral angle
|
|
46
|
+
pred_dih = get_dih(
|
|
47
|
+
chiral_dih[..., 0, :],
|
|
48
|
+
chiral_dih[..., 1, :],
|
|
49
|
+
chiral_dih[..., 2, :],
|
|
50
|
+
chiral_dih[..., 3, :],
|
|
51
|
+
) # [B, n_chiral]
|
|
52
|
+
|
|
53
|
+
# ... total chiral loss (sum of squared errors)
|
|
54
|
+
diff = pred_dih - chirals[..., -1] # [B, n_chiral]
|
|
55
|
+
is_correct_chirality = torch.sign(pred_dih) == torch.sign(
|
|
56
|
+
chirals[..., -1]
|
|
57
|
+
) # [B, n_chiral]
|
|
58
|
+
|
|
59
|
+
# To avoid over-counting chirals, we should only keep one "row" for each chiral center (rather than enumerating all orderings)
|
|
60
|
+
inf_tensor = torch.tensor(
|
|
61
|
+
[-float("inf")], device=chirals.device, dtype=chirals.dtype
|
|
62
|
+
)
|
|
63
|
+
shifted = torch.cat([inf_tensor, chirals[:-1, 0]], dim=0) # Shape [24]
|
|
64
|
+
first_occurence_mask = chirals[:, 0] != shifted
|
|
65
|
+
|
|
66
|
+
is_valid_chiral_center = mask[chirals[..., :-1].long()].all(
|
|
67
|
+
dim=-1
|
|
68
|
+
) # [L] -> [n_chiral] (a chiral center is valid iff ALL atoms are included)
|
|
69
|
+
# ... and only keep the first occurrence of each chiral center
|
|
70
|
+
is_valid_chiral_center = is_valid_chiral_center & first_occurence_mask
|
|
71
|
+
|
|
72
|
+
percent_correct_chirality = (is_correct_chirality[:, is_valid_chiral_center]).sum(
|
|
73
|
+
dim=-1
|
|
74
|
+
) / is_valid_chiral_center.sum(dim=-1) # [B]
|
|
75
|
+
|
|
76
|
+
l = torch.square(diff[:, is_valid_chiral_center]).sum(dim=-1) # [B]
|
|
77
|
+
|
|
78
|
+
return {
|
|
79
|
+
"chiral_loss_mean": l / mask.sum(), # [B]
|
|
80
|
+
"n_chiral_centers": is_valid_chiral_center.sum(dim=-1), # [B]
|
|
81
|
+
"percent_correct_chirality": percent_correct_chirality, # [B]
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def compute_chiral_metrics(
|
|
86
|
+
predicted_atom_array_stack: AtomArrayStack | AtomArray,
|
|
87
|
+
ground_truth_atom_array_stack: AtomArrayStack | AtomArray,
|
|
88
|
+
chiral_feats: Float[torch.Tensor, "n_chiral 5"] | None = None,
|
|
89
|
+
):
|
|
90
|
+
"""Compute chiral metrics from the predicted and ground truth atom arrays.
|
|
91
|
+
|
|
92
|
+
If chiral features are not directly provided, they will be re-computed from the AtomArrays.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
dict: Dictionary containing chiral metrics, separated for polymers and non-polymers. The metrics are:
|
|
96
|
+
- n_chiral_centers: number of chiral centers in the structure
|
|
97
|
+
- chiral_loss_mean: mean of the squared errors of chiral angles
|
|
98
|
+
- percent_correct_chirality: percentage of correctly predicted chiral centers
|
|
99
|
+
"""
|
|
100
|
+
predicted_atom_array_stack = ensure_atom_array_stack(predicted_atom_array_stack)
|
|
101
|
+
ground_truth_atom_array_stack = ensure_atom_array_stack(
|
|
102
|
+
ground_truth_atom_array_stack
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
chiral_metrics = {}
|
|
106
|
+
# (Choose the first model - chirality does not depend on our data augmentation)
|
|
107
|
+
ground_truth_atom_array = ground_truth_atom_array_stack[0]
|
|
108
|
+
|
|
109
|
+
if chiral_feats is None:
|
|
110
|
+
# Generate chiral features if not provided
|
|
111
|
+
_, rdkit_mols = get_af3_reference_molecule_features(ground_truth_atom_array)
|
|
112
|
+
chiral_centers = get_rdkit_chiral_centers(rdkit_mols)
|
|
113
|
+
chiral_feats = add_af3_chiral_features(
|
|
114
|
+
ground_truth_atom_array, chiral_centers, rdkit_mols
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
X_L = torch.from_numpy(predicted_atom_array_stack.coord).to(
|
|
118
|
+
device=chiral_feats.device
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
categories = ["polymer", "non_polymer"]
|
|
122
|
+
_polymer_mask = torch.from_numpy(ground_truth_atom_array.is_polymer).to(
|
|
123
|
+
device=chiral_feats.device
|
|
124
|
+
)
|
|
125
|
+
# (Only consider non-NaN coordinates in the ground truth, since otherwise we can't compare dihedral angles)
|
|
126
|
+
_valid_coord_mask = ~torch.isnan(
|
|
127
|
+
torch.from_numpy(ground_truth_atom_array.coord)
|
|
128
|
+
).any(dim=1).to(device=chiral_feats.device)
|
|
129
|
+
masks = [_polymer_mask, ~_polymer_mask]
|
|
130
|
+
|
|
131
|
+
for category, mask in zip(categories, masks):
|
|
132
|
+
# ... compute the chiral loss, given the mask
|
|
133
|
+
result = calc_chiral_metrics_masked(
|
|
134
|
+
X_L,
|
|
135
|
+
chiral_feats,
|
|
136
|
+
mask=mask & _valid_coord_mask,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
if not result:
|
|
140
|
+
# No chiral centers - skip
|
|
141
|
+
continue
|
|
142
|
+
|
|
143
|
+
# ... store the metric results, meaned over the diffusion batch
|
|
144
|
+
if result["n_chiral_centers"] > 0:
|
|
145
|
+
chiral_metrics[f"{category}_n_chiral_centers"] = result[
|
|
146
|
+
"n_chiral_centers"
|
|
147
|
+
].item()
|
|
148
|
+
chiral_metrics[f"{category}_chiral_loss_mean"] = (
|
|
149
|
+
result["chiral_loss_mean"].mean().item()
|
|
150
|
+
)
|
|
151
|
+
chiral_metrics[f"{category}_percent_correct_chirality"] = (
|
|
152
|
+
result["percent_correct_chirality"].mean().item()
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
return chiral_metrics
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class ChiralLoss(Metric):
|
|
159
|
+
@property
|
|
160
|
+
def kwargs_to_compute_args(self) -> dict[str, Any]:
|
|
161
|
+
return {
|
|
162
|
+
"predicted_atom_array_stack": "predicted_atom_array_stack",
|
|
163
|
+
"ground_truth_atom_array_stack": "ground_truth_atom_array_stack",
|
|
164
|
+
"chiral_feats": ("network_input", "f", "chiral_feats"),
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
def compute(
|
|
168
|
+
self,
|
|
169
|
+
predicted_atom_array_stack: AtomArrayStack | AtomArray,
|
|
170
|
+
ground_truth_atom_array_stack: AtomArrayStack | AtomArray,
|
|
171
|
+
chiral_feats: Float[torch.Tensor, "n_chiral 5"] = None,
|
|
172
|
+
):
|
|
173
|
+
chiral_metrics = compute_chiral_metrics(
|
|
174
|
+
predicted_atom_array_stack,
|
|
175
|
+
ground_truth_atom_array_stack,
|
|
176
|
+
chiral_feats=chiral_feats,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
return chiral_metrics
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from biotite.structure import AtomArrayStack
|
|
6
|
+
|
|
7
|
+
from foundry.metrics.metric import Metric
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class CountClashingChains(Metric):
|
|
11
|
+
@property
|
|
12
|
+
def kwargs_to_compute_args(self) -> dict[str, Any]:
|
|
13
|
+
return {
|
|
14
|
+
"X_L": ("network_output", "X_L"),
|
|
15
|
+
"predicted_atom_array_stack": ("predicted_atom_array_stack",),
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
def compute(
|
|
19
|
+
self,
|
|
20
|
+
X_L: torch.Tensor,
|
|
21
|
+
predicted_atom_array_stack: AtomArrayStack,
|
|
22
|
+
) -> dict[str, float]:
|
|
23
|
+
"""Compute the predicted interface TM-score (IPTM) from the predicted aligned error (PAE).
|
|
24
|
+
Args:
|
|
25
|
+
X_L: Predicted aligned error tensor.
|
|
26
|
+
predicted_atom_array_stack: AtomArrayStack containing the predicted structure.
|
|
27
|
+
Returns:
|
|
28
|
+
clash_count: Computed clashing chains count.
|
|
29
|
+
"""
|
|
30
|
+
D = X_L.shape[0]
|
|
31
|
+
MIN_CLASH_DISTANCE = 1.1 # Minimum distance to consider a clash
|
|
32
|
+
# Count the number of clashing chains
|
|
33
|
+
pn_units = set(predicted_atom_array_stack.pn_unit_id)
|
|
34
|
+
|
|
35
|
+
has_clash = torch.zeros((D), dtype=torch.bool)
|
|
36
|
+
for chain_i, chain_j in itertools.combinations(pn_units, 2):
|
|
37
|
+
# check to make sure they are both polymer chains
|
|
38
|
+
|
|
39
|
+
chain_i_atoms = predicted_atom_array_stack[
|
|
40
|
+
:, predicted_atom_array_stack.pn_unit_id == chain_i
|
|
41
|
+
]
|
|
42
|
+
chain_j_atoms = predicted_atom_array_stack[
|
|
43
|
+
:, predicted_atom_array_stack.pn_unit_id == chain_j
|
|
44
|
+
]
|
|
45
|
+
if not chain_i_atoms[0, 0].is_polymer or not chain_j_atoms[0, 0].is_polymer:
|
|
46
|
+
continue
|
|
47
|
+
|
|
48
|
+
distances = torch.cdist(
|
|
49
|
+
torch.from_numpy(chain_i_atoms.coord),
|
|
50
|
+
torch.from_numpy(chain_j_atoms.coord),
|
|
51
|
+
)
|
|
52
|
+
num_clashes = (distances < MIN_CLASH_DISTANCE).sum(dim=-1).sum(dim=-1)
|
|
53
|
+
has_clash_pair = (num_clashes > 100) | (
|
|
54
|
+
num_clashes
|
|
55
|
+
/ (
|
|
56
|
+
max(chain_i_atoms.coord.shape[0], chain_j_atoms.coord.shape[0])
|
|
57
|
+
+ 1e-6
|
|
58
|
+
)
|
|
59
|
+
> 0.5
|
|
60
|
+
)
|
|
61
|
+
has_clash = torch.logical_or(has_clash, has_clash_pair)
|
|
62
|
+
assert has_clash.shape == (D,)
|
|
63
|
+
# unpack the batch dimension into separate keys in the output dictionary
|
|
64
|
+
|
|
65
|
+
clash_count_per_batch = {
|
|
66
|
+
f"has_clash_{i}": int(has_clash[i]) for i in range(len(has_clash))
|
|
67
|
+
}
|
|
68
|
+
return clash_count_per_batch
|