swcgeom 0.18.1__py3-none-any.whl → 0.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of swcgeom might be problematic. Click here for more details.

Files changed (68) hide show
  1. swcgeom/__init__.py +12 -1
  2. swcgeom/analysis/__init__.py +6 -6
  3. swcgeom/analysis/feature_extractor.py +22 -24
  4. swcgeom/analysis/features.py +18 -40
  5. swcgeom/analysis/lmeasure.py +227 -323
  6. swcgeom/analysis/sholl.py +17 -23
  7. swcgeom/analysis/trunk.py +23 -28
  8. swcgeom/analysis/visualization.py +37 -44
  9. swcgeom/analysis/visualization3d.py +16 -25
  10. swcgeom/analysis/volume.py +33 -47
  11. swcgeom/core/__init__.py +12 -13
  12. swcgeom/core/branch.py +10 -17
  13. swcgeom/core/branch_tree.py +3 -2
  14. swcgeom/core/compartment.py +1 -1
  15. swcgeom/core/node.py +3 -6
  16. swcgeom/core/path.py +11 -16
  17. swcgeom/core/population.py +32 -51
  18. swcgeom/core/swc.py +25 -16
  19. swcgeom/core/swc_utils/__init__.py +10 -12
  20. swcgeom/core/swc_utils/assembler.py +5 -12
  21. swcgeom/core/swc_utils/base.py +40 -31
  22. swcgeom/core/swc_utils/checker.py +3 -8
  23. swcgeom/core/swc_utils/io.py +32 -47
  24. swcgeom/core/swc_utils/normalizer.py +17 -23
  25. swcgeom/core/swc_utils/subtree.py +13 -20
  26. swcgeom/core/tree.py +61 -51
  27. swcgeom/core/tree_utils.py +36 -49
  28. swcgeom/core/tree_utils_impl.py +4 -6
  29. swcgeom/images/__init__.py +2 -2
  30. swcgeom/images/augmentation.py +23 -39
  31. swcgeom/images/contrast.py +22 -46
  32. swcgeom/images/folder.py +32 -34
  33. swcgeom/images/io.py +80 -121
  34. swcgeom/transforms/__init__.py +13 -13
  35. swcgeom/transforms/base.py +28 -19
  36. swcgeom/transforms/branch.py +31 -41
  37. swcgeom/transforms/branch_tree.py +3 -1
  38. swcgeom/transforms/geometry.py +13 -4
  39. swcgeom/transforms/image_preprocess.py +2 -0
  40. swcgeom/transforms/image_stack.py +40 -35
  41. swcgeom/transforms/images.py +31 -24
  42. swcgeom/transforms/mst.py +27 -40
  43. swcgeom/transforms/neurolucida_asc.py +13 -13
  44. swcgeom/transforms/path.py +4 -0
  45. swcgeom/transforms/population.py +4 -0
  46. swcgeom/transforms/tree.py +16 -11
  47. swcgeom/transforms/tree_assembler.py +37 -54
  48. swcgeom/utils/__init__.py +12 -12
  49. swcgeom/utils/download.py +7 -14
  50. swcgeom/utils/dsu.py +12 -0
  51. swcgeom/utils/ellipse.py +26 -14
  52. swcgeom/utils/file.py +8 -13
  53. swcgeom/utils/neuromorpho.py +78 -92
  54. swcgeom/utils/numpy_helper.py +15 -12
  55. swcgeom/utils/plotter_2d.py +10 -16
  56. swcgeom/utils/plotter_3d.py +7 -9
  57. swcgeom/utils/renderer.py +16 -8
  58. swcgeom/utils/sdf.py +12 -23
  59. swcgeom/utils/solid_geometry.py +58 -2
  60. swcgeom/utils/transforms.py +164 -100
  61. swcgeom/utils/volumetric_object.py +29 -53
  62. {swcgeom-0.18.1.dist-info → swcgeom-0.19.0.dist-info}/METADATA +7 -6
  63. swcgeom-0.19.0.dist-info/RECORD +67 -0
  64. {swcgeom-0.18.1.dist-info → swcgeom-0.19.0.dist-info}/WHEEL +1 -1
  65. swcgeom/_version.py +0 -16
  66. swcgeom-0.18.1.dist-info/RECORD +0 -68
  67. {swcgeom-0.18.1.dist-info → swcgeom-0.19.0.dist-info/licenses}/LICENSE +0 -0
  68. {swcgeom-0.18.1.dist-info → swcgeom-0.19.0.dist-info}/top_level.txt +0 -0
