rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,347 @@
1
+ """
2
+ This module contains the FeaturizeUserSettings transform that sets
3
+ mode-specific and common user features required by MPNN models.
4
+ """
5
+
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+ from atomworks.io.utils.atom_array_plus import AtomArrayPlus
10
+ from atomworks.io.utils.selection import get_annotation
11
+ from atomworks.ml.transforms._checks import (
12
+ check_atom_array_annotation,
13
+ )
14
+ from atomworks.ml.transforms.base import Transform
15
+ from atomworks.ml.utils.token import (
16
+ get_token_starts,
17
+ spread_token_wise,
18
+ )
19
+
20
+
21
+ class FeaturizeUserSettings(Transform):
22
+ """
23
+ Transform for featurizing user settings to MPNN model inputs.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ is_inference: bool = False,
29
+ minimal_return: bool = False,
30
+ train_structure_noise_default: float = 0.1,
31
+ ):
32
+ """
33
+ Initialize the FeaturizeUserSettings transform.
34
+
35
+ Args:
36
+ is_inference (bool): Whether this is inference mode. Defaults to
37
+ False (training mode).
38
+ minimal_return (bool): Whether to return minimal intermediate
39
+ features. Defaults to False.
40
+ train_structure_noise_default (float): Default standard
41
+ deviation of Gaussian noise to add to atomic coordinates during
42
+ training for data augmentation. Defaults to 0.1.
43
+ """
44
+ self.is_inference = is_inference
45
+ self.minimal_return = minimal_return
46
+ self.train_structure_noise_default = train_structure_noise_default
47
+
48
+ def check_input(self, data: dict[str, Any]) -> None:
49
+ """Check that atomize annotation exists in the data."""
50
+ check_atom_array_annotation(data, ["atomize"])
51
+
52
+ # Check that the scalar user settings have the correct types.
53
+ if data.get("structure_noise", None) is not None:
54
+ if not isinstance(data["structure_noise"], (float, int)):
55
+ raise TypeError("structure_noise must be a float or int")
56
+
57
+ if data.get("decode_type", None) is not None:
58
+ if not isinstance(data["decode_type"], str):
59
+ raise TypeError("decode_type must be a string")
60
+
61
+ if data.get("causality_pattern", None) is not None:
62
+ if not isinstance(data["causality_pattern"], str):
63
+ raise TypeError("causality_pattern must be a string")
64
+
65
+ if (
66
+ data.get("initialize_sequence_embedding_with_ground_truth", None)
67
+ is not None
68
+ ):
69
+ if not isinstance(
70
+ data["initialize_sequence_embedding_with_ground_truth"], bool
71
+ ):
72
+ raise TypeError(
73
+ "initialize_sequence_embedding_with_ground_truth must be a bool"
74
+ )
75
+
76
+ if data.get("atomize_side_chains", None) is not None:
77
+ if not isinstance(data["atomize_side_chains"], bool):
78
+ raise TypeError("atomize_side_chains must be a bool")
79
+
80
+ if data.get("repeat_sample_num", None) is not None:
81
+ if not isinstance(data["repeat_sample_num"], int):
82
+ raise TypeError("repeat_sample_num must be an int")
83
+
84
+ if data.get("features_to_return", None) is not None:
85
+ if not isinstance(data["features_to_return"], dict):
86
+ raise TypeError("features_to_return must be a dict")
87
+ for key, value in data["features_to_return"].items():
88
+ if not isinstance(key, str):
89
+ raise TypeError("features_to_return keys must be strings")
90
+ if not isinstance(value, list):
91
+ raise TypeError("features_to_return values must be lists")
92
+
93
+ # Check that the array-wide user settings are consistent across all
94
+ # atoms in each token.
95
+ atom_array = data["atom_array"]
96
+ token_starts = get_token_starts(atom_array)
97
+ token_level_array = atom_array[token_starts]
98
+ keys_to_check = [
99
+ "mpnn_designed_residue_mask",
100
+ "mpnn_temperature",
101
+ "mpnn_symmetry_equivalence_group",
102
+ "mpnn_symmetry_weight",
103
+ "mpnn_bias",
104
+ ]
105
+ for key in keys_to_check:
106
+ atom_values = get_annotation(atom_array, key)
107
+ if atom_values is not None:
108
+ token_values = get_annotation(token_level_array, key)
109
+ if not np.all(
110
+ atom_values == spread_token_wise(atom_array, token_values)
111
+ ):
112
+ raise ValueError(
113
+ f"All atoms in each token must have the same value for {key}"
114
+ )
115
+
116
+ # Check pair keys such that token-level pairs are unique.
117
+ pair_keys_to_check = [
118
+ "mpnn_pair_bias",
119
+ ]
120
+ token_idx = spread_token_wise(atom_array, np.arange(len(token_level_array)))
121
+ # Only validate 2D annotations if atom_array supports them
122
+ if isinstance(atom_array, AtomArrayPlus):
123
+ annotations_2d = atom_array.get_annotation_2d_categories()
124
+ for key in pair_keys_to_check:
125
+ if key in annotations_2d:
126
+ annotation = atom_array.get_annotation_2d(key)
127
+ pairs = annotation.pairs
128
+ seen_token_pairs = set()
129
+ for i, j in pairs:
130
+ token_pair = (token_idx[i], token_idx[j])
131
+ if token_pair in seen_token_pairs:
132
+ raise ValueError(
133
+ f"Token-level pairs must be unique for {key}"
134
+ " i.e. token pairs should be represented using "
135
+ "only one atom pair across the tokens."
136
+ )
137
+ seen_token_pairs.add(token_pair)
138
+
139
+ def forward(self, data: dict[str, Any]) -> dict[str, Any]:
140
+ """Apply user settings to the input features."""
141
+ # +-------- Scalar User Settings --------- +
142
+ # structure_noise (float): the standard deviation of the Gaussian
143
+ # noise to add to the input coordinates, in Angstroms.
144
+ structure_noise = data.get("structure_noise", None)
145
+ if structure_noise is None:
146
+ structure_noise = (
147
+ 0.0 if self.is_inference else self.train_structure_noise_default
148
+ )
149
+
150
+ # decode_type (str): the type of decoding to use.
151
+ # - "teacher_forcing": Use teacher forcing for the
152
+ # decoder, where the decoder attends to the ground
153
+ # truth sequence S for all previously decoded
154
+ # residues.
155
+ # - "auto_regressive": Use auto-regressive decoding,
156
+ # where the decoder attends to the sequence and
157
+ # decoder representation of residues that have
158
+ # already been decoded (using the predicted sequence).
159
+ decode_type = data.get("decode_type", None)
160
+ if decode_type is None:
161
+ decode_type = "auto_regressive" if self.is_inference else "teacher_forcing"
162
+
163
+ # causality_pattern (str): The pattern of causality to use for the
164
+ # decoder. For all causality patterns, the decoding order is randomized.
165
+ # - "auto_regressive": Use an auto-regressive causality
166
+ # pattern, where residues can attend to the sequence
167
+ # and decoder representation of residues that have
168
+ # already been decoded (NOTE: as mentioned above,
169
+ # this will be randomized).
170
+ # - "unconditional": Residues cannot attend to the
171
+ # sequence or decoder representation of any other
172
+ # residues.
173
+ # - "conditional": Residues can attend to the sequence
174
+ # and decoder representation of all other residues.
175
+ # - "conditional_minus_self": Residues can attend to the
176
+ # sequence and decoder representation of all other
177
+ # residues, except for themselves (as destination
178
+ # nodes).
179
+ causality_pattern = data.get("causality_pattern", None)
180
+ if causality_pattern is None:
181
+ causality_pattern = "auto_regressive"
182
+
183
+ # initialize_sequence_embedding_with_ground_truth (bool):
184
+ # - True: Initialize the sequence embedding with the ground truth
185
+ # sequence S.
186
+ # - If doing auto-regressive decoding, also
187
+ # initialize S_sampled with the ground truth
188
+ # sequence S, which should only affect the
189
+ # application of pair bias.
190
+ # - False: Initialize the sequence embedding with zeros.
191
+ # - If doing auto-regressive decoding, initialize
192
+ # S_sampled with unknown residues.
193
+ initialize_sequence_embedding_with_ground_truth = data.get(
194
+ "initialize_sequence_embedding_with_ground_truth", None
195
+ )
196
+ if initialize_sequence_embedding_with_ground_truth is None:
197
+ initialize_sequence_embedding_with_ground_truth = (
198
+ False if self.is_inference else True
199
+ )
200
+
201
+ # atomize_side_chains (bool): Whether to atomize side chains of fixed
202
+ # residues.
203
+ atomize_side_chains = data.get("atomize_side_chains", None)
204
+ if atomize_side_chains is None:
205
+ if data["model_type"] == "ligand_mpnn":
206
+ atomize_side_chains = False if self.is_inference else True
207
+ else:
208
+ atomize_side_chains = False
209
+
210
+ # repeat_sample_num (int, optional): Number of times to
211
+ # repeat the samples along the batch dimension. If None,
212
+ # no repetition is performed. If greater than 1, the
213
+ # samples are repeated along the batch dimension. If
214
+ # greater than 1, B must be 1, since repeating samples
215
+ # along the batch dimension is not supported when more
216
+ # than one sample is provided in the batch.
217
+ # NOTE: default is None, so no conditional needed.
218
+ repeat_sample_num = data.get("repeat_sample_num", None)
219
+
220
+ # features_to_return (dict, optional): dictionary
221
+ # determining which features to return from the model. If
222
+ # None, return all features (including modified input
223
+ # features, graph features, encoder features, and decoder
224
+ # features). Otherwise, expects a dictionary with the
225
+ # following key, value pairs:
226
+ # - "input_features": list - the input features to return.
227
+ # - "graph_features": list - the graph features to return.
228
+ # - "encoder_features": list - the encoder features to
229
+ # return.
230
+ # - "decoder_features": list - the decoder features to
231
+ # return.
232
+ features_to_return = data.get("features_to_return", None)
233
+ if features_to_return is None:
234
+ if self.minimal_return:
235
+ features_to_return = {
236
+ "input_features": [
237
+ "mask_for_loss",
238
+ ],
239
+ "decoder_features": ["log_probs", "S_sampled", "S_argmax"],
240
+ }
241
+
242
+ # Save the scalar settings.
243
+ data["input_features"].update(
244
+ {
245
+ "structure_noise": structure_noise,
246
+ "decode_type": decode_type,
247
+ "causality_pattern": causality_pattern,
248
+ "initialize_sequence_embedding_with_ground_truth": initialize_sequence_embedding_with_ground_truth,
249
+ "atomize_side_chains": atomize_side_chains,
250
+ "repeat_sample_num": repeat_sample_num,
251
+ "features_to_return": features_to_return,
252
+ }
253
+ )
254
+
255
+ # +-------- Array-Wide User Settings --------- +
256
+ # Extract atom array.
257
+ atom_array = data["atom_array"]
258
+
259
+ # Subset to non-atomized.
260
+ non_atomized_array = atom_array[~atom_array.atomize]
261
+
262
+ # Get token starts for non-atomized tokens.
263
+ non_atomized_token_starts = get_token_starts(non_atomized_array)
264
+ non_atomized_token_level = non_atomized_array[non_atomized_token_starts]
265
+
266
+ # Project token indices for non-atomized tokens.
267
+ non_atomized_token_idx = spread_token_wise(
268
+ non_atomized_array, np.arange(len(non_atomized_token_level))
269
+ )
270
+
271
+ if get_annotation(non_atomized_array, "mpnn_designed_residue_mask") is not None:
272
+ designed_residue_mask = (
273
+ non_atomized_token_level.mpnn_designed_residue_mask.astype(np.bool_)
274
+ )
275
+ else:
276
+ designed_residue_mask = None
277
+
278
+ if get_annotation(non_atomized_array, "mpnn_temperature") is not None:
279
+ temperature = non_atomized_token_level.mpnn_temperature.astype(np.float32)
280
+ else:
281
+ if self.is_inference:
282
+ temperature = 0.1 * np.ones(
283
+ len(non_atomized_token_level), dtype=np.float32
284
+ )
285
+ else:
286
+ temperature = None
287
+
288
+ if (
289
+ get_annotation(non_atomized_array, "mpnn_symmetry_equivalence_group")
290
+ is not None
291
+ ):
292
+ symmetry_equivalence_group = (
293
+ non_atomized_token_level.mpnn_symmetry_equivalence_group.astype(
294
+ np.int32
295
+ )
296
+ )
297
+ else:
298
+ symmetry_equivalence_group = None
299
+
300
+ if get_annotation(non_atomized_array, "mpnn_symmetry_weight") is not None:
301
+ symmetry_weight = non_atomized_token_level.mpnn_symmetry_weight.astype(
302
+ np.float32
303
+ )
304
+ else:
305
+ symmetry_weight = None
306
+
307
+ if get_annotation(non_atomized_array, "mpnn_bias") is not None:
308
+ bias = non_atomized_token_level.mpnn_bias.astype(np.float32)
309
+ else:
310
+ bias = None
311
+
312
+ if (
313
+ isinstance(non_atomized_array, AtomArrayPlus)
314
+ and "mpnn_pair_bias" in non_atomized_array.get_annotation_2d_categories()
315
+ ):
316
+ pair_bias_sparse = non_atomized_array.get_annotation_2d("mpnn_pair_bias")
317
+ pair_bias = np.zeros(
318
+ (
319
+ len(non_atomized_token_level),
320
+ pair_bias_sparse.values.shape[1],
321
+ len(non_atomized_token_level),
322
+ pair_bias_sparse.values.shape[2],
323
+ ),
324
+ dtype=np.float32,
325
+ )
326
+ for idx in range(pair_bias_sparse.values.shape[0]):
327
+ i, j, pair_bias_ij = pair_bias_sparse[idx]
328
+ token_idx_i = non_atomized_token_idx[i]
329
+ token_idx_j = non_atomized_token_idx[j]
330
+ pair_bias[token_idx_i, :, token_idx_j, :] = pair_bias_ij
331
+
332
+ else:
333
+ pair_bias = None
334
+
335
+ # Save the array-wide settings.
336
+ data["input_features"].update(
337
+ {
338
+ "designed_residue_mask": designed_residue_mask,
339
+ "temperature": temperature,
340
+ "symmetry_equivalence_group": symmetry_equivalence_group,
341
+ "symmetry_weight": symmetry_weight,
342
+ "bias": bias,
343
+ "pair_bias": pair_bias,
344
+ }
345
+ )
346
+
347
+ return data
@@ -0,0 +1,164 @@
1
+ """
2
+ Utilities for computing polymer-ligand interface atoms.
3
+
4
+ This module provides a transform to identify and annotate polymer atoms that
5
+ are at the interface with ligand molecules, defined as atoms within a specified
6
+ distance threshold.
7
+ """
8
+
9
+ from typing import Any
10
+
11
+ import numpy as np
12
+ from atomworks.ml.transforms._checks import check_atom_array_annotation
13
+ from atomworks.ml.transforms.base import Transform
14
+ from biotite.structure import AtomArray, CellList
15
+
16
+
17
+ class ComputePolymerLigandInterface(Transform):
18
+ """
19
+ Compute polymer and ligand atoms at the polymer-ligand interface and
20
+ annotate the atom array with interface labels.
21
+
22
+ An interface atom is defined as any polymer atom that is within the
23
+ distance_threshold of any ligand atom, or vice versa.
24
+
25
+ Args:
26
+ distance_threshold (float): Maximum distance in Angstroms for
27
+ considering atoms to be at the interface.
28
+ """
29
+
30
+ def __init__(self, distance_threshold: float):
31
+ self.distance_threshold = distance_threshold
32
+
33
+ def check_input(self, data: dict[str, Any]) -> None:
34
+ """Check that required annotations are present."""
35
+ check_atom_array_annotation(
36
+ {"atom_array": data["atom_array"]}, required=["element", "atomize"]
37
+ )
38
+
39
+ def forward(self, data: dict[str, Any]) -> dict[str, Any]:
40
+ """Compute polymer-ligand interface and update atom array."""
41
+ atom_array = data["atom_array"]
42
+
43
+ # Create a copy to avoid modifying the original.
44
+ result_array = atom_array.copy()
45
+
46
+ # Identify polymer and ligand atoms
47
+ polymer_mask, ligand_mask = self._identify_polymer_and_ligand_atoms(
48
+ result_array
49
+ )
50
+
51
+ # If no valid atoms, return empty annotations.
52
+ if not np.any(polymer_mask) or not np.any(ligand_mask):
53
+ # If no polymer or ligand atoms found, return empty annotations.
54
+ result_array.set_annotation(
55
+ "at_polymer_ligand_interface",
56
+ np.zeros(result_array.array_length(), dtype=bool),
57
+ )
58
+ else:
59
+ # Extract coordinates for interface calculation
60
+ polymer_atoms = result_array[polymer_mask]
61
+ ligand_atoms = result_array[ligand_mask]
62
+
63
+ # Compute interface atoms using efficient spatial search.
64
+ (polymer_interface_indices, ligand_interface_indices) = (
65
+ self._compute_interface_atoms(
66
+ polymer_atoms,
67
+ ligand_atoms,
68
+ polymer_mask,
69
+ ligand_mask,
70
+ self.distance_threshold,
71
+ )
72
+ )
73
+
74
+ # Annotate the atom array with interface information
75
+ result_array = self._annotate_interface_results(
76
+ result_array, polymer_interface_indices, ligand_interface_indices
77
+ )
78
+
79
+ data["atom_array"] = result_array
80
+ return data
81
+
82
+ def _identify_polymer_and_ligand_atoms(
83
+ self, atom_array: AtomArray
84
+ ) -> tuple[np.ndarray, np.ndarray]:
85
+ """Identify polymer and ligand atoms in the atom array."""
86
+ # Exclude atoms with invalid coordinates
87
+ has_valid_coords = (~np.isnan(atom_array.coord)).any(axis=1)
88
+
89
+ ligand_mask = atom_array.atomize & has_valid_coords
90
+ polymer_mask = ~atom_array.atomize & has_valid_coords
91
+
92
+ return polymer_mask, ligand_mask
93
+
94
+ def _compute_interface_atoms(
95
+ self,
96
+ polymer_atoms: AtomArray,
97
+ ligand_atoms: AtomArray,
98
+ polymer_mask: np.ndarray,
99
+ ligand_mask: np.ndarray,
100
+ distance_threshold: float,
101
+ ) -> tuple[np.ndarray, np.ndarray]:
102
+ """
103
+ Compute interface atoms using spatial data structures.
104
+
105
+ Returns:
106
+ Tuple containing:
107
+ - polymer_indices: Global indices of polymer atoms at interface
108
+ - ligand_indices: Global indices of ligand atoms at interface
109
+ """
110
+ # Build CellList for ligand atoms
111
+ ligand_cell_list = CellList(ligand_atoms, cell_size=distance_threshold)
112
+
113
+ # Find polymer atoms within threshold of any ligand.
114
+ polymer_at_interface_mask = ligand_cell_list.get_atoms(
115
+ polymer_atoms.coord, distance_threshold, as_mask=True
116
+ )
117
+ polymer_interface_local_indices = np.where(
118
+ np.any(polymer_at_interface_mask, axis=1)
119
+ )[0]
120
+
121
+ # Convert local indices to global indices.
122
+ global_polymer_indices = np.where(polymer_mask)[0]
123
+ polymer_interface_indices = global_polymer_indices[
124
+ polymer_interface_local_indices
125
+ ]
126
+
127
+ # Build CellList for polymer atoms.
128
+ polymer_cell_list = CellList(polymer_atoms, cell_size=distance_threshold)
129
+
130
+ # Find ligand atoms within threshold of any polymer.
131
+ ligand_at_interface_mask = polymer_cell_list.get_atoms(
132
+ ligand_atoms.coord, distance_threshold, as_mask=True
133
+ )
134
+ ligand_interface_local_indices = np.where(
135
+ np.any(ligand_at_interface_mask, axis=1)
136
+ )[0]
137
+
138
+ # Convert local indices to global indices.
139
+ global_ligand_indices = np.where(ligand_mask)[0]
140
+ ligand_interface_indices = global_ligand_indices[ligand_interface_local_indices]
141
+
142
+ return (polymer_interface_indices, ligand_interface_indices)
143
+
144
+ def _annotate_interface_results(
145
+ self,
146
+ atom_array: AtomArray,
147
+ polymer_interface_indices: np.ndarray,
148
+ ligand_interface_indices: np.ndarray,
149
+ ) -> AtomArray:
150
+ """Annotate the atom array with interface calculation results."""
151
+ n_atoms = atom_array.array_length()
152
+
153
+ # Initialize interface annotations.
154
+ at_polymer_ligand_interface = np.zeros(n_atoms, dtype=bool)
155
+
156
+ # Mark interface atoms.
157
+ at_polymer_ligand_interface[polymer_interface_indices] = True
158
+ at_polymer_ligand_interface[ligand_interface_indices] = True
159
+
160
+ # Add annotation to atom array.
161
+ atom_array.set_annotation(
162
+ "at_polymer_ligand_interface", at_polymer_ligand_interface
163
+ )
164
+ return atom_array