xbarray 0.0.1a7__py3-none-any.whl → 0.0.1a10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xbarray might be problematic. Click here for more details.

Files changed (42) hide show
  1. xbarray/__init__.py +1 -1
  2. xbarray/backends/_cls_base.py +9 -0
  3. xbarray/{_src/implementations → backends/_implementations}/_common/implementations.py +32 -7
  4. xbarray/{_src/implementations → backends/_implementations}/jax/_extra.py +4 -1
  5. xbarray/{_src/implementations → backends/_implementations}/jax/random.py +1 -1
  6. xbarray/{_src/implementations → backends/_implementations}/numpy/_extra.py +6 -2
  7. xbarray/{_src/implementations → backends/_implementations}/pytorch/_extra.py +5 -2
  8. xbarray/{base → backends}/base.py +11 -1
  9. xbarray/backends/jax.py +19 -0
  10. xbarray/backends/numpy.py +19 -0
  11. xbarray/backends/pytorch.py +22 -0
  12. xbarray/jax.py +4 -19
  13. xbarray/numpy.py +4 -19
  14. xbarray/pytorch.py +4 -19
  15. xbarray/transformations/pointcloud/__init__.py +1 -0
  16. xbarray/transformations/pointcloud/base.py +185 -0
  17. xbarray/transformations/pointcloud/jax.py +15 -0
  18. xbarray/transformations/pointcloud/numpy.py +15 -0
  19. xbarray/transformations/pointcloud/pytorch.py +15 -0
  20. xbarray/transformations/rotation_conversions/__init__.py +1 -0
  21. xbarray/transformations/rotation_conversions/base.py +712 -0
  22. xbarray/transformations/rotation_conversions/jax.py +41 -0
  23. xbarray/transformations/rotation_conversions/numpy.py +41 -0
  24. xbarray/transformations/rotation_conversions/pytorch.py +41 -0
  25. {xbarray-0.0.1a7.dist-info → xbarray-0.0.1a10.dist-info}/METADATA +6 -1
  26. xbarray-0.0.1a10.dist-info/RECORD +51 -0
  27. xbarray/_src/serialization/__init__.py +0 -1
  28. xbarray/_src/serialization/serialization_map.py +0 -31
  29. xbarray/base/__init__.py +0 -1
  30. xbarray/cls_impl/cls_base.py +0 -10
  31. xbarray-0.0.1a7.dist-info/RECORD +0 -41
  32. /xbarray/{_src/implementations → backends/_implementations}/jax/__init__.py +0 -0
  33. /xbarray/{_src/implementations → backends/_implementations}/jax/_typing.py +0 -0
  34. /xbarray/{_src/implementations → backends/_implementations}/numpy/__init__.py +0 -0
  35. /xbarray/{_src/implementations → backends/_implementations}/numpy/_typing.py +0 -0
  36. /xbarray/{_src/implementations → backends/_implementations}/numpy/random.py +0 -0
  37. /xbarray/{_src/implementations → backends/_implementations}/pytorch/__init__.py +0 -0
  38. /xbarray/{_src/implementations → backends/_implementations}/pytorch/_typing.py +0 -0
  39. /xbarray/{_src/implementations → backends/_implementations}/pytorch/random.py +0 -0
  40. {xbarray-0.0.1a7.dist-info → xbarray-0.0.1a10.dist-info}/WHEEL +0 -0
  41. {xbarray-0.0.1a7.dist-info → xbarray-0.0.1a10.dist-info}/licenses/LICENSE +0 -0
  42. {xbarray-0.0.1a7.dist-info → xbarray-0.0.1a10.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,712 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # ---------- Note ----------
8
+ # This file is pulled from PyTorch3D (https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py)
9
+ # with minor modifications for cross-backend compatibility.
10
+ # Please refer to the original file for the full license and copyright notice.
11
+ # Please see https://github.com/facebookresearch/pytorch3d/issues/2002 for some issues involving axis angle rotations
12
+ # --------------------------
13
+
14
+ from typing import Optional
15
+ from xbarray.backends.base import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
16
+
17
+ __all__ = [
18
+ "quaternion_to_matrix",
19
+ "matrix_to_quaternion",
20
+ "euler_angles_to_matrix",
21
+ "matrix_to_euler_angles",
22
+ "random_quaternions",
23
+ "random_rotations",
24
+ "random_rotation",
25
+ "standardize_quaternion",
26
+ "quaternion_multiply",
27
+ "quaternion_invert",
28
+ "quaternion_apply",
29
+ "axis_angle_to_matrix",
30
+ "matrix_to_axis_angle",
31
+ "axis_angle_to_quaternion",
32
+ "quaternion_to_axis_angle",
33
+ "rotation_6d_to_matrix",
34
+ "matrix_to_rotation_6d",
35
+ ]
36
+
37
+ def quaternion_to_matrix(
38
+ backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
39
+ quaternions: BArrayType,
40
+ ) -> BArrayType:
41
+ """
42
+ Convert rotations given as quaternions to rotation matrices.
43
+
44
+ Args:
45
+ backend: The backend to use for the computation.
46
+ quaternions: quaternions with real part first,
47
+ as tensor of shape (..., 4).
48
+ Returns:
49
+ Rotation matrices as tensor of shape (..., 3, 3).
50
+ """
51
+ r, i, j, k = quaternions[..., 0], quaternions[..., 1], quaternions[..., 2], quaternions[..., 3]
52
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
53
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
54
+
55
+ o = backend.stack(
56
+ (
57
+ 1 - two_s * (j * j + k * k),
58
+ two_s * (i * j - k * r),
59
+ two_s * (i * k + j * r),
60
+ two_s * (i * j + k * r),
61
+ 1 - two_s * (i * i + k * k),
62
+ two_s * (j * k - i * r),
63
+ two_s * (i * k - j * r),
64
+ two_s * (j * k + i * r),
65
+ 1 - two_s * (i * i + j * j),
66
+ ),
67
+ axis=-1,
68
+ )
69
+ return backend.reshape(
70
+ o, quaternions.shape[:-1] + (3, 3)
71
+ )
72
+
73
+
74
+ def _copysign(
75
+ backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
76
+ a: BArrayType, b: BArrayType
77
+ ) -> BArrayType:
78
+ """
79
+ Return a tensor where each element has the absolute value taken from the,
80
+ corresponding element of a, with sign taken from the corresponding
81
+ element of b. This is like the standard copysign floating-point operation,
82
+ but is not careful about negative 0 and NaN.
83
+
84
+ Args:
85
+ backend: The backend to use for the computation.
86
+ a: source tensor.
87
+ b: tensor whose signs will be used, of the same shape as a.
88
+
89
+ Returns:
90
+ Tensor of the same shape as a with the signs of b.
91
+ """
92
+ signs_differ = (a < 0) != (b < 0)
93
+ return backend.where(signs_differ, -a, a)
94
+
95
+ def _sqrt_positive_part(backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType], x: BArrayType) -> BArrayType:
96
+ """
97
+ Returns backend.sqrt(backend.max(0, x))
98
+ but with a zero subgradient where x is 0.
99
+ """
100
+ positive_mask = x > 0
101
+ ret = backend.where(positive_mask, backend.sqrt(x), ret)
102
+ return ret
103
+
104
+
105
+ def matrix_to_quaternion(backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType], matrix: BArrayType) -> BArrayType:
106
+ """
107
+ Convert rotations given as rotation matrices to quaternions.
108
+
109
+ Args:
110
+ backend: The backend to use for the computation.
111
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
112
+
113
+ Returns:
114
+ quaternions with real part first, as tensor of shape (..., 4).
115
+ """
116
+ if matrix.shape[-1] != 3 or matrix.shape[-2] != 3:
117
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
118
+
119
+ batch_dim = matrix.shape[:-2]
120
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = matrix[..., 0, 0], matrix[..., 0, 1], matrix[..., 0, 2], matrix[..., 1, 0], matrix[..., 1, 1], matrix[..., 1, 2], matrix[..., 2, 0], matrix[..., 2, 1], matrix[..., 2, 2]
121
+
122
+ q_abs = _sqrt_positive_part(
123
+ backend,
124
+ backend.stack(
125
+ [
126
+ 1.0 + m00 + m11 + m22,
127
+ 1.0 + m00 - m11 - m22,
128
+ 1.0 - m00 + m11 - m22,
129
+ 1.0 - m00 - m11 + m22,
130
+ ],
131
+ axis=-1,
132
+ )
133
+ )
134
+
135
+ # we produce the desired quaternion multiplied by each of r, i, j, k
136
+ quat_by_rijk = backend.stack(
137
+ [
138
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
139
+ # `int`.
140
+ backend.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], axis=-1),
141
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
142
+ # `int`.
143
+ backend.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], axis=-1),
144
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
145
+ # `int`.
146
+ backend.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], axis=-1),
147
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
148
+ # `int`.
149
+ backend.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], axis=-1),
150
+ ],
151
+ axis=-2,
152
+ )
153
+
154
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
155
+ # the candidate won't be picked.
156
+ quat_candidates = quat_by_rijk / (2.0 * backend.maximum(q_abs[..., None], 0.1))
157
+
158
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
159
+ # forall i; we pick the best-conditioned one (with the largest denominator)
160
+ indices = backend.argmax(q_abs, axis=-1, keepdims=True)
161
+ expand_dims = list(batch_dim) + [1, 4]
162
+ gather_indices = backend.broadcast_to(indices[..., None], expand_dims)
163
+ out = backend.take_along_axis(quat_candidates, gather_indices, axis=-2)[..., 0, :]
164
+ return standardize_quaternion(backend, out)
165
+
166
+
167
+ def _axis_angle_rotation(
168
+ backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
169
+ axis: str, angle: BArrayType
170
+ ) -> BArrayType:
171
+ """
172
+ Return the rotation matrices for one of the rotations about an axis
173
+ of which Euler angles describe, for each value of the angle given.
174
+
175
+ Args:
176
+ backend: The backend to use for the computation.
177
+ axis: Axis label "X" or "Y or "Z".
178
+ angle: any shape tensor of Euler angles in radians
179
+
180
+ Returns:
181
+ Rotation matrices as tensor of shape (..., 3, 3).
182
+ """
183
+
184
+ cos = backend.cos(angle)
185
+ sin = backend.sin(angle)
186
+ one = backend.ones_like(angle)
187
+ zero = backend.zeros_like(angle)
188
+
189
+ if axis == "X":
190
+ R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
191
+ elif axis == "Y":
192
+ R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
193
+ elif axis == "Z":
194
+ R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
195
+ else:
196
+ raise ValueError("letter must be either X, Y or Z.")
197
+
198
+ return backend.reshape(
199
+ backend.stack(R_flat, axis=-1),
200
+ angle.shape + (3, 3)
201
+ )
202
+
203
+ def euler_angles_to_matrix(
204
+ backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
205
+ euler_angles: BArrayType, convention: str
206
+ ) -> BArrayType:
207
+ """
208
+ Convert rotations given as Euler angles in radians to rotation matrices.
209
+
210
+ Args:
211
+ backend: The backend to use for the computation.
212
+ euler_angles: Euler angles in radians as tensor of shape (..., 3).
213
+ convention: Convention string of three uppercase letters from
214
+ {"X", "Y", and "Z"}.
215
+
216
+ Returns:
217
+ Rotation matrices as tensor of shape (..., 3, 3).
218
+ """
219
+ if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
220
+ raise ValueError("Invalid input euler angles.")
221
+ if len(convention) != 3:
222
+ raise ValueError("Convention must have 3 letters.")
223
+ if convention[1] in (convention[0], convention[2]):
224
+ raise ValueError(f"Invalid convention {convention}.")
225
+ for letter in convention:
226
+ if letter not in ("X", "Y", "Z"):
227
+ raise ValueError(f"Invalid letter {letter} in convention string.")
228
+ matrices = [
229
+ _axis_angle_rotation(backend, c, e)
230
+ for c, e in zip(convention, [euler_angles[..., i] for i in range(3)])
231
+ ]
232
+ # return functools.reduce(torch.matmul, matrices)
233
+ return backend.matmul(backend.matmul(matrices[0], matrices[1]), matrices[2])
234
+
235
+ def _angle_from_tan(
236
+ backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
237
+ axis: str, other_axis: str,
238
+ data : BArrayType,
239
+ horizontal: bool, tait_bryan: bool
240
+ ) -> BArrayType:
241
+ """
242
+ Extract the first or third Euler angle from the two members of
243
+ the matrix which are positive constant times its sine and cosine.
244
+
245
+ Args:
246
+ backend: The backend to use for the computation.
247
+ axis: Axis label "X" or "Y or "Z" for the angle we are finding.
248
+ other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
249
+ convention.
250
+ data: Rotation matrices as tensor of shape (..., 3, 3).
251
+ horizontal: Whether we are looking for the angle for the third axis,
252
+ which means the relevant entries are in the same row of the
253
+ rotation matrix. If not, they are in the same column.
254
+ tait_bryan: Whether the first and third axes in the convention differ.
255
+
256
+ Returns:
257
+ Euler Angles in radians for each matrix in data as a tensor
258
+ of shape (...).
259
+ """
260
+
261
+ i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
262
+ if horizontal:
263
+ i2, i1 = i1, i2
264
+ even = (axis + other_axis) in ["XY", "YZ", "ZX"]
265
+ if horizontal == even:
266
+ return backend.atan2(data[..., i1], data[..., i2])
267
+ if tait_bryan:
268
+ return backend.atan2(-data[..., i2], data[..., i1])
269
+ return backend.atan2(data[..., i2], -data[..., i1])
270
+
271
+
272
+ def _index_from_letter(letter: str) -> int:
273
+ if letter == "X":
274
+ return 0
275
+ if letter == "Y":
276
+ return 1
277
+ if letter == "Z":
278
+ return 2
279
+ raise ValueError("letter must be either X, Y or Z.")
280
+
281
+
282
+ def matrix_to_euler_angles(
283
+ backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
284
+ matrix: BArrayType, convention: str
285
+ ) -> BArrayType:
286
+ """
287
+ Convert rotations given as rotation matrices to Euler angles in radians.
288
+
289
+ Args:
290
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
291
+ convention: Convention string of three uppercase letters.
292
+
293
+ Returns:
294
+ Euler angles in radians as tensor of shape (..., 3).
295
+ """
296
+ if len(convention) != 3:
297
+ raise ValueError("Convention must have 3 letters.")
298
+ if convention[1] in (convention[0], convention[2]):
299
+ raise ValueError(f"Invalid convention {convention}.")
300
+ for letter in convention:
301
+ if letter not in ("X", "Y", "Z"):
302
+ raise ValueError(f"Invalid letter {letter} in convention string.")
303
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
304
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
305
+ i0 = _index_from_letter(convention[0])
306
+ i2 = _index_from_letter(convention[2])
307
+ tait_bryan = i0 != i2
308
+ if tait_bryan:
309
+ central_angle = backend.asin(
310
+ backend.clip(matrix[..., i0, i2], -1.0, 1.0)
311
+ * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
312
+ )
313
+ else:
314
+ central_angle = backend.acos(backend.clip(matrix[..., i0, i0], -1.0, 1.0))
315
+
316
+ o = (
317
+ _angle_from_tan(
318
+ backend,
319
+ convention[0], convention[1], matrix[..., i2], False, tait_bryan
320
+ ),
321
+ central_angle,
322
+ _angle_from_tan(
323
+ backend,
324
+ convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
325
+ ),
326
+ )
327
+ return backend.stack(o, -1)
328
+
329
+ def random_quaternions(
330
+ backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType], rng : BRNGType,
331
+ n: int, dtype: Optional[BDtypeType] = None, device: Optional[BDeviceType] = None,
332
+ ) -> BArrayType:
333
+ """
334
+ Generate random quaternions representing rotations,
335
+ i.e. versors with nonnegative real part.
336
+
337
+ Args:
338
+ backend: The backend to use for the computation.
339
+ rng: A random number generator of the appropriate type for the backend.
340
+ n: Number of quaternions in a batch to return.
341
+ dtype: Type to return.
342
+ device: Desired device of returned tensor. Default:
343
+ uses the current device for the default tensor type.
344
+
345
+ Returns:
346
+ Quaternions as tensor of shape (N, 4).
347
+ """
348
+ o = backend.random.random_normal((n, 4), rng=rng, dtype=dtype, device=device)
349
+ s = backend.sum(o * o, axis=1)
350
+ o = o / _copysign(backend, backend.sqrt(s), o[:, 0])[:, None]
351
+ return o
352
+
353
+
354
+ def random_rotations(
355
+ backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType], rng : BRNGType,
356
+ n: int, dtype: Optional[BDtypeType] = None, device: Optional[BDeviceType] = None
357
+ ) -> BArrayType:
358
+ """
359
+ Generate random rotations as 3x3 rotation matrices.
360
+
361
+ Args:
362
+ backend: The backend to use for the computation.
363
+ rng: A random number generator of the appropriate type for the backend.
364
+ n: Number of rotation matrices in a batch to return.
365
+ dtype: Type to return.
366
+ device: Device of returned tensor. Default: if None,
367
+ uses the current device for the default tensor type.
368
+
369
+ Returns:
370
+ Rotation matrices as tensor of shape (n, 3, 3).
371
+ """
372
+ quaternions = random_quaternions(backend, rng, n, dtype=dtype, device=device)
373
+ return quaternion_to_matrix(backend, quaternions)
374
+
375
+
376
+ def random_rotation(
377
+ backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType], rng : BRNGType,
378
+ dtype: Optional[BDtypeType] = None, device: Optional[BDeviceType] = None
379
+ ) -> BArrayType:
380
+ """
381
+ Generate a single random 3x3 rotation matrix.
382
+
383
+ Args:
384
+ backend: The backend to use for the computation.
385
+ rng: A random number generator of the appropriate type for the backend.
386
+ dtype: Type to return
387
+ device: Device of returned tensor. Default: if None,
388
+ uses the current device for the default tensor type
389
+
390
+ Returns:
391
+ Rotation matrix as tensor of shape (3, 3).
392
+ """
393
+ return random_rotations(backend, rng, 1, dtype, device)[0]
394
+
395
+
396
+ def standardize_quaternion(
397
+ backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
398
+ quaternions: BArrayType) -> BArrayType:
399
+ """
400
+ Convert a unit quaternion to a standard form: one in which the real
401
+ part is non negative.
402
+
403
+ Args:
404
+ backend: The backend to use for the computation.
405
+ quaternions: Quaternions with real part first,
406
+ as tensor of shape (..., 4).
407
+
408
+ Returns:
409
+ Standardized quaternions as tensor of shape (..., 4).
410
+ """
411
+ return backend.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
412
+
413
+
414
+ def quaternion_raw_multiply(
415
+ backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
416
+ a: BArrayType, b: BArrayType
417
+ ) -> BArrayType:
418
+ """
419
+ Multiply two quaternions.
420
+ Usual torch rules for broadcasting apply.
421
+
422
+ Args:
423
+ backend: The backend to use for the computation.
424
+ a: Quaternions as tensor of shape (..., 4), real part first.
425
+ b: Quaternions as tensor of shape (..., 4), real part first.
426
+
427
+ Returns:
428
+ The product of a and b, a tensor of quaternions shape (..., 4).
429
+ """
430
+ aw, ax, ay, az = a[..., 0], a[..., 1], a[..., 2], a[..., 3]
431
+ bw, bx, by, bz = b[..., 0], b[..., 1], b[..., 2], b[..., 3]
432
+ ow = aw * bw - ax * bx - ay * by - az * bz
433
+ ox = aw * bx + ax * bw + ay * bz - az * by
434
+ oy = aw * by - ax * bz + ay * bw + az * bx
435
+ oz = aw * bz + ax * by - ay * bx + az * bw
436
+ return backend.stack((ow, ox, oy, oz), axis=-1)
437
+
438
+
439
+ def quaternion_multiply(
440
+ backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
441
+ a: BArrayType, b: BArrayType
442
+ ) -> BArrayType:
443
+ """
444
+ Multiply two quaternions representing rotations, returning the quaternion
445
+ representing their composition, i.e. the versor with nonnegative real part.
446
+ Usual torch rules for broadcasting apply.
447
+
448
+ Args:
449
+ backend: The backend to use for the computation.
450
+ a: Quaternions as tensor of shape (..., 4), real part first.
451
+ b: Quaternions as tensor of shape (..., 4), real part first.
452
+
453
+ Returns:
454
+ The product of a and b, a tensor of quaternions of shape (..., 4).
455
+ """
456
+ ab = quaternion_raw_multiply(backend, a, b)
457
+ return standardize_quaternion(backend, ab)
458
+
459
+
460
+ def quaternion_invert(
461
+ backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
462
+ quaternion: BArrayType
463
+ ) -> BArrayType:
464
+ """
465
+ Given a quaternion representing rotation, get the quaternion representing
466
+ its inverse.
467
+
468
+ Args:
469
+ backend: The backend to use for the computation.
470
+ quaternion: Quaternions as tensor of shape (..., 4), with real part
471
+ first, which must be versors (unit quaternions).
472
+
473
+ Returns:
474
+ The inverse, a tensor of quaternions of shape (..., 4).
475
+ """
476
+
477
+ scaling = backend.reshape(
478
+ backend.asarray([1, -1, -1, -1], device=backend.device(quaternion)),
479
+ [1] * (len(quaternion.shape) - 1) + [4]
480
+ )
481
+ return quaternion * scaling
482
+
483
+
484
+ def quaternion_apply(
485
+ backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
486
+ quaternion: BArrayType, point: BArrayType
487
+ ) -> BArrayType:
488
+ """
489
+ Apply the rotation given by a quaternion to a 3D point.
490
+ Usual torch rules for broadcasting apply.
491
+
492
+ Args:
493
+ backend: The backend to use for the computation.
494
+ quaternion: Tensor of quaternions, real part first, of shape (..., 4).
495
+ point: Tensor of 3D points of shape (..., 3).
496
+
497
+ Returns:
498
+ Tensor of rotated points of shape (..., 3).
499
+ """
500
+ if point.shape[-1] != 3:
501
+ raise ValueError(f"Points are not in 3D, {point.shape}.")
502
+ real_parts = backend.zeros(point.shape[:-1] + (1,), dtype=point.dtype, device=backend.device(point))
503
+ point_as_quaternion = backend.concat((real_parts, point), axis=-1)
504
+ out = quaternion_raw_multiply(
505
+ quaternion_raw_multiply(backend, quaternion, point_as_quaternion),
506
+ quaternion_invert(backend, quaternion),
507
+ )
508
+ return out[..., 1:]
509
+
510
+
511
+ def axis_angle_to_matrix(
512
+ backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
513
+ axis_angle: BArrayType, fast: bool = False) -> BArrayType:
514
+ """
515
+ Convert rotations given as axis/angle to rotation matrices.
516
+
517
+ Args:
518
+ backend: The backend to use for the computation.
519
+ axis_angle: Rotations given as a vector in axis angle form,
520
+ as a tensor of shape (..., 3), where the magnitude is
521
+ the angle turned anticlockwise in radians around the
522
+ vector's direction.
523
+ fast: Whether to use the new faster implementation (based on the
524
+ Rodrigues formula) instead of the original implementation (which
525
+ first converted to a quaternion and then back to a rotation matrix).
526
+
527
+ Returns:
528
+ Rotation matrices as tensor of shape (..., 3, 3).
529
+ """
530
+ if not fast:
531
+ return quaternion_to_matrix(backend, axis_angle_to_quaternion(backend, axis_angle))
532
+
533
+ shape = axis_angle.shape
534
+ device, dtype = backend.device(axis_angle), axis_angle.dtype
535
+
536
+ angles = backend.linalg.vector_norm(axis_angle, ord=2, axis=-1, keepdims=True)[..., backend.newaxis]
537
+
538
+ rx, ry, rz = axis_angle[..., 0], axis_angle[..., 1], axis_angle[..., 2]
539
+ zeros = backend.zeros(shape[:-1], dtype=dtype, device=device)
540
+ cross_product_matrix = backend.reshape(backend.stack(
541
+ [zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], axis=-1
542
+ ), (shape + (3,)))
543
+ cross_product_matrix_sqrd = cross_product_matrix @ cross_product_matrix
544
+
545
+ identity = backend.eye(3, dtype=dtype, device=device)
546
+ angles_sqrd = angles * angles
547
+ angles_sqrd = backend.where(angles_sqrd == 0, 1, angles_sqrd)
548
+ return (
549
+ backend.reshape(identity, [1] * (len(shape) - 1) + (3,3))
550
+ + backend.sinc(angles / backend.pi) * cross_product_matrix
551
+ + ((1 - backend.cos(angles)) / angles_sqrd) * cross_product_matrix_sqrd
552
+ )
553
+
554
+
555
+ def matrix_to_axis_angle(
556
+ backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
557
+ matrix: BArrayType, fast: bool = False
558
+ ) -> BArrayType:
559
+ """
560
+ Convert rotations given as rotation matrices to axis/angle.
561
+
562
+ Args:
563
+ backend: The backend to use for the computation.
564
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
565
+ fast: Whether to use the new faster implementation (based on the
566
+ Rodrigues formula) instead of the original implementation (which
567
+ first converted to a quaternion and then back to a rotation matrix).
568
+
569
+ Returns:
570
+ Rotations given as a vector in axis angle form, as a tensor
571
+ of shape (..., 3), where the magnitude is the angle
572
+ turned anticlockwise in radians around the vector's
573
+ direction.
574
+
575
+ """
576
+ if not fast:
577
+ return quaternion_to_axis_angle(backend, matrix_to_quaternion(backend, matrix))
578
+
579
+ if matrix.shape[-1] != 3 or matrix.shape[-2] != 3:
580
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
581
+
582
+ omegas = backend.stack(
583
+ [
584
+ matrix[..., 2, 1] - matrix[..., 1, 2],
585
+ matrix[..., 0, 2] - matrix[..., 2, 0],
586
+ matrix[..., 1, 0] - matrix[..., 0, 1],
587
+ ],
588
+ axis=-1,
589
+ )
590
+ norms = backend.linalg.vector_norm(omegas, ord=2, axis=-1, keepdims=True)
591
+ traces = backend.sum(backend.linalg.diagonal(matrix), axis=-1, keepdims=True)
592
+ angles = backend.atan2(norms, traces - 1)
593
+
594
+ zeros = backend.zeros(3, dtype=matrix.dtype, device=backend.device(matrix))
595
+ omegas = backend.where(backend.isclose(angles, 0), zeros, omegas)
596
+
597
+ near_pi = backend.isclose(angles, backend.pi)[..., 0]
598
+
599
+ axis_angles = backend.empty_like(omegas)
600
+ axis_angles = backend.at(axis_angles, backend.logical_not(near_pi)).set(
601
+ 0.5 * omegas[~near_pi] / backend.sinc(angles[~near_pi] / backend.pi)
602
+ )
603
+
604
+ # this derives from: nnT = (R + 1) / 2
605
+ n = 0.5 * (
606
+ matrix[near_pi][..., 0, :]
607
+ + backend.eye(1, 3, dtype=matrix.dtype, device=backend.device(matrix))
608
+ )
609
+ # TODO(Yunhao): The original pytorch3d file contains `torch.norm` which ignores batch shape, not sure if this is exactly what we wanted
610
+ axis_angles[near_pi] = angles[near_pi] * n / backend.linalg.vector_norm(n, axis=-1, keepdims=True, ord=2)
611
+
612
+ return axis_angles
613
+
614
+
615
+ def axis_angle_to_quaternion(
616
+ backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
617
+ axis_angle: BArrayType
618
+ ) -> BArrayType:
619
+ """
620
+ Convert rotations given as axis/angle to quaternions.
621
+
622
+ Args:
623
+ backend: The backend to use for the computation.
624
+ axis_angle: Rotations given as a vector in axis angle form,
625
+ as a tensor of shape (..., 3), where the magnitude is
626
+ the angle turned anticlockwise in radians around the
627
+ vector's direction.
628
+
629
+ Returns:
630
+ quaternions with real part first, as tensor of shape (..., 4).
631
+ """
632
+ angles = backend.linalg.vector_norm(axis_angle, ord=2, axis=-1, keepdims=True)
633
+ sin_half_angles_over_angles = 0.5 * backend.sinc(angles * 0.5 / backend.pi)
634
+ return backend.concat(
635
+ [backend.cos(angles * 0.5), axis_angle * sin_half_angles_over_angles], axis=-1
636
+ )
637
+
638
+
639
+ def quaternion_to_axis_angle(
640
+ backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
641
+ quaternions: BArrayType
642
+ ) -> BArrayType:
643
+ """
644
+ Convert rotations given as quaternions to axis/angle.
645
+
646
+ Args:
647
+ backend: The backend to use for the computation.
648
+ quaternions: quaternions with real part first,
649
+ as tensor of shape (..., 4).
650
+
651
+ Returns:
652
+ Rotations given as a vector in axis angle form, as a tensor
653
+ of shape (..., 3), where the magnitude is the angle
654
+ turned anticlockwise in radians around the vector's
655
+ direction.
656
+ """
657
+ norms = backend.linalg.vector_norm(quaternions[..., 1:], ord=2, axis=-1, keepdims=True)
658
+ half_angles = backend.atan2(norms, quaternions[..., :1])
659
+ sin_half_angles_over_angles = 0.5 * backend.sinc(half_angles / backend.pi)
660
+ # angles/2 are between [-pi/2, pi/2], thus sin_half_angles_over_angles
661
+ # can't be zero
662
+ return quaternions[..., 1:] / sin_half_angles_over_angles
663
+
664
+
665
+ def rotation_6d_to_matrix(
666
+ backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
667
+ d6: BArrayType
668
+ ) -> BArrayType:
669
+ """
670
+ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
671
+ using Gram--Schmidt orthogonalization per Section B of [1].
672
+ Args:
673
+ backend: The backend to use for the computation.
674
+ d6: 6D rotation representation, of size (*, 6)
675
+
676
+ Returns:
677
+ batch of rotation matrices of size (*, 3, 3)
678
+
679
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
680
+ On the Continuity of Rotation Representations in Neural Networks.
681
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
682
+ Retrieved from http://arxiv.org/abs/1812.07035
683
+ """
684
+
685
+ a1, a2 = d6[..., :3], d6[..., 3:]
686
+ b1 = a1 / backend.linalg.vector_norm(a1, ord=2, axis=-1, keepdims=True)
687
+ b2 = a2 - backend.sum(b1 * a2, axis=-1, keepdims=True) * b1
688
+ b2 = b2 / backend.linalg.vector_norm(b2, ord=2, axis=-1, keepdims=True)
689
+ b3 = backend.linalg.cross(b1, b2, axis=-1)
690
+ return backend.stack((b1, b2, b3), axis=-2)
691
+
692
+ def matrix_to_rotation_6d(
693
+ backend: ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
694
+ matrix: BArrayType
695
+ ) -> BArrayType:
696
+ """
697
+ Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
698
+ by dropping the last row. Note that 6D representation is not unique.
699
+ Args:
700
+ backend: The backend to use for the computation.
701
+ matrix: batch of rotation matrices of size (*, 3, 3)
702
+
703
+ Returns:
704
+ 6D rotation representation, of size (*, 6)
705
+
706
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
707
+ On the Continuity of Rotation Representations in Neural Networks.
708
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
709
+ Retrieved from http://arxiv.org/abs/1812.07035
710
+ """
711
+ batch_dim = matrix.shape[:-2]
712
+ return backend.reshape(matrix[..., :2, :], (batch_dim + (6,)))