@@ -15,10 +15,7 @@
15
15
 
16
16
  """NeuroMorpho.org.
17
17
 
18
- Examples
19
- --------
20
-
21
- Metadata:
18
+ Metadata Example:
22
19
 
23
20
  ```json
24
21
  {
@@ -80,9 +77,7 @@ Metadata:
80
77
  }
81
78
  ```
82
79
 
83
- Notes
84
- -----
85
- All denpendencies need to be installed, try:
80
+ NOTE: All denpendencies need to be installed, try:
86
81
 
87
82
  ```sh
88
83
  pip install swcgeom[all]
@@ -97,7 +92,7 @@ import math
97
92
  import os
98
93
  import urllib.parse
99
94
  from collections.abc import Callable, Iterable
100
- from typing import Any, Literal, Optional
95
+ from typing import Any, Literal
101
96
 
102
97
  from tqdm import tqdm
103
98
 
@@ -140,7 +135,7 @@ DOWNLOAD_CONFIGS: dict[RESOURCES, tuple[str, int]] = {
140
135
  "log_source": (URL_LOG_SOURCE, 512 * GB),
141
136
  }
142
137
 
143
- # fmt:off
138
+ # fmt: off
144
139
  # Test version: 8.5.25 (2023-08-01)
145
140
  # No ETAs for future version
146
141
  invalid_ids = [
@@ -166,7 +161,7 @@ def neuromorpho_is_valid(metadata: dict[str, Any]) -> bool:
166
161
 
167
162
 
168
163
  def neuromorpho_convert_lmdb_to_swc(
169
- root: str, dest: Optional[str] = None, *, verbose: bool = False, **kwargs
164
+ root: str, dest: str | None = None, *, verbose: bool = False, **kwargs
170
165
  ) -> None:
171
166
  nmo = NeuroMorpho(root, verbose=verbose)
172
167
  nmo.convert_lmdb_to_swc(dest, **kwargs)
@@ -182,11 +177,9 @@ class NeuroMorpho:
182
177
  self, root: str, *, url_base: str = URL_BASE, verbose: bool = False
183
178
  ) -> None:
184
179
  """
185
- Parameters
186
- ----------
187
- root : str
188
- verbose : bool, default False
189
- Show verbose log.
180
+ Args:
181
+ root: str
182
+ verbose: Show verbose log.
190
183
  """
191
184
 
192
185
  super().__init__()
@@ -252,34 +245,15 @@ class NeuroMorpho:
252
245
  # pylint: disable-next=too-many-locals
253
246
  def convert_lmdb_to_swc(
254
247
  self,
255
- dest: Optional[str] = None,
248
+ dest: str | None = None,
256
249
  *,
257
- group_by: Optional[str | Callable[[dict[str, Any]], str | None]] = None,
258
- where: Optional[Callable[[dict[str, Any]], bool]] = None,
250
+ group_by: str | Callable[[dict[str, Any]], str | None] | None = None,
251
+ where: Callable[[dict[str, Any]], bool] | None = None,
259
252
  encoding: str | None = "utf-8",
260
253
  ) -> None:
261
254
  r"""Convert lmdb format to SWCs.
262
255
 
263
- Parameters
264
- ----------
265
- path : str
266
- dest : str, optional
267
- If None, use `path/swc`.
268
- group_by : str | (metadata: dict[str, Any]) -> str | None, optional
269
- Group neurons by metadata. If a None is returned then no
270
- grouping. If a string is entered, use it as a metadata
271
- attribute name for grouping, e.g.: `archive`, `species`.
272
- where : (metadata: dict[str, Any]) -> bool, optional
273
- Filter neurons by metadata.
274
- encoding : str | None, default to `utf-8`
275
- Change swc encoding, part of the original data is not utf-8
276
- encoded. If is None, keep the original encoding format.
277
- verbose : bool, default False
278
- Print verbose info.
279
-
280
- Notes
281
- -----
282
- We are asserting the following folder.
256
+ NOTE: We are asserting the following folder.
283
257
 
