rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,407 @@
1
+ from typing import Any, Literal, Tuple
2
+
3
+ import biotite.structure as struc
4
+ import hydride
5
+ import numpy as np
6
+ from atomworks.io.transforms.atom_array import remove_hydrogens
7
+ from atomworks.io.utils.ccd import atom_array_from_ccd_code
8
+ from atomworks.ml.transforms._checks import (
9
+ check_atom_array_annotation,
10
+ check_contains_keys,
11
+ check_is_instance,
12
+ )
13
+ from atomworks.ml.transforms.base import Transform
14
+ from biotite.structure import AtomArray, AtomArrayStack
15
+ from rfd3.constants import SELECTION_NONPROTEIN, SELECTION_PROTEIN
16
+
17
+ from foundry.utils.ddp import RankedLogger
18
+
19
+ ranked_logger = RankedLogger()
20
+
21
+ HYDROGEN_LIKE_SYMBOLS = ("H", "H2", "D", "T")
22
+
23
+
24
+ # TODO: Once the cifutils submodule is bumped, we can use the built-in add_hydrogen_atom_positions function
25
+ def add_hydrogen_atom_positions(
26
+ atom_array: AtomArray | AtomArrayStack,
27
+ ) -> AtomArray | AtomArrayStack:
28
+ """Add hydrogens using biotite supported hydride library
29
+
30
+ Args:
31
+ atom_array (AtomArray | AtomArrayStack): The atom array containing the chain information.
32
+
33
+ Returns:
34
+ AtomArray: The updated atom array with hydrogens added.
35
+ """
36
+
37
+ def _get_charge_from_ccd_code(atom):
38
+ try:
39
+ ccd_array = atom_array_from_ccd_code(atom.res_name)
40
+ charge = ccd_array[
41
+ ccd_array.atom_name.tolist().index(atom.atom_name)
42
+ ].charge
43
+ except Exception:
44
+ ## res_name not found in ccd or atom_name not found in ccd_array
45
+ charge = 0
46
+ return charge
47
+
48
+ if "charge" not in atom_array.get_annotation_categories():
49
+ charges = np.vectorize(_get_charge_from_ccd_code)(atom_array)
50
+ atom_array.set_annotation("charge", charges)
51
+
52
+ # Add as a custom annotation
53
+
54
+ array = remove_hydrogens(atom_array)
55
+
56
+ fields_to_copy_from_residue_if_present = [
57
+ "auth_seq_id",
58
+ "label_entity_id",
59
+ "is_can_prot",
60
+ "is_can_nucl",
61
+ "is_sm",
62
+ "chain_type",
63
+ ]
64
+ fields_to_copy_from_residue_if_present = list(
65
+ set(fields_to_copy_from_residue_if_present).intersection(
66
+ set(atom_array.get_annotation_categories())
67
+ )
68
+ )
69
+
70
+ def _copy_missing_annotations_residue_wise(
71
+ arr_to_copy_from: AtomArray,
72
+ arr_to_update: AtomArray,
73
+ fields_to_copy_from_residue_if_present: list[str],
74
+ ) -> AtomArray:
75
+ """Copy specified annotations residue-wise from one AtomArray to another. Updates annotations in-place."""
76
+ residue_starts = struc.get_residue_starts(arr_to_copy_from)
77
+ residue_starts_atom_array = arr_to_copy_from[residue_starts]
78
+ annot = {
79
+ item: getattr(residue_starts_atom_array, item)
80
+ for item in fields_to_copy_from_residue_if_present
81
+ }
82
+ for field in fields_to_copy_from_residue_if_present:
83
+ updated_field = struc.spread_residue_wise(arr_to_update, annot[field])
84
+ arr_to_update.set_annotation(field, updated_field)
85
+ return arr_to_update
86
+
87
+ def _handle_nan_coords(atom_array, noise_level=1e-3):
88
+ coords = atom_array.coord
89
+
90
+ # Find NaNs
91
+ nan_mask = np.isnan(coords)
92
+
93
+ # Replace NaNs with 0 + small random offset
94
+ coords[nan_mask] = np.random.uniform(
95
+ -noise_level, noise_level, size=nan_mask.sum()
96
+ )
97
+
98
+ # Update atom_array in-place
99
+ atom_array.coord = coords
100
+ return atom_array, nan_mask
101
+
102
+ if isinstance(array, AtomArrayStack):
103
+ updated_arrays = []
104
+ for old_arr in array:
105
+ if old_arr.bonds is None:
106
+ old_arr.bonds = struc.connect_via_distances(old_arr)
107
+
108
+ ## give some values to nan
109
+ old_arr, nan_mask = _handle_nan_coords(old_arr)
110
+ arr, mask = hydride.add_hydrogen(old_arr)
111
+ ## put back nans
112
+ arr.coord[mask, :][nan_mask] = np.nan
113
+ arr = _copy_missing_annotations_residue_wise(
114
+ old_arr, arr, fields_to_copy_from_residue_if_present
115
+ )
116
+ updated_arrays.append(arr)
117
+
118
+ ret_array = struc.stack(updated_arrays)
119
+
120
+ elif isinstance(array, AtomArray):
121
+ if array.bonds is None:
122
+ array.bonds = struc.connect_via_distances(array)
123
+ ## give some values to nan
124
+ array, nan_mask = _handle_nan_coords(array)
125
+ arr, mask = hydride.add_hydrogen(array)
126
+ ## put back nans
127
+ arr.coord[mask, :][nan_mask] = np.nan
128
+ ret_array = _copy_missing_annotations_residue_wise(
129
+ array, arr, fields_to_copy_from_residue_if_present
130
+ )
131
+ return ret_array
132
+
133
+
134
+ def check_atom_array_has_hydrogen(data: dict[str, Any]):
135
+ """Check if `atom_array` key has bonds."""
136
+ import numpy as np
137
+
138
+ if not np.any(data["atom_array"].element == "H"):
139
+ raise ValueError("Key `atom_array` in data has no hydrogens.")
140
+
141
+
142
+ def calculate_hbonds(
143
+ atom_array: AtomArray,
144
+ selection1: np.ndarray = None,
145
+ selection2: np.ndarray = None,
146
+ selection1_type: Literal["acceptor", "donor", "both"] = "both",
147
+ cutoff_dist: float = 3,
148
+ cutoff_angle: float = 120,
149
+ donor_elements: Tuple[str] = ("O", "N", "S", "F"),
150
+ acceptor_elements: Tuple[str] = ("O", "N", "S", "F"),
151
+ periodic: bool = False,
152
+ ) -> Tuple[np.ndarray, np.ndarray, AtomArray]:
153
+ """
154
+ Calculates Hbonds with biotite.struc.Hbond.
155
+ Assigns donor, acceptor annotation for each heavy atom involved.
156
+ Args:
157
+ atom_array (AtomArray):Expects the atom_array that contains hydrogens.
158
+
159
+ selection1 and selection2 (np.ndarray, optional): (Boolean mask for atoms to limit the hydrogen bond search to specific sections of the model.
160
+ The shape must match the shape of the atoms argument. If None is given, the whole atoms stack is used instead. (Default: None))
161
+
162
+ selection1_type (Literal, optional): Determines the type of selection1. The type of selection2 is chosen accordingly (‘both’ or the opposite).
163
+ (Default: 'both')
164
+ cutoff_dist (float, optional): The maximal distance between the hydrogen and acceptor to be considered a hydrogen bond. (Default: 2.5)
165
+ cutoff_angle (float, optional): The angle cutoff in degree between Donor-H..Acceptor to be considered a hydrogen bond. (Default: 120)
166
+ donor_elements, acceptor_elements (tuple of str): Elements to be considered as possible donors or acceptors. (Default: O, N, S)
167
+ periodic (bool, optional): If true, hydrogen bonds can also be detected in periodic boundary conditions. The box attribute of atoms is required in this case. (Default: False)
168
+
169
+
170
+ """
171
+ # Remove NaN coordinates
172
+ has_resolved_coordinates = ~np.isnan(atom_array.coord).any(axis=-1)
173
+ nonNaN_array = atom_array[has_resolved_coordinates]
174
+
175
+ # update selections if any
176
+ if selection1 is not None:
177
+ selection1 = selection1[has_resolved_coordinates]
178
+ if selection2 is not None:
179
+ selection2 = selection2[has_resolved_coordinates]
180
+
181
+ ## index map from nonNaN_array to original
182
+ index_map = {
183
+ counter: i for counter, i in enumerate(has_resolved_coordinates.nonzero()[0])
184
+ }
185
+
186
+ if selection1.sum() == 0 or selection2.sum() == 0:
187
+ # no ligand, or ligand is of same type as selection1 (e.g. 6) (peptide)
188
+ triplets = np.array([])
189
+ else:
190
+ # Compute H bonds
191
+ triplets = struc.hbond( ## assuming AtomArray, not AtomArrayStack (returns an extra masks in that case)
192
+ nonNaN_array,
193
+ selection1=selection1,
194
+ selection2=selection2,
195
+ selection1_type=selection1_type,
196
+ cutoff_dist=cutoff_dist,
197
+ cutoff_angle=cutoff_angle,
198
+ donor_elements=donor_elements,
199
+ acceptor_elements=acceptor_elements,
200
+ periodic=periodic,
201
+ )
202
+
203
+ ## map back triplet indices, nonNaN indices to original indices
204
+ flattened = triplets.flatten()
205
+ triplets = np.array([index_map[i] for i in flattened]).reshape(-1, 3)
206
+
207
+ ## add back NaNs
208
+
209
+ donor_array = np.array([[0.0] * len(atom_array)])
210
+ acceptor_array = np.array([[0.0] * len(atom_array)])
211
+
212
+ if len(triplets) > 0:
213
+ donor_array[:, triplets[:, 0]] = 1.0
214
+ acceptor_array[:, triplets[:, 2]] = 1.0
215
+
216
+ ## [is_active_donor, is_active_acceptor] per atom
217
+ types = np.vstack((donor_array, acceptor_array)).T
218
+
219
+ return triplets, types, atom_array
220
+
221
+
222
+ class CalculateHbonds(Transform):
223
+ """Transform for calculating Hbonds, expects an AtomArray containing hydrogens."""
224
+
225
+ def __init__(
226
+ self,
227
+ selection1_type: Literal["acceptor", "donor", "both"] = "both",
228
+ cutoff_dist: float = 3,
229
+ cutoff_angle: float = 120,
230
+ donor_elements: Tuple[str] = ("O", "N", "S", "F"),
231
+ acceptor_elements: Tuple[str] = ("O", "N", "S", "F"),
232
+ periodic: bool = False,
233
+ make2d: bool = False,
234
+ ):
235
+ """
236
+ Initialize the Hbonds transform.
237
+
238
+ Args:
239
+
240
+ selection1 and selection2 (list[str], optional): Specify a list of ChainTypes as in atomworks.enums. e.g. selectoin1 = ['POLYPEPTIDE(L)'], selection2 = ['NON-POLYMER', 'POLYRIBONUCLEOTIDE']
241
+ Allowed values: {'PEPTIDE NUCLEIC ACID', 'BRANCHED', 'POLYDEOXYRIBONUCLEOTIDE', 'POLYRIBONUCLEOTIDE', 'CYCLIC-PSEUDO-PEPTIDE', 'MACROLIDE', 'POLYDEOXYRIBONUCLEOTIDE/POLYRIBONUCLEOTIDE HYBRID', 'OTHER', 'POLYPEPTIDE(L)', 'NON-POLYMER', 'POLYPEPTIDE(D)', 'WATER'}
242
+
243
+ selection1_type (Literal, optional): Determines the type of selection1. The type of selection2 is chosen accordingly (‘both’ or the opposite).
244
+ (Default: 'both')
245
+ cutoff_dist (float, optional): The maximal distance between the hydrogen and acceptor to be considered a hydrogen bond. (Default: 2.5)
246
+ cutoff_angle (float, optional): The angle cutoff in degree between Donor-H..Acceptor to be considered a hydrogen bond. (Default: 120)
247
+ donor_elements, acceptor_elements (tuple of str): Elements to be considered as possible donors or acceptors. (Default: O, N, S)
248
+ periodic (bool, optional): If true, hydrogen bonds can also be detected in periodic boundary conditions. The box attribute of atoms is required in this case. (Default: False)
249
+ """
250
+ self.selection1_type = selection1_type
251
+ self.cutoff_dist = cutoff_dist
252
+ self.cutoff_angle = cutoff_angle
253
+ self.donor_elements = donor_elements
254
+ self.acceptor_elements = acceptor_elements
255
+ self.periodic = periodic
256
+ self.make2d = make2d
257
+
258
+ def check_input(self, data: dict[str, Any]) -> None:
259
+ check_contains_keys(data, ["atom_array"])
260
+ check_is_instance(data, "atom_array", AtomArray)
261
+ check_atom_array_annotation(data, ["res_name"])
262
+
263
+ ## turn off cause H addition debug ongoing
264
+ # check_atom_array_has_hydrogen(data)
265
+
266
+ def forward(self, data: dict) -> dict:
267
+ """
268
+ Calculates Hbonds and adds it to the data dictionary under the key `hbonds`.
269
+
270
+ Args:
271
+ data: dict
272
+ A dictionary containing the input data atomarray.
273
+ Expects the atom_array in data["atom_array"] contains hydrogens.
274
+
275
+
276
+ Returns:
277
+ dict: The data dictionary with hbonds added.
278
+ Sets hbond_type = [Donor, Acceptor] annotation to each atom. Donor, Acceptor can be both 0 or 1 (float). size: Lx2 (L: length of AtomArray)
279
+ """
280
+
281
+ atom_array: AtomArray = data["atom_array"]
282
+
283
+ try:
284
+ atom_array = add_hydrogen_atom_positions(atom_array)
285
+
286
+ except Exception as e:
287
+ print(
288
+ f"WARNING: problem adding hydrogens: {e}.\nThis example will get no hydrogen bond annotations."
289
+ )
290
+ atom_array.set_annotation(
291
+ "active_donor", np.zeros(atom_array.array_length(), dtype=bool)
292
+ )
293
+ atom_array.set_annotation(
294
+ "active_acceptor", np.zeros(atom_array.array_length(), dtype=bool)
295
+ )
296
+ data["atom_array"] = atom_array
297
+ return data
298
+
299
+ ## These are the only two use-cases we have so far. Can be extended as needed
300
+
301
+ if data["sampled_condition_name"] == "ppi":
302
+ selection1_chain_types = ["POLYPEPTIDE(D)", "POLYPEPTIDE(L)"]
303
+ selection2_chain_types = ["POLYPEPTIDE(D)", "POLYPEPTIDE(L)"]
304
+ separate_selections_for_motif_and_diffused = True
305
+ else:
306
+ selection1_chain_types = SELECTION_PROTEIN
307
+ selection2_chain_types = SELECTION_NONPROTEIN
308
+ separate_selections_for_motif_and_diffused = False
309
+
310
+ selection1 = np.isin(atom_array.chain_type, selection1_chain_types)
311
+ selection2 = np.isin(atom_array.chain_type, selection2_chain_types)
312
+
313
+ # Optionally restrict to Hbonds between motif and diffused regions
314
+ if separate_selections_for_motif_and_diffused:
315
+ selection1 = selection1 & atom_array.is_motif_atom
316
+ selection2 = selection2 & ~atom_array.is_motif_atom
317
+ else:
318
+ # Include fixed motif atoms for hbond calculations
319
+ selection2 |= np.array(atom_array.is_motif_atom, dtype=bool)
320
+ selection1 = ~selection2
321
+
322
+ hbonds, hbond_types, atom_array = calculate_hbonds(
323
+ atom_array,
324
+ selection1=selection1,
325
+ selection2=selection2,
326
+ selection1_type=self.selection1_type,
327
+ cutoff_dist=self.cutoff_dist,
328
+ cutoff_angle=self.cutoff_angle,
329
+ donor_elements=self.donor_elements,
330
+ acceptor_elements=self.acceptor_elements,
331
+ periodic=self.periodic,
332
+ )
333
+
334
+ # Initialize log_dict if not present
335
+ data.setdefault("log_dict", {})
336
+ log_dict = data["log_dict"]
337
+
338
+ # Log hbond statistics
339
+ log_dict["hbond_total_count"] = len(hbonds)
340
+ log_dict["hbond_total_atoms"] = hbond_types.sum()
341
+
342
+ # Subsample if hbond_subsample is set and number of atoms is bigger than 3
343
+ final_hbond_types = hbond_types
344
+ final_hbond_types[:, 0] = final_hbond_types[:, 0] * np.array(
345
+ atom_array.is_motif_atom
346
+ )
347
+ final_hbond_types[:, 1] = final_hbond_types[:, 1] * np.array(
348
+ atom_array.is_motif_atom
349
+ )
350
+
351
+ if data["conditions"]["hbond_subsample"] and np.sum(hbond_types) > 3:
352
+ # Linear correlation: fewer hbonds = higher fraction
353
+ base_fraction = 0.1 # minimum fraction (when many hbonds)
354
+ max_fraction = 0.9 # maximum fraction (when few hbonds)
355
+ n_hbonds = len(hbonds)
356
+ max_hbonds = 50 # Expected maximum number of hbonds for scaling
357
+
358
+ # Linear interpolation: fraction decreases linearly with number of hbonds
359
+ fraction = max_fraction - (max_fraction - base_fraction) * min(
360
+ n_hbonds / max_hbonds, 1.0
361
+ )
362
+ final_hbond_types = subsample_one_hot_np(hbond_types, fraction)
363
+
364
+ # Set annotations and log subsample atoms
365
+ atom_array.set_annotation("active_donor", final_hbond_types[:, 0])
366
+ atom_array.set_annotation("active_acceptor", final_hbond_types[:, 1])
367
+ log_dict["hbond_subsample_atoms"] = final_hbond_types.sum()
368
+
369
+ # Remove hydrogens after processing
370
+ atom_array = remove_hydrogens(atom_array)
371
+ data["log_dict"] = log_dict
372
+ data["atom_array"] = atom_array
373
+ return data
374
+
375
+
376
+ def subsample_one_hot_np(array, fraction):
377
+ """
378
+ Subsamples a one-hot encoded NumPy array by randomly keeping a given fraction of the 1s.
379
+
380
+ Args:
381
+ array (np.ndarray): One-hot array of 0s and 1s.
382
+ fraction (float): Fraction of 1s to keep (0 < fraction <= 1).
383
+
384
+ Returns:
385
+ np.ndarray: Subsampled array with same shape.
386
+ """
387
+ if not (0 < fraction <= 1):
388
+ raise ValueError("Fraction must be in the range (0, 1].")
389
+
390
+ array = array.copy() # Don't modify original
391
+ one_indices = np.argwhere(array == 1)
392
+ num_ones = len(one_indices)
393
+
394
+ keep_count = int(num_ones * fraction)
395
+
396
+ # Shuffle and choose a subset of indices to keep
397
+ np.random.shuffle(one_indices)
398
+ keep_indices = one_indices[:keep_count]
399
+
400
+ # Create new zero array
401
+ new_array = np.zeros_like(array)
402
+
403
+ # Set selected indices to 1
404
+ for i, j in keep_indices:
405
+ new_array[i, j] = 1
406
+
407
+ return new_array
@@ -0,0 +1,246 @@
1
+ import os
2
+ import string
3
+ import subprocess
4
+ from datetime import datetime
5
+ from typing import Any, Tuple
6
+
7
+ import numpy as np
8
+ from atomworks.ml.transforms._checks import (
9
+ check_atom_array_annotation,
10
+ check_contains_keys,
11
+ check_is_instance,
12
+ )
13
+ from atomworks.ml.transforms.base import Transform
14
+ from biotite.structure import AtomArray
15
+ from biotite.structure.io.pdb import PDBFile
16
+
17
+
18
+ def save_atomarray_to_pdb(atom_array, output_path):
19
+ def _handle_nan_coords(atom_array, noise_level=1e-3):
20
+ coords = atom_array.coord
21
+ nan_mask = np.isnan(coords)
22
+ coords[nan_mask] = np.random.uniform(
23
+ -noise_level, noise_level, size=nan_mask.sum()
24
+ )
25
+ atom_array.coord = coords
26
+ return atom_array, nan_mask
27
+
28
+ atom_array, nan_mask = _handle_nan_coords(atom_array)
29
+
30
+ chain_iids = np.unique(atom_array.chain_iid)
31
+ if len(chain_iids) > 52:
32
+ raise ValueError(
33
+ "Too many chain_iids, cannot convert to PDB", "skipping HBPLUS"
34
+ )
35
+
36
+ all_possible_chainIDS = string.ascii_letters
37
+ chain_map = {}
38
+ for item in chain_iids:
39
+ if len(item) == 1:
40
+ chain_map[item] = item
41
+ all_possible_chainIDS = all_possible_chainIDS.replace(item, "")
42
+ for item in chain_iids:
43
+ if len(item) > 1:
44
+ chain_map[item] = all_possible_chainIDS[0]
45
+ all_possible_chainIDS = all_possible_chainIDS.replace(chain_map[item], "")
46
+
47
+ new_chain_ids = [chain_map[i] for i in atom_array.chain_iid]
48
+ inverted_chain_map = {v: k for k, v in chain_map.items()}
49
+ atom_array.chain_id = new_chain_ids
50
+ atom_array.b_factor = np.zeros(len(atom_array))
51
+
52
+ pdb = PDBFile()
53
+ pdb.set_structure(atom_array)
54
+ pdb.write(output_path)
55
+
56
+ return atom_array, nan_mask, inverted_chain_map
57
+
58
+
59
+ def check_atom_array_has_hydrogen(data: dict[str, Any]):
60
+ if not np.any(data["atom_array"].element == "H"):
61
+ raise ValueError("Key `atom_array` in data has no hydrogens.")
62
+
63
+
64
+ def calculate_hbonds(
65
+ atom_array: AtomArray,
66
+ cutoff_HA_dist: float = 3,
67
+ cutoff_DA_distance: float = 3.5,
68
+ ) -> Tuple[np.ndarray, np.ndarray, AtomArray]:
69
+ dtstr = datetime.now().strftime("%Y%m%d%H%M%S")
70
+ pdb_path = f"{dtstr}_{np.random.randint(10000)}.pdb"
71
+ atom_array, nan_mask, chain_map = save_atomarray_to_pdb(atom_array, pdb_path)
72
+
73
+ hbplus_exe = os.environ.get("HBPLUS_PATH")
74
+
75
+ if hbplus_exe is None or hbplus_exe == "":
76
+ raise ValueError(
77
+ "HBPLUS_PATH environment variable not set. "
78
+ "Please set it to the path of the hbplus executable in order to calculate hydrogen bonds."
79
+ )
80
+
81
+ subprocess.call(
82
+ [
83
+ hbplus_exe,
84
+ "-h",
85
+ str(cutoff_HA_dist),
86
+ "-d",
87
+ str(cutoff_DA_distance),
88
+ pdb_path,
89
+ pdb_path,
90
+ ],
91
+ stdout=subprocess.DEVNULL,
92
+ stderr=subprocess.DEVNULL,
93
+ )
94
+
95
+ HB = open(pdb_path.replace("pdb", "hb2"), "r").readlines()
96
+ hbonds = []
97
+ for i in range(8, len(HB)):
98
+ d_chain = HB[i][0]
99
+ d_resi = str(int(HB[i][1:5].strip()))
100
+ d_resn = HB[i][6:9].strip()
101
+ d_ins = HB[i][5].replace("-", " ")
102
+ d_atom = HB[i][9:13].strip()
103
+ a_chain = HB[i][14]
104
+ a_resi = str(int(HB[i][15:19].strip()))
105
+ a_ins = HB[i][19].replace("-", " ")
106
+ a_resn = HB[i][20:23].strip()
107
+ a_atom = HB[i][23:27].strip()
108
+ dist = float(HB[i][27:32].strip())
109
+
110
+ items = {
111
+ "d_chain": chain_map[d_chain],
112
+ "d_resi": d_resi,
113
+ "d_resn": d_resn,
114
+ "d_ins": d_ins,
115
+ "d_atom": d_atom,
116
+ "a_chain": chain_map[a_chain],
117
+ "a_resi": a_resi,
118
+ "a_resn": a_resn,
119
+ "a_ins": a_ins,
120
+ "a_atom": a_atom,
121
+ "dist": dist,
122
+ }
123
+ hbonds.append(items)
124
+
125
+ donor_array = np.zeros(len(atom_array))
126
+ acceptor_array = np.zeros(len(atom_array))
127
+ donor_mask = np.bool_(donor_array)
128
+ acceptor_mask = np.bool_(acceptor_array)
129
+
130
+ motif_hbonds = []
131
+ for item in hbonds:
132
+ current_donor_mask = (
133
+ (atom_array.chain_iid == item["d_chain"])
134
+ & (atom_array.res_id == float(item["d_resi"]))
135
+ & (atom_array.atom_name == item["d_atom"])
136
+ )
137
+ current_acceptor_mask = (
138
+ (atom_array.chain_iid == item["a_chain"])
139
+ & (atom_array.res_id == float(item["a_resi"]))
140
+ & (atom_array.atom_name == item["a_atom"])
141
+ )
142
+
143
+ # Ensure that we can uniquely identify the donor and acceptor atoms
144
+ if current_donor_mask.sum() != 1:
145
+ raise ValueError(
146
+ f"Unable to uniquely identify a donor atom with chain_iid={item['d_chain']}, res_id={item['d_resi']}, atom_name={item['d_atom']}."
147
+ )
148
+ if current_acceptor_mask.sum() != 1:
149
+ raise ValueError(
150
+ f"Unable to uniquely identify an acceptor atom with chain_iid={item['a_chain']}, res_id={item['a_resi']}, atom_name={item['a_atom']}."
151
+ )
152
+
153
+ current_donor_is_motif = atom_array.is_motif_atom[current_donor_mask][0]
154
+ current_acceptor_is_motif = atom_array.is_motif_atom[current_acceptor_mask][0]
155
+
156
+ # Only keep hbonds between the motif and diffused regions
157
+ if current_donor_is_motif != current_acceptor_is_motif:
158
+ motif_hbonds.append(item)
159
+ donor_mask |= current_donor_mask
160
+ acceptor_mask |= current_acceptor_mask
161
+
162
+ donor_array[donor_mask] = 1
163
+ acceptor_array[acceptor_mask] = 1
164
+
165
+ os.remove(pdb_path)
166
+ os.remove(pdb_path.replace("pdb", "hb2"))
167
+ atom_array.set_annotation("active_donor", donor_array)
168
+ atom_array.set_annotation("active_acceptor", acceptor_array)
169
+
170
+ return atom_array, motif_hbonds, len(motif_hbonds)
171
+
172
+
173
+ class CalculateHbondsPlus(Transform):
174
+ """Transform for calculating Hbonds, expects an AtomArray containing hydrogens."""
175
+
176
+ def __init__(
177
+ self,
178
+ cutoff_HA_dist: float = 3,
179
+ cutoff_DA_distance: float = 3.5,
180
+ ):
181
+ self.cutoff_HA_dist = cutoff_HA_dist
182
+ self.cutoff_DA_distance = cutoff_DA_distance
183
+
184
+ def check_input(self, data: dict[str, Any]) -> None:
185
+ check_contains_keys(data, ["atom_array"])
186
+ check_is_instance(data, "atom_array", AtomArray)
187
+ check_atom_array_annotation(data, ["res_name"])
188
+ # check_atom_array_has_hydrogen(data)
189
+
190
+ def forward(self, data: dict) -> dict:
191
+ atom_array: AtomArray = data["atom_array"]
192
+
193
+ atom_array, hbonds, _ = calculate_hbonds(
194
+ atom_array,
195
+ cutoff_HA_dist=self.cutoff_HA_dist,
196
+ cutoff_DA_distance=self.cutoff_DA_distance,
197
+ )
198
+
199
+ data.setdefault("log_dict", {})
200
+ log_dict = data["log_dict"]
201
+
202
+ hbond_types = np.vstack((atom_array.active_donor, atom_array.active_acceptor)).T
203
+
204
+ final_hbond_types = hbond_types
205
+ final_hbond_types[:, 0] *= np.array(atom_array.is_motif_atom)
206
+ final_hbond_types[:, 1] *= np.array(atom_array.is_motif_atom)
207
+ log_dict["hbond_total_count"] = np.sum(final_hbond_types)
208
+
209
+ if data["conditions"]["hbond_subsample"] and np.sum(final_hbond_types) > 3:
210
+ base_fraction = 0.1
211
+ max_fraction = 0.9
212
+ n_hbonds = np.sum(final_hbond_types)
213
+ max_hbonds = 50
214
+
215
+ fraction = max_fraction - (max_fraction - base_fraction) * min(
216
+ n_hbonds / max_hbonds, 1.0
217
+ )
218
+ final_hbond_types = subsample_one_hot_np(final_hbond_types, fraction)
219
+
220
+ atom_array.set_annotation("active_donor", final_hbond_types[:, 0])
221
+ atom_array.set_annotation("active_acceptor", final_hbond_types[:, 1])
222
+ log_dict["hbond_subsample_atoms"] = np.sum(final_hbond_types)
223
+
224
+ data["log_dict"] = log_dict
225
+ data["atom_array"] = atom_array
226
+
227
+ return data
228
+
229
+
230
+ def subsample_one_hot_np(array, fraction):
231
+ if not (0 < fraction <= 1):
232
+ raise ValueError("Fraction must be in the range (0, 1].")
233
+
234
+ array = array.copy()
235
+ one_indices = np.argwhere(array == 1)
236
+ num_ones = len(one_indices)
237
+ keep_count = int(num_ones * fraction)
238
+
239
+ np.random.shuffle(one_indices)
240
+ keep_indices = one_indices[:keep_count]
241
+
242
+ new_array = np.zeros_like(array)
243
+ for i, j in keep_indices:
244
+ new_array[i, j] = 1
245
+
246
+ return new_array