rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,398 @@
1
+ from typing import Optional
2
+
3
+ import biotite.structure as struc
4
+ import numpy as np
5
+ import torch
6
+ from pydantic import (
7
+ BaseModel,
8
+ ConfigDict,
9
+ Field,
10
+ )
11
+ from rfd3.inference.symmetry.atom_array import (
12
+ FIXED_ENTITY_ID,
13
+ FIXED_TRANSFORM_ID,
14
+ add_2d_entity_annotations,
15
+ add_src_sym_component_annotations,
16
+ add_sym_annotations,
17
+ annotate_unsym_atom_array,
18
+ fix_3D_sym_motif_annotations,
19
+ get_symmetry_unit,
20
+ reannotate_2d_conditions,
21
+ )
22
+ from rfd3.inference.symmetry.checks import (
23
+ check_symmetry_config,
24
+ )
25
+ from rfd3.inference.symmetry.contigs import (
26
+ expand_contig_unsym_motif,
27
+ get_unsym_motif_mask,
28
+ )
29
+ from rfd3.inference.symmetry.frames import (
30
+ get_symmetry_frames_from_atom_array,
31
+ get_symmetry_frames_from_symmetry_id,
32
+ )
33
+ from rfd3.transforms.conditioning_base import get_motif_features
34
+
35
+ from foundry.utils.components import fetch_mask_from_component
36
+ from foundry.utils.ddp import RankedLogger
37
+
38
+ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
39
+
40
+
41
+ class SymmetryConfig(BaseModel):
42
+ # AM / HE TODO: feel free to flesh this out and add validation as needed
43
+ model_config = ConfigDict(
44
+ arbitrary_types_allowed=True,
45
+ extra="allow",
46
+ )
47
+ id: Optional[str] = Field(None)
48
+ # is_unsym_motif: Optional[np.ndarray[bool]] = Field(...)
49
+ # is_symmetric_motif: bool = Field(...)
50
+
51
+
52
+ def make_symmetric_atom_array(
53
+ asu_atom_array, sym_conf: SymmetryConfig, sm=None, has_2d=False, src_atom_array=None
54
+ ):
55
+ """
56
+ apply symmetry to an atom array.
57
+ Arguments:
58
+ asu_atom_array: atom array of the asymmetric unit
59
+ sym_conf: symmetry configuration (dict, "id" key is required)
60
+ sm: optional small molecule names (str, comma separated)
61
+ has_2d: whether to add 2d entity annotations
62
+ Returns:
63
+ new_asu_atom_array: atom array with symmetry applied
64
+ """
65
+ sym_conf = (
66
+ sym_conf.model_dump()
67
+ ) # TODO: JB: remove this line to keep as symmetry config for cleaner syntax(?)
68
+ ranked_logger.info(f"Symmetry Configs: {sym_conf}")
69
+
70
+ # Making sure that the symmetry config is valid
71
+ check_symmetry_config(
72
+ asu_atom_array,
73
+ sym_conf,
74
+ sm,
75
+ has_dist_cond=has_2d,
76
+ src_atom_array=src_atom_array,
77
+ )
78
+ # Adding utility annotations to the asu atom array
79
+ asu_atom_array = _add_util_annotations(asu_atom_array, sym_conf, sm)
80
+
81
+ if has_2d: # NB: this will only work for asymmetric motifs at the moment - need to add functionality for symmetric motifs
82
+ asu_atom_array = add_2d_entity_annotations(asu_atom_array)
83
+
84
+ frames = get_symmetry_frames_from_symmetry_id(sym_conf)
85
+
86
+ # If the motif is symmetric, we get the frames instead from the source atom array.
87
+ if sym_conf.get("is_symmetric_motif"):
88
+ assert (
89
+ src_atom_array is not None
90
+ ), "Source atom array must be provided for symmetric motifs"
91
+ # if symmetric motif is provided, get the frames from the src atom array
92
+ frames = get_symmetry_frames_from_atom_array(src_atom_array, frames)
93
+ else:
94
+ raise NotImplementedError(
95
+ "Asymmetric motif inputs are not implemented yet. please symmetrize the motif."
96
+ )
97
+
98
+ # Add symmetry annotations to the asu atom array
99
+ asu_atom_array = add_sym_annotations(asu_atom_array, sym_conf)
100
+
101
+ # Extracting all things at this moment that we will not want to symmetrize.
102
+ # This includes: 1) unsym motifs, 2) ligands
103
+ unsym_atom_arrays = []
104
+ if sym_conf.get("is_unsym_motif"):
105
+ # unsym_motif_atom_array = get_unsym_motif(asu_atom_array, asu_atom_array._is_unsym_motif)
106
+ # Now remove the unsym motifs from the asu atom array
107
+ unsym_atom_arrays.append(asu_atom_array[asu_atom_array._is_unsym_motif])
108
+ asu_atom_array = asu_atom_array[~asu_atom_array._is_unsym_motif]
109
+ if sm:
110
+ unsym_atom_arrays.append(asu_atom_array[asu_atom_array._is_sm])
111
+ asu_atom_array = asu_atom_array[~asu_atom_array._is_sm]
112
+ unsym_atom_array = (
113
+ struc.concatenate(unsym_atom_arrays) if len(unsym_atom_arrays) > 0 else None
114
+ )
115
+
116
+ # Annotate symmetric subunits
117
+ symmetry_unit_list = []
118
+ for transform_id, frame in enumerate(frames):
119
+ # this is to build the fully symmetrized atom array containing all the symmetry subunits
120
+ symmetry_unit = get_symmetry_unit(asu_atom_array, transform_id, frame)
121
+ symmetry_unit_list.append(symmetry_unit)
122
+ if unsym_atom_array: # only if exists
123
+ unsym_atom_array = annotate_unsym_atom_array(unsym_atom_array)
124
+ symmetry_unit_list.append(
125
+ unsym_atom_array
126
+ ) # add the motifs to the end of the asu atom array list (motifs at end of atom array)
127
+ # build the full symmetrized atom array
128
+ symmetrized_atom_array = struc.concatenate(symmetry_unit_list)
129
+
130
+ # add 2D conditioning annotations
131
+ if has_2d:
132
+ symmetrized_atom_array = reannotate_2d_conditions(symmetrized_atom_array)
133
+
134
+ # set all motifs to not have any symmetrization applied to them
135
+ # TODO: this needs to be adapted to work with 2D cond (in 2D cond, we WANT to apply symmetry to the motifs since they move in space)
136
+ symmetrized_atom_array = fix_3D_sym_motif_annotations(symmetrized_atom_array)
137
+
138
+ # This is needed to output correct motif residue mappings in the output json
139
+ symmetrized_atom_array = add_src_sym_component_annotations(symmetrized_atom_array)
140
+ # remove utility annotations
141
+ symmetrized_atom_array = _del_util_annotations(symmetrized_atom_array)
142
+ return symmetrized_atom_array
143
+
144
+
145
+ def make_symmetric_atom_array_for_partial_diffusion(atom_array, sym_conf):
146
+ """
147
+ Apply symmetry to an atom array with partial diffusion.
148
+ Arguments:
149
+ atom_array: atom array of the asymmetric unit
150
+ sym_conf: symmetry configuration (dict, "id" key is required)
151
+ Returns:
152
+ atom_array: atom array with symmetry applied
153
+ """
154
+ # TODO: clean up this function
155
+
156
+ # For partial diffusion with symmetric inputs, preserve exact positioning
157
+ ranked_logger.info(
158
+ "Partial diffusion with symmetry - preserving exact input coordinates"
159
+ )
160
+ ranked_logger.info("SKIPPING symmetry reconstruction to preserve input structure")
161
+ # Add full symmetry annotations without changing coordinates
162
+ from rfd3.inference.symmetry.checks import (
163
+ check_atom_array_is_symmetric,
164
+ )
165
+ from rfd3.inference.symmetry.frames import (
166
+ decompose_symmetry_frame,
167
+ )
168
+
169
+ check_symmetry_config(
170
+ atom_array,
171
+ sym_conf,
172
+ sm=None,
173
+ has_dist_cond=False,
174
+ src_atom_array=None,
175
+ partial=True,
176
+ )
177
+
178
+ atom_array = add_sym_annotations(atom_array, sym_conf)
179
+ assert check_atom_array_is_symmetric(atom_array), "Atom array is not symmetric"
180
+
181
+ n = atom_array.shape[0]
182
+ chain_ids = np.unique(atom_array.chain_id)
183
+ frames = get_symmetry_frames_from_symmetry_id(sym_conf)
184
+
185
+ # Add symmetry ID
186
+ symmetry_ids = np.full(n, sym_conf.get("id"), dtype="U6")
187
+ atom_array.set_annotation("symmetry_id", symmetry_ids)
188
+
189
+ # Initialize transform annotations (use same format as original system)
190
+ symmetry_transform_id = np.zeros(n, dtype=np.int32)
191
+ symmetry_entity_id = np.zeros(n, dtype=np.int32)
192
+ is_asu = np.zeros(n, dtype=bool)
193
+
194
+ # Add transform annotations for each chain (same format as add_symmetry_transform_annotations)
195
+ for i, chain_id in enumerate(chain_ids):
196
+ chain_mask = atom_array.chain_id == chain_id
197
+ transform_id = i % len(frames) # Cycle through available frames
198
+ frame = frames[transform_id]
199
+
200
+ # Decompose frame to packed scalars
201
+ Ori, X, Y = decompose_symmetry_frame(frame)
202
+
203
+ # Set annotations for this chain (use np.full like original system)
204
+ if i == 0: # First chain - initialize arrays
205
+ sym_transform_Ori = np.full(n, Ori)
206
+ sym_transform_X = np.full(n, X)
207
+ sym_transform_Y = np.full(n, Y)
208
+ is_asu[chain_mask] = True
209
+ else: # Subsequent chains - update specific atoms
210
+ sym_transform_Ori[chain_mask] = Ori
211
+ sym_transform_X[chain_mask] = X
212
+ sym_transform_Y[chain_mask] = Y
213
+
214
+ symmetry_transform_id[chain_mask] = transform_id
215
+ symmetry_entity_id[chain_mask] = 0 # All chains same entity for C9
216
+
217
+ # Set all annotations
218
+ atom_array.set_annotation("sym_transform_Ori", sym_transform_Ori)
219
+ atom_array.set_annotation("sym_transform_X", sym_transform_X)
220
+ atom_array.set_annotation("sym_transform_Y", sym_transform_Y)
221
+ atom_array.set_annotation("sym_transform_id", symmetry_transform_id)
222
+ atom_array.set_annotation("sym_entity_id", symmetry_entity_id)
223
+ atom_array.set_annotation("is_sym_asu", is_asu)
224
+
225
+ ranked_logger.info(
226
+ f"Added full symmetry annotations to {len(chain_ids)} existing chains WITHOUT changing coordinates"
227
+ )
228
+
229
+ return atom_array
230
+
231
+
232
+ ########################################################
233
+ # Private functions only used in make_symmetric_atom_array
234
+ ########################################################
235
+
236
+
237
+ def _add_util_annotations(asu_atom_array, sym_conf, sm):
238
+ """
239
+ Add symmetry-specific utility annotations to the asu atom array.
240
+ Arguments:
241
+ asu_atom_array: atom array of the asymmetric unit
242
+ sym_conf: symmetry configuration
243
+ sm: small molecule names (str, comma separated)
244
+ """
245
+ n = asu_atom_array.shape[0]
246
+ is_motif = get_motif_features(asu_atom_array)["is_motif_atom"].astype(np.bool_)
247
+ is_sm = np.zeros(asu_atom_array.shape[0], dtype=bool)
248
+ is_asu = np.ones(n, dtype=bool)
249
+ is_unsym_motif = np.zeros(n, dtype=bool)
250
+
251
+ if sm:
252
+ is_sm = np.logical_or.reduce(
253
+ [
254
+ fetch_mask_from_component(lig, atom_array=asu_atom_array)
255
+ for lig in sm.split(",")
256
+ ]
257
+ )
258
+
259
+ # assign unsym motifs
260
+ if sym_conf.get("is_unsym_motif"):
261
+ unsym_motif_names = sym_conf["is_unsym_motif"].split(",")
262
+ unsym_motif_names = expand_contig_unsym_motif(unsym_motif_names)
263
+ is_unsym_motif = get_unsym_motif_mask(asu_atom_array, unsym_motif_names)
264
+
265
+ is_unindexed_motif = asu_atom_array.is_motif_atom_unindexed.astype(np.bool_)
266
+ is_indexed_motif = ~is_sm & ~is_unindexed_motif & is_motif
267
+
268
+ asu_atom_array.set_annotation(
269
+ "_is_asu", is_asu
270
+ ) # Currently not used but will needed for 2D cond
271
+ asu_atom_array.set_annotation("_is_motif", is_motif)
272
+ asu_atom_array.set_annotation("_is_sm", is_sm)
273
+ asu_atom_array.set_annotation("_is_indexed_motif", is_indexed_motif)
274
+ asu_atom_array.set_annotation("_is_unindexed_motif", is_unindexed_motif)
275
+ asu_atom_array.set_annotation("_is_unsym_motif", is_unsym_motif)
276
+ return asu_atom_array
277
+
278
+
279
+ def _del_util_annotations(aary):
280
+ """
281
+ Delete symmetry-specific utility annotations from the atom array.
282
+ Arguments:
283
+ aary: atom array
284
+ """
285
+ aary.del_annotation("_is_asu") # Currently not used but will needed for 2D cond
286
+ aary.del_annotation("_is_motif")
287
+ aary.del_annotation("_is_sm")
288
+ aary.del_annotation("_is_indexed_motif")
289
+ aary.del_annotation("_is_unindexed_motif")
290
+ aary.del_annotation("_is_unsym_motif")
291
+ return aary
292
+
293
+
294
+ #########################
295
+ # Symmetrization functions
296
+ #########################
297
+
298
+
299
+ def center_symmetric_src_atom_array(src_atom_array):
300
+ """
301
+ Center the src atom array at the origin.
302
+ Arguments:
303
+ src_atom_array: atom array of the source
304
+ Returns:
305
+ src_atom_array: atom array of the source centered at the origin
306
+ """
307
+ # Compute COM of the src atom array (protein only elements)
308
+ src_atom_array_com = np.mean(
309
+ src_atom_array[src_atom_array.chain_type == 6].coord, axis=0
310
+ )
311
+ # center the src atom array
312
+ src_atom_array.coord -= src_atom_array_com
313
+ return src_atom_array
314
+
315
+
316
+ def apply_symmetry_to_xyz_atomwise(X_L, sym_feats, partial_diffusion=False):
317
+ """
318
+ Apply symmetry to the xyz coordinates.
319
+ Arguments:
320
+ X_L: [B, L, 3] xyz coordinates
321
+ sym_feats: dictionary containing symmetry features (id, transform, entity_id, is_sym_asu)
322
+ Returns:
323
+ X_L: [B, L, 3] xyz coordinates with symmetry applied
324
+ """
325
+ sym_entity_id = sym_feats["sym_entity_id"]
326
+ sym_transform_id = sym_feats["sym_transform_id"]
327
+ is_sym_asu = sym_feats["is_sym_asu"]
328
+ fixed_motif_mask = sym_entity_id == FIXED_ENTITY_ID
329
+ sym_transforms = {
330
+ int(k): v
331
+ for k, v in sym_feats["sym_transform"].items()
332
+ if int(k) != FIXED_TRANSFORM_ID
333
+ } # {str(id): tensor(3,3)} -> {int(id): tensor(3,3)}
334
+ # COM correction (in case there is drift)
335
+ if not partial_diffusion:
336
+ X_L[:, ~fixed_motif_mask, :] = X_L[:, ~fixed_motif_mask, :] - X_L[
337
+ :, ~fixed_motif_mask, :
338
+ ].mean(dim=1, keepdim=True)
339
+ sym_X_L = X_L.clone()
340
+
341
+ # Loop through each symmetry entity id - making sure that we apply the matching symmetry transform to asu id
342
+ unique_entity_id = torch.unique(sym_entity_id)
343
+ unique_entity_id = unique_entity_id[unique_entity_id != FIXED_ENTITY_ID]
344
+ for entity_id in unique_entity_id:
345
+ # Mask for this entity id
346
+ entity_id_mask = sym_entity_id == entity_id # [L]
347
+ # ASU that corresponds to this transform only
348
+ entity_asu_mask = is_sym_asu & entity_id_mask
349
+ if entity_asu_mask.sum() == 0:
350
+ continue
351
+ asu_xyz = X_L[:, entity_asu_mask, :] # [B, Lasu, 3]
352
+ # Transforms
353
+ unique_transform_id = torch.unique(sym_transform_id[entity_id_mask]).tolist()
354
+ for (
355
+ target_id
356
+ ) in unique_transform_id: # Open to suggestions for making this more efficient
357
+ # Get a mask that corresponds to this specific transform in the entire atom array
358
+ this_subunit = entity_id_mask & (sym_transform_id == target_id)
359
+ # Apply this subunit's symmetry transform to its corresponding ASU
360
+ sym_X_L[:, this_subunit, :] = torch.einsum(
361
+ "blc,cd->bld", asu_xyz, sym_transforms[target_id][0].to(asu_xyz.dtype)
362
+ ) + sym_transforms[target_id][1].to(asu_xyz.dtype)
363
+
364
+ # Log inter-chain distances for debugging - use actual chain annotations
365
+ if sym_X_L.shape[1] > 100: # Only for large structures
366
+ # Use symmetry entity annotations to find different chains
367
+ sym_entity_id = sym_feats["sym_entity_id"]
368
+ unique_entities = torch.unique(sym_entity_id)
369
+
370
+ if len(unique_entities) >= 2:
371
+ # Get atoms from first two different entities
372
+ entity_0_mask = sym_entity_id == unique_entities[0]
373
+ entity_1_mask = sym_entity_id == unique_entities[1]
374
+
375
+ if entity_0_mask.sum() > 0 and entity_1_mask.sum() > 0:
376
+ entity_0_atoms = sym_X_L[0, entity_0_mask, :]
377
+ entity_1_atoms = sym_X_L[0, entity_1_mask, :]
378
+
379
+ # Sample subset to avoid memory issues
380
+ entity_0_sample = entity_0_atoms[: min(50, entity_0_atoms.shape[0]), :]
381
+ entity_1_sample = entity_1_atoms[: min(50, entity_1_atoms.shape[0]), :]
382
+
383
+ min_distance = (
384
+ torch.cdist(entity_0_sample, entity_1_sample).min().item()
385
+ )
386
+ ranked_logger.info(
387
+ f"Min inter-chain distance after symmetry: {min_distance:.2f} Å"
388
+ )
389
+
390
+ # Also log the centers of each entity
391
+ entity_0_center = entity_0_atoms.mean(dim=0)
392
+ entity_1_center = entity_1_atoms.mean(dim=0)
393
+ center_distance = torch.norm(entity_0_center - entity_1_center).item()
394
+ ranked_logger.info(
395
+ f"Distance between chain centers: {center_distance:.2f} Å"
396
+ )
397
+
398
+ return sym_X_L