quat-numpy 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quat/__init__.py +102 -0
- quat/_arrayops.py +56 -0
- quat/_checks.py +30 -0
- quat/_serialize.py +81 -0
- quat/algebra.py +90 -0
- quat/collections.py +928 -0
- quat/core.py +706 -0
- quat/interpolate.py +120 -0
- quat/linalg.py +265 -0
- quat/optimized.py +98 -0
- quat/py.typed +0 -0
- quat/random.py +87 -0
- quat/serialization.py +85 -0
- quat/signal.py +165 -0
- quat/utils.py +167 -0
- quat_numpy-0.2.0.dist-info/METADATA +214 -0
- quat_numpy-0.2.0.dist-info/RECORD +20 -0
- quat_numpy-0.2.0.dist-info/WHEEL +5 -0
- quat_numpy-0.2.0.dist-info/licenses/LICENSE +201 -0
- quat_numpy-0.2.0.dist-info/top_level.txt +1 -0
quat/__init__.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
# =========================================================================
|
|
2
|
+
# LLM-generated code. See README.md for full disclosure.
|
|
3
|
+
# =========================================================================
|
|
4
|
+
|
|
5
|
+
"""Quaternion Algebra Library — quat package.
|
|
6
|
+
|
|
7
|
+
Provides:
|
|
8
|
+
Quaternion — single quaternion value
|
|
9
|
+
QuatVector — 1-d collection of quaternions
|
|
10
|
+
QuatMatrix — 2-d quaternion matrix
|
|
11
|
+
QuatTensor — 3-d quaternion tensor
|
|
12
|
+
|
|
13
|
+
quat() — convenience constructor
|
|
14
|
+
dict_to_quat_matrix, dict_to_quat_tensor, labels_to_quat_vector
|
|
15
|
+
|
|
16
|
+
Algebra primitives (from quat.algebra):
|
|
17
|
+
_hamilton, _CONJ, _REAL_LEFT
|
|
18
|
+
|
|
19
|
+
Basis constants (from quat.core):
|
|
20
|
+
_I, _J, _K, _ZERO, _R, _ONE_Q
|
|
21
|
+
|
|
22
|
+
Utilities (from quat.utils):
|
|
23
|
+
to_ndarray, from_ndarray, from_components, broadcast_quat_shapes, stack_quat,
|
|
24
|
+
isnan, isinf, isfinite, isclose
|
|
25
|
+
|
|
26
|
+
Serialization (from quat.serialization):
|
|
27
|
+
to_json, from_json, to_bytes, from_bytes,
|
|
28
|
+
to_scipy_rotation, from_scipy_rotation
|
|
29
|
+
|
|
30
|
+
Optimized (from quat.optimized):
|
|
31
|
+
hamilton_einsum, quat_matmul, conjugate_batch,
|
|
32
|
+
norm_squared_batch, normalize_batch
|
|
33
|
+
|
|
34
|
+
Linear algebra (from quat.linalg):
|
|
35
|
+
svd, rank, condition_number, pseudo_inverse,
|
|
36
|
+
trace, det, norm, solve
|
|
37
|
+
|
|
38
|
+
Random (from quat.random):
|
|
39
|
+
random_quat, random_unit_quat, random_quat_vector,
|
|
40
|
+
random_quat_matrix, random_quat_tensor
|
|
41
|
+
|
|
42
|
+
Interpolation (from quat.interpolate):
|
|
43
|
+
slerp, slerp_vector, squad
|
|
44
|
+
|
|
45
|
+
Signal processing (from quat.signal):
|
|
46
|
+
qfft, iqfft, qfft2, iqfft2,
|
|
47
|
+
qconv, qconv2,
|
|
48
|
+
lowpass, highpass, bandpass, bandstop
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
__version__ = "0.2.0"
|
|
52
|
+
from quat.algebra import _hamilton, _CONJ, _REAL_LEFT
|
|
53
|
+
from quat.core import Quaternion, quat, _I, _J, _K, _ZERO, _R, _ONE_Q
|
|
54
|
+
from quat.collections import (
|
|
55
|
+
QuatVector, QuatMatrix, QuatTensor,
|
|
56
|
+
dict_to_quat_matrix, dict_to_quat_tensor, labels_to_quat_vector,
|
|
57
|
+
)
|
|
58
|
+
from quat.utils import to_ndarray, from_ndarray, from_components, broadcast_quat_shapes, stack_quat
|
|
59
|
+
from quat.utils import isnan, isinf, isfinite, isclose
|
|
60
|
+
from quat.serialization import (
|
|
61
|
+
to_json, from_json, to_bytes, from_bytes,
|
|
62
|
+
to_scipy_rotation, from_scipy_rotation,
|
|
63
|
+
)
|
|
64
|
+
from quat.optimized import (
|
|
65
|
+
hamilton_einsum, quat_matmul,
|
|
66
|
+
conjugate_batch, norm_squared_batch, normalize_batch,
|
|
67
|
+
)
|
|
68
|
+
from quat.linalg import (
|
|
69
|
+
svd, rank, condition_number, pseudo_inverse,
|
|
70
|
+
trace, det, norm, solve,
|
|
71
|
+
)
|
|
72
|
+
from quat.random import (
|
|
73
|
+
random_quat, random_unit_quat, random_quat_vector,
|
|
74
|
+
random_quat_matrix, random_quat_tensor,
|
|
75
|
+
)
|
|
76
|
+
from quat.interpolate import slerp, slerp_vector, squad
|
|
77
|
+
from quat.signal import (
|
|
78
|
+
qfft, iqfft, qfft2, iqfft2,
|
|
79
|
+
qconv, qconv2,
|
|
80
|
+
lowpass, highpass, bandpass, bandstop,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
__all__ = [
|
|
84
|
+
'Quaternion', 'QuatVector', 'QuatMatrix', 'QuatTensor',
|
|
85
|
+
'quat', 'dict_to_quat_matrix', 'dict_to_quat_tensor', 'labels_to_quat_vector',
|
|
86
|
+
'_hamilton', '_CONJ', '_REAL_LEFT', '_I', '_J', '_K', '_ZERO', '_R', '_ONE_Q',
|
|
87
|
+
'to_ndarray', 'from_ndarray', 'from_components', 'broadcast_quat_shapes', 'stack_quat',
|
|
88
|
+
'isnan', 'isinf', 'isfinite', 'isclose',
|
|
89
|
+
'to_json', 'from_json', 'to_bytes', 'from_bytes',
|
|
90
|
+
'to_scipy_rotation', 'from_scipy_rotation',
|
|
91
|
+
'hamilton_einsum', 'quat_matmul',
|
|
92
|
+
'conjugate_batch', 'norm_squared_batch', 'normalize_batch',
|
|
93
|
+
'svd', 'rank', 'condition_number', 'pseudo_inverse',
|
|
94
|
+
'trace', 'det', 'norm', 'solve',
|
|
95
|
+
'random_quat', 'random_unit_quat', 'random_quat_vector',
|
|
96
|
+
'random_quat_matrix', 'random_quat_tensor',
|
|
97
|
+
'slerp', 'slerp_vector', 'squad',
|
|
98
|
+
'qfft', 'iqfft', 'qfft2', 'iqfft2',
|
|
99
|
+
'qconv', 'qconv2',
|
|
100
|
+
'lowpass', 'highpass', 'bandpass', 'bandstop',
|
|
101
|
+
'__version__',
|
|
102
|
+
]
|
quat/_arrayops.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""NumPy interop helpers for quaternion collection types.
|
|
2
|
+
|
|
3
|
+
Provides shared implementations of ``data``, ``to_array``, ``to_numpy``,
|
|
4
|
+
``__array__``, and ``__array_ufunc__`` so QuatVector, QuatMatrix, and
|
|
5
|
+
QuatTensor share a single code path for these operations.
|
|
6
|
+
"""
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _data_copy(data: np.ndarray) -> np.ndarray:
|
|
12
|
+
"""Return a defensive copy of the underlying quaternion data."""
|
|
13
|
+
return data.copy()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _to_numpy(data: np.ndarray, copy: bool = True,
|
|
17
|
+
dtype: np.dtype | None = None) -> np.ndarray:
|
|
18
|
+
"""Export quaternion data to an ndarray, optionally with dtype conversion.
|
|
19
|
+
|
|
20
|
+
When ``copy=False`` and ``dtype=None``, returns the internal array
|
|
21
|
+
without copying — caller must not mutate it.
|
|
22
|
+
"""
|
|
23
|
+
if copy is False and dtype is None:
|
|
24
|
+
return data
|
|
25
|
+
return np.array(data, dtype=dtype, copy=copy)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _to_array(data: np.ndarray) -> np.ndarray:
|
|
29
|
+
"""Return a copy of the quaternion data as an ndarray."""
|
|
30
|
+
return data.copy()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _dispatch_collection_ufunc(
|
|
34
|
+
self, ufunc, method: str, *inputs, **kwargs
|
|
35
|
+
) -> object:
|
|
36
|
+
"""Shared ``__array_ufunc__`` dispatch for QuatVector/QuatMatrix/QuatTensor.
|
|
37
|
+
|
|
38
|
+
The *self* parameter is unused in the body (dispatch works via *inputs*)
|
|
39
|
+
but is kept for signature compatibility with NumPy's protocol.
|
|
40
|
+
|
|
41
|
+
Returns ``NotImplemented`` for unsupported ufuncs or reduction methods.
|
|
42
|
+
"""
|
|
43
|
+
if method != '__call__' or kwargs.get('out') is not None:
|
|
44
|
+
return NotImplemented
|
|
45
|
+
a, b = (inputs[0], inputs[1]) if len(inputs) == 2 else (inputs[0], None)
|
|
46
|
+
if ufunc is np.add:
|
|
47
|
+
return a + b
|
|
48
|
+
if ufunc is np.subtract:
|
|
49
|
+
return a - b
|
|
50
|
+
if ufunc is np.multiply:
|
|
51
|
+
return a * b
|
|
52
|
+
if ufunc is np.true_divide or ufunc is np.floor_divide:
|
|
53
|
+
return a / b
|
|
54
|
+
if ufunc is np.negative:
|
|
55
|
+
return -a
|
|
56
|
+
return NotImplemented
|
quat/_checks.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""Validation helpers for quaternion collection types.
|
|
2
|
+
|
|
3
|
+
Each function accepts a ``(..., 4)`` ndarray of quaternion component data
|
|
4
|
+
and returns a per-element boolean ndarray (or scalar for Quaternion's 1-D
|
|
5
|
+
case). Used by QuatVector, QuatMatrix, and QuatTensor to avoid duplicating
|
|
6
|
+
the same numpy calls across three classes.
|
|
7
|
+
"""
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _vec_isnan(data: np.ndarray) -> np.ndarray:
|
|
13
|
+
"""True where any quaternion component is NaN along the last axis."""
|
|
14
|
+
return np.any(np.isnan(data), axis=-1)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _vec_isinf(data: np.ndarray) -> np.ndarray:
|
|
18
|
+
"""True where any quaternion component is infinite along the last axis."""
|
|
19
|
+
return np.any(np.isinf(data), axis=-1)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _vec_isfinite(data: np.ndarray) -> np.ndarray:
|
|
23
|
+
"""True where all quaternion components are finite along the last axis."""
|
|
24
|
+
return np.all(np.isfinite(data), axis=-1)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _vec_isclose(data: np.ndarray, other_data: np.ndarray,
|
|
28
|
+
rtol: float, atol: float) -> np.ndarray:
|
|
29
|
+
"""Element-wise closeness test — all four components must be close along the last axis."""
|
|
30
|
+
return np.isclose(data, other_data, rtol=rtol, atol=atol).all(axis=-1)
|
quat/_serialize.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
"""Shared serialization primitives for quaternion types.
|
|
2
|
+
|
|
3
|
+
Provides JSON and binary serialization helpers used by Quaternion,
|
|
4
|
+
QuatVector, QuatMatrix, and QuatTensor. Each type delegates its
|
|
5
|
+
``to_json`` / ``from_json`` / ``to_bytes`` / ``from_bytes`` methods
|
|
6
|
+
to these functions, passing only its type-specific label or id.
|
|
7
|
+
"""
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
import json as _json
|
|
10
|
+
import struct as _struct
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
_TYPE_IDS: dict[int, str] = {0: "Quaternion", 1: "QuatVector", 2: "QuatMatrix", 3: "QuatTensor"}
|
|
14
|
+
"""Mapping from binary type_id to type name string."""
|
|
15
|
+
|
|
16
|
+
_TYPE_NAMES: dict[str, int] = {v: k for k, v in _TYPE_IDS.items()}
|
|
17
|
+
"""Reverse mapping from type name to binary type_id."""
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _serialize_to_json(type_name: str, data: np.ndarray) -> str:
|
|
21
|
+
"""Serialize quaternion data to a JSON string.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
type_name: one of ``"Quaternion"``, ``"QuatVector"``, etc.
|
|
25
|
+
data: ndarray of shape ``(..., 4)``.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
JSON string ``{"type": "<type_name>", "data": [[...], ...]}``.
|
|
29
|
+
"""
|
|
30
|
+
return _json.dumps({"type": type_name, "data": data.tolist()})
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _deserialize_from_json(s: str, cls_map: dict) -> object:
|
|
34
|
+
"""Deserialize a JSON string back to a quaternion type instance.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
s: JSON string produced by ``_serialize_to_json``.
|
|
38
|
+
cls_map: dict mapping type_name → constructor callable
|
|
39
|
+
(e.g. ``{"QuatVector": QuatVector}``).
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
Instance of the type indicated by the JSON ``"type"`` field.
|
|
43
|
+
"""
|
|
44
|
+
d = _json.loads(s)
|
|
45
|
+
return cls_map[d["type"]](np.array(d["data"], dtype=float))
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _serialize_bytes_shaped(type_id: int, data: np.ndarray) -> bytes:
|
|
49
|
+
"""Serialize quaternion data to compact binary form.
|
|
50
|
+
|
|
51
|
+
Format: ``<i`` type_id + ``<i`` ndim + int32[ndim] shape + float64[...].
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
type_id: binary type identifier (1=QuatVector, 2=QuatMatrix, 3=QuatTensor).
|
|
55
|
+
data: ndarray of shape ``(..., 4)``.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Packed bytes.
|
|
59
|
+
"""
|
|
60
|
+
data64 = data.astype(np.float64)
|
|
61
|
+
shape = np.array(data64.shape, dtype=np.int32)
|
|
62
|
+
return _struct.pack('<ii', type_id, len(shape)) + shape.tobytes() + data64.tobytes()
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _deserialize_bytes_shaped(b: bytes, cls_map: dict) -> object:
|
|
66
|
+
"""Deserialize binary bytes back to a quaternion collection type.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
b: bytes produced by ``_serialize_bytes_shaped``.
|
|
70
|
+
cls_map: dict mapping type_id → constructor callable.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Instance of the type indicated by the stored *type_id*.
|
|
74
|
+
"""
|
|
75
|
+
type_id, ndim = _struct.unpack_from('<ii', b, 0)
|
|
76
|
+
offset = 8
|
|
77
|
+
shape = np.frombuffer(b[offset:offset + ndim * 4], dtype=np.int32)
|
|
78
|
+
offset += ndim * 4
|
|
79
|
+
size = int(np.prod(shape))
|
|
80
|
+
data = np.frombuffer(b[offset:offset + size * 8], dtype=np.float64).reshape(shape)
|
|
81
|
+
return cls_map[type_id](data)
|
quat/algebra.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
# =========================================================================
|
|
2
|
+
# LLM-generated code. See README.md for full disclosure.
|
|
3
|
+
# =========================================================================
|
|
4
|
+
|
|
5
|
+
"""Low-level quaternion algebra — constants, Hamilton product, real-matrix tensor."""
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
import numpy as np
|
|
8
|
+
import numpy.typing as npt
|
|
9
|
+
|
|
10
|
+
_CONJ = np.array([1., -1., -1., -1.])
|
|
11
|
+
"""Conjugate mask: ``_CONJ * q`` negates the three imaginary components."""
|
|
12
|
+
|
|
13
|
+
_REAL_LEFT = np.zeros((4, 4, 4))
|
|
14
|
+
"""Left-regular real-representation tensor.
|
|
15
|
+
|
|
16
|
+
``L(q)[r, c] = Σ_k _REAL_LEFT[r, c, k] * q[k]`` yields the 4×4 real matrix
|
|
17
|
+
satisfying ``L(q) @ vec(x) = vec(q * x)``.
|
|
18
|
+
"""
|
|
19
|
+
_REAL_LEFT[0, 0, 0] = 1; _REAL_LEFT[0, 1, 1] = -1; _REAL_LEFT[0, 2, 2] = -1; _REAL_LEFT[0, 3, 3] = -1
|
|
20
|
+
_REAL_LEFT[1, 0, 1] = 1; _REAL_LEFT[1, 1, 0] = 1; _REAL_LEFT[1, 2, 3] = -1; _REAL_LEFT[1, 3, 2] = 1
|
|
21
|
+
_REAL_LEFT[2, 0, 2] = 1; _REAL_LEFT[2, 1, 3] = 1; _REAL_LEFT[2, 2, 0] = 1; _REAL_LEFT[2, 3, 1] = -1
|
|
22
|
+
_REAL_LEFT[3, 0, 3] = 1; _REAL_LEFT[3, 1, 2] = -1; _REAL_LEFT[3, 2, 1] = 1; _REAL_LEFT[3, 3, 0] = 1
|
|
23
|
+
|
|
24
|
+
_HAMILTON_TENSOR = np.zeros((4, 4, 4))
|
|
25
|
+
_HAMILTON_TENSOR[0, 0, 0] = 1; _HAMILTON_TENSOR[0, 1, 1] = -1; _HAMILTON_TENSOR[0, 2, 2] = -1; _HAMILTON_TENSOR[0, 3, 3] = -1
|
|
26
|
+
_HAMILTON_TENSOR[1, 0, 1] = 1; _HAMILTON_TENSOR[1, 1, 0] = 1; _HAMILTON_TENSOR[1, 2, 3] = 1; _HAMILTON_TENSOR[1, 3, 2] = -1
|
|
27
|
+
_HAMILTON_TENSOR[2, 0, 2] = 1; _HAMILTON_TENSOR[2, 1, 3] = -1; _HAMILTON_TENSOR[2, 2, 0] = 1; _HAMILTON_TENSOR[2, 3, 1] = 1
|
|
28
|
+
_HAMILTON_TENSOR[3, 0, 3] = 1; _HAMILTON_TENSOR[3, 1, 2] = 1; _HAMILTON_TENSOR[3, 2, 1] = -1; _HAMILTON_TENSOR[3, 3, 0] = 1
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
_SMALL_THRESHOLD = 500
|
|
32
|
+
"""Total-element threshold below which component-wise arithmetic is fastest."""
|
|
33
|
+
|
|
34
|
+
_LARGE_THRESHOLD = 5000
|
|
35
|
+
"""Total-element threshold above which full einsum optimisation pays off."""
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _hamilton(p: npt.NDArray, q: npt.NDArray) -> npt.NDArray:
|
|
39
|
+
"""Vectorized Hamilton (quaternion) product.
|
|
40
|
+
|
|
41
|
+
Dispatches to the optimal kernel based on data size:
|
|
42
|
+
- small (<=500 elements): component-wise arithmetic
|
|
43
|
+
- medium (500-5000): einsum without contraction-path optimisation
|
|
44
|
+
- large (>5000): einsum with full contraction-path optimisation
|
|
45
|
+
|
|
46
|
+
Supports arbitrary leading-dimension broadcasting.
|
|
47
|
+
|
|
48
|
+
Example:
|
|
49
|
+
>>> from quat.algebra import _hamilton
|
|
50
|
+
>>> import numpy as np
|
|
51
|
+
>>> p = np.array([1., 0., 0., 0.]) # real unit
|
|
52
|
+
>>> q = np.array([[1., 2., 3., 4.]]) # batch of one
|
|
53
|
+
>>> _hamilton(p, q)
|
|
54
|
+
array([[1., 2., 3., 4.]])
|
|
55
|
+
>>> i = np.array([0., 1., 0., 0.])
|
|
56
|
+
>>> j = np.array([0., 0., 1., 0.])
|
|
57
|
+
>>> _hamilton(i, j) # i*j = k
|
|
58
|
+
array([0., 0., 0., 1.])
|
|
59
|
+
"""
|
|
60
|
+
total_elements = p.size + q.size
|
|
61
|
+
if total_elements <= _SMALL_THRESHOLD:
|
|
62
|
+
return _hamilton_component(p, q)
|
|
63
|
+
if total_elements <= _LARGE_THRESHOLD:
|
|
64
|
+
return _hamilton_einsum_noopt(p, q)
|
|
65
|
+
return _hamilton_einsum(p, q)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _hamilton_component(p: npt.NDArray, q: npt.NDArray) -> npt.NDArray:
|
|
69
|
+
"""Component-wise Hamilton product — fastest for small batches."""
|
|
70
|
+
a1, b1, c1, d1 = p[..., 0], p[..., 1], p[..., 2], p[..., 3]
|
|
71
|
+
a2, b2, c2, d2 = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
|
|
72
|
+
shp = np.broadcast_shapes(p.shape[:-1], q.shape[:-1]) + (4,)
|
|
73
|
+
out = np.empty(shp)
|
|
74
|
+
out[..., 0] = a1*a2 - b1*b2 - c1*c2 - d1*d2
|
|
75
|
+
out[..., 1] = a1*b2 + b1*a2 + c1*d2 - d1*c2
|
|
76
|
+
out[..., 2] = a1*c2 - b1*d2 + c1*a2 + d1*b2
|
|
77
|
+
out[..., 3] = a1*d2 + b1*c2 - c1*b2 + d1*a2
|
|
78
|
+
return out
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _hamilton_einsum_noopt(p: npt.NDArray, q: npt.NDArray) -> npt.NDArray:
|
|
82
|
+
"""Einsum without contraction-path optimisation — faster for medium batches
|
|
83
|
+
where the optimisation overhead outweighs its benefit."""
|
|
84
|
+
return np.einsum('rck,...c,...k->...r', _HAMILTON_TENSOR, p, q, optimize=False)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _hamilton_einsum(p: npt.NDArray, q: npt.NDArray) -> npt.NDArray:
|
|
88
|
+
"""Einsum with full contraction-path optimisation — best throughput for large
|
|
89
|
+
batches where the optimisation cost is amortised."""
|
|
90
|
+
return np.einsum('rck,...c,...k->...r', _HAMILTON_TENSOR, p, q, optimize=True)
|