rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
rfd3/metrics/losses.py ADDED
@@ -0,0 +1,325 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from foundry.training.checkpoint import activation_checkpointing
5
+
6
+
7
+ class SequenceLoss(nn.Module):
8
+ def __init__(self, weight, min_t=0, max_t=torch.inf):
9
+ super().__init__()
10
+ self.weight = weight
11
+ self.min_t = min_t
12
+ self.max_t = max_t
13
+ self.loss_fn = nn.CrossEntropyLoss(reduction="none")
14
+
15
+ def forward(self, network_input, network_output, loss_input):
16
+ t = network_input["t"] # (B,)
17
+ valid_t = (self.min_t <= t) & (t < self.max_t) # bool mask over batch
18
+ n_valid_t = valid_t.sum()
19
+
20
+ # Grab network outputs
21
+ sequence_logits_I = network_output["sequence_logits_I"] # (B, L, 32)
22
+ sequence_indices_I = network_output["sequence_indices_I"] # (B, L)
23
+
24
+ if n_valid_t == 0:
25
+ zero = sequence_logits_I.sum() * 0.0
26
+ return zero, {
27
+ "valid_t_fraction": torch.tensor([0.0]),
28
+ "n_valid_t": torch.tensor([0.0]),
29
+ }
30
+
31
+ pred_seq = sequence_logits_I[valid_t] # (V, L, 32)
32
+ gt_seq = loss_input["seq_token_lvl"] # [L,]
33
+ gt_seq = gt_seq.unsqueeze(0).expand(n_valid_t, -1) # (V, L)
34
+ w_seq = loss_input["sequence_valid_mask"] # [L,]
35
+
36
+ # Cross‑entropy token loss
37
+ token_loss = self.loss_fn(pred_seq.permute(0, 2, 1), gt_seq) # (V, L)
38
+ token_loss = token_loss * w_seq[None] # (V, L)
39
+ token_loss = token_loss.mean(dim=-1) # (V,)
40
+
41
+ _, order = torch.sort(t[valid_t]) # low‑t first
42
+ sequence_indices_I = sequence_indices_I[valid_t]
43
+ recovery = (sequence_indices_I == gt_seq).float() # (V, L)
44
+ recovery = recovery[order] # reorder by t
45
+ recovery = recovery[..., (w_seq > 0).bool()] # [V, L_valid]
46
+ lowest_t_rec = recovery[0].mean() # scalar
47
+
48
+ outs = {
49
+ "token_lvl_sequence_loss": token_loss.mean().detach(),
50
+ "seq_recovery": recovery.mean().detach(),
51
+ "lowest_t_seq_recovery": lowest_t_rec.detach(),
52
+ "valid_t_fraction": valid_t.float().mean().detach(),
53
+ "n_valid_t": n_valid_t.float(),
54
+ }
55
+ token_loss = torch.clamp(token_loss.mean(), max=4)
56
+ return self.weight * token_loss, outs
57
+
58
+
59
+ class DiffusionLoss(nn.Module):
60
+ def __init__(
61
+ self,
62
+ *,
63
+ weight,
64
+ sigma_data,
65
+ lddt_weight,
66
+ alpha_virtual_atom=1.0,
67
+ alpha_unindexed_diffused=1.0,
68
+ alpha_polar_residues=1.0,
69
+ alpha_ligand=2.0,
70
+ unindexed_t_alpha=1.0,
71
+ unindexed_norm_p=1.0,
72
+ lp_weight=0.0,
73
+ **_, # dump args from old configs
74
+ ):
75
+ super().__init__()
76
+ self.weight = weight
77
+ self.lddt_weight = lddt_weight
78
+ self.sigma_data = sigma_data
79
+
80
+ self.alpha_unindexed_diffused = alpha_unindexed_diffused
81
+ self.alpha_virtual_atom = alpha_virtual_atom
82
+ self.unindexed_norm_p = unindexed_norm_p
83
+ self.unindexed_t_alpha = unindexed_t_alpha
84
+ self.lp_weight = lp_weight
85
+ self.alpha_ligand = alpha_ligand
86
+ self.alpha_polar_residues = alpha_polar_residues
87
+
88
+ self.get_lambda = (
89
+ lambda sigma: (sigma**2 + self.sigma_data**2)
90
+ / (sigma * self.sigma_data) ** 2
91
+ )
92
+
93
+ def forward(self, network_input, network_output, loss_input):
94
+ X_L = network_output["X_L"] # D, L, 3
95
+ D = X_L.shape[0]
96
+ crd_mask_L = loss_input["crd_mask_L"] # (D, L)
97
+ crd_mask_L = crd_mask_L.unsqueeze(0).expand(D, -1)
98
+ tok_idx = network_input["f"]["atom_to_token_map"]
99
+ t = network_input["t"] # (D,)
100
+ is_original_unindexed_token = loss_input["is_original_unindexed_token"][tok_idx]
101
+ is_polar_atom = network_input["f"]["is_polar"][tok_idx]
102
+ is_ligand = network_input["f"]["is_ligand"][tok_idx]
103
+ is_virtual_atom = network_input["f"]["is_virtual"] # L
104
+ is_sidechain_atom = network_input["f"]["is_sidechain"] # L
105
+ is_sidechain_atom = is_sidechain_atom & ~is_virtual_atom
106
+
107
+ w_L = torch.ones_like(tok_idx, dtype=X_L.dtype)
108
+ w_L[is_original_unindexed_token] = (
109
+ w_L[is_original_unindexed_token] * self.alpha_unindexed_diffused
110
+ )
111
+ w_L[is_virtual_atom] *= self.alpha_virtual_atom
112
+ w_L[is_ligand] *= self.alpha_ligand
113
+
114
+ # Upweight polar residues
115
+ w_L[is_polar_atom] *= self.alpha_polar_residues
116
+ w_L = w_L[None].expand(D, -1) * crd_mask_L
117
+
118
+ X_gt_L = torch.nan_to_num(loss_input["X_gt_L_in_input_frame"])
119
+ l_mse_L = w_L * torch.sum((X_L - X_gt_L) ** 2, dim=-1)
120
+ l_mse_L = torch.div(l_mse_L, 3 * torch.sum(crd_mask_L[0]) + 1e-4) # D, L
121
+
122
+ if torch.any(is_original_unindexed_token):
123
+ t_exp = t[:, None].expand(-1, X_L.shape[1]) # [D, L]
124
+ t_exp = (
125
+ t_exp * (~is_original_unindexed_token)
126
+ + self.unindexed_t_alpha * t_exp * is_original_unindexed_token
127
+ )
128
+
129
+ l_global = (self.get_lambda(t_exp) * l_mse_L).sum(-1)
130
+
131
+ # Get renormalization factor to equalize expectation of the loss
132
+ r = self.get_lambda(t * self.unindexed_t_alpha) / self.get_lambda(t)
133
+ t_factor = crd_mask_L.sum(-1) / (
134
+ r * crd_mask_L[:, is_original_unindexed_token].sum(-1)
135
+ + crd_mask_L[:, ~is_original_unindexed_token].sum(-1)
136
+ )
137
+ assert t_factor.shape == (D,), t_factor.shape
138
+ l_global = l_global * t_factor
139
+ else:
140
+ l_global = self.get_lambda(t) * l_mse_L.sum(-1)
141
+
142
+ assert l_global.shape == (D,), l_global.shape
143
+
144
+ if torch.any(is_original_unindexed_token):
145
+ lp_norm_L = w_L * torch.linalg.norm(
146
+ X_L - X_gt_L, ord=self.unindexed_norm_p, dim=-1
147
+ ) # [D, L]
148
+ lp_norm_unindexed_diffused = lp_norm_L * is_original_unindexed_token[None]
149
+ lp_norm_unindexed_diffused = torch.div(
150
+ lp_norm_unindexed_diffused,
151
+ self.alpha_unindexed_diffused
152
+ * 3
153
+ * torch.sum(is_original_unindexed_token)
154
+ + 1e-4,
155
+ ) # D, L
156
+ lp_norm_unindexed_diffused = lp_norm_unindexed_diffused.sum(
157
+ -1
158
+ ) * self.get_lambda(self.unindexed_t_alpha * t)
159
+
160
+ l_total = l_global + self.lp_weight * lp_norm_unindexed_diffused
161
+ else:
162
+ lp_norm_unindexed_diffused = None
163
+ lp_norm_L = None
164
+ l_total = l_global
165
+
166
+ # ... Aggregate
167
+ l_mse_total = torch.clamp(l_total, max=2)
168
+ assert l_mse_total.shape == (
169
+ D,
170
+ ), f"Expected l_total to be of shape (D,), got {l_total.shape}"
171
+ l_mse_total = torch.mean(l_mse_total) # D, -> scalar
172
+
173
+ # ... Return
174
+ if self.lddt_weight > 0:
175
+ # ... Calculate LDDT loss at the beginning
176
+ smoothed_lddt_loss_, lddt_loss_dict = smoothed_lddt_loss(
177
+ X_L,
178
+ X_gt_L,
179
+ crd_mask_L,
180
+ network_input["f"]["is_dna"],
181
+ network_input["f"]["is_rna"],
182
+ tok_idx,
183
+ return_extras=True,
184
+ ) # D,
185
+ l_total = l_mse_total + self.lddt_weight * smoothed_lddt_loss_.mean()
186
+ else:
187
+ lddt_loss_dict = {}
188
+ l_total = l_mse_total
189
+ # ... Return additional losses
190
+ t, indices = torch.sort(t)
191
+ l_mse_low, l_mse_high = torch.split(l_global[indices], [D // 2, D - D // 2])
192
+ loss_dict = {
193
+ "mse_loss_mean": l_mse_total,
194
+ "mse_loss_low_t": l_mse_low,
195
+ "mse_loss_high_t": l_mse_high,
196
+ "lp_norm": lp_norm_L,
197
+ "lp_norm_unindexed_diffused": lp_norm_unindexed_diffused,
198
+ } | lddt_loss_dict
199
+ loss_dict = {
200
+ k: torch.mean(v).detach() for k, v in loss_dict.items() if v is not None
201
+ }
202
+
203
+ return self.weight * l_total, loss_dict
204
+
205
+
206
+ def smoothed_lddt_loss(
207
+ X_L,
208
+ X_gt_L,
209
+ crd_mask_L,
210
+ is_dna,
211
+ is_rna,
212
+ tok_idx,
213
+ is_virtual=None,
214
+ alpha_virtual=1.0,
215
+ return_extras=False,
216
+ eps=1e-6,
217
+ ):
218
+ @activation_checkpointing
219
+ def _dolddt(X_L, X_gt_L, crd_mask_L, is_dna, is_rna, tok_idx, eps, use_amp=True):
220
+ B, L = X_L.shape[:2]
221
+ first_index, second_index = torch.triu_indices(L, L, 1, device=X_L.device)
222
+
223
+ # compute the unique distances between all pairs of atoms
224
+ X_gt_L = X_gt_L.nan_to_num()
225
+
226
+ # only use native 1 (assumes dist map identical btwn all copies)
227
+ ground_truth_distances = torch.linalg.norm(
228
+ X_gt_L[0:1, first_index] - X_gt_L[0:1, second_index], dim=-1
229
+ )
230
+
231
+ # only score pairs that are close enough in the ground truth
232
+ is_na_L = is_dna[tok_idx][first_index] | is_rna[tok_idx][first_index]
233
+ pair_mask = torch.logical_and(
234
+ ground_truth_distances > 0,
235
+ ground_truth_distances < torch.where(is_na_L, 30.0, 15.0),
236
+ )
237
+ del is_na_L
238
+
239
+ # only score pairs that are resolved in the ground truth
240
+ pair_mask *= crd_mask_L[0:1, first_index] * crd_mask_L[0:1, second_index]
241
+
242
+ # don't score pairs that are in the same token
243
+ pair_mask *= tok_idx[None, first_index] != tok_idx[None, second_index]
244
+
245
+ _, valid_pairs = pair_mask.nonzero(as_tuple=True)
246
+ pair_mask = pair_mask[:, valid_pairs].to(X_L.dtype)
247
+ ground_truth_distances = ground_truth_distances[:, valid_pairs]
248
+ first_index, second_index = first_index[valid_pairs], second_index[valid_pairs]
249
+
250
+ predicted_distances = torch.linalg.norm(
251
+ X_L[:, first_index] - X_L[:, second_index], dim=-1
252
+ )
253
+
254
+ delta_distances = torch.abs(predicted_distances - ground_truth_distances + eps)
255
+ del predicted_distances, ground_truth_distances
256
+
257
+ if is_virtual is not None:
258
+ pair_mask[:, (is_virtual[first_index] * is_virtual[second_index])] *= (
259
+ alpha_virtual
260
+ )
261
+
262
+ # I assume gradients flow better if we sum first rather than keeping everything in D, L...
263
+ lddt = (
264
+ 0.25
265
+ * (
266
+ torch.sum(torch.sigmoid(0.5 - delta_distances) * pair_mask, dim=(1))
267
+ + torch.sum(torch.sigmoid(1.0 - delta_distances) * pair_mask, dim=(1))
268
+ + torch.sum(torch.sigmoid(2.0 - delta_distances) * pair_mask, dim=(1))
269
+ + torch.sum(torch.sigmoid(4.0 - delta_distances) * pair_mask, dim=(1))
270
+ )
271
+ / (torch.sum(pair_mask, dim=(1)) + eps)
272
+ )
273
+
274
+ if not return_extras:
275
+ return 1 - lddt
276
+
277
+ # ...Hence we recalculate the losses here and pick out the parts of interest
278
+ with torch.no_grad():
279
+ lddt_ = (
280
+ 0.25
281
+ * (
282
+ torch.sigmoid(0.5 - delta_distances)
283
+ + torch.sigmoid(1.0 - delta_distances)
284
+ + torch.sigmoid(2.0 - delta_distances)
285
+ + torch.sigmoid(4.0 - delta_distances)
286
+ )
287
+ * pair_mask
288
+ / (torch.sum(pair_mask, dim=(1)) + eps)
289
+ )
290
+
291
+ def filter_lddt(mask, scale=1.0):
292
+ mask = mask.to(pair_mask.dtype)
293
+ if mask.ndim > 1:
294
+ mask = mask[0]
295
+ mask = (mask[first_index] * mask[second_index])[None].expand(
296
+ pair_mask.shape[0], -1
297
+ )
298
+ mask = (mask * pair_mask).to(bool)
299
+ return (
300
+ (1 - torch.sum(lddt_[:, mask[0]] * scale, dim=(1)))
301
+ .mean()
302
+ .detach()
303
+ .cpu()
304
+ )
305
+
306
+ extra_lddts = {}
307
+ extra_lddts["mean_lddt"] = filter_lddt(
308
+ torch.full_like(crd_mask_L, 1.0, device=X_L.device)
309
+ )
310
+ extra_lddts["mean_lddt_dna"] = filter_lddt(is_dna[tok_idx])
311
+ extra_lddts["mean_lddt_rna"] = filter_lddt(is_rna[tok_idx])
312
+ extra_lddts["mean_lddt_protein"] = filter_lddt(
313
+ ~is_dna[tok_idx] & ~is_rna[tok_idx]
314
+ )
315
+ # NOTE: This also seems to have issues at epoch level, as with n_valid_t
316
+ # before. Will leave as-is for now but may want to spoof as 0 in the future.
317
+ if is_virtual is not None:
318
+ extra_lddts["mean_lddt_virtual"] = filter_lddt(
319
+ is_virtual, scale=1 / alpha_virtual
320
+ )
321
+ extra_lddts["mean_lddt_non_virtual"] = filter_lddt(~is_virtual)
322
+
323
+ return 1 - lddt, extra_lddts
324
+
325
+ return _dolddt(X_L, X_gt_L, crd_mask_L, is_dna, is_rna, tok_idx, eps)
@@ -0,0 +1,118 @@
1
+ import itertools
2
+
3
+ import numpy as np
4
+ from atomworks.ml.preprocessing.utils.structure_utils import (
5
+ get_atom_mask_from_cell_list,
6
+ )
7
+ from atomworks.ml.utils.token import spread_token_wise
8
+ from biotite.structure import CellList, annotate_sse, gyration_radius
9
+ from rfd3.transforms.conditioning_base import get_motif_features
10
+
11
+
12
+ def get_ss_metrics_and_rg(
13
+ atom_array, ss_conditioning: dict[str, np.ndarray] | None = None
14
+ ):
15
+ """Compute secondary structure metrics and the radius of gyration for a given input file.
16
+
17
+ Args:
18
+ atom_array (AtomArray): Input AtomArray
19
+ ss_conditioning (dict[str, np.ndarray] | None): Dictionary mapping the keys "helix", "sheet", "loop" to the
20
+ corresponding conditioning arrays. If None, secondary structure adherence is not computed.
21
+
22
+ NOTE: Biotite computes secondary structures using the P-SEA algorithm:
23
+ G. Labesse, N. Colloc'h, J. Pothier, J. Mornon,
24
+ “P-SEA: a new efficient assignment of secondary structure from Ca trace of proteins,”
25
+ Bioinformatics, vol. 13, pp. 291-295, June 1997. doi: 10.1093/bioinformatics/13.3.291
26
+ """
27
+ # Compute secondary structure
28
+ sse_array = annotate_sse(atom_array)
29
+ sse_array_prot_only = sse_array[sse_array != ""]
30
+
31
+ # Basic compositional statistics
32
+ pdb_helix_percent = np.mean(sse_array_prot_only == "a")
33
+ pdb_strand_percent = np.mean(sse_array_prot_only == "b")
34
+ pdb_coil_percent = np.mean(sse_array_prot_only == "c")
35
+ pdb_ss_percent = pdb_helix_percent + pdb_strand_percent
36
+
37
+ # Number of disjoint helices or sheets
38
+ num_structural_elements = 0
39
+ for k, _ in itertools.groupby(sse_array):
40
+ if k not in ["", "c"]:
41
+ num_structural_elements += 1
42
+
43
+ if ss_conditioning is not None:
44
+ ss_adherence_dict = {}
45
+ atom_level_sse_array = spread_token_wise(atom_array, input_data=sse_array)
46
+ for ss_annot, ss_type in zip(["a", "b", "c"], ["helix", "sheet", "loop"]):
47
+ metric_name = f"{ss_type}_conditioning_adherence"
48
+ expected_indices = np.where(ss_conditioning[ss_type])[0]
49
+
50
+ if len(expected_indices) > 0:
51
+ ss_adherence = (
52
+ atom_level_sse_array[expected_indices] == ss_annot
53
+ ).mean()
54
+ ss_adherence_dict[metric_name] = ss_adherence
55
+ else:
56
+ # Would be misleading to give a numerical value if no conditioning of this type was provided
57
+ ss_adherence_dict[metric_name] = np.nan
58
+
59
+ # Compute radius of gyration
60
+ radius_of_gyration = gyration_radius(atom_array)
61
+
62
+ # Return output metrics
63
+ output_metrics = {
64
+ "non_loop_fraction": pdb_ss_percent,
65
+ "loop_fraction": pdb_coil_percent,
66
+ "helix_fraction": pdb_helix_percent,
67
+ "sheet_fraction": pdb_strand_percent,
68
+ "num_ss_elements": num_structural_elements,
69
+ "radius_of_gyration": radius_of_gyration,
70
+ }
71
+
72
+ if ss_conditioning is not None:
73
+ output_metrics.update(ss_adherence_dict)
74
+
75
+ return output_metrics
76
+
77
+
78
+ def _flatten_dict(d, parent="", sep="."):
79
+ """
80
+ Recursively flatten a nested dictionary.
81
+ E.g:
82
+ {"a": {"b": 1, "c": 2}} --> {"a.b": 1, "a.c": 2}
83
+ """
84
+ flat = {}
85
+ for k, v in d.items():
86
+ name = f"{parent}{sep}{k}" if parent else k
87
+ if isinstance(v, dict):
88
+ flat.update(_flatten_dict(v, name, sep=sep))
89
+ else:
90
+ flat[name] = v
91
+ return flat
92
+
93
+
94
+ def get_hotspot_contacts(atom_array, hotspot_mask, distance_cutoff=4.5):
95
+ """Get the number of inter-chain contacts between diffused atoms and hotspots within a distance cutoff."""
96
+
97
+ cell_list = CellList(atom_array, cell_size=distance_cutoff)
98
+ hotspot_array = atom_array[hotspot_mask]
99
+
100
+ # Compute all contacts with hotspots
101
+ full_contacting_atom_mask = get_atom_mask_from_cell_list(
102
+ hotspot_array.coord, cell_list, len(atom_array), distance_cutoff
103
+ ) # (n_hotspots, n_atoms)
104
+
105
+ # We only count interchain contacts
106
+ interchain_mask = hotspot_array.pn_unit_iid[:, None] != atom_array.pn_unit_iid[None]
107
+ interchain_contacts_mask = full_contacting_atom_mask & interchain_mask
108
+
109
+ # We only count contacts to diffused atoms
110
+ diffused_interchain_contacts_mask = interchain_contacts_mask[
111
+ :, ~get_motif_features(atom_array)["is_motif_atom"]
112
+ ]
113
+
114
+ contacted_hotspots_mask = np.any(
115
+ diffused_interchain_contacts_mask, axis=1
116
+ ) # (n_hotspots,)
117
+
118
+ return float(contacted_hotspots_mask.mean())