pfc-geometry 0.2.14__tar.gz → 0.2.16__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. {pfc_geometry-0.2.14 → pfc_geometry-0.2.16}/.gitignore +1 -1
  2. {pfc_geometry-0.2.14 → pfc_geometry-0.2.16}/PKG-INFO +5 -2
  3. {pfc_geometry-0.2.14 → pfc_geometry-0.2.16}/pyproject.toml +7 -2
  4. {pfc_geometry-0.2.14 → pfc_geometry-0.2.16}/src/geometry/base.py +95 -13
  5. {pfc_geometry-0.2.14 → pfc_geometry-0.2.16}/src/geometry/gps.py +10 -0
  6. {pfc_geometry-0.2.14 → pfc_geometry-0.2.16}/src/geometry/point.py +123 -38
  7. {pfc_geometry-0.2.14 → pfc_geometry-0.2.16}/src/geometry/quaternion.py +68 -6
  8. pfc_geometry-0.2.16/src/geometry/time.py +86 -0
  9. {pfc_geometry-0.2.14 → pfc_geometry-0.2.16}/src/geometry/transformation.py +25 -9
  10. pfc_geometry-0.2.16/src/geometry/utils.py +99 -0
  11. {pfc_geometry-0.2.14 → pfc_geometry-0.2.16}/tests/test_base.py +21 -1
  12. {pfc_geometry-0.2.14 → pfc_geometry-0.2.16}/tests/test_point.py +20 -1
  13. {pfc_geometry-0.2.14 → pfc_geometry-0.2.16}/tests/test_quaternion.py +25 -1
  14. pfc_geometry-0.2.16/tests/test_time.py +25 -0
  15. pfc_geometry-0.2.16/tests/test_utils.py +41 -0
  16. {pfc_geometry-0.2.14 → pfc_geometry-0.2.16}/uv.lock +108 -16
  17. pfc_geometry-0.2.14/src/geometry/time.py +0 -43
  18. {pfc_geometry-0.2.14 → pfc_geometry-0.2.16}/.dockerignore +0 -0
  19. {pfc_geometry-0.2.14 → pfc_geometry-0.2.16}/.github/workflows/publish_pypi.yml +0 -0
  20. {pfc_geometry-0.2.14 → pfc_geometry-0.2.16}/LICENSE +0 -0
  21. {pfc_geometry-0.2.14 → pfc_geometry-0.2.16}/README.md +0 -0
  22. {pfc_geometry-0.2.14 → pfc_geometry-0.2.16}/src/geometry/__init__.py +0 -0
  23. {pfc_geometry-0.2.14 → pfc_geometry-0.2.16}/src/geometry/checks.py +0 -0
  24. {pfc_geometry-0.2.14 → pfc_geometry-0.2.16}/src/geometry/coordinate_frame.py +0 -0
  25. {pfc_geometry-0.2.14 → pfc_geometry-0.2.16}/src/geometry/mass.py +0 -0
  26. {pfc_geometry-0.2.14 → pfc_geometry-0.2.16}/src/geometry/py.typed +0 -0
  27. {pfc_geometry-0.2.14 → pfc_geometry-0.2.16}/tests/__init__.py +0 -0
  28. {pfc_geometry-0.2.14 → pfc_geometry-0.2.16}/tests/test_coord.py +0 -0
  29. {pfc_geometry-0.2.14 → pfc_geometry-0.2.16}/tests/test_gps.py +0 -0
  30. {pfc_geometry-0.2.14 → pfc_geometry-0.2.16}/tests/test_mass.py +0 -0
  31. {pfc_geometry-0.2.14 → pfc_geometry-0.2.16}/tests/test_remove_outliers.csv +0 -0
  32. {pfc_geometry-0.2.14 → pfc_geometry-0.2.16}/tests/test_transform.py +0 -0
@@ -1,5 +1,5 @@
1
1
  .vscode/
2
-
2
+ _trials/
3
3
  # Byte-compiled / optimized / DLL files
4
4
  __pycache__/
5
5
  *.py[cod]
