quatorch 0.1.0a0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: quatorch
|
|
3
|
+
Version: 0.1.0a0
|
|
4
|
+
Summary: Quaternion operations in pure PyTorch
|
|
5
|
+
Keywords: quaternion,rotation,3d,orientation,pytorch
|
|
6
|
+
Author: Lucas N. Egidio
|
|
7
|
+
Author-email: Lucas N. Egidio <lucasegidio1@gmail.com>
|
|
8
|
+
License-Expression: MIT
|
|
9
|
+
Classifier: Development Status :: 4 - Beta
|
|
10
|
+
Classifier: Intended Audience :: Science/Research
|
|
11
|
+
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
12
|
+
Classifier: Topic :: Scientific/Engineering :: Physics
|
|
13
|
+
Classifier: Programming Language :: Python :: 3
|
|
14
|
+
Requires-Dist: torch>=2.8.0
|
|
15
|
+
Requires-Dist: torch ; extra == 'cpu'
|
|
16
|
+
Maintainer: Lucas N. Egidio
|
|
17
|
+
Maintainer-email: Lucas N. Egidio <lucasegidio1@gmail.com>
|
|
18
|
+
Requires-Python: >=3.10
|
|
19
|
+
Project-URL: Documentation, https://github.com/egidioln/QuaTorch/
|
|
20
|
+
Project-URL: Homepage, https://github.com/egidioln/QuaTorch/
|
|
21
|
+
Project-URL: Issues, https://github.com/egidioln/QuaTorch/issues
|
|
22
|
+
Project-URL: Repository, https://github.com/egidioln/QuaTorch.git
|
|
23
|
+
Provides-Extra: cpu
|
|
24
|
+
Description-Content-Type: text/markdown
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
<img src="https://raw.githubusercontent.com/egidioln/QuaTorch/refs/heads/main/docs/source/_static/logo.svg" width="400px">
|
|
28
|
+
|
|
29
|
+
***Quaternions in PyTorch***
|
|
30
|
+
# QuaTorch
|
|
31
|
+
<!--[](https://github.com/actions/deploy-pages/releases/latest) -->
|
|
32
|
+
[](https://github.com/egidioln/quatorch/actions)
|
|
33
|
+
[](https://github.com/egidioln/QuaTorch/actions/workflows/pytest.yml)
|
|
34
|
+
[](https://egidioln.github.io/QuaTorch/)
|
|
35
|
+
|
|
36
|
+
The package `quatorch` provides `Quaternion`, a `torch.Tensor` subclass that represents a [Quaternion](https://en.wikipedia.org/wiki/Quaternion). It implements common operations in following quaternion algebra such as multiplication,
|
|
37
|
+
conjugation, inversion, normalization, log, exp, etc. It also supports conversion to/from rotation matrix and axis-angle representation. Convenient utilities are provided together, such as spherical linear interpolation ([slerp](https://en.wikipedia.org/wiki/Slerp)) and 3D vector rotation.
|
|
38
|
+
|
|
39
|
+
## Highlights
|
|
40
|
+
|
|
41
|
+
- Quaternion type: `quatorch.Quaternion` (subclass of `torch.Tensor`).
|
|
42
|
+
- Element-wise and algebraic ops implemented: `+`, `-`, `*` (quaternion product and scalar mul),
|
|
43
|
+
`abs` (norm), `conjugate`, `inverse`, `normalize`, `to_rotation_matrix`, and more.
|
|
44
|
+
- Utilities: `from_rotation_matrix`, `from_axis_angle`, `to_axis_angle`, `rotate_vector`, `slerp`,
|
|
45
|
+
`log`, `exp`, and `pow`.
|
|
46
|
+
|
|
47
|
+
## Installation
|
|
48
|
+
|
|
49
|
+
This project targets Python 3.10+ and requires PyTorch. Install via pip (recommended):
|
|
50
|
+
|
|
51
|
+
```bash
|
|
52
|
+
pip install quatorch
|
|
53
|
+
```
|
|
54
|
+
|
|
55
|
+
Or install editable/development mode:
|
|
56
|
+
|
|
57
|
+
```bash
|
|
58
|
+
git clone
|
|
59
|
+
cd QuaTorch
|
|
60
|
+
pip install -e .
|
|
61
|
+
```
|
|
62
|
+
|
|
63
|
+
See `pyproject.toml` for dependency details (dev deps include `pytest`).
|
|
64
|
+
|
|
65
|
+
## Quick start
|
|
66
|
+
|
|
67
|
+
Basic usage examples using PyTorch tensors and `Quaternion`:
|
|
68
|
+
|
|
69
|
+
```python
|
|
70
|
+
import torch
|
|
71
|
+
from quatorch.quaternion import Quaternion
|
|
72
|
+
|
|
73
|
+
# Create a quaternion from four scalars (W, X, Y, Z)
|
|
74
|
+
q = Quaternion(1.0, 0.0, 0.0, 0.0)
|
|
75
|
+
|
|
76
|
+
# Or from a tensor of shape (..., 4)
|
|
77
|
+
q2 = Quaternion(torch.tensor([0.9239, 0.3827, 0.0, 0.0])) # 45° around X
|
|
78
|
+
|
|
79
|
+
# Normalize
|
|
80
|
+
q2 = q2.normalize()
|
|
81
|
+
|
|
82
|
+
# Quaternion multiplication (rotation composition)
|
|
83
|
+
q3 = q * q2
|
|
84
|
+
|
|
85
|
+
# Rotate a vector
|
|
86
|
+
v = torch.tensor([1.0, 0.0, 0.0])
|
|
87
|
+
v_rot = q2.rotate_vector(v)
|
|
88
|
+
|
|
89
|
+
# Convert to rotation matrix
|
|
90
|
+
R = q2.to_rotation_matrix()
|
|
91
|
+
|
|
92
|
+
# Slerp between quaternions
|
|
93
|
+
t = 0.5
|
|
94
|
+
q_mid = q.slerp(q2, t)
|
|
95
|
+
```
|
|
96
|
+
|
|
97
|
+
## API notes
|
|
98
|
+
|
|
99
|
+
- Construction:
|
|
100
|
+
- `Quaternion(data: torch.Tensor)` where `data.shape[-1] == 4`.
|
|
101
|
+
- `Quaternion(w, x, y, z)` accepts scalars or tensors broadcastable to the same shape.
|
|
102
|
+
|
|
103
|
+
- Shape requirements:
|
|
104
|
+
- The last dimension must be size 4 for quaternion tensors (W, X, Y, Z).
|
|
105
|
+
|
|
106
|
+
- Interoperability:
|
|
107
|
+
- The class implements several `torch.*` functions via a small dispatcher so
|
|
108
|
+
many PyTorch APIs behave sensibly with `Quaternion` objects.
|
|
109
|
+
|
|
110
|
+
See the source in `src/quatorch/quaternion.py` for the full implementation and more
|
|
111
|
+
helper methods.
|
|
112
|
+
|
|
113
|
+
## Running tests
|
|
114
|
+
|
|
115
|
+
This repository includes unit tests using `pytest` under `test/unit_tests`.
|
|
116
|
+
|
|
117
|
+
From the project root, run:
|
|
118
|
+
|
|
119
|
+
```bash
|
|
120
|
+
uv run --with=. pytest
|
|
121
|
+
```
|
|
122
|
+
|
|
123
|
+
## Contributing
|
|
124
|
+
|
|
125
|
+
Contributions are welcome. A few ideas:
|
|
126
|
+
|
|
127
|
+
- Add more conversions and higher-level utilities (e.g., batch rotation helpers).
|
|
128
|
+
- Improve numeric stability and add property-based tests.
|
|
129
|
+
- Add docs and usage notebooks / examples.
|
|
130
|
+
|
|
131
|
+
Please open issues or pull requests on the repository.
|
|
132
|
+
|
|
133
|
+
## License
|
|
134
|
+
|
|
135
|
+
MIT — see the [LICENSE.md](./LICENSE.md) for author/maintainer information.
|
|
136
|
+
|
|
137
|
+
## Contact
|
|
138
|
+
|
|
139
|
+
Maintainer: Lucas N. Egidio <lucasegidio1@gmail.com>
|
|
140
|
+
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
|
|
2
|
+
<img src="https://raw.githubusercontent.com/egidioln/QuaTorch/refs/heads/main/docs/source/_static/logo.svg" width="400px">
|
|
3
|
+
|
|
4
|
+
***Quaternions in PyTorch***
|
|
5
|
+
# QuaTorch
|
|
6
|
+
<!--[](https://github.com/actions/deploy-pages/releases/latest) -->
|
|
7
|
+
[](https://github.com/egidioln/quatorch/actions)
|
|
8
|
+
[](https://github.com/egidioln/QuaTorch/actions/workflows/pytest.yml)
|
|
9
|
+
[](https://egidioln.github.io/QuaTorch/)
|
|
10
|
+
|
|
11
|
+
The package `quatorch` provides `Quaternion`, a `torch.Tensor` subclass that represents a [Quaternion](https://en.wikipedia.org/wiki/Quaternion). It implements common operations in following quaternion algebra such as multiplication,
|
|
12
|
+
conjugation, inversion, normalization, log, exp, etc. It also supports conversion to/from rotation matrix and axis-angle representation. Convenient utilities are provided together, such as spherical linear interpolation ([slerp](https://en.wikipedia.org/wiki/Slerp)) and 3D vector rotation.
|
|
13
|
+
|
|
14
|
+
## Highlights
|
|
15
|
+
|
|
16
|
+
- Quaternion type: `quatorch.Quaternion` (subclass of `torch.Tensor`).
|
|
17
|
+
- Element-wise and algebraic ops implemented: `+`, `-`, `*` (quaternion product and scalar mul),
|
|
18
|
+
`abs` (norm), `conjugate`, `inverse`, `normalize`, `to_rotation_matrix`, and more.
|
|
19
|
+
- Utilities: `from_rotation_matrix`, `from_axis_angle`, `to_axis_angle`, `rotate_vector`, `slerp`,
|
|
20
|
+
`log`, `exp`, and `pow`.
|
|
21
|
+
|
|
22
|
+
## Installation
|
|
23
|
+
|
|
24
|
+
This project targets Python 3.10+ and requires PyTorch. Install via pip (recommended):
|
|
25
|
+
|
|
26
|
+
```bash
|
|
27
|
+
pip install quatorch
|
|
28
|
+
```
|
|
29
|
+
|
|
30
|
+
Or install editable/development mode:
|
|
31
|
+
|
|
32
|
+
```bash
|
|
33
|
+
git clone
|
|
34
|
+
cd QuaTorch
|
|
35
|
+
pip install -e .
|
|
36
|
+
```
|
|
37
|
+
|
|
38
|
+
See `pyproject.toml` for dependency details (dev deps include `pytest`).
|
|
39
|
+
|
|
40
|
+
## Quick start
|
|
41
|
+
|
|
42
|
+
Basic usage examples using PyTorch tensors and `Quaternion`:
|
|
43
|
+
|
|
44
|
+
```python
|
|
45
|
+
import torch
|
|
46
|
+
from quatorch.quaternion import Quaternion
|
|
47
|
+
|
|
48
|
+
# Create a quaternion from four scalars (W, X, Y, Z)
|
|
49
|
+
q = Quaternion(1.0, 0.0, 0.0, 0.0)
|
|
50
|
+
|
|
51
|
+
# Or from a tensor of shape (..., 4)
|
|
52
|
+
q2 = Quaternion(torch.tensor([0.9239, 0.3827, 0.0, 0.0])) # 45° around X
|
|
53
|
+
|
|
54
|
+
# Normalize
|
|
55
|
+
q2 = q2.normalize()
|
|
56
|
+
|
|
57
|
+
# Quaternion multiplication (rotation composition)
|
|
58
|
+
q3 = q * q2
|
|
59
|
+
|
|
60
|
+
# Rotate a vector
|
|
61
|
+
v = torch.tensor([1.0, 0.0, 0.0])
|
|
62
|
+
v_rot = q2.rotate_vector(v)
|
|
63
|
+
|
|
64
|
+
# Convert to rotation matrix
|
|
65
|
+
R = q2.to_rotation_matrix()
|
|
66
|
+
|
|
67
|
+
# Slerp between quaternions
|
|
68
|
+
t = 0.5
|
|
69
|
+
q_mid = q.slerp(q2, t)
|
|
70
|
+
```
|
|
71
|
+
|
|
72
|
+
## API notes
|
|
73
|
+
|
|
74
|
+
- Construction:
|
|
75
|
+
- `Quaternion(data: torch.Tensor)` where `data.shape[-1] == 4`.
|
|
76
|
+
- `Quaternion(w, x, y, z)` accepts scalars or tensors broadcastable to the same shape.
|
|
77
|
+
|
|
78
|
+
- Shape requirements:
|
|
79
|
+
- The last dimension must be size 4 for quaternion tensors (W, X, Y, Z).
|
|
80
|
+
|
|
81
|
+
- Interoperability:
|
|
82
|
+
- The class implements several `torch.*` functions via a small dispatcher so
|
|
83
|
+
many PyTorch APIs behave sensibly with `Quaternion` objects.
|
|
84
|
+
|
|
85
|
+
See the source in `src/quatorch/quaternion.py` for the full implementation and more
|
|
86
|
+
helper methods.
|
|
87
|
+
|
|
88
|
+
## Running tests
|
|
89
|
+
|
|
90
|
+
This repository includes unit tests using `pytest` under `test/unit_tests`.
|
|
91
|
+
|
|
92
|
+
From the project root, run:
|
|
93
|
+
|
|
94
|
+
```bash
|
|
95
|
+
uv run --with=. pytest
|
|
96
|
+
```
|
|
97
|
+
|
|
98
|
+
## Contributing
|
|
99
|
+
|
|
100
|
+
Contributions are welcome. A few ideas:
|
|
101
|
+
|
|
102
|
+
- Add more conversions and higher-level utilities (e.g., batch rotation helpers).
|
|
103
|
+
- Improve numeric stability and add property-based tests.
|
|
104
|
+
- Add docs and usage notebooks / examples.
|
|
105
|
+
|
|
106
|
+
Please open issues or pull requests on the repository.
|
|
107
|
+
|
|
108
|
+
## License
|
|
109
|
+
|
|
110
|
+
MIT — see the [LICENSE.md](./LICENSE.md) for author/maintainer information.
|
|
111
|
+
|
|
112
|
+
## Contact
|
|
113
|
+
|
|
114
|
+
Maintainer: Lucas N. Egidio <lucasegidio1@gmail.com>
|
|
115
|
+
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "quatorch"
|
|
3
|
+
version = "0.1.0-alpha"
|
|
4
|
+
description = "Quaternion operations in pure PyTorch"
|
|
5
|
+
keywords = ["quaternion", "rotation", "3d", "orientation", "pytorch"]
|
|
6
|
+
readme = "README.md"
|
|
7
|
+
license = "MIT"
|
|
8
|
+
requires-python = ">=3.10"
|
|
9
|
+
dependencies = [
|
|
10
|
+
"torch>=2.8.0",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
authors = [
|
|
14
|
+
{name = "Lucas N. Egidio", email = "lucasegidio1@gmail.com"},
|
|
15
|
+
]
|
|
16
|
+
maintainers = [
|
|
17
|
+
{name = "Lucas N. Egidio", email = "lucasegidio1@gmail.com"},
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
classifiers = [
|
|
21
|
+
# TODO change to 5 - Production/Stable
|
|
22
|
+
"Development Status :: 4 - Beta",
|
|
23
|
+
"Intended Audience :: Science/Research",
|
|
24
|
+
|
|
25
|
+
# TODO verify "License :: OSI Approved :: MIT License",
|
|
26
|
+
"Topic :: Scientific/Engineering :: Mathematics",
|
|
27
|
+
"Topic :: Scientific/Engineering :: Physics",
|
|
28
|
+
|
|
29
|
+
# Specify the Python versions you support here.
|
|
30
|
+
"Programming Language :: Python :: 3",
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
[project.optional-dependencies]
|
|
35
|
+
cpu = [
|
|
36
|
+
"torch"
|
|
37
|
+
]
|
|
38
|
+
|
|
39
|
+
[tool.uv.sources]
|
|
40
|
+
torch = [
|
|
41
|
+
{ index = "pytorch-cpu", extra = "cpu" },
|
|
42
|
+
]
|
|
43
|
+
|
|
44
|
+
[[tool.uv.index]]
|
|
45
|
+
name = "pytorch-cpu"
|
|
46
|
+
url = "https://download.pytorch.org/whl/cpu"
|
|
47
|
+
explicit = true
|
|
48
|
+
|
|
49
|
+
[project.urls]
|
|
50
|
+
Homepage = "https://github.com/egidioln/QuaTorch/"
|
|
51
|
+
Documentation = "https://github.com/egidioln/QuaTorch/"
|
|
52
|
+
Repository = "https://github.com/egidioln/QuaTorch.git"
|
|
53
|
+
Issues = "https://github.com/egidioln/QuaTorch/issues"
|
|
54
|
+
# Changelog = "https://github.com/me/spam/blob/master/CHANGELOG.md"
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
[dependency-groups]
|
|
58
|
+
dev = [
|
|
59
|
+
"furo>=2025.7.19",
|
|
60
|
+
"myst-parser>=4.0.1",
|
|
61
|
+
"pytest>=8.4.2",
|
|
62
|
+
"pytest-cov>=7.0.0",
|
|
63
|
+
"ruff>=0.13.0",
|
|
64
|
+
"sphinx>=8.1.3",
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
[build-system]
|
|
69
|
+
requires = ["uv_build >= 0.8.17, <0.9.0"]
|
|
70
|
+
build-backend = "uv_build"
|
|
71
|
+
|
|
72
|
+
[tool.pytest.ini_options]
|
|
73
|
+
testpaths = ["test"]
|
|
74
|
+
# Optional: set python_files pattern to include test files under unit_tests
|
|
75
|
+
python_files = ["test_*.py"]
|
|
76
|
+
addopts = "--cov=quatorch --cov-report xml:coverage.xml --cov-report html:htmlcov --cov-report json"
|
|
77
|
+
|
|
78
|
+
[tool.coverage.report]
|
|
79
|
+
exclude_also = [
|
|
80
|
+
'raise ValueError',
|
|
81
|
+
'raise TypeError'
|
|
82
|
+
]
|
|
@@ -0,0 +1,309 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
from typing import Any, overload
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
HANDLED_FUNCTIONS = {}
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def CHECK_OPERAND_SHAPE(other: Any, scalar_allowed: bool = True):
|
|
10
|
+
if torch.is_tensor(other) and other.shape[-1] not in (1, 4):
|
|
11
|
+
raise ValueError(
|
|
12
|
+
"The last dimension must be of size 4 to represent a quaternion (WXYZ) or size 1 to represent a real scalar."
|
|
13
|
+
)
|
|
14
|
+
if torch.is_tensor(other) and other.dtype in [torch.complex64, torch.complex128]:
|
|
15
|
+
raise TypeError("Cannot operate between quaternion and complex tensors.")
|
|
16
|
+
if not scalar_allowed and isinstance(other, (int, float, complex)):
|
|
17
|
+
raise TypeError("Operand must not be a scalar.")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def implements(torch_function):
|
|
21
|
+
"""Register a torch function override for ScalarTensor"""
|
|
22
|
+
|
|
23
|
+
@functools.wraps(torch_function)
|
|
24
|
+
def decorator(func):
|
|
25
|
+
HANDLED_FUNCTIONS[torch_function] = func
|
|
26
|
+
return func
|
|
27
|
+
|
|
28
|
+
return decorator
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class Quaternion(torch.Tensor):
|
|
32
|
+
r"""A `torch.Tensor` subclass representing quaternions.
|
|
33
|
+
Quaternions are represented as tensors of shape (..., 4), where the last dimension
|
|
34
|
+
corresponds to the components (W, X, Y, Z).
|
|
35
|
+
Args:
|
|
36
|
+
|
|
37
|
+
w: The real part of the quaternion.
|
|
38
|
+
x: The first imaginary part of the quaternion.
|
|
39
|
+
y: The second imaginary part of the quaternion.
|
|
40
|
+
z: The third imaginary part of the quaternion.
|
|
41
|
+
or
|
|
42
|
+
data: A tensor of shape (..., 4) representing the quaternion components
|
|
43
|
+
in the order (W, X, Y, Z).
|
|
44
|
+
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __new__(cls, *args, **kwargs):
|
|
48
|
+
if "data" in kwargs or (len(args) == 1 and isinstance(args[0], torch.Tensor)):
|
|
49
|
+
return super().__new__(cls, *args, **kwargs)
|
|
50
|
+
if len(args) == 4:
|
|
51
|
+
tensors = tuple(torch.as_tensor(arg) for arg in args)
|
|
52
|
+
try:
|
|
53
|
+
data = torch.stack(tensors, dim=-1)
|
|
54
|
+
except RuntimeError as e:
|
|
55
|
+
raise ValueError("All input tensors must have the same shape.") from e
|
|
56
|
+
return super().__new__(cls, data)
|
|
57
|
+
if all(_ in kwargs for _ in "wxyz"):
|
|
58
|
+
return super().__new__(
|
|
59
|
+
cls,
|
|
60
|
+
torch.stack(
|
|
61
|
+
[kwargs["w"], kwargs["x"], kwargs["y"], kwargs["z"]], dim=-1
|
|
62
|
+
),
|
|
63
|
+
)
|
|
64
|
+
raise ValueError("Invalid arguments for Quaternion initialization.")
|
|
65
|
+
|
|
66
|
+
@overload
|
|
67
|
+
def __init__(self, data: torch.Tensor, *args, **kwargs): ...
|
|
68
|
+
|
|
69
|
+
@overload
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
w: float | torch.Tensor,
|
|
73
|
+
x: float | torch.Tensor,
|
|
74
|
+
y: float | torch.Tensor,
|
|
75
|
+
z: float | torch.Tensor,
|
|
76
|
+
**kwargs,
|
|
77
|
+
): ...
|
|
78
|
+
|
|
79
|
+
def __init__(self, *args, **kwargs):
|
|
80
|
+
super().__init__()
|
|
81
|
+
if self.shape[-1] != 4:
|
|
82
|
+
raise ValueError(
|
|
83
|
+
"The last dimension must be of size 4 to represent a quaternion (WXYZ)."
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
def to_wxyz(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
87
|
+
return self.as_subclass(torch.Tensor).unbind(-1)
|
|
88
|
+
|
|
89
|
+
@classmethod
|
|
90
|
+
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
91
|
+
if kwargs is None:
|
|
92
|
+
kwargs = {}
|
|
93
|
+
if func in HANDLED_FUNCTIONS:
|
|
94
|
+
return HANDLED_FUNCTIONS[func](*args, **kwargs)
|
|
95
|
+
return super().__torch_function__(func, types, args, kwargs)
|
|
96
|
+
|
|
97
|
+
@implements(torch.add)
|
|
98
|
+
def add(self, other):
|
|
99
|
+
CHECK_OPERAND_SHAPE(other, scalar_allowed=False)
|
|
100
|
+
|
|
101
|
+
return Quaternion(super().add(other))
|
|
102
|
+
|
|
103
|
+
@implements(torch.mul)
|
|
104
|
+
def mul(self, other):
|
|
105
|
+
CHECK_OPERAND_SHAPE(other, scalar_allowed=True)
|
|
106
|
+
|
|
107
|
+
if isinstance(other, (int, float)):
|
|
108
|
+
return Quaternion(super().mul(other))
|
|
109
|
+
|
|
110
|
+
w1, x1, y1, z1 = self.to_wxyz()
|
|
111
|
+
if other.shape[-1] == 4:
|
|
112
|
+
w2, x2, y2, z2 = Quaternion(other).to_wxyz()
|
|
113
|
+
else: # scalar
|
|
114
|
+
w2 = other[..., 0]
|
|
115
|
+
x2 = 0
|
|
116
|
+
y2 = 0
|
|
117
|
+
z2 = 0
|
|
118
|
+
|
|
119
|
+
w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
|
|
120
|
+
x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
|
|
121
|
+
y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
|
|
122
|
+
z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
|
|
123
|
+
|
|
124
|
+
return Quaternion(torch.stack([w, x, y, z], dim=-1))
|
|
125
|
+
|
|
126
|
+
@implements(torch.sub)
|
|
127
|
+
def sub(self, other):
|
|
128
|
+
CHECK_OPERAND_SHAPE(other, scalar_allowed=False)
|
|
129
|
+
|
|
130
|
+
return Quaternion(super().sub(other))
|
|
131
|
+
|
|
132
|
+
@implements(torch.abs)
|
|
133
|
+
def abs(self):
|
|
134
|
+
w, x, y, z = self.to_wxyz()
|
|
135
|
+
return torch.sqrt(w**2 + x**2 + y**2 + z**2)
|
|
136
|
+
|
|
137
|
+
def __add__(self, other):
|
|
138
|
+
return self.add(other)
|
|
139
|
+
|
|
140
|
+
def __sub__(self, other):
|
|
141
|
+
return self.sub(other)
|
|
142
|
+
|
|
143
|
+
def __mul__(self, other):
|
|
144
|
+
return self.mul(other)
|
|
145
|
+
|
|
146
|
+
def __pow__(self, other):
|
|
147
|
+
return self.pow(other)
|
|
148
|
+
|
|
149
|
+
def __truediv__(self, other):
|
|
150
|
+
if isinstance(other, Quaternion):
|
|
151
|
+
return self.mul(other.inverse())
|
|
152
|
+
return self.div(other)
|
|
153
|
+
|
|
154
|
+
def conjugate(self):
|
|
155
|
+
w, x, y, z = self.to_wxyz()
|
|
156
|
+
return Quaternion(torch.stack([w, -x, -y, -z], dim=-1))
|
|
157
|
+
|
|
158
|
+
def inverse(self):
|
|
159
|
+
norm_sq = self.abs() ** 2
|
|
160
|
+
conj = self.conjugate()
|
|
161
|
+
return Quaternion(conj / norm_sq.unsqueeze(-1))
|
|
162
|
+
|
|
163
|
+
def normalize(self):
|
|
164
|
+
norm = self.abs()
|
|
165
|
+
return Quaternion(self / norm.unsqueeze(-1))
|
|
166
|
+
|
|
167
|
+
def to_rotation_matrix(self):
|
|
168
|
+
w, x, y, z = self.to_wxyz()
|
|
169
|
+
leading_dims = self.shape[:-1]
|
|
170
|
+
rotation_matrix = torch.empty(
|
|
171
|
+
*leading_dims, 3, 3, device=self.device, dtype=self.dtype
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
rotation_matrix[..., 0, 0] = 1 - 2 * (y**2 + z**2)
|
|
175
|
+
rotation_matrix[..., 0, 1] = 2 * (x * y - z * w)
|
|
176
|
+
rotation_matrix[..., 0, 2] = 2 * (x * z + y * w)
|
|
177
|
+
|
|
178
|
+
rotation_matrix[..., 1, 0] = 2 * (x * y + z * w)
|
|
179
|
+
rotation_matrix[..., 1, 1] = 1 - 2 * (x**2 + z**2)
|
|
180
|
+
rotation_matrix[..., 1, 2] = 2 * (y * z - x * w)
|
|
181
|
+
|
|
182
|
+
rotation_matrix[..., 2, 0] = 2 * (x * z - y * w)
|
|
183
|
+
rotation_matrix[..., 2, 1] = 2 * (y * z + x * w)
|
|
184
|
+
rotation_matrix[..., 2, 2] = 1 - 2 * (x**2 + y**2)
|
|
185
|
+
|
|
186
|
+
return rotation_matrix
|
|
187
|
+
|
|
188
|
+
@staticmethod
|
|
189
|
+
def from_rotation_matrix(R: torch.Tensor):
|
|
190
|
+
if R.shape[-2:] != (3, 3):
|
|
191
|
+
raise ValueError("Input rotation matrix must have shape (..., 3, 3)")
|
|
192
|
+
B = R.shape[:-2]
|
|
193
|
+
R = R.reshape(-1, 3, 3)
|
|
194
|
+
|
|
195
|
+
trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]
|
|
196
|
+
w = torch.sqrt(1.0 + trace) / 2.0
|
|
197
|
+
x = (R[..., 2, 1] - R[..., 1, 2]) / (4.0 * w)
|
|
198
|
+
y = (R[..., 0, 2] - R[..., 2, 0]) / (4.0 * w)
|
|
199
|
+
z = (R[..., 1, 0] - R[..., 0, 1]) / (4.0 * w)
|
|
200
|
+
|
|
201
|
+
q = torch.stack([w, x, y, z], dim=-1)
|
|
202
|
+
q = q.reshape(*B, 4)
|
|
203
|
+
return Quaternion(q)
|
|
204
|
+
|
|
205
|
+
@staticmethod
|
|
206
|
+
def from_axis_angle(axis: torch.Tensor, angle: torch.Tensor):
|
|
207
|
+
if axis.shape[-1] != 3:
|
|
208
|
+
raise ValueError("Axis must have shape (..., 3)")
|
|
209
|
+
if axis.dim() != angle.dim() + 1:
|
|
210
|
+
raise ValueError(
|
|
211
|
+
"Axis (..., 3) and angle (...) must have the same number of leading dimensions"
|
|
212
|
+
)
|
|
213
|
+
if axis.shape[:-1] != angle.shape:
|
|
214
|
+
raise ValueError("Axis and angle must have compatible shapes")
|
|
215
|
+
|
|
216
|
+
half_angle = angle / 2.0
|
|
217
|
+
sin_half_angle = torch.sin(half_angle)
|
|
218
|
+
cos_half_angle = torch.cos(half_angle)
|
|
219
|
+
|
|
220
|
+
axis = axis / torch.norm(axis, dim=-1, keepdim=True)
|
|
221
|
+
|
|
222
|
+
w = cos_half_angle
|
|
223
|
+
x = axis[..., 0] * sin_half_angle
|
|
224
|
+
y = axis[..., 1] * sin_half_angle
|
|
225
|
+
z = axis[..., 2] * sin_half_angle
|
|
226
|
+
|
|
227
|
+
q = torch.stack([w, x, y, z], dim=-1)
|
|
228
|
+
return Quaternion(q)
|
|
229
|
+
|
|
230
|
+
def to_axis_angle(self):
|
|
231
|
+
w, x, y, z = self.to_wxyz()
|
|
232
|
+
angle = 2 * torch.acos(w)
|
|
233
|
+
s = torch.sqrt(1 - w**2)
|
|
234
|
+
s = torch.where(s < 1e-8, torch.tensor(1e-8, device=s.device, dtype=s.dtype), s)
|
|
235
|
+
axis = torch.stack([x / s, y / s, z / s], dim=-1)
|
|
236
|
+
return axis, angle
|
|
237
|
+
|
|
238
|
+
def rotate_vector(self, v: torch.Tensor):
|
|
239
|
+
if v.shape[-1] != 3:
|
|
240
|
+
raise ValueError("Input vector must have shape (..., 3)")
|
|
241
|
+
|
|
242
|
+
v_quat = Quaternion(
|
|
243
|
+
torch.zeros_like(v[..., 0]), v[..., 0], v[..., 1], v[..., 2]
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
rotated_v_quat = self * v_quat * self.conjugate()
|
|
247
|
+
return torch.stack(
|
|
248
|
+
[rotated_v_quat[..., 1], rotated_v_quat[..., 2], rotated_v_quat[..., 3]],
|
|
249
|
+
dim=-1,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
def slerp(self, other: "Quaternion", t: float | torch.Tensor):
|
|
253
|
+
CHECK_OPERAND_SHAPE(other, scalar_allowed=False)
|
|
254
|
+
if self.shape != other.shape:
|
|
255
|
+
raise ValueError("Quaternions must have the same shape for slerp.")
|
|
256
|
+
if isinstance(t, (int, float)):
|
|
257
|
+
t = torch.tensor(t, device=self.device, dtype=self.dtype)
|
|
258
|
+
|
|
259
|
+
return self * (self.inverse() * other) ** t
|
|
260
|
+
|
|
261
|
+
@implements(torch.log)
|
|
262
|
+
def log(self):
|
|
263
|
+
w, x, y, z = self.to_wxyz()
|
|
264
|
+
v_norm = torch.sqrt(x**2 + y**2 + z**2)
|
|
265
|
+
v_norm = torch.where(
|
|
266
|
+
v_norm < 1e-8,
|
|
267
|
+
torch.tensor(1e-8, device=v_norm.device, dtype=v_norm.dtype),
|
|
268
|
+
v_norm,
|
|
269
|
+
)
|
|
270
|
+
theta = torch.atan2(v_norm, w)
|
|
271
|
+
coeff = theta / v_norm
|
|
272
|
+
return Quaternion(
|
|
273
|
+
torch.stack([0.0 * w, x * coeff, y * coeff, z * coeff], dim=-1)
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
@implements(torch.exp)
|
|
277
|
+
def exp(self):
|
|
278
|
+
w, x, y, z = self.to_wxyz()
|
|
279
|
+
v_norm = torch.sqrt(x**2 + y**2 + z**2)
|
|
280
|
+
exp_w = torch.exp(w)
|
|
281
|
+
cos_v_norm = torch.cos(v_norm)
|
|
282
|
+
sin_v_norm = torch.sin(v_norm)
|
|
283
|
+
coeff = torch.where(
|
|
284
|
+
v_norm < 1e-8,
|
|
285
|
+
torch.tensor(1.0, device=v_norm.device, dtype=v_norm.dtype),
|
|
286
|
+
sin_v_norm / v_norm,
|
|
287
|
+
)
|
|
288
|
+
return Quaternion(
|
|
289
|
+
torch.stack(
|
|
290
|
+
[
|
|
291
|
+
exp_w * cos_v_norm,
|
|
292
|
+
exp_w * x * coeff,
|
|
293
|
+
exp_w * y * coeff,
|
|
294
|
+
exp_w * z * coeff,
|
|
295
|
+
],
|
|
296
|
+
dim=-1,
|
|
297
|
+
)
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
@implements(torch.pow)
|
|
301
|
+
def pow(self, exponent: float | torch.Tensor):
|
|
302
|
+
if isinstance(exponent, (int, float)):
|
|
303
|
+
exponent = torch.tensor(exponent, device=self.device, dtype=self.dtype)
|
|
304
|
+
if exponent.dim() == 0:
|
|
305
|
+
exponent = exponent.unsqueeze(0)
|
|
306
|
+
|
|
307
|
+
log_q = self.log()
|
|
308
|
+
scaled_log_q = Quaternion(log_q * exponent)
|
|
309
|
+
return scaled_log_q.exp()
|