rc-foundry 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. foundry/inference_engines/checkpoint_registry.py +58 -11
  2. foundry/utils/alignment.py +10 -2
  3. foundry/version.py +2 -2
  4. foundry_cli/download_checkpoints.py +66 -66
  5. {rc_foundry-0.1.5.dist-info → rc_foundry-0.1.7.dist-info}/METADATA +25 -20
  6. rc_foundry-0.1.7.dist-info/RECORD +311 -0
  7. rf3/configs/callbacks/default.yaml +5 -0
  8. rf3/configs/callbacks/dump_validation_structures.yaml +6 -0
  9. rf3/configs/callbacks/metrics_logging.yaml +10 -0
  10. rf3/configs/callbacks/train_logging.yaml +16 -0
  11. rf3/configs/dataloader/default.yaml +15 -0
  12. rf3/configs/datasets/base.yaml +31 -0
  13. rf3/configs/datasets/pdb_and_distillation.yaml +58 -0
  14. rf3/configs/datasets/pdb_only.yaml +17 -0
  15. rf3/configs/datasets/train/disorder_distillation.yaml +48 -0
  16. rf3/configs/datasets/train/domain_distillation.yaml +50 -0
  17. rf3/configs/datasets/train/monomer_distillation.yaml +49 -0
  18. rf3/configs/datasets/train/na_complex_distillation.yaml +50 -0
  19. rf3/configs/datasets/train/pdb/af3_weighted_sampling.yaml +8 -0
  20. rf3/configs/datasets/train/pdb/base.yaml +32 -0
  21. rf3/configs/datasets/train/pdb/plinder.yaml +54 -0
  22. rf3/configs/datasets/train/pdb/train_interface.yaml +51 -0
  23. rf3/configs/datasets/train/pdb/train_pn_unit.yaml +46 -0
  24. rf3/configs/datasets/train/rna_monomer_distillation.yaml +56 -0
  25. rf3/configs/datasets/val/af3_ab_set.yaml +11 -0
  26. rf3/configs/datasets/val/af3_validation.yaml +11 -0
  27. rf3/configs/datasets/val/base.yaml +32 -0
  28. rf3/configs/datasets/val/runs_and_poses.yaml +12 -0
  29. rf3/configs/debug/default.yaml +66 -0
  30. rf3/configs/debug/train_specific_examples.yaml +21 -0
  31. rf3/configs/experiment/pretrained/rf3.yaml +50 -0
  32. rf3/configs/experiment/pretrained/rf3_with_confidence.yaml +13 -0
  33. rf3/configs/experiment/quick-rf3-with-confidence.yaml +15 -0
  34. rf3/configs/experiment/quick-rf3.yaml +61 -0
  35. rf3/configs/hydra/default.yaml +18 -0
  36. rf3/configs/hydra/no_logging.yaml +7 -0
  37. rf3/configs/inference.yaml +7 -0
  38. rf3/configs/inference_engine/base.yaml +23 -0
  39. rf3/configs/inference_engine/rf3.yaml +33 -0
  40. rf3/configs/logger/csv.yaml +6 -0
  41. rf3/configs/logger/default.yaml +3 -0
  42. rf3/configs/logger/wandb.yaml +15 -0
  43. rf3/configs/model/components/ema.yaml +1 -0
  44. rf3/configs/model/components/rf3_net.yaml +177 -0
  45. rf3/configs/model/components/rf3_net_with_confidence_head.yaml +45 -0
  46. rf3/configs/model/optimizers/adam.yaml +5 -0
  47. rf3/configs/model/rf3.yaml +43 -0
  48. rf3/configs/model/rf3_with_confidence.yaml +7 -0
  49. rf3/configs/model/schedulers/af3.yaml +6 -0
  50. rf3/configs/paths/data/default.yaml +43 -0
  51. rf3/configs/paths/default.yaml +21 -0
  52. rf3/configs/train.yaml +42 -0
  53. rf3/configs/trainer/cpu.yaml +6 -0
  54. rf3/configs/trainer/ddp.yaml +5 -0
  55. rf3/configs/trainer/loss/losses/confidence_loss.yaml +29 -0
  56. rf3/configs/trainer/loss/losses/diffusion_loss.yaml +9 -0
  57. rf3/configs/trainer/loss/losses/distogram_loss.yaml +2 -0
  58. rf3/configs/trainer/loss/structure_prediction.yaml +4 -0
  59. rf3/configs/trainer/loss/structure_prediction_with_confidence.yaml +2 -0
  60. rf3/configs/trainer/metrics/structure_prediction.yaml +14 -0
  61. rf3/configs/trainer/rf3.yaml +20 -0
  62. rf3/configs/trainer/rf3_with_confidence.yaml +13 -0
  63. rf3/configs/validate.yaml +45 -0
  64. rfd3/cli.py +10 -4
  65. rfd3/configs/__init__.py +0 -0
  66. rfd3/configs/callbacks/design_callbacks.yaml +10 -0
  67. rfd3/configs/callbacks/metrics_logging.yaml +20 -0
  68. rfd3/configs/callbacks/train_logging.yaml +24 -0
  69. rfd3/configs/dataloader/default.yaml +15 -0
  70. rfd3/configs/dataloader/fast.yaml +11 -0
  71. rfd3/configs/datasets/conditions/dna_condition.yaml +3 -0
  72. rfd3/configs/datasets/conditions/island.yaml +28 -0
  73. rfd3/configs/datasets/conditions/ppi.yaml +2 -0
  74. rfd3/configs/datasets/conditions/sequence_design.yaml +17 -0
  75. rfd3/configs/datasets/conditions/tipatom.yaml +28 -0
  76. rfd3/configs/datasets/conditions/unconditional.yaml +21 -0
  77. rfd3/configs/datasets/design_base.yaml +97 -0
  78. rfd3/configs/datasets/train/pdb/af3_train_interface.yaml +46 -0
  79. rfd3/configs/datasets/train/pdb/af3_train_pn_unit.yaml +42 -0
  80. rfd3/configs/datasets/train/pdb/base.yaml +14 -0
  81. rfd3/configs/datasets/train/pdb/base_no_weights.yaml +19 -0
  82. rfd3/configs/datasets/train/pdb/base_transform_args.yaml +59 -0
  83. rfd3/configs/datasets/train/pdb/na_complex_distillation.yaml +20 -0
  84. rfd3/configs/datasets/train/pdb/pdb_base.yaml +11 -0
  85. rfd3/configs/datasets/train/pdb/rfd3_train_interface.yaml +22 -0
  86. rfd3/configs/datasets/train/pdb/rfd3_train_pn_unit.yaml +23 -0
  87. rfd3/configs/datasets/train/rfd3_monomer_distillation.yaml +38 -0
  88. rfd3/configs/datasets/val/bcov_ppi_easy_medium.yaml +9 -0
  89. rfd3/configs/datasets/val/design_validation_base.yaml +40 -0
  90. rfd3/configs/datasets/val/dna_binder_design5.yaml +9 -0
  91. rfd3/configs/datasets/val/dna_binder_long.yaml +13 -0
  92. rfd3/configs/datasets/val/dna_binder_short.yaml +13 -0
  93. rfd3/configs/datasets/val/indexed.yaml +9 -0
  94. rfd3/configs/datasets/val/mcsa_41.yaml +9 -0
  95. rfd3/configs/datasets/val/mcsa_41_short_rigid.yaml +10 -0
  96. rfd3/configs/datasets/val/ppi_inference.yaml +7 -0
  97. rfd3/configs/datasets/val/sm_binder_hbonds.yaml +13 -0
  98. rfd3/configs/datasets/val/sm_binder_hbonds_short.yaml +15 -0
  99. rfd3/configs/datasets/val/unconditional.yaml +9 -0
  100. rfd3/configs/datasets/val/unconditional_deep.yaml +9 -0
  101. rfd3/configs/datasets/val/unindexed.yaml +8 -0
  102. rfd3/configs/datasets/val/val_examples/bcov_ppi_easy_medium_with_ori.yaml +151 -0
  103. rfd3/configs/datasets/val/val_examples/bcov_ppi_easy_medium_with_ori_spoof_helical_bundle.yaml +7 -0
  104. rfd3/configs/datasets/val/val_examples/bcov_ppi_easy_medium_with_ori_varying_lengths.yaml +28 -0
  105. rfd3/configs/datasets/val/val_examples/bpem_ori_hb.yaml +212 -0
  106. rfd3/configs/debug/default.yaml +64 -0
  107. rfd3/configs/debug/train_specific_examples.yaml +21 -0
  108. rfd3/configs/dev.yaml +9 -0
  109. rfd3/configs/experiment/debug.yaml +14 -0
  110. rfd3/configs/experiment/pretrain.yaml +31 -0
  111. rfd3/configs/experiment/test-uncond.yaml +10 -0
  112. rfd3/configs/experiment/test-unindexed.yaml +21 -0
  113. rfd3/configs/hydra/default.yaml +18 -0
  114. rfd3/configs/hydra/no_logging.yaml +7 -0
  115. rfd3/configs/inference.yaml +9 -0
  116. rfd3/configs/inference_engine/base.yaml +15 -0
  117. rfd3/configs/inference_engine/dev.yaml +20 -0
  118. rfd3/configs/inference_engine/rfdiffusion3.yaml +65 -0
  119. rfd3/configs/logger/csv.yaml +6 -0
  120. rfd3/configs/logger/default.yaml +2 -0
  121. rfd3/configs/logger/wandb.yaml +15 -0
  122. rfd3/configs/model/components/ema.yaml +1 -0
  123. rfd3/configs/model/components/rfd3_net.yaml +131 -0
  124. rfd3/configs/model/optimizers/adam.yaml +5 -0
  125. rfd3/configs/model/rfd3_base.yaml +8 -0
  126. rfd3/configs/model/samplers/edm.yaml +21 -0
  127. rfd3/configs/model/samplers/symmetry.yaml +10 -0
  128. rfd3/configs/model/schedulers/af3.yaml +6 -0
  129. rfd3/configs/paths/data/default.yaml +18 -0
  130. rfd3/configs/paths/default.yaml +22 -0
  131. rfd3/configs/train.yaml +28 -0
  132. rfd3/configs/trainer/cpu.yaml +6 -0
  133. rfd3/configs/trainer/ddp.yaml +5 -0
  134. rfd3/configs/trainer/loss/losses/diffusion_loss.yaml +12 -0
  135. rfd3/configs/trainer/loss/losses/sequence_loss.yaml +3 -0
  136. rfd3/configs/trainer/metrics/design_metrics.yaml +22 -0
  137. rfd3/configs/trainer/rfd3_base.yaml +35 -0
  138. rfd3/configs/validate.yaml +34 -0
  139. rfd3/engine.py +19 -11
  140. rfd3/inference/input_parsing.py +1 -1
  141. rfd3/inference/legacy_input_parsing.py +17 -1
  142. rfd3/inference/parsing.py +1 -0
  143. rfd3/inference/symmetry/atom_array.py +1 -5
  144. rfd3/inference/symmetry/checks.py +53 -28
  145. rfd3/inference/symmetry/frames.py +8 -5
  146. rfd3/inference/symmetry/symmetry_utils.py +38 -60
  147. rfd3/run_inference.py +3 -1
  148. rfd3/utils/inference.py +23 -0
  149. rc_foundry-0.1.5.dist-info/RECORD +0 -180
  150. {rc_foundry-0.1.5.dist-info → rc_foundry-0.1.7.dist-info}/WHEEL +0 -0
  151. {rc_foundry-0.1.5.dist-info → rc_foundry-0.1.7.dist-info}/entry_points.txt +0 -0
  152. {rc_foundry-0.1.5.dist-info → rc_foundry-0.1.7.dist-info}/licenses/LICENSE.md +0 -0
