rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,63 @@
1
+ import numpy as np
2
+
3
+ from foundry.utils.components import fetch_mask_from_idx
4
+
5
+
6
+ def expand_contig_to_resid_from_string(contig_string):
7
+ """
8
+ Expand a contig string to a list of residue indices.
9
+ Arguments:
10
+ contig_string: string of the form "X1-5", e.g.
11
+ Returns:
12
+ list of residue indices
13
+ """
14
+ chain = contig_string[0]
15
+ res_range = contig_string[1:].split("-")
16
+ res_start = int(res_range[0])
17
+ res_end = int(res_range[1])
18
+ return [f"{chain}{i}" for i in range(res_start, res_end + 1)]
19
+
20
+
21
+ def expand_contig_unsym_motif(unsym_motif_names):
22
+ """
23
+ Expand a list of unsym motif names to a list of residue indices.
24
+ Arguments:
25
+ unsym_motif_names: list of unsym motif names
26
+ Returns:
27
+ list of residue indices
28
+ """
29
+ expanded_contigs = [
30
+ expand_contig_to_resid_from_string(n) for n in unsym_motif_names if "-" in n
31
+ ]
32
+ # now remove any unexpanded contigs
33
+ unsym_motif_names = [n for n in unsym_motif_names if "-" not in n]
34
+ if len(expanded_contigs) != 0:
35
+ for c in expanded_contigs:
36
+ unsym_motif_names.extend(c)
37
+ return unsym_motif_names
38
+
39
+
40
+ def get_unsym_motif_mask(atom_array, unsym_motif_names):
41
+ """
42
+ Get a mask of the unsym motif atoms.
43
+ Arguments:
44
+ atom_array: atom array
45
+ unsym_motif_names: list of unsym motif names
46
+ Returns:
47
+ mask of the unsym motif atoms
48
+ """
49
+
50
+ is_unsym_motif = np.zeros(len(atom_array), dtype=bool)
51
+ for n in unsym_motif_names:
52
+ is_unsym_motif = np.logical_or(is_unsym_motif, atom_array.res_name == n)
53
+ if (
54
+ "src_component" in atom_array.get_annotation_categories()
55
+ and n in atom_array.src_component
56
+ ):
57
+ is_unsym_motif = np.logical_or(
58
+ is_unsym_motif, atom_array.src_component == n
59
+ )
60
+ elif n[0].isalpha() and n[1:].isdigit():
61
+ residue_mask = fetch_mask_from_idx(n, atom_array=atom_array)
62
+ is_unsym_motif = np.logical_or(is_unsym_motif, residue_mask)
63
+ return is_unsym_motif
@@ -0,0 +1,355 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ def get_symmetry_frames_from_symmetry_id(symmetry_id):
6
+ """
7
+ Get symmetry frames from a symmetry id.
8
+ Arguments:
9
+ symmetry_id: string of the symmetry id
10
+ Returns:
11
+ frames: list of rotation matrices
12
+ """
13
+
14
+ # Get frames from symmetry id
15
+ sym_conf = {}
16
+ if isinstance(symmetry_id, dict):
17
+ sym_conf = symmetry_id
18
+ symmetry_id = symmetry_id.get("id")
19
+
20
+ if symmetry_id.lower().startswith("c"):
21
+ order = int(symmetry_id[1:])
22
+ frames = get_cyclic_frames(order)
23
+ elif symmetry_id.lower().startswith("d"):
24
+ order = int(symmetry_id[1:])
25
+ frames = get_dihedral_frames(order)
26
+ elif symmetry_id.lower() == "input_defined":
27
+ assert (
28
+ "symmetry_file" in sym_conf
29
+ ), "symmetry_file is required for input_defined symmetry"
30
+ frames = get_frames_from_file(sym_conf.get("symmetry_file"))
31
+ else:
32
+ raise ValueError(f"Symmetry id {symmetry_id} not supported")
33
+
34
+ # Check that the frames are valid rotation matrices
35
+ for R, _ in frames:
36
+ assert is_valid_rotation_matrix(R), f"Frame {R} is not a valid rotation matrix"
37
+
38
+ return frames
39
+
40
+
41
+ def get_symmetry_frames_from_atom_array(src_atom_array, input_frames):
42
+ """
43
+ Get symmetry frames from an atom array. Adapted from code from FD
44
+ Arguments:
45
+ src_atom_array: atom array with coordinates and chain/residue information
46
+ input_frames: list of (rotation_matrix, translation_vector) tuples
47
+ Returns:
48
+ computed_frames: list of (rotation_matrix, translation_vector) tuples (updated)
49
+ """
50
+ # import within the function to avoid circular import
51
+ from rfd3.inference.symmetry.checks import (
52
+ check_input_frames_match_symmetry_frames,
53
+ check_max_rmsds,
54
+ check_max_transforms,
55
+ check_min_atoms_to_align,
56
+ check_valid_multiplicity,
57
+ check_valid_subunit_size,
58
+ )
59
+
60
+ # remove non-protein elements
61
+ src_atom_array = src_atom_array[src_atom_array.chain_type == 6]
62
+
63
+ # get entities and ids from the src atom array
64
+ pn_unit_ent = src_atom_array.get_annotation("pn_unit_entity")
65
+ pn_unit_id = src_atom_array.get_annotation("pn_unit_iid")
66
+ unique_entities = np.unique(pn_unit_ent)
67
+ nids_by_entity = {
68
+ i: np.unique(pn_unit_id[pn_unit_ent == i]) for i in unique_entities
69
+ }
70
+
71
+ # get coordinates
72
+ coords = src_atom_array.coord
73
+
74
+ # get/check multiplicities of subunits
75
+ check_valid_multiplicity(nids_by_entity)
76
+
77
+ multiplicity = min([len(i) for i in nids_by_entity.values()])
78
+ n_per_asu = {i: len(j) // multiplicity for i, j in nids_by_entity.items()}
79
+
80
+ # check that the subunits in the input are of the same size
81
+ check_valid_subunit_size(nids_by_entity, pn_unit_id)
82
+
83
+ # align the largest set of entities
84
+ natm_per_unique = {
85
+ i: (pn_unit_id == nids_by_entity[i][0]).sum()
86
+ for i in unique_entities
87
+ if n_per_asu[i] == 1
88
+ }
89
+ reference_entity = max(natm_per_unique, key=natm_per_unique.get)
90
+
91
+ # check that we have enough atoms to align
92
+ check_min_atoms_to_align(natm_per_unique, reference_entity)
93
+
94
+ # chains for the alignment (will generate complete set of frames)
95
+ chains_to_consider = nids_by_entity[reference_entity]
96
+ reference_molecule = nids_by_entity[reference_entity][0]
97
+
98
+ # check that we are not exceeding the max number of transforms
99
+ check_max_transforms(chains_to_consider)
100
+
101
+ # align reference molecule to all others
102
+ xforms = {
103
+ i: _align(coords[pn_unit_id == i], coords[pn_unit_id == reference_molecule])
104
+ for i in chains_to_consider
105
+ }
106
+ rmsds = {
107
+ i: _rms(coords[pn_unit_id == i], coords[pn_unit_id == reference_molecule], *j)
108
+ for i, j in xforms.items()
109
+ }
110
+
111
+ # check that there is not too big of a RMSD difference between subunits
112
+ check_max_rmsds(rmsds)
113
+
114
+ # check that the frames are valid rotation matrices
115
+ Rs = [R for _, R, _ in xforms.values()]
116
+ for R in Rs:
117
+ assert is_valid_rotation_matrix(
118
+ R
119
+ ), f"Computed frame {R} is not a valid rotation matrix"
120
+ computed_frames = [(R, np.array([0, 0, 0])) for R in Rs]
121
+
122
+ # check that the computed frames match the input frames
123
+ check_input_frames_match_symmetry_frames(computed_frames, input_frames)
124
+
125
+ return computed_frames
126
+
127
+
128
+ def _align(X_fixed, X_moving):
129
+ """
130
+ Align two sets of coordinates using Kabsch algorithm.
131
+ Arguments:
132
+ X_fixed: fixed coordinates
133
+ X_moving: moving coordinates
134
+ Returns:
135
+ u_X_moving: mean of the moving coordinates
136
+ R: rotation matrix
137
+ u_X_fixed: mean of the fixed coordinates
138
+ """
139
+ is_torch = isinstance(X_fixed, torch.Tensor)
140
+
141
+ def _mean_along_dim(X, dim):
142
+ if is_torch:
143
+ return X.mean(dim=dim)
144
+ else:
145
+ return X.mean(axis=dim)
146
+
147
+ assert X_fixed.shape == X_moving.shape
148
+
149
+ if X_fixed.ndim == 2:
150
+ X_fixed = X_fixed[None, ...]
151
+ X_moving = X_moving[None, ...]
152
+ B = X_fixed.shape[0]
153
+
154
+ if is_torch:
155
+ mask = (~torch.isnan(X_fixed) & ~torch.isnan(X_moving)).all(dim=-1).all(dim=0)
156
+ else:
157
+ mask = (~np.isnan(X_fixed) & ~np.isnan(X_moving)).all(axis=-1).all(axis=0)
158
+
159
+ X_fixed = X_fixed[:, mask]
160
+ X_moving = X_moving[:, mask]
161
+
162
+ u_X_fixed = _mean_along_dim(X_fixed, dim=-2)
163
+ u_X_moving = _mean_along_dim(X_moving, dim=-2)
164
+
165
+ X_fixed_centered = X_fixed - u_X_fixed[..., None, :]
166
+ X_moving_centered = X_moving - u_X_moving[..., None, :]
167
+
168
+ if is_torch:
169
+ C = torch.einsum("...ji,...jk->...ik", X_fixed_centered, X_moving_centered)
170
+ U, S, V = torch.linalg.svd(C, full_matrices=False)
171
+ else:
172
+ C = np.einsum("...ji,...jk->...ik", X_fixed_centered, X_moving_centered)
173
+ U, S, V = np.linalg.svd(C, full_matrices=False)
174
+
175
+ R = U @ V
176
+ if is_torch:
177
+ F = torch.eye(3, 3, device=R.device).expand(B, 3, 3).clone()
178
+ F[..., -1, -1] = torch.sign(torch.linalg.det(R))
179
+ else:
180
+ F = np.broadcast_to(np.eye(3, 3), (B, 3, 3)).copy()
181
+ F[..., -1, -1] = np.sign(np.linalg.det(R))
182
+ R = U @ F @ V
183
+
184
+ if R.shape[0] == 1:
185
+ return u_X_moving[0], R[0], u_X_fixed[0]
186
+
187
+ return u_X_moving, R, u_X_fixed
188
+
189
+
190
+ def _rms(X_fixed, X_moving, t_pre, R, t_post):
191
+ """
192
+ Calculate the RMSD between two sets of coordinates.
193
+ Arguments:
194
+ X_fixed: fixed coordinates
195
+ X_moving: moving coordinates
196
+ t_pre: pre-rotation translation
197
+ R: rotation matrix
198
+ t_post: post-rotation translation
199
+ Returns:
200
+ rms: RMSD
201
+ """
202
+ mask = (~np.isnan(X_fixed) & ~np.isnan(X_moving)).all(axis=-1)
203
+ X_fixed = X_fixed[mask]
204
+ X_moving = X_moving[mask]
205
+
206
+ X_moving_aln = np.einsum("ij,bj->bi", R, (X_moving - t_pre[None])) + t_post[None]
207
+ rms = np.sqrt(np.sum(np.square(X_moving_aln - X_fixed)) / X_moving_aln.shape[-2])
208
+ return rms
209
+
210
+
211
+ def is_valid_rotation_matrix(R):
212
+ """
213
+ check if a matrix is a valid rotation matrix.
214
+ Arguments:
215
+ R: rotation matrix
216
+ Returns:
217
+ bool: True if R is a valid rotation matrix, False otherwise
218
+ """
219
+
220
+ return np.allclose(R @ R.T, np.eye(3), atol=1e-6)
221
+
222
+
223
+ def get_cyclic_frames(order):
224
+ """
225
+ Get cyclic frames from a number of subunits.
226
+ Arguments:
227
+ order: number of subunits
228
+ Returns:
229
+ frames: list of rotation matrices
230
+ """
231
+
232
+ frames = []
233
+ for i in range(order):
234
+ angle = 2 * np.pi * i / order
235
+ R = np.array(
236
+ [
237
+ [np.cos(angle), -np.sin(angle), 0],
238
+ [np.sin(angle), np.cos(angle), 0],
239
+ [0, 0, 1],
240
+ ]
241
+ )
242
+ frames.append((R, np.array([0, 0, 0])))
243
+
244
+ return frames
245
+
246
+
247
+ def get_dihedral_frames(order):
248
+ """
249
+ Get dihedral frames from a number of subunits.
250
+ Arguments:
251
+ order: number of subunits // 2 (since each dihedral has two frames)
252
+ Returns:
253
+ frames: list of rotation matrices
254
+ """
255
+
256
+ frames = []
257
+
258
+ for i in range(order):
259
+ angle = 2 * np.pi * i / order
260
+ R = np.array(
261
+ [
262
+ [np.cos(angle), -np.sin(angle), 0],
263
+ [np.sin(angle), np.cos(angle), 0],
264
+ [0, 0, 1],
265
+ ]
266
+ )
267
+
268
+ # 180 degree rotation in the xy-plane
269
+ phi = angle + np.pi / order
270
+ u = np.array([np.cos(phi), np.sin(phi), 0])
271
+ flip = -np.eye(3) + 2 * np.outer(u, u)
272
+
273
+ # add both frames for the dihedral
274
+ frames.append((R, np.array([0, 0, 0])))
275
+ frames.append((R @ flip, np.array([0, 0, 0])))
276
+
277
+ return frames
278
+
279
+
280
+ def get_frames_from_file(file_path):
281
+ raise NotImplementedError("Input defined symmetry not implemented")
282
+
283
+
284
+ ###################################
285
+ # Kinematics
286
+ ###################################
287
+
288
+
289
+ # fd - two routines that convert between:
290
+ # a) a "virtual frame" consisting of three atoms; and
291
+ # b) a translation and rotation
292
+ # uses Gram-Schmidt orthogonalziation, handles stacked/unstacked
293
+ # support np and torch inputs
294
+ def RTs_to_framecoords(Rs, ts, sig=1.0):
295
+ if isinstance(Rs, np.ndarray):
296
+ Rs = torch.from_numpy(Rs)
297
+ ts = torch.from_numpy(ts)
298
+ Ori = ts
299
+ X = Ori + sig * Rs[..., 0, :] / (
300
+ torch.norm(Rs[..., 0, :], dim=-1, keepdim=True) + 1e-6
301
+ )
302
+ Y = Ori + sig * Rs[..., 1, :] / (
303
+ torch.norm(Rs[..., 1, :], dim=-1, keepdim=True) + 1e-6
304
+ )
305
+ return Ori, X, Y
306
+
307
+
308
+ # RTs_to_framecoords is used in loss and expects torch inputs
309
+ # (and must support backwards)
310
+ def framecoords_to_RTs(Ori, X, Y, eps=1e-6):
311
+ R1 = X - Ori
312
+ R1 = (R1 + torch.tensor([eps, 0, 0], device=R1.device)) / (
313
+ torch.linalg.norm(R1, axis=-1, keepdims=True) + eps
314
+ )
315
+
316
+ Y_rel = Y - Ori
317
+ proj = torch.sum(Y_rel * R1, axis=-1, keepdims=True) * R1
318
+ R2 = Y_rel - proj
319
+ R2 = (R2 + torch.tensor([0, eps, 0], device=R1.device)) / (
320
+ torch.linalg.norm(R2, axis=-1, keepdims=True) + eps
321
+ )
322
+
323
+ R3 = torch.cross(R1, R2, dim=-1)
324
+
325
+ # Stack into rotation matrix
326
+ R = torch.stack([R1, R2, R3], axis=-2) # shape (..., 3, 3)
327
+ T = Ori
328
+
329
+ return R, T
330
+
331
+
332
+ def pack_vector(v: np.ndarray) -> np.ndarray:
333
+ """
334
+ v: 1-D array of shape (3,) and arbitrary dtype
335
+ returns: 1-element of shape 1
336
+ """
337
+ dt = np.dtype([("x", v.dtype, (3,))])
338
+ a = np.zeros(1, dtype=dt)
339
+ a["x"][0] = v
340
+ return a
341
+
342
+
343
+ def unpack_vector(a: np.ndarray) -> np.ndarray:
344
+ """
345
+ a: stuctured array of shape (1,)
346
+ returns: original vector
347
+ """
348
+ return a["x"]
349
+
350
+
351
+ def decompose_symmetry_frame(frame):
352
+ R, T = frame
353
+ Ori, X, Y = RTs_to_framecoords(R, T)
354
+ Ori, X, Y = pack_vector(Ori.numpy()), pack_vector(X.numpy()), pack_vector(Y.numpy())
355
+ return Ori, X, Y