swcgeom 0.19.4__cp312-cp312-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of swcgeom might be problematic. Click here for more details.

Files changed (72) hide show
  1. swcgeom/__init__.py +21 -0
  2. swcgeom/analysis/__init__.py +13 -0
  3. swcgeom/analysis/feature_extractor.py +454 -0
  4. swcgeom/analysis/features.py +218 -0
  5. swcgeom/analysis/lmeasure.py +750 -0
  6. swcgeom/analysis/sholl.py +201 -0
  7. swcgeom/analysis/trunk.py +183 -0
  8. swcgeom/analysis/visualization.py +191 -0
  9. swcgeom/analysis/visualization3d.py +81 -0
  10. swcgeom/analysis/volume.py +143 -0
  11. swcgeom/core/__init__.py +19 -0
  12. swcgeom/core/branch.py +129 -0
  13. swcgeom/core/branch_tree.py +65 -0
  14. swcgeom/core/compartment.py +107 -0
  15. swcgeom/core/node.py +130 -0
  16. swcgeom/core/path.py +155 -0
  17. swcgeom/core/population.py +341 -0
  18. swcgeom/core/swc.py +247 -0
  19. swcgeom/core/swc_utils/__init__.py +19 -0
  20. swcgeom/core/swc_utils/assembler.py +35 -0
  21. swcgeom/core/swc_utils/base.py +180 -0
  22. swcgeom/core/swc_utils/checker.py +107 -0
  23. swcgeom/core/swc_utils/io.py +204 -0
  24. swcgeom/core/swc_utils/normalizer.py +163 -0
  25. swcgeom/core/swc_utils/subtree.py +70 -0
  26. swcgeom/core/tree.py +384 -0
  27. swcgeom/core/tree_utils.py +277 -0
  28. swcgeom/core/tree_utils_impl.py +58 -0
  29. swcgeom/images/__init__.py +9 -0
  30. swcgeom/images/augmentation.py +149 -0
  31. swcgeom/images/contrast.py +87 -0
  32. swcgeom/images/folder.py +217 -0
  33. swcgeom/images/io.py +578 -0
  34. swcgeom/images/loaders/__init__.py +8 -0
  35. swcgeom/images/loaders/pbd.cpython-312-darwin.so +0 -0
  36. swcgeom/images/loaders/pbd.pyx +523 -0
  37. swcgeom/images/loaders/raw.cpython-312-darwin.so +0 -0
  38. swcgeom/images/loaders/raw.pyx +183 -0
  39. swcgeom/transforms/__init__.py +20 -0
  40. swcgeom/transforms/base.py +136 -0
  41. swcgeom/transforms/branch.py +223 -0
  42. swcgeom/transforms/branch_tree.py +74 -0
  43. swcgeom/transforms/geometry.py +270 -0
  44. swcgeom/transforms/image_preprocess.py +107 -0
  45. swcgeom/transforms/image_stack.py +219 -0
  46. swcgeom/transforms/images.py +206 -0
  47. swcgeom/transforms/mst.py +183 -0
  48. swcgeom/transforms/neurolucida_asc.py +498 -0
  49. swcgeom/transforms/path.py +56 -0
  50. swcgeom/transforms/population.py +36 -0
  51. swcgeom/transforms/tree.py +265 -0
  52. swcgeom/transforms/tree_assembler.py +161 -0
  53. swcgeom/utils/__init__.py +18 -0
  54. swcgeom/utils/debug.py +23 -0
  55. swcgeom/utils/download.py +119 -0
  56. swcgeom/utils/dsu.py +58 -0
  57. swcgeom/utils/ellipse.py +131 -0
  58. swcgeom/utils/file.py +90 -0
  59. swcgeom/utils/neuromorpho.py +581 -0
  60. swcgeom/utils/numpy_helper.py +70 -0
  61. swcgeom/utils/plotter_2d.py +134 -0
  62. swcgeom/utils/plotter_3d.py +35 -0
  63. swcgeom/utils/renderer.py +145 -0
  64. swcgeom/utils/sdf.py +324 -0
  65. swcgeom/utils/solid_geometry.py +154 -0
  66. swcgeom/utils/transforms.py +367 -0
  67. swcgeom/utils/volumetric_object.py +483 -0
  68. swcgeom-0.19.4.dist-info/METADATA +86 -0
  69. swcgeom-0.19.4.dist-info/RECORD +72 -0
  70. swcgeom-0.19.4.dist-info/WHEEL +5 -0
  71. swcgeom-0.19.4.dist-info/licenses/LICENSE +201 -0
  72. swcgeom-0.19.4.dist-info/top_level.txt +1 -0
