rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,165 @@
1
+ from typing import Any, Dict, List, Optional, Union
2
+
3
+ import numpy as np
4
+ from biotite.structure import AtomArray, get_residue_starts
5
+ from pydantic import (
6
+ BaseModel,
7
+ ConfigDict,
8
+ Field,
9
+ model_serializer,
10
+ model_validator,
11
+ )
12
+
13
+ from foundry.utils.components import (
14
+ ComponentStr,
15
+ fetch_mask_from_idx,
16
+ get_name_mask,
17
+ split_contig,
18
+ unravel_components,
19
+ )
20
+
21
+ # ============================================================================
22
+ # Input Specification & Validation
23
+ # ============================================================================
24
+
25
+
26
+ class InputSelection(BaseModel):
27
+ model_config = ConfigDict(
28
+ arbitrary_types_allowed=True,
29
+ str_strip_whitespace=True,
30
+ str_min_length=1,
31
+ )
32
+ data: Dict[ComponentStr | str, List[str]] = Field(
33
+ ..., description="Validated selection dictionary", exclude=True
34
+ )
35
+ raw: Any = Field(..., description="Original input value")
36
+ mask: np.ndarray[np.bool_] = Field(
37
+ ..., description="Boolean mask over atom array", exclude=True
38
+ )
39
+ tokens: Optional[Dict[ComponentStr | str, AtomArray]] = Field(
40
+ ..., description="Selected atom arrays per component", exclude=True
41
+ )
42
+
43
+ @classmethod
44
+ def from_any(
45
+ cls, v: Union[str, bool, dict, None], atom_array: AtomArray
46
+ ) -> Optional["InputSelection"]:
47
+ """Create InputSelection from various input types."""
48
+ if v is None:
49
+ return None
50
+ data, mask, _ = from_any_(v=v, atom_array=atom_array)
51
+ return cls(
52
+ raw=v,
53
+ data=data,
54
+ mask=mask,
55
+ tokens=None,
56
+ )
57
+
58
+ @model_validator(mode="after")
59
+ def check_keys(self):
60
+ # lightweight validation that all keys have contig format (are splittable indices)
61
+ assert all([split_contig(k) for k in self.data.keys()])
62
+ return self
63
+
64
+ # Wrapper functionality as dict-like
65
+ def __getitem__(self, key: str) -> List[str]:
66
+ """Allow dict-like access."""
67
+ return self.data[key]
68
+
69
+ def items(self):
70
+ return self.data.items()
71
+
72
+ def keys(self):
73
+ return self.data.keys()
74
+
75
+ def values(self):
76
+ return self.data.values()
77
+
78
+ def get(self, *args, **kwargs):
79
+ return self.data.get(*args, **kwargs)
80
+
81
+ # Serialization & repr
82
+ def __repr__(self) -> str:
83
+ num_atoms = self.mask.sum() if hasattr(self.mask, "sum") else 0
84
+ num_tokens = len(self.tokens) if self.tokens else 0
85
+ return (
86
+ f"InputSelection(raw={self.raw!r}, atoms={num_atoms}, tokens={num_tokens})"
87
+ )
88
+
89
+ @model_serializer
90
+ def serialize_raw(self) -> Any:
91
+ return self.raw
92
+
93
+ # Listed as separate methods for future changes to parsing.
94
+ def get_mask(self):
95
+ return self.mask
96
+
97
+ def get_tokens(self, aa):
98
+ _, _, tokens = from_any_(v=self.raw, atom_array=aa)
99
+ return tokens
100
+
101
+
102
+ def from_any_(v: Any, atom_array: AtomArray):
103
+ data_norm = canonicalize_(v, atom_array)
104
+
105
+ # Canonicalize dictionaries to SelectionDict (I.e. convert "ALL" / "TIP" -> concrete atom names)
106
+ data_split = {}
107
+ mask = np.array([False] * len(atom_array))
108
+ tokens = {}
109
+ for idx, atm_names in data_norm.items():
110
+ # Find atom array subset
111
+ comp_mask = fetch_mask_from_idx(idx, atom_array=atom_array)
112
+ token = atom_array[comp_mask]
113
+
114
+ comp_mask_subset = get_name_mask(
115
+ token.atom_name, atm_names, token.res_name[0]
116
+ ) # [N_atoms_in_token,]
117
+
118
+ # Split to atom names
119
+ data_split[idx] = token.atom_name[comp_mask_subset].tolist()
120
+
121
+ # Update mask & token dictionary
122
+ mask[comp_mask] = comp_mask_subset
123
+ tokens[idx] = token[comp_mask_subset]
124
+
125
+ return (data_split, mask, tokens)
126
+
127
+
128
+ def canonicalize_(v, atom_array: AtomArray):
129
+ # Canonicalize inputs to dictionaries of strings:
130
+ # "A11-12" -> {"A11": "N,CA,C,...", "A12": "N,CA,C,..."}
131
+ # True -> {"A1": "ALL", "A2": "ALL", ...}
132
+ # False -> {"A1": "", "A2": "", ...}
133
+ # "LIG" -> {"B1": "ALL", "C1": "ALL"} (for two ligands named LIG)
134
+ data = {}
135
+ if isinstance(v, str):
136
+ for component in unravel_components(
137
+ v, atom_array=atom_array, allow_multiple_matches=True
138
+ ):
139
+ if (
140
+ isinstance(component, str) and component[0].isalpha()
141
+ ): # filter on valid chain IDs
142
+ data[component] = "ALL"
143
+
144
+ elif isinstance(v, bool):
145
+ starts = get_residue_starts(atom_array, add_exclusive_stop=True)
146
+ for start, stop in zip(starts[:-1], starts[1:]):
147
+ token = atom_array[start:stop]
148
+ # All atoms selected for every token or None
149
+ data[f"{token.chain_id[0]}{token.res_id[0]}"] = "ALL" if v else ""
150
+
151
+ elif isinstance(v, dict):
152
+ # Ensure all values of dictionaries are strings
153
+ data = {}
154
+ for k, vv in v.items():
155
+ for component in unravel_components(
156
+ k, atom_array=atom_array, allow_multiple_matches=True
157
+ ):
158
+ if isinstance(vv, list):
159
+ data[component] = ",".join(vv)
160
+ else:
161
+ data[component] = vv
162
+ else:
163
+ raise ValueError(f"Cannot convert {type(v)} to InputSelection")
164
+
165
+ return data
@@ -0,0 +1,298 @@
1
+ import numpy as np
2
+ from rfd3.inference.symmetry.frames import (
3
+ decompose_symmetry_frame,
4
+ get_symmetry_frames_from_symmetry_id,
5
+ )
6
+
7
+ from foundry.utils.ddp import RankedLogger
8
+
9
+ FIXED_TRANSFORM_ID = -1
10
+ FIXED_ENTITY_ID = -1
11
+ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
12
+
13
+
14
+ ########################################################
15
+ # Symmetry annotations
16
+ ########################################################
17
+
18
+
19
+ def add_sym_annotations(atom_array, sym_conf):
20
+ """
21
+ Add symmetry base annotations to an atom array.
22
+ Arguments:
23
+ atom_array: atom array of symmetry subunit
24
+ sym_conf: symmetry configuration (dict, "id" key is required)
25
+ """
26
+ n = atom_array.shape[0]
27
+ # which is the asymmetric unit? At this point, we annotate everything as the asu
28
+ is_asu = np.full(n, True, dtype=np.bool_)
29
+ atom_array.set_annotation("is_sym_asu", is_asu)
30
+ # symmetry_id
31
+ symmetry_ids = np.full(n, sym_conf.get("id"), dtype="U6")
32
+ atom_array.set_annotation("symmetry_id", symmetry_ids)
33
+ return atom_array
34
+
35
+
36
+ def add_sym_annotations_to_fixed_motif(atom_array):
37
+ """
38
+ Add symmetry annotations to a motif atom array.
39
+ Arguments:
40
+ atom_array: atom array of symmetry subunit
41
+ """
42
+ n = atom_array.shape[0]
43
+
44
+ # setting the identity transform
45
+ Ori, X, Y = decompose_symmetry_frame((np.eye(3), np.zeros(3)))
46
+ Oris = np.full(n, Ori)
47
+ Xs = np.full(n, X)
48
+ Ys = np.full(n, Y)
49
+ atom_array.set_annotation("sym_transform_Ori", Oris)
50
+ atom_array.set_annotation("sym_transform_X", Xs)
51
+ atom_array.set_annotation("sym_transform_Y", Ys)
52
+
53
+ transform_ids = np.full(n, FIXED_TRANSFORM_ID, dtype=np.int32)
54
+ atom_array.set_annotation("sym_transform_id", transform_ids)
55
+ entity_ids = np.full(n, FIXED_ENTITY_ID, dtype=np.int32)
56
+ atom_array.set_annotation("sym_entity_id", entity_ids)
57
+ # make sure that the motif is not the asu
58
+ is_sym_asu = np.full(n, False, dtype=np.bool_)
59
+ atom_array.set_annotation("is_sym_asu", is_sym_asu)
60
+ return atom_array
61
+
62
+
63
+ def add_src_sym_component_annotations(atom_array):
64
+ """
65
+ Add src_sym_component annotations to an atom array.
66
+ This is used to correctly map the original motif id to diffused unindexed motifs.
67
+ Arguments:
68
+ atom_array: atom array with src_component annotated
69
+ """
70
+ if "src_component" not in atom_array.get_annotation_categories():
71
+ return atom_array
72
+
73
+ src_sym_component = atom_array.src_component.copy()
74
+ src_tokens = np.unique(atom_array.src_component)
75
+
76
+ for src_token in src_tokens:
77
+ # Skip non-alphabetic tokens
78
+ if len(src_token) == 0:
79
+ continue
80
+ if not src_token[0].isalpha():
81
+ continue
82
+
83
+ # Get block of atoms with this src token
84
+ src_block_mask = atom_array.src_component == src_token
85
+ src_block = atom_array[src_block_mask]
86
+
87
+ # Skip if not all unindexed motif atoms
88
+ if not src_block.is_motif_atom_unindexed.all():
89
+ continue
90
+
91
+ # Update src component with chain ID prefix
92
+ for chain_id in np.unique(src_block.chain_id):
93
+ chain_mask = src_block.chain_id == chain_id
94
+ src_block.src_component[chain_mask] = chain_id + src_token[1:]
95
+
96
+ src_sym_component[src_block_mask] = src_block.src_component
97
+
98
+ atom_array.set_annotation("src_sym_component", src_sym_component)
99
+ return atom_array
100
+
101
+
102
+ def fix_3D_sym_motif_annotations(atom_array):
103
+ """
104
+ Add fixed motif annotations to the 3D NON-indexed motifs (only unindexed and ligands).
105
+ since indexed motifs are contiguously connected to generative residues,
106
+ it should NOT be fixed, instead get symmetrized at each step
107
+ Arguments:
108
+ atom_array: atom array
109
+ """
110
+ # fixed_motif_mask = atom_array.is_motif_atom_with_fixed_coord == 1
111
+ fixed_motif_mask = atom_array._is_motif & ~atom_array._is_indexed_motif
112
+ fixed_motif_array = atom_array[fixed_motif_mask].copy()
113
+ fixed_motif_array = add_sym_annotations_to_fixed_motif(fixed_motif_array)
114
+ atom_array[fixed_motif_mask] = fixed_motif_array
115
+ return atom_array
116
+
117
+
118
+ def add_sym_transform_annotations(atom_array, transform_id, frame, is_asu=False):
119
+ """
120
+ Add symmetry annotations to an atom array.
121
+ Arguments:
122
+ atom_array: atom array of symmetry subunit
123
+ transform_id: index of the transform frame
124
+ is_asu: whether this is the asymmetric unit
125
+ Returns:
126
+ atom_array: atom array with symmetry annotations
127
+ """
128
+ Ori, X, Y = decompose_symmetry_frame(frame)
129
+ n = atom_array.shape[0]
130
+
131
+ # symmetry transform (decomposed into Ori, X, Y)
132
+ Oris = np.full(n, Ori)
133
+ Xs = np.full(n, X)
134
+ Ys = np.full(n, Y)
135
+ atom_array.set_annotation("sym_transform_Ori", Oris)
136
+ atom_array.set_annotation("sym_transform_X", Xs)
137
+ atom_array.set_annotation("sym_transform_Y", Ys)
138
+
139
+ # symmetry transform id
140
+ transform_ids = np.full(n, transform_id, dtype=np.int32)
141
+ atom_array.set_annotation("sym_transform_id", transform_ids)
142
+
143
+ # entity ids - this will help keep track of different multiplicities
144
+ # if there are sm, they will have different entity ids from the prot atoms
145
+ unique_chain_ids = np.unique(atom_array.chain_id).tolist()
146
+ unique_chain_ids.sort()
147
+ entity_ids = np.array([unique_chain_ids.index(id) for id in atom_array.chain_id])
148
+ atom_array.set_annotation("sym_entity_id", entity_ids)
149
+
150
+ is_sym_asu = np.full(n, is_asu, dtype=np.bool_)
151
+ atom_array.set_annotation("is_sym_asu", is_sym_asu)
152
+
153
+ return atom_array
154
+
155
+
156
+ def apply_symmetry_to_atomarray_coord(atom_array, frame):
157
+ """
158
+ Apply symmetry to the atom array coordinates.
159
+ Arguments:
160
+ atom_array: atom array
161
+ frame: symmetry frame (R, T)
162
+ """
163
+ R, T = frame
164
+ atom_array.coord = atom_array.coord @ R.T
165
+ atom_array.coord += T # T should be 0 for most symmetry cases
166
+ return atom_array
167
+
168
+
169
+ ########################################################
170
+ # Motif functions
171
+ ########################################################
172
+
173
+
174
+ def annotate_unsym_atom_array(atom_array):
175
+ """
176
+ Annotate the unsym motif and return it.
177
+ Arguments:
178
+ atom_array: atom array
179
+ unsym_motif_mask: mask of unsym motifs
180
+ """
181
+ unsym_atom_array = atom_array.copy()
182
+ unsym_atom_array._is_asu = np.full(unsym_atom_array.shape[0], False, dtype=np.bool_)
183
+ unsym_atom_array.is_sym_asu = unsym_atom_array._is_asu
184
+ unsym_atom_array = reset_chain_ids(
185
+ unsym_atom_array, start_id="a"
186
+ ) # give it a lowercase chain id to avoid confusion with symmetry units
187
+ unsym_atom_array = add_sym_annotations_to_fixed_motif(unsym_atom_array)
188
+ return unsym_atom_array
189
+
190
+
191
+ ########################################################
192
+ # 2D conditioning functions
193
+ ########################################################
194
+
195
+
196
+ def add_2d_entity_annotations(atom_array):
197
+ entity_ids = np.zeros(atom_array.shape[0], dtype=np.int32)
198
+ categories = get_2d_annotation_categories(atom_array)
199
+ entity_id = 1
200
+ for i, anno in enumerate(categories):
201
+ entity_id = i + 1
202
+ entity_ids[atom_array.get_annotation(anno) == 1] = entity_id
203
+ atom_array.set_annotation("_2d_entity_id", entity_ids)
204
+ return atom_array
205
+
206
+
207
+ def reannotate_2d_entity_ids(atom_array, transform_id):
208
+ if "_2d_entity_id" not in atom_array.get_annotation_categories():
209
+ return atom_array
210
+ _2d_annos = get_2d_annotation_categories(atom_array)
211
+ frames = get_symmetry_frames_from_symmetry_id(atom_array.symmetry_id[0])
212
+ # NOTE: assuming its either 2d cond is within a subunit was specified or all active sites were explicity specified
213
+ max_entity_id = max(len(_2d_annos), len(frames))
214
+ mask = atom_array.get_annotation("_2d_entity_id") != 0
215
+ atom_array._2d_entity_id[mask] = (
216
+ (atom_array._2d_entity_id[mask] + transform_id - 1) % max_entity_id
217
+ ) + 1
218
+ return atom_array
219
+
220
+
221
+ def get_2d_annotation_categories(atom_array):
222
+ categories = []
223
+ for anno in atom_array.get_annotation_categories():
224
+ if "2d_condition" in anno:
225
+ categories.append(anno)
226
+ categories.sort() # sort to make sure that categories are in ascending order
227
+ return categories
228
+
229
+
230
+ def reannotate_2d_conditions(atom_array):
231
+ entity_ids_anno = atom_array.get_annotation("_2d_entity_id")
232
+ entity_ids = [d for d in np.unique(entity_ids_anno) if d != 0]
233
+ categories = get_2d_annotation_categories(atom_array)
234
+ diff = len(entity_ids) - len(categories)
235
+ if diff > 0:
236
+ for i in range(len(categories), len(categories) + diff):
237
+ categories.append(f"{categories[0]}_{i}")
238
+ for d, anno in zip(entity_ids, categories):
239
+ atom_array.set_annotation(anno, entity_ids_anno == d)
240
+ atom_array.del_annotation("_2d_entity_id")
241
+ return atom_array
242
+
243
+
244
+ ########################################################
245
+ # Utility functions
246
+ ########################################################
247
+
248
+
249
+ def reset_chain_ids(atom_array, start_id):
250
+ """
251
+ Reset the chain ids and pn_unit_iids of an atom array to start from the given id.
252
+ Arguments:
253
+ atom_array: atom array with chain_ids and pn_unit_iids annotated
254
+ """
255
+ chain_ids = np.unique(atom_array.chain_id)
256
+ new_chain_range = range(ord(start_id), ord(start_id) + len(chain_ids))
257
+ for new_id, old_id in zip(new_chain_range, chain_ids):
258
+ atom_array.chain_id[atom_array.chain_id == old_id] = chr(new_id)
259
+ atom_array.pn_unit_iid = atom_array.chain_id
260
+ return atom_array
261
+
262
+
263
+ def reannotate_chain_ids(atom_array, offset, multiplier=0):
264
+ """
265
+ Reannotate the chain ids and pn_unit_iids of an atom array.
266
+ Arguments:
267
+ atom_array: protein atom array with chain_ids and pn_unit_iids annotated
268
+ offset: offset to add to the chain ids
269
+ multiplier: multiplier to add to the chain ids
270
+ """
271
+ chain_ids_int = (
272
+ np.array([ord(c) for c in atom_array.chain_id]) + offset * multiplier
273
+ )
274
+ chain_ids = np.array([chr(id) for id in chain_ids_int], dtype=str)
275
+ atom_array.chain_id = chain_ids
276
+ atom_array.pn_unit_iid = chain_ids
277
+ return atom_array
278
+
279
+
280
+ def get_symmetry_unit(asu_atom_array, transform_id, frame):
281
+ """
282
+ Annotate the ASU protein atom array and return it for each symmetry unit.
283
+ Arguments:
284
+ asu_atom_array: atom array of the asymmetric unit, annotated with symmetry_id
285
+ transform_id: index of the symmetry unit
286
+ frame: symmetry frame
287
+ """
288
+ num_prot_chains = len(np.unique(asu_atom_array.chain_id))
289
+ symmetry_unit = asu_atom_array.copy()
290
+ symmetry_unit = reannotate_chain_ids(symmetry_unit, num_prot_chains, transform_id)
291
+ symmetry_unit = reannotate_2d_entity_ids(symmetry_unit, transform_id)
292
+ symmetry_unit = add_sym_transform_annotations(
293
+ symmetry_unit, transform_id, frame, is_asu=(transform_id == 0)
294
+ )
295
+ # apply symmetry to indexed motifs
296
+ # at this point, the diffused coordinates are at the origin/ have no xyz
297
+ symmetry_unit = apply_symmetry_to_atomarray_coord(symmetry_unit, frame)
298
+ return symmetry_unit
@@ -0,0 +1,241 @@
1
+ import numpy as np
2
+ from rfd3.inference.symmetry.contigs import expand_contig_unsym_motif
3
+ from rfd3.transforms.conditioning_base import get_motif_features
4
+
5
+ from foundry.utils.ddp import RankedLogger
6
+
7
+ MIN_ATOMS_ALIGN = 100
8
+ MAX_TRANSFORMS = 10
9
+ RMSD_CUT = 1.0 # Angstroms
10
+
11
+ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
12
+
13
+
14
+ def check_symmetry_config(
15
+ atom_array, sym_conf, sm, has_dist_cond, src_atom_array=None, partial=False
16
+ ):
17
+ """
18
+ Check if the symmetry configuration is valid. Add all basic checks here.
19
+ """
20
+
21
+ assert sym_conf.get("id"), "symmetry_id is required. e.g. {'id': 'C2'}"
22
+ # if unsym motif is provided, check that each motif name is in the atom array
23
+ if sym_conf.get("is_unsym_motif"):
24
+ assert (
25
+ src_atom_array is not None
26
+ ), "Source atom array must be provided for symmetric motifs"
27
+ unsym_motif_names = sym_conf["is_unsym_motif"].split(",")
28
+ unsym_motif_names = expand_contig_unsym_motif(unsym_motif_names)
29
+ for n in unsym_motif_names:
30
+ if (sm and n not in sm.split(",")) and (n not in atom_array.src_component):
31
+ raise ValueError(f"Unsym motif {n} not found in atom_array")
32
+ if (
33
+ get_motif_features(atom_array)["is_motif_token"].any()
34
+ and not sym_conf.get("is_symmetric_motif")
35
+ and not has_dist_cond
36
+ ):
37
+ raise ValueError(
38
+ "Asymmetric motif inputs should be distance constrained. "
39
+ "Use atomwise_fixed_dist to constrain the distance between the motif atoms."
40
+ )
41
+ # else: if unconditional symmetry, no need to have symmetric input motif
42
+
43
+ if partial and not sym_conf.get("is_symmetric_motif"):
44
+ raise ValueError(
45
+ "Partial diffusion with symmetry is only supported for symmetric inputs."
46
+ )
47
+
48
+
49
+ def check_atom_array_is_symmetric(atom_array):
50
+ """
51
+ Check if the atom array is symmetric. This is NOT to check that the atom array symmetry matches that of the symmetry_id.
52
+ Arguments:
53
+ atom_array: atom arrays to check
54
+ Returns:
55
+ bool: True if the atom array is symmetric, False otherwise
56
+ """
57
+ # TODO: Implement something like this https://github.com/baker-laboratory/ipd/blob/main/ipd/sym/sym_detect.py#L303
58
+ # and maybe this https://github.com/baker-laboratory/ipd/blob/main/ipd/sym/sym_detect.py#L231
59
+
60
+ import biotite.structure as struc
61
+ from rfd3.inference.symmetry.atom_array import (
62
+ apply_symmetry_to_atomarray_coord,
63
+ )
64
+ from rfd3.inference.symmetry.frames import (
65
+ get_symmetry_frames_from_symmetry_id,
66
+ )
67
+
68
+ # remove hetero atoms
69
+ atom_array = atom_array[~atom_array.hetero]
70
+ if len(atom_array) == 0:
71
+ ranked_logger.info("Atom array has no protein chains. Please check your input.")
72
+ return False
73
+
74
+ chains = np.unique(atom_array.chain_id)
75
+ asu_mask = atom_array.chain_id == chains[0]
76
+ asu_atoms = atom_array[asu_mask].copy()
77
+
78
+ # Check that all atom arrays have the same number of atoms
79
+ for chain in chains[1:]:
80
+ chain_mask = atom_array.chain_id == chain
81
+ if len(asu_atoms) != len(atom_array[chain_mask]):
82
+ ranked_logger.info(
83
+ f"Atom array has different number of atoms in chain {chain}. {len(asu_atoms)} != {len(atom_array[chain_mask])}"
84
+ )
85
+ return False
86
+
87
+ # Check that all atom arrays have the same atoms
88
+ for chain in chains[1:]:
89
+ chain_mask = atom_array.chain_id == chain
90
+ for i in range(len(asu_atoms)):
91
+ if asu_atoms.atom_name[i] != atom_array[chain_mask].atom_name[i]:
92
+ ranked_logger.info(
93
+ f"Atom array has different atoms in chain {chain}. {asu_atoms.atom_name[i]} != {atom_array[chain_mask].atom_name[i]}"
94
+ )
95
+ return False
96
+
97
+ # Check that the atom array aligns with the standard symmetry frames
98
+ standard_frames = get_symmetry_frames_from_symmetry_id(atom_array.symmetry_id[0])
99
+ standard_atom_array = []
100
+ for frame in standard_frames:
101
+ symmed_atoms = apply_symmetry_to_atomarray_coord(asu_atoms, frame)
102
+ standard_atom_array.append(symmed_atoms)
103
+ standard_atom_array = struc.concatenate(standard_atom_array)
104
+
105
+ R_standard_obtained = find_optimal_rotation(
106
+ standard_atom_array.coord, atom_array.coord
107
+ )
108
+
109
+ if R_standard_obtained is None:
110
+ ranked_logger.info(
111
+ "Atom array does not align with the standard symmetry frames."
112
+ )
113
+ return False
114
+
115
+ return True
116
+
117
+
118
+ def find_optimal_rotation(coords1, coords2, max_points=1000):
119
+ """
120
+ Find optimal rotation matrix between two sets of coordinates using Kabsch algorithm.
121
+
122
+ Args:
123
+ coords1: reference coordinates (N, 3)
124
+ coords2: target coordinates (N, 3)
125
+ max_points: maximum number of points to use for efficiency
126
+
127
+ Returns:
128
+ rotation_matrix: 3x3 rotation matrix or None if failed
129
+ """
130
+ if len(coords1) > max_points:
131
+ indices = np.random.choice(len(coords1), max_points, replace=False)
132
+ coords1 = coords1[indices]
133
+ coords2 = coords2[indices]
134
+
135
+ # Ensure same number of points
136
+ min_len = min(len(coords1), len(coords2))
137
+ coords1 = coords1[:min_len]
138
+ coords2 = coords2[:min_len]
139
+ if min_len < 3:
140
+ return None
141
+
142
+ # Kabsch algorithm
143
+ try:
144
+ centroid1 = np.mean(coords1, axis=0)
145
+ centroid2 = np.mean(coords2, axis=0)
146
+ coords1_centered = coords1 - centroid1
147
+ coords2_centered = coords2 - centroid2
148
+
149
+ # Compute covariance matrix
150
+ H = coords1_centered.T @ coords2_centered
151
+
152
+ U, S, Vt = np.linalg.svd(H)
153
+ R = Vt.T @ U.T
154
+ # Ensure proper rotation matrix
155
+ if np.linalg.det(R) < 0:
156
+ Vt[-1, :] *= -1
157
+ R = Vt.T @ U.T
158
+ return R
159
+
160
+ except Exception as e:
161
+ print(f"Error in rotation calculation: {e}")
162
+ return None
163
+
164
+
165
+ def check_input_frames_match_symmetry_frames(computed_frames, original_frames) -> None:
166
+ """
167
+ Check if the atom array matches the symmetry_id.
168
+ Arguments:
169
+ computed_frames: list of computed frames
170
+ original_frames: list of original frames
171
+ """
172
+ assert len(computed_frames) == len(
173
+ original_frames
174
+ ), "Number of computed frames does not match number of original frames"
175
+
176
+
177
+ def check_valid_multiplicity(nids_by_entity) -> None:
178
+ """
179
+ Check if the multiplicity is valid.
180
+ Arguments:
181
+ nids_by_entity: dict mapping entity to ids
182
+ """
183
+ # get multiplicities of subunits
184
+ multiplicity = min([len(i) for i in nids_by_entity.values()])
185
+ if multiplicity == 1: # no possible symmetry
186
+ raise ValueError(
187
+ "Input has no possible symmetry. If asymmetric motif, please use 2D conditioning inference instead."
188
+ )
189
+
190
+ # Check that the input is not asymmetric
191
+ multiplicity_good = [len(i) % multiplicity == 0 for i in nids_by_entity.values()]
192
+ if not all(multiplicity_good):
193
+ raise ValueError("Invalid multiplicities of subunits. Please check your input.")
194
+
195
+
196
+ def check_valid_subunit_size(nids_by_entity, pn_unit_id) -> None:
197
+ """
198
+ Check that the subunits in the input are of the same size.
199
+ Arguments:
200
+ nids_by_entity: dict mapping entity to ids
201
+ """
202
+ for i, js in nids_by_entity.items():
203
+ for j in js[1:]:
204
+ if (pn_unit_id == js[0]).sum() != (pn_unit_id == j).sum():
205
+ raise ValueError("Size mismatch in the input. Please check your file.")
206
+
207
+
208
+ def check_min_atoms_to_align(natm_per_unique, reference_entity) -> None:
209
+ """
210
+ Check that we have enough atoms to align.
211
+ Arguments:
212
+ nids_by_entity: dict mapping entity to ids
213
+ """
214
+ if natm_per_unique[reference_entity] < MIN_ATOMS_ALIGN:
215
+ raise ValueError("Not enough atoms to align. Please check your input.")
216
+
217
+
218
+ def check_max_transforms(chains_to_consider) -> None:
219
+ """
220
+ Check that we are not exceeding the max number of transforms.
221
+ Arguments:
222
+ chains_to_consider: list of chains to consider
223
+ max_transforms: max number of transforms
224
+ """
225
+ if len(chains_to_consider) > MAX_TRANSFORMS:
226
+ raise ValueError(
227
+ "Number of transforms exceeds the max number of transforms (10)"
228
+ )
229
+
230
+
231
+ def check_max_rmsds(rmsds) -> None:
232
+ """
233
+ Check that the RMSD between the reference molecule and the other molecules is not too big.
234
+ Arguments:
235
+ rmsds: dict mapping chain to RMSD
236
+ """
237
+ if max(rmsds.values()) > RMSD_CUT:
238
+ ranked_logger.warning(
239
+ f"RMSD between the reference molecule and the other molecules is too big ({max(rmsds.values())} > {RMSD_CUT}). Please provide a symmetric input PDB file."
240
+ )
241
+ # raise ValueError(f"RMSD between the reference molecule and the other molecules is too big ({max(rmsds.values())} > {RMSD_CUT}). Please provide a symmetric input PDB file.")