rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,541 @@
1
+ # from atomworks.ml.utils.token import get_token_masks, get_token_starts
2
+ from typing import Any
3
+
4
+ import biotite.structure as struc
5
+ import numpy as np
6
+ from assertpy import assert_that
7
+ from atomworks.ml.preprocessing.utils.structure_utils import (
8
+ get_atom_mask_from_cell_list,
9
+ )
10
+ from atomworks.ml.transforms._checks import (
11
+ check_atom_array_annotation,
12
+ check_contains_keys,
13
+ check_is_instance,
14
+ )
15
+ from atomworks.ml.transforms.atom_array import atom_id_to_atom_idx, atom_id_to_token_idx
16
+ from atomworks.ml.transforms.base import Transform
17
+ from atomworks.ml.transforms.crop import (
18
+ get_spatial_crop_center,
19
+ get_token_count,
20
+ resize_crop_info_if_too_many_atoms,
21
+ )
22
+ from atomworks.ml.utils.token import (
23
+ get_af3_token_center_coords,
24
+ get_af3_token_center_masks,
25
+ get_token_starts,
26
+ spread_token_wise,
27
+ )
28
+ from biotite.structure import AtomArray
29
+ from rfd3.transforms.conditioning_utils import sample_island_tokens
30
+ from scipy.spatial import KDTree
31
+
32
+ from foundry.utils.ddp import RankedLogger
33
+
34
+ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
35
+
36
+ # NOTE: This transform is based off of `rf_diffusion_aa.rf_diffusion.ppi.FindHotspotsTrainingTransform`
37
+ # However, this is progressing piecewise, and many features of that transform are not yet implemented.
38
+ # If this seems to be working, those should definitely be added in the future!
39
+
40
+ # NOTE: In contrast to RFD, we are providing hotspots at the atom level, not the residue level.
41
+ # Future hotspot subsampling schemes might want to avoid giving redundant information via (say) bonded atoms
42
+
43
+
44
+ def get_hotspot_atoms(atom_array, binder_pn_unit_iid, distance_cutoff=4.5):
45
+ """Get hotspot atoms for a given distance cutoff.
46
+
47
+ Args:
48
+ atom_array (AtomArray): The atom array containing the protein structure.
49
+ binder_pn_unit_iid (str): The chain ID of the binder (diffused chain).
50
+ distance_cutoff (float): The interchain distance cutoff that defines hotspot atoms.
51
+
52
+ Hotspots are atoms on non-binder chains that are within the distance cutoff of some residue on the binder.
53
+ Residue distances are computed as the minimum pairwise distance between the two atoms.
54
+ """
55
+
56
+ # We can only perform distance computations on atoms with non-NaN coordinates
57
+ nan_coords_mask = np.any(np.isnan(atom_array.coord), axis=1)
58
+ non_nan_atom_array = atom_array[~nan_coords_mask]
59
+
60
+ binder_atom_array = non_nan_atom_array[
61
+ non_nan_atom_array.pn_unit_iid == binder_pn_unit_iid
62
+ ]
63
+
64
+ # Perform the hotspot computation
65
+ cell_list = struc.CellList(non_nan_atom_array, cell_size=distance_cutoff)
66
+
67
+ full_contacting_atom_mask = get_atom_mask_from_cell_list(
68
+ binder_atom_array.coord, cell_list, len(non_nan_atom_array), distance_cutoff
69
+ ) # (n_query, n_cell_list)
70
+ contacting_atoms_mask = np.any(full_contacting_atom_mask, axis=0) # (n_cell_list,)
71
+
72
+ # Filter out atoms in the binder chain
73
+ non_query_atoms_mask = non_nan_atom_array.pn_unit_iid != binder_pn_unit_iid
74
+ hotspot_atom_mask = contacting_atoms_mask & non_query_atoms_mask
75
+
76
+ # Convert from mask over non-nan coords to mask over all coords
77
+ full_hotspot_atom_mask = np.zeros(len(atom_array), dtype=bool)
78
+ full_hotspot_atom_mask[~nan_coords_mask] = hotspot_atom_mask
79
+
80
+ return full_hotspot_atom_mask
81
+
82
+
83
+ def get_secondary_structure_types(atom_array: AtomArray) -> np.ndarray:
84
+ """Get the secondary structure types for a given atom array.
85
+
86
+ For now, only three categories will be one-hot encoded: helix, sheet, and loop.
87
+ """
88
+ ss_types = np.zeros((atom_array.array_length(), 3), dtype=bool)
89
+
90
+ # HACK: Temporarily overwrite res_id with token_id so that the sse_array will have length n_tokens
91
+ actual_res_id = atom_array.res_id.copy()
92
+ atom_array.res_id = atom_array.token_id
93
+
94
+ # Since annotate_sse detects chainbreaks based on res_id discontinuities, we create discontinuities where needed
95
+ _, chain_offsets = np.unique(atom_array.chain_iid, return_inverse=True)
96
+ atom_array.res_id += chain_offsets
97
+
98
+ # Compute secondary structure information
99
+ sse_array = struc.annotate_sse(atom_array)
100
+ assert len(sse_array) == len(
101
+ np.unique(atom_array.token_id)
102
+ ), "SSE array length does not match number of tokens."
103
+
104
+ # Restore original res_id
105
+ atom_array.res_id = actual_res_id
106
+
107
+ sse_array = spread_token_wise(atom_array, sse_array)
108
+ ss_types[:, 0] = sse_array == "a"
109
+ ss_types[:, 1] = sse_array == "b"
110
+ ss_types[:, 2] = sse_array == "c"
111
+
112
+ return ss_types
113
+
114
+
115
+ class AddGlobalIsNonLoopyFeature(Transform):
116
+ """Add feature indicating whether the global loop content in the non-motif region is below 30%.
117
+
118
+ For this initial implementation, only three categories will be one-hot encoded: helix, sheet, and loop.
119
+ """
120
+
121
+ def check_input(self, data: dict[str, Any]) -> None:
122
+ check_contains_keys(data, ["atom_array"])
123
+ check_is_instance(data, "atom_array", AtomArray)
124
+ check_atom_array_annotation(data, ["is_motif_token"])
125
+
126
+ def forward(self, data: dict[str, Any]) -> dict[str, Any]:
127
+ atom_array = data["atom_array"]
128
+
129
+ # Compute all ground-truth secondary structure types for the binder chain.
130
+ # For now boolean, later could include distances as in RFD. But maybe that's better as a 2D condition
131
+ gt_secondary_structures = get_secondary_structure_types(atom_array)
132
+ atom_array.set_annotation("is_loop_gt", gt_secondary_structures[:, 2])
133
+
134
+ is_motif_atom = atom_array.is_motif_token
135
+ is_non_loopy = atom_array.is_loop_gt[~is_motif_atom].mean() < 0.3
136
+ is_non_loopy_annot = np.full(
137
+ atom_array.array_length(), 1 if is_non_loopy else -1, dtype=int
138
+ )
139
+
140
+ atom_array.set_annotation("is_non_loopy", is_non_loopy_annot)
141
+
142
+ # HACK: Enables adding as atom-level features as well
143
+ atom_array.set_annotation("is_non_loopy_atom_level", is_non_loopy_annot)
144
+
145
+ return data
146
+
147
+
148
+ class Add1DSSFeature(Transform):
149
+ """Add secondary structure features to training examples.
150
+
151
+ For this initial implementation, only three categories will be one-hot encoded: helix, sheet, and loop.
152
+ """
153
+
154
+ def __init__(
155
+ self,
156
+ max_secondary_structure_frac_to_provide: float = 0.4,
157
+ min_ss_island_len: int = 1,
158
+ max_ss_island_len: int = 10, # Might want to expand later, this is only average. Done for now to avoid over-conditioning.
159
+ n_islands_max: int = 3,
160
+ ):
161
+ self.max_secondary_structure_frac_to_provide = (
162
+ max_secondary_structure_frac_to_provide
163
+ )
164
+ self.min_ss_island_len = min_ss_island_len
165
+ self.max_ss_island_len = max_ss_island_len
166
+ self.n_islands_max = n_islands_max
167
+
168
+ def check_input(self, data: dict[str, Any]) -> None:
169
+ check_contains_keys(data, ["atom_array"])
170
+ check_is_instance(data, "atom_array", AtomArray)
171
+ check_atom_array_annotation(data, ["is_motif_token"])
172
+
173
+ def forward(self, data: dict[str, Any]) -> dict[str, Any]:
174
+ atom_array = data["atom_array"]
175
+
176
+ # Compute all ground-truth secondary structure types for the binder chain.
177
+ gt_secondary_structures = get_secondary_structure_types(atom_array)
178
+ atom_array.set_annotation("is_helix_gt", gt_secondary_structures[:, 0])
179
+ atom_array.set_annotation("is_sheet_gt", gt_secondary_structures[:, 1])
180
+ atom_array.set_annotation("is_loop_gt", gt_secondary_structures[:, 2])
181
+
182
+ if not data["conditions"]["add_1d_ss_features"]:
183
+ return data
184
+
185
+ # Always add the secondary structure type annotation, even if all zeros
186
+ atom_array.set_annotation(
187
+ "is_helix_conditioning", np.zeros(atom_array.array_length(), dtype=bool)
188
+ )
189
+ atom_array.set_annotation(
190
+ "is_sheet_conditioning", np.zeros(atom_array.array_length(), dtype=bool)
191
+ )
192
+ atom_array.set_annotation(
193
+ "is_loop_conditioning", np.zeros(atom_array.array_length(), dtype=bool)
194
+ )
195
+
196
+ # # Uniformly sample a number of tokens to receive secondary structure conditioning, up to the given maximum fraction
197
+ max_residues_with_ss_conditioning = int(
198
+ np.ceil(
199
+ gt_secondary_structures.sum()
200
+ * self.max_secondary_structure_frac_to_provide
201
+ )
202
+ )
203
+
204
+ # Compute islands within the subset that is diffused and has secondary structure types.
205
+ token_level_array = atom_array[get_token_starts(atom_array)]
206
+ is_motif_token = token_level_array.is_motif_token
207
+ eligible_for_ss_info_mask = (
208
+ ~is_motif_token
209
+ & token_level_array.is_protein
210
+ & ( # Protein atoms with NaN coordinates would have no secondary structure annotation
211
+ token_level_array.is_helix_gt
212
+ | token_level_array.is_sheet_gt
213
+ | token_level_array.is_loop_gt
214
+ )
215
+ )
216
+ where_to_show_ss = sample_island_tokens(
217
+ eligible_for_ss_info_mask.sum(),
218
+ island_len_min=self.min_ss_island_len,
219
+ island_len_max=self.max_ss_island_len,
220
+ n_islands_min=1,
221
+ n_islands_max=self.n_islands_max,
222
+ max_length=max_residues_with_ss_conditioning,
223
+ )
224
+
225
+ # Convert this to a mask over the entire token-level atom array
226
+ token_level_ss_mask = np.zeros(token_level_array.array_length(), dtype=bool)
227
+ token_level_ss_mask[eligible_for_ss_info_mask] = where_to_show_ss
228
+ ss_mask = spread_token_wise(atom_array, token_level_ss_mask)
229
+
230
+ # Add the secondary structure type annotation
231
+ atom_array.is_helix_conditioning[ss_mask] = atom_array.is_helix_gt[ss_mask]
232
+ atom_array.is_sheet_conditioning[ss_mask] = atom_array.is_sheet_gt[ss_mask]
233
+ atom_array.is_loop_conditioning[ss_mask] = atom_array.is_loop_gt[ss_mask]
234
+
235
+ return data
236
+
237
+
238
+ class AddPPIHotspotFeature(Transform):
239
+ """Add hotspot features to PPI training examples."""
240
+
241
+ def __init__(
242
+ self,
243
+ max_hotspots_frac_to_provide: float = 0.2,
244
+ hotspot_max_distance: float = 7.0,
245
+ ):
246
+ """
247
+ Args:
248
+ max_hotspots_frac_to_provide (int): Maximum fraction of ground-truth hotspots to add to the training example.
249
+ The actual number added will be sampled uniformly from 0 to the number dictated by this parameter.
250
+ hotspot_min_distance (float): Maximum distance to the binder for an atom to be considered a hotspot.
251
+ """
252
+ self.max_hotspots_frac_to_provide = max_hotspots_frac_to_provide
253
+ self.hotspot_max_distance = hotspot_max_distance
254
+
255
+ def check_input(self, data: dict[str, Any]) -> None:
256
+ check_contains_keys(data, ["atom_array"])
257
+ check_is_instance(data, "atom_array", AtomArray)
258
+ check_atom_array_annotation(data, ["is_motif_token"])
259
+
260
+ def forward(self, data: dict[str, Any]) -> dict[str, Any]:
261
+ atom_array = data["atom_array"]
262
+
263
+ # Always add the hotspot annotation, even if all zeros
264
+ atom_array.set_annotation(
265
+ "is_atom_level_hotspot", np.zeros(atom_array.array_length(), dtype=bool)
266
+ )
267
+
268
+ # Compute all ground-truth hotspots for the binder chain.
269
+ # For now boolean, later could include distances as in RFD. But maybe that's better as a 2D condition
270
+ is_hotspot_atom_mask = get_hotspot_atoms(
271
+ atom_array,
272
+ binder_pn_unit_iid=data["binder_pn_unit"],
273
+ distance_cutoff=self.hotspot_max_distance,
274
+ )
275
+ atom_array.set_annotation("is_hotspot_gt", is_hotspot_atom_mask)
276
+
277
+ # Uniformly sample a number of hotspots to include, up to the given maximum fraction
278
+ max_hotspots_to_keep = int(
279
+ np.ceil(sum(is_hotspot_atom_mask) * self.max_hotspots_frac_to_provide)
280
+ )
281
+ if max_hotspots_to_keep == 0:
282
+ ranked_logger.warning("No hotspots found in the input data")
283
+ return data
284
+ else:
285
+ num_hotspots_to_keep = np.random.randint(
286
+ 0,
287
+ int(
288
+ np.ceil(
289
+ sum(is_hotspot_atom_mask) * self.max_hotspots_frac_to_provide
290
+ )
291
+ ),
292
+ )
293
+
294
+ # Subsample hotspots to add.
295
+ # For now random, later could add speckle_or_region from RFD
296
+ true_hotspot_indices = np.where(is_hotspot_atom_mask)[0]
297
+ hotspots_to_provide = np.random.choice(
298
+ true_hotspot_indices, size=num_hotspots_to_keep, replace=False
299
+ )
300
+ atom_array.is_atom_level_hotspot[hotspots_to_provide] = True
301
+
302
+ return data
303
+
304
+
305
+ class PPIFullBinderCropSpatial(Transform):
306
+ """Crop which keeps the entire binder chain, then crops spatially around the given interface.
307
+ Args:
308
+ crop_size (int): The maximum number of tokens to crop. Must be greater than 0.
309
+ jitter_scale (float, optional): The scale of the jitter to apply to the crop center.
310
+ This is to break ties between atoms with the same spatial distance. Defaults to 1e-3.
311
+ crop_center_cutoff_distance (float, optional): The cutoff distance to consider for
312
+ selecting crop centers. Measured in Angstroms. Defaults to 15.0.
313
+ keep_uncropped_atom_array (bool, optional): Whether to keep the uncropped atom array in the data.
314
+ If `True`, the uncropped atom array will be stored in the `crop_info` dictionary
315
+ under the key `"atom_array"`. Defaults to `False`.
316
+ force_crop (bool, optional): Whether to force crop even if the atom array is already small enough.
317
+ Defaults to `False`.
318
+ max_atoms_in_crop (int, optional): Maximum number of atoms allowed in a crop. If None, no resizing is performed.
319
+ Defaults to None.
320
+ """
321
+
322
+ def __init__(
323
+ self,
324
+ crop_size: int,
325
+ jitter_scale: float = 1e-3,
326
+ crop_center_cutoff_distance: float = 15.0,
327
+ keep_uncropped_atom_array: bool = False,
328
+ force_crop: bool = False,
329
+ max_atoms_in_crop: int | None = None,
330
+ ):
331
+ self.crop_size = crop_size
332
+ self.jitter_scale = jitter_scale
333
+ self.crop_center_cutoff_distance = crop_center_cutoff_distance
334
+ self.keep_uncropped_atom_array = keep_uncropped_atom_array
335
+ self.force_crop = force_crop
336
+ self.max_atoms_in_crop = max_atoms_in_crop
337
+
338
+ def check_input(self, data: dict):
339
+ check_contains_keys(data, ["atom_array"])
340
+ check_is_instance(data, "atom_array", AtomArray)
341
+ check_atom_array_annotation(data, ["pn_unit_iid", "atomize", "atom_id"])
342
+
343
+ def forward(self, data: dict) -> dict:
344
+ atom_array = data["atom_array"]
345
+
346
+ if "query_pn_unit_iids" in data and data["query_pn_unit_iids"]:
347
+ query_pn_units = data["query_pn_unit_iids"]
348
+ else:
349
+ query_pn_units = np.unique(atom_array.pn_unit_iid)
350
+ ranked_logger.info(
351
+ f"No query PN unit(s) provided for spatial crop. Randomly selecting from {query_pn_units}."
352
+ )
353
+
354
+ if "binder_pn_unit" not in data:
355
+ raise ValueError("Data dict must contain 'binder_pn_unit' key.")
356
+
357
+ crop_info = crop_spatial_keep_full_binder(
358
+ atom_array=atom_array,
359
+ query_pn_unit_iids=query_pn_units,
360
+ binder_pn_unit_iid=data["binder_pn_unit"],
361
+ crop_size=self.crop_size,
362
+ jitter_scale=self.jitter_scale,
363
+ crop_center_cutoff_distance=self.crop_center_cutoff_distance,
364
+ force_crop=self.force_crop,
365
+ )
366
+ crop_info = resize_crop_info_if_too_many_atoms(
367
+ crop_info=crop_info,
368
+ atom_array=atom_array,
369
+ max_atoms=self.max_atoms_in_crop,
370
+ )
371
+
372
+ data["crop_info"] = {"type": self.__class__.__name__} | crop_info
373
+
374
+ if self.keep_uncropped_atom_array:
375
+ data["crop_info"]["atom_array"] = atom_array
376
+
377
+ # Update data with cropped atom array
378
+ data["atom_array"] = atom_array[crop_info["crop_atom_idxs"]]
379
+
380
+ return data
381
+
382
+
383
+ def crop_spatial_keep_full_binder(
384
+ atom_array: AtomArray,
385
+ query_pn_unit_iids: list[str],
386
+ binder_pn_unit_iid: str,
387
+ crop_size: int,
388
+ jitter_scale: float = 1e-3,
389
+ crop_center_cutoff_distance: float = 15.0,
390
+ force_crop: bool = False,
391
+ ) -> dict:
392
+ """
393
+ Crop spatial tokens around a given `crop_center` by keeping the entire binder chain, then taking nearest
394
+ neighbors (with jitter) until reaching the `crop_size`.
395
+
396
+ Args:
397
+ - atom_array (AtomArray): The atom array to crop.
398
+ - query_pn_unit_iids (list[str]): List of query polymer/non-polymer unit instance IDs.
399
+ - binder_pn_unit_iid (str): The polymer/non-polymer unit instance ID corresponding to the binder.
400
+ - crop_size (int): The maximum number of tokens to crop.
401
+ - jitter_scale (float, optional): Scale of jitter to apply when calculating distances.
402
+ Defaults to 1e-3.
403
+ - crop_center_cutoff_distance (float, optional): Maximum distance from query units to
404
+ consider for crop center. Defaults to 15.0 Angstroms.
405
+ - force_crop (bool, optional): Whether to force crop even if the atom array is already small enough.
406
+ Defaults to False.
407
+
408
+ Returns:
409
+ dict: A dictionary containing crop information, including:
410
+ - requires_crop (bool): Whether cropping was necessary.
411
+ - crop_center_atom_id (int or np.nan): ID of the atom chosen as crop center.
412
+ - crop_center_atom_idx (int or np.nan): Index of the atom chosen as crop center.
413
+ - crop_center_token_idx (int or np.nan): Index of the token containing the crop center.
414
+ - crop_token_idxs (np.ndarray): Indices of tokens included in the crop.
415
+ - crop_atom_idxs (np.ndarray): Indices of atoms included in the crop.
416
+
417
+ Note:
418
+ This function implements the spatial cropping procedure as described in AlphaFold 3 and AlphaFold 2 Multimer.
419
+
420
+ References:
421
+ - AF3 https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-024-07487-w/MediaObjects/41586_2024_7487_MOESM1_ESM.pdf
422
+ - AF2 Multimer https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2.full.pdf
423
+ """
424
+ if binder_pn_unit_iid not in query_pn_unit_iids:
425
+ raise ValueError(
426
+ f"Binder polymer/non-polymer unit instance ID '{binder_pn_unit_iid}' "
427
+ f"not found in query polymer/non-polymer unit instance IDs: {query_pn_unit_iids}"
428
+ )
429
+ n_tokens = get_token_count(atom_array)
430
+ requires_crop = n_tokens > crop_size
431
+
432
+ # ... get binder information
433
+ binder_token_mask = (
434
+ atom_array[get_af3_token_center_masks(atom_array)].pn_unit_iid
435
+ == binder_pn_unit_iid
436
+ )
437
+ binder_atom_mask = atom_array.pn_unit_iid == binder_pn_unit_iid
438
+ n_binder_tokens = get_token_count(atom_array[binder_atom_mask])
439
+
440
+ if force_crop or requires_crop:
441
+ # Get possible crop centers
442
+ can_be_crop_center = get_spatial_crop_center(
443
+ atom_array, query_pn_unit_iids, crop_center_cutoff_distance
444
+ )
445
+
446
+ # ... sample crop center atom
447
+ crop_center_atom_id = np.random.choice(atom_array[can_be_crop_center].atom_id)
448
+ crop_center_atom_idx = atom_id_to_atom_idx(atom_array, crop_center_atom_id)
449
+
450
+ # ... sample crop, excluding the binder polymer/non-polymer unit
451
+ token_coords = get_af3_token_center_coords(atom_array)
452
+ crop_center_token_idx = atom_id_to_token_idx(atom_array, crop_center_atom_id)
453
+ is_token_in_crop = get_spatial_crop_excluding_mask(
454
+ token_coords,
455
+ crop_center_token_idx,
456
+ crop_size=crop_size
457
+ - n_binder_tokens, # reserve space for the binder tokens
458
+ mask_to_exclude=binder_token_mask,
459
+ jitter_scale=jitter_scale,
460
+ )
461
+ # ... spread token-level crop mask to atom-level
462
+ is_atom_in_crop = spread_token_wise(atom_array, is_token_in_crop)
463
+
464
+ # ... add in binder tokens and atoms
465
+ is_token_in_crop = is_token_in_crop | binder_token_mask
466
+ is_atom_in_crop = is_atom_in_crop | binder_atom_mask
467
+ else:
468
+ # ... no need to crop since the atom array is already small enough
469
+ crop_center_atom_id = np.nan
470
+ crop_center_atom_idx = np.nan
471
+ crop_center_token_idx = np.nan
472
+ is_atom_in_crop = np.ones(len(atom_array), dtype=bool)
473
+ is_token_in_crop = np.ones(n_tokens, dtype=bool)
474
+
475
+ return {
476
+ "requires_crop": requires_crop, # whether cropping was necessary
477
+ "crop_center_atom_id": crop_center_atom_id, # atom_id of crop center
478
+ "crop_center_atom_idx": crop_center_atom_idx, # atom_idx of crop center
479
+ "crop_center_token_idx": crop_center_token_idx, # token_idx of crop center
480
+ "crop_token_idxs": np.where(is_token_in_crop)[0], # token_idxs in crop
481
+ "crop_atom_idxs": np.where(is_atom_in_crop)[0], # atom_idxs in crop
482
+ }
483
+
484
+
485
+ def get_spatial_crop_excluding_mask(
486
+ coord: np.ndarray,
487
+ crop_center_idx: int,
488
+ crop_size: int,
489
+ mask_to_exclude: np.ndarray,
490
+ jitter_scale: float = 1e-3,
491
+ ) -> np.ndarray:
492
+ """
493
+ Crop spatial tokens around a given `crop_center`, keeping nearest neighbors (with jitter) and excluding atoms in a
494
+ specified mask, until reaching the `crop_size`.
495
+
496
+ Args:
497
+ coord (np.ndarray): A 2D numpy array of shape (N, 3) representing the 3D token-level coordinates.
498
+ Coordinates are expected to be in Angstroms.
499
+ crop_center_idx (int): The index of the token to be used as the center of the crop.
500
+ crop_size (int): The number of nearest neighbors to include in the crop.
501
+ mask_to_exclude (siwnp.ndarray): A mask indicating atoms to be excluded from the crop.
502
+ jitter_scale (float): The scale of the jitter to add to the coordinates.
503
+
504
+ Returns:
505
+ crop_mask (np.ndarray): A boolean mask of shape (N,) where True indicates that the token is within the crop.
506
+
507
+ """
508
+ assert_that(coord.ndim).is_equal_to(2)
509
+ assert_that(coord.shape[1]).is_equal_to(3)
510
+ assert_that(crop_center_idx).is_less_than(coord.shape[0])
511
+ assert_that(crop_size).is_greater_than(0)
512
+ assert_that(jitter_scale).is_greater_than_or_equal_to(0)
513
+
514
+ # Add small jitter to coordinates to break ties
515
+ if jitter_scale > 0:
516
+ coord = coord + np.random.normal(scale=jitter_scale, size=coord.shape)
517
+
518
+ # ... get query center
519
+ query_center = coord[crop_center_idx]
520
+
521
+ # ... extract a mask for valid coordinates (i.e. no `nan`'s, which indicate unknown token centers)
522
+ # including including unoccupied tokens in the crop
523
+ is_valid = np.isfinite(coord).all(axis=1)
524
+
525
+ # ... exclude the specified pn_unit
526
+ is_valid = is_valid & ~mask_to_exclude
527
+
528
+ # ... build a KDTree for efficient querying, excluding invalid coordinates
529
+ tree = KDTree(coord[is_valid])
530
+
531
+ # ... query the `crop_size` nearest neighbors of the crop center
532
+ _, nearest_neighbor_idxs = tree.query(query_center, k=crop_size, p=2)
533
+ # ... filter out missing neighbours (index equal to `tree.n`)
534
+ nearest_neighbor_idxs = nearest_neighbor_idxs[nearest_neighbor_idxs < tree.n]
535
+
536
+ # ... crop mask is True for the `crop_size` nearest neighbors of the crop center
537
+ crop_mask = np.zeros(coord.shape[0], dtype=bool)
538
+ is_valid_and_in_crop_idxs = np.where(is_valid)[0][nearest_neighbor_idxs]
539
+ crop_mask[is_valid_and_in_crop_idxs] = True
540
+
541
+ return crop_mask
@@ -0,0 +1,116 @@
1
+ import numpy as np
2
+ from atomworks.ml.transforms.base import Transform
3
+ from atomworks.ml.transforms.sasa import calculate_atomwise_rasa
4
+ from atomworks.ml.utils.token import apply_and_spread_token_wise
5
+
6
+
7
+ class CalculateRASA(Transform):
8
+ """Transform for calculating relative SASA (RASA) for each atom in an AtomArray."""
9
+
10
+ def __init__(
11
+ self,
12
+ probe_radius: float = 1.4,
13
+ atom_radii: str | np.ndarray = "ProtOr",
14
+ point_number: int = 100,
15
+ requires_ligand=False,
16
+ ):
17
+ """
18
+ probe_radius (float, optional): Van-der-Waals radius of the probe in Angstrom. Defaults to 1.4 (for water).
19
+ atom_radii (str | np.ndarray, optional): Atom radii set to use for calculation. Defaults to "ProtOr".
20
+ "ProtOr" will not get sasa's for hydrogen atoms and some other atoms, like ions or certain atoms with charges
21
+ point_number (int, optional): Number of points in the Shrake-Rupley algorithm to sample for calculating SASA. Defaults to 100.
22
+ """
23
+ self.probe_radius = probe_radius
24
+ self.atom_radii = atom_radii
25
+ self.point_number = point_number
26
+ self.requires_ligand = requires_ligand
27
+
28
+ def forward(self, data):
29
+ atom_array = data["atom_array"]
30
+
31
+ if not np.any(atom_array.is_ligand) and self.requires_ligand:
32
+ return data
33
+
34
+ # Calculate exact rasa
35
+ rasa = calculate_atomwise_rasa(
36
+ atom_array, self.probe_radius, self.atom_radii, self.point_number
37
+ )
38
+ atom_array.set_annotation("rasa", rasa)
39
+
40
+ data["atom_array"] = atom_array
41
+ return data
42
+
43
+
44
+ def discretize_rasa(atom_array, low=0, high=0.2, n_bins=3, keep_protein_motif=False):
45
+ inclusion_mask = ~np.isnan(atom_array.rasa)
46
+ inclusion_mask = inclusion_mask & atom_array.is_motif_token
47
+ if not keep_protein_motif:
48
+ inclusion_mask = inclusion_mask & ~atom_array.is_protein
49
+
50
+ bin_edges = np.linspace(low, high, n_bins) # e.g., [0.0, 0.1, 0.2]
51
+ bins = (
52
+ np.digitize(atom_array.rasa, bin_edges, right=False)
53
+ - 1 # Subtract 1 since first bin would mean negative rasa!
54
+ ) # bins in [0, n_bins-1]
55
+ bins[~inclusion_mask] = n_bins # Assign excluded atoms to an additional, unused bin
56
+ return bins
57
+
58
+
59
+ class SetZeroOccOnDeltaRASA(Transform):
60
+ """
61
+ Recomputes RASA and sets zero-occupancy for those that have become significantly exposed
62
+
63
+ Used to measure if the atomwise RASA changed during cropping
64
+ """
65
+
66
+ requires_previous_transforms = [CalculateRASA]
67
+ incompatible_previous_transforms = [
68
+ "PadWithVirtualAtoms", # must have the same atom names
69
+ "CreateDesignReferenceFeatures",
70
+ "AggregateFeaturesLikeAF3WithoutMSA",
71
+ ]
72
+
73
+ def __init__(
74
+ self,
75
+ probe_radius: float = 1.4,
76
+ atom_radii: str | np.ndarray = "ProtOr",
77
+ point_number: int = 100,
78
+ ):
79
+ self.probe_radius = probe_radius
80
+ self.atom_radii = atom_radii
81
+ self.point_number = point_number
82
+
83
+ def check_input(self, data):
84
+ assert "rasa" in data["atom_array"].get_annotation_categories()
85
+
86
+ def forward(self, data):
87
+ atom_array = data["atom_array"]
88
+ rasa_old = atom_array.rasa
89
+
90
+ rasa_new = calculate_atomwise_rasa(
91
+ atom_array, self.probe_radius, self.atom_radii, self.point_number
92
+ )
93
+
94
+ delta_rasa = np.clip(rasa_new, a_min=0, a_max=0.2) - np.clip(
95
+ rasa_old, a_min=0, a_max=0.2
96
+ )
97
+ has_become_exposed = np.nan_to_num(delta_rasa) > 0.075
98
+ token_has_become_exposed = apply_and_spread_token_wise(
99
+ atom_array,
100
+ has_become_exposed,
101
+ function=lambda x: np.any(x),
102
+ )
103
+ is_sidechain = (
104
+ ~np.isin(atom_array.atom_name, ["N", "CA", "C", "O"])
105
+ & atom_array.is_residue
106
+ )
107
+
108
+ # Set zero occupancy for sidechains only
109
+ atom_has_become_exposed = token_has_become_exposed & is_sidechain
110
+
111
+ atom_array.occupancy[atom_has_become_exposed] = 0.0
112
+ # atom_array.res_name[token_has_become_exposed] = "UNK"
113
+
114
+ data["atom_array"] = atom_array
115
+
116
+ return data