rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,76 @@
1
+ import numpy as np
2
+ import torch
3
+ from atomworks.ml.transforms.base import Transform
4
+ from rfd3.inference.symmetry.frames import (
5
+ framecoords_to_RTs,
6
+ unpack_vector,
7
+ )
8
+
9
+
10
+ class AddSymmetryFeats(Transform):
11
+ """
12
+ Add atom_array symmetry features to the data features.
13
+ Arguments:
14
+ symmetry_features: The atom_array symmetry features to add to the data features.
15
+ Returns:
16
+ data: The data with the atom_array symmetry features added to the data features.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ symmetry_features=[
22
+ "sym_transform_id",
23
+ "sym_entity_id",
24
+ "is_sym_asu",
25
+ ],
26
+ ):
27
+ self.symmetry_feats = symmetry_features
28
+
29
+ def forward(self, data):
30
+ atom_array = data["atom_array"]
31
+ # Get frames from atom_array
32
+ transforms_dict = self.make_transforms_dict(atom_array)
33
+ data["feats"]["sym_transform"] = transforms_dict # {str(id): tuple (R,T)}
34
+ # Else, add symmetry features atomwise
35
+ for feature_name in self.symmetry_feats:
36
+ feature_array = atom_array.get_annotation(feature_name)
37
+ data["feats"][feature_name] = feature_array
38
+ return data
39
+
40
+ def make_transforms_dict(self, atom_array):
41
+ transforms_dict = {}
42
+ # get decomposed frames from atom array (unpacking the vectorized frames)
43
+ Oris = torch.tensor(
44
+ [
45
+ np.asarray(unpack_vector(Ori)).tolist()
46
+ for Ori in atom_array.get_annotation("sym_transform_Ori")
47
+ ]
48
+ )
49
+ Xs = torch.tensor(
50
+ [
51
+ np.asarray(unpack_vector(X)).tolist()
52
+ for X in atom_array.get_annotation("sym_transform_X")
53
+ ]
54
+ )
55
+ Ys = torch.tensor(
56
+ [
57
+ np.asarray(unpack_vector(Y)).tolist()
58
+ for Y in atom_array.get_annotation("sym_transform_Y")
59
+ ]
60
+ )
61
+ TIDs = torch.from_numpy(atom_array.get_annotation("sym_transform_id"))
62
+
63
+ Oris = torch.unique_consecutive(Oris, dim=0)
64
+ Xs = torch.unique_consecutive(Xs, dim=0)
65
+ Ys = torch.unique_consecutive(Ys, dim=0)
66
+ TIDs = torch.unique_consecutive(TIDs, dim=0)
67
+ # the case in which there is only rotation (no translation), Ori = [0,0,0]
68
+ if len(Oris) == 1 and (Oris == 0).all():
69
+ Oris = Oris.repeat(len(Xs), 1)
70
+ Rs, Ts = framecoords_to_RTs(Oris, Xs, Ys)
71
+
72
+ for R, T, transform_id in zip(Rs, Ts, TIDs):
73
+ if transform_id.item() == -1:
74
+ continue
75
+ transforms_dict[str(transform_id.item())] = (R, T)
76
+ return transforms_dict
@@ -0,0 +1,552 @@
1
+ """
2
+ Class-based motif masking system
3
+ """
4
+
5
+ import logging
6
+ from abc import ABC, abstractmethod
7
+
8
+ import networkx as nx
9
+ import numpy as np
10
+ from atomworks.ml.utils.token import (
11
+ apply_token_wise,
12
+ get_token_starts,
13
+ spread_token_wise,
14
+ )
15
+ from biotite.structure import AtomArray, get_residue_starts
16
+ from rfd3.transforms.conditioning_utils import (
17
+ random_condition,
18
+ sample_island_tokens,
19
+ sample_subgraph_atoms,
20
+ )
21
+
22
+ nx.from_numpy_matrix = nx.from_numpy_array
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ #################################################################################
27
+ # Transform for creating training conditions
28
+ #################################################################################
29
+
30
+
31
+ class TrainingCondition(ABC):
32
+ """
33
+ Base class for applying conditioning during training
34
+ """
35
+
36
+ name = None
37
+
38
+ def __init__(self, frequency):
39
+ self.frequency = frequency
40
+
41
+ @abstractmethod
42
+ def is_valid_for_example(self, data) -> bool:
43
+ """
44
+ Returns true whether this mask can be applied to the data instance
45
+
46
+ E.g. only use this transform if data metadata contains key or if data contains type
47
+ """
48
+
49
+ @abstractmethod
50
+ def sample(self, data) -> AtomArray:
51
+ """
52
+ Set which atoms should be made into tokens
53
+ """
54
+
55
+
56
+ class IslandCondition(TrainingCondition):
57
+ """
58
+ Select islands as motif and assign conditioning strategies.
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ *,
64
+ name,
65
+ frequency,
66
+ island_sampling_kwargs,
67
+ p_diffuse_motif_sidechains,
68
+ p_diffuse_subgraph_atoms,
69
+ subgraph_sampling_kwargs,
70
+ p_fix_motif_coordinates,
71
+ p_fix_motif_sequence,
72
+ p_unindex_motif_tokens,
73
+ ):
74
+ self.name = name
75
+ self.frequency = frequency
76
+
77
+ # Token selection
78
+ self.island_sampling_kwargs = island_sampling_kwargs
79
+
80
+ # Atom selection
81
+ self.p_diffuse_motif_sidechains = p_diffuse_motif_sidechains
82
+ self.p_include_oxygen_in_backbone_mask = 0.95
83
+ self.p_diffuse_subgraph_atoms = p_diffuse_subgraph_atoms
84
+ self.subgraph_sampling_kwargs = subgraph_sampling_kwargs
85
+
86
+ # Additional conditioning selection
87
+ self.p_fix_motif_coordinates = p_fix_motif_coordinates
88
+ self.p_fix_motif_sequence = p_fix_motif_sequence
89
+ self.p_unindex_motif_tokens = p_unindex_motif_tokens
90
+
91
+ def is_valid_for_example(self, data) -> bool:
92
+ is_protein = data["atom_array"].is_protein
93
+ if not np.any(is_protein):
94
+ return False
95
+ return True
96
+
97
+ def sample_motif_tokens(self, atom_array):
98
+ """
99
+ Samples what tokens should be considered motif.
100
+ """
101
+ token_level_array = atom_array[get_token_starts(atom_array)]
102
+
103
+ # initialize motif tokens as all non-protein tokens
104
+ is_motif_token = np.asarray(~token_level_array.is_protein, dtype=bool).copy()
105
+ n_protein_tokens = np.sum(token_level_array.is_protein)
106
+ islands_mask = sample_island_tokens(
107
+ n_protein_tokens,
108
+ **self.island_sampling_kwargs,
109
+ )
110
+ is_motif_token[token_level_array.is_protein] = islands_mask
111
+
112
+ # TODO: Atoms with covalent bonds should be motif, needs FlagAndReassignCovalentModifications transform prior to this
113
+ # atom_with_coval_bond = token_level_array.covale # (n_atoms, )
114
+ # is_motif_token[atom_with_coval_bond] = True
115
+
116
+ return spread_token_wise(atom_array, is_motif_token)
117
+
118
+ def sample_motif_atoms(self, atom_array):
119
+ """
120
+ Samples which atoms in motif tokens should be masked.
121
+ This handles the case where you want the sidechain of a residue to not be motif.
122
+
123
+ Argument attrs:
124
+ - is_motif_token
125
+ - is_motif_atom_with_fixed_seq
126
+ """
127
+ is_motif_atom = np.asarray(atom_array.is_motif_token, dtype=bool).copy()
128
+
129
+ if random_condition(self.p_diffuse_motif_sidechains):
130
+ backbone_atoms = ["N", "C", "CA"]
131
+ if random_condition(self.p_include_oxygen_in_backbone_mask):
132
+ backbone_atoms.append("O")
133
+ is_motif_atom = is_motif_atom & np.isin(
134
+ atom_array.atom_name, backbone_atoms
135
+ )
136
+ elif random_condition(self.p_diffuse_subgraph_atoms):
137
+ is_motif_atom = sample_motif_subgraphs(
138
+ atom_array=atom_array,
139
+ **self.subgraph_sampling_kwargs,
140
+ )
141
+
142
+ # We also only want resolved atoms to be motif
143
+ is_motif_atom = (is_motif_atom) & (atom_array.occupancy > 0.0)
144
+
145
+ return is_motif_atom
146
+
147
+ def sample(self, data):
148
+ atom_array = data["atom_array"]
149
+
150
+ atom_array.set_annotation(
151
+ "is_motif_token", self.sample_motif_tokens(atom_array)
152
+ )
153
+ atom_array.set_annotation("is_motif_atom", self.sample_motif_atoms(atom_array))
154
+
155
+ # After selecting the motif, we need to decide what conditioning strategy to apply
156
+ atom_array = sample_conditioning_strategy(
157
+ atom_array,
158
+ p_fix_motif_sequence=self.p_fix_motif_sequence,
159
+ p_fix_motif_coordinates=self.p_fix_motif_coordinates,
160
+ p_unindex_motif_tokens=self.p_unindex_motif_tokens,
161
+ )
162
+
163
+ atom_array.set_annotation(
164
+ "is_motif_atom_unindexed_motif_breakpoint",
165
+ sample_unindexed_breaks(
166
+ atom_array,
167
+ remove_random_break=data["conditions"]["unindex_remove_random_break"],
168
+ insert_random_break=data["conditions"]["unindex_insert_random_break"],
169
+ leak_global_index=data["conditions"]["unindex_leak_global_index"],
170
+ ),
171
+ )
172
+
173
+ return atom_array
174
+
175
+
176
+ class PPICondition(TrainingCondition):
177
+ """Get condition indicating what is motif and what is to be diffused for protein-protein interaction training."""
178
+
179
+ name = "ppi"
180
+
181
+ def is_valid_for_example(self, data):
182
+ # Extract relevant data
183
+ atom_array = data["atom_array"]
184
+ self.query_pn_unit_iids = data.get("query_pn_unit_iids")
185
+
186
+ # Compute protein pn_unit_iids
187
+ protein_pn_unit_iids = []
188
+ for pn_unit_iid in np.unique(atom_array.pn_unit_iid):
189
+ pn_unit_atom_array = atom_array[atom_array.pn_unit_iid == pn_unit_iid]
190
+ pn_unit_is_protein = np.unique(pn_unit_atom_array.is_protein)
191
+
192
+ if all(pn_unit_is_protein): # Exclude cases of chimeric ligands
193
+ protein_pn_unit_iids.append(pn_unit_iid)
194
+
195
+ # This mask is intended to operate on binary protein-protein interfaces
196
+ if (
197
+ self.query_pn_unit_iids is None
198
+ or len(self.query_pn_unit_iids) != 2
199
+ or len(np.unique(self.query_pn_unit_iids)) != 2
200
+ ):
201
+ return False
202
+
203
+ elif not all(
204
+ [pn_unit in protein_pn_unit_iids for pn_unit in self.query_pn_unit_iids]
205
+ ):
206
+ return False
207
+
208
+ else:
209
+ # Randomly select one of the two query pn_unit_iids to be the binder
210
+ # NOTE: Could also do this based on if only one will work uncropped, but since that
211
+ # strategy will not always be applied, enforcing it here would bias the training data.
212
+ binder_pn_unit = np.random.choice(self.query_pn_unit_iids)
213
+ data["binder_pn_unit"] = binder_pn_unit
214
+ atom_array.set_annotation(
215
+ "is_binder_pn_unit", atom_array.pn_unit_iid == binder_pn_unit
216
+ )
217
+ return True
218
+
219
+ # TODO: If I want to have multiple possible strategies for motif assignment (e.g. motif scaffolding for the binder)
220
+ # should probably just have this function sample between them with a set of probabilities specified in the config.
221
+ # Anything that makes it this far will have to be a valid PPI example with an assigned binder chain.
222
+ def sample(self, data):
223
+ atom_array = data["atom_array"]
224
+
225
+ # Set `is_motif_token`
226
+ # NOTE: In the future, we may want to diffuse part of the target or fix part of the binder
227
+ is_motif_token = atom_array.pn_unit_iid != data["binder_pn_unit"]
228
+ atom_array.set_annotation("is_motif_token", is_motif_token)
229
+
230
+ # Set `is_motif_atom_with_fixed_seq`
231
+ is_motif_atom_with_fixed_seq = (
232
+ is_motif_token.copy()
233
+ ) # We fix the target sequence in binder design
234
+ atom_array.set_annotation(
235
+ "is_motif_atom_with_fixed_seq", is_motif_atom_with_fixed_seq
236
+ )
237
+
238
+ # Set `is_motif_atom`
239
+ is_motif_atom = (
240
+ is_motif_token.copy()
241
+ ) # The PPI mask should apply to all or no atoms of a token
242
+ atom_array.set_annotation("is_motif_atom", is_motif_atom)
243
+
244
+ # Set `is_motif_atom_with_fixed_pos`
245
+ is_motif_atom_with_fixed_coord = (
246
+ is_motif_token.copy()
247
+ ) # We fully fix the target atom positions (at least for now)
248
+ atom_array.set_annotation(
249
+ "is_motif_atom_with_fixed_coord", is_motif_atom_with_fixed_coord
250
+ )
251
+
252
+ # Set `is_motif_atom_without_index`
253
+ is_motif_atom_unindexed = np.zeros_like(
254
+ is_motif_token
255
+ ) # We want fixed indices for the target
256
+ atom_array.set_annotation("is_motif_atom_unindexed", is_motif_atom_unindexed)
257
+
258
+ # Set `is_motif_atom_unindexed_motif_breakpoint`
259
+ is_motif_atom_unindexed_motif_breakpoint = np.zeros_like(is_motif_token)
260
+ atom_array.set_annotation(
261
+ "is_motif_atom_unindexed_motif_breakpoint",
262
+ is_motif_atom_unindexed_motif_breakpoint,
263
+ )
264
+ return atom_array
265
+
266
+
267
+ ##############################################################################################
268
+ # Additional conditioning classes
269
+ ##############################################################################################
270
+
271
+
272
+ class SubtypeCondition(TrainingCondition):
273
+ """
274
+ Selects specific subtypes of atoms as motif and assigns conditioning strategies.
275
+ """
276
+
277
+ name = "subtype"
278
+
279
+ def __init__(self, frequency: float, subtype: list[str], fix_pos: bool = False):
280
+ self.frequency = frequency
281
+ self.subtype = subtype
282
+ self.fix_pos = fix_pos
283
+
284
+ def is_valid_for_example(self, data):
285
+ """
286
+ For subtype conditioning, the example must contain the specified subtype
287
+ """
288
+ is_subtypes = [
289
+ data["atom_array"].get_annotation(subtype).sum() for subtype in self.subtype
290
+ ]
291
+ if not np.any(is_subtypes):
292
+ return False
293
+ return True
294
+
295
+ def sample(self, data):
296
+ atom_array = data["atom_array"]
297
+
298
+ is_motif = generate_subtype_mask(atom_array, self.subtype)
299
+ is_motif = prune_unresolved_motif(atom_array, is_motif)
300
+ atom_array.set_annotation("is_motif_token", is_motif)
301
+ atom_array.set_annotation("is_motif_atom", is_motif)
302
+ atom_array.set_annotation("is_motif_atom_with_fixed_seq", is_motif)
303
+
304
+ if self.fix_pos:
305
+ atom_array.set_annotation("is_motif_atom_with_fixed_coord", is_motif)
306
+ else:
307
+ atom_array.set_annotation(
308
+ "is_motif_atom_with_fixed_coord", np.zeros(len(atom_array), dtype=bool)
309
+ )
310
+ atom_array.set_annotation(
311
+ "is_motif_atom_unindexed", np.zeros(len(atom_array), dtype=bool)
312
+ )
313
+ atom_array.set_annotation(
314
+ "is_motif_atom_unindexed_motif_breakpoint",
315
+ np.zeros(len(atom_array), dtype=bool),
316
+ )
317
+
318
+ return atom_array
319
+
320
+
321
+ ################# need mask -> condition refactor
322
+ def prune_unresolved_motif(atom_array, mask):
323
+ """
324
+ Prune the mask to only include resolved atoms.
325
+ and for any residue that have unresolved atoms, set the whole residue to be False.
326
+ """
327
+ # Get the indices of the atoms that are resolved
328
+ resolved_indices = np.where(atom_array.occupancy > 0.0)[0]
329
+
330
+ # Create a mask for the resolved atoms
331
+ resolved_mask = np.zeros_like(mask, dtype=bool)
332
+ resolved_mask[resolved_indices] = True
333
+
334
+ # Combine the original mask with the resolved mask
335
+ combined_mask = mask & resolved_mask
336
+
337
+ # Set the whole residue to be False if any atom in the residue is unresolved
338
+ token_ids = np.unique(atom_array.token_id)
339
+ for token_id in token_ids:
340
+ if np.any(~combined_mask[atom_array.token_id == token_id]):
341
+ combined_mask[atom_array.token_id == token_id] = False
342
+ return combined_mask
343
+
344
+
345
+ def generate_subtype_mask(atom_array, subtypes):
346
+ """
347
+ Generate a mask for a specific subtype list of atoms.
348
+ E.g. is_protein, is_ligand, is_dna etc.
349
+ """
350
+ all_masks = []
351
+ for subtype in subtypes:
352
+ if subtype not in atom_array.get_annotation_categories():
353
+ raise ValueError(f"Subtype {subtype} not found in atom array annotations.")
354
+ mask = atom_array.get_annotation(subtype)
355
+ all_masks.append(mask)
356
+ # Combine all masks using logical OR
357
+ combined_mask = np.logical_or.reduce(all_masks)
358
+ return combined_mask
359
+
360
+
361
+ ##############################################################################################
362
+ # Shared assignment functions
363
+ ##############################################################################################
364
+
365
+
366
+ def sample_motif_subgraphs(
367
+ atom_array,
368
+ residue_p_seed_furthest_from_o,
369
+ residue_n_bond_expectation,
370
+ hetatom_n_bond_expectation,
371
+ residue_p_fix_all,
372
+ hetatom_p_fix_all,
373
+ ):
374
+ """
375
+ Returns a boolean mask over atoms, indicating which atoms are part of the sampled motif.
376
+ Sampling is performed per residue, with sidechains optionally excluded based on bond-based neighborhood sampling.
377
+
378
+ Handles both protein residues and heteroatoms (e.g., ligands).
379
+
380
+ Args:
381
+ atom_array: AtomArray with annotations is_motif_token, is_protein, occupancy, res_id.
382
+
383
+ Returns:
384
+ is_motif_atom: np.ndarray of shape (n_atoms,) with True for sampled motif atoms.
385
+ """
386
+ is_motif_token = atom_array.is_motif_token.copy()
387
+ is_motif_atom = is_motif_token.copy()
388
+ idxs = np.arange(atom_array.array_length(), dtype=int)
389
+ starts = get_residue_starts(atom_array, add_exclusive_stop=True)
390
+
391
+ for i, (start, end) in enumerate(zip(starts[:-1], starts[1:])):
392
+ if not is_motif_token[start]:
393
+ continue
394
+
395
+ # Get the atoms of the current residue
396
+ subset_mask = np.isin(idxs, idxs[start:end])
397
+ atom_array_subset = atom_array[subset_mask]
398
+ assert atom_array_subset.array_length() > 0
399
+
400
+ args = {
401
+ "p_seed_furthest_from_o": residue_p_seed_furthest_from_o,
402
+ "n_bond_expectation": residue_n_bond_expectation,
403
+ "p_fix_all": residue_p_fix_all,
404
+ }
405
+ if not atom_array_subset.is_protein.all():
406
+ args.update(
407
+ {
408
+ "p_seed_furthest_from_o": 0.0,
409
+ "n_bond_expectation": hetatom_n_bond_expectation,
410
+ "p_fix_all": hetatom_p_fix_all,
411
+ }
412
+ )
413
+ try:
414
+ mask = sample_subgraph_atoms(atom_array_subset, **args)
415
+ except Exception as e:
416
+ logger.warning(
417
+ f"Failed to sample subgraph motif atoms for {atom_array_subset.res_name[0]}. Error: {e}"
418
+ )
419
+ mask = np.ones(atom_array_subset.array_length(), dtype=bool)
420
+
421
+ is_motif_atom[subset_mask] = mask
422
+
423
+ # We also only want resolved atoms to be motif
424
+ is_motif_atom = (is_motif_atom) & (atom_array.occupancy > 0.0)
425
+
426
+ return is_motif_atom
427
+
428
+
429
+ def sample_conditioning_strategy(
430
+ atom_array,
431
+ p_fix_motif_sequence,
432
+ p_fix_motif_coordinates,
433
+ p_unindex_motif_tokens,
434
+ ):
435
+ atom_array.set_annotation(
436
+ "is_motif_atom_with_fixed_seq",
437
+ sample_is_motif_atom_with_fixed_seq(
438
+ atom_array, p_fix_motif_sequence=p_fix_motif_sequence
439
+ ),
440
+ )
441
+
442
+ atom_array.set_annotation(
443
+ "is_motif_atom_with_fixed_coord",
444
+ sample_fix_motif_coordinates(
445
+ atom_array, p_fix_motif_coordinates=p_fix_motif_coordinates
446
+ ),
447
+ )
448
+
449
+ atom_array.set_annotation(
450
+ "is_motif_atom_unindexed",
451
+ sample_unindexed_atoms(
452
+ atom_array, p_unindex_motif_tokens=p_unindex_motif_tokens
453
+ ),
454
+ )
455
+
456
+ return atom_array
457
+
458
+
459
+ def sample_is_motif_atom_with_fixed_seq(atom_array, p_fix_motif_sequence):
460
+ """
461
+ Samples what kind of conditioning to apply to motif tokens.
462
+
463
+ Argument attrs:
464
+ - is_motif_token
465
+ """
466
+ if random_condition(p_fix_motif_sequence):
467
+ is_motif_atom_with_fixed_seq = atom_array.is_motif_token.copy()
468
+ else:
469
+ is_motif_atom_with_fixed_seq = np.zeros(atom_array.array_length(), dtype=bool)
470
+
471
+ # By default reveal sequence for non-protein
472
+ is_motif_atom_with_fixed_seq = is_motif_atom_with_fixed_seq | ~atom_array.is_protein
473
+ return is_motif_atom_with_fixed_seq
474
+
475
+
476
+ def sample_fix_motif_coordinates(atom_array, p_fix_motif_coordinates):
477
+ """
478
+ Universal function to decide if atoms' coords are fixed in the point cloud for conditioning.
479
+
480
+ Argument attrs:
481
+ - is_motif_atom_with_fixed_coord
482
+ """
483
+ if random_condition(p_fix_motif_coordinates):
484
+ is_motif_atom_with_fixed_coord = atom_array.is_motif_atom.copy()
485
+ else:
486
+ is_motif_atom_with_fixed_coord = np.zeros(atom_array.array_length(), dtype=bool)
487
+ return is_motif_atom_with_fixed_coord
488
+
489
+
490
+ def sample_unindexed_atoms(atom_array, p_unindex_motif_tokens):
491
+ """
492
+ Samples which atoms in motif tokens should be flagged for unindexing.
493
+
494
+ Argument attrs:
495
+ - is_motif_atom_unindexed
496
+ """
497
+ if random_condition(p_unindex_motif_tokens):
498
+ is_motif_atom_unindexed = atom_array.is_motif_atom.copy()
499
+ else:
500
+ is_motif_atom_unindexed = np.zeros(atom_array.array_length(), dtype=bool)
501
+
502
+ # ensure non-residue atoms are not already flagged
503
+ is_motif_atom_unindexed = np.logical_and(
504
+ is_motif_atom_unindexed, atom_array.is_residue
505
+ )
506
+
507
+ return is_motif_atom_unindexed
508
+
509
+
510
+ def sample_unindexed_breaks(
511
+ atom_array,
512
+ remove_random_break=False,
513
+ insert_random_break=False,
514
+ leak_global_index=False,
515
+ ):
516
+ is_unindexed_token = apply_token_wise(
517
+ atom_array,
518
+ atom_array.is_motif_atom_unindexed.copy(),
519
+ function=lambda x: np.any(x),
520
+ )
521
+ starts = get_token_starts(atom_array)
522
+ token_idxs = np.arange(len(starts))
523
+ breaks_all = np.zeros(len(starts), dtype=bool)
524
+
525
+ if is_unindexed_token.sum() == 1:
526
+ breaks_all = is_unindexed_token
527
+ elif np.any(is_unindexed_token):
528
+ # ... Subset to unindexed tokens
529
+ unindexed_token_starts = starts[is_unindexed_token]
530
+ unindexed_token_resid = atom_array[unindexed_token_starts].res_id
531
+ breaks = np.diff(unindexed_token_resid) != 1 # (M-1,)
532
+
533
+ # ... Connect discontiguous regions
534
+ if remove_random_break and np.any(breaks):
535
+ break_idx = np.random.choice(np.flatnonzero(breaks), size=1, replace=False)
536
+ breaks[break_idx] = False
537
+
538
+ # ... Disconnect contiguous regions
539
+ if insert_random_break:
540
+ break_idx = np.random.choice(np.arange(len(breaks)), size=1, replace=False)
541
+ breaks[break_idx] = True
542
+
543
+ breaks[0] = True
544
+ breaks = np.concatenate([np.array([False], dtype=bool), breaks])
545
+
546
+ # ... Remove all breaks to leak global indices:
547
+ if leak_global_index:
548
+ breaks = False
549
+
550
+ breaks_all[token_idxs[is_unindexed_token]] = breaks
551
+
552
+ return spread_token_wise(atom_array, breaks_all)