@@ -1,10 +1,13 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: pfc-geometry
3
- Version: 0.2.14
3
+ Version: 0.2.16
4
4
  Summary: A library for working with 3D geometry.
5
+ License-File: LICENSE
5
6
  Requires-Python: >=3.12
7
+ Requires-Dist: numpy-quaternion>=2024.0.7
6
8
  Requires-Dist: numpy>=2.1.3
7
9
  Requires-Dist: pandas>=2.2.3
10
+ Requires-Dist: rowan>=1.3.2
8
11
  Description-Content-Type: text/markdown
9
12
 
10
13
  # geometry #
@@ -1,10 +1,15 @@
1
1
  [project]
2
2
  name = "pfc-geometry"
3
- version="0.2.14"
3
+ version="0.2.16"
4
4
  description = "A library for working with 3D geometry."
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.12"
7
- dependencies = ["numpy>=2.1.3", "pandas>=2.2.3"]
7
+ dependencies = [
8
+ "numpy-quaternion>=2024.0.7",
9
+ "numpy>=2.1.3",
10
+ "pandas>=2.2.3",
11
+ "rowan>=1.3.2",
12
+ ]
8
13
 
9
14
  [build-system]
10
15
  requires = ["hatchling"]
@@ -10,7 +10,8 @@ You should have received a copy of the GNU General Public License along with
10
10
  this program. If not, see <http://www.gnu.org/licenses/>.
