rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,415 @@
|
|
|
1
|
+
import random
|
|
2
|
+
import re
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from atomworks.ml.encoding_definitions import AF3SequenceEncoding
|
|
7
|
+
from biotite.structure import AtomArray
|
|
8
|
+
from rfd3.constants import (
|
|
9
|
+
TIP_BY_RESTYPE,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
from foundry.common import exists
|
|
13
|
+
from foundry.utils.ddp import RankedLogger
|
|
14
|
+
|
|
15
|
+
global_logger = RankedLogger(__name__, rank_zero_only=False)
|
|
16
|
+
sequence_encoding = AF3SequenceEncoding()
|
|
17
|
+
_aa_like_res_names = sequence_encoding.all_res_names[sequence_encoding.is_aa_like]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
#################################################################################
|
|
21
|
+
# Component / contig parsing
|
|
22
|
+
#################################################################################
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ComponentValidationError(ValueError):
|
|
26
|
+
"""Raised when contig/component inputs cannot be parsed or validated."""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
message: str,
|
|
31
|
+
*,
|
|
32
|
+
component: str | None = None,
|
|
33
|
+
details: dict | None = None,
|
|
34
|
+
):
|
|
35
|
+
self.component = component
|
|
36
|
+
self.details = details or {}
|
|
37
|
+
prefix = f"[component={component}] " if component else ""
|
|
38
|
+
suffix = f" Details: {self.details}" if self.details else ""
|
|
39
|
+
super().__init__(f"{prefix}{message}{suffix}")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class ComponentStr(str):
|
|
43
|
+
"""Component identifier, e.g. "A1" for residues, "B12", etc. Previously named `contig_string`"""
|
|
44
|
+
|
|
45
|
+
def split_component(v):
|
|
46
|
+
return split_contig(v)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def split_contig(x):
|
|
50
|
+
try:
|
|
51
|
+
chain = str(x[0])
|
|
52
|
+
idx = x[1:]
|
|
53
|
+
idx = int(idx)
|
|
54
|
+
if idx < 0:
|
|
55
|
+
raise ComponentValidationError(
|
|
56
|
+
"Residue index must be a non-negative integer.", component=str(x)
|
|
57
|
+
)
|
|
58
|
+
except Exception as e:
|
|
59
|
+
raise ComponentValidationError(
|
|
60
|
+
f"Invalid contig format: '{x}'. Expected format is 'ChainIDResID' (e.g. 'A20').",
|
|
61
|
+
component=str(x),
|
|
62
|
+
) from e
|
|
63
|
+
return [chain, idx]
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def extract_pn_unit_info(contig):
|
|
67
|
+
"""
|
|
68
|
+
Convert substring like A20-21 or A20 to separate terms: A, 20, 21.
|
|
69
|
+
"""
|
|
70
|
+
pattern = r"([A-Za-z])(\d+)(?:-(\d+))?"
|
|
71
|
+
|
|
72
|
+
match = re.match(pattern, contig)
|
|
73
|
+
if match:
|
|
74
|
+
pn_unit_id = match.group(1)
|
|
75
|
+
start = int(match.group(2))
|
|
76
|
+
end = int(match.group(3)) if match.group(3) else start
|
|
77
|
+
return pn_unit_id, start, end
|
|
78
|
+
|
|
79
|
+
raise ComponentValidationError(
|
|
80
|
+
"Invalid contig format. Expected 'ChainIDStart-Stop' or 'ChainIDIdx'.",
|
|
81
|
+
component=contig,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def get_design_pattern_with_constraints(contig, length=None):
|
|
86
|
+
"""
|
|
87
|
+
Convert the contig string to separate modules.
|
|
88
|
+
e.g. '1-5,A20-21,1-5,A25-25,1-5,A30-30,/0,1-5' with length = 10-10 may be converted to [2, A20, A21, 2, A25, 3, A30, /0, 3]
|
|
89
|
+
Integers represent number of free residues to put there.
|
|
90
|
+
|
|
91
|
+
"""
|
|
92
|
+
contig_parts = contig.split(",")
|
|
93
|
+
|
|
94
|
+
# Separate fixed segments (e.g., "A1051-1051") and variable ranges (e.g., "0-40")
|
|
95
|
+
variable_ranges = []
|
|
96
|
+
fixed_parts = []
|
|
97
|
+
pos_to_put_motif = []
|
|
98
|
+
|
|
99
|
+
for part in contig_parts:
|
|
100
|
+
if any(c.isalpha() for c in part): # Detect parts containing letters as fixed
|
|
101
|
+
pn_unit_id, pn_unit_start, pn_unit_end = extract_pn_unit_info(part)
|
|
102
|
+
fixed_parts.append([pn_unit_id, pn_unit_start, pn_unit_end])
|
|
103
|
+
pos_to_put_motif.append(1)
|
|
104
|
+
elif part == "/0":
|
|
105
|
+
pos_to_put_motif.append(2)
|
|
106
|
+
else:
|
|
107
|
+
if "-" in part:
|
|
108
|
+
start, end = map(int, part.split("-"))
|
|
109
|
+
else:
|
|
110
|
+
start = end = int(part)
|
|
111
|
+
variable_ranges.append([start, end])
|
|
112
|
+
pos_to_put_motif.append(0)
|
|
113
|
+
|
|
114
|
+
# adjust the total length to solely for free residues
|
|
115
|
+
num_motif_residues = sum([i[2] - i[1] + 1 for i in fixed_parts])
|
|
116
|
+
|
|
117
|
+
if length is None:
|
|
118
|
+
length_min, length_max = 0, 9999
|
|
119
|
+
else:
|
|
120
|
+
if "-" in length:
|
|
121
|
+
length_min, length_max = map(int, length.split("-"))
|
|
122
|
+
else:
|
|
123
|
+
length_min = length_max = int(length)
|
|
124
|
+
|
|
125
|
+
length_min -= num_motif_residues
|
|
126
|
+
length_max -= num_motif_residues
|
|
127
|
+
|
|
128
|
+
remaining_length_min = length_min
|
|
129
|
+
remaining_length_max = length_max
|
|
130
|
+
|
|
131
|
+
num_free_atoms = []
|
|
132
|
+
for range_limits in variable_ranges:
|
|
133
|
+
min_value = range_limits[0]
|
|
134
|
+
max_value = range_limits[1]
|
|
135
|
+
|
|
136
|
+
# Calculate the valid range for the current segment
|
|
137
|
+
valid_min = max(
|
|
138
|
+
min_value,
|
|
139
|
+
remaining_length_min
|
|
140
|
+
- sum(r[1] for r in variable_ranges[len(num_free_atoms) + 1 :]),
|
|
141
|
+
)
|
|
142
|
+
valid_max = min(
|
|
143
|
+
max_value,
|
|
144
|
+
remaining_length_max
|
|
145
|
+
- sum(r[0] for r in variable_ranges[len(num_free_atoms) + 1 :]),
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
if valid_min > valid_max and length is not None:
|
|
149
|
+
raise ComponentValidationError(
|
|
150
|
+
"No valid selections possible with the given constraints."
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# Randomly select a value for the current segment
|
|
154
|
+
selected_value = random.randint(valid_min, valid_max)
|
|
155
|
+
num_free_atoms.append(selected_value)
|
|
156
|
+
|
|
157
|
+
# Update remaining lengths
|
|
158
|
+
remaining_length_min -= selected_value
|
|
159
|
+
remaining_length_max -= selected_value
|
|
160
|
+
|
|
161
|
+
atoms_with_motif = []
|
|
162
|
+
for idx in range(len(pos_to_put_motif)):
|
|
163
|
+
if pos_to_put_motif[idx] == 1:
|
|
164
|
+
motif = fixed_parts.pop(0)
|
|
165
|
+
pn_unit_id, pn_unit_start, pn_unit_end = motif[0], motif[1], motif[2]
|
|
166
|
+
for index in range(pn_unit_start, pn_unit_end + 1):
|
|
167
|
+
atoms_with_motif.append(f"{pn_unit_id}{index}")
|
|
168
|
+
elif pos_to_put_motif[idx] == 0:
|
|
169
|
+
free_atom = num_free_atoms.pop(0)
|
|
170
|
+
atoms_with_motif.append(free_atom)
|
|
171
|
+
elif pos_to_put_motif[idx] == 2:
|
|
172
|
+
atoms_with_motif.append("/0")
|
|
173
|
+
|
|
174
|
+
return atoms_with_motif
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def get_motif_components_and_breaks(unindexed_contig, index_all=False):
|
|
178
|
+
"""
|
|
179
|
+
Convert a contig string into its components and breaks in motif
|
|
180
|
+
This way you can specify in your contigs where the breaks in the motif should be, so that,
|
|
181
|
+
say, residues aren't glued together by the model. Used for parsing unindexed inputs.
|
|
182
|
+
|
|
183
|
+
e.g.:
|
|
184
|
+
contig="A14,A15,A16" -> components=[A14, A15, A16] breaks=[True, True, True]
|
|
185
|
+
contig="A14-15,A16" -> components=[A14, A15, A16] breaks=[True, False, True]
|
|
186
|
+
|
|
187
|
+
args:
|
|
188
|
+
unindexed_contig: Contig string for unindexed tokens, see above for example on how positional
|
|
189
|
+
encodings between contigs can be selectively leaked
|
|
190
|
+
index_all: No breaks are used, allows for full indexing of concatenated tokens
|
|
191
|
+
Can use cleanup if this is the desired way to provide motif tokens.
|
|
192
|
+
"""
|
|
193
|
+
components = []
|
|
194
|
+
breaks = []
|
|
195
|
+
|
|
196
|
+
contig_parts = unindexed_contig.split(",")
|
|
197
|
+
for part in contig_parts:
|
|
198
|
+
if any(c.isalpha() for c in part):
|
|
199
|
+
# ... Parse possibilities: A11 | A11-12 | A11-11
|
|
200
|
+
pn_unit_id, pn_unit_start, pn_unit_end = extract_pn_unit_info(part)
|
|
201
|
+
|
|
202
|
+
if pn_unit_start == pn_unit_end:
|
|
203
|
+
# ... For single residues, append and break
|
|
204
|
+
components.append(f"{pn_unit_id}{pn_unit_start}")
|
|
205
|
+
breaks.append(True)
|
|
206
|
+
else:
|
|
207
|
+
# ... For multiple residues, break and then append without breaks
|
|
208
|
+
for index in range(pn_unit_start, pn_unit_end + 1):
|
|
209
|
+
components.append(f"{pn_unit_id}{index}")
|
|
210
|
+
if index == pn_unit_start:
|
|
211
|
+
breaks.append(True)
|
|
212
|
+
else:
|
|
213
|
+
breaks.append(False)
|
|
214
|
+
elif part == "/0":
|
|
215
|
+
components.append(part)
|
|
216
|
+
breaks.append(None)
|
|
217
|
+
else:
|
|
218
|
+
if "-" in part:
|
|
219
|
+
raise ComponentValidationError(
|
|
220
|
+
"Partial unindexing without fixed length is not supported.",
|
|
221
|
+
component=part,
|
|
222
|
+
)
|
|
223
|
+
components.append(part)
|
|
224
|
+
breaks.append(None)
|
|
225
|
+
|
|
226
|
+
breaks[0] = True # Decouple unindexed region from global index
|
|
227
|
+
if index_all:
|
|
228
|
+
global_logger.info("Unindexing all residues")
|
|
229
|
+
breaks = [(False if b is not None else None) for b in breaks]
|
|
230
|
+
return components, breaks
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
#################################################################################
|
|
234
|
+
# Mask getters
|
|
235
|
+
#################################################################################
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def get_name_mask(
|
|
239
|
+
source_names: np.ndarray, query_names: str, source_resname: str | None = None
|
|
240
|
+
):
|
|
241
|
+
"""
|
|
242
|
+
Args:
|
|
243
|
+
source_names: list of all names to match in current token
|
|
244
|
+
query_string: specifier of names to get:
|
|
245
|
+
"ALL" - All atom names in token are matched
|
|
246
|
+
"BKBN - Only backbone atoms (not CB)
|
|
247
|
+
"TIP" - 2 farthest atoms from the backbone are fixed with any
|
|
248
|
+
additional atoms that automatically constrain geometries
|
|
249
|
+
(e.g. 4 atoms for carboxylates/amides). See `constants.py`.
|
|
250
|
+
Comma-separated string - e.g. "N,CA,C,O,CB" for exact queries
|
|
251
|
+
List of names - e.g. ["N", "CA", "C", "O"] for exact queries
|
|
252
|
+
source_resname: residue name is required when specifying just to grab the names for a "TIP"
|
|
253
|
+
|
|
254
|
+
Raises error if not all exact atom names are found and unique
|
|
255
|
+
|
|
256
|
+
Returns:
|
|
257
|
+
mask of atoms corresponding to token
|
|
258
|
+
"""
|
|
259
|
+
if isinstance(query_names, list):
|
|
260
|
+
names = query_names
|
|
261
|
+
elif isinstance(query_names, str):
|
|
262
|
+
if query_names.upper() == "ALL":
|
|
263
|
+
return np.ones(source_names.shape[0], dtype=bool)
|
|
264
|
+
elif query_names.upper() == "BKBN":
|
|
265
|
+
names = ["N", "CA", "C", "O"]
|
|
266
|
+
elif query_names.upper() == "TIP":
|
|
267
|
+
if not exists(source_resname):
|
|
268
|
+
raise ComponentValidationError(
|
|
269
|
+
"TIP selection requires a residue name.",
|
|
270
|
+
component=str(source_resname),
|
|
271
|
+
)
|
|
272
|
+
names = TIP_BY_RESTYPE[source_resname]
|
|
273
|
+
if not exists(names):
|
|
274
|
+
raise ComponentValidationError(
|
|
275
|
+
"Residue does not define TIP atoms; use ALL, BKBN, or explicit names.",
|
|
276
|
+
component=str(source_resname),
|
|
277
|
+
)
|
|
278
|
+
elif query_names == "":
|
|
279
|
+
names = []
|
|
280
|
+
else:
|
|
281
|
+
names = query_names.split(",")
|
|
282
|
+
else:
|
|
283
|
+
raise ComponentValidationError(
|
|
284
|
+
"query_names must be a string or list of strings.",
|
|
285
|
+
details={"got_type": str(type(query_names))},
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
if any(n == "" for n in names):
|
|
289
|
+
raise ComponentValidationError(
|
|
290
|
+
f"Empty atom name found in selection '{query_names}'.",
|
|
291
|
+
component=str(source_resname),
|
|
292
|
+
)
|
|
293
|
+
mask = np.isin(source_names, names)
|
|
294
|
+
|
|
295
|
+
if len(names) == 0:
|
|
296
|
+
return mask
|
|
297
|
+
|
|
298
|
+
if not len(set(names)) == len(names):
|
|
299
|
+
raise ComponentValidationError(
|
|
300
|
+
f"Atom names in '{query_names}' must be unique.",
|
|
301
|
+
details={"duplicates": names},
|
|
302
|
+
)
|
|
303
|
+
if not mask.any():
|
|
304
|
+
raise ComponentValidationError(
|
|
305
|
+
f"Could not find requested atoms '{query_names}' in atom array.",
|
|
306
|
+
details={"source_names": np.asarray(source_names).tolist()},
|
|
307
|
+
)
|
|
308
|
+
if mask.sum() != len(names):
|
|
309
|
+
global_logger.warning(
|
|
310
|
+
"Not all atoms found in atom array. Are you expecting multiple residues/ligands with the same names? "
|
|
311
|
+
+ "If not, check your input pdb file. "
|
|
312
|
+
+ "Atom array requested to contain names {}. Got: {}. Requested {}".format(
|
|
313
|
+
query_names,
|
|
314
|
+
np.asarray(source_names).tolist(),
|
|
315
|
+
np.asarray(names).tolist(),
|
|
316
|
+
)
|
|
317
|
+
)
|
|
318
|
+
if mask.sum() % len(names) != 0:
|
|
319
|
+
# for the case where source_names are originated from multiple residues with the same names
|
|
320
|
+
# (e.g. two ORO ligands in the input pdb: {ligand: "ORO", fixed_atoms: {ORO:"N3,C2,C4,N1"}})
|
|
321
|
+
raise ComponentValidationError(
|
|
322
|
+
"Number of atoms must be a multiple of the requested names.",
|
|
323
|
+
details={
|
|
324
|
+
"query": query_names,
|
|
325
|
+
"source_names": np.asarray(source_names).tolist(),
|
|
326
|
+
"requested": np.asarray(names).tolist(),
|
|
327
|
+
},
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
return mask
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def fetch_mask_from_idx(contig_str, *, atom_array):
|
|
334
|
+
"""
|
|
335
|
+
contig_str: A11
|
|
336
|
+
returns:
|
|
337
|
+
mask of atoms within contig (e.g. residue 11 in chain A)
|
|
338
|
+
"""
|
|
339
|
+
chain, res_id = split_contig(contig_str)
|
|
340
|
+
mask = (atom_array.chain_id == chain) & (atom_array.res_id == res_id)
|
|
341
|
+
if not np.any(mask):
|
|
342
|
+
raise ComponentValidationError(
|
|
343
|
+
f"Residue {chain}{res_id} not found in atom array.",
|
|
344
|
+
component=f"{chain}{res_id}",
|
|
345
|
+
)
|
|
346
|
+
return mask
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def fetch_mask_from_name(name, *, atom_array):
|
|
350
|
+
"""
|
|
351
|
+
name: LIG_NAME
|
|
352
|
+
returns:
|
|
353
|
+
mask of atoms corresponding to non-protein
|
|
354
|
+
"""
|
|
355
|
+
mask = atom_array.res_name == name
|
|
356
|
+
if not np.any(mask):
|
|
357
|
+
non_protein_res_names = np.unique(
|
|
358
|
+
atom_array.res_name[~np.isin(atom_array.res_name, _aa_like_res_names)]
|
|
359
|
+
)
|
|
360
|
+
raise ComponentValidationError(
|
|
361
|
+
"Component not found in input atom array.",
|
|
362
|
+
component=name,
|
|
363
|
+
details={"available_non_protein": non_protein_res_names.tolist()},
|
|
364
|
+
)
|
|
365
|
+
return mask
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def fetch_mask_from_component(component, *, atom_array):
|
|
369
|
+
"""
|
|
370
|
+
Catch-all function for fetching a component by non-protein name or contig
|
|
371
|
+
component: A11 or LIG_NAME
|
|
372
|
+
returns:
|
|
373
|
+
mask of atoms corresponding to component
|
|
374
|
+
"""
|
|
375
|
+
try:
|
|
376
|
+
mask = fetch_mask_from_name(component, atom_array=atom_array)
|
|
377
|
+
except ComponentValidationError:
|
|
378
|
+
mask = fetch_mask_from_idx(component, atom_array=atom_array)
|
|
379
|
+
return mask
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
def unravel_components(
|
|
383
|
+
v: str, *, atom_array: AtomArray = None, allow_multiple_matches: bool = False
|
|
384
|
+
) -> List[str]:
|
|
385
|
+
"""Safely unravel components from a string input."""
|
|
386
|
+
components = []
|
|
387
|
+
if "," in v or "-" in v:
|
|
388
|
+
components.extend(get_design_pattern_with_constraints(v))
|
|
389
|
+
else:
|
|
390
|
+
# Safely canonicalize to single component
|
|
391
|
+
mask = fetch_mask_from_component(v, atom_array=atom_array)
|
|
392
|
+
if mask.sum() > 0:
|
|
393
|
+
res_ids, chain_ids = atom_array.res_id[mask], atom_array.chain_id[mask]
|
|
394
|
+
# assert unique resids for component
|
|
395
|
+
if len(set(zip(chain_ids, res_ids))) != 1:
|
|
396
|
+
if not allow_multiple_matches:
|
|
397
|
+
raise ComponentValidationError(
|
|
398
|
+
f"Component '{v}' maps to multiple residues.",
|
|
399
|
+
component=v,
|
|
400
|
+
)
|
|
401
|
+
else:
|
|
402
|
+
global_logger.warning(
|
|
403
|
+
f"Component '{v}' maps to multiple residues. If you are using Symmetry this is OK."
|
|
404
|
+
)
|
|
405
|
+
components.extend([f"{c}{r}" for c, r in zip(chain_ids, res_ids)])
|
|
406
|
+
components = list(set(components)) # unique components
|
|
407
|
+
return components
|
|
408
|
+
res_id, chain_id = res_ids[0], chain_ids[0]
|
|
409
|
+
|
|
410
|
+
component = f"{chain_id}{res_id}"
|
|
411
|
+
global_logger.debug(
|
|
412
|
+
"Canonicalized component string: %s -> %s", v, component
|
|
413
|
+
)
|
|
414
|
+
components.append(component)
|
|
415
|
+
return components
|