rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
rfd3/metrics/losses.py
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
from foundry.training.checkpoint import activation_checkpointing
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SequenceLoss(nn.Module):
|
|
8
|
+
def __init__(self, weight, min_t=0, max_t=torch.inf):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.weight = weight
|
|
11
|
+
self.min_t = min_t
|
|
12
|
+
self.max_t = max_t
|
|
13
|
+
self.loss_fn = nn.CrossEntropyLoss(reduction="none")
|
|
14
|
+
|
|
15
|
+
def forward(self, network_input, network_output, loss_input):
|
|
16
|
+
t = network_input["t"] # (B,)
|
|
17
|
+
valid_t = (self.min_t <= t) & (t < self.max_t) # bool mask over batch
|
|
18
|
+
n_valid_t = valid_t.sum()
|
|
19
|
+
|
|
20
|
+
# Grab network outputs
|
|
21
|
+
sequence_logits_I = network_output["sequence_logits_I"] # (B, L, 32)
|
|
22
|
+
sequence_indices_I = network_output["sequence_indices_I"] # (B, L)
|
|
23
|
+
|
|
24
|
+
if n_valid_t == 0:
|
|
25
|
+
zero = sequence_logits_I.sum() * 0.0
|
|
26
|
+
return zero, {
|
|
27
|
+
"valid_t_fraction": torch.tensor([0.0]),
|
|
28
|
+
"n_valid_t": torch.tensor([0.0]),
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
pred_seq = sequence_logits_I[valid_t] # (V, L, 32)
|
|
32
|
+
gt_seq = loss_input["seq_token_lvl"] # [L,]
|
|
33
|
+
gt_seq = gt_seq.unsqueeze(0).expand(n_valid_t, -1) # (V, L)
|
|
34
|
+
w_seq = loss_input["sequence_valid_mask"] # [L,]
|
|
35
|
+
|
|
36
|
+
# Cross‑entropy token loss
|
|
37
|
+
token_loss = self.loss_fn(pred_seq.permute(0, 2, 1), gt_seq) # (V, L)
|
|
38
|
+
token_loss = token_loss * w_seq[None] # (V, L)
|
|
39
|
+
token_loss = token_loss.mean(dim=-1) # (V,)
|
|
40
|
+
|
|
41
|
+
_, order = torch.sort(t[valid_t]) # low‑t first
|
|
42
|
+
sequence_indices_I = sequence_indices_I[valid_t]
|
|
43
|
+
recovery = (sequence_indices_I == gt_seq).float() # (V, L)
|
|
44
|
+
recovery = recovery[order] # reorder by t
|
|
45
|
+
recovery = recovery[..., (w_seq > 0).bool()] # [V, L_valid]
|
|
46
|
+
lowest_t_rec = recovery[0].mean() # scalar
|
|
47
|
+
|
|
48
|
+
outs = {
|
|
49
|
+
"token_lvl_sequence_loss": token_loss.mean().detach(),
|
|
50
|
+
"seq_recovery": recovery.mean().detach(),
|
|
51
|
+
"lowest_t_seq_recovery": lowest_t_rec.detach(),
|
|
52
|
+
"valid_t_fraction": valid_t.float().mean().detach(),
|
|
53
|
+
"n_valid_t": n_valid_t.float(),
|
|
54
|
+
}
|
|
55
|
+
token_loss = torch.clamp(token_loss.mean(), max=4)
|
|
56
|
+
return self.weight * token_loss, outs
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class DiffusionLoss(nn.Module):
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
*,
|
|
63
|
+
weight,
|
|
64
|
+
sigma_data,
|
|
65
|
+
lddt_weight,
|
|
66
|
+
alpha_virtual_atom=1.0,
|
|
67
|
+
alpha_unindexed_diffused=1.0,
|
|
68
|
+
alpha_polar_residues=1.0,
|
|
69
|
+
alpha_ligand=2.0,
|
|
70
|
+
unindexed_t_alpha=1.0,
|
|
71
|
+
unindexed_norm_p=1.0,
|
|
72
|
+
lp_weight=0.0,
|
|
73
|
+
**_, # dump args from old configs
|
|
74
|
+
):
|
|
75
|
+
super().__init__()
|
|
76
|
+
self.weight = weight
|
|
77
|
+
self.lddt_weight = lddt_weight
|
|
78
|
+
self.sigma_data = sigma_data
|
|
79
|
+
|
|
80
|
+
self.alpha_unindexed_diffused = alpha_unindexed_diffused
|
|
81
|
+
self.alpha_virtual_atom = alpha_virtual_atom
|
|
82
|
+
self.unindexed_norm_p = unindexed_norm_p
|
|
83
|
+
self.unindexed_t_alpha = unindexed_t_alpha
|
|
84
|
+
self.lp_weight = lp_weight
|
|
85
|
+
self.alpha_ligand = alpha_ligand
|
|
86
|
+
self.alpha_polar_residues = alpha_polar_residues
|
|
87
|
+
|
|
88
|
+
self.get_lambda = (
|
|
89
|
+
lambda sigma: (sigma**2 + self.sigma_data**2)
|
|
90
|
+
/ (sigma * self.sigma_data) ** 2
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
def forward(self, network_input, network_output, loss_input):
|
|
94
|
+
X_L = network_output["X_L"] # D, L, 3
|
|
95
|
+
D = X_L.shape[0]
|
|
96
|
+
crd_mask_L = loss_input["crd_mask_L"] # (D, L)
|
|
97
|
+
crd_mask_L = crd_mask_L.unsqueeze(0).expand(D, -1)
|
|
98
|
+
tok_idx = network_input["f"]["atom_to_token_map"]
|
|
99
|
+
t = network_input["t"] # (D,)
|
|
100
|
+
is_original_unindexed_token = loss_input["is_original_unindexed_token"][tok_idx]
|
|
101
|
+
is_polar_atom = network_input["f"]["is_polar"][tok_idx]
|
|
102
|
+
is_ligand = network_input["f"]["is_ligand"][tok_idx]
|
|
103
|
+
is_virtual_atom = network_input["f"]["is_virtual"] # L
|
|
104
|
+
is_sidechain_atom = network_input["f"]["is_sidechain"] # L
|
|
105
|
+
is_sidechain_atom = is_sidechain_atom & ~is_virtual_atom
|
|
106
|
+
|
|
107
|
+
w_L = torch.ones_like(tok_idx, dtype=X_L.dtype)
|
|
108
|
+
w_L[is_original_unindexed_token] = (
|
|
109
|
+
w_L[is_original_unindexed_token] * self.alpha_unindexed_diffused
|
|
110
|
+
)
|
|
111
|
+
w_L[is_virtual_atom] *= self.alpha_virtual_atom
|
|
112
|
+
w_L[is_ligand] *= self.alpha_ligand
|
|
113
|
+
|
|
114
|
+
# Upweight polar residues
|
|
115
|
+
w_L[is_polar_atom] *= self.alpha_polar_residues
|
|
116
|
+
w_L = w_L[None].expand(D, -1) * crd_mask_L
|
|
117
|
+
|
|
118
|
+
X_gt_L = torch.nan_to_num(loss_input["X_gt_L_in_input_frame"])
|
|
119
|
+
l_mse_L = w_L * torch.sum((X_L - X_gt_L) ** 2, dim=-1)
|
|
120
|
+
l_mse_L = torch.div(l_mse_L, 3 * torch.sum(crd_mask_L[0]) + 1e-4) # D, L
|
|
121
|
+
|
|
122
|
+
if torch.any(is_original_unindexed_token):
|
|
123
|
+
t_exp = t[:, None].expand(-1, X_L.shape[1]) # [D, L]
|
|
124
|
+
t_exp = (
|
|
125
|
+
t_exp * (~is_original_unindexed_token)
|
|
126
|
+
+ self.unindexed_t_alpha * t_exp * is_original_unindexed_token
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
l_global = (self.get_lambda(t_exp) * l_mse_L).sum(-1)
|
|
130
|
+
|
|
131
|
+
# Get renormalization factor to equalize expectation of the loss
|
|
132
|
+
r = self.get_lambda(t * self.unindexed_t_alpha) / self.get_lambda(t)
|
|
133
|
+
t_factor = crd_mask_L.sum(-1) / (
|
|
134
|
+
r * crd_mask_L[:, is_original_unindexed_token].sum(-1)
|
|
135
|
+
+ crd_mask_L[:, ~is_original_unindexed_token].sum(-1)
|
|
136
|
+
)
|
|
137
|
+
assert t_factor.shape == (D,), t_factor.shape
|
|
138
|
+
l_global = l_global * t_factor
|
|
139
|
+
else:
|
|
140
|
+
l_global = self.get_lambda(t) * l_mse_L.sum(-1)
|
|
141
|
+
|
|
142
|
+
assert l_global.shape == (D,), l_global.shape
|
|
143
|
+
|
|
144
|
+
if torch.any(is_original_unindexed_token):
|
|
145
|
+
lp_norm_L = w_L * torch.linalg.norm(
|
|
146
|
+
X_L - X_gt_L, ord=self.unindexed_norm_p, dim=-1
|
|
147
|
+
) # [D, L]
|
|
148
|
+
lp_norm_unindexed_diffused = lp_norm_L * is_original_unindexed_token[None]
|
|
149
|
+
lp_norm_unindexed_diffused = torch.div(
|
|
150
|
+
lp_norm_unindexed_diffused,
|
|
151
|
+
self.alpha_unindexed_diffused
|
|
152
|
+
* 3
|
|
153
|
+
* torch.sum(is_original_unindexed_token)
|
|
154
|
+
+ 1e-4,
|
|
155
|
+
) # D, L
|
|
156
|
+
lp_norm_unindexed_diffused = lp_norm_unindexed_diffused.sum(
|
|
157
|
+
-1
|
|
158
|
+
) * self.get_lambda(self.unindexed_t_alpha * t)
|
|
159
|
+
|
|
160
|
+
l_total = l_global + self.lp_weight * lp_norm_unindexed_diffused
|
|
161
|
+
else:
|
|
162
|
+
lp_norm_unindexed_diffused = None
|
|
163
|
+
lp_norm_L = None
|
|
164
|
+
l_total = l_global
|
|
165
|
+
|
|
166
|
+
# ... Aggregate
|
|
167
|
+
l_mse_total = torch.clamp(l_total, max=2)
|
|
168
|
+
assert l_mse_total.shape == (
|
|
169
|
+
D,
|
|
170
|
+
), f"Expected l_total to be of shape (D,), got {l_total.shape}"
|
|
171
|
+
l_mse_total = torch.mean(l_mse_total) # D, -> scalar
|
|
172
|
+
|
|
173
|
+
# ... Return
|
|
174
|
+
if self.lddt_weight > 0:
|
|
175
|
+
# ... Calculate LDDT loss at the beginning
|
|
176
|
+
smoothed_lddt_loss_, lddt_loss_dict = smoothed_lddt_loss(
|
|
177
|
+
X_L,
|
|
178
|
+
X_gt_L,
|
|
179
|
+
crd_mask_L,
|
|
180
|
+
network_input["f"]["is_dna"],
|
|
181
|
+
network_input["f"]["is_rna"],
|
|
182
|
+
tok_idx,
|
|
183
|
+
return_extras=True,
|
|
184
|
+
) # D,
|
|
185
|
+
l_total = l_mse_total + self.lddt_weight * smoothed_lddt_loss_.mean()
|
|
186
|
+
else:
|
|
187
|
+
lddt_loss_dict = {}
|
|
188
|
+
l_total = l_mse_total
|
|
189
|
+
# ... Return additional losses
|
|
190
|
+
t, indices = torch.sort(t)
|
|
191
|
+
l_mse_low, l_mse_high = torch.split(l_global[indices], [D // 2, D - D // 2])
|
|
192
|
+
loss_dict = {
|
|
193
|
+
"mse_loss_mean": l_mse_total,
|
|
194
|
+
"mse_loss_low_t": l_mse_low,
|
|
195
|
+
"mse_loss_high_t": l_mse_high,
|
|
196
|
+
"lp_norm": lp_norm_L,
|
|
197
|
+
"lp_norm_unindexed_diffused": lp_norm_unindexed_diffused,
|
|
198
|
+
} | lddt_loss_dict
|
|
199
|
+
loss_dict = {
|
|
200
|
+
k: torch.mean(v).detach() for k, v in loss_dict.items() if v is not None
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
return self.weight * l_total, loss_dict
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def smoothed_lddt_loss(
|
|
207
|
+
X_L,
|
|
208
|
+
X_gt_L,
|
|
209
|
+
crd_mask_L,
|
|
210
|
+
is_dna,
|
|
211
|
+
is_rna,
|
|
212
|
+
tok_idx,
|
|
213
|
+
is_virtual=None,
|
|
214
|
+
alpha_virtual=1.0,
|
|
215
|
+
return_extras=False,
|
|
216
|
+
eps=1e-6,
|
|
217
|
+
):
|
|
218
|
+
@activation_checkpointing
|
|
219
|
+
def _dolddt(X_L, X_gt_L, crd_mask_L, is_dna, is_rna, tok_idx, eps, use_amp=True):
|
|
220
|
+
B, L = X_L.shape[:2]
|
|
221
|
+
first_index, second_index = torch.triu_indices(L, L, 1, device=X_L.device)
|
|
222
|
+
|
|
223
|
+
# compute the unique distances between all pairs of atoms
|
|
224
|
+
X_gt_L = X_gt_L.nan_to_num()
|
|
225
|
+
|
|
226
|
+
# only use native 1 (assumes dist map identical btwn all copies)
|
|
227
|
+
ground_truth_distances = torch.linalg.norm(
|
|
228
|
+
X_gt_L[0:1, first_index] - X_gt_L[0:1, second_index], dim=-1
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
# only score pairs that are close enough in the ground truth
|
|
232
|
+
is_na_L = is_dna[tok_idx][first_index] | is_rna[tok_idx][first_index]
|
|
233
|
+
pair_mask = torch.logical_and(
|
|
234
|
+
ground_truth_distances > 0,
|
|
235
|
+
ground_truth_distances < torch.where(is_na_L, 30.0, 15.0),
|
|
236
|
+
)
|
|
237
|
+
del is_na_L
|
|
238
|
+
|
|
239
|
+
# only score pairs that are resolved in the ground truth
|
|
240
|
+
pair_mask *= crd_mask_L[0:1, first_index] * crd_mask_L[0:1, second_index]
|
|
241
|
+
|
|
242
|
+
# don't score pairs that are in the same token
|
|
243
|
+
pair_mask *= tok_idx[None, first_index] != tok_idx[None, second_index]
|
|
244
|
+
|
|
245
|
+
_, valid_pairs = pair_mask.nonzero(as_tuple=True)
|
|
246
|
+
pair_mask = pair_mask[:, valid_pairs].to(X_L.dtype)
|
|
247
|
+
ground_truth_distances = ground_truth_distances[:, valid_pairs]
|
|
248
|
+
first_index, second_index = first_index[valid_pairs], second_index[valid_pairs]
|
|
249
|
+
|
|
250
|
+
predicted_distances = torch.linalg.norm(
|
|
251
|
+
X_L[:, first_index] - X_L[:, second_index], dim=-1
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
delta_distances = torch.abs(predicted_distances - ground_truth_distances + eps)
|
|
255
|
+
del predicted_distances, ground_truth_distances
|
|
256
|
+
|
|
257
|
+
if is_virtual is not None:
|
|
258
|
+
pair_mask[:, (is_virtual[first_index] * is_virtual[second_index])] *= (
|
|
259
|
+
alpha_virtual
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
# I assume gradients flow better if we sum first rather than keeping everything in D, L...
|
|
263
|
+
lddt = (
|
|
264
|
+
0.25
|
|
265
|
+
* (
|
|
266
|
+
torch.sum(torch.sigmoid(0.5 - delta_distances) * pair_mask, dim=(1))
|
|
267
|
+
+ torch.sum(torch.sigmoid(1.0 - delta_distances) * pair_mask, dim=(1))
|
|
268
|
+
+ torch.sum(torch.sigmoid(2.0 - delta_distances) * pair_mask, dim=(1))
|
|
269
|
+
+ torch.sum(torch.sigmoid(4.0 - delta_distances) * pair_mask, dim=(1))
|
|
270
|
+
)
|
|
271
|
+
/ (torch.sum(pair_mask, dim=(1)) + eps)
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
if not return_extras:
|
|
275
|
+
return 1 - lddt
|
|
276
|
+
|
|
277
|
+
# ...Hence we recalculate the losses here and pick out the parts of interest
|
|
278
|
+
with torch.no_grad():
|
|
279
|
+
lddt_ = (
|
|
280
|
+
0.25
|
|
281
|
+
* (
|
|
282
|
+
torch.sigmoid(0.5 - delta_distances)
|
|
283
|
+
+ torch.sigmoid(1.0 - delta_distances)
|
|
284
|
+
+ torch.sigmoid(2.0 - delta_distances)
|
|
285
|
+
+ torch.sigmoid(4.0 - delta_distances)
|
|
286
|
+
)
|
|
287
|
+
* pair_mask
|
|
288
|
+
/ (torch.sum(pair_mask, dim=(1)) + eps)
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
def filter_lddt(mask, scale=1.0):
|
|
292
|
+
mask = mask.to(pair_mask.dtype)
|
|
293
|
+
if mask.ndim > 1:
|
|
294
|
+
mask = mask[0]
|
|
295
|
+
mask = (mask[first_index] * mask[second_index])[None].expand(
|
|
296
|
+
pair_mask.shape[0], -1
|
|
297
|
+
)
|
|
298
|
+
mask = (mask * pair_mask).to(bool)
|
|
299
|
+
return (
|
|
300
|
+
(1 - torch.sum(lddt_[:, mask[0]] * scale, dim=(1)))
|
|
301
|
+
.mean()
|
|
302
|
+
.detach()
|
|
303
|
+
.cpu()
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
extra_lddts = {}
|
|
307
|
+
extra_lddts["mean_lddt"] = filter_lddt(
|
|
308
|
+
torch.full_like(crd_mask_L, 1.0, device=X_L.device)
|
|
309
|
+
)
|
|
310
|
+
extra_lddts["mean_lddt_dna"] = filter_lddt(is_dna[tok_idx])
|
|
311
|
+
extra_lddts["mean_lddt_rna"] = filter_lddt(is_rna[tok_idx])
|
|
312
|
+
extra_lddts["mean_lddt_protein"] = filter_lddt(
|
|
313
|
+
~is_dna[tok_idx] & ~is_rna[tok_idx]
|
|
314
|
+
)
|
|
315
|
+
# NOTE: This also seems to have issues at epoch level, as with n_valid_t
|
|
316
|
+
# before. Will leave as-is for now but may want to spoof as 0 in the future.
|
|
317
|
+
if is_virtual is not None:
|
|
318
|
+
extra_lddts["mean_lddt_virtual"] = filter_lddt(
|
|
319
|
+
is_virtual, scale=1 / alpha_virtual
|
|
320
|
+
)
|
|
321
|
+
extra_lddts["mean_lddt_non_virtual"] = filter_lddt(~is_virtual)
|
|
322
|
+
|
|
323
|
+
return 1 - lddt, extra_lddts
|
|
324
|
+
|
|
325
|
+
return _dolddt(X_L, X_gt_L, crd_mask_L, is_dna, is_rna, tok_idx, eps)
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from atomworks.ml.preprocessing.utils.structure_utils import (
|
|
5
|
+
get_atom_mask_from_cell_list,
|
|
6
|
+
)
|
|
7
|
+
from atomworks.ml.utils.token import spread_token_wise
|
|
8
|
+
from biotite.structure import CellList, annotate_sse, gyration_radius
|
|
9
|
+
from rfd3.transforms.conditioning_base import get_motif_features
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def get_ss_metrics_and_rg(
|
|
13
|
+
atom_array, ss_conditioning: dict[str, np.ndarray] | None = None
|
|
14
|
+
):
|
|
15
|
+
"""Compute secondary structure metrics and the radius of gyration for a given input file.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
atom_array (AtomArray): Input AtomArray
|
|
19
|
+
ss_conditioning (dict[str, np.ndarray] | None): Dictionary mapping the keys "helix", "sheet", "loop" to the
|
|
20
|
+
corresponding conditioning arrays. If None, secondary structure adherence is not computed.
|
|
21
|
+
|
|
22
|
+
NOTE: Biotite computes secondary structures using the P-SEA algorithm:
|
|
23
|
+
G. Labesse, N. Colloc'h, J. Pothier, J. Mornon,
|
|
24
|
+
“P-SEA: a new efficient assignment of secondary structure from Ca trace of proteins,”
|
|
25
|
+
Bioinformatics, vol. 13, pp. 291-295, June 1997. doi: 10.1093/bioinformatics/13.3.291
|
|
26
|
+
"""
|
|
27
|
+
# Compute secondary structure
|
|
28
|
+
sse_array = annotate_sse(atom_array)
|
|
29
|
+
sse_array_prot_only = sse_array[sse_array != ""]
|
|
30
|
+
|
|
31
|
+
# Basic compositional statistics
|
|
32
|
+
pdb_helix_percent = np.mean(sse_array_prot_only == "a")
|
|
33
|
+
pdb_strand_percent = np.mean(sse_array_prot_only == "b")
|
|
34
|
+
pdb_coil_percent = np.mean(sse_array_prot_only == "c")
|
|
35
|
+
pdb_ss_percent = pdb_helix_percent + pdb_strand_percent
|
|
36
|
+
|
|
37
|
+
# Number of disjoint helices or sheets
|
|
38
|
+
num_structural_elements = 0
|
|
39
|
+
for k, _ in itertools.groupby(sse_array):
|
|
40
|
+
if k not in ["", "c"]:
|
|
41
|
+
num_structural_elements += 1
|
|
42
|
+
|
|
43
|
+
if ss_conditioning is not None:
|
|
44
|
+
ss_adherence_dict = {}
|
|
45
|
+
atom_level_sse_array = spread_token_wise(atom_array, input_data=sse_array)
|
|
46
|
+
for ss_annot, ss_type in zip(["a", "b", "c"], ["helix", "sheet", "loop"]):
|
|
47
|
+
metric_name = f"{ss_type}_conditioning_adherence"
|
|
48
|
+
expected_indices = np.where(ss_conditioning[ss_type])[0]
|
|
49
|
+
|
|
50
|
+
if len(expected_indices) > 0:
|
|
51
|
+
ss_adherence = (
|
|
52
|
+
atom_level_sse_array[expected_indices] == ss_annot
|
|
53
|
+
).mean()
|
|
54
|
+
ss_adherence_dict[metric_name] = ss_adherence
|
|
55
|
+
else:
|
|
56
|
+
# Would be misleading to give a numerical value if no conditioning of this type was provided
|
|
57
|
+
ss_adherence_dict[metric_name] = np.nan
|
|
58
|
+
|
|
59
|
+
# Compute radius of gyration
|
|
60
|
+
radius_of_gyration = gyration_radius(atom_array)
|
|
61
|
+
|
|
62
|
+
# Return output metrics
|
|
63
|
+
output_metrics = {
|
|
64
|
+
"non_loop_fraction": pdb_ss_percent,
|
|
65
|
+
"loop_fraction": pdb_coil_percent,
|
|
66
|
+
"helix_fraction": pdb_helix_percent,
|
|
67
|
+
"sheet_fraction": pdb_strand_percent,
|
|
68
|
+
"num_ss_elements": num_structural_elements,
|
|
69
|
+
"radius_of_gyration": radius_of_gyration,
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
if ss_conditioning is not None:
|
|
73
|
+
output_metrics.update(ss_adherence_dict)
|
|
74
|
+
|
|
75
|
+
return output_metrics
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _flatten_dict(d, parent="", sep="."):
|
|
79
|
+
"""
|
|
80
|
+
Recursively flatten a nested dictionary.
|
|
81
|
+
E.g:
|
|
82
|
+
{"a": {"b": 1, "c": 2}} --> {"a.b": 1, "a.c": 2}
|
|
83
|
+
"""
|
|
84
|
+
flat = {}
|
|
85
|
+
for k, v in d.items():
|
|
86
|
+
name = f"{parent}{sep}{k}" if parent else k
|
|
87
|
+
if isinstance(v, dict):
|
|
88
|
+
flat.update(_flatten_dict(v, name, sep=sep))
|
|
89
|
+
else:
|
|
90
|
+
flat[name] = v
|
|
91
|
+
return flat
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def get_hotspot_contacts(atom_array, hotspot_mask, distance_cutoff=4.5):
|
|
95
|
+
"""Get the number of inter-chain contacts between diffused atoms and hotspots within a distance cutoff."""
|
|
96
|
+
|
|
97
|
+
cell_list = CellList(atom_array, cell_size=distance_cutoff)
|
|
98
|
+
hotspot_array = atom_array[hotspot_mask]
|
|
99
|
+
|
|
100
|
+
# Compute all contacts with hotspots
|
|
101
|
+
full_contacting_atom_mask = get_atom_mask_from_cell_list(
|
|
102
|
+
hotspot_array.coord, cell_list, len(atom_array), distance_cutoff
|
|
103
|
+
) # (n_hotspots, n_atoms)
|
|
104
|
+
|
|
105
|
+
# We only count interchain contacts
|
|
106
|
+
interchain_mask = hotspot_array.pn_unit_iid[:, None] != atom_array.pn_unit_iid[None]
|
|
107
|
+
interchain_contacts_mask = full_contacting_atom_mask & interchain_mask
|
|
108
|
+
|
|
109
|
+
# We only count contacts to diffused atoms
|
|
110
|
+
diffused_interchain_contacts_mask = interchain_contacts_mask[
|
|
111
|
+
:, ~get_motif_features(atom_array)["is_motif_atom"]
|
|
112
|
+
]
|
|
113
|
+
|
|
114
|
+
contacted_hotspots_mask = np.any(
|
|
115
|
+
diffused_interchain_contacts_mask, axis=1
|
|
116
|
+
) # (n_hotspots,)
|
|
117
|
+
|
|
118
|
+
return float(contacted_hotspots_mask.mean())
|