11
11
  """
12
12
 
13
- from typing import Self
13
+ from __future__ import annotations
14
+ from typing import Self, Literal
14
15
  import numpy as np
15
16
  import numpy.typing as npt
16
17
  import pandas as pd
@@ -92,6 +93,14 @@ class Base:
92
93
  else:
93
94
  raise TypeError(f"Empty {self.__class__.__name__} not allowed")
94
95
 
96
+ def to_numpy(self, cols: str | list = None):
97
+ cols = self.cols if cols is None else cols
98
+ return np.column_stack([getattr(self, c) for c in cols])
99
+
100
+ @classmethod
101
+ def from_numpy(Cls, data: npt.NDArray, cols: str | list):
102
+ return Cls(np.column_stack([data[:, cols.index(col)] for col in Cls.cols]))
103
+
95
104
  @classmethod
96
105
  def _clean_data(cls, data) -> npt.NDArray[np.float64]:
97
106
  assert isinstance(data, np.ndarray)
@@ -247,14 +256,17 @@ class Base:
247
256
  def dot(self, other: Self) -> Self:
248
257
  return np.einsum("ij,ij->i", self.data, other)
249
258
 
250
- def diff(self, dt: np.array) -> Self:
259
+ def diff(
260
+ self, dt: npt.NDArray, method: Literal["gradient", "diff"] = "gradient"
261
+ ) -> Self:
251
262
  if not pd.api.types.is_list_like(dt):
252
263
  dt = np.full(len(self), dt)
253
- assert len(dt) == len(self)
254
- return self.__class__(
255
- np.gradient(self.data, axis=0)
256
- / np.tile(dt, (len(self.__class__.cols), 1)).T
257
- )
264
+ self, dt = Base.length_check(self, dt)
265
+ diff_method = np.gradient if method == "gradient" else np.diff
266
+
267
+ data = diff_method(self.data, axis=0)
268
+ dt = dt if method == "gradient" else dt[:-1]
269
+ return self.__class__(data / np.tile(dt, (len(self.__class__.cols), 1)).T)
258
270
 
259
271
  def to_pandas(self, prefix="", suffix="", columns=None, index=None):
260
272
  if columns is not None:
@@ -263,6 +275,10 @@ class Base:
263
275
  cols = [prefix + col + suffix for col in self.__class__.cols]
264
276
  return pd.DataFrame(self.data, columns=cols, index=index)
265
277
 
278
+ @property
279
+ def df(self):
280
+ return self.to_pandas()
281
+
266
282
  def tile(self, count) -> Self:
267
283
  return self.__class__(np.tile(self.data, (count, 1)))
268
284
 
@@ -347,15 +363,81 @@ class Base:
347
363
  def fill_zeros(self):
348
364
  """fills zero length rows with the previous or next non-zero value"""
349
365
  return self.__class__(
350
- pd.DataFrame(np.where(
351
- np.tile(abs(self) == 0, (3, 1)).T,
352
- np.full(self.data.shape, np.nan),
353
- self.data,
354
- )).ffill().bfill().to_numpy()
366
+ pd.DataFrame(
367
+ np.where(
368
+ np.tile(abs(self) == 0, (3, 1)).T,
369
+ np.full(self.data.shape, np.nan),
370
+ self.data,
371
+ )
372
+ )
373
+ .ffill()
374
+ .bfill()
375
+ .to_numpy()
355
376
  )
356
377
 
357
378
  def ffill(self):
358
379
  return self.__class__(pd.DataFrame(self.data).ffill().to_numpy())
359
380
 
360
381
  def bfill(self):
361
- return self.__class__(pd.DataFrame(self.data).bfill().to_numpy())
382
+ return self.__class__(pd.DataFrame(self.data).bfill().to_numpy())
383
+
384
+ def linterp(
385
+ self,
386
+ index: npt.NDArray | pd.Index,
387
+ extrapolate: Literal["throw", "nearest"] = "throw",
388
+ ):
389
+ "linear interpolation"
390
+ index = pd.Index(np.arange(len(self)) if index is None else index)
391
+ assert len(index) == len(self)
392
+ assert pd.Index(index).is_monotonic_increasing
393
+
394
+ def dolinterp(ts: npt.NDArray | Number):
395
+ starts = index.get_indexer(ts, method="ffill")
396
+ stops = index.get_indexer(ts, method="bfill")
397
+ if np.any(starts * stops < 0) and extrapolate=="throw":
398
+ raise Exception("Cannot extrapolate beyond parent range")
399
+ return self.__class__(np.column_stack(
400
+ [
401
+ np.interp(
402
+ ts, index, self.data[:, i], self.data[0, i], self.data[-1, i]
403
+ )
404
+ for i, col in enumerate(self.cols)
405
+ ]
406
+ ))
407
+ # return lambda t: a + (b - a) * np.clip(t, 0, 1)
408
+ return dolinterp
409
+
410
+ def bspline(self, index: npt.NDArray | pd.Index = None):
411
+ from scipy.interpolate import make_interp_spline
412
+
413
+ bspline = make_interp_spline(
414
+ np.arange(len(self)) if index is None else index, self.data, axis=0
415
+ )
416
+ return lambda i: self.__class__(bspline(i))
417
+
418
+ def interpolate(self, index: npt.NDArray | pd.Index = None, method:str=None):
419
+ if method is None:
420
+ match (self.__class__.__name__):
421
+ case "Point":
422
+ method="bspline"
423
+ case "Quaternion":
424
+ method="slerp"
425
+ case "Time":
426
+ method="linterp"
427
+ return getattr(self, method)(index)
428
+
429
+ def plot(self, index=None, **kwargs):
430
+ import plotly.graph_objects as go
431
+
432
+ fig = go.Figure()
433
+ for col in self.cols:
434
+ fig.add_trace(
435
+ go.Scatter(
436
+ x=np.arange(len(self)) if index is None else index,
437
+ y=getattr(self, col),
438
+ name=col,
439
+ **kwargs,
440
+ )
441
+ )
442
+ # df = self.to_pandas(self.__class__.__name__[0], index=index)
443
+ return fig
@@ -13,6 +13,7 @@ import math
13
13
  from geometry.base import Base
14
14
  from geometry.point import Point
15
15
  from typing import List, Union
16
+ import numpy.typing as npt
16
17
  import numpy as np
17
18
  import pandas as pd
18
19
 
@@ -71,6 +72,15 @@ class GPS(Base):
71
72
  self.alt - pin.z
72
73
  )
73
74
 
75
+ def bspline(self, index: npt.NDArray | pd.Index = None):
76
+
77
+ def interpolator(i):
78
+ ps: Point = self - self[0]
79
+ ips = ps.bspline(index)(i)
80
+ return self[0].offset(ips)
81
+
82
+ return interpolator
83
+
74
84
 
75
85
  '''
76
86
  Extract from ardupilot:
@@ -9,7 +9,9 @@ FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
9
9
  You should have received a copy of the GNU General Public License along with
10
10
  this program. If not, see <http://www.gnu.org/licenses/>.
11
11
  """