284
258
  ```text
285
259
  |- root
@@ -289,10 +263,23 @@ class NeuroMorpho:
289
263
  | | |- groups # output of groups if grouped
290
264
  ```
291
265
 
292
- See Also
293
- --------
294
- neuromorpho_is_valid :
295
- Recommended filter function, try `where=neuromorpho_is_valid`
266
+ Args:
267
+ path: str
268
+ dest: If None, use `path/swc`.
269
+ group_by: Group neurons by metadata.
270
+ If None, no grouping. If a string is entered, use it as a metadata
271
+ attribute name for grouping, e.g.: `archive`, `species`. If a callable
272
+ is entered, use it as a function `(metadata: dict[str, Any]) -> str | None\
273
+ to get the group name.
274
+ where: Filter neurons by metadata.
275
+ (metadata: dict[str, Any]) -> bool
276
+ encoding: Change swc encoding, part of the original data is not utf-8 encoded.
277
+ If is None, keep the original encoding format.default to `utf-8`
278
+ verbose: Print verbose info.
279
+
280
+ See Also:
281
+ neuromorpho_is_valid:
282
+ Recommended filter function, try `where=neuromorpho_is_valid`
296
283
  """
297
284
 
298
285
  import lmdb
@@ -302,9 +289,19 @@ class NeuroMorpho:
302
289
  where = where or (lambda _: True)
303
290
  if isinstance(group_by, str):
304
291
  key = group_by
305
- group_by = lambda v: v[key] # pylint: disable=unnecessary-lambda-assignment
292
+
293
+ def group_by_key(v):
294
+ return v[key]
295
+
296
+ group_by = group_by_key
297
+
306
298
  elif group_by is None:
307
- group_by = lambda _: None # pylint: disable=unnecessary-lambda-assignment
299
+
300
+ def no_group(v):
301
+ return None
302
+
303
+ group_by = no_group
304
+
308
305
  items = []
309
306
  for k, v in tx_m.cursor():
310
307
  metadata = json.loads(v)
@@ -336,9 +333,9 @@ class NeuroMorpho:
336
333
 
337
334
  if encoding is None:
338
335
  with open(fs, "wb") as f:
339
- f.write(bs) # type: ignore
336
+ f.write(bs)
340
337
  else:
341
- bs = io.BytesIO(bs) # type: ignore
338
+ bs = io.BytesIO(bs)
342
339
  with (
343
340
  open(fs, "w", encoding=encoding) as fw,
344
341
  FileReader(bs, encoding="detect") as fr,
@@ -355,27 +352,20 @@ class NeuroMorpho:
355
352
  self,
356
353
  path: str,
357
354
  *,
358
- pages: Optional[Iterable[int]] = None,
355
+ pages: Iterable[int] | None = None,
359
356
  page_size: int = API_PAGE_SIZE_MAX,
360
357
  **kwargs,
361
358
  ) -> list[int]:
362
359
  r"""Download all neuron metadata.
363
360
 
364
- Parameters
365
- ----------
366
- path : str
367
- Path to save data.
368
- pages : List of int, optional
369
- If is None, download all pages.
370
- verbose : bool, default False
371
- Show verbose log.
372
- **kwargs :
373
- Forwarding to `get`.
374
-
375
- Returns
376
- -------
377
- err_pages : List of int
378
- Failed pages.
361
+ Args:
362
+ path: Path to save data.
363
+ pages: If is None, download all pages.
364
+ verbose: Show verbose log.
365
+ **kwargs: Forwarding to `get`.
366
+
367
+ Returns:
368
+ err_pages: Failed pages.
379
369
  """
380
370
 
381
371
  # TODO: how to cache between versions?
@@ -410,32 +400,24 @@ class NeuroMorpho:
410
400
  path: str,
411
401
  path_metadata: str,
412
402
  *,
413
- keys: Optional[Iterable[bytes]] = None,
403
+ keys: Iterable[bytes] | None = None,
414
404
  override: bool = False,
415
405
  map_size: int = 512 * GB,
416
406
  **kwargs,
417
407
  ) -> list[bytes]:
418
408
  """Download files.
419
409
 
420
- Parameters
421
- ----------
422
- url : str
423
- path : str
424
- Path to save data.
425
- path_metadata : str
426
- Path to lmdb of metadata.
427
- keys : List of bytes, optional
428
- If exist, ignore `override` option. If None, download all key.
429
- override : bool, default False
430
- Override even exists.
431
- map_size : int, default 512GB
432
- **kwargs :
433
- Forwarding to `get`.
434
-
435
- Returns
436
- -------
437
- err_keys : List of str
438
- Failed keys.
410
+ Args:
411
+ url: URL of file.
412
+ path: Path to save data.
413
+ path_metadata: Path to lmdb of metadata.
414
+ keys: If exist, ignore `override` option. If None, download all key.
415
+ override: Override even exists, default to False
416
+ map_size: int, default 512GB
417
+ **kwargs: Forwarding to `get`.
418
+
419
+ Returns:
420
+ err_keys: Failed keys.
439
421
  """
440
422
 
441
423
  import lmdb
@@ -445,16 +427,16 @@ class NeuroMorpho:
445
427
  if keys is None:
446
428
  with env_m.begin() as tx_m:
447
429
  if override:
448
- keys = [k for k, v in tx_m.cursor()]
430
+ keys = [k for k, _ in tx_m.cursor()]
449
431
  else:
450
432
  with env_c.begin() as tx:
451
- keys = [k for k, v in tx_m.cursor() if tx.get(k) is None]
433
+ keys = [k for k, _ in tx_m.cursor() if tx.get(k) is None]
452
434
 
453
435
  err_keys = []
454
436
  for k in tqdm(keys) if self.verbose else keys:
455
437
  try:
456
438
  with env_m.begin() as tx:
457
- metadata = json.loads(tx.get(k).decode("utf-8")) # type: ignore
439
+ metadata = json.loads(tx.get(k).decode("utf-8"))
458
440
 
459
441
  swc = self._get_file(url, metadata, **kwargs)
460
442
  with env_c.begin(write=True) as tx:
@@ -485,10 +467,8 @@ class NeuroMorpho:
485
467
  def _get_file(self, url: str, metadata: dict[str, Any], **kwargs) -> bytes:
486
468
  """Get file.
487
469
 
488
- Returns
489
- -------
490
- bs : bytes
491
- Bytes of morphology file, encoding is NOT FIXED.
470
+ Returns:
471
+ bs: Bytes of morphology file, encoding is NOT FIXED.
492
472
  """
493
473
 
494
474
  archive = urllib.parse.quote(metadata["archive"].lower())
@@ -502,7 +482,7 @@ class NeuroMorpho:
502
482
  return self._get(url, **kwargs)
503
483
 
504
484
  def _get(
505
- self, url: str, *, timeout: int = 2 * 60, proxy: Optional[str] = None
485
+ self, url: str, *, timeout: int = 2 * 60, proxy: str | None = None
506
486
  ) -> bytes:
507
487
  if not url.startswith("http://") and not url.startswith("https://"):
508
488
  url = urllib.parse.urljoin(self.url_base, url)
@@ -529,9 +509,15 @@ class NeuroMorpho:
529
509
  self.ssl_context = ssl_context
530
510
  super().__init__(**kwargs)
531
511
 
532
- def init_poolmanager(self, connections, maxsize, block=False):
512
+ def init_poolmanager(
513
+ self, connections, maxsize, block=False, **pool_kwargs
514
+ ):
533
515
  super().init_poolmanager(
534
- connections, maxsize, block, ssl_context=self.ssl_context
516
+ connections,
517
+ maxsize,
518
+ block,
519
+ ssl_context=self.ssl_context,
520
+ **pool_kwargs,
535
521
  )
536
522
 
537
523
  def proxy_manager_for(self, proxy, **proxy_kwargs):
@@ -32,18 +32,21 @@ def padding1d(
32
32
  ) -> npt.NDArray:
33
33
  """Padding x to array of shape (n,).
34
34
 
35
- Parameters
36
- ----------
37
- n : int
38
- Size of vector.
39
- v : np.ndarray, optional
40
- Input vector.
41
- padding_value : any, default to `0`.
42
- If x.shape[0] is less than n, the rest will be filled with
43
- padding value.
44
- dtype : np.DTypeLike, optional
45
- Data type of array. If specify, cast x to dtype, else dtype of
46
- x will used, otherwise defaults to `~numpy.float32`.
35
+ >>> padding1d(5, [1, 2, 3])
36
+ array([1., 2., 3., 0., 0.], dtype=float32)
37
+ >>> padding1d(5, [1, 2, 3], padding_value=6)
38
+ array([1., 2., 3., 6., 6.], dtype=float32)
39
+ >>> padding1d(5, [1, 2, 3], dtype=np.int64)
40
+ array([1, 2, 3, 0, 0])
41
+
42
+ Args:
43
+ n: Size of vector.
44
+ v: Input vector.
45
+ padding_value: Padding value.
46
+ If x.shape[0] is less than n, the rest will be filled with padding value.
47
+ dtype: Data type of array.
48
+ If specify, cast x to dtype, else dtype of x will used, otherwise defaults
49
+ to `~numpy.float32`.
47
50
  """
48
51
 
49
52
  if not isinstance(v, np.ndarray):
@@ -15,8 +15,6 @@
15
15
 
16
16
  """2D Plotting utils."""
17
17
 
18
- from typing import Optional
19
-
20
18
  import matplotlib.pyplot as plt
21
19
  import numpy as np
22
20
  import numpy.typing as npt
@@ -43,18 +41,14 @@ def draw_lines(
43
41
  ) -> LineCollection:
44
42
  """Draw lines.
45
43
 
46
- Parameters
47
- ----------
48
- ax : ~matplotlib.axes.Axes
49
- lines : A collection of coords of lines
50
- Excepting a ndarray of shape (N, 2, 3), the axis-2 holds two points,
51
- and the axis-3 holds the coordinates (x, y, z).
52
- camera : Camera
53
- Camera position.
54
- **kwargs : dict[str, Unknown]
55
- Forwarded to `~matplotlib.collections.LineCollection`.
44
+ Args:
45
+ ax: The plot axes.
46
+ lines: A collection of coords of lines
47
+ Excepting a ndarray of shape (N, 2, 3), the axis-2 holds two points,
48
+ and the axis-3 holds the coordinates (x, y, z).
49
+ camera: Camera position.
50
+ **kwargs: Forwarded to `~matplotlib.collections.LineCollection`.
56
51
  """
57
-
58
52
  T = camera.MVP
59
53
  T = translate3d(*camera.position).dot(T) # keep origin
60
54
 
@@ -113,8 +107,8 @@ def draw_circles(
113
107
  x: npt.NDArray,
114
108
  y: npt.NDArray,
115
109
  *,
116
- y_min: Optional[float] = None,
117
- y_max: Optional[float] = None,
110
+ y_min: float | None = None,
111
+ y_max: float | None = None,
118
112
  cmap: str | Colormap = "viridis",
119
113
  ) -> PatchCollection:
120
114
  """Draw a sequential of circles."""
@@ -140,7 +134,7 @@ def draw_circles(
140
134
 
141
135
 
142
136
  def get_fig_ax(
143
- fig: Optional[Figure] = None, ax: Optional[Axes] = None
137
+ fig: Figure | None = None, ax: Axes | None = None
144
138
  ) -> tuple[Figure, Axes]:
145
139
  if fig is None and ax is not None:
146
140
  fig = ax.get_figure()
@@ -32,17 +32,15 @@ def draw_lines_3d(
32
32
  ):
33
33
  """Draw lines.
34
34
 
35
- Parameters
36
- ----------
37
- ax : ~matplotlib.axes.Axes
38
- lines : A collection of coords of lines
39
- Excepting a ndarray of shape (N, 2, 3), the axis-2 holds two points,
40
- and the axis-3 holds the coordinates (x, y, z).
41
- **kwargs : dict[str, Unknown]
42
- Forwarded to `~mpl_toolkits.mplot3d.art3d.Line3DCollection`.
35
+ Args:
36
+ ax: The plot axes.
37
+ lines: A collection of coords of lines
38
+ Excepting a ndarray of shape (N, 2, 3), the axis-2 holds two points,
39
+ and the axis-3 holds the coordinates (x, y, z).
40
+ **kwargs: Forwarded to `~mpl_toolkits.mplot3d.art3d.Line3DCollection`.
43
41
  """
44
42
 
45
43
  line_collection = Line3DCollection(
46
44
  lines, joinstyle=joinstyle, capstyle=capstyle, **kwargs
47
- ) # type: ignore
45
+ )
48
46
  return ax.add_collection3d(line_collection)
swcgeom/utils/renderer.py CHANGED
@@ -48,21 +48,29 @@ class Camera:
48
48
  _look_at: Vec3f
49
49
  _up: Vec3f
50
50
 
51
- # fmt: off
52
51
  @property
53
- def position(self) -> Vec3f: return self._position
52
+ def position(self) -> Vec3f:
53
+ return self._position
54
+
54
55
  @property
55
- def look_at(self) -> Vec3f: return self._look_at
56
+ def look_at(self) -> Vec3f:
57
+ return self._look_at
58
+
56
59
  @property
57
- def up(self) -> Vec3f: return self._up
60
+ def up(self) -> Vec3f:
61
+ return self._up
58
62
 
59
63
  @property
60
- def MV(self) -> npt.NDArray[np.float32]: raise NotImplementedError()
64
+ def MV(self) -> npt.NDArray[np.float32]:
65
+ raise NotImplementedError()
66
+
61
67
  @property
62
- def P(self) -> npt.NDArray[np.float32]: raise NotImplementedError()
68
+ def P(self) -> npt.NDArray[np.float32]:
69
+ raise NotImplementedError()
70
+
63
71
  @property
64
- def MVP(self) -> npt.NDArray[np.float32]: return self.P.dot(self.MV)
65
- # fmt: on
72
+ def MVP(self) -> npt.NDArray[np.float32]:
73
+ return self.P.dot(self.MV)
66
74
 
67
75
 
68
76
  class SimpleCamera(Camera):
swcgeom/utils/sdf.py CHANGED
@@ -17,10 +17,8 @@
17
17
 
18
18
  Refs: https://iquilezles.org/articles/distfunctions/
19
19
 
20
- Note
21
- ----
22
- This module has been deprecated since v0.14.0, and will be removed in
23
- the future, use `sdflit` instead.
20
+ NOTE: This module has been deprecated since v0.14.0, and will be removed in the future,
21
+ use `sdflit` instead.
24
22
  """
25
23
 
26
24
  import warnings
@@ -60,15 +58,11 @@ class SDF(ABC):
60
58
  def distance(self, p: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
61
59
  """Calculate signed distance.
62
60
 
63
- Parmeters
64
- ---------
65
- p: ArrayLike
66
- Hit point p of shape (N, 3).
61
+ Args:
62
+ p: Hit point p of shape (N, 3).
67
63
 
68
- Returns
69
- -------
70
- distance : npt.NDArray[np.float32]
71
- Distance array of shape (3,).
64
+ Returns:
65
+ distance: Distance array of shape (3,).
72
66
  """
73
67
  raise NotImplementedError()
74
68
 
@@ -84,11 +78,9 @@ class SDF(ABC):
84
78
  def is_in_bounding_box(self, p: npt.NDArray[np.float32]) -> npt.NDArray[np.bool_]:
85
79
  """Is p in bounding box.
86
80
 
87
- Returns
88
- -------
89
- is_in : npt.NDArray[np.bool_]
90
- Array of shape (N,), if bounding box is `None`, `True` will
91
- be returned.
81
+ Returns:
82
+ is_in: Array of shape (N,).
83
+ If bounding box is `None`, `True` will be returned.
92
84
  """
93
85
 
94
86
  if self.bounding_box is None:
@@ -285,12 +277,9 @@ class SDFRoundCone(SDF):
285
277
  ) -> None:
286
278
  """SDF of round cone.
287
279
 
288
- Parmeters
289
- ---------
290
- a, b : ArrayLike
291
- Coordinates of point A/B of shape (3,).
292
- ra, rb : float
293
- Radius of point A/B.
280
+ Args:
281
+ a, b: Coordinates of point A/B of shape (3,).
282
+ ra, rb: Radius of point A/B.
294
283
  """
295
284
 
296
285
  self.a = np.array(a, dtype=np.float32)
@@ -28,6 +28,15 @@ __all__ = [
28
28
 
29
29
 
30
30
  def find_unit_vector_on_plane(normal_vec3: npt.NDArray) -> npt.NDArray:
31
+ """Find a random unit vector on the plane defined by the normal vector.
32
+
33
+ >>> normal = np.array([0, 0, 1])
34
+ >>> u = find_unit_vector_on_plane(normal)
35
+ >>> np.allclose(np.dot(u, normal), 0) # Should be perpendicular
36
+ True
37
+ >>> np.allclose(np.linalg.norm(u), 1) # Should be unit length
38
+ True
39
+ """
31
40
  r = np.random.rand(3)
32
41
  r /= np.linalg.norm(r)
33
42
  while np.allclose(r, normal_vec3) or np.allclose(r, -normal_vec3):
@@ -45,6 +54,21 @@ def find_sphere_line_intersection(
45
54
  line_point_a: npt.NDArray,
46
55
  line_point_b: npt.NDArray,
47
56
  ) -> list[tuple[float, npt.NDArray[np.float64]]]:
57
+ """Find intersection points between a sphere and a line.
58
+
59
+ >>> center = np.array([0, 0, 0])
60
+ >>> radius = 1.0
61
+ >>> p1 = np.array([-2, 0, 0])
62
+ >>> p2 = np.array([2, 0, 0])
63
+ >>> intersections = find_sphere_line_intersection(center, radius, p1, p2)
64
+ >>> len(intersections)
65
+ 2
66
+ >>> np.allclose(intersections[0][1], [-1, 0, 0])
67
+ True
68
+ >>> np.allclose(intersections[1][1], [1, 0, 0])
69
+ True
70
+ """
71
+
48
72
  A = np.array(line_point_a)
49
73
  B = np.array(line_point_b)
50
74
  C = np.array(sphere_center)
@@ -74,8 +98,20 @@ def find_sphere_line_intersection(
74
98
 
75
99
 
76
100
  def project_point_on_line(
77
- point_a: npt.ArrayLike, direction_vector: npt.ArrayLike, point_p: npt.ArrayLike
101
+ point_a: npt.ArrayLike,
102
+ direction_vector: npt.ArrayLike,
103
+ point_p: npt.ArrayLike,
78
104
  ) -> npt.NDArray:
105
+ """Project a point onto a line defined by a point and direction vector.
106
+
107
+ >>> a = np.array([0, 0, 0])
108
+ >>> d = np.array([1, 0, 0])
109
+ >>> p = np.array([1, 1, 0])
110
+ >>> projection = project_point_on_line(a, d, p)
111
+ >>> np.allclose(projection, [1, 0, 0])
112
+ True
113
+ """
114
+
79
115
  A = np.array(point_a)
80
116
  n = np.array(direction_vector)
81
117
  P = np.array(point_p)
@@ -86,6 +122,14 @@ def project_point_on_line(
86
122
 
87
123
 
88
124
  def project_vector_on_vector(vec: npt.ArrayLike, target: npt.ArrayLike) -> npt.NDArray:
125
+ """Project one vector onto another.
126
+
127
+ >>> v = np.array([1, 1, 0])
128
+ >>> t = np.array([1, 0, 0])
129
+ >>> proj = project_vector_on_vector(v, t)
130
+ >>> np.allclose(proj, [1, 0, 0])
131
+ True
132
+ """
89
133
  v = np.array(vec)
90
134
  n = np.array(target)
91
135
 
@@ -95,8 +139,20 @@ def project_vector_on_vector(vec: npt.ArrayLike, target: npt.ArrayLike) -> npt.N
95
139
 
96
140
 
97
141
  def project_vector_on_plane(
98
- vec: npt.ArrayLike, plane_normal_vec: npt.ArrayLike
142
+ vec: npt.ArrayLike,
143
+ plane_normal_vec: npt.ArrayLike,
99
144
  ) -> npt.NDArray:
145
+ """Project a vector onto a plane defined by its normal vector.
146
+
147
+ >>> v = np.array([1, 1, 1])
148
+ >>> n = np.array([0, 0, 1])
149
+ >>> proj = project_vector_on_plane(v, n)
150
+ >>> np.allclose(proj, [1, 1, 0]) # Z component removed
151
+ True
152
+ >>> np.allclose(np.dot(proj, n), 0) # Should be perpendicular to normal
153
+ True
154
+ """
155
+
100
156
  v = np.array(vec)
101
157
  n = np.array(plane_normal_vec)
102
158