rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,523 @@
1
+ import numpy as np
2
+ from atomworks.enums import ChainType
3
+ from atomworks.ml.transforms._checks import (
4
+ check_atom_array_annotation,
5
+ check_contains_keys,
6
+ )
7
+ from atomworks.ml.transforms.base import Transform
8
+ from atomworks.ml.transforms.crop import resize_crop_info_if_too_many_atoms
9
+ from atomworks.ml.utils.token import (
10
+ get_token_count,
11
+ spread_token_wise,
12
+ )
13
+ from biotite.structure.basepairs import (
14
+ _check_dssr_criteria,
15
+ _get_proximate_residues,
16
+ get_residue_masks,
17
+ get_residue_starts_for,
18
+ )
19
+ from scipy.spatial import distance_matrix
20
+
21
+
22
+ def protein_dna_contact_contiguous_crop_mask(
23
+ atom_array,
24
+ protein_contact_atoms,
25
+ dna_contact_atoms,
26
+ contact_dist_cutoff,
27
+ protein_expand_min,
28
+ protein_expand_max,
29
+ dna_expand_min,
30
+ dna_expand_max,
31
+ ):
32
+ dna_contact, prot_contact = identify_and_sample_protein_dna_contact(
33
+ atom_array, protein_contact_atoms, dna_contact_atoms, contact_dist_cutoff
34
+ )
35
+
36
+ # total_protein_expand = np.random.randint(protein_expand_min, protein_expand_max)
37
+ left = np.random.randint(protein_expand_min, protein_expand_max)
38
+ right = np.random.randint(protein_expand_min, protein_expand_max)
39
+ protein_keep_mask = expand_connected_component_mask(
40
+ atom_array, prot_contact, left, right
41
+ )
42
+
43
+ # total_dna_expand = np.random.randint(dna_expand_min, dna_expand_max)
44
+ left = np.random.randint(dna_expand_min, dna_expand_max)
45
+ right = np.random.randint(dna_expand_min, dna_expand_max)
46
+ dna_keep_mask = get_dna_mask(atom_array, dna_contact, left, right)
47
+ # count keep protein token num and dna token num
48
+
49
+ mask = np.logical_or(protein_keep_mask, dna_keep_mask)
50
+
51
+ requires_crop = np.any(mask)
52
+ crop_atom_idxs = np.where(mask)[0]
53
+
54
+ token_id = np.arange(get_token_count(atom_array), dtype=np.uint32)
55
+ crop_token_idxs = spread_token_wise(atom_array, token_id)[mask]
56
+
57
+ if get_token_count(atom_array[mask]) > 300:
58
+ raise ValueError(
59
+ "Noncanonical DNAs are causing token count explosion, skipping..."
60
+ )
61
+
62
+ return {
63
+ "type": "ProteinDNAContactContiguousCrop",
64
+ "requires_crop": requires_crop,
65
+ "crop_atom_idxs": crop_atom_idxs,
66
+ "crop_token_idxs": crop_token_idxs,
67
+ "atom_array": atom_array,
68
+ }
69
+
70
+
71
+ def atom_array_from_contact_dict(atom_array, contact_atoms):
72
+ mask = []
73
+ for row in atom_array:
74
+ if (
75
+ row.res_name in contact_atoms.keys()
76
+ and row.atom_name in contact_atoms[row.res_name]
77
+ ):
78
+ mask.append(True)
79
+ else:
80
+ mask.append(False)
81
+
82
+ return atom_array[mask]
83
+
84
+
85
+ def identify_and_sample_protein_dna_contact(
86
+ atom_array, protein_contact_atoms, dna_contact_atoms, contact_dist=4
87
+ ):
88
+ if isinstance(protein_contact_atoms, dict):
89
+ protein = atom_array_from_contact_dict(atom_array, protein_contact_atoms)
90
+ elif isinstance(protein_contact_atoms, list):
91
+ protein = atom_array[
92
+ (atom_array.chain_type == ChainType.POLYPEPTIDE_L)
93
+ & np.isin(atom_array.atom_name, protein_contact_atoms)
94
+ ]
95
+ elif isinstance(protein_contact_atoms, str):
96
+ if protein_contact_atoms == "all":
97
+ protein = atom_array[(atom_array.chain_type == ChainType.POLYPEPTIDE_L)]
98
+ else:
99
+ raise ValueError
100
+ else:
101
+ raise ValueError
102
+
103
+ if isinstance(dna_contact_atoms, dict):
104
+ atom_array = atom_array[atom_array.chain_type == ChainType.DNA]
105
+ dna = atom_array_from_contact_dict(atom_array, dna_contact_atoms)
106
+ elif isinstance(dna_contact_atoms, list):
107
+ dna = atom_array[
108
+ (atom_array.chain_type == ChainType.DNA)
109
+ & (np.isin(atom_array.atom_name, dna_contact_atoms))
110
+ ]
111
+ elif isinstance(dna_contact_atoms, str):
112
+ if dna_contact_atoms == "all":
113
+ dna = atom_array[(atom_array.chain_type == ChainType.DNA)]
114
+ else:
115
+ raise ValueError
116
+ else:
117
+ raise ValueError
118
+ pdist = distance_matrix(dna.coord, protein.coord)
119
+
120
+ contacts = np.stack(np.where(pdist < contact_dist), axis=1)
121
+
122
+ try:
123
+ sample = contacts[np.random.choice(range(len(contacts)))]
124
+ except Exception:
125
+ raise ValueError("No protein-DNA contacts found")
126
+
127
+ dna_contact = dna[sample[0]]
128
+ prot_contact = protein[sample[1]]
129
+
130
+ return dna_contact, prot_contact
131
+
132
+
133
+ def create_residue_mask(atom_array, first_atom_indices):
134
+ """
135
+ Creates a boolean mask for entire residues based on indices of their first atoms.
136
+ Uses efficient broadcasting for better performance.
137
+
138
+ Parameters
139
+ ----------
140
+ atom_array : biotite.structure.atom_array
141
+ The atom array to create the mask for
142
+ first_atom_indices : array-like
143
+ Indices of the first atoms of the residues to select
144
+
145
+ Returns
146
+ -------
147
+ numpy.ndarray
148
+ Boolean mask that can be used to select all atoms of the specified residues
149
+ """
150
+ # Get target residue IDs and chain IDs as 2D arrays
151
+ target_res_ids = atom_array.res_id[first_atom_indices][:, np.newaxis]
152
+ target_chain_ids = atom_array.chain_id[first_atom_indices][:, np.newaxis]
153
+
154
+ # Use broadcasting to create masks for all residues at once
155
+ res_match = atom_array.res_id == target_res_ids
156
+ chain_match = atom_array.chain_id == target_chain_ids
157
+
158
+ # Combine the matches
159
+ mask = (res_match & chain_match).any(axis=0)
160
+
161
+ return mask
162
+
163
+
164
+ def expand_connected_component_mask(atom_array, origin, left_expand, right_expand):
165
+ center = origin.within_poly_res_idx
166
+ left = center - left_expand
167
+ right = center + right_expand
168
+ candidates = list(range(left, right))
169
+ keep_mask = (atom_array.chain_id == origin.chain_id) & np.isin(
170
+ atom_array.within_poly_res_idx, candidates
171
+ )
172
+ return keep_mask
173
+
174
+
175
+ def get_dna_mask(atom_array, origin, left_expand, right_expand):
176
+ one_chain_mask = expand_connected_component_mask(
177
+ atom_array, origin, left_expand, right_expand
178
+ )
179
+
180
+ pairs = base_pairs(atom_array)
181
+
182
+ other_chain_first_atom_indices = []
183
+ one_chain_first_atom_tags = np.zeros(len(atom_array), dtype=bool)
184
+ for pair in pairs:
185
+ if one_chain_mask[pair[0]]:
186
+ other_chain_first_atom_indices.append(pair[1])
187
+ one_chain_first_atom_tags[pair[0]] = True
188
+
189
+ elif one_chain_mask[pair[1]]:
190
+ other_chain_first_atom_indices.append(pair[0])
191
+ one_chain_first_atom_tags[pair[1]] = True
192
+
193
+ other_chain_mask = create_residue_mask(atom_array, other_chain_first_atom_indices)
194
+
195
+ return np.logical_or(one_chain_mask, other_chain_mask)
196
+
197
+
198
+ class ProteinDNAContactContiguousCrop(Transform):
199
+ """
200
+ A transform the crops the DNA-protein contact region according to the continous region of contact.
201
+
202
+ Args:
203
+ protein_contact_type (str): The type of protein contact atoms to consider. Can be 'backbone', 'sidechain', 'all', or 'from_dict'
204
+ dna_contact_type (str): The type of DNA contact atoms to consider. Can be 'backbone', 'base', 'all', or 'from_dict'
205
+ contact_distance_cutoff (float): The distance cutoff for considering two atoms to be in contact
206
+ """
207
+
208
+ def __init__(
209
+ self,
210
+ protein_contact_type,
211
+ dna_contact_type,
212
+ contact_distance_cutoff=10.0,
213
+ protein_expand_min=15,
214
+ protein_expand_max=40,
215
+ dna_expand_min=3,
216
+ dna_expand_max=10,
217
+ keep_uncropped_atom_array: bool = False,
218
+ max_atoms_in_crop=None,
219
+ protein_contact_atom_dict=None,
220
+ dna_contact_atom_dict=None,
221
+ ):
222
+ if protein_contact_type == "backbone":
223
+ self.protein_contact_atoms = ["N", "CA", "C"]
224
+ elif protein_contact_type == "all":
225
+ self.protein_contact_atoms = "all"
226
+
227
+ if dna_contact_type == "backbone":
228
+ self.dna_contact_atoms = ["P", "OP1", "OP2"]
229
+ elif dna_contact_type == "base":
230
+ self.dna_contact_atoms = {
231
+ "DA": ["N7", "N6"],
232
+ "DC": ["N4"],
233
+ "DG": ["N7", "O6"],
234
+ "DT": ["O4"],
235
+ }
236
+ else:
237
+ self.dna_contact_atoms = "all"
238
+
239
+ self.protein_contact_type = protein_contact_type
240
+ self.dna_contact_type = dna_contact_type
241
+
242
+ self.protein_expand_min = protein_expand_min
243
+ self.protein_expand_max = protein_expand_max
244
+ self.dna_expand_min = dna_expand_min
245
+ self.dna_expand_max = dna_expand_max
246
+ self.contact_distance_cutoff = contact_distance_cutoff
247
+
248
+ self.keep_uncropped_atom_array = keep_uncropped_atom_array
249
+ self.max_atoms_in_crop = max_atoms_in_crop
250
+
251
+ def check_input(self, data: dict):
252
+ check_contains_keys(data, ["atom_array"])
253
+ check_atom_array_annotation(data, ["res_name"])
254
+
255
+ def forward(self, data: dict) -> dict:
256
+ atom_array = data["atom_array"]
257
+
258
+ crop_info = protein_dna_contact_contiguous_crop_mask(
259
+ atom_array,
260
+ self.protein_contact_atoms,
261
+ self.dna_contact_atoms,
262
+ self.contact_distance_cutoff,
263
+ self.protein_expand_min,
264
+ self.protein_expand_max,
265
+ self.dna_expand_min,
266
+ self.dna_expand_max,
267
+ )
268
+
269
+ crop_info = resize_crop_info_if_too_many_atoms(
270
+ crop_info=crop_info,
271
+ atom_array=atom_array,
272
+ max_atoms=self.max_atoms_in_crop,
273
+ )
274
+
275
+ if self.keep_uncropped_atom_array:
276
+ data["uncropped_atom_array"] = atom_array
277
+
278
+ if crop_info["requires_crop"]:
279
+ data["atom_array"] = atom_array[crop_info["crop_atom_idxs"]]
280
+ data["crop_info"] = crop_info
281
+ else:
282
+ data["atom_array"] = atom_array
283
+ return data
284
+
285
+
286
+ def fill_nan_coords_with_random(atoms, min_val=-50, max_val=50, seed=None):
287
+ """
288
+ Fill NaN coordinates in a biotite AtomArray with random values.
289
+
290
+ Parameters
291
+ ----------
292
+ atoms : biotite.structure.AtomArray
293
+ The atom array containing coordinates to be filled
294
+ min_val : float, optional
295
+ Minimum value for random coordinates (default: -50)
296
+ max_val : float, optional
297
+ Maximum value for random coordinates (default: 50)
298
+ seed : int, optional
299
+ Random seed for reproducibility
300
+
301
+ Returns
302
+ -------
303
+ biotite.structure.AtomArray
304
+ A new AtomArray with NaN coordinates filled
305
+ """
306
+ # Create a copy to avoid modifying the original
307
+ filled_atoms = atoms.copy()
308
+
309
+ # Set random seed if provided
310
+ if seed is not None:
311
+ np.random.seed(seed)
312
+
313
+ # Get the coordinate array
314
+ coords = filled_atoms.coord
315
+
316
+ # Find indices of NaN values
317
+ nan_mask = np.isnan(coords)
318
+
319
+ # Generate random values for NaN positions
320
+ random_coords = np.random.uniform(
321
+ low=min_val, high=max_val, size=coords[nan_mask].shape
322
+ )
323
+
324
+ # Fill NaN values with random coordinates
325
+ coords[nan_mask] = random_coords
326
+
327
+ return filled_atoms
328
+
329
+
330
+ def base_pairs(atom_array, min_atoms_per_base=3, unique=True):
331
+ """
332
+ Use DSSR criteria to find the base pairs in an :class:`atom_array`.
333
+
334
+ The algorithm is able to identify canonical and non-canonical
335
+ base pairs. between the 5 common bases Adenine, Guanine, Thymine,
336
+ Cytosine, and Uracil bound to Deoxyribose and Ribose.
337
+ Each Base is mapped to the 5 common bases Adenine, Guanine, Thymine,
338
+ Cytosine, and Uracil in a standard reference frame described in
339
+ :footcite:`Olson2001` using :func:`map_nucleotide()`.
340
+
341
+ The DSSR Criteria are as follows :footcite:`Lu2015`:
342
+
343
+ (i) Distance between base origins <=15 Å
344
+
345
+ (ii) Vertical separation between the base planes <=2.5 Å
346
+
347
+ (iii) Angle between the base normal vectors <=65°
348
+
349
+ (iv) Absence of stacking between the two bases
350
+
351
+ (v) Presence of at least one hydrogen bond involving a base atom
352
+
353
+ Parameters
354
+ ----------
355
+ atom_array : atom_array
356
+ The :class:`atom_array` to find base pairs in.
357
+ min_atoms_per_base : integer, optional (default: 3)
358
+ The number of atoms a nucleotides' base must have to be
359
+ considered a candidate for a base pair.
360
+ unique : bool, optional (default: True)
361
+ If ``True``, each base is assumed to be only paired with one
362
+ other base. If multiple pairings are plausible, the pairing with
363
+ the most hydrogen bonds is selected.
364
+
365
+ Returns
366
+ -------
367
+ basepairs : ndarray, dtype=int, shape=(n,2)
368
+ Each row is equivalent to one base pair and contains the first
369
+ indices of the residues corresponding to each base.
370
+
371
+ Notes
372
+ -----
373
+ The bases from the standard reference frame described in
374
+ :footcite:`Olson2001` were modified such that only the base atoms
375
+ are implemented.
376
+ Sugar atoms (specifically C1') were disregarded, as nucleosides such
377
+ as PSU do not posess the usual N-glycosidic linkage, thus leading to
378
+ inaccurate results.
379
+
380
+ The vertical separation is implemented as the scalar
381
+ projection of the distance vectors between the base origins
382
+ according to :footcite:`Lu1997` onto the averaged base normal
383
+ vectors.
384
+
385
+ The presence of base stacking is assumed if the following criteria
386
+ are met :footcite:`Gabb1996`:
387
+
388
+ (i) Distance between aromatic ring centers <=4.5 Å
389
+
390
+ (ii) Angle between the ring normal vectors <=23°
391
+
392
+ (iii) Angle between normalized distance vector between two ring
393
+ centers and both bases' normal vectors <=40°
394
+
395
+ Please note that ring normal vectors are assumed to be equal to the
396
+ base normal vectors.
397
+
398
+ For structures without hydrogens the accuracy of the algorithm is
399
+ limited as the hydrogen bonds can be only checked be checked for
400
+ plausibility.
401
+ A hydrogen bond is considered as plausible if a cutoff of 3.6 Å
402
+ between N/O atom pairs is met. 3.6Å was chosen as hydrogen bonds are
403
+ typically 1.5-2.5Å in length. N-H and O-H bonds have a length of
404
+ 1.00Å and 0.96Å respectively. Thus, including some buffer, a 3.6Å
405
+ cutoff should cover all hydrogen bonds.
406
+
407
+ Examples
408
+ --------
409
+ Compute the base pairs for the structure with the PDB ID 1QXB:
410
+
411
+ >>> from os.path import join
412
+ >>> dna_helix = load_structure(join(path_to_structures, "base_pairs", "1qxb.cif"))
413
+ >>> basepairs = base_pairs(dna_helix)
414
+ >>> print(dna_helix[basepairs].res_name)
415
+ [['DC' 'DG']
416
+ ['DG' 'DC']
417
+ ['DC' 'DG']
418
+ ['DG' 'DC']
419
+ ['DA' 'DT']
420
+ ['DA' 'DT']
421
+ ['DT' 'DA']
422
+ ['DT' 'DA']
423
+ ['DC' 'DG']
424
+ ['DG' 'DC']
425
+ ['DC' 'DG']
426
+ ['DG' 'DC']]
427
+
428
+ References
429
+ ----------
430
+
431
+ .. footbibliography::
432
+ """
433
+ dna_boolean = np.logical_and(
434
+ atom_array.chain_type == ChainType.DNA,
435
+ np.isin(atom_array.res_name, ["DA", "DG", "DT", "DC"]),
436
+ )
437
+
438
+ # Get the nucleotides for the given atom_array
439
+ # Disregard the phosphate-backbone
440
+ non_phosphate_boolean = ~np.isin(
441
+ atom_array.atom_name, ["O5'", "P", "OP1", "OP2", "OP3", "HOP2", "HOP3"]
442
+ )
443
+
444
+ # Combine the two boolean masks
445
+ boolean_mask = np.logical_and(non_phosphate_boolean, dna_boolean)
446
+
447
+ # Get only nucleosides
448
+ nucleosides = atom_array[boolean_mask]
449
+
450
+ # Get the base pair candidates according to a N/O cutoff distance,
451
+ # where each base is identified as the first index of its respective
452
+ # residue
453
+ n_o_mask = np.isin(nucleosides.element, ["N", "O"])
454
+
455
+ nucleosides = fill_nan_coords_with_random(nucleosides)
456
+ basepair_candidates, n_o_matches = _get_proximate_residues(
457
+ nucleosides, n_o_mask, 3.6
458
+ )
459
+
460
+ # Contains the plausible base pairs
461
+ basepairs = []
462
+ # Contains the number of hydrogens for each plausible base pair
463
+ basepairs_hbonds = []
464
+
465
+ # Get the residue masks for each residue
466
+ base_masks = get_residue_masks(nucleosides, basepair_candidates.flatten())
467
+
468
+ # Group every two masks together for easy iteration (each 'row' is
469
+ # respective to a row in ``basepair_candidates``)
470
+ base_masks = base_masks.reshape(
471
+ (basepair_candidates.shape[0], 2, nucleosides.shape[0])
472
+ )
473
+
474
+ for (base1_index, base2_index), (base1_mask, base2_mask), n_o_pairs in zip(
475
+ basepair_candidates, base_masks, n_o_matches
476
+ ):
477
+ base1 = nucleosides[base1_mask]
478
+ base2 = nucleosides[base2_mask]
479
+
480
+ hbonds = _check_dssr_criteria((base1, base2), min_atoms_per_base, unique)
481
+
482
+ # If no hydrogens are present use the number N/O pairs to
483
+ # decide between multiple pairing possibilities.
484
+
485
+ if hbonds is None:
486
+ # Each N/O-pair is detected twice. Thus, the number of
487
+ # matches must be divided by two.
488
+ hbonds = n_o_pairs / 2
489
+ if hbonds != -1:
490
+ basepairs.append((base1_index, base2_index))
491
+ if unique:
492
+ basepairs_hbonds.append(hbonds)
493
+
494
+ basepair_array = np.array(basepairs)
495
+
496
+ if unique:
497
+ # Contains all non-unique base pairs that are flagged to be
498
+ # removed
499
+ to_remove = []
500
+
501
+ # Get all bases that have non-unique pairing interactions
502
+ base_indices, occurrences = np.unique(basepairs, return_counts=True)
503
+ for base_index, occurrence in zip(base_indices, occurrences):
504
+ if occurrence > 1:
505
+ # Write the non-unique base pairs to a dictionary as
506
+ # 'index: number of hydrogen bonds'
507
+ remove_candidates = {}
508
+ for i, row in enumerate(np.asarray(basepair_array == base_index)):
509
+ if np.any(row):
510
+ remove_candidates[i] = basepairs_hbonds[i]
511
+ # Flag all non-unique base pairs for removal except the
512
+ # one that has the most hydrogen bonds
513
+ del remove_candidates[max(remove_candidates, key=remove_candidates.get)]
514
+ to_remove += list(remove_candidates.keys())
515
+ # Remove all flagged base pairs from the output `ndarray`
516
+ basepair_array = np.delete(basepair_array, to_remove, axis=0)
517
+
518
+ # Remap values to original atom array
519
+ if len(basepair_array) > 0:
520
+ basepair_array = np.where(boolean_mask)[0][basepair_array]
521
+ for i, row in enumerate(basepair_array):
522
+ basepair_array[i] = get_residue_starts_for(atom_array, row)
523
+ return basepair_array