12
+
12
13
  from __future__ import annotations
14
+ from typing import Literal
13
15
  from .base import Base
14
16
  import numpy as np
15
17
  import pandas as pd
@@ -19,22 +21,38 @@ import numpy.typing as npt
19
21
 
20
22
 
21
23
  class Point(Base):
22
- cols=["x", "y", "z"]
24
+ cols = ["x", "y", "z"]
23
25
  from_np = [
24
- "sin","cos","tan",
25
- "arcsin","arccos","arctan",
26
+ "sin",
27
+ "cos",
28
+ "tan",
29
+ "arcsin",
30
+ "arccos",
31
+ "arctan",
26
32
  ]
27
33
 
34
+ @property
35
+ def xy(self):
36
+ return Point(self.x, self.y, np.zeros(len(self)))
37
+
38
+ @property
39
+ def yz(self):
40
+ return Point(np.zeros(len(self)), self.y, self.z)
41
+
42
+ @property
43
+ def zx(self):
44
+ return Point(self.x, np.zeros(len(self)), self.z)
45
+
28
46
  def scale(self, value) -> Point:
29
47
  with np.errstate(divide="ignore"):
30
- res = value/abs(self)
31
- res[res==np.inf] = 0
48
+ res = value / abs(self)
49
+ res[res == np.inf] = 0
32
50
  return self * res
33
-
51
+
34
52
  def unit(self) -> Point:
35
53
  return self.scale(1)
36
54
 
37
- def remove_outliers(self, nstds = 2):
55
+ def remove_outliers(self, nstds=2):
38
56
  ab = abs(self)
39
57
  std = np.nanstd(ab)
40
58
  mean = np.nanmean(ab)
@@ -50,69 +68,123 @@ class Point(Base):
50
68
 
51
69
  def max(self):
52
70
  return Point(np.max(self.data, axis=0))
53
-
71
+
54
72
  def min(self):
55
73
  return Point(np.min(self.data, axis=0))
56
74
 
57
75
  def angles(self, p2):
58
76
  return (self.cross(p2) / (abs(self) * abs(p2))).arcsin
59
-
77
+
60
78
  def planar_angles(self):
61
- return Point(np.arctan2(self.y, self.z), np.arctan2(self.z, self.x), np.arctan2(self.x, self.y))
79
+ return Point(
80
+ np.arctan2(self.y, self.z),
81
+ np.arctan2(self.z, self.x),
82
+ np.arctan2(self.x, self.y),
83
+ )
62
84
 
63
85
  def angle(self, p2):
64
86
  return abs(Point.angles(self, p2))
65
-
87
+
66
88
  @staticmethod
67
- def X(value: Number | npt.NDArray=1, count=1):
68
- return np.tile(value, count) * Point(1,0,0)
89
+ def X(value: Number | npt.NDArray = 1, count=1):
90
+ return np.tile(value, count) * Point(1, 0, 0)
69
91
 
70
92
  @staticmethod
71
93
  def Y(value=1, count=1):
72
- return np.tile(value, count) * Point(0,1,0)
94
+ return np.tile(value, count) * Point(0, 1, 0)
73
95
 
74
96
  @staticmethod
75
97
  def Z(value=1, count=1):
76
- return np.tile(value, count) * Point(0,0,1)
98
+ return np.tile(value, count) * Point(0, 0, 1)
77
99
 
78
100
  def rotate(self, rmat=np.ndarray):
79
101
  if len(rmat.shape) == 3:
80
102
  pass
81
103
  elif len(rmat.shape) == 2:
82
- rmat = np.reshape(rmat, (1, 3, 3 ))
104
+ rmat = np.reshape(rmat, (1, 3, 3))
83
105
  else:
