rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
foundry/utils/rigid.py ADDED
@@ -0,0 +1,1460 @@
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, Optional, Tuple
17
+
18
+ import numpy as np
19
+ import torch
20
+
21
+
22
+ def rot_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
23
+ """
24
+ Performs matrix multiplication of two rotation matrix tensors. Written
25
+ out by hand to avoid AMP downcasting.
26
+
27
+ Args:
28
+ a: [*, 3, 3] left multiplicand
29
+ b: [*, 3, 3] right multiplicand
30
+ Returns:
31
+ The product ab
32
+ """
33
+ row_1 = torch.stack(
34
+ [
35
+ a[..., 0, 0] * b[..., 0, 0]
36
+ + a[..., 0, 1] * b[..., 1, 0]
37
+ + a[..., 0, 2] * b[..., 2, 0],
38
+ a[..., 0, 0] * b[..., 0, 1]
39
+ + a[..., 0, 1] * b[..., 1, 1]
40
+ + a[..., 0, 2] * b[..., 2, 1],
41
+ a[..., 0, 0] * b[..., 0, 2]
42
+ + a[..., 0, 1] * b[..., 1, 2]
43
+ + a[..., 0, 2] * b[..., 2, 2],
44
+ ],
45
+ dim=-1,
46
+ )
47
+ row_2 = torch.stack(
48
+ [
49
+ a[..., 1, 0] * b[..., 0, 0]
50
+ + a[..., 1, 1] * b[..., 1, 0]
51
+ + a[..., 1, 2] * b[..., 2, 0],
52
+ a[..., 1, 0] * b[..., 0, 1]
53
+ + a[..., 1, 1] * b[..., 1, 1]
54
+ + a[..., 1, 2] * b[..., 2, 1],
55
+ a[..., 1, 0] * b[..., 0, 2]
56
+ + a[..., 1, 1] * b[..., 1, 2]
57
+ + a[..., 1, 2] * b[..., 2, 2],
58
+ ],
59
+ dim=-1,
60
+ )
61
+ row_3 = torch.stack(
62
+ [
63
+ a[..., 2, 0] * b[..., 0, 0]
64
+ + a[..., 2, 1] * b[..., 1, 0]
65
+ + a[..., 2, 2] * b[..., 2, 0],
66
+ a[..., 2, 0] * b[..., 0, 1]
67
+ + a[..., 2, 1] * b[..., 1, 1]
68
+ + a[..., 2, 2] * b[..., 2, 1],
69
+ a[..., 2, 0] * b[..., 0, 2]
70
+ + a[..., 2, 1] * b[..., 1, 2]
71
+ + a[..., 2, 2] * b[..., 2, 2],
72
+ ],
73
+ dim=-1,
74
+ )
75
+
76
+ return torch.stack([row_1, row_2, row_3], dim=-2)
77
+
78
+
79
+ def rot_vec_mul(r: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
80
+ """
81
+ Applies a rotation to a vector. Written out by hand to avoid transfer
82
+ to avoid AMP downcasting.
83
+
84
+ Args:
85
+ r: [*, 3, 3] rotation matrices
86
+ t: [*, 3] coordinate tensors
87
+ Returns:
88
+ [*, 3] rotated coordinates
89
+ """
90
+ x = t[..., 0]
91
+ y = t[..., 1]
92
+ z = t[..., 2]
93
+ return torch.stack(
94
+ [
95
+ r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z,
96
+ r[..., 1, 0] * x + r[..., 1, 1] * y + r[..., 1, 2] * z,
97
+ r[..., 2, 0] * x + r[..., 2, 1] * y + r[..., 2, 2] * z,
98
+ ],
99
+ dim=-1,
100
+ )
101
+
102
+
103
+ def identity_rot_mats(
104
+ batch_dims: Tuple[int],
105
+ dtype: Optional[torch.dtype] = None,
106
+ device: Optional[torch.device] = None,
107
+ requires_grad: bool = True,
108
+ ) -> torch.Tensor:
109
+ rots = torch.eye(3, dtype=dtype, device=device, requires_grad=requires_grad)
110
+ rots = rots.view(*((1,) * len(batch_dims)), 3, 3)
111
+ rots = rots.expand(*batch_dims, -1, -1)
112
+
113
+ return rots
114
+
115
+
116
+ def identity_trans(
117
+ batch_dims: Tuple[int],
118
+ dtype: Optional[torch.dtype] = None,
119
+ device: Optional[torch.device] = None,
120
+ requires_grad: bool = True,
121
+ ) -> torch.Tensor:
122
+ trans = torch.zeros(
123
+ (*batch_dims, 3), dtype=dtype, device=device, requires_grad=requires_grad
124
+ )
125
+ return trans
126
+
127
+
128
+ def identity_quats(
129
+ batch_dims: Tuple[int],
130
+ dtype: Optional[torch.dtype] = None,
131
+ device: Optional[torch.device] = None,
132
+ requires_grad: bool = True,
133
+ ) -> torch.Tensor:
134
+ quat = torch.zeros(
135
+ (*batch_dims, 4), dtype=dtype, device=device, requires_grad=requires_grad
136
+ )
137
+
138
+ with torch.no_grad():
139
+ quat[..., 0] = 1
140
+
141
+ return quat
142
+
143
+
144
+ _quat_elements = ["a", "b", "c", "d"]
145
+ _qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements]
146
+ _qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)}
147
+
148
+
149
+ def _to_mat(pairs):
150
+ mat = np.zeros((4, 4))
151
+ for pair in pairs:
152
+ key, value = pair
153
+ ind = _qtr_ind_dict[key]
154
+ mat[ind // 4][ind % 4] = value
155
+
156
+ return mat
157
+
158
+
159
+ _QTR_MAT = np.zeros((4, 4, 3, 3))
160
+ _QTR_MAT[..., 0, 0] = _to_mat([("aa", 1), ("bb", 1), ("cc", -1), ("dd", -1)])
161
+ _QTR_MAT[..., 0, 1] = _to_mat([("bc", 2), ("ad", -2)])
162
+ _QTR_MAT[..., 0, 2] = _to_mat([("bd", 2), ("ac", 2)])
163
+ _QTR_MAT[..., 1, 0] = _to_mat([("bc", 2), ("ad", 2)])
164
+ _QTR_MAT[..., 1, 1] = _to_mat([("aa", 1), ("bb", -1), ("cc", 1), ("dd", -1)])
165
+ _QTR_MAT[..., 1, 2] = _to_mat([("cd", 2), ("ab", -2)])
166
+ _QTR_MAT[..., 2, 0] = _to_mat([("bd", 2), ("ac", -2)])
167
+ _QTR_MAT[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)])
168
+ _QTR_MAT[..., 2, 2] = _to_mat([("aa", 1), ("bb", -1), ("cc", -1), ("dd", 1)])
169
+
170
+
171
+ def quat_to_rot(quat: torch.Tensor) -> torch.Tensor:
172
+ """
173
+ Converts a quaternion to a rotation matrix.
174
+
175
+ Args:
176
+ quat: [*, 4] quaternions
177
+ Returns:
178
+ [*, 3, 3] rotation matrices
179
+ """
180
+ # [*, 4, 4]
181
+ quat = quat[..., None] * quat[..., None, :]
182
+
183
+ # [4, 4, 3, 3]
184
+ mat = quat.new_tensor(_QTR_MAT, requires_grad=False)
185
+
186
+ # [*, 4, 4, 3, 3]
187
+ shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape)
188
+ quat = quat[..., None, None] * shaped_qtr_mat
189
+
190
+ # [*, 3, 3]
191
+ return torch.sum(quat, dim=(-3, -4))
192
+
193
+
194
+ def rot_to_quat(
195
+ rot: torch.Tensor,
196
+ ):
197
+ if rot.shape[-2:] != (3, 3):
198
+ raise ValueError("Input rotation is incorrectly shaped")
199
+
200
+ rot = [[rot[..., i, j] for j in range(3)] for i in range(3)]
201
+ [[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot
202
+
203
+ k = [
204
+ [
205
+ xx + yy + zz,
206
+ zy - yz,
207
+ xz - zx,
208
+ yx - xy,
209
+ ],
210
+ [
211
+ zy - yz,
212
+ xx - yy - zz,
213
+ xy + yx,
214
+ xz + zx,
215
+ ],
216
+ [
217
+ xz - zx,
218
+ xy + yx,
219
+ yy - xx - zz,
220
+ yz + zy,
221
+ ],
222
+ [
223
+ yx - xy,
224
+ xz + zx,
225
+ yz + zy,
226
+ zz - xx - yy,
227
+ ],
228
+ ]
229
+
230
+ k = (1.0 / 3.0) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2)
231
+
232
+ _, vectors = torch.linalg.eigh(k)
233
+ return vectors[..., -1]
234
+
235
+
236
+ _QUAT_MULTIPLY = np.zeros((4, 4, 4))
237
+ _QUAT_MULTIPLY[:, :, 0] = [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, -1]]
238
+
239
+ _QUAT_MULTIPLY[:, :, 1] = [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 0, -1, 0]]
240
+
241
+ _QUAT_MULTIPLY[:, :, 2] = [[0, 0, 1, 0], [0, 0, 0, -1], [1, 0, 0, 0], [0, 1, 0, 0]]
242
+
243
+ _QUAT_MULTIPLY[:, :, 3] = [[0, 0, 0, 1], [0, 0, 1, 0], [0, -1, 0, 0], [1, 0, 0, 0]]
244
+
245
+ _QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :]
246
+
247
+
248
+ def quat_multiply(quat1, quat2):
249
+ """Multiply a quaternion by another quaternion."""
250
+ mat = quat1.new_tensor(_QUAT_MULTIPLY)
251
+ reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape)
252
+ return torch.sum(
253
+ reshaped_mat * quat1[..., :, None, None] * quat2[..., None, :, None],
254
+ dim=(-3, -2),
255
+ )
256
+
257
+
258
+ def quat_multiply_by_vec(quat, vec):
259
+ """Multiply a quaternion by a pure-vector quaternion."""
260
+ mat = quat.new_tensor(_QUAT_MULTIPLY_BY_VEC)
261
+ reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape)
262
+ return torch.sum(
263
+ reshaped_mat * quat[..., :, None, None] * vec[..., None, :, None], dim=(-3, -2)
264
+ )
265
+
266
+
267
+ def invert_rot_mat(rot_mat: torch.Tensor):
268
+ return rot_mat.transpose(-1, -2)
269
+
270
+
271
+ def invert_quat(quat: torch.Tensor):
272
+ quat_prime = quat.clone()
273
+ quat_prime[..., 1:] *= -1
274
+ inv = quat_prime / torch.sum(quat**2, dim=-1, keepdim=True)
275
+ return inv
276
+
277
+
278
+ class Rotation:
279
+ """
280
+ A 3D rotation. Depending on how the object is initialized, the
281
+ rotation is represented by either a rotation matrix or a
282
+ quaternion, though both formats are made available by helper functions.
283
+ To simplify gradient computation, the underlying format of the
284
+ rotation cannot be changed in-place. Like Rigid, the class is designed
285
+ to mimic the behavior of a torch Tensor, almost as if each Rotation
286
+ object were a tensor of rotations, in one format or another.
287
+ """
288
+
289
+ def __init__(
290
+ self,
291
+ rot_mats: Optional[torch.Tensor] = None,
292
+ quats: Optional[torch.Tensor] = None,
293
+ normalize_quats: bool = True,
294
+ ):
295
+ """
296
+ Args:
297
+ rot_mats:
298
+ A [*, 3, 3] rotation matrix tensor. Mutually exclusive with
299
+ quats
300
+ quats:
301
+ A [*, 4] quaternion. Mutually exclusive with rot_mats. If
302
+ normalize_quats is not True, must be a unit quaternion
303
+ normalize_quats:
304
+ If quats is specified, whether to normalize quats
305
+ """
306
+ if (rot_mats is None and quats is None) or (
307
+ rot_mats is not None and quats is not None
308
+ ):
309
+ raise ValueError("Exactly one input argument must be specified")
310
+
311
+ if (rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or (
312
+ quats is not None and quats.shape[-1] != 4
313
+ ):
314
+ raise ValueError("Incorrectly shaped rotation matrix or quaternion")
315
+
316
+ # Force full-precision
317
+ if quats is not None:
318
+ quats = quats.type(torch.float32)
319
+ if rot_mats is not None:
320
+ rot_mats = rot_mats.type(torch.float32)
321
+
322
+ if quats is not None and normalize_quats:
323
+ quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True)
324
+
325
+ self._rot_mats = rot_mats
326
+ self._quats = quats
327
+
328
+ @staticmethod
329
+ def identity(
330
+ shape,
331
+ dtype: Optional[torch.dtype] = None,
332
+ device: Optional[torch.device] = None,
333
+ requires_grad: bool = True,
334
+ fmt: str = "quat",
335
+ ):
336
+ """
337
+ Returns an identity Rotation.
338
+
339
+ Args:
340
+ shape:
341
+ The "shape" of the resulting Rotation object. See documentation
342
+ for the shape property
343
+ dtype:
344
+ The torch dtype for the rotation
345
+ device:
346
+ The torch device for the new rotation
347
+ requires_grad:
348
+ Whether the underlying tensors in the new rotation object
349
+ should require gradient computation
350
+ fmt:
351
+ One of "quat" or "rot_mat". Determines the underlying format
352
+ of the new object's rotation
353
+ Returns:
354
+ A new identity rotation
355
+ """
356
+ if fmt == "rot_mat":
357
+ rot_mats = identity_rot_mats(
358
+ shape,
359
+ dtype,
360
+ device,
361
+ requires_grad,
362
+ )
363
+ return Rotation(rot_mats=rot_mats, quats=None)
364
+ elif fmt == "quat":
365
+ quats = identity_quats(shape, dtype, device, requires_grad)
366
+ return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
367
+ else:
368
+ raise ValueError(f"Invalid format: f{fmt}")
369
+
370
+ # Magic methods
371
+
372
+ def __getitem__(self, index: Any):
373
+ """
374
+ Allows torch-style indexing over the virtual shape of the rotation
375
+ object. See documentation for the shape property.
376
+
377
+ Args:
378
+ index:
379
+ A torch index. E.g. (1, 3, 2), or (slice(None,))
380
+ Returns:
381
+ The indexed rotation
382
+ """
383
+ if not isinstance(index, tuple):
384
+ index = (index,)
385
+
386
+ if self._rot_mats is not None:
387
+ rot_mats = self._rot_mats[index + (slice(None), slice(None))]
388
+ return Rotation(rot_mats=rot_mats)
389
+ elif self._quats is not None:
390
+ quats = self._quats[index + (slice(None),)]
391
+ return Rotation(quats=quats, normalize_quats=False)
392
+ else:
393
+ raise ValueError("Both rotations are None")
394
+
395
+ def __setitem__(self, index: Any, new: Any):
396
+ if not isinstance(index, tuple):
397
+ index = (index,)
398
+
399
+ if self._rot_mats is not None:
400
+ self._rot_mats[index + (slice(None), slice(None))] = new.get_rot_mats()
401
+ elif self._quats is not None:
402
+ self._quats[index + (slice(None),)] = new.get_quats()
403
+ else:
404
+ raise ValueError("Both rotations are None")
405
+
406
+ def __mul__(
407
+ self,
408
+ right: torch.Tensor,
409
+ ):
410
+ """
411
+ Pointwise left multiplication of the rotation with a tensor. Can be
412
+ used to e.g. mask the Rotation.
413
+
414
+ Args:
415
+ right:
416
+ The tensor multiplicand
417
+ Returns:
418
+ The product
419
+ """
420
+ if not (isinstance(right, torch.Tensor)):
421
+ raise TypeError("The other multiplicand must be a Tensor")
422
+
423
+ if self._rot_mats is not None:
424
+ rot_mats = self._rot_mats * right[..., None, None]
425
+ return Rotation(rot_mats=rot_mats, quats=None)
426
+ elif self._quats is not None:
427
+ quats = self._quats * right[..., None]
428
+ return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
429
+ else:
430
+ raise ValueError("Both rotations are None")
431
+
432
+ def __rmul__(
433
+ self,
434
+ left: torch.Tensor,
435
+ ):
436
+ """
437
+ Reverse pointwise multiplication of the rotation with a tensor.
438
+
439
+ Args:
440
+ left:
441
+ The left multiplicand
442
+ Returns:
443
+ The product
444
+ """
445
+ return self.__mul__(left)
446
+
447
+ # Properties
448
+
449
+ @property
450
+ def shape(self) -> torch.Size:
451
+ """
452
+ Returns the virtual shape of the rotation object. This shape is
453
+ defined as the batch dimensions of the underlying rotation matrix
454
+ or quaternion. If the Rotation was initialized with a [10, 3, 3]
455
+ rotation matrix tensor, for example, the resulting shape would be
456
+ [10].
457
+
458
+ Returns:
459
+ The virtual shape of the rotation object
460
+ """
461
+ s = None
462
+ if self._quats is not None:
463
+ s = self._quats.shape[:-1]
464
+ else:
465
+ s = self._rot_mats.shape[:-2]
466
+
467
+ return s
468
+
469
+ @property
470
+ def dtype(self) -> torch.dtype:
471
+ """
472
+ Returns the dtype of the underlying rotation.
473
+
474
+ Returns:
475
+ The dtype of the underlying rotation
476
+ """
477
+ if self._rot_mats is not None:
478
+ return self._rot_mats.dtype
479
+ elif self._quats is not None:
480
+ return self._quats.dtype
481
+ else:
482
+ raise ValueError("Both rotations are None")
483
+
484
+ @property
485
+ def device(self) -> torch.device:
486
+ """
487
+ The device of the underlying rotation
488
+
489
+ Returns:
490
+ The device of the underlying rotation
491
+ """
492
+ if self._rot_mats is not None:
493
+ return self._rot_mats.device
494
+ elif self._quats is not None:
495
+ return self._quats.device
496
+ else:
497
+ raise ValueError("Both rotations are None")
498
+
499
+ @property
500
+ def requires_grad(self) -> bool:
501
+ """
502
+ Returns the requires_grad property of the underlying rotation
503
+
504
+ Returns:
505
+ The requires_grad property of the underlying tensor
506
+ """
507
+ if self._rot_mats is not None:
508
+ return self._rot_mats.requires_grad
509
+ elif self._quats is not None:
510
+ return self._quats.requires_grad
511
+ else:
512
+ raise ValueError("Both rotations are None")
513
+
514
+ def get_rot_mats(self) -> torch.Tensor:
515
+ """
516
+ Returns the underlying rotation as a rotation matrix tensor.
517
+
518
+ Returns:
519
+ The rotation as a rotation matrix tensor
520
+ """
521
+ rot_mats = self._rot_mats
522
+ if rot_mats is None:
523
+ if self._quats is None:
524
+ raise ValueError("Both rotations are None")
525
+ else:
526
+ rot_mats = quat_to_rot(self._quats)
527
+
528
+ return rot_mats
529
+
530
+ def get_quats(self) -> torch.Tensor:
531
+ """
532
+ Returns the underlying rotation as a quaternion tensor.
533
+
534
+ Depending on whether the Rotation was initialized with a
535
+ quaternion, this function may call torch.linalg.eigh.
536
+
537
+ Returns:
538
+ The rotation as a quaternion tensor.
539
+ """
540
+ quats = self._quats
541
+ if quats is None:
542
+ if self._rot_mats is None:
543
+ raise ValueError("Both rotations are None")
544
+ else:
545
+ quats = rot_to_quat(self._rot_mats)
546
+
547
+ return quats
548
+
549
+ def get_cur_rot(self) -> torch.Tensor:
550
+ """
551
+ Return the underlying rotation in its current form
552
+
553
+ Returns:
554
+ The stored rotation
555
+ """
556
+ if self._rot_mats is not None:
557
+ return self._rot_mats
558
+ elif self._quats is not None:
559
+ return self._quats
560
+ else:
561
+ raise ValueError("Both rotations are None")
562
+
563
+ def get_rotvec(self, eps=1e-4) -> torch.Tensor:
564
+ """
565
+ Return the underlying axis-angle rotation vector.
566
+
567
+ Follow's scipy's implementation:
568
+ https://github.com/scipy/scipy/blob/HEAD/scipy/spatial/transform/_rotation.pyx#L1385-L1402
569
+
570
+ Returns:
571
+ The stored rotation as a axis-angle vector.
572
+ """
573
+ quat = self.get_quats()
574
+ # w > 0 to ensure 0 <= angle <= pi
575
+ flip = (quat[..., :1] < 0).float()
576
+ quat = (-1 * quat) * flip + (1 - flip) * quat
577
+
578
+ angle = 2 * torch.atan2(torch.linalg.norm(quat[..., 1:], dim=-1), quat[..., 0])
579
+
580
+ angle2 = angle * angle
581
+ small_angle_scales = 2 + angle2 / 12 + 7 * angle2 * angle2 / 2880
582
+ large_angle_scales = angle / torch.sin(angle / 2 + eps)
583
+
584
+ small_angles = (angle <= 1e-3).float()
585
+ rot_vec_scale = (
586
+ small_angle_scales * small_angles + (1 - small_angles) * large_angle_scales
587
+ )
588
+ rot_vec = rot_vec_scale[..., None] * quat[..., 1:]
589
+ return rot_vec
590
+
591
+ # Rotation functions
592
+
593
+ def compose_q_update_vec(
594
+ self,
595
+ q_update_vec: torch.Tensor,
596
+ normalize_quats: bool = True,
597
+ update_mask: torch.Tensor = None,
598
+ ):
599
+ """
600
+ Returns a new quaternion Rotation after updating the current
601
+ object's underlying rotation with a quaternion update, formatted
602
+ as a [*, 3] tensor whose final three columns represent x, y, z such
603
+ that (1, x, y, z) is the desired (not necessarily unit) quaternion
604
+ update.
605
+
606
+ Args:
607
+ q_update_vec:
608
+ A [*, 3] quaternion update tensor
609
+ normalize_quats:
610
+ Whether to normalize the output quaternion
611
+ Returns:
612
+ An updated Rotation
613
+ """
614
+ quats = self.get_quats()
615
+ quat_update = quat_multiply_by_vec(quats, q_update_vec)
616
+ if update_mask is not None:
617
+ quat_update = quat_update * update_mask
618
+ new_quats = quats + quat_update
619
+ return Rotation(
620
+ rot_mats=None,
621
+ quats=new_quats,
622
+ normalize_quats=normalize_quats,
623
+ )
624
+
625
+ def compose_r(self, r):
626
+ """
627
+ Compose the rotation matrices of the current Rotation object with
628
+ those of another.
629
+
630
+ Args:
631
+ r:
632
+ An update rotation object
633
+ Returns:
634
+ An updated rotation object
635
+ """
636
+ r1 = self.get_rot_mats()
637
+ r2 = r.get_rot_mats()
638
+ new_rot_mats = rot_matmul(r1, r2)
639
+ return Rotation(rot_mats=new_rot_mats, quats=None)
640
+
641
+ def compose_q(self, r, normalize_quats: bool = True):
642
+ """
643
+ Compose the quaternions of the current Rotation object with those
644
+ of another.
645
+
646
+ Depending on whether either Rotation was initialized with
647
+ quaternions, this function may call torch.linalg.eigh.
648
+
649
+ Args:
650
+ r:
651
+ An update rotation object
652
+ Returns:
653
+ An updated rotation object
654
+ """
655
+ q1 = self.get_quats()
656
+ q2 = r.get_quats()
657
+ new_quats = quat_multiply(q1, q2)
658
+ return Rotation(rot_mats=None, quats=new_quats, normalize_quats=normalize_quats)
659
+
660
+ def apply(self, pts: torch.Tensor) -> torch.Tensor:
661
+ """
662
+ Apply the current Rotation as a rotation matrix to a set of 3D
663
+ coordinates.
664
+
665
+ Args:
666
+ pts:
667
+ A [*, 3] set of points
668
+ Returns:
669
+ [*, 3] rotated points
670
+ """
671
+ rot_mats = self.get_rot_mats()
672
+ return rot_vec_mul(rot_mats, pts)
673
+
674
+ def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
675
+ """
676
+ The inverse of the apply() method.
677
+
678
+ Args:
679
+ pts:
680
+ A [*, 3] set of points
681
+ Returns:
682
+ [*, 3] inverse-rotated points
683
+ """
684
+ rot_mats = self.get_rot_mats()
685
+ inv_rot_mats = invert_rot_mat(rot_mats)
686
+ return rot_vec_mul(inv_rot_mats, pts)
687
+
688
+ def invert(self):
689
+ """
690
+ Returns the inverse of the current Rotation.
691
+
692
+ Returns:
693
+ The inverse of the current Rotation
694
+ """
695
+ if self._rot_mats is not None:
696
+ return Rotation(rot_mats=invert_rot_mat(self._rot_mats), quats=None)
697
+ elif self._quats is not None:
698
+ return Rotation(
699
+ rot_mats=None,
700
+ quats=invert_quat(self._quats),
701
+ normalize_quats=False,
702
+ )
703
+ else:
704
+ raise ValueError("Both rotations are None")
705
+
706
+ # "Tensor" stuff
707
+
708
+ def unsqueeze(
709
+ self,
710
+ dim: int,
711
+ ):
712
+ """
713
+ Analogous to torch.unsqueeze. The dimension is relative to the
714
+ shape of the Rotation object.
715
+
716
+ Args:
717
+ dim: A positive or negative dimension index.
718
+ Returns:
719
+ The unsqueezed Rotation.
720
+ """
721
+ if dim >= len(self.shape):
722
+ raise ValueError("Invalid dimension")
723
+
724
+ if self._rot_mats is not None:
725
+ rot_mats = self._rot_mats.unsqueeze(dim if dim >= 0 else dim - 2)
726
+ return Rotation(rot_mats=rot_mats, quats=None)
727
+ elif self._quats is not None:
728
+ quats = self._quats.unsqueeze(dim if dim >= 0 else dim - 1)
729
+ return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
730
+ else:
731
+ raise ValueError("Both rotations are None")
732
+
733
+ @staticmethod
734
+ def cat(
735
+ rs,
736
+ dim: int,
737
+ ):
738
+ """
739
+ Concatenates rotations along one of the batch dimensions. Analogous
740
+ to torch.cat().
741
+
742
+ Note that the output of this operation is always a rotation matrix,
743
+ regardless of the format of input rotations.
744
+
745
+ Args:
746
+ rs:
747
+ A list of rotation objects
748
+ dim:
749
+ The dimension along which the rotations should be
750
+ concatenated
751
+ Returns:
752
+ A concatenated Rotation object in rotation matrix format
753
+ """
754
+ rot_mats = [r.get_rot_mats() for r in rs]
755
+ rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2)
756
+
757
+ return Rotation(rot_mats=rot_mats, quats=None)
758
+
759
+ def map_tensor_fn(self, fn):
760
+ """
761
+ Apply a Tensor -> Tensor function to underlying rotation tensors,
762
+ mapping over the rotation dimension(s). Can be used e.g. to sum out
763
+ a one-hot batch dimension.
764
+
765
+ Args:
766
+ fn:
767
+ A Tensor -> Tensor function to be mapped over the Rotation
768
+ Returns:
769
+ The transformed Rotation object
770
+ """
771
+ if self._rot_mats is not None:
772
+ rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,))
773
+ rot_mats = torch.stack(
774
+ list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1
775
+ )
776
+ rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3))
777
+ return Rotation(rot_mats=rot_mats, quats=None)
778
+ elif self._quats is not None:
779
+ quats = torch.stack(
780
+ list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1
781
+ )
782
+ return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
783
+ else:
784
+ raise ValueError("Both rotations are None")
785
+
786
+ def cuda(self):
787
+ """
788
+ Analogous to the cuda() method of torch Tensors
789
+
790
+ Returns:
791
+ A copy of the Rotation in CUDA memory
792
+ """
793
+ if self._rot_mats is not None:
794
+ return Rotation(rot_mats=self._rot_mats.cuda(), quats=None)
795
+ elif self._quats is not None:
796
+ return Rotation(
797
+ rot_mats=None, quats=self._quats.cuda(), normalize_quats=False
798
+ )
799
+ else:
800
+ raise ValueError("Both rotations are None")
801
+
802
+ def to(self, device: Optional[torch.device], dtype: Optional[torch.dtype]):
803
+ """
804
+ Analogous to the to() method of torch Tensors
805
+
806
+ Args:
807
+ device:
808
+ A torch device
809
+ dtype:
810
+ A torch dtype
811
+ Returns:
812
+ A copy of the Rotation using the new device and dtype
813
+ """
814
+ if self._rot_mats is not None:
815
+ return Rotation(
816
+ rot_mats=self._rot_mats.to(device=device, dtype=dtype),
817
+ quats=None,
818
+ )
819
+ elif self._quats is not None:
820
+ return Rotation(
821
+ rot_mats=None,
822
+ quats=self._quats.to(device=device, dtype=dtype),
823
+ normalize_quats=False,
824
+ )
825
+ else:
826
+ raise ValueError("Both rotations are None")
827
+
828
+ def detach(self):
829
+ """
830
+ Returns a copy of the Rotation whose underlying Tensor has been
831
+ detached from its torch graph.
832
+
833
+ Returns:
834
+ A copy of the Rotation whose underlying Tensor has been detached
835
+ from its torch graph
836
+ """
837
+ if self._rot_mats is not None:
838
+ return Rotation(rot_mats=self._rot_mats.detach(), quats=None)
839
+ elif self._quats is not None:
840
+ return Rotation(
841
+ rot_mats=None,
842
+ quats=self._quats.detach(),
843
+ normalize_quats=False,
844
+ )
845
+ else:
846
+ raise ValueError("Both rotations are None")
847
+
848
+
849
+ class Rigid:
850
+ """
851
+ A class representing a rigid transformation. Little more than a wrapper
852
+ around two objects: a Rotation object and a [*, 3] translation
853
+ Designed to behave approximately like a single torch tensor with the
854
+ shape of the shared batch dimensions of its component parts.
855
+ """
856
+
857
+ def __init__(
858
+ self,
859
+ rots: Optional[Rotation],
860
+ trans: Optional[torch.Tensor],
861
+ ):
862
+ """
863
+ Args:
864
+ rots: A [*, 3, 3] rotation tensor
865
+ trans: A corresponding [*, 3] translation tensor
866
+ """
867
+ # (we need device, dtype, etc. from at least one input)
868
+
869
+ batch_dims, dtype, device, requires_grad = None, None, None, None
870
+ if trans is not None:
871
+ batch_dims = trans.shape[:-1]
872
+ dtype = trans.dtype
873
+ device = trans.device
874
+ requires_grad = trans.requires_grad
875
+ elif rots is not None:
876
+ batch_dims = rots.shape
877
+ dtype = rots.dtype
878
+ device = rots.device
879
+ requires_grad = rots.requires_grad
880
+ else:
881
+ raise ValueError("At least one input argument must be specified")
882
+
883
+ if rots is None:
884
+ rots = Rotation.identity(
885
+ batch_dims,
886
+ dtype,
887
+ device,
888
+ requires_grad,
889
+ )
890
+ elif trans is None:
891
+ trans = identity_trans(
892
+ batch_dims,
893
+ dtype,
894
+ device,
895
+ requires_grad,
896
+ )
897
+
898
+ if (rots.shape != trans.shape[:-1]) or (rots.device != trans.device):
899
+ raise ValueError("Rots and trans incompatible")
900
+
901
+ # Force full precision. Happens to the rotations automatically.
902
+ trans = trans.type(torch.float32)
903
+
904
+ self._rots = rots
905
+ self._trans = trans
906
+
907
+ @staticmethod
908
+ def identity(
909
+ shape: Tuple[int],
910
+ dtype: Optional[torch.dtype] = None,
911
+ device: Optional[torch.device] = None,
912
+ requires_grad: bool = True,
913
+ fmt: str = "quat",
914
+ ):
915
+ """
916
+ Constructs an identity transformation.
917
+
918
+ Args:
919
+ shape:
920
+ The desired shape
921
+ dtype:
922
+ The dtype of both internal tensors
923
+ device:
924
+ The device of both internal tensors
925
+ requires_grad:
926
+ Whether grad should be enabled for the internal tensors
927
+ Returns:
928
+ The identity transformation
929
+ """
930
+ return Rigid(
931
+ Rotation.identity(shape, dtype, device, requires_grad, fmt=fmt),
932
+ identity_trans(shape, dtype, device, requires_grad),
933
+ )
934
+
935
+ def __getitem__(
936
+ self,
937
+ index: Any,
938
+ ):
939
+ """
940
+ Indexes the affine transformation with PyTorch-style indices.
941
+ The index is applied to the shared dimensions of both the rotation
942
+ and the translation.
943
+
944
+ E.g.::
945
+
946
+ r = Rotation(rot_mats=torch.rand(10, 10, 3, 3), quats=None)
947
+ t = Rigid(r, torch.rand(10, 10, 3))
948
+ indexed = t[3, 4:6]
949
+ assert indexed.shape == (2,)
950
+ assert indexed.get_rots().shape == (2,)
951
+ assert indexed.get_trans().shape == (2, 3)
952
+
953
+ Args:
954
+ index: A standard torch tensor index. E.g. 8, (10, None, 3),
955
+ or (3, slice(0, 1, None))
956
+ Returns:
957
+ The indexed tensor
958
+ """
959
+ if not isinstance(index, tuple):
960
+ index = (index,)
961
+
962
+ return Rigid(
963
+ self._rots[index],
964
+ self._trans[index + (slice(None),)],
965
+ )
966
+
967
+ def __mul__(
968
+ self,
969
+ right: torch.Tensor,
970
+ ):
971
+ """
972
+ Pointwise left multiplication of the transformation with a tensor.
973
+ Can be used to e.g. mask the Rigid.
974
+
975
+ Args:
976
+ right:
977
+ The tensor multiplicand
978
+ Returns:
979
+ The product
980
+ """
981
+ if not (isinstance(right, torch.Tensor)):
982
+ raise TypeError("The other multiplicand must be a Tensor")
983
+
984
+ new_rots = self._rots * right
985
+ new_trans = self._trans * right[..., None]
986
+
987
+ return Rigid(new_rots, new_trans)
988
+
989
+ def __rmul__(
990
+ self,
991
+ left: torch.Tensor,
992
+ ):
993
+ """
994
+ Reverse pointwise multiplication of the transformation with a
995
+ tensor.
996
+
997
+ Args:
998
+ left:
999
+ The left multiplicand
1000
+ Returns:
1001
+ The product
1002
+ """
1003
+ return self.__mul__(left)
1004
+
1005
+ @property
1006
+ def shape(self) -> torch.Size:
1007
+ """
1008
+ Returns the shape of the shared dimensions of the rotation and
1009
+ the translation.
1010
+
1011
+ Returns:
1012
+ The shape of the transformation
1013
+ """
1014
+ s = self._trans.shape[:-1]
1015
+ return s
1016
+
1017
+ @property
1018
+ def device(self) -> torch.device:
1019
+ """
1020
+ Returns the device on which the Rigid's tensors are located.
1021
+
1022
+ Returns:
1023
+ The device on which the Rigid's tensors are located
1024
+ """
1025
+ return self._trans.device
1026
+
1027
+ def get_rots(self) -> Rotation:
1028
+ """
1029
+ Getter for the rotation.
1030
+
1031
+ Returns:
1032
+ The rotation object
1033
+ """
1034
+ return self._rots
1035
+
1036
+ def get_trans(self) -> torch.Tensor:
1037
+ """
1038
+ Getter for the translation.
1039
+
1040
+ Returns:
1041
+ The stored translation
1042
+ """
1043
+ return self._trans
1044
+
1045
+ def compose_q_update_vec(
1046
+ self,
1047
+ q_update_vec: torch.Tensor,
1048
+ update_mask: torch.Tensor = None,
1049
+ ):
1050
+ """
1051
+ Composes the transformation with a quaternion update vector of
1052
+ shape [*, 6], where the final 6 columns represent the x, y, and
1053
+ z values of a quaternion of form (1, x, y, z) followed by a 3D
1054
+ translation.
1055
+
1056
+ Args:
1057
+ q_vec: The quaternion update vector.
1058
+ Returns:
1059
+ The composed transformation.
1060
+ """
1061
+ q_vec, t_vec = q_update_vec[..., :3], q_update_vec[..., 3:]
1062
+ new_rots = self._rots.compose_q_update_vec(q_vec, update_mask=update_mask)
1063
+
1064
+ trans_update = self._rots.apply(t_vec)
1065
+ if update_mask is not None:
1066
+ trans_update = trans_update * update_mask
1067
+ new_translation = self._trans + trans_update
1068
+
1069
+ return Rigid(new_rots, new_translation)
1070
+
1071
+ def compose_tran_update_vec(
1072
+ self,
1073
+ t_vec: torch.Tensor,
1074
+ update_mask: torch.Tensor = None,
1075
+ ):
1076
+ """
1077
+ Composes the transformation with a quaternion update vector of
1078
+ shape [*, 3], where columns represent a 3D translation.
1079
+
1080
+ Args:
1081
+ q_vec: The quaternion update vector.
1082
+ Returns:
1083
+ The composed transformation.
1084
+ """
1085
+ trans_update = self._rots.apply(t_vec)
1086
+ if update_mask is not None:
1087
+ trans_update = trans_update * update_mask
1088
+ new_translation = self._trans + trans_update
1089
+
1090
+ return Rigid(self._rots, new_translation)
1091
+
1092
+ def compose(
1093
+ self,
1094
+ r,
1095
+ ):
1096
+ """
1097
+ Composes the current rigid object with another.
1098
+
1099
+ Args:
1100
+ r:
1101
+ Another Rigid object
1102
+ Returns:
1103
+ The composition of the two transformations
1104
+ """
1105
+ new_rot = self._rots.compose_r(r._rots)
1106
+ new_trans = self._rots.apply(r._trans) + self._trans
1107
+ return Rigid(new_rot, new_trans)
1108
+
1109
+ def compose_r(self, rot, order="right"):
1110
+ """
1111
+ Composes the current rigid object with another.
1112
+
1113
+ Args:
1114
+ r:
1115
+ Another Rigid object
1116
+ order:
1117
+ Order in which to perform rotation multiplication.
1118
+ Returns:
1119
+ The composition of the two transformations
1120
+ """
1121
+ if order == "right":
1122
+ new_rot = self._rots.compose_r(rot)
1123
+ elif order == "left":
1124
+ new_rot = rot.compose_r(self._rots)
1125
+ else:
1126
+ raise ValueError(f"Unrecognized multiplication order: {order}")
1127
+ return Rigid(new_rot, self._trans)
1128
+
1129
+ def apply(
1130
+ self,
1131
+ pts: torch.Tensor,
1132
+ ) -> torch.Tensor:
1133
+ """
1134
+ Applies the transformation to a coordinate tensor.
1135
+
1136
+ Args:
1137
+ pts: A [*, 3] coordinate tensor.
1138
+ Returns:
1139
+ The transformed points.
1140
+ """
1141
+ rotated = self._rots.apply(pts)
1142
+ return rotated + self._trans
1143
+
1144
+ def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
1145
+ """
1146
+ Applies the inverse of the transformation to a coordinate tensor.
1147
+
1148
+ Args:
1149
+ pts: A [*, 3] coordinate tensor
1150
+ Returns:
1151
+ The transformed points.
1152
+ """
1153
+ pts = pts - self._trans
1154
+ return self._rots.invert_apply(pts)
1155
+
1156
+ def invert(self):
1157
+ """
1158
+ Inverts the transformation.
1159
+
1160
+ Returns:
1161
+ The inverse transformation.
1162
+ """
1163
+ rot_inv = self._rots.invert()
1164
+ trn_inv = rot_inv.apply(self._trans)
1165
+
1166
+ return Rigid(rot_inv, -1 * trn_inv)
1167
+
1168
+ def map_tensor_fn(self, fn):
1169
+ """
1170
+ Apply a Tensor -> Tensor function to underlying translation and
1171
+ rotation tensors, mapping over the translation/rotation dimensions
1172
+ respectively.
1173
+
1174
+ Args:
1175
+ fn:
1176
+ A Tensor -> Tensor function to be mapped over the Rigid
1177
+ Returns:
1178
+ The transformed Rigid object
1179
+ """
1180
+ new_rots = self._rots.map_tensor_fn(fn)
1181
+ new_trans = torch.stack(
1182
+ list(map(fn, torch.unbind(self._trans, dim=-1))), dim=-1
1183
+ )
1184
+
1185
+ return Rigid(new_rots, new_trans)
1186
+
1187
+ def to_tensor_4x4(self) -> torch.Tensor:
1188
+ """
1189
+ Converts a transformation to a homogenous transformation tensor.
1190
+
1191
+ Returns:
1192
+ A [*, 4, 4] homogenous transformation tensor
1193
+ """
1194
+ tensor = self._trans.new_zeros((*self.shape, 4, 4))
1195
+ tensor[..., :3, :3] = self._rots.get_rot_mats()
1196
+ tensor[..., :3, 3] = self._trans
1197
+ tensor[..., 3, 3] = 1
1198
+ return tensor
1199
+
1200
+ @staticmethod
1201
+ def from_tensor_4x4(t: torch.Tensor):
1202
+ """
1203
+ Constructs a transformation from a homogenous transformation
1204
+ tensor.
1205
+
1206
+ Args:
1207
+ t: [*, 4, 4] homogenous transformation tensor
1208
+ Returns:
1209
+ T object with shape [*]
1210
+ """
1211
+ if t.shape[-2:] != (4, 4):
1212
+ raise ValueError("Incorrectly shaped input tensor")
1213
+
1214
+ rots = Rotation(rot_mats=t[..., :3, :3], quats=None)
1215
+ trans = t[..., :3, 3]
1216
+
1217
+ return Rigid(rots, trans)
1218
+
1219
+ def to_tensor_7(self) -> torch.Tensor:
1220
+ """
1221
+ Converts a transformation to a tensor with 7 final columns, four
1222
+ for the quaternion followed by three for the translation.
1223
+
1224
+ Returns:
1225
+ A [*, 7] tensor representation of the transformation
1226
+ """
1227
+ tensor = self._trans.new_zeros((*self.shape, 7))
1228
+ tensor[..., :4] = self._rots.get_quats()
1229
+ tensor[..., 4:] = self._trans
1230
+
1231
+ return tensor
1232
+
1233
+ @staticmethod
1234
+ def from_tensor_7(
1235
+ t: torch.Tensor,
1236
+ normalize_quats: bool = False,
1237
+ ):
1238
+ if t.shape[-1] != 7:
1239
+ raise ValueError("Incorrectly shaped input tensor")
1240
+
1241
+ quats, trans = t[..., :4], t[..., 4:]
1242
+
1243
+ rots = Rotation(rot_mats=None, quats=quats, normalize_quats=normalize_quats)
1244
+
1245
+ return Rigid(rots, trans)
1246
+
1247
+ @staticmethod
1248
+ def from_3_points(
1249
+ p_neg_x_axis: torch.Tensor,
1250
+ origin: torch.Tensor,
1251
+ p_xy_plane: torch.Tensor,
1252
+ eps: float = 1e-4,
1253
+ ):
1254
+ """
1255
+ Implements algorithm 21. Constructs transformations from sets of 3
1256
+ points using the Gram-Schmidt algorithm.
1257
+
1258
+ Args:
1259
+ p_neg_x_axis: [*, 3] coordinates
1260
+ origin: [*, 3] coordinates used as frame origins
1261
+ p_xy_plane: [*, 3] coordinates
1262
+ eps: Small epsilon value
1263
+ Returns:
1264
+ A transformation object of shape [*]
1265
+ """
1266
+ p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1)
1267
+ origin = torch.unbind(origin, dim=-1)
1268
+ p_xy_plane = torch.unbind(p_xy_plane, dim=-1)
1269
+
1270
+ e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)]
1271
+ e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)]
1272
+
1273
+ denom = torch.sqrt(sum((c * c for c in e0)) + eps)
1274
+ e0 = [c / denom for c in e0]
1275
+ dot = sum((c1 * c2 for c1, c2 in zip(e0, e1)))
1276
+ e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)]
1277
+ denom = torch.sqrt(sum((c * c for c in e1)) + eps)
1278
+ e1 = [c / denom for c in e1]
1279
+ e2 = [
1280
+ e0[1] * e1[2] - e0[2] * e1[1],
1281
+ e0[2] * e1[0] - e0[0] * e1[2],
1282
+ e0[0] * e1[1] - e0[1] * e1[0],
1283
+ ]
1284
+
1285
+ rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1)
1286
+ rots = rots.reshape(rots.shape[:-1] + (3, 3))
1287
+
1288
+ rot_obj = Rotation(rot_mats=rots, quats=None)
1289
+
1290
+ return Rigid(rot_obj, torch.stack(origin, dim=-1))
1291
+
1292
+ def unsqueeze(
1293
+ self,
1294
+ dim: int,
1295
+ ):
1296
+ """
1297
+ Analogous to torch.unsqueeze. The dimension is relative to the
1298
+ shared dimensions of the rotation/translation.
1299
+
1300
+ Args:
1301
+ dim: A positive or negative dimension index.
1302
+ Returns:
1303
+ The unsqueezed transformation.
1304
+ """
1305
+ if dim >= len(self.shape):
1306
+ raise ValueError("Invalid dimension")
1307
+ rots = self._rots.unsqueeze(dim)
1308
+ trans = self._trans.unsqueeze(dim if dim >= 0 else dim - 1)
1309
+
1310
+ return Rigid(rots, trans)
1311
+
1312
+ @staticmethod
1313
+ def cat(
1314
+ ts,
1315
+ dim: int,
1316
+ ):
1317
+ """
1318
+ Concatenates transformations along a new dimension.
1319
+
1320
+ Args:
1321
+ ts:
1322
+ A list of T objects
1323
+ dim:
1324
+ The dimension along which the transformations should be
1325
+ concatenated
1326
+ Returns:
1327
+ A concatenated transformation object
1328
+ """
1329
+ rots = Rotation.cat([t._rots for t in ts], dim)
1330
+ trans = torch.cat([t._trans for t in ts], dim=dim if dim >= 0 else dim - 1)
1331
+
1332
+ return Rigid(rots, trans)
1333
+
1334
+ def apply_rot_fn(self, fn):
1335
+ """
1336
+ Applies a Rotation -> Rotation function to the stored rotation
1337
+ object.
1338
+
1339
+ Args:
1340
+ fn: A function of type Rotation -> Rotation
1341
+ Returns:
1342
+ A transformation object with a transformed rotation.
1343
+ """
1344
+ return Rigid(fn(self._rots), self._trans)
1345
+
1346
+ def apply_trans_fn(self, fn):
1347
+ """
1348
+ Applies a Tensor -> Tensor function to the stored translation.
1349
+
1350
+ Args:
1351
+ fn:
1352
+ A function of type Tensor -> Tensor to be applied to the
1353
+ translation
1354
+ Returns:
1355
+ A transformation object with a transformed translation.
1356
+ """
1357
+ return Rigid(self._rots, fn(self._trans))
1358
+
1359
+ def scale_translation(self, trans_scale_factor: float):
1360
+ """
1361
+ Scales the translation by a constant factor.
1362
+
1363
+ Args:
1364
+ trans_scale_factor:
1365
+ The constant factor
1366
+ Returns:
1367
+ A transformation object with a scaled translation.
1368
+ """
1369
+ fn = lambda t: t * trans_scale_factor # noqa: E731
1370
+ return self.apply_trans_fn(fn)
1371
+
1372
+ def stop_rot_gradient(self):
1373
+ """
1374
+ Detaches the underlying rotation object
1375
+
1376
+ Returns:
1377
+ A transformation object with detached rotations
1378
+ """
1379
+ fn = lambda r: r.detach() # noqa: E731
1380
+ return self.apply_rot_fn(fn)
1381
+
1382
+ @staticmethod
1383
+ def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20):
1384
+ """
1385
+ Returns a transformation object from reference coordinates.
1386
+
1387
+ Note that this method does not take care of symmetries. If you
1388
+ provide the atom positions in the non-standard way, the N atom will
1389
+ end up not at [-0.527250, 1.359329, 0.0] but instead at
1390
+ [-0.527250, -1.359329, 0.0]. You need to take care of such cases in
1391
+ your code.
1392
+
1393
+ Args:
1394
+ n_xyz: A [*, 3] tensor of nitrogen xyz coordinates.
1395
+ ca_xyz: A [*, 3] tensor of carbon alpha xyz coordinates.
1396
+ c_xyz: A [*, 3] tensor of carbon xyz coordinates.
1397
+ Returns:
1398
+ A transformation object. After applying the translation and
1399
+ rotation to the reference backbone, the coordinates will
1400
+ approximately equal to the input coordinates.
1401
+ """
1402
+ translation = -1 * ca_xyz
1403
+ n_xyz = n_xyz + translation
1404
+ c_xyz = c_xyz + translation
1405
+
1406
+ c_x, c_y, c_z = [c_xyz[..., i] for i in range(3)]
1407
+ norm = torch.sqrt(eps + c_x**2 + c_y**2)
1408
+ sin_c1 = -c_y / norm
1409
+ cos_c1 = c_x / norm
1410
+
1411
+ c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3))
1412
+ c1_rots[..., 0, 0] = cos_c1
1413
+ c1_rots[..., 0, 1] = -1 * sin_c1
1414
+ c1_rots[..., 1, 0] = sin_c1
1415
+ c1_rots[..., 1, 1] = cos_c1
1416
+ c1_rots[..., 2, 2] = 1
1417
+
1418
+ norm = torch.sqrt(eps + c_x**2 + c_y**2 + c_z**2)
1419
+ sin_c2 = c_z / norm
1420
+ cos_c2 = torch.sqrt(c_x**2 + c_y**2) / norm
1421
+
1422
+ c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
1423
+ c2_rots[..., 0, 0] = cos_c2
1424
+ c2_rots[..., 0, 2] = sin_c2
1425
+ c2_rots[..., 1, 1] = 1
1426
+ c1_rots[..., 2, 0] = -1 * sin_c2
1427
+ c1_rots[..., 2, 2] = cos_c2
1428
+
1429
+ c_rots = rot_matmul(c2_rots, c1_rots)
1430
+ n_xyz = rot_vec_mul(c_rots, n_xyz)
1431
+
1432
+ _, n_y, n_z = [n_xyz[..., i] for i in range(3)]
1433
+ norm = torch.sqrt(eps + n_y**2 + n_z**2)
1434
+ sin_n = -n_z / norm
1435
+ cos_n = n_y / norm
1436
+
1437
+ n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
1438
+ n_rots[..., 0, 0] = 1
1439
+ n_rots[..., 1, 1] = cos_n
1440
+ n_rots[..., 1, 2] = -1 * sin_n
1441
+ n_rots[..., 2, 1] = sin_n
1442
+ n_rots[..., 2, 2] = cos_n
1443
+
1444
+ rots = rot_matmul(n_rots, c_rots)
1445
+
1446
+ rots = rots.transpose(-1, -2)
1447
+ translation = -1 * translation
1448
+
1449
+ rot_obj = Rotation(rot_mats=rots, quats=None)
1450
+
1451
+ return Rigid(rot_obj, translation)
1452
+
1453
+ def cuda(self):
1454
+ """
1455
+ Moves the transformation object to GPU memory
1456
+
1457
+ Returns:
1458
+ A version of the transformation on GPU
1459
+ """
1460
+ return Rigid(self._rots.cuda(), self._trans.cuda())