xbarray 0.0.1a13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- array_api_typing/__init__.py +9 -0
- array_api_typing/typing_2024_12/__init__.py +12 -0
- array_api_typing/typing_2024_12/_api_constant.py +32 -0
- array_api_typing/typing_2024_12/_api_fft_typing.py +717 -0
- array_api_typing/typing_2024_12/_api_linalg_typing.py +897 -0
- array_api_typing/typing_2024_12/_api_return_typing.py +103 -0
- array_api_typing/typing_2024_12/_api_typing.py +5855 -0
- array_api_typing/typing_2024_12/_array_typing.py +1265 -0
- array_api_typing/typing_compat/__init__.py +12 -0
- array_api_typing/typing_compat/_api_typing.py +27 -0
- array_api_typing/typing_compat/_array_typing.py +36 -0
- array_api_typing/typing_extra/__init__.py +12 -0
- array_api_typing/typing_extra/_api_typing.py +651 -0
- array_api_typing/typing_extra/_at.py +87 -0
- xbarray/__init__.py +1 -0
- xbarray/backends/_cls_base.py +9 -0
- xbarray/backends/_implementations/_common/implementations.py +87 -0
- xbarray/backends/_implementations/jax/__init__.py +33 -0
- xbarray/backends/_implementations/jax/_extra.py +127 -0
- xbarray/backends/_implementations/jax/_typing.py +15 -0
- xbarray/backends/_implementations/jax/random.py +115 -0
- xbarray/backends/_implementations/numpy/__init__.py +25 -0
- xbarray/backends/_implementations/numpy/_extra.py +98 -0
- xbarray/backends/_implementations/numpy/_typing.py +14 -0
- xbarray/backends/_implementations/numpy/random.py +105 -0
- xbarray/backends/_implementations/pytorch/__init__.py +26 -0
- xbarray/backends/_implementations/pytorch/_extra.py +135 -0
- xbarray/backends/_implementations/pytorch/_typing.py +13 -0
- xbarray/backends/_implementations/pytorch/random.py +101 -0
- xbarray/backends/base.py +218 -0
- xbarray/backends/jax.py +19 -0
- xbarray/backends/numpy.py +19 -0
- xbarray/backends/pytorch.py +22 -0
- xbarray/jax.py +4 -0
- xbarray/numpy.py +4 -0
- xbarray/pytorch.py +4 -0
- xbarray/transformations/pointcloud/__init__.py +1 -0
- xbarray/transformations/pointcloud/base.py +449 -0
- xbarray/transformations/pointcloud/jax.py +24 -0
- xbarray/transformations/pointcloud/numpy.py +23 -0
- xbarray/transformations/pointcloud/pytorch.py +23 -0
- xbarray/transformations/rotation_conversions/__init__.py +1 -0
- xbarray/transformations/rotation_conversions/base.py +713 -0
- xbarray/transformations/rotation_conversions/jax.py +41 -0
- xbarray/transformations/rotation_conversions/numpy.py +41 -0
- xbarray/transformations/rotation_conversions/pytorch.py +41 -0
- xbarray-0.0.1a13.dist-info/METADATA +20 -0
- xbarray-0.0.1a13.dist-info/RECORD +51 -0
- xbarray-0.0.1a13.dist-info/WHEEL +5 -0
- xbarray-0.0.1a13.dist-info/licenses/LICENSE +21 -0
- xbarray-0.0.1a13.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,713 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# ---------- Note ----------
|
|
8
|
+
# This file is pulled from PyTorch3D (https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py)
|
|
9
|
+
# with minor modifications for cross-backend compatibility.
|
|
10
|
+
# Please refer to the original file for the full license and copyright notice.
|
|
11
|
+
# Please see https://github.com/facebookresearch/pytorch3d/issues/2002 for some issues involving axis angle rotations
|
|
12
|
+
# --------------------------
|
|
13
|
+
|
|
14
|
+
from typing import Optional
|
|
15
|
+
from xbarray.backends.base import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"quaternion_to_matrix",
|
|
19
|
+
"matrix_to_quaternion",
|
|
20
|
+
"euler_angles_to_matrix",
|
|
21
|
+
"matrix_to_euler_angles",
|
|
22
|
+
"random_quaternions",
|
|
23
|
+
"random_rotations",
|
|
24
|
+
"random_rotation",
|
|
25
|
+
"standardize_quaternion",
|
|
26
|
+
"quaternion_multiply",
|
|
27
|
+
"quaternion_invert",
|
|
28
|
+
"quaternion_apply",
|
|
29
|
+
"axis_angle_to_matrix",
|
|
30
|
+
"matrix_to_axis_angle",
|
|
31
|
+
"axis_angle_to_quaternion",
|
|
32
|
+
"quaternion_to_axis_angle",
|
|
33
|
+
"rotation_6d_to_matrix",
|
|
34
|
+
"matrix_to_rotation_6d",
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
def quaternion_to_matrix(
|
|
38
|
+
backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
39
|
+
quaternions: BArrayType,
|
|
40
|
+
) -> BArrayType:
|
|
41
|
+
"""
|
|
42
|
+
Convert rotations given as quaternions to rotation matrices.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
backend: The backend to use for the computation.
|
|
46
|
+
quaternions: quaternions with real part first,
|
|
47
|
+
as tensor of shape (..., 4).
|
|
48
|
+
Returns:
|
|
49
|
+
Rotation matrices as tensor of shape (..., 3, 3).
|
|
50
|
+
"""
|
|
51
|
+
r, i, j, k = quaternions[..., 0], quaternions[..., 1], quaternions[..., 2], quaternions[..., 3]
|
|
52
|
+
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
|
53
|
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
|
54
|
+
|
|
55
|
+
o = backend.stack(
|
|
56
|
+
(
|
|
57
|
+
1 - two_s * (j * j + k * k),
|
|
58
|
+
two_s * (i * j - k * r),
|
|
59
|
+
two_s * (i * k + j * r),
|
|
60
|
+
two_s * (i * j + k * r),
|
|
61
|
+
1 - two_s * (i * i + k * k),
|
|
62
|
+
two_s * (j * k - i * r),
|
|
63
|
+
two_s * (i * k - j * r),
|
|
64
|
+
two_s * (j * k + i * r),
|
|
65
|
+
1 - two_s * (i * i + j * j),
|
|
66
|
+
),
|
|
67
|
+
axis=-1,
|
|
68
|
+
)
|
|
69
|
+
return backend.reshape(
|
|
70
|
+
o, quaternions.shape[:-1] + (3, 3)
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _copysign(
|
|
75
|
+
backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
76
|
+
a: BArrayType, b: BArrayType
|
|
77
|
+
) -> BArrayType:
|
|
78
|
+
"""
|
|
79
|
+
Return a tensor where each element has the absolute value taken from the,
|
|
80
|
+
corresponding element of a, with sign taken from the corresponding
|
|
81
|
+
element of b. This is like the standard copysign floating-point operation,
|
|
82
|
+
but is not careful about negative 0 and NaN.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
backend: The backend to use for the computation.
|
|
86
|
+
a: source tensor.
|
|
87
|
+
b: tensor whose signs will be used, of the same shape as a.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
Tensor of the same shape as a with the signs of b.
|
|
91
|
+
"""
|
|
92
|
+
signs_differ = (a < 0) != (b < 0)
|
|
93
|
+
return backend.where(signs_differ, -a, a)
|
|
94
|
+
|
|
95
|
+
def _sqrt_positive_part(backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType], x: BArrayType) -> BArrayType:
|
|
96
|
+
"""
|
|
97
|
+
Returns backend.sqrt(backend.max(0, x))
|
|
98
|
+
but with a zero subgradient where x is 0.
|
|
99
|
+
"""
|
|
100
|
+
positive_mask = x > 0
|
|
101
|
+
ret = backend.where(positive_mask, backend.sqrt(x), x)
|
|
102
|
+
return ret
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def matrix_to_quaternion(backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType], matrix: BArrayType) -> BArrayType:
|
|
106
|
+
"""
|
|
107
|
+
Convert rotations given as rotation matrices to quaternions.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
backend: The backend to use for the computation.
|
|
111
|
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
quaternions with real part first, as tensor of shape (..., 4).
|
|
115
|
+
"""
|
|
116
|
+
if matrix.shape[-1] != 3 or matrix.shape[-2] != 3:
|
|
117
|
+
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
|
118
|
+
|
|
119
|
+
batch_dim = matrix.shape[:-2]
|
|
120
|
+
m00, m01, m02, m10, m11, m12, m20, m21, m22 = matrix[..., 0, 0], matrix[..., 0, 1], matrix[..., 0, 2], matrix[..., 1, 0], matrix[..., 1, 1], matrix[..., 1, 2], matrix[..., 2, 0], matrix[..., 2, 1], matrix[..., 2, 2]
|
|
121
|
+
|
|
122
|
+
q_abs = _sqrt_positive_part(
|
|
123
|
+
backend,
|
|
124
|
+
backend.stack(
|
|
125
|
+
[
|
|
126
|
+
1.0 + m00 + m11 + m22,
|
|
127
|
+
1.0 + m00 - m11 - m22,
|
|
128
|
+
1.0 - m00 + m11 - m22,
|
|
129
|
+
1.0 - m00 - m11 + m22,
|
|
130
|
+
],
|
|
131
|
+
axis=-1,
|
|
132
|
+
)
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# we produce the desired quaternion multiplied by each of r, i, j, k
|
|
136
|
+
quat_by_rijk = backend.stack(
|
|
137
|
+
[
|
|
138
|
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
|
139
|
+
# `int`.
|
|
140
|
+
backend.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], axis=-1),
|
|
141
|
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
|
142
|
+
# `int`.
|
|
143
|
+
backend.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], axis=-1),
|
|
144
|
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
|
145
|
+
# `int`.
|
|
146
|
+
backend.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], axis=-1),
|
|
147
|
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
|
148
|
+
# `int`.
|
|
149
|
+
backend.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], axis=-1),
|
|
150
|
+
],
|
|
151
|
+
axis=-2,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
|
|
155
|
+
# the candidate won't be picked.
|
|
156
|
+
# Due to https://github.com/data-apis/array-api-compat/issues/271, cannot use `backend.maximum` here
|
|
157
|
+
quat_candidates = quat_by_rijk / (2.0 * backend.clip(q_abs[..., None], min=0.1))
|
|
158
|
+
|
|
159
|
+
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
|
|
160
|
+
# forall i; we pick the best-conditioned one (with the largest denominator)
|
|
161
|
+
indices = backend.argmax(q_abs, axis=-1, keepdims=True)
|
|
162
|
+
expand_dims = list(batch_dim) + [1, 4]
|
|
163
|
+
gather_indices = backend.broadcast_to(indices[..., None], expand_dims)
|
|
164
|
+
out = backend.take_along_axis(quat_candidates, gather_indices, axis=-2)[..., 0, :]
|
|
165
|
+
return standardize_quaternion(backend, out)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _axis_angle_rotation(
|
|
169
|
+
backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
170
|
+
axis: str, angle: BArrayType
|
|
171
|
+
) -> BArrayType:
|
|
172
|
+
"""
|
|
173
|
+
Return the rotation matrices for one of the rotations about an axis
|
|
174
|
+
of which Euler angles describe, for each value of the angle given.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
backend: The backend to use for the computation.
|
|
178
|
+
axis: Axis label "X" or "Y or "Z".
|
|
179
|
+
angle: any shape tensor of Euler angles in radians
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
Rotation matrices as tensor of shape (..., 3, 3).
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
cos = backend.cos(angle)
|
|
186
|
+
sin = backend.sin(angle)
|
|
187
|
+
one = backend.ones_like(angle)
|
|
188
|
+
zero = backend.zeros_like(angle)
|
|
189
|
+
|
|
190
|
+
if axis == "X":
|
|
191
|
+
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
|
|
192
|
+
elif axis == "Y":
|
|
193
|
+
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
|
|
194
|
+
elif axis == "Z":
|
|
195
|
+
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
|
|
196
|
+
else:
|
|
197
|
+
raise ValueError("letter must be either X, Y or Z.")
|
|
198
|
+
|
|
199
|
+
return backend.reshape(
|
|
200
|
+
backend.stack(R_flat, axis=-1),
|
|
201
|
+
angle.shape + (3, 3)
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
def euler_angles_to_matrix(
|
|
205
|
+
backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
206
|
+
euler_angles: BArrayType, convention: str
|
|
207
|
+
) -> BArrayType:
|
|
208
|
+
"""
|
|
209
|
+
Convert rotations given as Euler angles in radians to rotation matrices.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
backend: The backend to use for the computation.
|
|
213
|
+
euler_angles: Euler angles in radians as tensor of shape (..., 3).
|
|
214
|
+
convention: Convention string of three uppercase letters from
|
|
215
|
+
{"X", "Y", and "Z"}.
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
Rotation matrices as tensor of shape (..., 3, 3).
|
|
219
|
+
"""
|
|
220
|
+
if euler_angles.ndim == 0 or euler_angles.shape[-1] != 3:
|
|
221
|
+
raise ValueError("Invalid input euler angles.")
|
|
222
|
+
if len(convention) != 3:
|
|
223
|
+
raise ValueError("Convention must have 3 letters.")
|
|
224
|
+
if convention[1] in (convention[0], convention[2]):
|
|
225
|
+
raise ValueError(f"Invalid convention {convention}.")
|
|
226
|
+
for letter in convention:
|
|
227
|
+
if letter not in ("X", "Y", "Z"):
|
|
228
|
+
raise ValueError(f"Invalid letter {letter} in convention string.")
|
|
229
|
+
matrices = [
|
|
230
|
+
_axis_angle_rotation(backend, c, e)
|
|
231
|
+
for c, e in zip(convention, [euler_angles[..., i] for i in range(3)])
|
|
232
|
+
]
|
|
233
|
+
# return functools.reduce(torch.matmul, matrices)
|
|
234
|
+
return backend.matmul(backend.matmul(matrices[0], matrices[1]), matrices[2])
|
|
235
|
+
|
|
236
|
+
def _angle_from_tan(
|
|
237
|
+
backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
238
|
+
axis: str, other_axis: str,
|
|
239
|
+
data : BArrayType,
|
|
240
|
+
horizontal: bool, tait_bryan: bool
|
|
241
|
+
) -> BArrayType:
|
|
242
|
+
"""
|
|
243
|
+
Extract the first or third Euler angle from the two members of
|
|
244
|
+
the matrix which are positive constant times its sine and cosine.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
backend: The backend to use for the computation.
|
|
248
|
+
axis: Axis label "X" or "Y or "Z" for the angle we are finding.
|
|
249
|
+
other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
|
|
250
|
+
convention.
|
|
251
|
+
data: Rotation matrices as tensor of shape (..., 3, 3).
|
|
252
|
+
horizontal: Whether we are looking for the angle for the third axis,
|
|
253
|
+
which means the relevant entries are in the same row of the
|
|
254
|
+
rotation matrix. If not, they are in the same column.
|
|
255
|
+
tait_bryan: Whether the first and third axes in the convention differ.
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
Euler Angles in radians for each matrix in data as a tensor
|
|
259
|
+
of shape (...).
|
|
260
|
+
"""
|
|
261
|
+
|
|
262
|
+
i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
|
|
263
|
+
if horizontal:
|
|
264
|
+
i2, i1 = i1, i2
|
|
265
|
+
even = (axis + other_axis) in ["XY", "YZ", "ZX"]
|
|
266
|
+
if horizontal == even:
|
|
267
|
+
return backend.atan2(data[..., i1], data[..., i2])
|
|
268
|
+
if tait_bryan:
|
|
269
|
+
return backend.atan2(-data[..., i2], data[..., i1])
|
|
270
|
+
return backend.atan2(data[..., i2], -data[..., i1])
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def _index_from_letter(letter: str) -> int:
|
|
274
|
+
if letter == "X":
|
|
275
|
+
return 0
|
|
276
|
+
if letter == "Y":
|
|
277
|
+
return 1
|
|
278
|
+
if letter == "Z":
|
|
279
|
+
return 2
|
|
280
|
+
raise ValueError("letter must be either X, Y or Z.")
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def matrix_to_euler_angles(
|
|
284
|
+
backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
285
|
+
matrix: BArrayType, convention: str
|
|
286
|
+
) -> BArrayType:
|
|
287
|
+
"""
|
|
288
|
+
Convert rotations given as rotation matrices to Euler angles in radians.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
|
292
|
+
convention: Convention string of three uppercase letters.
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
Euler angles in radians as tensor of shape (..., 3).
|
|
296
|
+
"""
|
|
297
|
+
if len(convention) != 3:
|
|
298
|
+
raise ValueError("Convention must have 3 letters.")
|
|
299
|
+
if convention[1] in (convention[0], convention[2]):
|
|
300
|
+
raise ValueError(f"Invalid convention {convention}.")
|
|
301
|
+
for letter in convention:
|
|
302
|
+
if letter not in ("X", "Y", "Z"):
|
|
303
|
+
raise ValueError(f"Invalid letter {letter} in convention string.")
|
|
304
|
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
|
305
|
+
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
|
306
|
+
i0 = _index_from_letter(convention[0])
|
|
307
|
+
i2 = _index_from_letter(convention[2])
|
|
308
|
+
tait_bryan = i0 != i2
|
|
309
|
+
if tait_bryan:
|
|
310
|
+
central_angle = backend.asin(
|
|
311
|
+
backend.clip(matrix[..., i0, i2], -1.0, 1.0)
|
|
312
|
+
* (-1.0 if i0 - i2 in [-1, 2] else 1.0)
|
|
313
|
+
)
|
|
314
|
+
else:
|
|
315
|
+
central_angle = backend.acos(backend.clip(matrix[..., i0, i0], -1.0, 1.0))
|
|
316
|
+
|
|
317
|
+
o = (
|
|
318
|
+
_angle_from_tan(
|
|
319
|
+
backend,
|
|
320
|
+
convention[0], convention[1], matrix[..., i2], False, tait_bryan
|
|
321
|
+
),
|
|
322
|
+
central_angle,
|
|
323
|
+
_angle_from_tan(
|
|
324
|
+
backend,
|
|
325
|
+
convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
|
|
326
|
+
),
|
|
327
|
+
)
|
|
328
|
+
return backend.stack(o, -1)
|
|
329
|
+
|
|
330
|
+
def random_quaternions(
|
|
331
|
+
backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType], rng : BRNGType,
|
|
332
|
+
n: int, dtype: Optional[BDtypeType] = None, device: Optional[BDeviceType] = None,
|
|
333
|
+
) -> BArrayType:
|
|
334
|
+
"""
|
|
335
|
+
Generate random quaternions representing rotations,
|
|
336
|
+
i.e. versors with nonnegative real part.
|
|
337
|
+
|
|
338
|
+
Args:
|
|
339
|
+
backend: The backend to use for the computation.
|
|
340
|
+
rng: A random number generator of the appropriate type for the backend.
|
|
341
|
+
n: Number of quaternions in a batch to return.
|
|
342
|
+
dtype: Type to return.
|
|
343
|
+
device: Desired device of returned tensor. Default:
|
|
344
|
+
uses the current device for the default tensor type.
|
|
345
|
+
|
|
346
|
+
Returns:
|
|
347
|
+
Quaternions as tensor of shape (N, 4).
|
|
348
|
+
"""
|
|
349
|
+
o = backend.random.random_normal((n, 4), rng=rng, dtype=dtype, device=device)
|
|
350
|
+
s = backend.sum(o * o, axis=1)
|
|
351
|
+
o = o / _copysign(backend, backend.sqrt(s), o[:, 0])[:, None]
|
|
352
|
+
return o
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def random_rotations(
|
|
356
|
+
backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType], rng : BRNGType,
|
|
357
|
+
n: int, dtype: Optional[BDtypeType] = None, device: Optional[BDeviceType] = None
|
|
358
|
+
) -> BArrayType:
|
|
359
|
+
"""
|
|
360
|
+
Generate random rotations as 3x3 rotation matrices.
|
|
361
|
+
|
|
362
|
+
Args:
|
|
363
|
+
backend: The backend to use for the computation.
|
|
364
|
+
rng: A random number generator of the appropriate type for the backend.
|
|
365
|
+
n: Number of rotation matrices in a batch to return.
|
|
366
|
+
dtype: Type to return.
|
|
367
|
+
device: Device of returned tensor. Default: if None,
|
|
368
|
+
uses the current device for the default tensor type.
|
|
369
|
+
|
|
370
|
+
Returns:
|
|
371
|
+
Rotation matrices as tensor of shape (n, 3, 3).
|
|
372
|
+
"""
|
|
373
|
+
quaternions = random_quaternions(backend, rng, n, dtype=dtype, device=device)
|
|
374
|
+
return quaternion_to_matrix(backend, quaternions)
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
def random_rotation(
|
|
378
|
+
backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType], rng : BRNGType,
|
|
379
|
+
dtype: Optional[BDtypeType] = None, device: Optional[BDeviceType] = None
|
|
380
|
+
) -> BArrayType:
|
|
381
|
+
"""
|
|
382
|
+
Generate a single random 3x3 rotation matrix.
|
|
383
|
+
|
|
384
|
+
Args:
|
|
385
|
+
backend: The backend to use for the computation.
|
|
386
|
+
rng: A random number generator of the appropriate type for the backend.
|
|
387
|
+
dtype: Type to return
|
|
388
|
+
device: Device of returned tensor. Default: if None,
|
|
389
|
+
uses the current device for the default tensor type
|
|
390
|
+
|
|
391
|
+
Returns:
|
|
392
|
+
Rotation matrix as tensor of shape (3, 3).
|
|
393
|
+
"""
|
|
394
|
+
return random_rotations(backend, rng, 1, dtype, device)[0]
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def standardize_quaternion(
|
|
398
|
+
backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
399
|
+
quaternions: BArrayType) -> BArrayType:
|
|
400
|
+
"""
|
|
401
|
+
Convert a unit quaternion to a standard form: one in which the real
|
|
402
|
+
part is non negative.
|
|
403
|
+
|
|
404
|
+
Args:
|
|
405
|
+
backend: The backend to use for the computation.
|
|
406
|
+
quaternions: Quaternions with real part first,
|
|
407
|
+
as tensor of shape (..., 4).
|
|
408
|
+
|
|
409
|
+
Returns:
|
|
410
|
+
Standardized quaternions as tensor of shape (..., 4).
|
|
411
|
+
"""
|
|
412
|
+
return backend.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
def quaternion_raw_multiply(
|
|
416
|
+
backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
417
|
+
a: BArrayType, b: BArrayType
|
|
418
|
+
) -> BArrayType:
|
|
419
|
+
"""
|
|
420
|
+
Multiply two quaternions.
|
|
421
|
+
Usual torch rules for broadcasting apply.
|
|
422
|
+
|
|
423
|
+
Args:
|
|
424
|
+
backend: The backend to use for the computation.
|
|
425
|
+
a: Quaternions as tensor of shape (..., 4), real part first.
|
|
426
|
+
b: Quaternions as tensor of shape (..., 4), real part first.
|
|
427
|
+
|
|
428
|
+
Returns:
|
|
429
|
+
The product of a and b, a tensor of quaternions shape (..., 4).
|
|
430
|
+
"""
|
|
431
|
+
aw, ax, ay, az = a[..., 0], a[..., 1], a[..., 2], a[..., 3]
|
|
432
|
+
bw, bx, by, bz = b[..., 0], b[..., 1], b[..., 2], b[..., 3]
|
|
433
|
+
ow = aw * bw - ax * bx - ay * by - az * bz
|
|
434
|
+
ox = aw * bx + ax * bw + ay * bz - az * by
|
|
435
|
+
oy = aw * by - ax * bz + ay * bw + az * bx
|
|
436
|
+
oz = aw * bz + ax * by - ay * bx + az * bw
|
|
437
|
+
return backend.stack((ow, ox, oy, oz), axis=-1)
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
def quaternion_multiply(
|
|
441
|
+
backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
442
|
+
a: BArrayType, b: BArrayType
|
|
443
|
+
) -> BArrayType:
|
|
444
|
+
"""
|
|
445
|
+
Multiply two quaternions representing rotations, returning the quaternion
|
|
446
|
+
representing their composition, i.e. the versor with nonnegative real part.
|
|
447
|
+
Usual torch rules for broadcasting apply.
|
|
448
|
+
|
|
449
|
+
Args:
|
|
450
|
+
backend: The backend to use for the computation.
|
|
451
|
+
a: Quaternions as tensor of shape (..., 4), real part first.
|
|
452
|
+
b: Quaternions as tensor of shape (..., 4), real part first.
|
|
453
|
+
|
|
454
|
+
Returns:
|
|
455
|
+
The product of a and b, a tensor of quaternions of shape (..., 4).
|
|
456
|
+
"""
|
|
457
|
+
ab = quaternion_raw_multiply(backend, a, b)
|
|
458
|
+
return standardize_quaternion(backend, ab)
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
def quaternion_invert(
|
|
462
|
+
backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
463
|
+
quaternion: BArrayType
|
|
464
|
+
) -> BArrayType:
|
|
465
|
+
"""
|
|
466
|
+
Given a quaternion representing rotation, get the quaternion representing
|
|
467
|
+
its inverse.
|
|
468
|
+
|
|
469
|
+
Args:
|
|
470
|
+
backend: The backend to use for the computation.
|
|
471
|
+
quaternion: Quaternions as tensor of shape (..., 4), with real part
|
|
472
|
+
first, which must be versors (unit quaternions).
|
|
473
|
+
|
|
474
|
+
Returns:
|
|
475
|
+
The inverse, a tensor of quaternions of shape (..., 4).
|
|
476
|
+
"""
|
|
477
|
+
|
|
478
|
+
scaling = backend.reshape(
|
|
479
|
+
backend.asarray([1, -1, -1, -1], device=backend.device(quaternion)),
|
|
480
|
+
[1] * (len(quaternion.shape) - 1) + [4]
|
|
481
|
+
)
|
|
482
|
+
return quaternion * scaling
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
def quaternion_apply(
|
|
486
|
+
backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
487
|
+
quaternion: BArrayType, point: BArrayType
|
|
488
|
+
) -> BArrayType:
|
|
489
|
+
"""
|
|
490
|
+
Apply the rotation given by a quaternion to a 3D point.
|
|
491
|
+
Usual torch rules for broadcasting apply.
|
|
492
|
+
|
|
493
|
+
Args:
|
|
494
|
+
backend: The backend to use for the computation.
|
|
495
|
+
quaternion: Tensor of quaternions, real part first, of shape (..., 4).
|
|
496
|
+
point: Tensor of 3D points of shape (..., 3).
|
|
497
|
+
|
|
498
|
+
Returns:
|
|
499
|
+
Tensor of rotated points of shape (..., 3).
|
|
500
|
+
"""
|
|
501
|
+
if point.shape[-1] != 3:
|
|
502
|
+
raise ValueError(f"Points are not in 3D, {point.shape}.")
|
|
503
|
+
real_parts = backend.zeros(point.shape[:-1] + (1,), dtype=point.dtype, device=backend.device(point))
|
|
504
|
+
point_as_quaternion = backend.concat((real_parts, point), axis=-1)
|
|
505
|
+
out = quaternion_raw_multiply(
|
|
506
|
+
quaternion_raw_multiply(backend, quaternion, point_as_quaternion),
|
|
507
|
+
quaternion_invert(backend, quaternion),
|
|
508
|
+
)
|
|
509
|
+
return out[..., 1:]
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
def axis_angle_to_matrix(
|
|
513
|
+
backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
514
|
+
axis_angle: BArrayType, fast: bool = False) -> BArrayType:
|
|
515
|
+
"""
|
|
516
|
+
Convert rotations given as axis/angle to rotation matrices.
|
|
517
|
+
|
|
518
|
+
Args:
|
|
519
|
+
backend: The backend to use for the computation.
|
|
520
|
+
axis_angle: Rotations given as a vector in axis angle form,
|
|
521
|
+
as a tensor of shape (..., 3), where the magnitude is
|
|
522
|
+
the angle turned anticlockwise in radians around the
|
|
523
|
+
vector's direction.
|
|
524
|
+
fast: Whether to use the new faster implementation (based on the
|
|
525
|
+
Rodrigues formula) instead of the original implementation (which
|
|
526
|
+
first converted to a quaternion and then back to a rotation matrix).
|
|
527
|
+
|
|
528
|
+
Returns:
|
|
529
|
+
Rotation matrices as tensor of shape (..., 3, 3).
|
|
530
|
+
"""
|
|
531
|
+
if not fast:
|
|
532
|
+
return quaternion_to_matrix(backend, axis_angle_to_quaternion(backend, axis_angle))
|
|
533
|
+
|
|
534
|
+
shape = axis_angle.shape
|
|
535
|
+
device, dtype = backend.device(axis_angle), axis_angle.dtype
|
|
536
|
+
|
|
537
|
+
angles = backend.linalg.vector_norm(axis_angle, ord=2, axis=-1, keepdims=True)[..., backend.newaxis]
|
|
538
|
+
|
|
539
|
+
rx, ry, rz = axis_angle[..., 0], axis_angle[..., 1], axis_angle[..., 2]
|
|
540
|
+
zeros = backend.zeros(shape[:-1], dtype=dtype, device=device)
|
|
541
|
+
cross_product_matrix = backend.reshape(backend.stack(
|
|
542
|
+
[zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], axis=-1
|
|
543
|
+
), (shape + (3,)))
|
|
544
|
+
cross_product_matrix_sqrd = cross_product_matrix @ cross_product_matrix
|
|
545
|
+
|
|
546
|
+
identity = backend.eye(3, dtype=dtype, device=device)
|
|
547
|
+
angles_sqrd = angles * angles
|
|
548
|
+
angles_sqrd = backend.where(angles_sqrd == 0, 1, angles_sqrd)
|
|
549
|
+
return (
|
|
550
|
+
backend.reshape(identity, [1] * (len(shape) - 1) + (3,3))
|
|
551
|
+
+ backend.sinc(angles / backend.pi) * cross_product_matrix
|
|
552
|
+
+ ((1 - backend.cos(angles)) / angles_sqrd) * cross_product_matrix_sqrd
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
|
|
556
|
+
def matrix_to_axis_angle(
|
|
557
|
+
backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
558
|
+
matrix: BArrayType, fast: bool = False
|
|
559
|
+
) -> BArrayType:
|
|
560
|
+
"""
|
|
561
|
+
Convert rotations given as rotation matrices to axis/angle.
|
|
562
|
+
|
|
563
|
+
Args:
|
|
564
|
+
backend: The backend to use for the computation.
|
|
565
|
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
|
566
|
+
fast: Whether to use the new faster implementation (based on the
|
|
567
|
+
Rodrigues formula) instead of the original implementation (which
|
|
568
|
+
first converted to a quaternion and then back to a rotation matrix).
|
|
569
|
+
|
|
570
|
+
Returns:
|
|
571
|
+
Rotations given as a vector in axis angle form, as a tensor
|
|
572
|
+
of shape (..., 3), where the magnitude is the angle
|
|
573
|
+
turned anticlockwise in radians around the vector's
|
|
574
|
+
direction.
|
|
575
|
+
|
|
576
|
+
"""
|
|
577
|
+
if not fast:
|
|
578
|
+
return quaternion_to_axis_angle(backend, matrix_to_quaternion(backend, matrix))
|
|
579
|
+
|
|
580
|
+
if matrix.shape[-1] != 3 or matrix.shape[-2] != 3:
|
|
581
|
+
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
|
582
|
+
|
|
583
|
+
omegas = backend.stack(
|
|
584
|
+
[
|
|
585
|
+
matrix[..., 2, 1] - matrix[..., 1, 2],
|
|
586
|
+
matrix[..., 0, 2] - matrix[..., 2, 0],
|
|
587
|
+
matrix[..., 1, 0] - matrix[..., 0, 1],
|
|
588
|
+
],
|
|
589
|
+
axis=-1,
|
|
590
|
+
)
|
|
591
|
+
norms = backend.linalg.vector_norm(omegas, ord=2, axis=-1, keepdims=True)
|
|
592
|
+
traces = backend.sum(backend.linalg.diagonal(matrix), axis=-1, keepdims=True)
|
|
593
|
+
angles = backend.atan2(norms, traces - 1)
|
|
594
|
+
|
|
595
|
+
zeros = backend.zeros(3, dtype=matrix.dtype, device=backend.device(matrix))
|
|
596
|
+
omegas = backend.where(backend.isclose(angles, 0), zeros, omegas)
|
|
597
|
+
|
|
598
|
+
near_pi = backend.isclose(angles, backend.pi)[..., 0]
|
|
599
|
+
|
|
600
|
+
axis_angles = backend.empty_like(omegas)
|
|
601
|
+
axis_angles = backend.at(axis_angles, backend.logical_not(near_pi)).set(
|
|
602
|
+
0.5 * omegas[~near_pi] / backend.sinc(angles[~near_pi] / backend.pi)
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
# this derives from: nnT = (R + 1) / 2
|
|
606
|
+
n = 0.5 * (
|
|
607
|
+
matrix[near_pi][..., 0, :]
|
|
608
|
+
+ backend.eye(1, 3, dtype=matrix.dtype, device=backend.device(matrix))
|
|
609
|
+
)
|
|
610
|
+
# TODO(Yunhao): The original pytorch3d file contains `torch.norm` which ignores batch shape, not sure if this is exactly what we wanted
|
|
611
|
+
axis_angles[near_pi] = angles[near_pi] * n / backend.linalg.vector_norm(n, axis=-1, keepdims=True, ord=2)
|
|
612
|
+
|
|
613
|
+
return axis_angles
|
|
614
|
+
|
|
615
|
+
|
|
616
|
+
def axis_angle_to_quaternion(
|
|
617
|
+
backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
618
|
+
axis_angle: BArrayType
|
|
619
|
+
) -> BArrayType:
|
|
620
|
+
"""
|
|
621
|
+
Convert rotations given as axis/angle to quaternions.
|
|
622
|
+
|
|
623
|
+
Args:
|
|
624
|
+
backend: The backend to use for the computation.
|
|
625
|
+
axis_angle: Rotations given as a vector in axis angle form,
|
|
626
|
+
as a tensor of shape (..., 3), where the magnitude is
|
|
627
|
+
the angle turned anticlockwise in radians around the
|
|
628
|
+
vector's direction.
|
|
629
|
+
|
|
630
|
+
Returns:
|
|
631
|
+
quaternions with real part first, as tensor of shape (..., 4).
|
|
632
|
+
"""
|
|
633
|
+
angles = backend.linalg.vector_norm(axis_angle, ord=2, axis=-1, keepdims=True)
|
|
634
|
+
sin_half_angles_over_angles = 0.5 * backend.sinc(angles * 0.5 / backend.pi)
|
|
635
|
+
return backend.concat(
|
|
636
|
+
[backend.cos(angles * 0.5), axis_angle * sin_half_angles_over_angles], axis=-1
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
|
|
640
|
+
def quaternion_to_axis_angle(
|
|
641
|
+
backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
642
|
+
quaternions: BArrayType
|
|
643
|
+
) -> BArrayType:
|
|
644
|
+
"""
|
|
645
|
+
Convert rotations given as quaternions to axis/angle.
|
|
646
|
+
|
|
647
|
+
Args:
|
|
648
|
+
backend: The backend to use for the computation.
|
|
649
|
+
quaternions: quaternions with real part first,
|
|
650
|
+
as tensor of shape (..., 4).
|
|
651
|
+
|
|
652
|
+
Returns:
|
|
653
|
+
Rotations given as a vector in axis angle form, as a tensor
|
|
654
|
+
of shape (..., 3), where the magnitude is the angle
|
|
655
|
+
turned anticlockwise in radians around the vector's
|
|
656
|
+
direction.
|
|
657
|
+
"""
|
|
658
|
+
norms = backend.linalg.vector_norm(quaternions[..., 1:], ord=2, axis=-1, keepdims=True)
|
|
659
|
+
half_angles = backend.atan2(norms, quaternions[..., :1])
|
|
660
|
+
sin_half_angles_over_angles = 0.5 * backend.sinc(half_angles / backend.pi)
|
|
661
|
+
# angles/2 are between [-pi/2, pi/2], thus sin_half_angles_over_angles
|
|
662
|
+
# can't be zero
|
|
663
|
+
return quaternions[..., 1:] / sin_half_angles_over_angles
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
def rotation_6d_to_matrix(
|
|
667
|
+
backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
668
|
+
d6: BArrayType
|
|
669
|
+
) -> BArrayType:
|
|
670
|
+
"""
|
|
671
|
+
Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
|
|
672
|
+
using Gram--Schmidt orthogonalization per Section B of [1].
|
|
673
|
+
Args:
|
|
674
|
+
backend: The backend to use for the computation.
|
|
675
|
+
d6: 6D rotation representation, of size (*, 6)
|
|
676
|
+
|
|
677
|
+
Returns:
|
|
678
|
+
batch of rotation matrices of size (*, 3, 3)
|
|
679
|
+
|
|
680
|
+
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
|
|
681
|
+
On the Continuity of Rotation Representations in Neural Networks.
|
|
682
|
+
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
|
|
683
|
+
Retrieved from http://arxiv.org/abs/1812.07035
|
|
684
|
+
"""
|
|
685
|
+
|
|
686
|
+
a1, a2 = d6[..., :3], d6[..., 3:]
|
|
687
|
+
b1 = a1 / backend.linalg.vector_norm(a1, ord=2, axis=-1, keepdims=True)
|
|
688
|
+
b2 = a2 - backend.sum(b1 * a2, axis=-1, keepdims=True) * b1
|
|
689
|
+
b2 = b2 / backend.linalg.vector_norm(b2, ord=2, axis=-1, keepdims=True)
|
|
690
|
+
b3 = backend.linalg.cross(b1, b2, axis=-1)
|
|
691
|
+
return backend.stack((b1, b2, b3), axis=-2)
|
|
692
|
+
|
|
693
|
+
def matrix_to_rotation_6d(
|
|
694
|
+
backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
695
|
+
matrix: BArrayType
|
|
696
|
+
) -> BArrayType:
|
|
697
|
+
"""
|
|
698
|
+
Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
|
|
699
|
+
by dropping the last row. Note that 6D representation is not unique.
|
|
700
|
+
Args:
|
|
701
|
+
backend: The backend to use for the computation.
|
|
702
|
+
matrix: batch of rotation matrices of size (*, 3, 3)
|
|
703
|
+
|
|
704
|
+
Returns:
|
|
705
|
+
6D rotation representation, of size (*, 6)
|
|
706
|
+
|
|
707
|
+
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
|
|
708
|
+
On the Continuity of Rotation Representations in Neural Networks.
|
|
709
|
+
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
|
|
710
|
+
Retrieved from http://arxiv.org/abs/1812.07035
|
|
711
|
+
"""
|
|
712
|
+
batch_dim = matrix.shape[:-2]
|
|
713
|
+
return backend.reshape(matrix[..., :2, :], (batch_dim + (6,)))
|