84
106
  raise TypeError("expected a 3x3 matrix")
85
-
107
+
86
108
  return self.dot(rmat)
87
109
 
88
110
  def to_rotation_matrix(self):
89
- '''returns the rotation matrix based on a point representing Euler angles'''
111
+ """returns the rotation matrix based on a point representing Euler angles"""
90
112
  s = self.sin
91
113
  c = self.cos
92
- return np.array([
93
- [
94
- c.z * c.y,
95
- c.z * s.y * s.x - c.x * s.z,
96
- c.x * c.z * s.y + s.x * s.z
97
- ], [
98
- c.y * s.z,
99
- c.x * c.z + s.x * s.y * s.z,
100
- -1 * c.z * s.x + c.x * s.y * s.z
101
- ],
102
- [
103
- -1 * s.y,
104
- c.y * s.x,
105
- c.x * c.y
106
- ]
107
- ]).T
114
+ return np.transpose(
115
+ np.array(
116
+ [
117
+ [
118
+ c.z * c.y,
119
+ c.z * s.y * s.x - c.x * s.z,
120
+ c.x * c.z * s.y + s.x * s.z,
121
+ ],
122
+ [
123
+ c.y * s.z,
124
+ c.x * c.z + s.x * s.y * s.z,
125
+ -1 * c.z * s.x + c.x * s.y * s.z,
126
+ ],
127
+ [-1 * s.y, c.y * s.x, c.x * c.y],
128
+ ]
129
+ ),
130
+ (2, 0, 1),
131
+ )
132
+
133
+ def matrix(self):
134
+ return np.einsum("i...,...->i...", self.data, np.identity(3))
135
+
136
+ @staticmethod
137
+ def from_matrix(matrix):
138
+ return Point(matrix[:, 0, 0], matrix[:, 1, 1], matrix[:, 2, 2])
139
+
140
+ def skew_symmetric(self):
141
+ o = np.zeros(len(self))
142
+ return np.transpose(
143
+ np.array(
144
+ [[o, -self.z, self.y], [self.z, o, -self.x], [-self.y, self.x, o]]
145
+ ),
146
+ (2, 0, 1),
147
+ )
108
148
 
109
149
  @staticmethod
110
150
  def zeros(count=1):
111
- return Point(np.zeros((count,3)))
151
+ return Point(np.zeros((count, 3)))
112
152
 
113
153
  def bearing(self):
114
154
  return np.arctan2(self.y, self.x)
115
155
 
156
+ def plot3d(self, **kwargs):
157
+ import plotly.graph_objects as go
158
+ fig = go.Figure()
159
+
160
+ fig.add_trace(go.Scatter3d(x=self.x, y=self.y, z=self.z, **kwargs))
161
+ fig.update_layout(
162
+ scene=dict(aspectmode="data"),
163
+ )
164
+ return fig
165
+
166
+ def plotxy(self):
167
+ import plotly.express as px
168
+
169
+ return px.line(self.df, x="x", y="y").update_layout(
170
+ yaxis=dict(scaleanchor="x", scaleratio=1)
171
+ )
172
+
173
+ def plotyz(self):
174
+ import plotly.express as px
175
+
176
+ return px.line(self.df, x="y", y="z").update_layout(
177
+ yaxis=dict(scaleanchor="x", scaleratio=1, title="z"), xaxis=dict(title="y")
178
+ )
179
+
180
+ def plotzx(self):
181
+ import plotly.express as px
182
+
183
+ return px.line(self.df, x="z", y="x").update_layout(
184
+ yaxis=dict(scaleanchor="x", scaleratio=1, title="x"), xaxis=dict(title="z")
185
+ )
186
+
187
+
116
188
  def Points(*args, **kwargs):
117
189
  warn("Points is deprecated, you can now just use Point", DeprecationWarning)
118
190
  return Point(*args, **kwargs)
@@ -121,15 +193,19 @@ def Points(*args, **kwargs):
121
193
  def PX(length=1, count=1):
122
194
  return Point.X(length, count)
123
195
 
196
+
124
197
  def PY(length=1, count=1):