@@ -139,13 +139,18 @@ def fetch_motif_residue_(
139
139
  subarray, motif=True, unindexed=False, dtype=int
140
140
  ) # all values init to True (fix all)
141
141
 
142
+ to_unindex = f"{src_chain}{src_resid}" in unindexed_components
143
+ to_index = f"{src_chain}{src_resid}" in components
144
+
142
145
  # Assign is motif atom and sequence
143
146
  if exists(atoms := fixed_atoms.get(f"{src_chain}{src_resid}")):
147
+ # If specified, we set fixed atoms in the residue to be motif atoms
144
148
  atom_mask = get_name_mask(subarray.atom_name, atoms, res_name)
145
149
  subarray.set_annotation("is_motif_atom", atom_mask)
146
150
  # subarray.set_annotation("is_motif_atom_with_fixed_coord", atom_mask) # BUGFIX: uncomment
147
151
 
148
152
  elif redesign_motif_sidechains and res_name in STANDARD_AA:
153
+ # If redesign_motif_sidechains is True, we only make the backbone atoms to be motif atoms
149
154
  n_atoms = subarray.shape[0]
150
155
  diffuse_oxygen = False
151
156
  if n_atoms < 3:
@@ -178,6 +183,18 @@ def fetch_motif_residue_(
178
183
  subarray.set_annotation(
179
184
  "is_motif_atom_with_fixed_seq", np.zeros(subarray.shape[0], dtype=int)
180
185
  )
186
+ elif to_index or to_unindex:
187
+ # If the residue is in the contig or unindexed components,
188
+ # we set all atoms in the residue to be motif atoms
189
+ subarray.set_annotation("is_motif_atom", np.ones(subarray.shape[0], dtype=int))
190
+ else:
191
+ if to_unindex and not (
192
+ unfix_all or f"{src_chain}{src_resid}" in unfix_residues
193
+ ):
194
+ raise ValueError(
195
+ f"{src_chain}{src_resid} is not found in fixed_atoms, contig or unindex contig."
196
+ "Please check your input and contig specification."
197
+ )
181
198
  if unfix_all or f"{src_chain}{src_resid}" in unfix_residues:
182
199
  subarray.set_annotation(
183
200
  "is_motif_atom_with_fixed_coord", np.zeros(subarray.shape[0], dtype=int)
@@ -197,7 +214,6 @@ def fetch_motif_residue_(
197
214
  subarray.set_annotation(
198
215
  "is_flexible_motif_atom", np.zeros(subarray.shape[0], dtype=bool)
199
216
  )
200
- to_unindex = f"{src_chain}{src_resid}" in unindexed_components
201
217
  if to_unindex:
202
218
  subarray.set_annotation(
203
219
  "is_motif_atom_unindexed", subarray.is_motif_atom.copy()
rfd3/inference/parsing.py CHANGED
@@ -117,6 +117,7 @@ def from_any_(v: Any, atom_array: AtomArray):
117
117
 
118
118
  # Split to atom names
119
119
  data_split[idx] = token.atom_name[comp_mask_subset].tolist()
120
+ # TODO: there is a bug where when you select specifc atoms within a ligand, output ligand is fragmented
120
121
 
121
122
  # Update mask & token dictionary
122
123
  mask[comp_mask] = comp_mask_subset
@@ -4,12 +4,8 @@ from rfd3.inference.symmetry.frames import (
4
4
  get_symmetry_frames_from_symmetry_id,
5
5
  )
6
6
 
7
- from foundry.utils.ddp import RankedLogger
8
-
9
7
  FIXED_TRANSFORM_ID = -1
10
8
  FIXED_ENTITY_ID = -1
11
- ranked_logger = RankedLogger(__name__, rank_zero_only=True)
12
-
13
9
 
14
10
  ########################################################
15
11
  # Symmetry annotations
@@ -28,7 +24,7 @@ def add_sym_annotations(atom_array, sym_conf):
28
24
  is_asu = np.full(n, True, dtype=np.bool_)
29
25
  atom_array.set_annotation("is_sym_asu", is_asu)
30
26
  # symmetry_id
31
- symmetry_ids = np.full(n, sym_conf.get("id"), dtype="U6")
27
+ symmetry_ids = np.full(n, sym_conf.id, dtype="U6")
32
28
  atom_array.set_annotation("symmetry_id", symmetry_ids)
33
29
  return atom_array
34
30
 
@@ -1,10 +1,13 @@
1
1
  import numpy as np
2
- from rfd3.inference.symmetry.contigs import expand_contig_unsym_motif
2
+ from rfd3.inference.symmetry.contigs import (
3
+ expand_contig_unsym_motif,
4
+ get_unsym_motif_mask,
5
+ )
3
6
  from rfd3.transforms.conditioning_base import get_motif_features
4
7
 
5
8
  from foundry.utils.ddp import RankedLogger
6
9
 
7
- MIN_ATOMS_ALIGN = 100
10
+ MIN_ATOMS_ALIGN = 30
8
11
  MAX_TRANSFORMS = 10
9
12
  RMSD_CUT = 1.0 # Angstroms
10
13
 
@@ -18,29 +21,33 @@ def check_symmetry_config(
18
21
  Check if the symmetry configuration is valid. Add all basic checks here.
19
22
  """
20
23
 
21
- assert sym_conf.get("id"), "symmetry_id is required. e.g. {'id': 'C2'}"
24
+ assert sym_conf.id, "symmetry_id is required. e.g. {'id': 'C2'}"
22
25
  # if unsym motif is provided, check that each motif name is in the atom array
23
- if sym_conf.get("is_unsym_motif"):
26
+
27
+ is_unsym_motif = np.zeros(atom_array.shape[0], dtype=bool)
28
+ if sym_conf.is_unsym_motif:
24
29
  assert (
25
30
  src_atom_array is not None
26
31
  ), "Source atom array must be provided for symmetric motifs"
27
- unsym_motif_names = sym_conf["is_unsym_motif"].split(",")
32
+ unsym_motif_names = sym_conf.is_unsym_motif.split(",")
28
33
  unsym_motif_names = expand_contig_unsym_motif(unsym_motif_names)
34
+ is_unsym_motif = get_unsym_motif_mask(atom_array, unsym_motif_names)
29
35
  for n in unsym_motif_names:
30
36
  if (sm and n not in sm.split(",")) and (n not in atom_array.src_component):
31
37
  raise ValueError(f"Unsym motif {n} not found in atom_array")
38
+
39
+ is_motif_token = get_motif_features(atom_array)["is_motif_token"]
32
40
  if (
33
- get_motif_features(atom_array)["is_motif_token"].any()
34
- and not sym_conf.get("is_symmetric_motif")
41
+ is_motif_token[~is_unsym_motif].any()
42
+ and not sym_conf.is_symmetric_motif
35
43
  and not has_dist_cond
36
44
  ):
37
45
  raise ValueError(
38
- "Asymmetric motif inputs should be distance constrained. "
46
+ "Asymmetric motif inputs should be distance constrained."
39
47
  "Use atomwise_fixed_dist to constrain the distance between the motif atoms."
40
48
  )
41
- # else: if unconditional symmetry, no need to have symmetric input motif
42
49
 
43
- if partial and not sym_conf.get("is_symmetric_motif"):
50
+ if partial and not sym_conf.is_symmetric_motif:
44
51
  raise ValueError(
45
52
  "Partial diffusion with symmetry is only supported for symmetric inputs."
46
53
  )
@@ -54,9 +61,6 @@ def check_atom_array_is_symmetric(atom_array):
54
61
  Returns:
55
62
  bool: True if the atom array is symmetric, False otherwise
56
63
  """
57
- # TODO: Implement something like this https://github.com/baker-laboratory/ipd/blob/main/ipd/sym/sym_detect.py#L303
58
- # and maybe this https://github.com/baker-laboratory/ipd/blob/main/ipd/sym/sym_detect.py#L231
59
-
60
64
  import biotite.structure as struc
61
65
  from rfd3.inference.symmetry.atom_array import (
62
66
  apply_symmetry_to_atomarray_coord,
@@ -68,8 +72,10 @@ def check_atom_array_is_symmetric(atom_array):
68
72
  # remove hetero atoms
69
73
  atom_array = atom_array[~atom_array.hetero]
70
74
  if len(atom_array) == 0:
71
- ranked_logger.info("Atom array has no protein chains. Please check your input.")
72
- return False
75
+ ranked_logger.warning(
76
+ "Atom array has no protein chains. Please check your input."
77
+ )
78
+ return True
73
79
 
74
80
  chains = np.unique(atom_array.chain_id)
75
81
  asu_mask = atom_array.chain_id == chains[0]
@@ -162,16 +168,22 @@ def find_optimal_rotation(coords1, coords2, max_points=1000):
162
168
  return None
163
169
 
164
170
 
165
- def check_input_frames_match_symmetry_frames(computed_frames, original_frames) -> None:
171
+ def check_input_frames_match_symmetry_frames(
172
+ computed_frames, original_frames, nids_by_entity
173
+ ) -> None:
166
174
  """
167
175
  Check if the atom array matches the symmetry_id.
168
176
  Arguments:
169
177
  computed_frames: list of computed frames
170
178
  original_frames: list of original frames
171
179
  """
172
- assert len(computed_frames) == len(
173
- original_frames
174
- ), "Number of computed frames does not match number of original frames"
180
+ assert len(computed_frames) == len(original_frames), (
181
+ "Number of computed frames does not match number of original frames.\n"
182
+ f"Computed Frames: {len(computed_frames)}. Original Frames: {len(original_frames)}.\n"
183
+ "If the computed frames are not as expected, please check if you have one-to-one mapping "
184
+ "(size, sequence, folding) of an entity across all chains.\n"
185
+ f"Computed Entity Mapping: {nids_by_entity}."
186
+ )
175
187
 
176
188
 
177
189
  def check_valid_multiplicity(nids_by_entity) -> None:
@@ -184,25 +196,35 @@ def check_valid_multiplicity(nids_by_entity) -> None:
184
196
  multiplicity = min([len(i) for i in nids_by_entity.values()])
185
197
  if multiplicity == 1: # no possible symmetry
186
198
  raise ValueError(
187
- "Input has no possible symmetry. If asymmetric motif, please use 2D conditioning inference instead."
199
+ "Input has no possible symmetry. If asymmetric motif, please use 2D conditioning inference instead.\n"
200
+ "Multiplicity: 1"
188
201
  )
189
202
 
190
203
  # Check that the input is not asymmetric
191
204
  multiplicity_good = [len(i) % multiplicity == 0 for i in nids_by_entity.values()]
192
205
  if not all(multiplicity_good):
193
- raise ValueError("Invalid multiplicities of subunits. Please check your input.")
206
+ raise ValueError(
207
+ "Expected multiplicity does not match for some entities.\n"
208
+ "Please modify your input to have one-to-one mapping (size, sequence, folding) of an entity across all chains.\n"
209
+ f"Expected Multiplicity: {multiplicity}.\n"
210
+ f"Computed Entity Mapping: {nids_by_entity}."
211
+ )
194
212
 
195
213
 
196
214
  def check_valid_subunit_size(nids_by_entity, pn_unit_id) -> None:
197
215
  """
198
216
  Check that the subunits in the input are of the same size.
199
217
  Arguments:
200
- nids_by_entity: dict mapping entity to ids
218
+ nids_by_entity: dict mapping entity to ids. e.g. {0: (['A_1', 'B_1', 'C_1']), 1: (['A_2', 'B_2', 'C_2'])}
219
+ pn_unit_id: array of ids. e.g. ['A_1', 'B_1', 'C_1', 'A_2', 'B_2', 'C_2']
201
220
  """
202
- for i, js in nids_by_entity.items():
203
- for j in js[1:]:
204
- if (pn_unit_id == js[0]).sum() != (pn_unit_id == j).sum():
205
- raise ValueError("Size mismatch in the input. Please check your file.")
221
+ for js in nids_by_entity.values():
222
+ for js_i in js[1:]:
223
+ if (pn_unit_id == js[0]).sum() != (pn_unit_id == js_i).sum():
224
+ raise ValueError(
225
+ f"Size mismatch between chain {js[0]} ({(pn_unit_id == js[0]).sum()} atoms) "
226
+ f"and chain {js_i} ({(pn_unit_id == js_i).sum()} atoms). Please check your input file."
227
+ )
206
228
 
207
229
 
208
230
  def check_min_atoms_to_align(natm_per_unique, reference_entity) -> None:
@@ -212,7 +234,10 @@ def check_min_atoms_to_align(natm_per_unique, reference_entity) -> None:
212
234
  nids_by_entity: dict mapping entity to ids
213
235
  """
214
236
  if natm_per_unique[reference_entity] < MIN_ATOMS_ALIGN:
215
- raise ValueError("Not enough atoms to align. Please check your input.")
237
+ raise ValueError(
238
+ f"Not enough atoms to align < {MIN_ATOMS_ALIGN} atoms."
239
+ f"Please provide a input with at least {MIN_ATOMS_ALIGN} atoms."
240
+ )
216
241
 
217
242
 
218
243
  def check_max_transforms(chains_to_consider) -> None:
@@ -224,7 +249,7 @@ def check_max_transforms(chains_to_consider) -> None:
224
249
  """
225
250
  if len(chains_to_consider) > MAX_TRANSFORMS:
226
251
  raise ValueError(
227
- "Number of transforms exceeds the max number of transforms (10)"
252
+ f"Number of transforms exceeds the max number of transforms ({MAX_TRANSFORMS})."
228
253
  )
229
254
 
230
255
 
@@ -10,12 +10,13 @@ def get_symmetry_frames_from_symmetry_id(symmetry_id):
10
10
  Returns:
11
11
  frames: list of rotation matrices
12
12
  """
13
+ from rfd3.inference.symmetry.symmetry_utils import SymmetryConfig
13
14
 
14
15
  # Get frames from symmetry id
15
16
  sym_conf = {}
16
- if isinstance(symmetry_id, dict):
17
+ if isinstance(symmetry_id, SymmetryConfig):
17
18
  sym_conf = symmetry_id
18
- symmetry_id = symmetry_id.get("id")
19
+ symmetry_id = symmetry_id.id
19
20
 
20
21
  if symmetry_id.lower().startswith("c"):
21
22
  order = int(symmetry_id[1:])
@@ -25,9 +26,9 @@ def get_symmetry_frames_from_symmetry_id(symmetry_id):
25
26
  frames = get_dihedral_frames(order)
26
27
  elif symmetry_id.lower() == "input_defined":
27
28
  assert (
28
- "symmetry_file" in sym_conf
29
+ sym_conf.symmetry_file is not None
29
30
  ), "symmetry_file is required for input_defined symmetry"
30
- frames = get_frames_from_file(sym_conf.get("symmetry_file"))
31
+ frames = get_frames_from_file(sym_conf.symmetry_file)
31
32
  else:
32
33
  raise ValueError(f"Symmetry id {symmetry_id} not supported")
33
34
 
@@ -120,7 +121,9 @@ def get_symmetry_frames_from_atom_array(src_atom_array, input_frames):
120
121
  computed_frames = [(R, np.array([0, 0, 0])) for R in Rs]
121
122
 
122
123
  # check that the computed frames match the input frames
123
- check_input_frames_match_symmetry_frames(computed_frames, input_frames)
124
+ check_input_frames_match_symmetry_frames(
125
+ computed_frames, input_frames, nids_by_entity
126
+ )
124
127
 
125
128
  return computed_frames
126
129
 
@@ -39,18 +39,36 @@ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
39
39
 
40
40
 
41
41
  class SymmetryConfig(BaseModel):
42
- # AM / HE TODO: feel free to flesh this out and add validation as needed
43
42
  model_config = ConfigDict(
44
43
  arbitrary_types_allowed=True,
45
44
  extra="allow",
46
45
  )
47
- id: Optional[str] = Field(None)
48
- # is_unsym_motif: Optional[np.ndarray[bool]] = Field(...)
49
- # is_symmetric_motif: bool = Field(...)
46
+ id: Optional[str] = Field(
47
+ None,
48
+ description="Symmetry group ID. e.g. 'C3', 'D2'. Only C and D symmetry types are supported currently.",
49
+ )
50
+ is_unsym_motif: Optional[str] = Field(
51
+ None,
52
+ description="Comma separated list of contig/ligand names that should not be symmetrized such as DNA strands. \
53
+ e.g. 'HEM' or 'Y1-11,Z16-25'",
54
+ )
55
+ is_symmetric_motif: bool = Field(
56
+ True,
57
+ description="If True, the input motifs are expected to be already symmetric and won't be symmetrized. \
58
+ If False, the all input motifs are expected to be ASU and will be symmetrized.",
59
+ )
60
+
61
+
62
+ def convery_sym_conf_to_symmetry_config(sym_conf: dict):
63
+ return SymmetryConfig(**sym_conf)
50
64
 
51
65
 
52
66
  def make_symmetric_atom_array(
53
- asu_atom_array, sym_conf: SymmetryConfig, sm=None, has_2d=False, src_atom_array=None
67
+ asu_atom_array,
68
+ sym_conf: SymmetryConfig | dict,
69
+ sm=None,
70
+ has_dist_cond=False,
71
+ src_atom_array=None,
54
72
  ):
55
73
  """
56
74
  apply symmetry to an atom array.
@@ -58,39 +76,33 @@ def make_symmetric_atom_array(
58
76
  asu_atom_array: atom array of the asymmetric unit
59
77
  sym_conf: symmetry configuration (dict, "id" key is required)
60
78
  sm: optional small molecule names (str, comma separated)
61
- has_2d: whether to add 2d entity annotations
79
+ has_dist_cond: whether to add 2d entity annotations
62
80
  Returns:
63
81
  new_asu_atom_array: atom array with symmetry applied
64
82
  """
65
- sym_conf = (
66
- sym_conf.model_dump()
67
- ) # TODO: JB: remove this line to keep as symmetry config for cleaner syntax(?)
68
- ranked_logger.info(f"Symmetry Configs: {sym_conf}")
83
+ if not isinstance(sym_conf, SymmetryConfig):
84
+ sym_conf = convery_sym_conf_to_symmetry_config(sym_conf)
69
85
 
70
- # Making sure that the symmetry config is valid
71
86
  check_symmetry_config(
72
- asu_atom_array,
73
- sym_conf,
74
- sm,
75
- has_dist_cond=has_2d,
76
- src_atom_array=src_atom_array,
87
+ asu_atom_array, sym_conf, sm, has_dist_cond, src_atom_array=src_atom_array
77
88
  )
78
89
  # Adding utility annotations to the asu atom array
79
90
  asu_atom_array = _add_util_annotations(asu_atom_array, sym_conf, sm)
80
91
 
81
- if has_2d: # NB: this will only work for asymmetric motifs at the moment - need to add functionality for symmetric motifs
92
+ if has_dist_cond: # NB: this will only work for asymmetric motifs at the moment - need to add functionality for symmetric motifs
82
93
  asu_atom_array = add_2d_entity_annotations(asu_atom_array)
83
94
 
84
95
  frames = get_symmetry_frames_from_symmetry_id(sym_conf)
85
96
 
86
97
  # If the motif is symmetric, we get the frames instead from the source atom array.
87
- if sym_conf.get("is_symmetric_motif"):
98
+ if sym_conf.is_symmetric_motif:
88
99
  assert (
89
100
  src_atom_array is not None
90
101
  ), "Source atom array must be provided for symmetric motifs"
91
- # if symmetric motif is provided, get the frames from the src atom array
102
+ # if symmetric motif is provided, get the frames from the src atom array.
92
103
  frames = get_symmetry_frames_from_atom_array(src_atom_array, frames)
93
- else:
104
+ elif (asu_atom_array._is_motif[~asu_atom_array._is_unsym_motif]).any():
105
+ # if the motifs that's not unsym motifs are present.
94
106
  raise NotImplementedError(
95
107
  "Asymmetric motif inputs are not implemented yet. please symmetrize the motif."
96
108
  )
@@ -101,7 +113,7 @@ def make_symmetric_atom_array(
101
113
  # Extracting all things at this moment that we will not want to symmetrize.
102
114
  # This includes: 1) unsym motifs, 2) ligands
103
115
  unsym_atom_arrays = []
104
- if sym_conf.get("is_unsym_motif"):
116
+ if sym_conf.is_unsym_motif:
105
117
  # unsym_motif_atom_array = get_unsym_motif(asu_atom_array, asu_atom_array._is_unsym_motif)
106
118
  # Now remove the unsym motifs from the asu atom array
107
119
  unsym_atom_arrays.append(asu_atom_array[asu_atom_array._is_unsym_motif])
@@ -128,7 +140,7 @@ def make_symmetric_atom_array(
128
140
  symmetrized_atom_array = struc.concatenate(symmetry_unit_list)
129
141
 
130
142
  # add 2D conditioning annotations
131
- if has_2d:
143
+ if has_dist_cond:
132
144
  symmetrized_atom_array = reannotate_2d_conditions(symmetrized_atom_array)
133
145
 
134
146
  # set all motifs to not have any symmetrization applied to them
@@ -183,7 +195,7 @@ def make_symmetric_atom_array_for_partial_diffusion(atom_array, sym_conf):
183
195
  frames = get_symmetry_frames_from_symmetry_id(sym_conf)
184
196
 
185
197
  # Add symmetry ID
186
- symmetry_ids = np.full(n, sym_conf.get("id"), dtype="U6")
198
+ symmetry_ids = np.full(n, sym_conf.id, dtype="U6")
187
199
  atom_array.set_annotation("symmetry_id", symmetry_ids)
188
200
 
189
201
  # Initialize transform annotations (use same format as original system)
@@ -244,7 +256,7 @@ def _add_util_annotations(asu_atom_array, sym_conf, sm):
244
256
  """
245
257
  n = asu_atom_array.shape[0]
246
258
  is_motif = get_motif_features(asu_atom_array)["is_motif_atom"].astype(np.bool_)
247
- is_sm = np.zeros(asu_atom_array.shape[0], dtype=bool)
259
+ is_sm = np.zeros(n, dtype=bool)
248
260
  is_asu = np.ones(n, dtype=bool)
249
261
  is_unsym_motif = np.zeros(n, dtype=bool)
250
262
 
@@ -257,8 +269,8 @@ def _add_util_annotations(asu_atom_array, sym_conf, sm):
257
269
  )
258
270
 
259
271
  # assign unsym motifs
260
- if sym_conf.get("is_unsym_motif"):
261
- unsym_motif_names = sym_conf["is_unsym_motif"].split(",")
272
+ if sym_conf.is_unsym_motif:
273
+ unsym_motif_names = sym_conf.is_unsym_motif.split(",")
262
274
  unsym_motif_names = expand_contig_unsym_motif(unsym_motif_names)
263
275
  is_unsym_motif = get_unsym_motif_mask(asu_atom_array, unsym_motif_names)
264
276
 
@@ -361,38 +373,4 @@ def apply_symmetry_to_xyz_atomwise(X_L, sym_feats, partial_diffusion=False):
361
373
  "blc,cd->bld", asu_xyz, sym_transforms[target_id][0].to(asu_xyz.dtype)
362
374
  ) + sym_transforms[target_id][1].to(asu_xyz.dtype)
363
375
 
364
- # Log inter-chain distances for debugging - use actual chain annotations
365
- if sym_X_L.shape[1] > 100: # Only for large structures
366
- # Use symmetry entity annotations to find different chains
367
- sym_entity_id = sym_feats["sym_entity_id"]
368
- unique_entities = torch.unique(sym_entity_id)
369
-
370
- if len(unique_entities) >= 2:
371
- # Get atoms from first two different entities
372
- entity_0_mask = sym_entity_id == unique_entities[0]
373
- entity_1_mask = sym_entity_id == unique_entities[1]
374
-
375
- if entity_0_mask.sum() > 0 and entity_1_mask.sum() > 0:
376
- entity_0_atoms = sym_X_L[0, entity_0_mask, :]
377
- entity_1_atoms = sym_X_L[0, entity_1_mask, :]
378
-
379
- # Sample subset to avoid memory issues
380
- entity_0_sample = entity_0_atoms[: min(50, entity_0_atoms.shape[0]), :]
381
- entity_1_sample = entity_1_atoms[: min(50, entity_1_atoms.shape[0]), :]
382
-
383
- min_distance = (
384
- torch.cdist(entity_0_sample, entity_1_sample).min().item()
385
- )
386
- ranked_logger.info(
387
- f"Min inter-chain distance after symmetry: {min_distance:.2f} Å"
388
- )
389
-
390
- # Also log the centers of each entity
391
- entity_0_center = entity_0_atoms.mean(dim=0)
392
- entity_1_center = entity_1_atoms.mean(dim=0)
393
- center_distance = torch.norm(entity_0_center - entity_1_center).item()
394
- ranked_logger.info(
395
- f"Distance between chain centers: {center_distance:.2f} Å"
396
- )
397
-
398
376
  return sym_X_L
rfd3/run_inference.py CHANGED
@@ -12,7 +12,9 @@ load_dotenv(override=True)
12
12
 
13
13
  # For pip-installed package, configs should be relative to this file
14
14
  # Adjust this path based on where configs are bundled in the package
15
- _config_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "configs")
15
+ _config_path = os.path.join(
16
+ os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "configs"
17
+ )
16
18
 
17
19
 
18
20
  @hydra.main(
rfd3/utils/inference.py CHANGED
@@ -391,6 +391,29 @@ def ensure_input_is_abspath(args: dict, path: PathLike | None):
391
391
  return args
392
392
 
393
393
 
394
+ def ensure_inference_sampler_matches_design_spec(
395
+ design_spec: dict, inference_sampler: dict | None = None
396
+ ):
397
+ """
398
+ Ensure the inference sampler is set to the correct sampler for the design specification.
399
+ Args:
400
+ design_spec: Design specification dictionary
401
+ inference_sampler: Inference sampler dictionary
402
+ """
403
+ has_symmetry_specification = [
404
+ True if "symmetry" in item.keys() else False for item in design_spec.values()
405
+ ]
406
+ if any(has_symmetry_specification):
407
+ if (
408
+ inference_sampler is None
409
+ or inference_sampler.get("kind", "default") != "symmetry"
410
+ ):
411
+ raise ValueError(
412
+ "You requested for symmetric designs, but inference sampler is not set to symmetry. "
413
+ "Please add inference_sampler.kind='symmetry' to your command."
414
+ )
415
+
416
+
394
417
  #################################################################################
395
418
  # Custom infer_ori functions
396
419
  #################################################################################