rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,415 @@
1
+ import random
2
+ import re
3
+ from typing import List
4
+
5
+ import numpy as np
6
+ from atomworks.ml.encoding_definitions import AF3SequenceEncoding
7
+ from biotite.structure import AtomArray
8
+ from rfd3.constants import (
9
+ TIP_BY_RESTYPE,
10
+ )
11
+
12
+ from foundry.common import exists
13
+ from foundry.utils.ddp import RankedLogger
14
+
15
+ global_logger = RankedLogger(__name__, rank_zero_only=False)
16
+ sequence_encoding = AF3SequenceEncoding()
17
+ _aa_like_res_names = sequence_encoding.all_res_names[sequence_encoding.is_aa_like]
18
+
19
+
20
+ #################################################################################
21
+ # Component / contig parsing
22
+ #################################################################################
23
+
24
+
25
+ class ComponentValidationError(ValueError):
26
+ """Raised when contig/component inputs cannot be parsed or validated."""
27
+
28
+ def __init__(
29
+ self,
30
+ message: str,
31
+ *,
32
+ component: str | None = None,
33
+ details: dict | None = None,
34
+ ):
35
+ self.component = component
36
+ self.details = details or {}
37
+ prefix = f"[component={component}] " if component else ""
38
+ suffix = f" Details: {self.details}" if self.details else ""
39
+ super().__init__(f"{prefix}{message}{suffix}")
40
+
41
+
42
+ class ComponentStr(str):
43
+ """Component identifier, e.g. "A1" for residues, "B12", etc. Previously named `contig_string`"""
44
+
45
+ def split_component(v):
46
+ return split_contig(v)
47
+
48
+
49
+ def split_contig(x):
50
+ try:
51
+ chain = str(x[0])
52
+ idx = x[1:]
53
+ idx = int(idx)
54
+ if idx < 0:
55
+ raise ComponentValidationError(
56
+ "Residue index must be a non-negative integer.", component=str(x)
57
+ )
58
+ except Exception as e:
59
+ raise ComponentValidationError(
60
+ f"Invalid contig format: '{x}'. Expected format is 'ChainIDResID' (e.g. 'A20').",
61
+ component=str(x),
62
+ ) from e
63
+ return [chain, idx]
64
+
65
+
66
+ def extract_pn_unit_info(contig):
67
+ """
68
+ Convert substring like A20-21 or A20 to separate terms: A, 20, 21.
69
+ """
70
+ pattern = r"([A-Za-z])(\d+)(?:-(\d+))?"
71
+
72
+ match = re.match(pattern, contig)
73
+ if match:
74
+ pn_unit_id = match.group(1)
75
+ start = int(match.group(2))
76
+ end = int(match.group(3)) if match.group(3) else start
77
+ return pn_unit_id, start, end
78
+
79
+ raise ComponentValidationError(
80
+ "Invalid contig format. Expected 'ChainIDStart-Stop' or 'ChainIDIdx'.",
81
+ component=contig,
82
+ )
83
+
84
+
85
+ def get_design_pattern_with_constraints(contig, length=None):
86
+ """
87
+ Convert the contig string to separate modules.
88
+ e.g. '1-5,A20-21,1-5,A25-25,1-5,A30-30,/0,1-5' with length = 10-10 may be converted to [2, A20, A21, 2, A25, 3, A30, /0, 3]
89
+ Integers represent number of free residues to put there.
90
+
91
+ """
92
+ contig_parts = contig.split(",")
93
+
94
+ # Separate fixed segments (e.g., "A1051-1051") and variable ranges (e.g., "0-40")
95
+ variable_ranges = []
96
+ fixed_parts = []
97
+ pos_to_put_motif = []
98
+
99
+ for part in contig_parts:
100
+ if any(c.isalpha() for c in part): # Detect parts containing letters as fixed
101
+ pn_unit_id, pn_unit_start, pn_unit_end = extract_pn_unit_info(part)
102
+ fixed_parts.append([pn_unit_id, pn_unit_start, pn_unit_end])
103
+ pos_to_put_motif.append(1)
104
+ elif part == "/0":
105
+ pos_to_put_motif.append(2)
106
+ else:
107
+ if "-" in part:
108
+ start, end = map(int, part.split("-"))
109
+ else:
110
+ start = end = int(part)
111
+ variable_ranges.append([start, end])
112
+ pos_to_put_motif.append(0)
113
+
114
+ # adjust the total length to solely for free residues
115
+ num_motif_residues = sum([i[2] - i[1] + 1 for i in fixed_parts])
116
+
117
+ if length is None:
118
+ length_min, length_max = 0, 9999
119
+ else:
120
+ if "-" in length:
121
+ length_min, length_max = map(int, length.split("-"))
122
+ else:
123
+ length_min = length_max = int(length)
124
+
125
+ length_min -= num_motif_residues
126
+ length_max -= num_motif_residues
127
+
128
+ remaining_length_min = length_min
129
+ remaining_length_max = length_max
130
+
131
+ num_free_atoms = []
132
+ for range_limits in variable_ranges:
133
+ min_value = range_limits[0]
134
+ max_value = range_limits[1]
135
+
136
+ # Calculate the valid range for the current segment
137
+ valid_min = max(
138
+ min_value,
139
+ remaining_length_min
140
+ - sum(r[1] for r in variable_ranges[len(num_free_atoms) + 1 :]),
141
+ )
142
+ valid_max = min(
143
+ max_value,
144
+ remaining_length_max
145
+ - sum(r[0] for r in variable_ranges[len(num_free_atoms) + 1 :]),
146
+ )
147
+
148
+ if valid_min > valid_max and length is not None:
149
+ raise ComponentValidationError(
150
+ "No valid selections possible with the given constraints."
151
+ )
152
+
153
+ # Randomly select a value for the current segment
154
+ selected_value = random.randint(valid_min, valid_max)
155
+ num_free_atoms.append(selected_value)
156
+
157
+ # Update remaining lengths
158
+ remaining_length_min -= selected_value
159
+ remaining_length_max -= selected_value
160
+
161
+ atoms_with_motif = []
162
+ for idx in range(len(pos_to_put_motif)):
163
+ if pos_to_put_motif[idx] == 1:
164
+ motif = fixed_parts.pop(0)
165
+ pn_unit_id, pn_unit_start, pn_unit_end = motif[0], motif[1], motif[2]
166
+ for index in range(pn_unit_start, pn_unit_end + 1):
167
+ atoms_with_motif.append(f"{pn_unit_id}{index}")
168
+ elif pos_to_put_motif[idx] == 0:
169
+ free_atom = num_free_atoms.pop(0)
170
+ atoms_with_motif.append(free_atom)
171
+ elif pos_to_put_motif[idx] == 2:
172
+ atoms_with_motif.append("/0")
173
+
174
+ return atoms_with_motif
175
+
176
+
177
+ def get_motif_components_and_breaks(unindexed_contig, index_all=False):
178
+ """
179
+ Convert a contig string into its components and breaks in motif
180
+ This way you can specify in your contigs where the breaks in the motif should be, so that,
181
+ say, residues aren't glued together by the model. Used for parsing unindexed inputs.
182
+
183
+ e.g.:
184
+ contig="A14,A15,A16" -> components=[A14, A15, A16] breaks=[True, True, True]
185
+ contig="A14-15,A16" -> components=[A14, A15, A16] breaks=[True, False, True]
186
+
187
+ args:
188
+ unindexed_contig: Contig string for unindexed tokens, see above for example on how positional
189
+ encodings between contigs can be selectively leaked
190
+ index_all: No breaks are used, allows for full indexing of concatenated tokens
191
+ Can use cleanup if this is the desired way to provide motif tokens.
192
+ """
193
+ components = []
194
+ breaks = []
195
+
196
+ contig_parts = unindexed_contig.split(",")
197
+ for part in contig_parts:
198
+ if any(c.isalpha() for c in part):
199
+ # ... Parse possibilities: A11 | A11-12 | A11-11
200
+ pn_unit_id, pn_unit_start, pn_unit_end = extract_pn_unit_info(part)
201
+
202
+ if pn_unit_start == pn_unit_end:
203
+ # ... For single residues, append and break
204
+ components.append(f"{pn_unit_id}{pn_unit_start}")
205
+ breaks.append(True)
206
+ else:
207
+ # ... For multiple residues, break and then append without breaks
208
+ for index in range(pn_unit_start, pn_unit_end + 1):
209
+ components.append(f"{pn_unit_id}{index}")
210
+ if index == pn_unit_start:
211
+ breaks.append(True)
212
+ else:
213
+ breaks.append(False)
214
+ elif part == "/0":
215
+ components.append(part)
216
+ breaks.append(None)
217
+ else:
218
+ if "-" in part:
219
+ raise ComponentValidationError(
220
+ "Partial unindexing without fixed length is not supported.",
221
+ component=part,
222
+ )
223
+ components.append(part)
224
+ breaks.append(None)
225
+
226
+ breaks[0] = True # Decouple unindexed region from global index
227
+ if index_all:
228
+ global_logger.info("Unindexing all residues")
229
+ breaks = [(False if b is not None else None) for b in breaks]
230
+ return components, breaks
231
+
232
+
233
+ #################################################################################
234
+ # Mask getters
235
+ #################################################################################
236
+
237
+
238
+ def get_name_mask(
239
+ source_names: np.ndarray, query_names: str, source_resname: str | None = None
240
+ ):
241
+ """
242
+ Args:
243
+ source_names: list of all names to match in current token
244
+ query_string: specifier of names to get:
245
+ "ALL" - All atom names in token are matched
246
+ "BKBN - Only backbone atoms (not CB)
247
+ "TIP" - 2 farthest atoms from the backbone are fixed with any
248
+ additional atoms that automatically constrain geometries
249
+ (e.g. 4 atoms for carboxylates/amides). See `constants.py`.
250
+ Comma-separated string - e.g. "N,CA,C,O,CB" for exact queries
251
+ List of names - e.g. ["N", "CA", "C", "O"] for exact queries
252
+ source_resname: residue name is required when specifying just to grab the names for a "TIP"
253
+
254
+ Raises error if not all exact atom names are found and unique
255
+
256
+ Returns:
257
+ mask of atoms corresponding to token
258
+ """
259
+ if isinstance(query_names, list):
260
+ names = query_names
261
+ elif isinstance(query_names, str):
262
+ if query_names.upper() == "ALL":
263
+ return np.ones(source_names.shape[0], dtype=bool)
264
+ elif query_names.upper() == "BKBN":
265
+ names = ["N", "CA", "C", "O"]
266
+ elif query_names.upper() == "TIP":
267
+ if not exists(source_resname):
268
+ raise ComponentValidationError(
269
+ "TIP selection requires a residue name.",
270
+ component=str(source_resname),
271
+ )
272
+ names = TIP_BY_RESTYPE[source_resname]
273
+ if not exists(names):
274
+ raise ComponentValidationError(
275
+ "Residue does not define TIP atoms; use ALL, BKBN, or explicit names.",
276
+ component=str(source_resname),
277
+ )
278
+ elif query_names == "":
279
+ names = []
280
+ else:
281
+ names = query_names.split(",")
282
+ else:
283
+ raise ComponentValidationError(
284
+ "query_names must be a string or list of strings.",
285
+ details={"got_type": str(type(query_names))},
286
+ )
287
+
288
+ if any(n == "" for n in names):
289
+ raise ComponentValidationError(
290
+ f"Empty atom name found in selection '{query_names}'.",
291
+ component=str(source_resname),
292
+ )
293
+ mask = np.isin(source_names, names)
294
+
295
+ if len(names) == 0:
296
+ return mask
297
+
298
+ if not len(set(names)) == len(names):
299
+ raise ComponentValidationError(
300
+ f"Atom names in '{query_names}' must be unique.",
301
+ details={"duplicates": names},
302
+ )
303
+ if not mask.any():
304
+ raise ComponentValidationError(
305
+ f"Could not find requested atoms '{query_names}' in atom array.",
306
+ details={"source_names": np.asarray(source_names).tolist()},
307
+ )
308
+ if mask.sum() != len(names):
309
+ global_logger.warning(
310
+ "Not all atoms found in atom array. Are you expecting multiple residues/ligands with the same names? "
311
+ + "If not, check your input pdb file. "
312
+ + "Atom array requested to contain names {}. Got: {}. Requested {}".format(
313
+ query_names,
314
+ np.asarray(source_names).tolist(),
315
+ np.asarray(names).tolist(),
316
+ )
317
+ )
318
+ if mask.sum() % len(names) != 0:
319
+ # for the case where source_names are originated from multiple residues with the same names
320
+ # (e.g. two ORO ligands in the input pdb: {ligand: "ORO", fixed_atoms: {ORO:"N3,C2,C4,N1"}})
321
+ raise ComponentValidationError(
322
+ "Number of atoms must be a multiple of the requested names.",
323
+ details={
324
+ "query": query_names,
325
+ "source_names": np.asarray(source_names).tolist(),
326
+ "requested": np.asarray(names).tolist(),
327
+ },
328
+ )
329
+
330
+ return mask
331
+
332
+
333
+ def fetch_mask_from_idx(contig_str, *, atom_array):
334
+ """
335
+ contig_str: A11
336
+ returns:
337
+ mask of atoms within contig (e.g. residue 11 in chain A)
338
+ """
339
+ chain, res_id = split_contig(contig_str)
340
+ mask = (atom_array.chain_id == chain) & (atom_array.res_id == res_id)
341
+ if not np.any(mask):
342
+ raise ComponentValidationError(
343
+ f"Residue {chain}{res_id} not found in atom array.",
344
+ component=f"{chain}{res_id}",
345
+ )
346
+ return mask
347
+
348
+
349
+ def fetch_mask_from_name(name, *, atom_array):
350
+ """
351
+ name: LIG_NAME
352
+ returns:
353
+ mask of atoms corresponding to non-protein
354
+ """
355
+ mask = atom_array.res_name == name
356
+ if not np.any(mask):
357
+ non_protein_res_names = np.unique(
358
+ atom_array.res_name[~np.isin(atom_array.res_name, _aa_like_res_names)]
359
+ )
360
+ raise ComponentValidationError(
361
+ "Component not found in input atom array.",
362
+ component=name,
363
+ details={"available_non_protein": non_protein_res_names.tolist()},
364
+ )
365
+ return mask
366
+
367
+
368
+ def fetch_mask_from_component(component, *, atom_array):
369
+ """
370
+ Catch-all function for fetching a component by non-protein name or contig
371
+ component: A11 or LIG_NAME
372
+ returns:
373
+ mask of atoms corresponding to component
374
+ """
375
+ try:
376
+ mask = fetch_mask_from_name(component, atom_array=atom_array)
377
+ except ComponentValidationError:
378
+ mask = fetch_mask_from_idx(component, atom_array=atom_array)
379
+ return mask
380
+
381
+
382
+ def unravel_components(
383
+ v: str, *, atom_array: AtomArray = None, allow_multiple_matches: bool = False
384
+ ) -> List[str]:
385
+ """Safely unravel components from a string input."""
386
+ components = []
387
+ if "," in v or "-" in v:
388
+ components.extend(get_design_pattern_with_constraints(v))
389
+ else:
390
+ # Safely canonicalize to single component
391
+ mask = fetch_mask_from_component(v, atom_array=atom_array)
392
+ if mask.sum() > 0:
393
+ res_ids, chain_ids = atom_array.res_id[mask], atom_array.chain_id[mask]
394
+ # assert unique resids for component
395
+ if len(set(zip(chain_ids, res_ids))) != 1:
396
+ if not allow_multiple_matches:
397
+ raise ComponentValidationError(
398
+ f"Component '{v}' maps to multiple residues.",
399
+ component=v,
400
+ )
401
+ else:
402
+ global_logger.warning(
403
+ f"Component '{v}' maps to multiple residues. If you are using Symmetry this is OK."
404
+ )
405
+ components.extend([f"{c}{r}" for c, r in zip(chain_ids, res_ids)])
406
+ components = list(set(components)) # unique components
407
+ return components
408
+ res_id, chain_id = res_ids[0], chain_ids[0]
409
+
410
+ component = f"{chain_id}{res_id}"
411
+ global_logger.debug(
412
+ "Canonicalized component string: %s -> %s", v, component
413
+ )
414
+ components.append(component)
415
+ return components