rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
rf3/loss/loss.py ADDED
@@ -0,0 +1,179 @@
1
+ import logging
2
+
3
+ import torch
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ def calc_ddihedralmse_dxyz(a, b, c, d, true_dih, eps=1e-6):
9
+ """
10
+ Calculates the gradient of the dihedral angle with respect to the xyz coordinates using the closed form derivative.
11
+ a, b, c, and d are atoms participating in the chiral center. true_dih is the true dihedral angle.
12
+
13
+ Unlike the original implementation, this does NOT use autograd.
14
+ """
15
+ # I need to reshape this from n_symm, batch, n, 3 to n_symm * batch * n, 3)
16
+ og_shape = a.shape
17
+ # Expand the dihedral by the batch dimension to match n_atoms*batchs
18
+ true_dih = true_dih.unsqueeze(0).repeat(a.shape[0], 1)
19
+ a = a.view(-1, 3)
20
+ b = b.view(-1, 3)
21
+ c = c.view(-1, 3)
22
+ d = d.view(-1, 3)
23
+ true_dih = true_dih.view(-1)
24
+
25
+ batch_size = a.shape[0] # Support for batch size
26
+ I = (
27
+ torch.eye(3).unsqueeze(0).repeat(batch_size, 1, 1).to(a.device)
28
+ ) # Make batch-aware identity matrix
29
+
30
+ # Compute b0, b1, b2
31
+ b0 = a - b
32
+ b1 = c - b
33
+ b2 = d - c
34
+
35
+ # Normalize b1
36
+ b1_norm = torch.norm(b1, dim=-1, keepdim=True)
37
+ b1n = b1 / (b1_norm + eps)
38
+
39
+ # Compute orthogonal components v and w
40
+ v = b0 - torch.sum(b0 * b1n, dim=-1, keepdim=True) * b1n
41
+ w = b2 - torch.sum(b2 * b1n, dim=-1, keepdim=True) * b1n
42
+
43
+ # Dihedral components x and y
44
+ x = torch.sum(v * w, dim=-1)
45
+ y = torch.sum(torch.cross(b1n, v, dim=-1) * w, dim=-1)
46
+
47
+ # Dihedral angle
48
+ dih = torch.atan2(y + eps, x + eps)
49
+
50
+ # Compute MSE loss and manual gradients
51
+ # mse_loss = torch.mean(torch.square(dih - true_dih))
52
+ # mse_loss = torch.sum(torch.square(dih - true_dih))
53
+
54
+ # Define matrices and gradients, adapted for batch
55
+ db0_db = -I
56
+ db1_db = -I
57
+ db1_dc = I
58
+ db2_dc = -I
59
+ db0_da = I
60
+ db2_dd = I
61
+ # dmse_ddih = 2 * (dih - true_dih) / batch_size
62
+ dmse_ddih = 2 * (dih - true_dih)
63
+ ddih_dx = -y / (x**2 + y**2 + eps)
64
+ ddih_dy = x / (x**2 + y**2 + eps)
65
+ dy_dv = -torch.cross(b1n, w, dim=-1)
66
+ dy_dw = torch.cross(b1n, v, dim=-1)
67
+ dx_dv = w
68
+ dx_dw = v
69
+
70
+ dw_db1n = -torch.sum(b2 * b1n, dim=-1, keepdim=True).unsqueeze(-1) * I - torch.bmm(
71
+ b2.unsqueeze(-1), b1n.unsqueeze(1)
72
+ )
73
+
74
+ db1n_db1 = (b1_norm + eps).unsqueeze(-1) * I / (b1_norm**2 + eps).unsqueeze(
75
+ -1
76
+ ) - torch.bmm(b1.unsqueeze(-1), b1.unsqueeze(1)) / (b1_norm**2 + eps).unsqueeze(-1)
77
+
78
+ dv_db1n = -torch.sum(b0 * b1n, dim=-1, keepdim=True).unsqueeze(-1) * I - torch.bmm(
79
+ b0.unsqueeze(-1), b1n.unsqueeze(1)
80
+ )
81
+ dv_db0 = I - torch.bmm(b1n.unsqueeze(-1), b1n.unsqueeze(1))
82
+ dw_db2 = I - torch.bmm(b1n.unsqueeze(-1), b1n.unsqueeze(1))
83
+
84
+ # Adjust sizes now for efficiency
85
+ ddih_dx = ddih_dx.view(-1, 1, 1)
86
+ ddih_dy = ddih_dy.view(-1, 1, 1)
87
+ dmse_ddih = dmse_ddih.view(-1, 1, 1)
88
+ dx_dv = dx_dv.unsqueeze(1)
89
+ dx_dw = dx_dw.unsqueeze(1)
90
+ dy_dv = dy_dv.unsqueeze(1)
91
+ dy_dw = dy_dw.unsqueeze(1)
92
+
93
+ # Gradient computations
94
+ # wrt a
95
+ dv_da = torch.matmul(dv_db0, db0_da)
96
+ ddih_da = torch.bmm((ddih_dx * dx_dv), dv_da) + torch.bmm((ddih_dy * dy_dv), dv_da)
97
+ dmse_da = torch.bmm(dmse_ddih, ddih_da)
98
+
99
+ # wrt b
100
+ db1n_db = torch.matmul(db1n_db1, db1_db)
101
+ dv_db = torch.matmul(dv_db0, db0_db) + torch.matmul(
102
+ dv_db1n.transpose(-1, -2), db1n_db
103
+ )
104
+ dw_db = torch.matmul(dw_db1n.transpose(-1, -2), db1n_db)
105
+ dx_db = torch.bmm(dx_dv, dv_db) + torch.bmm(dx_dw, dw_db)
106
+ dy_db = torch.bmm(dy_dv, dv_db) + torch.bmm(dy_dw, dw_db)
107
+ ddih_db = torch.bmm(ddih_dx, dx_db) + torch.bmm(ddih_dy, dy_db)
108
+ dmse_db = torch.bmm(dmse_ddih, ddih_db)
109
+
110
+ # wrt c
111
+ db1n_dc = torch.matmul(db1n_db1, db1_dc)
112
+ dv_dc = torch.matmul(dv_db1n.transpose(-1, -2), db1n_dc)
113
+ dw_dc = torch.matmul(dw_db2, db2_dc) + torch.matmul(
114
+ dw_db1n.transpose(-1, -2), db1n_dc
115
+ )
116
+ dx_dc = torch.bmm(dx_dv, dv_dc) + torch.bmm(dx_dw, dw_dc)
117
+ dy_dc = torch.bmm(dy_dv, dv_dc) + torch.bmm(dy_dw, dw_dc)
118
+ ddih_dc = torch.bmm(ddih_dx, dx_dc) + torch.bmm(ddih_dy, dy_dc)
119
+ dmse_dc = torch.bmm(dmse_ddih, ddih_dc)
120
+
121
+ # wrt d
122
+ dw_dd = torch.matmul(dw_db2, db2_dd)
123
+ ddih_dd = torch.bmm((ddih_dx * dx_dw), dw_dd) + torch.bmm((ddih_dy * dy_dw), dw_dd)
124
+ dmse_dd = torch.bmm(dmse_ddih, ddih_dd)
125
+
126
+ # Reshape gradients back to original shape and prep for cat
127
+ dmse_da = dmse_da.view(og_shape).unsqueeze(-2)
128
+ dmse_db = dmse_db.view(og_shape).unsqueeze(-2)
129
+ dmse_dc = dmse_dc.view(og_shape).unsqueeze(-2)
130
+ dmse_dd = dmse_dd.view(og_shape).unsqueeze(-2)
131
+
132
+ grads = torch.cat([dmse_da, dmse_db, dmse_dc, dmse_dd], dim=-2)
133
+ return grads
134
+
135
+
136
+ def calc_chiral_grads_flat_impl(
137
+ xyz, chiral_centers, chiral_center_dihedral_angles, no_grad_on_chiral_center
138
+ ):
139
+ """
140
+ Calculates the gradient of the chiral centers with respect to the xyz coordinates using the closed form derivative.
141
+ Args:
142
+ xyz: torch.Tensor, shape (batch, n_atoms, 3)
143
+ chiral_centers: torch.Tensor, shape (long) (n_centers, 4)
144
+ chiral_center_dihedral_angles: torch.Tensor, shape (float) (n_centers, 1)
145
+
146
+ Returns:
147
+ grads: torch.Tensor, shape (batch, n_atoms, 3)
148
+ """
149
+ # (We want to track the gradient of the dihedral angle loss with respect to the xyz coordinates)
150
+ xyz.requires_grad_(True)
151
+
152
+ # Edge case: No chiral centers, return zero gradients
153
+ if chiral_centers.shape[0] == 0:
154
+ return torch.zeros(xyz.shape, device=xyz.device)
155
+
156
+ # Get the coordinates of the four atoms that make up the chiral center
157
+ chiral_dih = xyz[:, chiral_centers, :]
158
+
159
+ # Calculate the gradient of the dihedral angle loss with respect to the xyz coordinates
160
+ grads = torch.zeros_like(xyz).to(xyz.device)
161
+ chiral_grads = calc_ddihedralmse_dxyz(
162
+ chiral_dih[..., 0, :],
163
+ chiral_dih[..., 1, :],
164
+ chiral_dih[..., 2, :],
165
+ chiral_dih[..., 3, :],
166
+ chiral_center_dihedral_angles,
167
+ ) # n_center, 4, 3
168
+
169
+ if no_grad_on_chiral_center:
170
+ chiral_grads[:, :, 0] = 0.0 # no gradient on chiral center
171
+
172
+ # back to atom
173
+ grads.index_add_(
174
+ 1,
175
+ chiral_centers.flatten(),
176
+ chiral_grads.flatten(start_dim=1, end_dim=2),
177
+ )
178
+
179
+ return grads
rf3/metrics/chiral.py ADDED
@@ -0,0 +1,179 @@
1
+ import torch
2
+ from atomworks.io.transforms.atom_array import ensure_atom_array_stack
3
+ from atomworks.ml.transforms.af3_reference_molecule import (
4
+ get_af3_reference_molecule_features,
5
+ )
6
+ from atomworks.ml.transforms.chirals import add_af3_chiral_features
7
+ from atomworks.ml.transforms.rdkit_utils import get_rdkit_chiral_centers
8
+ from beartype.typing import Any
9
+ from biotite.structure import AtomArray, AtomArrayStack
10
+ from jaxtyping import Bool, Float
11
+ from rf3.kinematics import get_dih
12
+
13
+ from foundry.metrics.metric import Metric
14
+
15
+
16
+ def calc_chiral_metrics_masked(
17
+ pred: Float[torch.Tensor, "B L ... 3"],
18
+ chirals: Float[torch.Tensor, "n_chiral 5"],
19
+ mask: Bool[torch.Tensor, "I"],
20
+ ):
21
+ """Calculate metrics for chiral centers, including:
22
+ - n_chiral_centers (B): number of chiral centers in the structure
23
+ - chiral_loss_mean (B): mean of the squared errors of chiral angles
24
+ - percent_correct_chirality (B): percentage of correctly predicted chiral centers
25
+
26
+ Args:
27
+ pred: predicted coords (B, L, :, 3)
28
+ chirals: True coords (nchiral, 5); skip if 0 chiral sites. 5 dimension are indices for 4 atoms that make dihedral and the ideal angle they should form
29
+ mask: Boolean mask of shape (I) indicating valid positions (e.g., non-NaN coordinates, desired residue type)
30
+
31
+ Returns:
32
+ chiral_loss_sum: sum of squared errors of chiral angles (B)
33
+ n_chiral_centers: number of chiral centers in the structure
34
+ percent_correct_chirality: percentage of correctly predicted chiral centers (B)
35
+ """
36
+ if not chirals.numel() or not mask.sum():
37
+ # ... no chiral centers; exit
38
+ return {}
39
+
40
+ # ... get the coordinates of all four atoms involved in each chiral center
41
+ chiral_dih = pred[
42
+ :, chirals[..., :-1].long()
43
+ ] # (n_chiral 5) -> (B, n_chiral, 4, 3)
44
+
45
+ # ... for each chiral center, compute the dihedral angle
46
+ pred_dih = get_dih(
47
+ chiral_dih[..., 0, :],
48
+ chiral_dih[..., 1, :],
49
+ chiral_dih[..., 2, :],
50
+ chiral_dih[..., 3, :],
51
+ ) # [B, n_chiral]
52
+
53
+ # ... total chiral loss (sum of squared errors)
54
+ diff = pred_dih - chirals[..., -1] # [B, n_chiral]
55
+ is_correct_chirality = torch.sign(pred_dih) == torch.sign(
56
+ chirals[..., -1]
57
+ ) # [B, n_chiral]
58
+
59
+ # To avoid over-counting chirals, we should only keep one "row" for each chiral center (rather than enumerating all orderings)
60
+ inf_tensor = torch.tensor(
61
+ [-float("inf")], device=chirals.device, dtype=chirals.dtype
62
+ )
63
+ shifted = torch.cat([inf_tensor, chirals[:-1, 0]], dim=0) # Shape [24]
64
+ first_occurence_mask = chirals[:, 0] != shifted
65
+
66
+ is_valid_chiral_center = mask[chirals[..., :-1].long()].all(
67
+ dim=-1
68
+ ) # [L] -> [n_chiral] (a chiral center is valid iff ALL atoms are included)
69
+ # ... and only keep the first occurrence of each chiral center
70
+ is_valid_chiral_center = is_valid_chiral_center & first_occurence_mask
71
+
72
+ percent_correct_chirality = (is_correct_chirality[:, is_valid_chiral_center]).sum(
73
+ dim=-1
74
+ ) / is_valid_chiral_center.sum(dim=-1) # [B]
75
+
76
+ l = torch.square(diff[:, is_valid_chiral_center]).sum(dim=-1) # [B]
77
+
78
+ return {
79
+ "chiral_loss_mean": l / mask.sum(), # [B]
80
+ "n_chiral_centers": is_valid_chiral_center.sum(dim=-1), # [B]
81
+ "percent_correct_chirality": percent_correct_chirality, # [B]
82
+ }
83
+
84
+
85
+ def compute_chiral_metrics(
86
+ predicted_atom_array_stack: AtomArrayStack | AtomArray,
87
+ ground_truth_atom_array_stack: AtomArrayStack | AtomArray,
88
+ chiral_feats: Float[torch.Tensor, "n_chiral 5"] | None = None,
89
+ ):
90
+ """Compute chiral metrics from the predicted and ground truth atom arrays.
91
+
92
+ If chiral features are not directly provided, they will be re-computed from the AtomArrays.
93
+
94
+ Returns:
95
+ dict: Dictionary containing chiral metrics, separated for polymers and non-polymers. The metrics are:
96
+ - n_chiral_centers: number of chiral centers in the structure
97
+ - chiral_loss_mean: mean of the squared errors of chiral angles
98
+ - percent_correct_chirality: percentage of correctly predicted chiral centers
99
+ """
100
+ predicted_atom_array_stack = ensure_atom_array_stack(predicted_atom_array_stack)
101
+ ground_truth_atom_array_stack = ensure_atom_array_stack(
102
+ ground_truth_atom_array_stack
103
+ )
104
+
105
+ chiral_metrics = {}
106
+ # (Choose the first model - chirality does not depend on our data augmentation)
107
+ ground_truth_atom_array = ground_truth_atom_array_stack[0]
108
+
109
+ if chiral_feats is None:
110
+ # Generate chiral features if not provided
111
+ _, rdkit_mols = get_af3_reference_molecule_features(ground_truth_atom_array)
112
+ chiral_centers = get_rdkit_chiral_centers(rdkit_mols)
113
+ chiral_feats = add_af3_chiral_features(
114
+ ground_truth_atom_array, chiral_centers, rdkit_mols
115
+ )
116
+
117
+ X_L = torch.from_numpy(predicted_atom_array_stack.coord).to(
118
+ device=chiral_feats.device
119
+ )
120
+
121
+ categories = ["polymer", "non_polymer"]
122
+ _polymer_mask = torch.from_numpy(ground_truth_atom_array.is_polymer).to(
123
+ device=chiral_feats.device
124
+ )
125
+ # (Only consider non-NaN coordinates in the ground truth, since otherwise we can't compare dihedral angles)
126
+ _valid_coord_mask = ~torch.isnan(
127
+ torch.from_numpy(ground_truth_atom_array.coord)
128
+ ).any(dim=1).to(device=chiral_feats.device)
129
+ masks = [_polymer_mask, ~_polymer_mask]
130
+
131
+ for category, mask in zip(categories, masks):
132
+ # ... compute the chiral loss, given the mask
133
+ result = calc_chiral_metrics_masked(
134
+ X_L,
135
+ chiral_feats,
136
+ mask=mask & _valid_coord_mask,
137
+ )
138
+
139
+ if not result:
140
+ # No chiral centers - skip
141
+ continue
142
+
143
+ # ... store the metric results, meaned over the diffusion batch
144
+ if result["n_chiral_centers"] > 0:
145
+ chiral_metrics[f"{category}_n_chiral_centers"] = result[
146
+ "n_chiral_centers"
147
+ ].item()
148
+ chiral_metrics[f"{category}_chiral_loss_mean"] = (
149
+ result["chiral_loss_mean"].mean().item()
150
+ )
151
+ chiral_metrics[f"{category}_percent_correct_chirality"] = (
152
+ result["percent_correct_chirality"].mean().item()
153
+ )
154
+
155
+ return chiral_metrics
156
+
157
+
158
+ class ChiralLoss(Metric):
159
+ @property
160
+ def kwargs_to_compute_args(self) -> dict[str, Any]:
161
+ return {
162
+ "predicted_atom_array_stack": "predicted_atom_array_stack",
163
+ "ground_truth_atom_array_stack": "ground_truth_atom_array_stack",
164
+ "chiral_feats": ("network_input", "f", "chiral_feats"),
165
+ }
166
+
167
+ def compute(
168
+ self,
169
+ predicted_atom_array_stack: AtomArrayStack | AtomArray,
170
+ ground_truth_atom_array_stack: AtomArrayStack | AtomArray,
171
+ chiral_feats: Float[torch.Tensor, "n_chiral 5"] = None,
172
+ ):
173
+ chiral_metrics = compute_chiral_metrics(
174
+ predicted_atom_array_stack,
175
+ ground_truth_atom_array_stack,
176
+ chiral_feats=chiral_feats,
177
+ )
178
+
179
+ return chiral_metrics
@@ -0,0 +1,68 @@
1
+ import itertools
2
+ from typing import Any
3
+
4
+ import torch
5
+ from biotite.structure import AtomArrayStack
6
+
7
+ from foundry.metrics.metric import Metric
8
+
9
+
10
+ class CountClashingChains(Metric):
11
+ @property
12
+ def kwargs_to_compute_args(self) -> dict[str, Any]:
13
+ return {
14
+ "X_L": ("network_output", "X_L"),
15
+ "predicted_atom_array_stack": ("predicted_atom_array_stack",),
16
+ }
17
+
18
+ def compute(
19
+ self,
20
+ X_L: torch.Tensor,
21
+ predicted_atom_array_stack: AtomArrayStack,
22
+ ) -> dict[str, float]:
23
+ """Compute the predicted interface TM-score (IPTM) from the predicted aligned error (PAE).
24
+ Args:
25
+ X_L: Predicted aligned error tensor.
26
+ predicted_atom_array_stack: AtomArrayStack containing the predicted structure.
27
+ Returns:
28
+ clash_count: Computed clashing chains count.
29
+ """
30
+ D = X_L.shape[0]
31
+ MIN_CLASH_DISTANCE = 1.1 # Minimum distance to consider a clash
32
+ # Count the number of clashing chains
33
+ pn_units = set(predicted_atom_array_stack.pn_unit_id)
34
+
35
+ has_clash = torch.zeros((D), dtype=torch.bool)
36
+ for chain_i, chain_j in itertools.combinations(pn_units, 2):
37
+ # check to make sure they are both polymer chains
38
+
39
+ chain_i_atoms = predicted_atom_array_stack[
40
+ :, predicted_atom_array_stack.pn_unit_id == chain_i
41
+ ]
42
+ chain_j_atoms = predicted_atom_array_stack[
43
+ :, predicted_atom_array_stack.pn_unit_id == chain_j
44
+ ]
45
+ if not chain_i_atoms[0, 0].is_polymer or not chain_j_atoms[0, 0].is_polymer:
46
+ continue
47
+
48
+ distances = torch.cdist(
49
+ torch.from_numpy(chain_i_atoms.coord),
50
+ torch.from_numpy(chain_j_atoms.coord),
51
+ )
52
+ num_clashes = (distances < MIN_CLASH_DISTANCE).sum(dim=-1).sum(dim=-1)
53
+ has_clash_pair = (num_clashes > 100) | (
54
+ num_clashes
55
+ / (
56
+ max(chain_i_atoms.coord.shape[0], chain_j_atoms.coord.shape[0])
57
+ + 1e-6
58
+ )
59
+ > 0.5
60
+ )
61
+ has_clash = torch.logical_or(has_clash, has_clash_pair)
62
+ assert has_clash.shape == (D,)
63
+ # unpack the batch dimension into separate keys in the output dictionary
64
+
65
+ clash_count_per_batch = {
66
+ f"has_clash_{i}": int(has_clash[i]) for i in range(len(has_clash))
67
+ }
68
+ return clash_count_per_batch