quatorch 0.1.0a0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,140 @@
1
+ Metadata-Version: 2.4
2
+ Name: quatorch
3
+ Version: 0.1.0a0
4
+ Summary: Quaternion operations in pure PyTorch
5
+ Keywords: quaternion,rotation,3d,orientation,pytorch
6
+ Author: Lucas N. Egidio
7
+ Author-email: Lucas N. Egidio <lucasegidio1@gmail.com>
8
+ License-Expression: MIT
9
+ Classifier: Development Status :: 4 - Beta
10
+ Classifier: Intended Audience :: Science/Research
11
+ Classifier: Topic :: Scientific/Engineering :: Mathematics
12
+ Classifier: Topic :: Scientific/Engineering :: Physics
13
+ Classifier: Programming Language :: Python :: 3
14
+ Requires-Dist: torch>=2.8.0
15
+ Requires-Dist: torch ; extra == 'cpu'
16
+ Maintainer: Lucas N. Egidio
17
+ Maintainer-email: Lucas N. Egidio <lucasegidio1@gmail.com>
18
+ Requires-Python: >=3.10
19
+ Project-URL: Documentation, https://github.com/egidioln/QuaTorch/
20
+ Project-URL: Homepage, https://github.com/egidioln/QuaTorch/
21
+ Project-URL: Issues, https://github.com/egidioln/QuaTorch/issues
22
+ Project-URL: Repository, https://github.com/egidioln/QuaTorch.git
23
+ Provides-Extra: cpu
24
+ Description-Content-Type: text/markdown
25
+
26
+
27
+ <img src="https://raw.githubusercontent.com/egidioln/QuaTorch/refs/heads/main/docs/source/_static/logo.svg" width="400px">
28
+
29
+ ***Quaternions in PyTorch***
30
+ # QuaTorch
31
+ <!--[![Release](https://img.shields.io/github/v/release/actions/deploy-pages?label=Release&logo=github)](https://github.com/actions/deploy-pages/releases/latest) -->
32
+ [![cov](https://raw.githubusercontent.com/egidioln/QuaTorch/refs/heads/gh-pages/badges/coverage.svg)](https://github.com/egidioln/quatorch/actions)
33
+ [![Tests](https://img.shields.io/github/actions/workflow/status/egidioln/QuaTorch/pytest.yml?label=Tests&logo=github)](https://github.com/egidioln/QuaTorch/actions/workflows/pytest.yml)
34
+ [![Docs](https://img.shields.io/github/actions/workflow/status/egidioln/QuaTorch/docs.yml?label=Docs&logo=github)](https://egidioln.github.io/QuaTorch/)
35
+
36
+ The package `quatorch` provides `Quaternion`, a `torch.Tensor` subclass that represents a [Quaternion](https://en.wikipedia.org/wiki/Quaternion). It implements common operations in following quaternion algebra such as multiplication,
37
+ conjugation, inversion, normalization, log, exp, etc. It also supports conversion to/from rotation matrix and axis-angle representation. Convenient utilities are provided together, such as spherical linear interpolation ([slerp](https://en.wikipedia.org/wiki/Slerp)) and 3D vector rotation.
38
+
39
+ ## Highlights
40
+
41
+ - Quaternion type: `quatorch.Quaternion` (subclass of `torch.Tensor`).
42
+ - Element-wise and algebraic ops implemented: `+`, `-`, `*` (quaternion product and scalar mul),
43
+ `abs` (norm), `conjugate`, `inverse`, `normalize`, `to_rotation_matrix`, and more.
44
+ - Utilities: `from_rotation_matrix`, `from_axis_angle`, `to_axis_angle`, `rotate_vector`, `slerp`,
45
+ `log`, `exp`, and `pow`.
46
+
47
+ ## Installation
48
+
49
+ This project targets Python 3.10+ and requires PyTorch. Install via pip (recommended):
50
+
51
+ ```bash
52
+ pip install quatorch
53
+ ```
54
+
55
+ Or install editable/development mode:
56
+
57
+ ```bash
58
+ git clone
59
+ cd QuaTorch
60
+ pip install -e .
61
+ ```
62
+
63
+ See `pyproject.toml` for dependency details (dev deps include `pytest`).
64
+
65
+ ## Quick start
66
+
67
+ Basic usage examples using PyTorch tensors and `Quaternion`:
68
+
69
+ ```python
70
+ import torch
71
+ from quatorch.quaternion import Quaternion
72
+
73
+ # Create a quaternion from four scalars (W, X, Y, Z)
74
+ q = Quaternion(1.0, 0.0, 0.0, 0.0)
75
+
76
+ # Or from a tensor of shape (..., 4)
77
+ q2 = Quaternion(torch.tensor([0.9239, 0.3827, 0.0, 0.0])) # 45° around X
78
+
79
+ # Normalize
80
+ q2 = q2.normalize()
81
+
82
+ # Quaternion multiplication (rotation composition)
83
+ q3 = q * q2
84
+
85
+ # Rotate a vector
86
+ v = torch.tensor([1.0, 0.0, 0.0])
87
+ v_rot = q2.rotate_vector(v)
88
+
89
+ # Convert to rotation matrix
90
+ R = q2.to_rotation_matrix()
91
+
92
+ # Slerp between quaternions
93
+ t = 0.5
94
+ q_mid = q.slerp(q2, t)
95
+ ```
96
+
97
+ ## API notes
98
+
99
+ - Construction:
100
+ - `Quaternion(data: torch.Tensor)` where `data.shape[-1] == 4`.
101
+ - `Quaternion(w, x, y, z)` accepts scalars or tensors broadcastable to the same shape.
102
+
103
+ - Shape requirements:
104
+ - The last dimension must be size 4 for quaternion tensors (W, X, Y, Z).
105
+
106
+ - Interoperability:
107
+ - The class implements several `torch.*` functions via a small dispatcher so
108
+ many PyTorch APIs behave sensibly with `Quaternion` objects.
109
+
110
+ See the source in `src/quatorch/quaternion.py` for the full implementation and more
111
+ helper methods.
112
+
113
+ ## Running tests
114
+
115
+ This repository includes unit tests using `pytest` under `test/unit_tests`.
116
+
117
+ From the project root, run:
118
+
119
+ ```bash
120
+ uv run --with=. pytest
121
+ ```
122
+
123
+ ## Contributing
124
+
125
+ Contributions are welcome. A few ideas:
126
+
127
+ - Add more conversions and higher-level utilities (e.g., batch rotation helpers).
128
+ - Improve numeric stability and add property-based tests.
129
+ - Add docs and usage notebooks / examples.
130
+
131
+ Please open issues or pull requests on the repository.
132
+
133
+ ## License
134
+
135
+ MIT — see the [LICENSE.md](./LICENSE.md) for author/maintainer information.
136
+
137
+ ## Contact
138
+
139
+ Maintainer: Lucas N. Egidio <lucasegidio1@gmail.com>
140
+
@@ -0,0 +1,115 @@
1
+
2
+ <img src="https://raw.githubusercontent.com/egidioln/QuaTorch/refs/heads/main/docs/source/_static/logo.svg" width="400px">
3
+
4
+ ***Quaternions in PyTorch***
5
+ # QuaTorch
6
+ <!--[![Release](https://img.shields.io/github/v/release/actions/deploy-pages?label=Release&logo=github)](https://github.com/actions/deploy-pages/releases/latest) -->
7
+ [![cov](https://raw.githubusercontent.com/egidioln/QuaTorch/refs/heads/gh-pages/badges/coverage.svg)](https://github.com/egidioln/quatorch/actions)
8
+ [![Tests](https://img.shields.io/github/actions/workflow/status/egidioln/QuaTorch/pytest.yml?label=Tests&logo=github)](https://github.com/egidioln/QuaTorch/actions/workflows/pytest.yml)
9
+ [![Docs](https://img.shields.io/github/actions/workflow/status/egidioln/QuaTorch/docs.yml?label=Docs&logo=github)](https://egidioln.github.io/QuaTorch/)
10
+
11
+ The package `quatorch` provides `Quaternion`, a `torch.Tensor` subclass that represents a [Quaternion](https://en.wikipedia.org/wiki/Quaternion). It implements common operations in following quaternion algebra such as multiplication,
12
+ conjugation, inversion, normalization, log, exp, etc. It also supports conversion to/from rotation matrix and axis-angle representation. Convenient utilities are provided together, such as spherical linear interpolation ([slerp](https://en.wikipedia.org/wiki/Slerp)) and 3D vector rotation.
13
+
14
+ ## Highlights
15
+
16
+ - Quaternion type: `quatorch.Quaternion` (subclass of `torch.Tensor`).
17
+ - Element-wise and algebraic ops implemented: `+`, `-`, `*` (quaternion product and scalar mul),
18
+ `abs` (norm), `conjugate`, `inverse`, `normalize`, `to_rotation_matrix`, and more.
19
+ - Utilities: `from_rotation_matrix`, `from_axis_angle`, `to_axis_angle`, `rotate_vector`, `slerp`,
20
+ `log`, `exp`, and `pow`.
21
+
22
+ ## Installation
23
+
24
+ This project targets Python 3.10+ and requires PyTorch. Install via pip (recommended):
25
+
26
+ ```bash
27
+ pip install quatorch
28
+ ```
29
+
30
+ Or install editable/development mode:
31
+
32
+ ```bash
33
+ git clone
34
+ cd QuaTorch
35
+ pip install -e .
36
+ ```
37
+
38
+ See `pyproject.toml` for dependency details (dev deps include `pytest`).
39
+
40
+ ## Quick start
41
+
42
+ Basic usage examples using PyTorch tensors and `Quaternion`:
43
+
44
+ ```python
45
+ import torch
46
+ from quatorch.quaternion import Quaternion
47
+
48
+ # Create a quaternion from four scalars (W, X, Y, Z)
49
+ q = Quaternion(1.0, 0.0, 0.0, 0.0)
50
+
51
+ # Or from a tensor of shape (..., 4)
52
+ q2 = Quaternion(torch.tensor([0.9239, 0.3827, 0.0, 0.0])) # 45° around X
53
+
54
+ # Normalize
55
+ q2 = q2.normalize()
56
+
57
+ # Quaternion multiplication (rotation composition)
58
+ q3 = q * q2
59
+
60
+ # Rotate a vector
61
+ v = torch.tensor([1.0, 0.0, 0.0])
62
+ v_rot = q2.rotate_vector(v)
63
+
64
+ # Convert to rotation matrix
65
+ R = q2.to_rotation_matrix()
66
+
67
+ # Slerp between quaternions
68
+ t = 0.5
69
+ q_mid = q.slerp(q2, t)
70
+ ```
71
+
72
+ ## API notes
73
+
74
+ - Construction:
75
+ - `Quaternion(data: torch.Tensor)` where `data.shape[-1] == 4`.
76
+ - `Quaternion(w, x, y, z)` accepts scalars or tensors broadcastable to the same shape.
77
+
78
+ - Shape requirements:
79
+ - The last dimension must be size 4 for quaternion tensors (W, X, Y, Z).
80
+
81
+ - Interoperability:
82
+ - The class implements several `torch.*` functions via a small dispatcher so
83
+ many PyTorch APIs behave sensibly with `Quaternion` objects.
84
+
85
+ See the source in `src/quatorch/quaternion.py` for the full implementation and more
86
+ helper methods.
87
+
88
+ ## Running tests
89
+
90
+ This repository includes unit tests using `pytest` under `test/unit_tests`.
91
+
92
+ From the project root, run:
93
+
94
+ ```bash
95
+ uv run --with=. pytest
96
+ ```
97
+
98
+ ## Contributing
99
+
100
+ Contributions are welcome. A few ideas:
101
+
102
+ - Add more conversions and higher-level utilities (e.g., batch rotation helpers).
103
+ - Improve numeric stability and add property-based tests.
104
+ - Add docs and usage notebooks / examples.
105
+
106
+ Please open issues or pull requests on the repository.
107
+
108
+ ## License
109
+
110
+ MIT — see the [LICENSE.md](./LICENSE.md) for author/maintainer information.
111
+
112
+ ## Contact
113
+
114
+ Maintainer: Lucas N. Egidio <lucasegidio1@gmail.com>
115
+
@@ -0,0 +1,82 @@
1
+ [project]
2
+ name = "quatorch"
3
+ version = "0.1.0-alpha"
4
+ description = "Quaternion operations in pure PyTorch"
5
+ keywords = ["quaternion", "rotation", "3d", "orientation", "pytorch"]
6
+ readme = "README.md"
7
+ license = "MIT"
8
+ requires-python = ">=3.10"
9
+ dependencies = [
10
+ "torch>=2.8.0",
11
+ ]
12
+
13
+ authors = [
14
+ {name = "Lucas N. Egidio", email = "lucasegidio1@gmail.com"},
15
+ ]
16
+ maintainers = [
17
+ {name = "Lucas N. Egidio", email = "lucasegidio1@gmail.com"},
18
+ ]
19
+
20
+ classifiers = [
21
+ # TODO change to 5 - Production/Stable
22
+ "Development Status :: 4 - Beta",
23
+ "Intended Audience :: Science/Research",
24
+
25
+ # TODO verify "License :: OSI Approved :: MIT License",
26
+ "Topic :: Scientific/Engineering :: Mathematics",
27
+ "Topic :: Scientific/Engineering :: Physics",
28
+
29
+ # Specify the Python versions you support here.
30
+ "Programming Language :: Python :: 3",
31
+ ]
32
+
33
+
34
+ [project.optional-dependencies]
35
+ cpu = [
36
+ "torch"
37
+ ]
38
+
39
+ [tool.uv.sources]
40
+ torch = [
41
+ { index = "pytorch-cpu", extra = "cpu" },
42
+ ]
43
+
44
+ [[tool.uv.index]]
45
+ name = "pytorch-cpu"
46
+ url = "https://download.pytorch.org/whl/cpu"
47
+ explicit = true
48
+
49
+ [project.urls]
50
+ Homepage = "https://github.com/egidioln/QuaTorch/"
51
+ Documentation = "https://github.com/egidioln/QuaTorch/"
52
+ Repository = "https://github.com/egidioln/QuaTorch.git"
53
+ Issues = "https://github.com/egidioln/QuaTorch/issues"
54
+ # Changelog = "https://github.com/me/spam/blob/master/CHANGELOG.md"
55
+
56
+
57
+ [dependency-groups]
58
+ dev = [
59
+ "furo>=2025.7.19",
60
+ "myst-parser>=4.0.1",
61
+ "pytest>=8.4.2",
62
+ "pytest-cov>=7.0.0",
63
+ "ruff>=0.13.0",
64
+ "sphinx>=8.1.3",
65
+ ]
66
+
67
+
68
+ [build-system]
69
+ requires = ["uv_build >= 0.8.17, <0.9.0"]
70
+ build-backend = "uv_build"
71
+
72
+ [tool.pytest.ini_options]
73
+ testpaths = ["test"]
74
+ # Optional: set python_files pattern to include test files under unit_tests
75
+ python_files = ["test_*.py"]
76
+ addopts = "--cov=quatorch --cov-report xml:coverage.xml --cov-report html:htmlcov --cov-report json"
77
+
78
+ [tool.coverage.report]
79
+ exclude_also = [
80
+ 'raise ValueError',
81
+ 'raise TypeError'
82
+ ]
@@ -0,0 +1,5 @@
1
+ __version__ = "0.1.0"
2
+ from .quaternion import Quaternion
3
+
4
+
5
+ __all__ = ["Quaternion"]
@@ -0,0 +1,309 @@
1
+ import functools
2
+ from typing import Any, overload
3
+
4
+ import torch
5
+
6
+ HANDLED_FUNCTIONS = {}
7
+
8
+
9
+ def CHECK_OPERAND_SHAPE(other: Any, scalar_allowed: bool = True):
10
+ if torch.is_tensor(other) and other.shape[-1] not in (1, 4):
11
+ raise ValueError(
12
+ "The last dimension must be of size 4 to represent a quaternion (WXYZ) or size 1 to represent a real scalar."
13
+ )
14
+ if torch.is_tensor(other) and other.dtype in [torch.complex64, torch.complex128]:
15
+ raise TypeError("Cannot operate between quaternion and complex tensors.")
16
+ if not scalar_allowed and isinstance(other, (int, float, complex)):
17
+ raise TypeError("Operand must not be a scalar.")
18
+
19
+
20
+ def implements(torch_function):
21
+ """Register a torch function override for ScalarTensor"""
22
+
23
+ @functools.wraps(torch_function)
24
+ def decorator(func):
25
+ HANDLED_FUNCTIONS[torch_function] = func
26
+ return func
27
+
28
+ return decorator
29
+
30
+
31
+ class Quaternion(torch.Tensor):
32
+ r"""A `torch.Tensor` subclass representing quaternions.
33
+ Quaternions are represented as tensors of shape (..., 4), where the last dimension
34
+ corresponds to the components (W, X, Y, Z).
35
+ Args:
36
+
37
+ w: The real part of the quaternion.
38
+ x: The first imaginary part of the quaternion.
39
+ y: The second imaginary part of the quaternion.
40
+ z: The third imaginary part of the quaternion.
41
+ or
42
+ data: A tensor of shape (..., 4) representing the quaternion components
43
+ in the order (W, X, Y, Z).
44
+
45
+ """
46
+
47
+ def __new__(cls, *args, **kwargs):
48
+ if "data" in kwargs or (len(args) == 1 and isinstance(args[0], torch.Tensor)):
49
+ return super().__new__(cls, *args, **kwargs)
50
+ if len(args) == 4:
51
+ tensors = tuple(torch.as_tensor(arg) for arg in args)
52
+ try:
53
+ data = torch.stack(tensors, dim=-1)
54
+ except RuntimeError as e:
55
+ raise ValueError("All input tensors must have the same shape.") from e
56
+ return super().__new__(cls, data)
57
+ if all(_ in kwargs for _ in "wxyz"):
58
+ return super().__new__(
59
+ cls,
60
+ torch.stack(
61
+ [kwargs["w"], kwargs["x"], kwargs["y"], kwargs["z"]], dim=-1
62
+ ),
63
+ )
64
+ raise ValueError("Invalid arguments for Quaternion initialization.")
65
+
66
+ @overload
67
+ def __init__(self, data: torch.Tensor, *args, **kwargs): ...
68
+
69
+ @overload
70
+ def __init__(
71
+ self,
72
+ w: float | torch.Tensor,
73
+ x: float | torch.Tensor,
74
+ y: float | torch.Tensor,
75
+ z: float | torch.Tensor,
76
+ **kwargs,
77
+ ): ...
78
+
79
+ def __init__(self, *args, **kwargs):
80
+ super().__init__()
81
+ if self.shape[-1] != 4:
82
+ raise ValueError(
83
+ "The last dimension must be of size 4 to represent a quaternion (WXYZ)."
84
+ )
85
+
86
+ def to_wxyz(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
87
+ return self.as_subclass(torch.Tensor).unbind(-1)
88
+
89
+ @classmethod
90
+ def __torch_function__(cls, func, types, args=(), kwargs=None):
91
+ if kwargs is None:
92
+ kwargs = {}
93
+ if func in HANDLED_FUNCTIONS:
94
+ return HANDLED_FUNCTIONS[func](*args, **kwargs)
95
+ return super().__torch_function__(func, types, args, kwargs)
96
+
97
+ @implements(torch.add)
98
+ def add(self, other):
99
+ CHECK_OPERAND_SHAPE(other, scalar_allowed=False)
100
+
101
+ return Quaternion(super().add(other))
102
+
103
+ @implements(torch.mul)
104
+ def mul(self, other):
105
+ CHECK_OPERAND_SHAPE(other, scalar_allowed=True)
106
+
107
+ if isinstance(other, (int, float)):
108
+ return Quaternion(super().mul(other))
109
+
110
+ w1, x1, y1, z1 = self.to_wxyz()
111
+ if other.shape[-1] == 4:
112
+ w2, x2, y2, z2 = Quaternion(other).to_wxyz()
113
+ else: # scalar
114
+ w2 = other[..., 0]
115
+ x2 = 0
116
+ y2 = 0
117
+ z2 = 0
118
+
119
+ w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
120
+ x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
121
+ y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
122
+ z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
123
+
124
+ return Quaternion(torch.stack([w, x, y, z], dim=-1))
125
+
126
+ @implements(torch.sub)
127
+ def sub(self, other):
128
+ CHECK_OPERAND_SHAPE(other, scalar_allowed=False)
129
+
130
+ return Quaternion(super().sub(other))
131
+
132
+ @implements(torch.abs)
133
+ def abs(self):
134
+ w, x, y, z = self.to_wxyz()
135
+ return torch.sqrt(w**2 + x**2 + y**2 + z**2)
136
+
137
+ def __add__(self, other):
138
+ return self.add(other)
139
+
140
+ def __sub__(self, other):
141
+ return self.sub(other)
142
+
143
+ def __mul__(self, other):
144
+ return self.mul(other)
145
+
146
+ def __pow__(self, other):
147
+ return self.pow(other)
148
+
149
+ def __truediv__(self, other):
150
+ if isinstance(other, Quaternion):
151
+ return self.mul(other.inverse())
152
+ return self.div(other)
153
+
154
+ def conjugate(self):
155
+ w, x, y, z = self.to_wxyz()
156
+ return Quaternion(torch.stack([w, -x, -y, -z], dim=-1))
157
+
158
+ def inverse(self):
159
+ norm_sq = self.abs() ** 2
160
+ conj = self.conjugate()
161
+ return Quaternion(conj / norm_sq.unsqueeze(-1))
162
+
163
+ def normalize(self):
164
+ norm = self.abs()
165
+ return Quaternion(self / norm.unsqueeze(-1))
166
+
167
+ def to_rotation_matrix(self):
168
+ w, x, y, z = self.to_wxyz()
169
+ leading_dims = self.shape[:-1]
170
+ rotation_matrix = torch.empty(
171
+ *leading_dims, 3, 3, device=self.device, dtype=self.dtype
172
+ )
173
+
174
+ rotation_matrix[..., 0, 0] = 1 - 2 * (y**2 + z**2)
175
+ rotation_matrix[..., 0, 1] = 2 * (x * y - z * w)
176
+ rotation_matrix[..., 0, 2] = 2 * (x * z + y * w)
177
+
178
+ rotation_matrix[..., 1, 0] = 2 * (x * y + z * w)
179
+ rotation_matrix[..., 1, 1] = 1 - 2 * (x**2 + z**2)
180
+ rotation_matrix[..., 1, 2] = 2 * (y * z - x * w)
181
+
182
+ rotation_matrix[..., 2, 0] = 2 * (x * z - y * w)
183
+ rotation_matrix[..., 2, 1] = 2 * (y * z + x * w)
184
+ rotation_matrix[..., 2, 2] = 1 - 2 * (x**2 + y**2)
185
+
186
+ return rotation_matrix
187
+
188
+ @staticmethod
189
+ def from_rotation_matrix(R: torch.Tensor):
190
+ if R.shape[-2:] != (3, 3):
191
+ raise ValueError("Input rotation matrix must have shape (..., 3, 3)")
192
+ B = R.shape[:-2]
193
+ R = R.reshape(-1, 3, 3)
194
+
195
+ trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]
196
+ w = torch.sqrt(1.0 + trace) / 2.0
197
+ x = (R[..., 2, 1] - R[..., 1, 2]) / (4.0 * w)
198
+ y = (R[..., 0, 2] - R[..., 2, 0]) / (4.0 * w)
199
+ z = (R[..., 1, 0] - R[..., 0, 1]) / (4.0 * w)
200
+
201
+ q = torch.stack([w, x, y, z], dim=-1)
202
+ q = q.reshape(*B, 4)
203
+ return Quaternion(q)
204
+
205
+ @staticmethod
206
+ def from_axis_angle(axis: torch.Tensor, angle: torch.Tensor):
207
+ if axis.shape[-1] != 3:
208
+ raise ValueError("Axis must have shape (..., 3)")
209
+ if axis.dim() != angle.dim() + 1:
210
+ raise ValueError(
211
+ "Axis (..., 3) and angle (...) must have the same number of leading dimensions"
212
+ )
213
+ if axis.shape[:-1] != angle.shape:
214
+ raise ValueError("Axis and angle must have compatible shapes")
215
+
216
+ half_angle = angle / 2.0
217
+ sin_half_angle = torch.sin(half_angle)
218
+ cos_half_angle = torch.cos(half_angle)
219
+
220
+ axis = axis / torch.norm(axis, dim=-1, keepdim=True)
221
+
222
+ w = cos_half_angle
223
+ x = axis[..., 0] * sin_half_angle
224
+ y = axis[..., 1] * sin_half_angle
225
+ z = axis[..., 2] * sin_half_angle
226
+
227
+ q = torch.stack([w, x, y, z], dim=-1)
228
+ return Quaternion(q)
229
+
230
+ def to_axis_angle(self):
231
+ w, x, y, z = self.to_wxyz()
232
+ angle = 2 * torch.acos(w)
233
+ s = torch.sqrt(1 - w**2)
234
+ s = torch.where(s < 1e-8, torch.tensor(1e-8, device=s.device, dtype=s.dtype), s)
235
+ axis = torch.stack([x / s, y / s, z / s], dim=-1)
236
+ return axis, angle
237
+
238
+ def rotate_vector(self, v: torch.Tensor):
239
+ if v.shape[-1] != 3:
240
+ raise ValueError("Input vector must have shape (..., 3)")
241
+
242
+ v_quat = Quaternion(
243
+ torch.zeros_like(v[..., 0]), v[..., 0], v[..., 1], v[..., 2]
244
+ )
245
+
246
+ rotated_v_quat = self * v_quat * self.conjugate()
247
+ return torch.stack(
248
+ [rotated_v_quat[..., 1], rotated_v_quat[..., 2], rotated_v_quat[..., 3]],
249
+ dim=-1,
250
+ )
251
+
252
+ def slerp(self, other: "Quaternion", t: float | torch.Tensor):
253
+ CHECK_OPERAND_SHAPE(other, scalar_allowed=False)
254
+ if self.shape != other.shape:
255
+ raise ValueError("Quaternions must have the same shape for slerp.")
256
+ if isinstance(t, (int, float)):
257
+ t = torch.tensor(t, device=self.device, dtype=self.dtype)
258
+
259
+ return self * (self.inverse() * other) ** t
260
+
261
+ @implements(torch.log)
262
+ def log(self):
263
+ w, x, y, z = self.to_wxyz()
264
+ v_norm = torch.sqrt(x**2 + y**2 + z**2)
265
+ v_norm = torch.where(
266
+ v_norm < 1e-8,
267
+ torch.tensor(1e-8, device=v_norm.device, dtype=v_norm.dtype),
268
+ v_norm,
269
+ )
270
+ theta = torch.atan2(v_norm, w)
271
+ coeff = theta / v_norm
272
+ return Quaternion(
273
+ torch.stack([0.0 * w, x * coeff, y * coeff, z * coeff], dim=-1)
274
+ )
275
+
276
+ @implements(torch.exp)
277
+ def exp(self):
278
+ w, x, y, z = self.to_wxyz()
279
+ v_norm = torch.sqrt(x**2 + y**2 + z**2)
280
+ exp_w = torch.exp(w)
281
+ cos_v_norm = torch.cos(v_norm)
282
+ sin_v_norm = torch.sin(v_norm)
283
+ coeff = torch.where(
284
+ v_norm < 1e-8,
285
+ torch.tensor(1.0, device=v_norm.device, dtype=v_norm.dtype),
286
+ sin_v_norm / v_norm,
287
+ )
288
+ return Quaternion(
289
+ torch.stack(
290
+ [
291
+ exp_w * cos_v_norm,
292
+ exp_w * x * coeff,
293
+ exp_w * y * coeff,
294
+ exp_w * z * coeff,
295
+ ],
296
+ dim=-1,
297
+ )
298
+ )
299
+
300
+ @implements(torch.pow)
301
+ def pow(self, exponent: float | torch.Tensor):
302
+ if isinstance(exponent, (int, float)):
303
+ exponent = torch.tensor(exponent, device=self.device, dtype=self.dtype)
304
+ if exponent.dim() == 0:
305
+ exponent = exponent.unsqueeze(0)
306
+
307
+ log_q = self.log()
308
+ scaled_log_q = Quaternion(log_q * exponent)
309
+ return scaled_log_q.exp()