@@ -0,0 +1,134 @@
1
+ # SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """2D Plotting utils."""
6
+
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import numpy.typing as npt
10
+ from matplotlib import cm
11
+ from matplotlib.axes import Axes
12
+ from matplotlib.collections import LineCollection, PatchCollection
13
+ from matplotlib.colors import Colormap, Normalize
14
+ from matplotlib.figure import Figure
15
+ from matplotlib.patches import Circle
16
+
17
+ from swcgeom.utils.renderer import Camera
18
+ from swcgeom.utils.transforms import to_homogeneous, translate3d
19
+
20
+ __all__ = ["draw_lines", "draw_direction_indicator", "draw_circles", "get_fig_ax"]
21
+
22
+
23
+ def draw_lines(
24
+ ax: Axes,
25
+ lines: npt.NDArray[np.floating],
26
+ camera: Camera,
27
+ joinstyle="round",
28
+ capstyle="round",
29
+ **kwargs,
30
+ ) -> LineCollection:
31
+ """Draw lines.
32
+
33
+ Args:
34
+ ax: The plot axes.
35
+ lines: A collection of coords of lines
36
+ Excepting a ndarray of shape (N, 2, 3), the axis-2 holds two points,
37
+ and the axis-3 holds the coordinates (x, y, z).
38
+ camera: Camera position.
39
+ **kwargs: Forwarded to `~matplotlib.collections.LineCollection`.
40
+ """
41
+ T = camera.MVP
42
+ T = translate3d(*camera.position).dot(T) # keep origin
43
+
44
+ starts, ends = lines[:, 0], lines[:, 1]
45
+ starts, ends = to_homogeneous(starts, 1), to_homogeneous(ends, 1)
46
+ starts, ends = np.dot(T, starts.T).T[:, 0:2], np.dot(T, ends.T).T[:, 0:2]
47
+
48
+ edges = np.stack([starts, ends], axis=1)
49
+ collection = LineCollection(edges, joinstyle=joinstyle, capstyle=capstyle, **kwargs) # type: ignore
50
+ return ax.add_collection(collection) # type: ignore
51
+
52
+
53
+ def draw_direction_indicator(
54
+ ax: Axes, camera: Camera, loc: tuple[float, float]
55
+ ) -> None:
56
+ x, y = loc
57
+ direction = camera.MV.dot(
58
+ [
59
+ [1, 0, 0, 1],
60
+ [0, 1, 0, 1],
61
+ [0, 0, 1, 1],
62
+ [0, 0, 0, 1],
63
+ ]
64
+ )
65
+
66
+ arrow_length, text_offset = 0.05, 0.05 # TODO: may still overlap
67
+ text_colors = [["x", "red"], ["y", "green"], ["z", "blue"]]
68
+ for (dx, dy, dz, _), (text, color) in zip(direction, text_colors):
69
+ if 1 - abs(dz) < 1e-5:
70
+ continue
71
+
72
+ ax.arrow(
73
+ x,
74
+ y,
75
+ arrow_length * dx,
76
+ arrow_length * dy,
77
+ head_length=0.02,
78
+ head_width=0.01,
79
+ color=color,
80
+ transform=ax.transAxes,
81
+ )
82
+
83
+ ax.text(
84
+ x + (arrow_length + text_offset) * dx,
85
+ y + (arrow_length + text_offset) * dy,
86
+ text,
87
+ color=color,
88
+ transform=ax.transAxes,
89
+ horizontalalignment="center",
90
+ verticalalignment="center",
91
+ )
92
+
93
+
94
+ def draw_circles(
95
+ ax: Axes,
96
+ x: npt.NDArray,
97
+ y: npt.NDArray,
98
+ *,
99
+ y_min: float | None = None,
100
+ y_max: float | None = None,
101
+ cmap: str | Colormap = "viridis",
102
+ ) -> PatchCollection:
103
+ """Draw a sequential of circles."""
104
+
105
+ y_min = y.min() if y_min is None else y_min
106
+ y_max = y.max() if y_max is None else y_max
107
+ norm = Normalize(y_min, y_max)
108
+
109
+ color_map = cmap if isinstance(cmap, Colormap) else cm.get_cmap(name=cmap)
110
+ colors = color_map(norm(y))
111
+
112
+ circles = [
113
+ Circle((0, 0), xi, color=color) for xi, color in reversed(list(zip(x, colors)))
114
+ ]
115
+ patches = PatchCollection(circles, match_original=True)
116
+ patches.set_cmap(color_map)
117
+ patches.set_norm(norm)
118
+ patches: PatchCollection = ax.add_collection(patches) # type: ignore
119
+
120
+ ax.set_aspect(1)
121
+ ax.autoscale()
122
+ return patches
123
+
124
+
125
+ def get_fig_ax(
126
+ fig: Figure | None = None, ax: Axes | None = None
127
+ ) -> tuple[Figure, Axes]:
128
+ if fig is None and ax is not None:
129
+ fig = ax.get_figure()
130
+ assert fig is not None, "expecting a figure from the axes"
131
+
132
+ fig = fig or plt.gcf()
133
+ ax = ax or plt.gca()
134
+ return fig, ax
@@ -0,0 +1,35 @@
1
+ # SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """3D Plotting utils."""
6
+
7
+ import numpy as np
8
+ import numpy.typing as npt
9
+ from mpl_toolkits.mplot3d import Axes3D
10
+ from mpl_toolkits.mplot3d.art3d import Line3DCollection
11
+
12
+ __all__ = ["draw_lines_3d"]
13
+
14
+
15
+ def draw_lines_3d(
16
+ ax: Axes3D,
17
+ lines: npt.NDArray[np.floating],
18
+ joinstyle="round",
19
+ capstyle="round",
20
+ **kwargs,
21
+ ):
22
+ """Draw lines.
23
+
24
+ Args:
25
+ ax: The plot axes.
26
+ lines: A collection of coords of lines
27
+ Excepting a ndarray of shape (N, 2, 3), the axis-2 holds two points,
28
+ and the axis-3 holds the coordinates (x, y, z).
29
+ **kwargs: Forwarded to `~mpl_toolkits.mplot3d.art3d.Line3DCollection`.
30
+ """
31
+
32
+ line_collection = Line3DCollection(
33
+ lines, joinstyle=joinstyle, capstyle=capstyle, **kwargs
34
+ )
35
+ return ax.add_collection3d(line_collection)
@@ -0,0 +1,145 @@
1
+ # SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """Rendering related utils."""
6
+
7
+ from functools import cached_property
8
+ from typing import Literal, cast
9
+
10
+ import numpy as np
11
+ import numpy.typing as npt
12
+ from typing_extensions import Self
13
+
14
+ from swcgeom.utils.transforms import (
15
+ Vec3f,
16
+ model_view_transformation,
17
+ orthographic_projection_simple,
18
+ )
19
+
20
+ __all__ = ["CameraOptions", "Camera", "SimpleCamera", "palette"]
21
+
22
+ CameraOption = Vec3f | tuple[Vec3f, Vec3f] | tuple[Vec3f, Vec3f, Vec3f]
23
+ CameraPreset = Literal["xy", "yz", "zx", "yx", "zy", "xz"]
24
+ CameraPresets: dict[CameraPreset, tuple[Vec3f, Vec3f, Vec3f]] = {
25
+ "xy": ((0.0, 0.0, 0.0), (+0.0, +0.0, -1.0), (+0.0, +1.0, +0.0)),
26
+ "yz": ((0.0, 0.0, 0.0), (-1.0, +0.0, +0.0), (+0.0, +0.0, +1.0)),
27
+ "zx": ((0.0, 0.0, 0.0), (+0.0, -1.0, +0.0), (+1.0, +0.0, +0.0)),
28
+ "yx": ((0.0, 0.0, 0.0), (+0.0, +0.0, -1.0), (+0.0, -1.0, +0.0)),
29
+ "zy": ((0.0, 0.0, 0.0), (-1.0, +0.0, +0.0), (+0.0, +0.0, -1.0)),
30
+ "xz": ((0.0, 0.0, 0.0), (+0.0, -1.0, +0.0), (-1.0, +0.0, +0.0)),
31
+ }
32
+ CameraOptions = CameraOption | CameraPreset
33
+
34
+
35
+ class Camera:
36
+ _position: Vec3f
37
+ _look_at: Vec3f
38
+ _up: Vec3f
39
+
40
+ @property
41
+ def position(self) -> Vec3f:
42
+ return self._position
43
+
44
+ @property
45
+ def look_at(self) -> Vec3f:
46
+ return self._look_at
47
+
48
+ @property
49
+ def up(self) -> Vec3f:
50
+ return self._up
51
+
52
+ @property
53
+ def MV(self) -> npt.NDArray[np.float32]:
54
+ raise NotImplementedError()
55
+
56
+ @property
57
+ def P(self) -> npt.NDArray[np.float32]:
58
+ raise NotImplementedError()
59
+
60
+ @property
61
+ def MVP(self) -> npt.NDArray[np.float32]:
62
+ return self.P.dot(self.MV)
63
+
64
+
65
+ class SimpleCamera(Camera):
66
+ """Simplest camera."""
67
+
68
+ def __init__(self, position: Vec3f, look_at: Vec3f, up: Vec3f):
69
+ self._position = position
70
+ self._look_at = look_at
71
+ self._up = up
72
+
73
+ @cached_property
74
+ def MV(self) -> npt.NDArray[np.float32]: # pylint: disable=invalid-name
75
+ return model_view_transformation(self.position, self.look_at, self.up)
76
+
77
+ @cached_property
78
+ def P(self) -> npt.NDArray[np.float32]: # pylint: disable=invalid-name
79
+ return orthographic_projection_simple()
80
+
81
+ @classmethod
82
+ def from_options(cls, camera: CameraOptions) -> Self:
83
+ if isinstance(camera, str):
84
+ return cls(*CameraPresets[camera])
85
+
86
+ if len(camera) == 2:
87
+ return cls((0, 0, 0), camera[0], camera[1])
88
+
89
+ if isinstance(camera[0], tuple):
90
+ return cls((0, 0, 0), cast(Vec3f, camera), (0, 1, 0))
91
+
92
+ return cls(*cast(tuple[Vec3f, Vec3f, Vec3f], camera))
93
+
94
+
95
+ class Palette:
96
+ """The palette provides default and vaa3d color matching."""
97
+
98
+ # pylint: disable=too-few-public-methods
99
+
100
+ default: dict[int, str]
101
+ vaa3d: dict[int, str]
102
+
103
+ def __init__(self):
104
+ default = [
105
+ "#F596AA", # momo,
106
+ "#867835", # kimirucha,
107
+ "#E2943B", # kuchiba,
108
+ "#00896C", # aotake,
109
+ "#B9887D", # mizugaki,
110
+ "#2EA9DF", # tsuyukusa,
111
+ "#66327C", # sumire,
112
+ "#52433D", # benikeshinezumi,
113
+ ]
114
+ self.default = dict(enumerate(default))
115
+
116
+ vaa3d = [
117
+ "#ffffff", # white, 0-undefined
118
+ "#141414", # black, 1-soma
119
+ "#c81400", # red, 2-axon
120
+ "#0014c8", # blue, 3-dendrite
121
+ "#c800c8", # purple, 4-apical dendrite
122
+ # the following is Hanchuan’s extended color. 090331
123
+ "#00c8c8", # cyan, 5
124
+ "#dcc800", # yellow, 6
125
+ "#00c814", # green, 7
126
+ "#bc5e25", # coffee, 8
127
+ "#b4c878", # asparagus, 9
128
+ "#fa6478", # salmon, 10
129
+ "#78c8c8", # ice, 11
130
+ "#6478c8", # orchid, 12
131
+ # the following is Hanchuan’s further extended color. 111003
132
+ "#ff80a8", # 13
133
+ "#80ffa8", # 14
134
+ "#80a8ff", # 15
135
+ "#a8ff80", # 16
136
+ "#ffa880", # 17
137
+ "#a880ff", # 18
138
+ "#000000", # 19 # totally black. PHC, 2012-02-15
139
+ # the following (20-275) is used for matlab heat map. 120209 by WYN
140
+ "#000083",
141
+ ]
142
+ self.vaa3d = dict(enumerate(vaa3d))
143
+
144
+
145
+ palette = Palette()
swcgeom/utils/sdf.py ADDED
@@ -0,0 +1,324 @@
1
+
2
+ # SPDX-FileCopyrightText: 2022 - 2025 Zexin Yuan <pypi@yzx9.xyz>
3
+ #
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ """Signed distance functions.
7
+
8
+ Refs: https://iquilezles.org/articles/distfunctions/
9
+
10
+ NOTE: This module has been deprecated since v0.14.0, and will be removed in the future,
11
+ use `sdflit` instead.
12
+ """
13
+
14
+ import warnings
15
+ from abc import ABC, abstractmethod
16
+ from collections.abc import Iterable
17
+
18
+ import numpy as np
19
+ import numpy.typing as npt
20
+ from typing_extensions import deprecated
21
+
22
+ from swcgeom.utils.solid_geometry import project_vector_on_plane
23
+
24
+ __all__ = [
25
+ "SDF",
26
+ "SDFUnion",
27
+ "SDFIntersection",
28
+ "SDFDifference",
29
+ "SDFCompose",
30
+ "SDFSphere",
31
+ "SDFFrustumCone",
32
+ "SDFRoundCone",
33
+ ]
34
+
35
+ # Axis-aligned bounding box, tuple of array of shape (3,)
36
+ AABB = tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]
37
+
38
+
39
+ class SDF(ABC):
40
+ """Signed distance functions."""
41
+
42
+ bounding_box: AABB | None = None
43
+
44
+ def __call__(self, p: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
45
+ return self.distance(p)
46
+
47
+ @abstractmethod
48
+ def distance(self, p: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
49
+ """Calculate signed distance.
50
+
51
+ Args:
52
+ p: Hit point p of shape (N, 3).
53
+
54
+ Returns:
55
+ distance: Distance array of shape (3,).
56
+ """
57
+ raise NotImplementedError()
58
+
59
+ def is_in(self, p: npt.NDArray[np.float32]) -> npt.NDArray[np.bool_]:
60
+ p = np.array(p, dtype=np.float32)
61
+ assert p.ndim == 2 and p.shape[1] == 3, "p should be array of shape (N, 3)"
62
+
63
+ in_box = self.is_in_bounding_box(p)
64
+ flags = np.full((p.shape[0]), False, dtype=np.bool_)
65
+ flags[in_box] = self.distance(p[in_box]) <= 0
66
+ return flags
67
+
68
+ def is_in_bounding_box(self, p: npt.NDArray[np.float32]) -> npt.NDArray[np.bool_]:
69
+ """Is p in bounding box.
70
+
71
+ Returns:
72
+ is_in: Array of shape (N,).
73
+ If bounding box is `None`, `True` will be returned.
74
+ """
75
+
76
+ if self.bounding_box is None:
77
+ return np.full((p.shape[0]), True, dtype=np.bool_)
78
+
79
+ is_in = np.logical_and(
80
+ np.all(p >= self.bounding_box[0], axis=1),
81
+ np.all(p <= self.bounding_box[1], axis=1),
82
+ )
83
+ return is_in
84
+
85
+
86
+ class SDFUnion(SDF):
87
+ """Union multiple SDFs."""
88
+
89
+ def __init__(self, *sdfs: SDF) -> None:
90
+ assert len(sdfs) != 0, "must combine at least one SDF"
91
+ super().__init__()
92
+
93
+ self.sdfs = sdfs
94
+
95
+ bounding_boxes = [sdf.bounding_box for sdf in sdfs if sdf.bounding_box]
96
+ if len(bounding_boxes) == len(self.sdfs):
97
+ self.bounding_box = (
98
+ np.min(np.stack([box[0] for box in bounding_boxes]).T, axis=1),
99
+ np.max(np.stack([box[1] for box in bounding_boxes]).T, axis=1),
100
+ )
101
+
102
+ def distance(self, p: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
103
+ return np.min([sdf(p) for sdf in self.sdfs], axis=0)
104
+
105
+ def is_in(self, p: npt.NDArray[np.float32]) -> npt.NDArray[np.bool_]:
106
+ p = np.array(p, dtype=np.float32)
107
+ assert p.ndim == 2 and p.shape[1] == 3, "p should be array of shape (N, 3)"
108
+
109
+ in_box = self.is_in_bounding_box(p)
110
+ p_in_box = p[in_box]
111
+ is_in = np.stack([sdf.is_in(p_in_box) for sdf in self.sdfs])
112
+ flags = np.full_like(in_box, False, dtype=np.bool_)
113
+ flags[in_box] = np.any(is_in, axis=0)
114
+ return flags
115
+
116
+
117
+ class SDFIntersection(SDF):
118
+ def __init__(self, *sdfs: SDF) -> None:
119
+ assert len(sdfs) != 0, "must intersect at least one SDF"
120
+ super().__init__()
121
+ self.sdfs = sdfs
122
+
123
+ bounding_boxes = [sdf.bounding_box for sdf in self.sdfs if sdf.bounding_box]
124
+ if len(bounding_boxes) == len(self.sdfs):
125
+ self.bounding_box = (
126
+ np.max(np.stack([box[0] for box in bounding_boxes]).T, axis=1),
127
+ np.min(np.stack([box[1] for box in bounding_boxes]).T, axis=1),
128
+ )
129
+
130
+ def distance(self, p: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
131
+ distances = np.stack([sdf.distance(p) for sdf in self.sdfs])
132
+ return np.max(distances, axis=1)
133
+
134
+ def is_in(self, p: npt.NDArray[np.float32]) -> npt.NDArray[np.bool_]:
135
+ p = np.array(p, dtype=np.float32)
136
+ assert p.ndim == 2 and p.shape[1] == 3, "p should be array of shape (N, 3)"
137
+
138
+ in_box = self.is_in_bounding_box(p)
139
+ p_in_box = p[in_box]
140
+ is_in = np.stack([sdf.is_in(p_in_box) for sdf in self.sdfs])
141
+ flags = np.full_like(in_box, False, dtype=np.bool_)
142
+ flags[in_box] = np.all(is_in, axis=0)
143
+ return flags
144
+
145
+
146
+ class SDFDifference(SDF):
147
+ """Difference of two SDFs A-B."""
148
+
149
+ def __init__(self, sdf_a: SDF, sdf_b: SDF) -> None:
150
+ super().__init__()
151
+ self.sdf_a = sdf_a
152
+ self.sdf_b = sdf_b
153
+
154
+ self.bounding_box = sdf_a.bounding_box
155
+
156
+ def distance(self, p: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
157
+ da = self.sdf_a.distance(p)
158
+ db = self.sdf_b.distance(p)
159
+ return np.maximum(da, -db)
160
+
161
+ def is_in(self, p: npt.NDArray[np.float32]) -> npt.NDArray[np.bool_]:
162
+ p = np.array(p, dtype=np.float32)
163
+ assert p.ndim == 2 and p.shape[1] == 3, "p should be array of shape (N, 3)"
164
+
165
+ in_box = self.is_in_bounding_box(p)
166
+ p_in_box = p[in_box]
167
+ is_in_a = self.sdf_a.is_in(p_in_box)
168
+ is_in_b = self.sdf_b.is_in(p_in_box)
169
+ flags = np.full_like(in_box, False, dtype=np.bool_)
170
+ flags[in_box] = np.logical_and(is_in_a, np.logical_not(is_in_b))
171
+ return flags
172
+
173
+
174
+ @deprecated("Use `SDFUnion` instead")
175
+ class SDFCompose(SDFUnion):
176
+ """Compose multiple SDFs.
177
+
178
+ .. deprecated:: 0.14.0
179
+ Use :cls:`SDFUnion` instead.
180
+ """
181
+
182
+ def __init__(self, sdfs: Iterable[SDF]) -> None:
183
+ sdfs = list(sdfs)
184
+ if len(sdfs) == 1:
185
+ warnings.warn("compose only one SDF, use SDFCompose.compose instead")
186
+
187
+ super().__init__(*sdfs)
188
+
189
+ @staticmethod
190
+ def compose(sdfs: Iterable[SDF]) -> SDF:
191
+ sdfs = list(sdfs)
192
+ return SDFCompose(sdfs) if len(sdfs) != 1 else sdfs[0]
193
+
194
+
195
+ class SDFSphere(SDF):
196
+ """SDF of sphere."""
197
+
198
+ def __init__(self, center: npt.ArrayLike, radius: float) -> None:
199
+ super().__init__()
200
+ self.center = np.array(center)
201
+ self.radius = radius
202
+ assert tuple(self.center.shape) == (3,), "center should be vector of 3d"
203
+
204
+ self.bounding_box = (self.center - self.radius, self.center + self.radius)
205
+
206
+ def distance(self, p: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
207
+ return np.linalg.norm(p - self.center, axis=1) - self.radius
208
+
209
+
210
+ class SDFFrustumCone(SDF):
211
+ """SDF of frustum cone."""
212
+
213
+ def __init__(
214
+ self, a: npt.ArrayLike, b: npt.ArrayLike, ra: float, rb: float
215
+ ) -> None:
216
+ super().__init__()
217
+ self.a = np.array(a)
218
+ self.b = np.array(b)
219
+ self.ra = ra
220
+ self.rb = rb
221
+ assert tuple(self.a.shape) == (3,), "a should be vector of 3d"
222
+ assert tuple(self.b.shape) == (3,), "b should be vector of 3d"
223
+
224
+ self.bounding_box = self.get_bounding_box()
225
+
226
+ def distance(self, p: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
227
+ a, b, ra, rb = self.a, self.b, self.ra, self.rb
228
+
229
+ rba = rb - ra
230
+ baba = np.dot(b - a, b - a)
231
+ papa = np.einsum("ij,ij->i", p - a, p - a)
232
+ paba = np.dot(p - a, b - a) / baba
233
+ # maybe negative due to numerical error
234
+ x = np.sqrt(np.maximum(papa - paba * paba * baba, 0))
235
+ cax = np.maximum(0.0, x - np.where(paba < 0.5, ra, rb))
236
+ cay = np.abs(paba - 0.5) - 0.5
237
+ k = rba * rba + baba
238
+ f = np.clip((rba * (x - ra) + paba * baba) / k, 0.0, 1.0)
239
+ cbx = x - ra - f * rba
240
+ cby = paba - f
241
+ s = np.where(np.logical_and(cbx < 0.0, cay < 0.0), -1.0, 1.0)
242
+ return s * np.sqrt(
243
+ np.minimum(cax * cax + cay * cay * baba, cbx * cbx + cby * cby * baba)
244
+ )
245
+
246
+ def get_bounding_box(self) -> AABB | None:
247
+ a, b, ra, rb = self.a, self.b, self.ra, self.rb
248
+ up = a - b
249
+ vx = project_vector_on_plane((1, 0, 0), up)
250
+ vy = project_vector_on_plane((0, 1, 0), up)
251
+ vz = project_vector_on_plane((0, 0, 1), up)
252
+ a1 = a - ra * vx - ra * vy - ra * vz
253
+ a2 = a + ra * vx + ra * vy + ra * vz
254
+ b1 = b - rb * vx - rb * vy - rb * vz
255
+ b2 = b + rb * vx + rb * vy + rb * vz
256
+ return (
257
+ np.minimum(a1, b1).astype(np.float32),
258
+ np.maximum(a2, b2).astype(np.float32),
259
+ )
260
+
261
+
262
+ class SDFRoundCone(SDF):
263
+ """Round cone is made up of two balls and a cylinder."""
264
+
265
+ def __init__(
266
+ self, a: npt.ArrayLike, b: npt.ArrayLike, ra: float, rb: float
267
+ ) -> None:
268
+ """SDF of round cone.
269
+
270
+ Args:
271
+ a, b: Coordinates of point A/B of shape (3,).
272
+ ra, rb: Radius of point A/B.
273
+ """
274
+
275
+ self.a = np.array(a, dtype=np.float32)
276
+ self.b = np.array(b, dtype=np.float32)
277
+ self.ra = ra
278
+ self.rb = rb
279
+
280
+ assert tuple(self.a.shape) == (3,), "a should be vector of 3d"
281
+ assert tuple(self.b.shape) == (3,), "b should be vector of 3d"
282
+
283
+ self.bounding_box = (
284
+ np.min([self.a - self.ra, self.b - self.rb], axis=0).astype(np.float32),
285
+ np.max([self.a + self.ra, self.b + self.rb], axis=0).astype(np.float32),
286
+ )
287
+
288
+ def distance(self, p: npt.ArrayLike) -> npt.NDArray[np.float32]:
289
+ # pylint: disable=too-many-locals
290
+ p = np.array(p, dtype=np.float32)
291
+ assert p.ndim == 2 and p.shape[1] == 3, "p should be array of shape (N, 3)"
292
+
293
+ a = self.a
294
+ b = self.b
295
+ ra = self.ra
296
+ rb = self.rb
297
+
298
+ # sampling independent computations (only depend on shape)
299
+ ba = b - a
300
+ l2 = np.dot(ba, ba)
301
+ rr = ra - rb
302
+ a2 = l2 - rr * rr
303
+ il2 = 1.0 / l2
304
+
305
+ # sampling dependant computations
306
+ pa = p - a
307
+ y = np.dot(pa, ba)
308
+ z = y - l2
309
+ x = pa * l2 - np.outer(y, ba)
310
+ x2 = np.sum(x * x, axis=1)
311
+ y2 = y * y * l2
312
+ z2 = z * z * l2
313
+
314
+ # single square root!
315
+ k = np.sign(rr) * rr * rr * x2
316
+ dis = (np.sqrt(x2 * a2 * il2) + y * rr) * il2 - ra
317
+
318
+ lt = np.sign(z) * a2 * z2 > k
319
+ dis[lt] = np.sqrt(x2[lt] + z2[lt]) * il2 - rb
320
+
321
+ rt = np.sign(y) * a2 * y2 < k
322
+ dis[rt] = np.sqrt(x2[rt] + y2[rt]) * il2 - ra
323
+
324
+ return dis