rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
foundry/utils/rigid.py
ADDED
|
@@ -0,0 +1,1460 @@
|
|
|
1
|
+
# Copyright 2021 AlQuraishi Laboratory
|
|
2
|
+
# Copyright 2021 DeepMind Technologies Limited
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from typing import Any, Optional, Tuple
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
import torch
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def rot_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
|
23
|
+
"""
|
|
24
|
+
Performs matrix multiplication of two rotation matrix tensors. Written
|
|
25
|
+
out by hand to avoid AMP downcasting.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
a: [*, 3, 3] left multiplicand
|
|
29
|
+
b: [*, 3, 3] right multiplicand
|
|
30
|
+
Returns:
|
|
31
|
+
The product ab
|
|
32
|
+
"""
|
|
33
|
+
row_1 = torch.stack(
|
|
34
|
+
[
|
|
35
|
+
a[..., 0, 0] * b[..., 0, 0]
|
|
36
|
+
+ a[..., 0, 1] * b[..., 1, 0]
|
|
37
|
+
+ a[..., 0, 2] * b[..., 2, 0],
|
|
38
|
+
a[..., 0, 0] * b[..., 0, 1]
|
|
39
|
+
+ a[..., 0, 1] * b[..., 1, 1]
|
|
40
|
+
+ a[..., 0, 2] * b[..., 2, 1],
|
|
41
|
+
a[..., 0, 0] * b[..., 0, 2]
|
|
42
|
+
+ a[..., 0, 1] * b[..., 1, 2]
|
|
43
|
+
+ a[..., 0, 2] * b[..., 2, 2],
|
|
44
|
+
],
|
|
45
|
+
dim=-1,
|
|
46
|
+
)
|
|
47
|
+
row_2 = torch.stack(
|
|
48
|
+
[
|
|
49
|
+
a[..., 1, 0] * b[..., 0, 0]
|
|
50
|
+
+ a[..., 1, 1] * b[..., 1, 0]
|
|
51
|
+
+ a[..., 1, 2] * b[..., 2, 0],
|
|
52
|
+
a[..., 1, 0] * b[..., 0, 1]
|
|
53
|
+
+ a[..., 1, 1] * b[..., 1, 1]
|
|
54
|
+
+ a[..., 1, 2] * b[..., 2, 1],
|
|
55
|
+
a[..., 1, 0] * b[..., 0, 2]
|
|
56
|
+
+ a[..., 1, 1] * b[..., 1, 2]
|
|
57
|
+
+ a[..., 1, 2] * b[..., 2, 2],
|
|
58
|
+
],
|
|
59
|
+
dim=-1,
|
|
60
|
+
)
|
|
61
|
+
row_3 = torch.stack(
|
|
62
|
+
[
|
|
63
|
+
a[..., 2, 0] * b[..., 0, 0]
|
|
64
|
+
+ a[..., 2, 1] * b[..., 1, 0]
|
|
65
|
+
+ a[..., 2, 2] * b[..., 2, 0],
|
|
66
|
+
a[..., 2, 0] * b[..., 0, 1]
|
|
67
|
+
+ a[..., 2, 1] * b[..., 1, 1]
|
|
68
|
+
+ a[..., 2, 2] * b[..., 2, 1],
|
|
69
|
+
a[..., 2, 0] * b[..., 0, 2]
|
|
70
|
+
+ a[..., 2, 1] * b[..., 1, 2]
|
|
71
|
+
+ a[..., 2, 2] * b[..., 2, 2],
|
|
72
|
+
],
|
|
73
|
+
dim=-1,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
return torch.stack([row_1, row_2, row_3], dim=-2)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def rot_vec_mul(r: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
|
80
|
+
"""
|
|
81
|
+
Applies a rotation to a vector. Written out by hand to avoid transfer
|
|
82
|
+
to avoid AMP downcasting.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
r: [*, 3, 3] rotation matrices
|
|
86
|
+
t: [*, 3] coordinate tensors
|
|
87
|
+
Returns:
|
|
88
|
+
[*, 3] rotated coordinates
|
|
89
|
+
"""
|
|
90
|
+
x = t[..., 0]
|
|
91
|
+
y = t[..., 1]
|
|
92
|
+
z = t[..., 2]
|
|
93
|
+
return torch.stack(
|
|
94
|
+
[
|
|
95
|
+
r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z,
|
|
96
|
+
r[..., 1, 0] * x + r[..., 1, 1] * y + r[..., 1, 2] * z,
|
|
97
|
+
r[..., 2, 0] * x + r[..., 2, 1] * y + r[..., 2, 2] * z,
|
|
98
|
+
],
|
|
99
|
+
dim=-1,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def identity_rot_mats(
|
|
104
|
+
batch_dims: Tuple[int],
|
|
105
|
+
dtype: Optional[torch.dtype] = None,
|
|
106
|
+
device: Optional[torch.device] = None,
|
|
107
|
+
requires_grad: bool = True,
|
|
108
|
+
) -> torch.Tensor:
|
|
109
|
+
rots = torch.eye(3, dtype=dtype, device=device, requires_grad=requires_grad)
|
|
110
|
+
rots = rots.view(*((1,) * len(batch_dims)), 3, 3)
|
|
111
|
+
rots = rots.expand(*batch_dims, -1, -1)
|
|
112
|
+
|
|
113
|
+
return rots
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def identity_trans(
|
|
117
|
+
batch_dims: Tuple[int],
|
|
118
|
+
dtype: Optional[torch.dtype] = None,
|
|
119
|
+
device: Optional[torch.device] = None,
|
|
120
|
+
requires_grad: bool = True,
|
|
121
|
+
) -> torch.Tensor:
|
|
122
|
+
trans = torch.zeros(
|
|
123
|
+
(*batch_dims, 3), dtype=dtype, device=device, requires_grad=requires_grad
|
|
124
|
+
)
|
|
125
|
+
return trans
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def identity_quats(
|
|
129
|
+
batch_dims: Tuple[int],
|
|
130
|
+
dtype: Optional[torch.dtype] = None,
|
|
131
|
+
device: Optional[torch.device] = None,
|
|
132
|
+
requires_grad: bool = True,
|
|
133
|
+
) -> torch.Tensor:
|
|
134
|
+
quat = torch.zeros(
|
|
135
|
+
(*batch_dims, 4), dtype=dtype, device=device, requires_grad=requires_grad
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
with torch.no_grad():
|
|
139
|
+
quat[..., 0] = 1
|
|
140
|
+
|
|
141
|
+
return quat
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
_quat_elements = ["a", "b", "c", "d"]
|
|
145
|
+
_qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements]
|
|
146
|
+
_qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)}
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def _to_mat(pairs):
|
|
150
|
+
mat = np.zeros((4, 4))
|
|
151
|
+
for pair in pairs:
|
|
152
|
+
key, value = pair
|
|
153
|
+
ind = _qtr_ind_dict[key]
|
|
154
|
+
mat[ind // 4][ind % 4] = value
|
|
155
|
+
|
|
156
|
+
return mat
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
_QTR_MAT = np.zeros((4, 4, 3, 3))
|
|
160
|
+
_QTR_MAT[..., 0, 0] = _to_mat([("aa", 1), ("bb", 1), ("cc", -1), ("dd", -1)])
|
|
161
|
+
_QTR_MAT[..., 0, 1] = _to_mat([("bc", 2), ("ad", -2)])
|
|
162
|
+
_QTR_MAT[..., 0, 2] = _to_mat([("bd", 2), ("ac", 2)])
|
|
163
|
+
_QTR_MAT[..., 1, 0] = _to_mat([("bc", 2), ("ad", 2)])
|
|
164
|
+
_QTR_MAT[..., 1, 1] = _to_mat([("aa", 1), ("bb", -1), ("cc", 1), ("dd", -1)])
|
|
165
|
+
_QTR_MAT[..., 1, 2] = _to_mat([("cd", 2), ("ab", -2)])
|
|
166
|
+
_QTR_MAT[..., 2, 0] = _to_mat([("bd", 2), ("ac", -2)])
|
|
167
|
+
_QTR_MAT[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)])
|
|
168
|
+
_QTR_MAT[..., 2, 2] = _to_mat([("aa", 1), ("bb", -1), ("cc", -1), ("dd", 1)])
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def quat_to_rot(quat: torch.Tensor) -> torch.Tensor:
|
|
172
|
+
"""
|
|
173
|
+
Converts a quaternion to a rotation matrix.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
quat: [*, 4] quaternions
|
|
177
|
+
Returns:
|
|
178
|
+
[*, 3, 3] rotation matrices
|
|
179
|
+
"""
|
|
180
|
+
# [*, 4, 4]
|
|
181
|
+
quat = quat[..., None] * quat[..., None, :]
|
|
182
|
+
|
|
183
|
+
# [4, 4, 3, 3]
|
|
184
|
+
mat = quat.new_tensor(_QTR_MAT, requires_grad=False)
|
|
185
|
+
|
|
186
|
+
# [*, 4, 4, 3, 3]
|
|
187
|
+
shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape)
|
|
188
|
+
quat = quat[..., None, None] * shaped_qtr_mat
|
|
189
|
+
|
|
190
|
+
# [*, 3, 3]
|
|
191
|
+
return torch.sum(quat, dim=(-3, -4))
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def rot_to_quat(
|
|
195
|
+
rot: torch.Tensor,
|
|
196
|
+
):
|
|
197
|
+
if rot.shape[-2:] != (3, 3):
|
|
198
|
+
raise ValueError("Input rotation is incorrectly shaped")
|
|
199
|
+
|
|
200
|
+
rot = [[rot[..., i, j] for j in range(3)] for i in range(3)]
|
|
201
|
+
[[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot
|
|
202
|
+
|
|
203
|
+
k = [
|
|
204
|
+
[
|
|
205
|
+
xx + yy + zz,
|
|
206
|
+
zy - yz,
|
|
207
|
+
xz - zx,
|
|
208
|
+
yx - xy,
|
|
209
|
+
],
|
|
210
|
+
[
|
|
211
|
+
zy - yz,
|
|
212
|
+
xx - yy - zz,
|
|
213
|
+
xy + yx,
|
|
214
|
+
xz + zx,
|
|
215
|
+
],
|
|
216
|
+
[
|
|
217
|
+
xz - zx,
|
|
218
|
+
xy + yx,
|
|
219
|
+
yy - xx - zz,
|
|
220
|
+
yz + zy,
|
|
221
|
+
],
|
|
222
|
+
[
|
|
223
|
+
yx - xy,
|
|
224
|
+
xz + zx,
|
|
225
|
+
yz + zy,
|
|
226
|
+
zz - xx - yy,
|
|
227
|
+
],
|
|
228
|
+
]
|
|
229
|
+
|
|
230
|
+
k = (1.0 / 3.0) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2)
|
|
231
|
+
|
|
232
|
+
_, vectors = torch.linalg.eigh(k)
|
|
233
|
+
return vectors[..., -1]
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
_QUAT_MULTIPLY = np.zeros((4, 4, 4))
|
|
237
|
+
_QUAT_MULTIPLY[:, :, 0] = [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, -1]]
|
|
238
|
+
|
|
239
|
+
_QUAT_MULTIPLY[:, :, 1] = [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 0, -1, 0]]
|
|
240
|
+
|
|
241
|
+
_QUAT_MULTIPLY[:, :, 2] = [[0, 0, 1, 0], [0, 0, 0, -1], [1, 0, 0, 0], [0, 1, 0, 0]]
|
|
242
|
+
|
|
243
|
+
_QUAT_MULTIPLY[:, :, 3] = [[0, 0, 0, 1], [0, 0, 1, 0], [0, -1, 0, 0], [1, 0, 0, 0]]
|
|
244
|
+
|
|
245
|
+
_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :]
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def quat_multiply(quat1, quat2):
|
|
249
|
+
"""Multiply a quaternion by another quaternion."""
|
|
250
|
+
mat = quat1.new_tensor(_QUAT_MULTIPLY)
|
|
251
|
+
reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape)
|
|
252
|
+
return torch.sum(
|
|
253
|
+
reshaped_mat * quat1[..., :, None, None] * quat2[..., None, :, None],
|
|
254
|
+
dim=(-3, -2),
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def quat_multiply_by_vec(quat, vec):
|
|
259
|
+
"""Multiply a quaternion by a pure-vector quaternion."""
|
|
260
|
+
mat = quat.new_tensor(_QUAT_MULTIPLY_BY_VEC)
|
|
261
|
+
reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape)
|
|
262
|
+
return torch.sum(
|
|
263
|
+
reshaped_mat * quat[..., :, None, None] * vec[..., None, :, None], dim=(-3, -2)
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def invert_rot_mat(rot_mat: torch.Tensor):
|
|
268
|
+
return rot_mat.transpose(-1, -2)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def invert_quat(quat: torch.Tensor):
|
|
272
|
+
quat_prime = quat.clone()
|
|
273
|
+
quat_prime[..., 1:] *= -1
|
|
274
|
+
inv = quat_prime / torch.sum(quat**2, dim=-1, keepdim=True)
|
|
275
|
+
return inv
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
class Rotation:
|
|
279
|
+
"""
|
|
280
|
+
A 3D rotation. Depending on how the object is initialized, the
|
|
281
|
+
rotation is represented by either a rotation matrix or a
|
|
282
|
+
quaternion, though both formats are made available by helper functions.
|
|
283
|
+
To simplify gradient computation, the underlying format of the
|
|
284
|
+
rotation cannot be changed in-place. Like Rigid, the class is designed
|
|
285
|
+
to mimic the behavior of a torch Tensor, almost as if each Rotation
|
|
286
|
+
object were a tensor of rotations, in one format or another.
|
|
287
|
+
"""
|
|
288
|
+
|
|
289
|
+
def __init__(
|
|
290
|
+
self,
|
|
291
|
+
rot_mats: Optional[torch.Tensor] = None,
|
|
292
|
+
quats: Optional[torch.Tensor] = None,
|
|
293
|
+
normalize_quats: bool = True,
|
|
294
|
+
):
|
|
295
|
+
"""
|
|
296
|
+
Args:
|
|
297
|
+
rot_mats:
|
|
298
|
+
A [*, 3, 3] rotation matrix tensor. Mutually exclusive with
|
|
299
|
+
quats
|
|
300
|
+
quats:
|
|
301
|
+
A [*, 4] quaternion. Mutually exclusive with rot_mats. If
|
|
302
|
+
normalize_quats is not True, must be a unit quaternion
|
|
303
|
+
normalize_quats:
|
|
304
|
+
If quats is specified, whether to normalize quats
|
|
305
|
+
"""
|
|
306
|
+
if (rot_mats is None and quats is None) or (
|
|
307
|
+
rot_mats is not None and quats is not None
|
|
308
|
+
):
|
|
309
|
+
raise ValueError("Exactly one input argument must be specified")
|
|
310
|
+
|
|
311
|
+
if (rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or (
|
|
312
|
+
quats is not None and quats.shape[-1] != 4
|
|
313
|
+
):
|
|
314
|
+
raise ValueError("Incorrectly shaped rotation matrix or quaternion")
|
|
315
|
+
|
|
316
|
+
# Force full-precision
|
|
317
|
+
if quats is not None:
|
|
318
|
+
quats = quats.type(torch.float32)
|
|
319
|
+
if rot_mats is not None:
|
|
320
|
+
rot_mats = rot_mats.type(torch.float32)
|
|
321
|
+
|
|
322
|
+
if quats is not None and normalize_quats:
|
|
323
|
+
quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True)
|
|
324
|
+
|
|
325
|
+
self._rot_mats = rot_mats
|
|
326
|
+
self._quats = quats
|
|
327
|
+
|
|
328
|
+
@staticmethod
|
|
329
|
+
def identity(
|
|
330
|
+
shape,
|
|
331
|
+
dtype: Optional[torch.dtype] = None,
|
|
332
|
+
device: Optional[torch.device] = None,
|
|
333
|
+
requires_grad: bool = True,
|
|
334
|
+
fmt: str = "quat",
|
|
335
|
+
):
|
|
336
|
+
"""
|
|
337
|
+
Returns an identity Rotation.
|
|
338
|
+
|
|
339
|
+
Args:
|
|
340
|
+
shape:
|
|
341
|
+
The "shape" of the resulting Rotation object. See documentation
|
|
342
|
+
for the shape property
|
|
343
|
+
dtype:
|
|
344
|
+
The torch dtype for the rotation
|
|
345
|
+
device:
|
|
346
|
+
The torch device for the new rotation
|
|
347
|
+
requires_grad:
|
|
348
|
+
Whether the underlying tensors in the new rotation object
|
|
349
|
+
should require gradient computation
|
|
350
|
+
fmt:
|
|
351
|
+
One of "quat" or "rot_mat". Determines the underlying format
|
|
352
|
+
of the new object's rotation
|
|
353
|
+
Returns:
|
|
354
|
+
A new identity rotation
|
|
355
|
+
"""
|
|
356
|
+
if fmt == "rot_mat":
|
|
357
|
+
rot_mats = identity_rot_mats(
|
|
358
|
+
shape,
|
|
359
|
+
dtype,
|
|
360
|
+
device,
|
|
361
|
+
requires_grad,
|
|
362
|
+
)
|
|
363
|
+
return Rotation(rot_mats=rot_mats, quats=None)
|
|
364
|
+
elif fmt == "quat":
|
|
365
|
+
quats = identity_quats(shape, dtype, device, requires_grad)
|
|
366
|
+
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
|
|
367
|
+
else:
|
|
368
|
+
raise ValueError(f"Invalid format: f{fmt}")
|
|
369
|
+
|
|
370
|
+
# Magic methods
|
|
371
|
+
|
|
372
|
+
def __getitem__(self, index: Any):
|
|
373
|
+
"""
|
|
374
|
+
Allows torch-style indexing over the virtual shape of the rotation
|
|
375
|
+
object. See documentation for the shape property.
|
|
376
|
+
|
|
377
|
+
Args:
|
|
378
|
+
index:
|
|
379
|
+
A torch index. E.g. (1, 3, 2), or (slice(None,))
|
|
380
|
+
Returns:
|
|
381
|
+
The indexed rotation
|
|
382
|
+
"""
|
|
383
|
+
if not isinstance(index, tuple):
|
|
384
|
+
index = (index,)
|
|
385
|
+
|
|
386
|
+
if self._rot_mats is not None:
|
|
387
|
+
rot_mats = self._rot_mats[index + (slice(None), slice(None))]
|
|
388
|
+
return Rotation(rot_mats=rot_mats)
|
|
389
|
+
elif self._quats is not None:
|
|
390
|
+
quats = self._quats[index + (slice(None),)]
|
|
391
|
+
return Rotation(quats=quats, normalize_quats=False)
|
|
392
|
+
else:
|
|
393
|
+
raise ValueError("Both rotations are None")
|
|
394
|
+
|
|
395
|
+
def __setitem__(self, index: Any, new: Any):
|
|
396
|
+
if not isinstance(index, tuple):
|
|
397
|
+
index = (index,)
|
|
398
|
+
|
|
399
|
+
if self._rot_mats is not None:
|
|
400
|
+
self._rot_mats[index + (slice(None), slice(None))] = new.get_rot_mats()
|
|
401
|
+
elif self._quats is not None:
|
|
402
|
+
self._quats[index + (slice(None),)] = new.get_quats()
|
|
403
|
+
else:
|
|
404
|
+
raise ValueError("Both rotations are None")
|
|
405
|
+
|
|
406
|
+
def __mul__(
|
|
407
|
+
self,
|
|
408
|
+
right: torch.Tensor,
|
|
409
|
+
):
|
|
410
|
+
"""
|
|
411
|
+
Pointwise left multiplication of the rotation with a tensor. Can be
|
|
412
|
+
used to e.g. mask the Rotation.
|
|
413
|
+
|
|
414
|
+
Args:
|
|
415
|
+
right:
|
|
416
|
+
The tensor multiplicand
|
|
417
|
+
Returns:
|
|
418
|
+
The product
|
|
419
|
+
"""
|
|
420
|
+
if not (isinstance(right, torch.Tensor)):
|
|
421
|
+
raise TypeError("The other multiplicand must be a Tensor")
|
|
422
|
+
|
|
423
|
+
if self._rot_mats is not None:
|
|
424
|
+
rot_mats = self._rot_mats * right[..., None, None]
|
|
425
|
+
return Rotation(rot_mats=rot_mats, quats=None)
|
|
426
|
+
elif self._quats is not None:
|
|
427
|
+
quats = self._quats * right[..., None]
|
|
428
|
+
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
|
|
429
|
+
else:
|
|
430
|
+
raise ValueError("Both rotations are None")
|
|
431
|
+
|
|
432
|
+
def __rmul__(
|
|
433
|
+
self,
|
|
434
|
+
left: torch.Tensor,
|
|
435
|
+
):
|
|
436
|
+
"""
|
|
437
|
+
Reverse pointwise multiplication of the rotation with a tensor.
|
|
438
|
+
|
|
439
|
+
Args:
|
|
440
|
+
left:
|
|
441
|
+
The left multiplicand
|
|
442
|
+
Returns:
|
|
443
|
+
The product
|
|
444
|
+
"""
|
|
445
|
+
return self.__mul__(left)
|
|
446
|
+
|
|
447
|
+
# Properties
|
|
448
|
+
|
|
449
|
+
@property
|
|
450
|
+
def shape(self) -> torch.Size:
|
|
451
|
+
"""
|
|
452
|
+
Returns the virtual shape of the rotation object. This shape is
|
|
453
|
+
defined as the batch dimensions of the underlying rotation matrix
|
|
454
|
+
or quaternion. If the Rotation was initialized with a [10, 3, 3]
|
|
455
|
+
rotation matrix tensor, for example, the resulting shape would be
|
|
456
|
+
[10].
|
|
457
|
+
|
|
458
|
+
Returns:
|
|
459
|
+
The virtual shape of the rotation object
|
|
460
|
+
"""
|
|
461
|
+
s = None
|
|
462
|
+
if self._quats is not None:
|
|
463
|
+
s = self._quats.shape[:-1]
|
|
464
|
+
else:
|
|
465
|
+
s = self._rot_mats.shape[:-2]
|
|
466
|
+
|
|
467
|
+
return s
|
|
468
|
+
|
|
469
|
+
@property
|
|
470
|
+
def dtype(self) -> torch.dtype:
|
|
471
|
+
"""
|
|
472
|
+
Returns the dtype of the underlying rotation.
|
|
473
|
+
|
|
474
|
+
Returns:
|
|
475
|
+
The dtype of the underlying rotation
|
|
476
|
+
"""
|
|
477
|
+
if self._rot_mats is not None:
|
|
478
|
+
return self._rot_mats.dtype
|
|
479
|
+
elif self._quats is not None:
|
|
480
|
+
return self._quats.dtype
|
|
481
|
+
else:
|
|
482
|
+
raise ValueError("Both rotations are None")
|
|
483
|
+
|
|
484
|
+
@property
|
|
485
|
+
def device(self) -> torch.device:
|
|
486
|
+
"""
|
|
487
|
+
The device of the underlying rotation
|
|
488
|
+
|
|
489
|
+
Returns:
|
|
490
|
+
The device of the underlying rotation
|
|
491
|
+
"""
|
|
492
|
+
if self._rot_mats is not None:
|
|
493
|
+
return self._rot_mats.device
|
|
494
|
+
elif self._quats is not None:
|
|
495
|
+
return self._quats.device
|
|
496
|
+
else:
|
|
497
|
+
raise ValueError("Both rotations are None")
|
|
498
|
+
|
|
499
|
+
@property
|
|
500
|
+
def requires_grad(self) -> bool:
|
|
501
|
+
"""
|
|
502
|
+
Returns the requires_grad property of the underlying rotation
|
|
503
|
+
|
|
504
|
+
Returns:
|
|
505
|
+
The requires_grad property of the underlying tensor
|
|
506
|
+
"""
|
|
507
|
+
if self._rot_mats is not None:
|
|
508
|
+
return self._rot_mats.requires_grad
|
|
509
|
+
elif self._quats is not None:
|
|
510
|
+
return self._quats.requires_grad
|
|
511
|
+
else:
|
|
512
|
+
raise ValueError("Both rotations are None")
|
|
513
|
+
|
|
514
|
+
def get_rot_mats(self) -> torch.Tensor:
|
|
515
|
+
"""
|
|
516
|
+
Returns the underlying rotation as a rotation matrix tensor.
|
|
517
|
+
|
|
518
|
+
Returns:
|
|
519
|
+
The rotation as a rotation matrix tensor
|
|
520
|
+
"""
|
|
521
|
+
rot_mats = self._rot_mats
|
|
522
|
+
if rot_mats is None:
|
|
523
|
+
if self._quats is None:
|
|
524
|
+
raise ValueError("Both rotations are None")
|
|
525
|
+
else:
|
|
526
|
+
rot_mats = quat_to_rot(self._quats)
|
|
527
|
+
|
|
528
|
+
return rot_mats
|
|
529
|
+
|
|
530
|
+
def get_quats(self) -> torch.Tensor:
|
|
531
|
+
"""
|
|
532
|
+
Returns the underlying rotation as a quaternion tensor.
|
|
533
|
+
|
|
534
|
+
Depending on whether the Rotation was initialized with a
|
|
535
|
+
quaternion, this function may call torch.linalg.eigh.
|
|
536
|
+
|
|
537
|
+
Returns:
|
|
538
|
+
The rotation as a quaternion tensor.
|
|
539
|
+
"""
|
|
540
|
+
quats = self._quats
|
|
541
|
+
if quats is None:
|
|
542
|
+
if self._rot_mats is None:
|
|
543
|
+
raise ValueError("Both rotations are None")
|
|
544
|
+
else:
|
|
545
|
+
quats = rot_to_quat(self._rot_mats)
|
|
546
|
+
|
|
547
|
+
return quats
|
|
548
|
+
|
|
549
|
+
def get_cur_rot(self) -> torch.Tensor:
|
|
550
|
+
"""
|
|
551
|
+
Return the underlying rotation in its current form
|
|
552
|
+
|
|
553
|
+
Returns:
|
|
554
|
+
The stored rotation
|
|
555
|
+
"""
|
|
556
|
+
if self._rot_mats is not None:
|
|
557
|
+
return self._rot_mats
|
|
558
|
+
elif self._quats is not None:
|
|
559
|
+
return self._quats
|
|
560
|
+
else:
|
|
561
|
+
raise ValueError("Both rotations are None")
|
|
562
|
+
|
|
563
|
+
def get_rotvec(self, eps=1e-4) -> torch.Tensor:
|
|
564
|
+
"""
|
|
565
|
+
Return the underlying axis-angle rotation vector.
|
|
566
|
+
|
|
567
|
+
Follow's scipy's implementation:
|
|
568
|
+
https://github.com/scipy/scipy/blob/HEAD/scipy/spatial/transform/_rotation.pyx#L1385-L1402
|
|
569
|
+
|
|
570
|
+
Returns:
|
|
571
|
+
The stored rotation as a axis-angle vector.
|
|
572
|
+
"""
|
|
573
|
+
quat = self.get_quats()
|
|
574
|
+
# w > 0 to ensure 0 <= angle <= pi
|
|
575
|
+
flip = (quat[..., :1] < 0).float()
|
|
576
|
+
quat = (-1 * quat) * flip + (1 - flip) * quat
|
|
577
|
+
|
|
578
|
+
angle = 2 * torch.atan2(torch.linalg.norm(quat[..., 1:], dim=-1), quat[..., 0])
|
|
579
|
+
|
|
580
|
+
angle2 = angle * angle
|
|
581
|
+
small_angle_scales = 2 + angle2 / 12 + 7 * angle2 * angle2 / 2880
|
|
582
|
+
large_angle_scales = angle / torch.sin(angle / 2 + eps)
|
|
583
|
+
|
|
584
|
+
small_angles = (angle <= 1e-3).float()
|
|
585
|
+
rot_vec_scale = (
|
|
586
|
+
small_angle_scales * small_angles + (1 - small_angles) * large_angle_scales
|
|
587
|
+
)
|
|
588
|
+
rot_vec = rot_vec_scale[..., None] * quat[..., 1:]
|
|
589
|
+
return rot_vec
|
|
590
|
+
|
|
591
|
+
# Rotation functions
|
|
592
|
+
|
|
593
|
+
def compose_q_update_vec(
|
|
594
|
+
self,
|
|
595
|
+
q_update_vec: torch.Tensor,
|
|
596
|
+
normalize_quats: bool = True,
|
|
597
|
+
update_mask: torch.Tensor = None,
|
|
598
|
+
):
|
|
599
|
+
"""
|
|
600
|
+
Returns a new quaternion Rotation after updating the current
|
|
601
|
+
object's underlying rotation with a quaternion update, formatted
|
|
602
|
+
as a [*, 3] tensor whose final three columns represent x, y, z such
|
|
603
|
+
that (1, x, y, z) is the desired (not necessarily unit) quaternion
|
|
604
|
+
update.
|
|
605
|
+
|
|
606
|
+
Args:
|
|
607
|
+
q_update_vec:
|
|
608
|
+
A [*, 3] quaternion update tensor
|
|
609
|
+
normalize_quats:
|
|
610
|
+
Whether to normalize the output quaternion
|
|
611
|
+
Returns:
|
|
612
|
+
An updated Rotation
|
|
613
|
+
"""
|
|
614
|
+
quats = self.get_quats()
|
|
615
|
+
quat_update = quat_multiply_by_vec(quats, q_update_vec)
|
|
616
|
+
if update_mask is not None:
|
|
617
|
+
quat_update = quat_update * update_mask
|
|
618
|
+
new_quats = quats + quat_update
|
|
619
|
+
return Rotation(
|
|
620
|
+
rot_mats=None,
|
|
621
|
+
quats=new_quats,
|
|
622
|
+
normalize_quats=normalize_quats,
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
def compose_r(self, r):
|
|
626
|
+
"""
|
|
627
|
+
Compose the rotation matrices of the current Rotation object with
|
|
628
|
+
those of another.
|
|
629
|
+
|
|
630
|
+
Args:
|
|
631
|
+
r:
|
|
632
|
+
An update rotation object
|
|
633
|
+
Returns:
|
|
634
|
+
An updated rotation object
|
|
635
|
+
"""
|
|
636
|
+
r1 = self.get_rot_mats()
|
|
637
|
+
r2 = r.get_rot_mats()
|
|
638
|
+
new_rot_mats = rot_matmul(r1, r2)
|
|
639
|
+
return Rotation(rot_mats=new_rot_mats, quats=None)
|
|
640
|
+
|
|
641
|
+
def compose_q(self, r, normalize_quats: bool = True):
|
|
642
|
+
"""
|
|
643
|
+
Compose the quaternions of the current Rotation object with those
|
|
644
|
+
of another.
|
|
645
|
+
|
|
646
|
+
Depending on whether either Rotation was initialized with
|
|
647
|
+
quaternions, this function may call torch.linalg.eigh.
|
|
648
|
+
|
|
649
|
+
Args:
|
|
650
|
+
r:
|
|
651
|
+
An update rotation object
|
|
652
|
+
Returns:
|
|
653
|
+
An updated rotation object
|
|
654
|
+
"""
|
|
655
|
+
q1 = self.get_quats()
|
|
656
|
+
q2 = r.get_quats()
|
|
657
|
+
new_quats = quat_multiply(q1, q2)
|
|
658
|
+
return Rotation(rot_mats=None, quats=new_quats, normalize_quats=normalize_quats)
|
|
659
|
+
|
|
660
|
+
def apply(self, pts: torch.Tensor) -> torch.Tensor:
|
|
661
|
+
"""
|
|
662
|
+
Apply the current Rotation as a rotation matrix to a set of 3D
|
|
663
|
+
coordinates.
|
|
664
|
+
|
|
665
|
+
Args:
|
|
666
|
+
pts:
|
|
667
|
+
A [*, 3] set of points
|
|
668
|
+
Returns:
|
|
669
|
+
[*, 3] rotated points
|
|
670
|
+
"""
|
|
671
|
+
rot_mats = self.get_rot_mats()
|
|
672
|
+
return rot_vec_mul(rot_mats, pts)
|
|
673
|
+
|
|
674
|
+
def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
|
|
675
|
+
"""
|
|
676
|
+
The inverse of the apply() method.
|
|
677
|
+
|
|
678
|
+
Args:
|
|
679
|
+
pts:
|
|
680
|
+
A [*, 3] set of points
|
|
681
|
+
Returns:
|
|
682
|
+
[*, 3] inverse-rotated points
|
|
683
|
+
"""
|
|
684
|
+
rot_mats = self.get_rot_mats()
|
|
685
|
+
inv_rot_mats = invert_rot_mat(rot_mats)
|
|
686
|
+
return rot_vec_mul(inv_rot_mats, pts)
|
|
687
|
+
|
|
688
|
+
def invert(self):
|
|
689
|
+
"""
|
|
690
|
+
Returns the inverse of the current Rotation.
|
|
691
|
+
|
|
692
|
+
Returns:
|
|
693
|
+
The inverse of the current Rotation
|
|
694
|
+
"""
|
|
695
|
+
if self._rot_mats is not None:
|
|
696
|
+
return Rotation(rot_mats=invert_rot_mat(self._rot_mats), quats=None)
|
|
697
|
+
elif self._quats is not None:
|
|
698
|
+
return Rotation(
|
|
699
|
+
rot_mats=None,
|
|
700
|
+
quats=invert_quat(self._quats),
|
|
701
|
+
normalize_quats=False,
|
|
702
|
+
)
|
|
703
|
+
else:
|
|
704
|
+
raise ValueError("Both rotations are None")
|
|
705
|
+
|
|
706
|
+
# "Tensor" stuff
|
|
707
|
+
|
|
708
|
+
def unsqueeze(
|
|
709
|
+
self,
|
|
710
|
+
dim: int,
|
|
711
|
+
):
|
|
712
|
+
"""
|
|
713
|
+
Analogous to torch.unsqueeze. The dimension is relative to the
|
|
714
|
+
shape of the Rotation object.
|
|
715
|
+
|
|
716
|
+
Args:
|
|
717
|
+
dim: A positive or negative dimension index.
|
|
718
|
+
Returns:
|
|
719
|
+
The unsqueezed Rotation.
|
|
720
|
+
"""
|
|
721
|
+
if dim >= len(self.shape):
|
|
722
|
+
raise ValueError("Invalid dimension")
|
|
723
|
+
|
|
724
|
+
if self._rot_mats is not None:
|
|
725
|
+
rot_mats = self._rot_mats.unsqueeze(dim if dim >= 0 else dim - 2)
|
|
726
|
+
return Rotation(rot_mats=rot_mats, quats=None)
|
|
727
|
+
elif self._quats is not None:
|
|
728
|
+
quats = self._quats.unsqueeze(dim if dim >= 0 else dim - 1)
|
|
729
|
+
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
|
|
730
|
+
else:
|
|
731
|
+
raise ValueError("Both rotations are None")
|
|
732
|
+
|
|
733
|
+
@staticmethod
|
|
734
|
+
def cat(
|
|
735
|
+
rs,
|
|
736
|
+
dim: int,
|
|
737
|
+
):
|
|
738
|
+
"""
|
|
739
|
+
Concatenates rotations along one of the batch dimensions. Analogous
|
|
740
|
+
to torch.cat().
|
|
741
|
+
|
|
742
|
+
Note that the output of this operation is always a rotation matrix,
|
|
743
|
+
regardless of the format of input rotations.
|
|
744
|
+
|
|
745
|
+
Args:
|
|
746
|
+
rs:
|
|
747
|
+
A list of rotation objects
|
|
748
|
+
dim:
|
|
749
|
+
The dimension along which the rotations should be
|
|
750
|
+
concatenated
|
|
751
|
+
Returns:
|
|
752
|
+
A concatenated Rotation object in rotation matrix format
|
|
753
|
+
"""
|
|
754
|
+
rot_mats = [r.get_rot_mats() for r in rs]
|
|
755
|
+
rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2)
|
|
756
|
+
|
|
757
|
+
return Rotation(rot_mats=rot_mats, quats=None)
|
|
758
|
+
|
|
759
|
+
def map_tensor_fn(self, fn):
|
|
760
|
+
"""
|
|
761
|
+
Apply a Tensor -> Tensor function to underlying rotation tensors,
|
|
762
|
+
mapping over the rotation dimension(s). Can be used e.g. to sum out
|
|
763
|
+
a one-hot batch dimension.
|
|
764
|
+
|
|
765
|
+
Args:
|
|
766
|
+
fn:
|
|
767
|
+
A Tensor -> Tensor function to be mapped over the Rotation
|
|
768
|
+
Returns:
|
|
769
|
+
The transformed Rotation object
|
|
770
|
+
"""
|
|
771
|
+
if self._rot_mats is not None:
|
|
772
|
+
rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,))
|
|
773
|
+
rot_mats = torch.stack(
|
|
774
|
+
list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1
|
|
775
|
+
)
|
|
776
|
+
rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3))
|
|
777
|
+
return Rotation(rot_mats=rot_mats, quats=None)
|
|
778
|
+
elif self._quats is not None:
|
|
779
|
+
quats = torch.stack(
|
|
780
|
+
list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1
|
|
781
|
+
)
|
|
782
|
+
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
|
|
783
|
+
else:
|
|
784
|
+
raise ValueError("Both rotations are None")
|
|
785
|
+
|
|
786
|
+
def cuda(self):
|
|
787
|
+
"""
|
|
788
|
+
Analogous to the cuda() method of torch Tensors
|
|
789
|
+
|
|
790
|
+
Returns:
|
|
791
|
+
A copy of the Rotation in CUDA memory
|
|
792
|
+
"""
|
|
793
|
+
if self._rot_mats is not None:
|
|
794
|
+
return Rotation(rot_mats=self._rot_mats.cuda(), quats=None)
|
|
795
|
+
elif self._quats is not None:
|
|
796
|
+
return Rotation(
|
|
797
|
+
rot_mats=None, quats=self._quats.cuda(), normalize_quats=False
|
|
798
|
+
)
|
|
799
|
+
else:
|
|
800
|
+
raise ValueError("Both rotations are None")
|
|
801
|
+
|
|
802
|
+
def to(self, device: Optional[torch.device], dtype: Optional[torch.dtype]):
|
|
803
|
+
"""
|
|
804
|
+
Analogous to the to() method of torch Tensors
|
|
805
|
+
|
|
806
|
+
Args:
|
|
807
|
+
device:
|
|
808
|
+
A torch device
|
|
809
|
+
dtype:
|
|
810
|
+
A torch dtype
|
|
811
|
+
Returns:
|
|
812
|
+
A copy of the Rotation using the new device and dtype
|
|
813
|
+
"""
|
|
814
|
+
if self._rot_mats is not None:
|
|
815
|
+
return Rotation(
|
|
816
|
+
rot_mats=self._rot_mats.to(device=device, dtype=dtype),
|
|
817
|
+
quats=None,
|
|
818
|
+
)
|
|
819
|
+
elif self._quats is not None:
|
|
820
|
+
return Rotation(
|
|
821
|
+
rot_mats=None,
|
|
822
|
+
quats=self._quats.to(device=device, dtype=dtype),
|
|
823
|
+
normalize_quats=False,
|
|
824
|
+
)
|
|
825
|
+
else:
|
|
826
|
+
raise ValueError("Both rotations are None")
|
|
827
|
+
|
|
828
|
+
def detach(self):
|
|
829
|
+
"""
|
|
830
|
+
Returns a copy of the Rotation whose underlying Tensor has been
|
|
831
|
+
detached from its torch graph.
|
|
832
|
+
|
|
833
|
+
Returns:
|
|
834
|
+
A copy of the Rotation whose underlying Tensor has been detached
|
|
835
|
+
from its torch graph
|
|
836
|
+
"""
|
|
837
|
+
if self._rot_mats is not None:
|
|
838
|
+
return Rotation(rot_mats=self._rot_mats.detach(), quats=None)
|
|
839
|
+
elif self._quats is not None:
|
|
840
|
+
return Rotation(
|
|
841
|
+
rot_mats=None,
|
|
842
|
+
quats=self._quats.detach(),
|
|
843
|
+
normalize_quats=False,
|
|
844
|
+
)
|
|
845
|
+
else:
|
|
846
|
+
raise ValueError("Both rotations are None")
|
|
847
|
+
|
|
848
|
+
|
|
849
|
+
class Rigid:
|
|
850
|
+
"""
|
|
851
|
+
A class representing a rigid transformation. Little more than a wrapper
|
|
852
|
+
around two objects: a Rotation object and a [*, 3] translation
|
|
853
|
+
Designed to behave approximately like a single torch tensor with the
|
|
854
|
+
shape of the shared batch dimensions of its component parts.
|
|
855
|
+
"""
|
|
856
|
+
|
|
857
|
+
def __init__(
|
|
858
|
+
self,
|
|
859
|
+
rots: Optional[Rotation],
|
|
860
|
+
trans: Optional[torch.Tensor],
|
|
861
|
+
):
|
|
862
|
+
"""
|
|
863
|
+
Args:
|
|
864
|
+
rots: A [*, 3, 3] rotation tensor
|
|
865
|
+
trans: A corresponding [*, 3] translation tensor
|
|
866
|
+
"""
|
|
867
|
+
# (we need device, dtype, etc. from at least one input)
|
|
868
|
+
|
|
869
|
+
batch_dims, dtype, device, requires_grad = None, None, None, None
|
|
870
|
+
if trans is not None:
|
|
871
|
+
batch_dims = trans.shape[:-1]
|
|
872
|
+
dtype = trans.dtype
|
|
873
|
+
device = trans.device
|
|
874
|
+
requires_grad = trans.requires_grad
|
|
875
|
+
elif rots is not None:
|
|
876
|
+
batch_dims = rots.shape
|
|
877
|
+
dtype = rots.dtype
|
|
878
|
+
device = rots.device
|
|
879
|
+
requires_grad = rots.requires_grad
|
|
880
|
+
else:
|
|
881
|
+
raise ValueError("At least one input argument must be specified")
|
|
882
|
+
|
|
883
|
+
if rots is None:
|
|
884
|
+
rots = Rotation.identity(
|
|
885
|
+
batch_dims,
|
|
886
|
+
dtype,
|
|
887
|
+
device,
|
|
888
|
+
requires_grad,
|
|
889
|
+
)
|
|
890
|
+
elif trans is None:
|
|
891
|
+
trans = identity_trans(
|
|
892
|
+
batch_dims,
|
|
893
|
+
dtype,
|
|
894
|
+
device,
|
|
895
|
+
requires_grad,
|
|
896
|
+
)
|
|
897
|
+
|
|
898
|
+
if (rots.shape != trans.shape[:-1]) or (rots.device != trans.device):
|
|
899
|
+
raise ValueError("Rots and trans incompatible")
|
|
900
|
+
|
|
901
|
+
# Force full precision. Happens to the rotations automatically.
|
|
902
|
+
trans = trans.type(torch.float32)
|
|
903
|
+
|
|
904
|
+
self._rots = rots
|
|
905
|
+
self._trans = trans
|
|
906
|
+
|
|
907
|
+
@staticmethod
|
|
908
|
+
def identity(
|
|
909
|
+
shape: Tuple[int],
|
|
910
|
+
dtype: Optional[torch.dtype] = None,
|
|
911
|
+
device: Optional[torch.device] = None,
|
|
912
|
+
requires_grad: bool = True,
|
|
913
|
+
fmt: str = "quat",
|
|
914
|
+
):
|
|
915
|
+
"""
|
|
916
|
+
Constructs an identity transformation.
|
|
917
|
+
|
|
918
|
+
Args:
|
|
919
|
+
shape:
|
|
920
|
+
The desired shape
|
|
921
|
+
dtype:
|
|
922
|
+
The dtype of both internal tensors
|
|
923
|
+
device:
|
|
924
|
+
The device of both internal tensors
|
|
925
|
+
requires_grad:
|
|
926
|
+
Whether grad should be enabled for the internal tensors
|
|
927
|
+
Returns:
|
|
928
|
+
The identity transformation
|
|
929
|
+
"""
|
|
930
|
+
return Rigid(
|
|
931
|
+
Rotation.identity(shape, dtype, device, requires_grad, fmt=fmt),
|
|
932
|
+
identity_trans(shape, dtype, device, requires_grad),
|
|
933
|
+
)
|
|
934
|
+
|
|
935
|
+
def __getitem__(
|
|
936
|
+
self,
|
|
937
|
+
index: Any,
|
|
938
|
+
):
|
|
939
|
+
"""
|
|
940
|
+
Indexes the affine transformation with PyTorch-style indices.
|
|
941
|
+
The index is applied to the shared dimensions of both the rotation
|
|
942
|
+
and the translation.
|
|
943
|
+
|
|
944
|
+
E.g.::
|
|
945
|
+
|
|
946
|
+
r = Rotation(rot_mats=torch.rand(10, 10, 3, 3), quats=None)
|
|
947
|
+
t = Rigid(r, torch.rand(10, 10, 3))
|
|
948
|
+
indexed = t[3, 4:6]
|
|
949
|
+
assert indexed.shape == (2,)
|
|
950
|
+
assert indexed.get_rots().shape == (2,)
|
|
951
|
+
assert indexed.get_trans().shape == (2, 3)
|
|
952
|
+
|
|
953
|
+
Args:
|
|
954
|
+
index: A standard torch tensor index. E.g. 8, (10, None, 3),
|
|
955
|
+
or (3, slice(0, 1, None))
|
|
956
|
+
Returns:
|
|
957
|
+
The indexed tensor
|
|
958
|
+
"""
|
|
959
|
+
if not isinstance(index, tuple):
|
|
960
|
+
index = (index,)
|
|
961
|
+
|
|
962
|
+
return Rigid(
|
|
963
|
+
self._rots[index],
|
|
964
|
+
self._trans[index + (slice(None),)],
|
|
965
|
+
)
|
|
966
|
+
|
|
967
|
+
def __mul__(
|
|
968
|
+
self,
|
|
969
|
+
right: torch.Tensor,
|
|
970
|
+
):
|
|
971
|
+
"""
|
|
972
|
+
Pointwise left multiplication of the transformation with a tensor.
|
|
973
|
+
Can be used to e.g. mask the Rigid.
|
|
974
|
+
|
|
975
|
+
Args:
|
|
976
|
+
right:
|
|
977
|
+
The tensor multiplicand
|
|
978
|
+
Returns:
|
|
979
|
+
The product
|
|
980
|
+
"""
|
|
981
|
+
if not (isinstance(right, torch.Tensor)):
|
|
982
|
+
raise TypeError("The other multiplicand must be a Tensor")
|
|
983
|
+
|
|
984
|
+
new_rots = self._rots * right
|
|
985
|
+
new_trans = self._trans * right[..., None]
|
|
986
|
+
|
|
987
|
+
return Rigid(new_rots, new_trans)
|
|
988
|
+
|
|
989
|
+
def __rmul__(
|
|
990
|
+
self,
|
|
991
|
+
left: torch.Tensor,
|
|
992
|
+
):
|
|
993
|
+
"""
|
|
994
|
+
Reverse pointwise multiplication of the transformation with a
|
|
995
|
+
tensor.
|
|
996
|
+
|
|
997
|
+
Args:
|
|
998
|
+
left:
|
|
999
|
+
The left multiplicand
|
|
1000
|
+
Returns:
|
|
1001
|
+
The product
|
|
1002
|
+
"""
|
|
1003
|
+
return self.__mul__(left)
|
|
1004
|
+
|
|
1005
|
+
@property
|
|
1006
|
+
def shape(self) -> torch.Size:
|
|
1007
|
+
"""
|
|
1008
|
+
Returns the shape of the shared dimensions of the rotation and
|
|
1009
|
+
the translation.
|
|
1010
|
+
|
|
1011
|
+
Returns:
|
|
1012
|
+
The shape of the transformation
|
|
1013
|
+
"""
|
|
1014
|
+
s = self._trans.shape[:-1]
|
|
1015
|
+
return s
|
|
1016
|
+
|
|
1017
|
+
@property
|
|
1018
|
+
def device(self) -> torch.device:
|
|
1019
|
+
"""
|
|
1020
|
+
Returns the device on which the Rigid's tensors are located.
|
|
1021
|
+
|
|
1022
|
+
Returns:
|
|
1023
|
+
The device on which the Rigid's tensors are located
|
|
1024
|
+
"""
|
|
1025
|
+
return self._trans.device
|
|
1026
|
+
|
|
1027
|
+
def get_rots(self) -> Rotation:
|
|
1028
|
+
"""
|
|
1029
|
+
Getter for the rotation.
|
|
1030
|
+
|
|
1031
|
+
Returns:
|
|
1032
|
+
The rotation object
|
|
1033
|
+
"""
|
|
1034
|
+
return self._rots
|
|
1035
|
+
|
|
1036
|
+
def get_trans(self) -> torch.Tensor:
|
|
1037
|
+
"""
|
|
1038
|
+
Getter for the translation.
|
|
1039
|
+
|
|
1040
|
+
Returns:
|
|
1041
|
+
The stored translation
|
|
1042
|
+
"""
|
|
1043
|
+
return self._trans
|
|
1044
|
+
|
|
1045
|
+
def compose_q_update_vec(
|
|
1046
|
+
self,
|
|
1047
|
+
q_update_vec: torch.Tensor,
|
|
1048
|
+
update_mask: torch.Tensor = None,
|
|
1049
|
+
):
|
|
1050
|
+
"""
|
|
1051
|
+
Composes the transformation with a quaternion update vector of
|
|
1052
|
+
shape [*, 6], where the final 6 columns represent the x, y, and
|
|
1053
|
+
z values of a quaternion of form (1, x, y, z) followed by a 3D
|
|
1054
|
+
translation.
|
|
1055
|
+
|
|
1056
|
+
Args:
|
|
1057
|
+
q_vec: The quaternion update vector.
|
|
1058
|
+
Returns:
|
|
1059
|
+
The composed transformation.
|
|
1060
|
+
"""
|
|
1061
|
+
q_vec, t_vec = q_update_vec[..., :3], q_update_vec[..., 3:]
|
|
1062
|
+
new_rots = self._rots.compose_q_update_vec(q_vec, update_mask=update_mask)
|
|
1063
|
+
|
|
1064
|
+
trans_update = self._rots.apply(t_vec)
|
|
1065
|
+
if update_mask is not None:
|
|
1066
|
+
trans_update = trans_update * update_mask
|
|
1067
|
+
new_translation = self._trans + trans_update
|
|
1068
|
+
|
|
1069
|
+
return Rigid(new_rots, new_translation)
|
|
1070
|
+
|
|
1071
|
+
def compose_tran_update_vec(
|
|
1072
|
+
self,
|
|
1073
|
+
t_vec: torch.Tensor,
|
|
1074
|
+
update_mask: torch.Tensor = None,
|
|
1075
|
+
):
|
|
1076
|
+
"""
|
|
1077
|
+
Composes the transformation with a quaternion update vector of
|
|
1078
|
+
shape [*, 3], where columns represent a 3D translation.
|
|
1079
|
+
|
|
1080
|
+
Args:
|
|
1081
|
+
q_vec: The quaternion update vector.
|
|
1082
|
+
Returns:
|
|
1083
|
+
The composed transformation.
|
|
1084
|
+
"""
|
|
1085
|
+
trans_update = self._rots.apply(t_vec)
|
|
1086
|
+
if update_mask is not None:
|
|
1087
|
+
trans_update = trans_update * update_mask
|
|
1088
|
+
new_translation = self._trans + trans_update
|
|
1089
|
+
|
|
1090
|
+
return Rigid(self._rots, new_translation)
|
|
1091
|
+
|
|
1092
|
+
def compose(
|
|
1093
|
+
self,
|
|
1094
|
+
r,
|
|
1095
|
+
):
|
|
1096
|
+
"""
|
|
1097
|
+
Composes the current rigid object with another.
|
|
1098
|
+
|
|
1099
|
+
Args:
|
|
1100
|
+
r:
|
|
1101
|
+
Another Rigid object
|
|
1102
|
+
Returns:
|
|
1103
|
+
The composition of the two transformations
|
|
1104
|
+
"""
|
|
1105
|
+
new_rot = self._rots.compose_r(r._rots)
|
|
1106
|
+
new_trans = self._rots.apply(r._trans) + self._trans
|
|
1107
|
+
return Rigid(new_rot, new_trans)
|
|
1108
|
+
|
|
1109
|
+
def compose_r(self, rot, order="right"):
|
|
1110
|
+
"""
|
|
1111
|
+
Composes the current rigid object with another.
|
|
1112
|
+
|
|
1113
|
+
Args:
|
|
1114
|
+
r:
|
|
1115
|
+
Another Rigid object
|
|
1116
|
+
order:
|
|
1117
|
+
Order in which to perform rotation multiplication.
|
|
1118
|
+
Returns:
|
|
1119
|
+
The composition of the two transformations
|
|
1120
|
+
"""
|
|
1121
|
+
if order == "right":
|
|
1122
|
+
new_rot = self._rots.compose_r(rot)
|
|
1123
|
+
elif order == "left":
|
|
1124
|
+
new_rot = rot.compose_r(self._rots)
|
|
1125
|
+
else:
|
|
1126
|
+
raise ValueError(f"Unrecognized multiplication order: {order}")
|
|
1127
|
+
return Rigid(new_rot, self._trans)
|
|
1128
|
+
|
|
1129
|
+
def apply(
|
|
1130
|
+
self,
|
|
1131
|
+
pts: torch.Tensor,
|
|
1132
|
+
) -> torch.Tensor:
|
|
1133
|
+
"""
|
|
1134
|
+
Applies the transformation to a coordinate tensor.
|
|
1135
|
+
|
|
1136
|
+
Args:
|
|
1137
|
+
pts: A [*, 3] coordinate tensor.
|
|
1138
|
+
Returns:
|
|
1139
|
+
The transformed points.
|
|
1140
|
+
"""
|
|
1141
|
+
rotated = self._rots.apply(pts)
|
|
1142
|
+
return rotated + self._trans
|
|
1143
|
+
|
|
1144
|
+
def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
|
|
1145
|
+
"""
|
|
1146
|
+
Applies the inverse of the transformation to a coordinate tensor.
|
|
1147
|
+
|
|
1148
|
+
Args:
|
|
1149
|
+
pts: A [*, 3] coordinate tensor
|
|
1150
|
+
Returns:
|
|
1151
|
+
The transformed points.
|
|
1152
|
+
"""
|
|
1153
|
+
pts = pts - self._trans
|
|
1154
|
+
return self._rots.invert_apply(pts)
|
|
1155
|
+
|
|
1156
|
+
def invert(self):
|
|
1157
|
+
"""
|
|
1158
|
+
Inverts the transformation.
|
|
1159
|
+
|
|
1160
|
+
Returns:
|
|
1161
|
+
The inverse transformation.
|
|
1162
|
+
"""
|
|
1163
|
+
rot_inv = self._rots.invert()
|
|
1164
|
+
trn_inv = rot_inv.apply(self._trans)
|
|
1165
|
+
|
|
1166
|
+
return Rigid(rot_inv, -1 * trn_inv)
|
|
1167
|
+
|
|
1168
|
+
def map_tensor_fn(self, fn):
|
|
1169
|
+
"""
|
|
1170
|
+
Apply a Tensor -> Tensor function to underlying translation and
|
|
1171
|
+
rotation tensors, mapping over the translation/rotation dimensions
|
|
1172
|
+
respectively.
|
|
1173
|
+
|
|
1174
|
+
Args:
|
|
1175
|
+
fn:
|
|
1176
|
+
A Tensor -> Tensor function to be mapped over the Rigid
|
|
1177
|
+
Returns:
|
|
1178
|
+
The transformed Rigid object
|
|
1179
|
+
"""
|
|
1180
|
+
new_rots = self._rots.map_tensor_fn(fn)
|
|
1181
|
+
new_trans = torch.stack(
|
|
1182
|
+
list(map(fn, torch.unbind(self._trans, dim=-1))), dim=-1
|
|
1183
|
+
)
|
|
1184
|
+
|
|
1185
|
+
return Rigid(new_rots, new_trans)
|
|
1186
|
+
|
|
1187
|
+
def to_tensor_4x4(self) -> torch.Tensor:
|
|
1188
|
+
"""
|
|
1189
|
+
Converts a transformation to a homogenous transformation tensor.
|
|
1190
|
+
|
|
1191
|
+
Returns:
|
|
1192
|
+
A [*, 4, 4] homogenous transformation tensor
|
|
1193
|
+
"""
|
|
1194
|
+
tensor = self._trans.new_zeros((*self.shape, 4, 4))
|
|
1195
|
+
tensor[..., :3, :3] = self._rots.get_rot_mats()
|
|
1196
|
+
tensor[..., :3, 3] = self._trans
|
|
1197
|
+
tensor[..., 3, 3] = 1
|
|
1198
|
+
return tensor
|
|
1199
|
+
|
|
1200
|
+
@staticmethod
|
|
1201
|
+
def from_tensor_4x4(t: torch.Tensor):
|
|
1202
|
+
"""
|
|
1203
|
+
Constructs a transformation from a homogenous transformation
|
|
1204
|
+
tensor.
|
|
1205
|
+
|
|
1206
|
+
Args:
|
|
1207
|
+
t: [*, 4, 4] homogenous transformation tensor
|
|
1208
|
+
Returns:
|
|
1209
|
+
T object with shape [*]
|
|
1210
|
+
"""
|
|
1211
|
+
if t.shape[-2:] != (4, 4):
|
|
1212
|
+
raise ValueError("Incorrectly shaped input tensor")
|
|
1213
|
+
|
|
1214
|
+
rots = Rotation(rot_mats=t[..., :3, :3], quats=None)
|
|
1215
|
+
trans = t[..., :3, 3]
|
|
1216
|
+
|
|
1217
|
+
return Rigid(rots, trans)
|
|
1218
|
+
|
|
1219
|
+
def to_tensor_7(self) -> torch.Tensor:
|
|
1220
|
+
"""
|
|
1221
|
+
Converts a transformation to a tensor with 7 final columns, four
|
|
1222
|
+
for the quaternion followed by three for the translation.
|
|
1223
|
+
|
|
1224
|
+
Returns:
|
|
1225
|
+
A [*, 7] tensor representation of the transformation
|
|
1226
|
+
"""
|
|
1227
|
+
tensor = self._trans.new_zeros((*self.shape, 7))
|
|
1228
|
+
tensor[..., :4] = self._rots.get_quats()
|
|
1229
|
+
tensor[..., 4:] = self._trans
|
|
1230
|
+
|
|
1231
|
+
return tensor
|
|
1232
|
+
|
|
1233
|
+
@staticmethod
|
|
1234
|
+
def from_tensor_7(
|
|
1235
|
+
t: torch.Tensor,
|
|
1236
|
+
normalize_quats: bool = False,
|
|
1237
|
+
):
|
|
1238
|
+
if t.shape[-1] != 7:
|
|
1239
|
+
raise ValueError("Incorrectly shaped input tensor")
|
|
1240
|
+
|
|
1241
|
+
quats, trans = t[..., :4], t[..., 4:]
|
|
1242
|
+
|
|
1243
|
+
rots = Rotation(rot_mats=None, quats=quats, normalize_quats=normalize_quats)
|
|
1244
|
+
|
|
1245
|
+
return Rigid(rots, trans)
|
|
1246
|
+
|
|
1247
|
+
@staticmethod
|
|
1248
|
+
def from_3_points(
|
|
1249
|
+
p_neg_x_axis: torch.Tensor,
|
|
1250
|
+
origin: torch.Tensor,
|
|
1251
|
+
p_xy_plane: torch.Tensor,
|
|
1252
|
+
eps: float = 1e-4,
|
|
1253
|
+
):
|
|
1254
|
+
"""
|
|
1255
|
+
Implements algorithm 21. Constructs transformations from sets of 3
|
|
1256
|
+
points using the Gram-Schmidt algorithm.
|
|
1257
|
+
|
|
1258
|
+
Args:
|
|
1259
|
+
p_neg_x_axis: [*, 3] coordinates
|
|
1260
|
+
origin: [*, 3] coordinates used as frame origins
|
|
1261
|
+
p_xy_plane: [*, 3] coordinates
|
|
1262
|
+
eps: Small epsilon value
|
|
1263
|
+
Returns:
|
|
1264
|
+
A transformation object of shape [*]
|
|
1265
|
+
"""
|
|
1266
|
+
p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1)
|
|
1267
|
+
origin = torch.unbind(origin, dim=-1)
|
|
1268
|
+
p_xy_plane = torch.unbind(p_xy_plane, dim=-1)
|
|
1269
|
+
|
|
1270
|
+
e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)]
|
|
1271
|
+
e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)]
|
|
1272
|
+
|
|
1273
|
+
denom = torch.sqrt(sum((c * c for c in e0)) + eps)
|
|
1274
|
+
e0 = [c / denom for c in e0]
|
|
1275
|
+
dot = sum((c1 * c2 for c1, c2 in zip(e0, e1)))
|
|
1276
|
+
e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)]
|
|
1277
|
+
denom = torch.sqrt(sum((c * c for c in e1)) + eps)
|
|
1278
|
+
e1 = [c / denom for c in e1]
|
|
1279
|
+
e2 = [
|
|
1280
|
+
e0[1] * e1[2] - e0[2] * e1[1],
|
|
1281
|
+
e0[2] * e1[0] - e0[0] * e1[2],
|
|
1282
|
+
e0[0] * e1[1] - e0[1] * e1[0],
|
|
1283
|
+
]
|
|
1284
|
+
|
|
1285
|
+
rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1)
|
|
1286
|
+
rots = rots.reshape(rots.shape[:-1] + (3, 3))
|
|
1287
|
+
|
|
1288
|
+
rot_obj = Rotation(rot_mats=rots, quats=None)
|
|
1289
|
+
|
|
1290
|
+
return Rigid(rot_obj, torch.stack(origin, dim=-1))
|
|
1291
|
+
|
|
1292
|
+
def unsqueeze(
|
|
1293
|
+
self,
|
|
1294
|
+
dim: int,
|
|
1295
|
+
):
|
|
1296
|
+
"""
|
|
1297
|
+
Analogous to torch.unsqueeze. The dimension is relative to the
|
|
1298
|
+
shared dimensions of the rotation/translation.
|
|
1299
|
+
|
|
1300
|
+
Args:
|
|
1301
|
+
dim: A positive or negative dimension index.
|
|
1302
|
+
Returns:
|
|
1303
|
+
The unsqueezed transformation.
|
|
1304
|
+
"""
|
|
1305
|
+
if dim >= len(self.shape):
|
|
1306
|
+
raise ValueError("Invalid dimension")
|
|
1307
|
+
rots = self._rots.unsqueeze(dim)
|
|
1308
|
+
trans = self._trans.unsqueeze(dim if dim >= 0 else dim - 1)
|
|
1309
|
+
|
|
1310
|
+
return Rigid(rots, trans)
|
|
1311
|
+
|
|
1312
|
+
@staticmethod
|
|
1313
|
+
def cat(
|
|
1314
|
+
ts,
|
|
1315
|
+
dim: int,
|
|
1316
|
+
):
|
|
1317
|
+
"""
|
|
1318
|
+
Concatenates transformations along a new dimension.
|
|
1319
|
+
|
|
1320
|
+
Args:
|
|
1321
|
+
ts:
|
|
1322
|
+
A list of T objects
|
|
1323
|
+
dim:
|
|
1324
|
+
The dimension along which the transformations should be
|
|
1325
|
+
concatenated
|
|
1326
|
+
Returns:
|
|
1327
|
+
A concatenated transformation object
|
|
1328
|
+
"""
|
|
1329
|
+
rots = Rotation.cat([t._rots for t in ts], dim)
|
|
1330
|
+
trans = torch.cat([t._trans for t in ts], dim=dim if dim >= 0 else dim - 1)
|
|
1331
|
+
|
|
1332
|
+
return Rigid(rots, trans)
|
|
1333
|
+
|
|
1334
|
+
def apply_rot_fn(self, fn):
|
|
1335
|
+
"""
|
|
1336
|
+
Applies a Rotation -> Rotation function to the stored rotation
|
|
1337
|
+
object.
|
|
1338
|
+
|
|
1339
|
+
Args:
|
|
1340
|
+
fn: A function of type Rotation -> Rotation
|
|
1341
|
+
Returns:
|
|
1342
|
+
A transformation object with a transformed rotation.
|
|
1343
|
+
"""
|
|
1344
|
+
return Rigid(fn(self._rots), self._trans)
|
|
1345
|
+
|
|
1346
|
+
def apply_trans_fn(self, fn):
|
|
1347
|
+
"""
|
|
1348
|
+
Applies a Tensor -> Tensor function to the stored translation.
|
|
1349
|
+
|
|
1350
|
+
Args:
|
|
1351
|
+
fn:
|
|
1352
|
+
A function of type Tensor -> Tensor to be applied to the
|
|
1353
|
+
translation
|
|
1354
|
+
Returns:
|
|
1355
|
+
A transformation object with a transformed translation.
|
|
1356
|
+
"""
|
|
1357
|
+
return Rigid(self._rots, fn(self._trans))
|
|
1358
|
+
|
|
1359
|
+
def scale_translation(self, trans_scale_factor: float):
|
|
1360
|
+
"""
|
|
1361
|
+
Scales the translation by a constant factor.
|
|
1362
|
+
|
|
1363
|
+
Args:
|
|
1364
|
+
trans_scale_factor:
|
|
1365
|
+
The constant factor
|
|
1366
|
+
Returns:
|
|
1367
|
+
A transformation object with a scaled translation.
|
|
1368
|
+
"""
|
|
1369
|
+
fn = lambda t: t * trans_scale_factor # noqa: E731
|
|
1370
|
+
return self.apply_trans_fn(fn)
|
|
1371
|
+
|
|
1372
|
+
def stop_rot_gradient(self):
|
|
1373
|
+
"""
|
|
1374
|
+
Detaches the underlying rotation object
|
|
1375
|
+
|
|
1376
|
+
Returns:
|
|
1377
|
+
A transformation object with detached rotations
|
|
1378
|
+
"""
|
|
1379
|
+
fn = lambda r: r.detach() # noqa: E731
|
|
1380
|
+
return self.apply_rot_fn(fn)
|
|
1381
|
+
|
|
1382
|
+
@staticmethod
|
|
1383
|
+
def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20):
|
|
1384
|
+
"""
|
|
1385
|
+
Returns a transformation object from reference coordinates.
|
|
1386
|
+
|
|
1387
|
+
Note that this method does not take care of symmetries. If you
|
|
1388
|
+
provide the atom positions in the non-standard way, the N atom will
|
|
1389
|
+
end up not at [-0.527250, 1.359329, 0.0] but instead at
|
|
1390
|
+
[-0.527250, -1.359329, 0.0]. You need to take care of such cases in
|
|
1391
|
+
your code.
|
|
1392
|
+
|
|
1393
|
+
Args:
|
|
1394
|
+
n_xyz: A [*, 3] tensor of nitrogen xyz coordinates.
|
|
1395
|
+
ca_xyz: A [*, 3] tensor of carbon alpha xyz coordinates.
|
|
1396
|
+
c_xyz: A [*, 3] tensor of carbon xyz coordinates.
|
|
1397
|
+
Returns:
|
|
1398
|
+
A transformation object. After applying the translation and
|
|
1399
|
+
rotation to the reference backbone, the coordinates will
|
|
1400
|
+
approximately equal to the input coordinates.
|
|
1401
|
+
"""
|
|
1402
|
+
translation = -1 * ca_xyz
|
|
1403
|
+
n_xyz = n_xyz + translation
|
|
1404
|
+
c_xyz = c_xyz + translation
|
|
1405
|
+
|
|
1406
|
+
c_x, c_y, c_z = [c_xyz[..., i] for i in range(3)]
|
|
1407
|
+
norm = torch.sqrt(eps + c_x**2 + c_y**2)
|
|
1408
|
+
sin_c1 = -c_y / norm
|
|
1409
|
+
cos_c1 = c_x / norm
|
|
1410
|
+
|
|
1411
|
+
c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3))
|
|
1412
|
+
c1_rots[..., 0, 0] = cos_c1
|
|
1413
|
+
c1_rots[..., 0, 1] = -1 * sin_c1
|
|
1414
|
+
c1_rots[..., 1, 0] = sin_c1
|
|
1415
|
+
c1_rots[..., 1, 1] = cos_c1
|
|
1416
|
+
c1_rots[..., 2, 2] = 1
|
|
1417
|
+
|
|
1418
|
+
norm = torch.sqrt(eps + c_x**2 + c_y**2 + c_z**2)
|
|
1419
|
+
sin_c2 = c_z / norm
|
|
1420
|
+
cos_c2 = torch.sqrt(c_x**2 + c_y**2) / norm
|
|
1421
|
+
|
|
1422
|
+
c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
|
|
1423
|
+
c2_rots[..., 0, 0] = cos_c2
|
|
1424
|
+
c2_rots[..., 0, 2] = sin_c2
|
|
1425
|
+
c2_rots[..., 1, 1] = 1
|
|
1426
|
+
c1_rots[..., 2, 0] = -1 * sin_c2
|
|
1427
|
+
c1_rots[..., 2, 2] = cos_c2
|
|
1428
|
+
|
|
1429
|
+
c_rots = rot_matmul(c2_rots, c1_rots)
|
|
1430
|
+
n_xyz = rot_vec_mul(c_rots, n_xyz)
|
|
1431
|
+
|
|
1432
|
+
_, n_y, n_z = [n_xyz[..., i] for i in range(3)]
|
|
1433
|
+
norm = torch.sqrt(eps + n_y**2 + n_z**2)
|
|
1434
|
+
sin_n = -n_z / norm
|
|
1435
|
+
cos_n = n_y / norm
|
|
1436
|
+
|
|
1437
|
+
n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
|
|
1438
|
+
n_rots[..., 0, 0] = 1
|
|
1439
|
+
n_rots[..., 1, 1] = cos_n
|
|
1440
|
+
n_rots[..., 1, 2] = -1 * sin_n
|
|
1441
|
+
n_rots[..., 2, 1] = sin_n
|
|
1442
|
+
n_rots[..., 2, 2] = cos_n
|
|
1443
|
+
|
|
1444
|
+
rots = rot_matmul(n_rots, c_rots)
|
|
1445
|
+
|
|
1446
|
+
rots = rots.transpose(-1, -2)
|
|
1447
|
+
translation = -1 * translation
|
|
1448
|
+
|
|
1449
|
+
rot_obj = Rotation(rot_mats=rots, quats=None)
|
|
1450
|
+
|
|
1451
|
+
return Rigid(rot_obj, translation)
|
|
1452
|
+
|
|
1453
|
+
def cuda(self):
|
|
1454
|
+
"""
|
|
1455
|
+
Moves the transformation object to GPU memory
|
|
1456
|
+
|
|
1457
|
+
Returns:
|
|
1458
|
+
A version of the transformation on GPU
|
|
1459
|
+
"""
|
|
1460
|
+
return Rigid(self._rots.cuda(), self._trans.cuda())
|