125
198
  return Point.Y(length, count)
126
199
 
200
+
127
201
  def PZ(length=1, count=1):
128
202
  return Point.Z(length, count)
129
203
 
204
+
130
205
  def P0(count=1):
131
206
  return Point.zeros(count)
132
207
 
208
+
133
209
  def ppmeth(func):
134
210
  def wrapper(a, b, *args, **kwargs):
135
211
  assert all([isinstance(arg, Point) for arg in args])
@@ -143,47 +219,56 @@ def ppmeth(func):
143
219
  @ppmeth
144
220
  def cross(a: Point, b: Point) -> Point:
145
221
  return Point(np.cross(a.data, b.data))
146
-
222
+
147
223
 
148
224
  @ppmeth
149
225
  def cos_angle_between(a: Point, b: Point) -> np.ndarray:
150
226
  return a.unit().dot(b.unit())
151
227
 
228
+
152
229
  @ppmeth
153
230
  def angle_between(a: Point, b: Point) -> np.ndarray:
154
231
  return np.arccos(a.cos_angle_between(b))
155
232
 
233
+
156
234
  @ppmeth
157
235
  def scalar_projection(a: Point, b: Point) -> Point:
158
236
  return a.cos_angle_between(b) * abs(a)
159
237
 
238
+
160
239
  @ppmeth
161
240
  def vector_projection(a: Point, b: Point) -> Point:
162
241
  return b.scale(a.scalar_projection(b))
163
242
 
243
+
164
244
  @ppmeth
165
245
  def vector_rejection(a: Point, b: Point) -> Point:
166
- return a - ((Point.dot(a, b)) / Point.dot(b,b)) * b
246
+ return a - ((Point.dot(a, b)) / Point.dot(b, b)) * b
167
247
 
168
248
 
169
249
  @ppmeth
170
250
  def is_parallel(a: Point, b: Point, tolerance=1e-6):
171
251
  return abs(a.cos_angle_between(b) - 1) < tolerance
172
252
 
253
+
173
254
  @ppmeth
174
255
  def is_perpendicular(a: Point, b: Point, tolerance=1e-6):
175
256
  return abs(a.dot(b)) < tolerance
176
257
 
258
+
177
259
  @ppmeth
178
260
  def min_angle_between(p1: Point, p2: Point):
179
261
  angle = angle_between(p1, p2) % np.pi
180
262
  return np.minimum(angle, np.pi - angle)
181
263
 
264
+
182
265
  def arbitrary_perpendicular(v: Point) -> Point:
183
266
  return Point(-v.y, v.x, 0).unit()
184
267
 
268
+
185
269
  def vector_norm(point: Point):
186
270
  return abs(point)
187
271
 
272
+
188
273
  def normalize_vector(point: Point):
