rc-foundry 0.1.6__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rf3/cli.py CHANGED
@@ -23,10 +23,19 @@ def fold(
23
23
  configure_minimal_inference_logging()
24
24
 
25
25
  # Find the RF3 configs directory relative to this file
26
- # This file is at: models/rf3/src/rf3/cli.py
27
- # Configs are at: models/rf3/configs/
28
- rf3_package_dir = Path(__file__).parent.parent.parent # Go up to models/rf3/
29
- config_path = str(rf3_package_dir / "configs")
26
+ # In development: models/rf3/src/rf3/cli.py -> models/rf3/configs/
27
+ # When installed: site-packages/rf3/cli.py -> site-packages/rf3/configs/
28
+ rf3_file_dir = Path(__file__).parent
29
+
30
+ # Check if we're in installed mode (configs are sibling to this file)
31
+ # or development mode (configs are ../../../configs)
32
+ if (rf3_file_dir / "configs").exists():
33
+ # Installed mode
34
+ config_path = str(rf3_file_dir / "configs")
35
+ else:
36
+ # Development mode
37
+ rf3_package_dir = rf3_file_dir.parent.parent # Go up to models/rf3/
38
+ config_path = str(rf3_package_dir / "configs")
30
39
 
31
40
  # Get all arguments
32
41
  args = ctx.params.get("args", []) + ctx.args
rf3/inference.py CHANGED
@@ -16,7 +16,9 @@ rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
16
16
 
17
17
  load_dotenv(override=True)
18
18
 
19
- _config_path = os.path.join(os.environ["PROJECT_ROOT"], "models/rf3/configs")
19
+ _config_path = os.path.join(
20
+ os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "configs"
21
+ )
20
22
 
21
23
 
22
24
  @hydra.main(
@@ -7,7 +7,7 @@ dataset:
7
7
  base_dir: ${paths.data.pdb_data_dir}
8
8
  dataset:
9
9
  name: interface
10
- data: ${paths.data.pdb_parquet_dir}/interfaces_df_train.parquet
10
+ data: ${paths.data.pdb_parquet_dir}/interfaces_df.parquet
11
11
  filters:
12
12
  # filters common across all PDB datasets
13
13
  - "deposition_date < '2021-09-30'"
@@ -7,7 +7,7 @@ _target_: rfd3.engine.RFD3InferenceEngine
7
7
 
8
8
  out_dir: ???
9
9
  inputs: ??? # null, json, pdb or
10
- ckpt_path: /projects/ml/aa_design/models/rfd3_latest_cleaned.ckpt
10
+ ckpt_path: rfd3
11
11
  json_keys_subset: null
12
12
  skip_existing: True
13
13
 
@@ -61,5 +61,5 @@ global_prefix: null
61
61
  dump_prediction_metadata_json: True
62
62
  dump_trajectories: False
63
63
  align_trajectory_structures: False
64
- prevalidate_inputs: True
64
+ prevalidate_inputs: False
65
65
  low_memory_mode: False # False for standard mode, True for memory efficient tokenization mode
@@ -4,7 +4,7 @@ defaults:
4
4
 
5
5
  kind: symmetry
6
6
  num_timesteps: 200
7
- gamma_0: 1.0 # 1.0 for SDE sampling
7
+ gamma_0: 0.6 # 1.0 for SDE sampling
8
8
  gamma_min: 1.0
9
9
  gamma_min2: 0.0
10
10
  sym_step_frac: 0.9 # when 0.9, 90% of the trajectory from the start is symmetrized
rfd3/engine.py CHANGED
@@ -21,9 +21,14 @@ from rfd3.constants import SAVED_CONDITIONING_ANNOTATIONS
21
21
  from rfd3.inference.datasets import (
22
22
  assemble_distributed_inference_loader_from_json,
23
23
  )
24
- from rfd3.inference.input_parsing import DesignInputSpecification
24
+ from rfd3.inference.input_parsing import (
25
+ DesignInputSpecification,
26
+ ensure_input_is_abspath,
27
+ )
25
28
  from rfd3.model.inference_sampler import SampleDiffusionConfig
26
- from rfd3.utils.inference import ensure_input_is_abspath
29
+ from rfd3.utils.inference import (
30
+ ensure_inference_sampler_matches_design_spec,
31
+ )
27
32
  from rfd3.utils.io import (
28
33
  CIF_LIKE_EXTENSIONS,
29
34
  build_stack_from_atom_array_and_batched_coords,
@@ -171,6 +176,7 @@ class RFD3InferenceEngine(BaseInferenceEngine):
171
176
  )
172
177
  # save
173
178
  self.specification_overrides = dict(specification or {})
179
+ self.inference_sampler_overrides = dict(inference_sampler or {})
174
180
 
175
181
  # Setup output directories and args
176
182
  self.global_prefix = global_prefix
@@ -210,6 +216,9 @@ class RFD3InferenceEngine(BaseInferenceEngine):
210
216
  inputs=inputs,
211
217
  n_batches=n_batches,
212
218
  )
219
+ ensure_inference_sampler_matches_design_spec(
220
+ design_specifications, self.inference_sampler_overrides
221
+ )
213
222
  # init before
214
223
  self.initialize()
215
224
  outputs = self._run_multi(design_specifications)
@@ -383,6 +392,15 @@ class RFD3InferenceEngine(BaseInferenceEngine):
383
392
  # Based on inputs, construct the specifications to loop through
384
393
  design_specifications = {}
385
394
  for prefix, example_spec in inputs.items():
395
+ # Record task name in the specification
396
+ if isinstance(example_spec, DesignInputSpecification):
397
+ example_spec.extra = example_spec.extra or {}
398
+ example_spec.extra["task_name"] = prefix
399
+ else:
400
+ if "extra" not in example_spec:
401
+ example_spec["extra"] = {}
402
+ example_spec["extra"]["task_name"] = prefix
403
+
386
404
  # ... Create n_batches for example
387
405
  for batch_id in range((n_batches) if exists(n_batches) else 1):
388
406
  # ... Example ID
@@ -524,21 +542,19 @@ def process_input(
524
542
 
525
543
 
526
544
  def _reshape_trajectory(traj, align_structures: bool):
527
- traj = [traj[i] for i in range(len(traj))]
528
- n_steps = len(traj)
545
+ traj = [traj[i] for i in range(len(traj))] # make list of arrays
529
546
  max_frames = 100
530
-
547
+ if len(traj) > max_frames:
548
+ selected_indices = torch.linspace(0, len(traj) - 1, max_frames).long().tolist()
549
+ traj = [traj[i] for i in selected_indices]
531
550
  if align_structures:
532
551
  # ... align the trajectories on the last prediction
533
- for step in range(n_steps - 1):
552
+ for step in range(len(traj) - 1):
534
553
  traj[step] = weighted_rigid_align(
535
- X_L=traj[-1],
536
- X_gt_L=traj[step],
537
- )
554
+ X_L=traj[-1][None],
555
+ X_gt_L=traj[step][None],
556
+ ).squeeze(0)
538
557
  traj = traj[::-1] # reverse to go from noised -> denoised
539
- if n_steps > max_frames:
540
- selected_indices = torch.linspace(0, n_steps - 1, max_frames).long().tolist()
541
- traj = [traj[i] for i in selected_indices]
542
558
 
543
559
  traj = torch.stack(traj).cpu().numpy()
544
560
  return traj
@@ -14,8 +14,8 @@ from atomworks.ml.transforms.base import Compose, Transform
14
14
  from omegaconf import DictConfig, OmegaConf
15
15
  from rfd3.inference.input_parsing import (
16
16
  DesignInputSpecification,
17
+ ensure_input_is_abspath,
17
18
  )
18
- from rfd3.utils.inference import ensure_input_is_abspath
19
19
  from torch.utils.data import (
20
20
  DataLoader,
21
21
  SequentialSampler,
@@ -5,6 +5,7 @@ import os
5
5
  import time
6
6
  import warnings
7
7
  from contextlib import contextmanager
8
+ from os import PathLike
8
9
  from typing import Any, Dict, List, Optional, Union
9
10
 
10
11
  import numpy as np
@@ -696,7 +697,7 @@ class DesignInputSpecification(BaseModel):
696
697
  # Partial diffusion: use COM, keep all coordinates
697
698
  if exists(self.symmetry) and self.symmetry.id:
698
699
  # For symmetric structures, avoid COM centering that would collapse chains
699
- ranked_logger.info(
700
+ logger.info(
700
701
  "Partial diffusion with symmetry: skipping COM centering to preserve chain spacing"
701
702
  )
702
703
  else:
@@ -1121,3 +1122,33 @@ def accumulate_components(
1121
1122
  if atom_array_accum.bonds is None:
1122
1123
  atom_array_accum.bonds = BondList(atom_array_accum.array_length())
1123
1124
  return atom_array_accum
1125
+
1126
+
1127
+ def ensure_input_is_abspath(args: Dict[str, Any], path: PathLike | None):
1128
+ """
1129
+ Ensures the input source is an absolute path if exists, if not it will convert
1130
+
1131
+ args:
1132
+ args: Inference specification for atom array
1133
+ path: None or file to which the input is relative to.
1134
+ """
1135
+ if isinstance(args, str):
1136
+ raise ValueError(
1137
+ "Expected args to be a dictionary, got a string: {}. If you are using an input JSON ensure it contains dictionaries of arguments".format(
1138
+ args
1139
+ )
1140
+ )
1141
+ if "input" not in args or not exists(args["input"]):
1142
+ return args
1143
+ input = str(args["input"])
1144
+ if not os.path.isabs(input):
1145
+ if path is None:
1146
+ raise ValueError(
1147
+ "Input path is relative, but no base path was provided to resolve it against."
1148
+ )
1149
+ input = os.path.abspath(os.path.join(os.path.dirname(str(path)), input))
1150
+ logger.info(
1151
+ f"Input source path is relative, converted to absolute path: {input}"
1152
+ )
1153
+ args["input"] = input
1154
+ return args
@@ -139,13 +139,18 @@ def fetch_motif_residue_(
139
139
  subarray, motif=True, unindexed=False, dtype=int
140
140
  ) # all values init to True (fix all)
141
141
 
142
+ to_unindex = f"{src_chain}{src_resid}" in unindexed_components
143
+ to_index = f"{src_chain}{src_resid}" in components
144
+
142
145
  # Assign is motif atom and sequence
143
146
  if exists(atoms := fixed_atoms.get(f"{src_chain}{src_resid}")):
147
+ # If specified, we set fixed atoms in the residue to be motif atoms
144
148
  atom_mask = get_name_mask(subarray.atom_name, atoms, res_name)
145
149
  subarray.set_annotation("is_motif_atom", atom_mask)
146
150
  # subarray.set_annotation("is_motif_atom_with_fixed_coord", atom_mask) # BUGFIX: uncomment
147
151
 
148
152
  elif redesign_motif_sidechains and res_name in STANDARD_AA:
153
+ # If redesign_motif_sidechains is True, we only make the backbone atoms to be motif atoms
149
154
  n_atoms = subarray.shape[0]
150
155
  diffuse_oxygen = False
151
156
  if n_atoms < 3:
@@ -178,6 +183,18 @@ def fetch_motif_residue_(
178
183
  subarray.set_annotation(
179
184
  "is_motif_atom_with_fixed_seq", np.zeros(subarray.shape[0], dtype=int)
180
185
  )
186
+ elif to_index or to_unindex:
187
+ # If the residue is in the contig or unindexed components,
188
+ # we set all atoms in the residue to be motif atoms
189
+ subarray.set_annotation("is_motif_atom", np.ones(subarray.shape[0], dtype=int))
190
+ else:
191
+ if to_unindex and not (
192
+ unfix_all or f"{src_chain}{src_resid}" in unfix_residues
193
+ ):
194
+ raise ValueError(
195
+ f"{src_chain}{src_resid} is not found in fixed_atoms, contig or unindex contig."
196
+ "Please check your input and contig specification."
197
+ )
181
198
  if unfix_all or f"{src_chain}{src_resid}" in unfix_residues:
182
199
  subarray.set_annotation(
183
200
  "is_motif_atom_with_fixed_coord", np.zeros(subarray.shape[0], dtype=int)
@@ -197,7 +214,6 @@ def fetch_motif_residue_(
197
214
  subarray.set_annotation(
198
215
  "is_flexible_motif_atom", np.zeros(subarray.shape[0], dtype=bool)
199
216
  )
200
- to_unindex = f"{src_chain}{src_resid}" in unindexed_components
201
217
  if to_unindex:
202
218
  subarray.set_annotation(
203
219
  "is_motif_atom_unindexed", subarray.is_motif_atom.copy()
rfd3/inference/parsing.py CHANGED
@@ -117,6 +117,7 @@ def from_any_(v: Any, atom_array: AtomArray):
117
117
 
118
118
  # Split to atom names
119
119
  data_split[idx] = token.atom_name[comp_mask_subset].tolist()
120
+ # TODO: there is a bug where when you select specifc atoms within a ligand, output ligand is fragmented
120
121
 
121
122
  # Update mask & token dictionary
122
123
  mask[comp_mask] = comp_mask_subset
@@ -1,14 +1,74 @@
1
+ import string
2
+
1
3
  import numpy as np
2
4
  from rfd3.inference.symmetry.frames import (
3
5
  decompose_symmetry_frame,
4
6
  get_symmetry_frames_from_symmetry_id,
5
7
  )
6
8
 
7
- from foundry.utils.ddp import RankedLogger
8
-
9
9
  FIXED_TRANSFORM_ID = -1
10
10
  FIXED_ENTITY_ID = -1
11
- ranked_logger = RankedLogger(__name__, rank_zero_only=True)
11
+
12
+ # Alphabet for chain ID generation (uppercase letters only, per wwPDB convention)
13
+ _CHAIN_ALPHABET = string.ascii_uppercase
14
+
15
+
16
+ def index_to_chain_id(index: int) -> str:
17
+ """
18
+ Convert a zero-based index to a chain ID following wwPDB convention.
19
+
20
+ The naming follows the wwPDB-assigned chain ID system:
21
+ - 0-25: A-Z (single letter)
22
+ - 26-701: AA-ZZ (double letter)
23
+ - 702-18277: AAA-ZZZ (triple letter)
24
+ - And so on...
25
+
26
+ This is similar to Excel column naming (A, B, ..., Z, AA, AB, ...).
27
+
28
+ Arguments:
29
+ index: zero-based index (0 -> 'A', 25 -> 'Z', 26 -> 'AA', etc.)
30
+ Returns:
31
+ chain_id: string chain identifier
32
+ """
33
+ if index < 0:
34
+ raise ValueError(f"Chain index must be non-negative, got {index}")
35
+
36
+ result = ""
37
+ remaining = index
38
+
39
+ # Convert to bijective base-26 (like Excel columns)
40
+ while True:
41
+ result = _CHAIN_ALPHABET[remaining % 26] + result
42
+ remaining = remaining // 26 - 1
43
+ if remaining < 0:
44
+ break
45
+
46
+ return result
47
+
48
+
49
+ def chain_id_to_index(chain_id: str) -> int:
50
+ """
51
+ Convert a chain ID back to a zero-based index.
52
+
53
+ Inverse of index_to_chain_id.
54
+
55
+ Arguments:
56
+ chain_id: string chain identifier (e.g., 'A', 'Z', 'AA', 'AB')
57
+ Returns:
58
+ index: zero-based index
59
+ """
60
+ if not chain_id or not all(c in _CHAIN_ALPHABET for c in chain_id):
61
+ raise ValueError(f"Invalid chain ID: {chain_id}")
62
+
63
+ # Offset for all shorter chain IDs (26 + 26^2 + ... + 26^(len-1))
64
+ offset = sum(26**k for k in range(1, len(chain_id)))
65
+
66
+ # Value within the current length group (standard base-26)
67
+ value = 0
68
+ for char in chain_id:
69
+ value = value * 26 + _CHAIN_ALPHABET.index(char)
70
+
71
+ return offset + value
12
72
 
13
73
 
14
74
  ########################################################
@@ -28,7 +88,7 @@ def add_sym_annotations(atom_array, sym_conf):
28
88
  is_asu = np.full(n, True, dtype=np.bool_)
29
89
  atom_array.set_annotation("is_sym_asu", is_asu)
30
90
  # symmetry_id
31
- symmetry_ids = np.full(n, sym_conf.get("id"), dtype="U6")
91
+ symmetry_ids = np.full(n, sym_conf.id, dtype="U6")
32
92
  atom_array.set_annotation("symmetry_id", symmetry_ids)
33
93
  return atom_array
34
94
 
@@ -251,11 +311,13 @@ def reset_chain_ids(atom_array, start_id):
251
311
  Reset the chain ids and pn_unit_iids of an atom array to start from the given id.
252
312
  Arguments:
253
313
  atom_array: atom array with chain_ids and pn_unit_iids annotated
314
+ start_id: starting chain ID (e.g., 'A')
254
315
  """
255
316
  chain_ids = np.unique(atom_array.chain_id)
256
- new_chain_range = range(ord(start_id), ord(start_id) + len(chain_ids))
257
- for new_id, old_id in zip(new_chain_range, chain_ids):
258
- atom_array.chain_id[atom_array.chain_id == old_id] = chr(new_id)
317
+ start_index = chain_id_to_index(start_id)
318
+ for i, old_id in enumerate(chain_ids):
319
+ new_id = index_to_chain_id(start_index + i)
320
+ atom_array.chain_id[atom_array.chain_id == old_id] = new_id
259
321
  atom_array.pn_unit_iid = atom_array.chain_id
260
322
  return atom_array
261
323
 
@@ -263,15 +325,18 @@ def reset_chain_ids(atom_array, start_id):
263
325
  def reannotate_chain_ids(atom_array, offset, multiplier=0):
264
326
  """
265
327
  Reannotate the chain ids and pn_unit_iids of an atom array.
328
+
329
+ Uses wwPDB-style chain IDs (A-Z, AA-ZZ, AAA-ZZZ, ...) to support
330
+ any number of chains.
331
+
266
332
  Arguments:
267
333
  atom_array: protein atom array with chain_ids and pn_unit_iids annotated
268
- offset: offset to add to the chain ids
269
- multiplier: multiplier to add to the chain ids
334
+ offset: offset to add to the chain ids (typically num_chains in ASU)
335
+ multiplier: multiplier for the offset (typically transform index)
270
336
  """
271
- chain_ids_int = (
272
- np.array([ord(c) for c in atom_array.chain_id]) + offset * multiplier
273
- )
274
- chain_ids = np.array([chr(id) for id in chain_ids_int], dtype=str)
337
+ chain_ids_indices = np.array([chain_id_to_index(c) for c in atom_array.chain_id])
338
+ new_indices = chain_ids_indices + offset * multiplier
339
+ chain_ids = np.array([index_to_chain_id(idx) for idx in new_indices], dtype="U4")
275
340
  atom_array.chain_id = chain_ids
276
341
  atom_array.pn_unit_iid = chain_ids
277
342
  return atom_array
@@ -1,10 +1,13 @@
1
1
  import numpy as np
2
- from rfd3.inference.symmetry.contigs import expand_contig_unsym_motif
2
+ from rfd3.inference.symmetry.contigs import (
3
+ expand_contig_unsym_motif,
4
+ get_unsym_motif_mask,
5
+ )
3
6
  from rfd3.transforms.conditioning_base import get_motif_features
4
7
 
5
8
  from foundry.utils.ddp import RankedLogger
6
9
 
7
- MIN_ATOMS_ALIGN = 100
10
+ MIN_ATOMS_ALIGN = 30
8
11
  MAX_TRANSFORMS = 10
9
12
  RMSD_CUT = 1.0 # Angstroms
10
13
 
@@ -18,32 +21,44 @@ def check_symmetry_config(
18
21
  Check if the symmetry configuration is valid. Add all basic checks here.
19
22
  """
20
23
 
21
- assert sym_conf.get("id"), "symmetry_id is required. e.g. {'id': 'C2'}"
24
+ assert sym_conf.id, "symmetry_id is required. e.g. {'id': 'C2'}"
22
25
  # if unsym motif is provided, check that each motif name is in the atom array
23
- if sym_conf.get("is_unsym_motif"):
26
+
27
+ is_motif_atom = get_motif_features(atom_array)["is_motif_atom"]
28
+ is_unsym_motif = np.zeros(atom_array.shape[0], dtype=bool)
29
+
30
+ if not is_motif_atom.any():
31
+ sym_conf.is_symmetric_motif = None
32
+ ranked_logger.warning(
33
+ "No motifs found in atom array. Setting is_symmetric_motif to None."
34
+ )
35
+ return sym_conf
36
+
37
+ if sym_conf.is_unsym_motif:
24
38
  assert (
25
39
  src_atom_array is not None
26
40
  ), "Source atom array must be provided for symmetric motifs"
27
- unsym_motif_names = sym_conf["is_unsym_motif"].split(",")
41
+ unsym_motif_names = sym_conf.is_unsym_motif.split(",")
28
42
  unsym_motif_names = expand_contig_unsym_motif(unsym_motif_names)
43
+ is_unsym_motif = get_unsym_motif_mask(atom_array, unsym_motif_names)
29
44
  for n in unsym_motif_names:
30
45
  if (sm and n not in sm.split(",")) and (n not in atom_array.src_component):
31
46
  raise ValueError(f"Unsym motif {n} not found in atom_array")
47
+
32
48
  if (
33
- get_motif_features(atom_array)["is_motif_token"].any()
34
- and not sym_conf.get("is_symmetric_motif")
49
+ is_motif_atom[~is_unsym_motif].any()
50
+ and not sym_conf.is_symmetric_motif
35
51
  and not has_dist_cond
36
52
  ):
37
53
  raise ValueError(
38
- "Asymmetric motif inputs should be distance constrained. "
39
- "Use atomwise_fixed_dist to constrain the distance between the motif atoms."
54
+ "Asymmetric motif inputs are not supported yet. Please provide a symmetric motif."
40
55
  )
41
- # else: if unconditional symmetry, no need to have symmetric input motif
42
56
 
43
- if partial and not sym_conf.get("is_symmetric_motif"):
57
+ if partial and not sym_conf.is_symmetric_motif:
44
58
  raise ValueError(
45
59
  "Partial diffusion with symmetry is only supported for symmetric inputs."
46
60
  )
61
+ return sym_conf
47
62
 
48
63
 
49
64
  def check_atom_array_is_symmetric(atom_array):
@@ -54,9 +69,6 @@ def check_atom_array_is_symmetric(atom_array):
54
69
  Returns:
55
70
  bool: True if the atom array is symmetric, False otherwise
56
71
  """
57
- # TODO: Implement something like this https://github.com/baker-laboratory/ipd/blob/main/ipd/sym/sym_detect.py#L303
58
- # and maybe this https://github.com/baker-laboratory/ipd/blob/main/ipd/sym/sym_detect.py#L231
59
-
60
72
  import biotite.structure as struc
61
73
  from rfd3.inference.symmetry.atom_array import (
62
74
  apply_symmetry_to_atomarray_coord,
@@ -68,8 +80,10 @@ def check_atom_array_is_symmetric(atom_array):
68
80
  # remove hetero atoms
69
81
  atom_array = atom_array[~atom_array.hetero]
70
82
  if len(atom_array) == 0:
71
- ranked_logger.info("Atom array has no protein chains. Please check your input.")
72
- return False
83
+ ranked_logger.warning(
84
+ "Atom array has no protein chains. Please check your input."
85
+ )
86
+ return True
73
87
 
74
88
  chains = np.unique(atom_array.chain_id)
75
89
  asu_mask = atom_array.chain_id == chains[0]
@@ -162,16 +176,22 @@ def find_optimal_rotation(coords1, coords2, max_points=1000):
162
176
  return None
163
177
 
164
178
 
165
- def check_input_frames_match_symmetry_frames(computed_frames, original_frames) -> None:
179
+ def check_input_frames_match_symmetry_frames(
180
+ computed_frames, original_frames, nids_by_entity
181
+ ) -> None:
166
182
  """
167
183
  Check if the atom array matches the symmetry_id.
168
184
  Arguments:
169
185
  computed_frames: list of computed frames
170
186
  original_frames: list of original frames
171
187
  """
172
- assert len(computed_frames) == len(
173
- original_frames
174
- ), "Number of computed frames does not match number of original frames"
188
+ assert len(computed_frames) == len(original_frames), (
189
+ "Number of computed frames does not match number of original frames.\n"
190
+ f"Computed Frames: {len(computed_frames)}. Original Frames: {len(original_frames)}.\n"
191
+ "If the computed frames are not as expected, please check if you have one-to-one mapping "
192
+ "(size, sequence, folding) of an entity across all chains.\n"
193
+ f"Computed Entity Mapping: {nids_by_entity}."
194
+ )
175
195
 
176
196
 
177
197
  def check_valid_multiplicity(nids_by_entity) -> None:
@@ -184,25 +204,35 @@ def check_valid_multiplicity(nids_by_entity) -> None:
184
204
  multiplicity = min([len(i) for i in nids_by_entity.values()])
185
205
  if multiplicity == 1: # no possible symmetry
186
206
  raise ValueError(
187
- "Input has no possible symmetry. If asymmetric motif, please use 2D conditioning inference instead."
207
+ "Input has no possible symmetry. If asymmetric motif, please use 2D conditioning inference instead.\n"
208
+ "Multiplicity: 1"
188
209
  )
189
210
 
190
211
  # Check that the input is not asymmetric
191
212
  multiplicity_good = [len(i) % multiplicity == 0 for i in nids_by_entity.values()]
192
213
  if not all(multiplicity_good):
193
- raise ValueError("Invalid multiplicities of subunits. Please check your input.")
214
+ raise ValueError(
215
+ "Expected multiplicity does not match for some entities.\n"
216
+ "Please modify your input to have one-to-one mapping (size, sequence, folding) of an entity across all chains.\n"
217
+ f"Expected Multiplicity: {multiplicity}.\n"
218
+ f"Computed Entity Mapping: {nids_by_entity}."
219
+ )
194
220
 
195
221
 
196
222
  def check_valid_subunit_size(nids_by_entity, pn_unit_id) -> None:
197
223
  """
198
224
  Check that the subunits in the input are of the same size.
199
225
  Arguments:
200
- nids_by_entity: dict mapping entity to ids
226
+ nids_by_entity: dict mapping entity to ids. e.g. {0: (['A_1', 'B_1', 'C_1']), 1: (['A_2', 'B_2', 'C_2'])}
227
+ pn_unit_id: array of ids. e.g. ['A_1', 'B_1', 'C_1', 'A_2', 'B_2', 'C_2']
201
228
  """
202
- for i, js in nids_by_entity.items():
203
- for j in js[1:]:
204
- if (pn_unit_id == js[0]).sum() != (pn_unit_id == j).sum():
205
- raise ValueError("Size mismatch in the input. Please check your file.")
229
+ for js in nids_by_entity.values():
230
+ for js_i in js[1:]:
231
+ if (pn_unit_id == js[0]).sum() != (pn_unit_id == js_i).sum():
232
+ raise ValueError(
233
+ f"Size mismatch between chain {js[0]} ({(pn_unit_id == js[0]).sum()} atoms) "
234
+ f"and chain {js_i} ({(pn_unit_id == js_i).sum()} atoms). Please check your input file."
235
+ )
206
236
 
207
237
 
208
238
  def check_min_atoms_to_align(natm_per_unique, reference_entity) -> None:
@@ -212,7 +242,10 @@ def check_min_atoms_to_align(natm_per_unique, reference_entity) -> None:
212
242
  nids_by_entity: dict mapping entity to ids
213
243
  """
214
244
  if natm_per_unique[reference_entity] < MIN_ATOMS_ALIGN:
215
- raise ValueError("Not enough atoms to align. Please check your input.")
245
+ raise ValueError(
246
+ f"Not enough atoms to align < {MIN_ATOMS_ALIGN} atoms."
247
+ f"Please provide a input with at least {MIN_ATOMS_ALIGN} atoms."
248
+ )
216
249
 
217
250
 
218
251
  def check_max_transforms(chains_to_consider) -> None:
@@ -224,7 +257,7 @@ def check_max_transforms(chains_to_consider) -> None:
224
257
  """
225
258
  if len(chains_to_consider) > MAX_TRANSFORMS:
226
259
  raise ValueError(
227
- "Number of transforms exceeds the max number of transforms (10)"
260
+ f"Number of transforms exceeds the max number of transforms ({MAX_TRANSFORMS})."
228
261
  )
229
262
 
230
263