rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,308 @@
1
+ import logging
2
+
3
+ import biotite.structure as struc
4
+ import numpy as np
5
+ from rfd3.constants import (
6
+ ATOM14_ATOM_NAMES,
7
+ association_schemes_stripped,
8
+ )
9
+ from rfd3.transforms.conditioning_base import get_motif_features
10
+ from rfd3.transforms.hbonds_hbplus import calculate_hbonds
11
+
12
+ from foundry.metrics.metric import Metric
13
+ from foundry.utils.ddp import RankedLogger
14
+
15
+ logging.basicConfig(level=logging.INFO)
16
+ global_logger = RankedLogger(__name__, rank_zero_only=False)
17
+
18
+
19
+ def simplified_processing_atom_array(atom_arrays, central_atom="CB", threshold=0.5):
20
+ """
21
+ Allows for sequence extraction from cleaned up virtual atoms. Needed for hbond metrics.
22
+ """
23
+ final_atom_array = []
24
+
25
+ for atom_array in atom_arrays:
26
+ cur_atom_array_list = []
27
+
28
+ res_ids = atom_array.res_id
29
+ res_start_indices = np.concatenate(
30
+ [[0], np.where(res_ids[1:] != res_ids[:-1])[0] + 1]
31
+ )
32
+ res_end_indices = np.concatenate([res_start_indices[1:], [len(res_ids)]])
33
+
34
+ for start, end in zip(res_start_indices, res_end_indices):
35
+ cur_res_atom_array = atom_array[start:end]
36
+
37
+ # Check if the current residue is after padding (seq unknown):
38
+ if_seq_known = not any(
39
+ atom_name.startswith("V") for atom_name in cur_res_atom_array.atom_name
40
+ )
41
+
42
+ if not if_seq_known:
43
+ # Glycine fallback to CA
44
+ CA_coord = cur_res_atom_array.coord[
45
+ cur_res_atom_array.atom_name == "CA"
46
+ ]
47
+ CB_coord = cur_res_atom_array.coord[
48
+ cur_res_atom_array.atom_name == "CB"
49
+ ]
50
+
51
+ if np.linalg.norm(CA_coord - CB_coord) < threshold:
52
+ central_atom = "CA"
53
+
54
+ central_mask = cur_res_atom_array.atom_name == central_atom
55
+ central_coord = cur_res_atom_array.coord[central_mask][0]
56
+ dists = np.linalg.norm(
57
+ cur_res_atom_array.coord - central_coord, axis=-1
58
+ )
59
+ is_virtual = (dists < threshold) & ~central_mask
60
+
61
+ cur_res_atom_array = cur_res_atom_array[~is_virtual]
62
+ cur_pred_res_atom_names = cur_res_atom_array.atom_name
63
+
64
+ has_restype_assigned = False
65
+ for restype, atom_names in association_schemes_stripped[
66
+ "atom14"
67
+ ].items():
68
+ if restype in ["UNK", "MSK"]:
69
+ continue
70
+ atom_names = np.array(atom_names)
71
+ atom_name_idx = np.array(
72
+ [
73
+ np.where(ATOM14_ATOM_NAMES == atom_name)[0][0]
74
+ for atom_name in cur_pred_res_atom_names
75
+ ]
76
+ )
77
+ atom14_mask = np.zeros_like(ATOM14_ATOM_NAMES, dtype=bool)
78
+ atom14_mask[atom_name_idx] = True
79
+
80
+ if all(x is not None for x in atom_names[atom14_mask]) and all(
81
+ x is None for x in atom_names[~atom14_mask]
82
+ ):
83
+ cur_res_atom_array.res_name = np.array(
84
+ [restype] * len(cur_res_atom_array)
85
+ )
86
+ cur_res_atom_array.atom_name = np.asarray(
87
+ atom_names[atom14_mask], dtype=str
88
+ )
89
+ cur_atom_array_list.append(cur_res_atom_array)
90
+ has_restype_assigned = True
91
+ break
92
+ else:
93
+ cur_atom_array_list.append(cur_res_atom_array)
94
+ has_restype_assigned = True
95
+
96
+ if not has_restype_assigned:
97
+ cur_res_atom_array.res_name = np.array(
98
+ ["UNK"] * len(cur_res_atom_array)
99
+ )
100
+ cur_atom_array_list.append(cur_res_atom_array)
101
+
102
+ cur_atom_array = struc.concatenate(cur_atom_array_list)
103
+ cur_atom_array.element = struc.infer_elements(cur_atom_array.atom_name)
104
+ final_atom_array.append(cur_atom_array)
105
+
106
+ return final_atom_array
107
+
108
+
109
+ def calculate_hbond_stats(
110
+ input_atom_array_stack,
111
+ output_atom_array_stack,
112
+ cutoff_HA_dist=3,
113
+ cutoff_DA_distance=3.5,
114
+ inference_metrics=False,
115
+ ):
116
+ output_atom_array_stack = simplified_processing_atom_array(output_atom_array_stack)
117
+ assert len(input_atom_array_stack) == len(output_atom_array_stack)
118
+
119
+ total_correct_donors_percent = 0.0
120
+ total_correct_acceptors_percent = 0.0
121
+ total_number_donors_acceptors = 0
122
+ total_number_hbonds = 0
123
+ num_valid_samples = 0
124
+
125
+ for input_atom_array, output_atom_array in zip(
126
+ input_atom_array_stack, output_atom_array_stack
127
+ ):
128
+ # Ensure required annotations exist
129
+ for annotation in ["active_donor", "active_acceptor"]:
130
+ if annotation not in input_atom_array.get_annotation_categories():
131
+ input_atom_array.set_annotation(
132
+ annotation, np.zeros(len(input_atom_array), dtype=bool)
133
+ )
134
+
135
+ # Skip samples with no donors or acceptors
136
+ if (
137
+ np.sum(input_atom_array.active_donor) == 0
138
+ and np.sum(input_atom_array.active_acceptor) == 0
139
+ ):
140
+ continue
141
+
142
+ # Clean up coordinate annotations
143
+ for atom_array in [input_atom_array, output_atom_array]:
144
+ if "coord_to_be_noised" in atom_array.get_annotation_categories():
145
+ atom_array.del_annotation("coord_to_be_noised")
146
+
147
+ # Calculate hydrogen bonds
148
+ output_atom_array, hbonds, motif_diffused_hbond_count = calculate_hbonds(
149
+ output_atom_array,
150
+ )
151
+
152
+ # Update hbond annotations for motif atoms only
153
+ hbond_types = np.vstack(
154
+ (output_atom_array.active_donor, output_atom_array.active_acceptor)
155
+ ).T
156
+ motif_mask = np.array(get_motif_features(output_atom_array)["is_motif_atom"])
157
+ hbond_types[:, 0] *= motif_mask
158
+ hbond_types[:, 1] *= motif_mask
159
+
160
+ output_atom_array.set_annotation("active_donor", hbond_types[:, 0])
161
+ output_atom_array.set_annotation("active_acceptor", hbond_types[:, 1])
162
+
163
+ # Count correct predictions
164
+ correct_donors = _count_correct_hbond_atoms(
165
+ input_atom_array, output_atom_array, "active_donor"
166
+ )
167
+ correct_acceptors = _count_correct_hbond_atoms(
168
+ input_atom_array, output_atom_array, "active_acceptor"
169
+ )
170
+
171
+ # Calculate percentages
172
+ given_donors = np.sum(input_atom_array.active_donor)
173
+ given_acceptors = np.sum(input_atom_array.active_acceptor)
174
+
175
+ correct_donor_pct = correct_donors / given_donors if given_donors > 0 else 1.0
176
+ correct_acceptor_pct = (
177
+ correct_acceptors / given_acceptors if given_acceptors > 0 else 1.0
178
+ )
179
+
180
+ # Accumulate totals
181
+ total_correct_donors_percent += correct_donor_pct
182
+ total_correct_acceptors_percent += correct_acceptor_pct
183
+ total_number_donors_acceptors += np.sum(hbond_types)
184
+ total_number_hbonds += motif_diffused_hbond_count
185
+ num_valid_samples += 1
186
+
187
+ if num_valid_samples == 0:
188
+ if inference_metrics:
189
+ return {
190
+ "correct_donor_percent": "",
191
+ "correct_acceptor_percent": "",
192
+ "num_hbonds": "",
193
+ "hbonds": [],
194
+ "total_number_donors_acceptors": "",
195
+ "output_atom_array": None,
196
+ }
197
+ return 0, 0, 0
198
+
199
+ avg_donor_pct = total_correct_donors_percent / num_valid_samples
200
+ avg_acceptor_pct = total_correct_acceptors_percent / num_valid_samples
201
+ avg_hbonds = total_number_hbonds / num_valid_samples
202
+
203
+ if inference_metrics:
204
+ return {
205
+ "correct_donor_percent": avg_donor_pct,
206
+ "correct_acceptor_percent": avg_acceptor_pct,
207
+ "num_hbonds": avg_hbonds,
208
+ "hbonds": hbonds,
209
+ "total_number_donors_acceptors": total_number_donors_acceptors,
210
+ "output_atom_array": output_atom_array,
211
+ }
212
+
213
+ # Return results
214
+ if num_valid_samples == 0:
215
+ return 0, 0, 0
216
+
217
+ return avg_donor_pct, avg_acceptor_pct, avg_hbonds
218
+
219
+
220
+ def _count_correct_hbond_atoms(input_atom_array, output_atom_array, annotation_type):
221
+ """Count correctly predicted hydrogen bond atoms."""
222
+ correct_count = 0
223
+ target_indices = np.where(getattr(input_atom_array, annotation_type) == 1)[0]
224
+
225
+ for idx in target_indices:
226
+ matching_atoms = output_atom_array[
227
+ (output_atom_array.chain_iid == input_atom_array.chain_iid[idx])
228
+ & (output_atom_array.res_id == input_atom_array.res_id[idx])
229
+ & (output_atom_array.atom_name == input_atom_array.gt_atom_name[idx])
230
+ ]
231
+
232
+ if len(matching_atoms) > 0 and bool(getattr(matching_atoms, annotation_type)):
233
+ correct_count += 1
234
+
235
+ return correct_count
236
+
237
+
238
+ def get_hbond_metrics(atom_array=None):
239
+ if atom_array is None:
240
+ global_logger.warning("atom_array is None")
241
+ return None
242
+
243
+ try:
244
+ output = calculate_hbond_stats(
245
+ [atom_array.copy()], [atom_array.copy()], inference_metrics=True
246
+ )
247
+ hbonds = output["hbonds"]
248
+
249
+ o = {
250
+ "donor_atom_names": list(
251
+ set(f"{hb['d_atom']}_{hb['d_resn']}_{hb['d_resi']}" for hb in hbonds)
252
+ ),
253
+ "acceptor_atom_names": list(
254
+ set(f"{hb['a_atom']}_{hb['a_resn']}_{hb['a_resi']}" for hb in hbonds)
255
+ ),
256
+ "hbond_connections": list(
257
+ set(
258
+ f"{hb['d_atom']}_{hb['d_resn']}_{hb['d_resi']}-{hb['a_atom']}_{hb['a_resn']}_{hb['a_resi']}"
259
+ for hb in hbonds
260
+ )
261
+ ),
262
+ "correct_donor_percent": float(output["correct_donor_percent"]),
263
+ "correct_acceptor_percent": float(output["correct_acceptor_percent"]),
264
+ "num_hbonds": float(output["num_hbonds"]),
265
+ }
266
+ return o
267
+
268
+ except Exception as e:
269
+ global_logger.warning(f"Could not calculate hbond metrics: {e}")
270
+ return {}
271
+
272
+
273
+ class HbondMetrics(Metric):
274
+ def __init__(
275
+ self,
276
+ cutoff_HA_dist: float = 3,
277
+ cutoff_DA_distance: float = 3.5,
278
+ ):
279
+ super().__init__()
280
+ self.cutoff_HA_dist = cutoff_HA_dist
281
+ self.cutoff_DA_distance = cutoff_DA_distance
282
+
283
+ @property
284
+ def kwargs_to_compute_args(self):
285
+ return {
286
+ "ground_truth_atom_array_stack": ("ground_truth_atom_array_stack",),
287
+ "predicted_atom_array_stack": ("predicted_atom_array_stack",),
288
+ }
289
+
290
+ def compute(self, *, ground_truth_atom_array_stack, predicted_atom_array_stack):
291
+ try:
292
+ d_pct, a_pct, n_hbonds = calculate_hbond_stats(
293
+ input_atom_array_stack=ground_truth_atom_array_stack,
294
+ output_atom_array_stack=predicted_atom_array_stack,
295
+ cutoff_HA_dist=self.cutoff_HA_dist,
296
+ cutoff_DA_distance=self.cutoff_DA_distance,
297
+ )
298
+ except Exception as e:
299
+ global_logger.error(
300
+ f"Error calculating hydrogen bond metrics: {e} | Skipping"
301
+ )
302
+ return {}
303
+
304
+ return {
305
+ "mean_correct_donors_percent": float(d_pct),
306
+ "mean_correct_acceptors_percent": float(a_pct),
307
+ "mean_num_hbonds": float(n_hbonds),
308
+ }
@@ -0,0 +1,389 @@
1
+ import logging
2
+ from typing import Literal
3
+
4
+ import biotite.structure as struc
5
+ import numpy as np
6
+ from atomworks.enums import ChainType
7
+ from atomworks.io.transforms.atom_array import remove_hydrogens
8
+ from rfd3.constants import (
9
+ ATOM14_ATOM_NAMES,
10
+ SELECTION_NONPROTEIN,
11
+ SELECTION_PROTEIN,
12
+ association_schemes_stripped,
13
+ )
14
+ from rfd3.transforms.hbonds import (
15
+ add_hydrogen_atom_positions,
16
+ calculate_hbonds,
17
+ )
18
+
19
+ from foundry.metrics.base import Metric
20
+ from foundry.utils.ddp import RankedLogger
21
+
22
+ logging.basicConfig(level=logging.INFO)
23
+ global_logger = RankedLogger(__name__, rank_zero_only=False)
24
+
25
+
26
+ def simplified_processing_atom_array(atom_arrays, central_atom="CB", threshold=0.5):
27
+ """
28
+ Allows for sequence extraction from cleaned up virtual atoms. Needed for hbond metrics.
29
+ """
30
+ final_atom_array = []
31
+ for atom_array in atom_arrays:
32
+ cur_atom_array_list = []
33
+
34
+ res_ids = atom_array.res_id
35
+ res_start_indices = np.concatenate(
36
+ [[0], np.where(res_ids[1:] != res_ids[:-1])[0] + 1]
37
+ )
38
+ res_end_indices = np.concatenate([res_start_indices[1:], [len(res_ids)]])
39
+
40
+ for start, end in zip(res_start_indices, res_end_indices):
41
+ cur_res_atom_array = atom_array[start:end]
42
+
43
+ # Check if the current residue is after padding (seq unknown):
44
+ if_seq_known = not any(
45
+ atom_name.startswith("V") for atom_name in cur_res_atom_array.atom_name
46
+ )
47
+
48
+ if not if_seq_known:
49
+ # For Glycine: it doesn't have CB, so set the virtual atom as CA.
50
+ # The current way to handle this is to check if predicted CA and CB are too close, because in the case of glycine and we pad virtual atoms based on CB, CB's coords are set as CA.
51
+ # There might be a better way to do this.
52
+ CA_coord = cur_res_atom_array.coord[
53
+ cur_res_atom_array.atom_name == "CA"
54
+ ]
55
+ CB_coord = cur_res_atom_array.coord[
56
+ cur_res_atom_array.atom_name == "CB"
57
+ ]
58
+ if np.linalg.norm(CA_coord - CB_coord) < threshold:
59
+ central_atom = "CA"
60
+
61
+ central_mask = cur_res_atom_array.atom_name == central_atom
62
+
63
+ # ... Calculate the distance to the central atom
64
+ central_coord = cur_res_atom_array.coord[central_mask][
65
+ 0
66
+ ] # Should only have one central atom anyway
67
+ dists = np.linalg.norm(
68
+ cur_res_atom_array.coord - central_coord, axis=-1
69
+ )
70
+
71
+ # ... Select virtual atom by the distance. Shouldn't count the central atom itself.
72
+ is_virtual = (dists < threshold) & ~central_mask
73
+
74
+ cur_res_atom_array = cur_res_atom_array[~is_virtual]
75
+ cur_pred_res_atom_names = (
76
+ cur_res_atom_array.atom_name
77
+ ) # e.g. [N, CA, C, O, CB, V6, V2]
78
+
79
+ has_restype_assigned = False
80
+ for restype, atom_names in association_schemes_stripped[
81
+ "atom14"
82
+ ].items():
83
+ atom_names = np.array(atom_names)
84
+ if restype in ["UNK", "MSK"]:
85
+ continue
86
+
87
+ atom_name_idx_in_atom14_scheme = np.array(
88
+ [
89
+ np.where(ATOM14_ATOM_NAMES == atom_name)[0][0]
90
+ for atom_name in cur_pred_res_atom_names
91
+ ]
92
+ ) # [0, 1, 2, 3, 4, 11, 7]
93
+ atom14_scheme_mask = np.zeros_like(ATOM14_ATOM_NAMES, dtype=bool)
94
+ atom14_scheme_mask[atom_name_idx_in_atom14_scheme] = True
95
+ if all(
96
+ x is not None for x in atom_names[atom14_scheme_mask]
97
+ ) and all(x is None for x in atom_names[~atom14_scheme_mask]):
98
+ cur_res_atom_array.res_name = np.array(
99
+ [restype] * len(cur_res_atom_array)
100
+ )
101
+ cur_res_atom_array.atom_name = np.asarray(
102
+ atom_names[atom14_scheme_mask], dtype=str
103
+ )
104
+ cur_atom_array_list.append(cur_res_atom_array)
105
+ has_restype_assigned = True
106
+ break
107
+ else:
108
+ cur_atom_array_list.append(cur_res_atom_array)
109
+ has_restype_assigned = True
110
+
111
+ if not has_restype_assigned:
112
+ cur_res_atom_array.res_name = np.array(
113
+ ["UNK"] * len(cur_res_atom_array)
114
+ )
115
+ cur_atom_array_list.append(cur_res_atom_array)
116
+
117
+ cur_atom_array = struc.concatenate(cur_atom_array_list)
118
+ cur_atom_array.element = struc.infer_elements(cur_atom_array.atom_name)
119
+
120
+ final_atom_array.append(cur_atom_array)
121
+
122
+ return final_atom_array
123
+
124
+
125
+ # Training comparison
126
+ def calculate_hbond_stats(
127
+ input_atom_array_stack,
128
+ output_atom_array_stack,
129
+ selection1,
130
+ selection2,
131
+ selection1_type,
132
+ cutoff_dist,
133
+ cutoff_angle,
134
+ donor_elements,
135
+ acceptor_elements,
136
+ periodic,
137
+ ):
138
+ """
139
+ Compare the number of hbonds correctly recapitualted in the output atom array.
140
+
141
+ Args:
142
+ input_atom_array_stack: Input atom array stack
143
+ output_atom_array_stack: Output atom array stack
144
+ selection1: Selection of atom types allowed to be donors (5,6)
145
+ selection2: Selection of atom types allowed to be acceptors (1,2,3...)
146
+ cutoff_dist: Cutoff distance for hbonds
147
+ cutoff_angle: Cutoff angle for hbonds
148
+ """
149
+ # Used the latest function above, should check if it works correctly
150
+ output_atom_array_stack = simplified_processing_atom_array(output_atom_array_stack)
151
+
152
+ assert len(input_atom_array_stack) == len(
153
+ output_atom_array_stack
154
+ ), "Input and output atom arrays must have the same length"
155
+
156
+ total_correct_donors_percent = 0.0
157
+ total_correct_acceptors_percent = 0.0
158
+ total_number_hbonds = 0
159
+ num_valid_samples = 0
160
+ for i in range(len(input_atom_array_stack)):
161
+ correct_donors = 0
162
+ correct_acceptors = 0
163
+
164
+ input_atom_array = input_atom_array_stack[i]
165
+ output_atom_array = output_atom_array_stack[i]
166
+
167
+ if not (
168
+ "active_donor" in input_atom_array.get_annotation_categories()
169
+ or "active_acceptor" in input_atom_array.get_annotation_categories()
170
+ ):
171
+ # print("active donor/acceptor not in annotation")
172
+ continue
173
+ if np.sum(input_atom_array.active_donor == 0) and np.sum(
174
+ input_atom_array.active_acceptor == 0
175
+ ):
176
+ continue
177
+
178
+ # Select possible donors and acceptors for the model output
179
+ if selection1 is None or selection2 is None:
180
+ continue
181
+
182
+ # Hack: Temporarily use biotite to infer bonds, should be replaced with cifutils?
183
+ output_atom_array.bonds = struc.connect_via_distances(
184
+ output_atom_array, default_bond_type=1
185
+ )
186
+
187
+ # Hack: delete coords_to_be_diffused (if exists) to temporarily solve a weird bug in create hydrogens. Anyway it will not be used.
188
+ if "coord_to_be_noised" in input_atom_array.get_annotation_categories():
189
+ input_atom_array.del_annotation("coord_to_be_noised")
190
+ if "coord_to_be_noised" in output_atom_array.get_annotation_categories():
191
+ output_atom_array.del_annotation("coord_to_be_noised")
192
+
193
+ output_atom_array = add_hydrogen_atom_positions(output_atom_array)
194
+
195
+ cur_selection1 = np.isin(output_atom_array.chain_type, selection1)
196
+ cur_selection2 = (
197
+ np.isin(output_atom_array.chain_type, selection2)
198
+ | get_motif_features(output_atom_array)["is_motif_atom"]
199
+ )
200
+
201
+ hbonds, hbond_types, output_atom_array = calculate_hbonds(
202
+ output_atom_array,
203
+ cur_selection1,
204
+ cur_selection2,
205
+ selection1_type=selection1_type,
206
+ cutoff_dist=cutoff_dist,
207
+ cutoff_angle=cutoff_angle,
208
+ donor_elements=donor_elements,
209
+ acceptor_elements=acceptor_elements,
210
+ periodic=periodic,
211
+ )
212
+
213
+ output_atom_array.set_annotation("active_donor", hbond_types[:, 0])
214
+ output_atom_array.set_annotation("active_acceptor", hbond_types[:, 1])
215
+
216
+ output_atom_array = remove_hydrogens(output_atom_array)
217
+
218
+ given_hbond_donors = np.array(input_atom_array.active_donor, dtype=bool)
219
+ given_hbond_acceptors = np.array(input_atom_array.active_acceptor, dtype=bool)
220
+ given_hbond_donors_index = np.where(input_atom_array.active_donor == 1)[0]
221
+ given_hbond_acceptors_index = np.where(input_atom_array.active_acceptor == 1)[0]
222
+
223
+ # Ensure the produced hbonds matches input hbond requirements: have the same atom type, residue name, and atom name
224
+ for idx in given_hbond_donors_index:
225
+ if bool(
226
+ output_atom_array[
227
+ (output_atom_array.chain_id == input_atom_array.chain_id[idx])
228
+ & (output_atom_array.res_id == input_atom_array.res_id[idx])
229
+ & (
230
+ output_atom_array.atom_name
231
+ == input_atom_array.gt_atom_name[idx]
232
+ )
233
+ ].active_donor
234
+ ):
235
+ correct_donors += 1
236
+
237
+ for idx in given_hbond_acceptors_index:
238
+ if bool(
239
+ output_atom_array[
240
+ (output_atom_array.chain_id == input_atom_array.chain_id[idx])
241
+ & (output_atom_array.res_id == input_atom_array.res_id[idx])
242
+ & (
243
+ output_atom_array.atom_name
244
+ == input_atom_array.gt_atom_name[idx]
245
+ )
246
+ ].active_acceptor
247
+ ):
248
+ correct_acceptors += 1
249
+
250
+ correct_hbond_donors_percent = (
251
+ correct_donors / np.sum(given_hbond_donors)
252
+ if np.sum(given_hbond_donors) > 0
253
+ else 1.0
254
+ )
255
+ correct_hbond_acceptors_percent = (
256
+ correct_acceptors / np.sum(given_hbond_acceptors)
257
+ if np.sum(given_hbond_acceptors) > 0
258
+ else 1.0
259
+ )
260
+
261
+ total_correct_donors_percent += correct_hbond_donors_percent
262
+ total_correct_acceptors_percent += correct_hbond_acceptors_percent
263
+ total_number_hbonds += len(hbonds)
264
+ num_valid_samples += 1
265
+
266
+ if num_valid_samples == 0:
267
+ return 0, 0, 0
268
+ return (
269
+ total_correct_donors_percent / num_valid_samples,
270
+ total_correct_acceptors_percent / num_valid_samples,
271
+ total_number_hbonds / num_valid_samples,
272
+ )
273
+
274
+
275
+ # Inference comparison -> tempportary fix to test out sm_hbonds, should be merged with hbond in transforms down the line
276
+ def get_hbond_metrics(atom_array=None):
277
+ if atom_array is None:
278
+ print("WARNING: atom_array is None")
279
+ return None # Or raise a more descriptive error
280
+
281
+ curr_copy = atom_array.copy()
282
+ o = {}
283
+ selection1 = np.array([ChainType.as_enum(item).value for item in SELECTION_PROTEIN])
284
+ selection2 = np.array(
285
+ [ChainType.as_enum(item).value for item in SELECTION_NONPROTEIN]
286
+ )
287
+ # Hack: Temporarily use biotite to infer bonds, should be replaced with cifutils?
288
+ curr_copy.bonds = struc.connect_via_distances(curr_copy, default_bond_type=1)
289
+ # Hack: delete coords_to_be_diffused (if exists) to temporarily solve a weird bug in create hydrogens. Anyway it will not be used.
290
+ if "coord_to_be_noised" in curr_copy.get_annotation_categories():
291
+ curr_copy.del_annotation("coord_to_be_noised")
292
+
293
+ try:
294
+ curr_copy = add_hydrogen_atom_positions(curr_copy)
295
+ except Exception as e:
296
+ print("WARNING: problem adding hydrogen", e)
297
+
298
+ if selection1 is not None:
299
+ selection1 = np.isin(curr_copy.chain_type, selection1)
300
+ else:
301
+ selection1 = selection1
302
+ if selection2 is not None:
303
+ selection2 = np.isin(curr_copy.chain_type, selection2)
304
+ else:
305
+ selection2 = selection2
306
+
307
+ # Always include fixed motif atoms for hbond calculations
308
+ selection2 |= np.array(curr_copy.is_motif_atom, dtype=bool)
309
+ selection1 = ~selection2
310
+
311
+ hbonds, hbond_types, curr_copy = calculate_hbonds(
312
+ curr_copy,
313
+ selection1=selection1,
314
+ selection2=selection2,
315
+ )
316
+
317
+ o["num_hbonds"] = int(len(hbonds))
318
+ o["num_donors"] = int(np.sum(hbond_types[:, 0]))
319
+ o["num_acceptors"] = int(np.sum(hbond_types[:, 1]))
320
+
321
+ return o
322
+
323
+
324
+ class HbondMetrics(Metric):
325
+ def __init__(
326
+ self,
327
+ selection1: list[str] = SELECTION_PROTEIN,
328
+ selection2: list[str] = SELECTION_NONPROTEIN,
329
+ selection1_type: Literal["acceptor", "donor", "both"] = "both",
330
+ cutoff_dist: float = 3.0,
331
+ cutoff_angle: float = 120.0,
332
+ donor_elements: list[str] = ["N", "O", "S", "F"],
333
+ acceptor_elements: list[str] = ["N", "O", "S", "F"],
334
+ periodic: bool = False,
335
+ ):
336
+ super().__init__()
337
+
338
+ self.selection1 = np.array(
339
+ [ChainType.as_enum(item).value for item in selection1]
340
+ )
341
+ self.selection2 = np.array(
342
+ [ChainType.as_enum(item).value for item in selection2]
343
+ )
344
+
345
+ self.selection1_type = selection1_type
346
+ self.cutoff_dist = cutoff_dist
347
+ self.cutoff_angle = cutoff_angle
348
+ self.donor_elements = donor_elements
349
+ self.acceptor_elements = acceptor_elements
350
+ self.periodic = periodic
351
+
352
+ @property
353
+ def kwargs_to_compute_args(self):
354
+ return {
355
+ "ground_truth_atom_array_stack": ("ground_truth_atom_array_stack",),
356
+ "predicted_atom_array_stack": ("predicted_atom_array_stack",),
357
+ }
358
+
359
+ def compute(self, *, ground_truth_atom_array_stack, predicted_atom_array_stack):
360
+ try:
361
+ (
362
+ mean_correct_donors_percent,
363
+ mean_correct_acceptors_percent,
364
+ mean_num_hbonds,
365
+ ) = calculate_hbond_stats(
366
+ input_atom_array_stack=ground_truth_atom_array_stack,
367
+ output_atom_array_stack=predicted_atom_array_stack,
368
+ selection1=self.selection1,
369
+ selection2=self.selection2,
370
+ selection1_type=self.selection1_type,
371
+ cutoff_dist=self.cutoff_dist,
372
+ cutoff_angle=self.cutoff_angle,
373
+ donor_elements=self.donor_elements,
374
+ acceptor_elements=self.acceptor_elements,
375
+ periodic=self.periodic,
376
+ )
377
+ except Exception as e:
378
+ global_logger.error(
379
+ f"Error calculating hydrogen bond metrics: {e} | Skipping"
380
+ )
381
+ return {}
382
+
383
+ # Aggregate output for batch-level metrics
384
+ o = {
385
+ "mean_correct_donors_percent": float(mean_correct_donors_percent),
386
+ "mean_correct_acceptors_percent": float(mean_correct_acceptors_percent),
387
+ "mean_num_hbonds": float(mean_num_hbonds),
388
+ }
389
+ return o