189
- return point / abs(point)
274
+ return point / abs(point)
@@ -11,13 +11,14 @@ this program. If not, see <http://www.gnu.org/licenses/>.
11
11
  """
12
12
  from __future__ import annotations
13
13
  from .point import Point
14
- from .base import Base, dprep
14
+ from .base import Base
15
15
  from geometry import PZ
16
16
  import numpy as np
17
17
  import numpy.typing as npt
18
18
  import pandas as pd
19
19
  from warnings import warn
20
20
  from numbers import Number
21
+ from typing import Callable, Literal
21
22
 
22
23
 
23
24
  class Quaternion(Base):
@@ -179,10 +180,10 @@ class Quaternion(Base):
179
180
  def body_rotate(self, rate: Point) -> Quaternion:
180
181
  return (self * Quaternion.from_axis_angle(rate)).norm()
181
182
 
182
- def diff(self, dt: Number | npt.NDArray) -> Point:
183
+ def diff(self, dt: Number | npt.NDArray = None) -> Point:
183
184
  """differentiate in the world frame"""
184
185
  if not pd.api.types.is_list_like(dt):
185
- dt = np.full(len(self), dt)
186
+ dt = np.full(len(self), 1 if not dt else dt)
186
187
  assert len(dt) == len(self)
187
188
  dt = dt * len(dt) / (len(dt) - 1)
188
189
 
@@ -192,10 +193,10 @@ class Quaternion(Base):
192
193
  ) / dt[:-1]
193
194
  return Point(np.vstack([ps.data, ps.data[-1,:]]))
194
195
 
195
- def body_diff(self, dt: Number | npt.NDArray) -> Point:
196
+ def body_diff(self, dt: Number | npt.NDArray = None) -> Point:
196
197
  """differentiate in the body frame"""
197
198
  if not pd.api.types.is_list_like(dt):
198
- dt = np.full(len(self), dt)
199
+ dt = np.full(len(self), 1 if not dt else dt)
199
200
  assert len(dt) == len(self)
200
201
  dt = dt * len(dt) / (len(dt) - 1)
201
202
 
@@ -258,7 +259,68 @@ class Quaternion(Base):
258
259
  p = Point.X()
259
260
  return self.transform_point(p).bearing()
260
261
 
261
-
262
+ def slerp(self, index: pd.Index | npt.NDArray = None, extrapolate:Literal["throw", "nearest"]="throw"):
263
+ index = pd.Index(np.arange(len(self)) if index is None else index)
264
+
265
+ assert len(index) == len(self)
266
+ assert pd.Index(index).is_monotonic_increasing
267
+ from rowan.interpolate import slerp
268
+ def doslerp(ts: npt.NDArray | Number) -> Quaternion:
269
+ starts = index.get_indexer(ts, method='ffill')
270
+ stops = index.get_indexer(ts, method='bfill')
271
+
272
+ #case interpolate match (start == stop - 1)
273
+ odata = slerp(
274
+ self[starts].to_numpy("xyzw"),
275
+ self[stops].to_numpy("xyzw"),
276
+ (ts - index[starts]) / (index[stops] - index[starts]),
277
+ True
278
+ )
279
+
280
+ #case exact match (start == stop)
281
+ exacts = starts == stops
282
+ odata[exacts] = self.to_numpy("xyzw")[starts[exacts]]
283
+
284
+ #case outside range above (start == index[-1], stop== -1)
285
+ aboves = stops==-1
286
+ if np.any(aboves):
287
+ if extrapolate=="throw":
288
+ raise Exception("Cannot slerp beyond range")
289
+ else:
290
+ odata[aboves] = self.to_numpy("xyzw")[-1, :]
291
+ #case outside range below (start == -1, stop==index[0])
292
+ belows = starts==-1
293
+ if np.any(belows):
294
+ if extrapolate=="throw":
295
+ raise Exception("Cannot slerp beyond range")
296
+ else:
297
+ odata[belows] = self.to_numpy("xyzw")[0, :]
298
+
299
+ return Quaternion.from_numpy( odata, "xyzw")
300
+
301
+ return doslerp
302
+
303
+
304
+ # @staticmethod
305
+ # def slerp(a: Quaternion, b: Quaternion):
306
+ # """spherical linear interpolation"""
307
+ # from rowan.interpolate import slerp
308
+ # def doslerp(t):
309
+ # xyzw = slerp(a.xyzw, b.xyzw, np.clip(t, 0, 1))
310
+ # return Quaternion(xyzw[:,3], xyzw[:,0], xyzw[:,1], xyzw[:,2])
311
+ # return doslerp
312
+
313
+ @staticmethod
314
+ def squad(p: Quaternion, a: Quaternion, b: Quaternion, q: Quaternion):
315
+ from rowan.interpolate import squad
316
+ def dosquad(t):
317
+ xyzq = squad(p.xyzw, a.xyzw, b.xyzw, q.xyzw, np.clip(t, 0, 1))
318
+ return Quaternion(xyzq[:,3], xyzq[:,0], xyzq[:,1], xyzq[:,2])
319
+ return dosquad
320
+
321
+ def plot_3d(self, size: float=3, vis:Literal["coord", "plane"]="coord"):
322
+ from geometry import Transformation
323
+ return Transformation(self).plot_3d(size, vis)
262
324
 
263
325
  def Q0(count=1):
264
326
  return Quaternion.zero(count)