optisketch 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
optisketch/__init__.py ADDED
@@ -0,0 +1,46 @@
1
+ """A light raytracing library."""
2
+
3
+ from importlib.metadata import PackageNotFoundError, version
4
+
5
+ try:
6
+ __version__ = version("trace-light")
7
+ except PackageNotFoundError:
8
+ __version__ = "uninstalled"
9
+ __author__ = "Kevin Yamauchi"
10
+ __email__ = "kevin.yamauchi@gmail.com"
11
+
12
+ from optisketch import analysis, backends, fluorescence, optimize, viz
13
+ from optisketch.kernels import _trace_surfaces, trace
14
+ from optisketch.rays import (
15
+ Rays,
16
+ Surface,
17
+ System,
18
+ _Params,
19
+ _Structure,
20
+ load_system,
21
+ save_system,
22
+ )
23
+ from optisketch.sources import Source, collimated_source, emit, point_source
24
+ from optisketch.systems import SystemBuilder
25
+
26
+ __all__ = [
27
+ "Rays",
28
+ "Source",
29
+ "Surface",
30
+ "System",
31
+ "SystemBuilder",
32
+ "_Params",
33
+ "_Structure",
34
+ "_trace_surfaces",
35
+ "analysis",
36
+ "backends",
37
+ "collimated_source",
38
+ "emit",
39
+ "fluorescence",
40
+ "load_system",
41
+ "optimize",
42
+ "point_source",
43
+ "save_system",
44
+ "trace",
45
+ "viz",
46
+ ]
@@ -0,0 +1,22 @@
1
+ """Read-only analysis functions (Phases 4-5).
2
+
3
+ * :func:`spot` — spot-diagram statistics for a ray bundle.
4
+ * :func:`psf` — geometric point-spread-function kernel.
5
+ * :func:`irradiance` — flux-density histogram at a plane.
6
+ * :func:`image_sim` — incoherent image formation by PSF convolution.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from optisketch.analysis.image_sim import image_sim
12
+ from optisketch.analysis.irradiance import irradiance
13
+ from optisketch.analysis.psf import psf
14
+ from optisketch.analysis.spot import SpotStats, spot
15
+
16
+ __all__ = [
17
+ "SpotStats",
18
+ "image_sim",
19
+ "irradiance",
20
+ "psf",
21
+ "spot",
22
+ ]
@@ -0,0 +1,230 @@
1
+ """Incoherent image simulation by PSF convolution (Phase 5).
2
+
3
+ :func:`image_sim` forms an image as the incoherent sum, over object depth
4
+ slices, of each slice convolved with the system PSF appropriate to its
5
+ ``(field, depth, focus)``. Two modes are provided:
6
+
7
+ * ``psf="single"`` — one lateral PSF per depth (shift-invariant); each slice is
8
+ FFT-convolved with that single kernel.
9
+ * ``psf="varying"`` — PSFs sampled on a coarse ``grid`` of field points per
10
+ depth. Each field region is selected by a separable partition-of-unity weight
11
+ map; the object slice is multiplied by the weight, convolved with the local
12
+ PSF, and the contributions are summed. Because the weights sum to one
13
+ everywhere, a system whose PSF is field-independent reproduces the
14
+ ``psf="single"`` result exactly.
15
+
16
+ A scalar ``focus`` yields a 2-D image; an array ``focus`` yields a 3-D focal
17
+ stack.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ from typing import TYPE_CHECKING, Any
23
+
24
+ import numpy as np
25
+
26
+ from optisketch.analysis.psf import psf as _psf
27
+
28
+ if TYPE_CHECKING:
29
+ from optisketch.rays import System
30
+
31
+
32
+ def _axis_hats(n: int, g: int) -> tuple[np.ndarray, np.ndarray]:
33
+ """Return separable partition-of-unity hat weights along one axis.
34
+
35
+ For *g* nodes spread evenly over ``[0, n-1]``, each pixel is assigned
36
+ linear-interpolation weights to the nodes such that the weights sum to one
37
+ at every pixel (a partition of unity).
38
+
39
+ Parameters
40
+ ----------
41
+ n : int
42
+ Number of pixels along the axis.
43
+ g : int
44
+ Number of field-sample nodes along the axis.
45
+
46
+ Returns
47
+ -------
48
+ weights : numpy.ndarray
49
+ Array of shape ``(n, g)``; column *k* is the weight of node *k*.
50
+ nodes : numpy.ndarray
51
+ Node pixel positions, shape ``(g,)``.
52
+ """
53
+ coords = np.arange(n, dtype=np.float64)
54
+ nodes = np.linspace(0.0, n - 1, g)
55
+ weights = np.zeros((n, g), dtype=np.float64)
56
+ for k in range(g):
57
+ e = np.zeros(g, dtype=np.float64)
58
+ e[k] = 1.0
59
+ weights[:, k] = np.interp(coords, nodes, e)
60
+ return weights, nodes
61
+
62
+
63
+ def _pixel_to_field(node_px: float, n: int, half_extent: float) -> float:
64
+ """Map a node pixel position to a lateral field coordinate (mm).
65
+
66
+ Parameters
67
+ ----------
68
+ node_px : float
69
+ Node position in pixel units (0 .. n-1).
70
+ n : int
71
+ Number of pixels along the axis.
72
+ half_extent : float
73
+ Half-width of the field of view along this axis (mm).
74
+
75
+ Returns
76
+ -------
77
+ float
78
+ Field coordinate in millimetres, centred on the array.
79
+ """
80
+ if n <= 1:
81
+ return 0.0
82
+ frac = (node_px - (n - 1) / 2.0) / ((n - 1) / 2.0)
83
+ return frac * half_extent
84
+
85
+
86
+ def image_sim(
87
+ system: System,
88
+ obj: Any,
89
+ extent: float | tuple[float, float],
90
+ *,
91
+ psf: str = "varying",
92
+ field: tuple[float, float] = (0.0, 0.0),
93
+ grid: tuple[int, int] = (3, 3),
94
+ focus: float | Any = 0.0,
95
+ wavelength: float | None = None,
96
+ psf_grid: tuple[int, int] = (31, 31),
97
+ n_rays: int = 256,
98
+ z_object: float = -100.0,
99
+ depth_extent: float = 0.0,
100
+ ) -> Any:
101
+ """Simulate the incoherent image of *obj* formed by *system*.
102
+
103
+ Parameters
104
+ ----------
105
+ system : System
106
+ Imaging system.
107
+ obj : array
108
+ Object intensity. Shape ``(ny, nx)`` for a planar object or
109
+ ``(nz, ny, nx)`` for a volume (incoherently summed over depth).
110
+ extent : float or tuple of float
111
+ Lateral half-width of the field of view (mm). Scalar applies to both
112
+ axes; ``(ey, ex)`` sets them separately.
113
+ psf : str, optional
114
+ ``"varying"`` (field-dependent, default) or ``"single"``
115
+ (shift-invariant).
116
+ field : tuple of float, optional
117
+ Base lateral field offset (mm) added to every PSF evaluation.
118
+ grid : tuple of int, optional
119
+ ``(gy, gx)`` coarse field-sampling grid for ``psf="varying"``.
120
+ Ignored for ``psf="single"``. Default ``(3, 3)``.
121
+ focus : float or array, optional
122
+ Detector focus offset(s) from ``system.image_z`` (mm). A scalar yields
123
+ a 2-D image; a 1-D array yields a 3-D ``(nf, ny, nx)`` stack.
124
+ wavelength : float, optional
125
+ Imaging wavelength (µm). Defaults to the system's first wavelength.
126
+ psf_grid : tuple of int, optional
127
+ ``(ny, nx)`` grid of each evaluated PSF kernel. Default ``(31, 31)``.
128
+ n_rays : int, optional
129
+ Pupil samples per PSF evaluation. Default 256.
130
+ z_object : float, optional
131
+ Nominal object-plane z-position (mm). Default ``-100.0``.
132
+ depth_extent : float, optional
133
+ Half-range of object axial depth mapped across the ``nz`` slices (mm).
134
+ Default 0 (all slices at the nominal plane).
135
+
136
+ Returns
137
+ -------
138
+ array
139
+ Simulated image of shape ``(ny, nx)`` (scalar *focus*) or
140
+ ``(nf, ny, nx)`` (array *focus*).
141
+
142
+ Raises
143
+ ------
144
+ ValueError
145
+ If *psf* is not ``"single"`` or ``"varying"`` or *obj* is not 2-D/3-D.
146
+ """
147
+ if psf not in ("single", "varying"):
148
+ raise ValueError(f"Unknown psf mode {psf!r}. Choose 'single' or 'varying'.")
149
+
150
+ be = system.backend
151
+
152
+ shape = tuple(obj.shape)
153
+ if len(shape) == 2:
154
+ ny, nx = shape
155
+ slices = [obj]
156
+ nz = 1
157
+ elif len(shape) == 3:
158
+ nz, ny, nx = shape
159
+ slices = [obj[zi] for zi in range(nz)]
160
+ else:
161
+ raise ValueError("obj must be 2-D (ny, nx) or 3-D (nz, ny, nx).")
162
+
163
+ if isinstance(extent, (tuple, list)):
164
+ ext_y, ext_x = float(extent[0]), float(extent[1])
165
+ else:
166
+ ext_y = ext_x = float(extent)
167
+
168
+ # The PSF kernel must share the object's pixel pitch so the discrete
169
+ # convolution is physically meaningful. Pixel pitch = full width / pixels.
170
+ pitch = (2.0 * ext_x) / nx
171
+ psf_extent = pitch * psf_grid[1] / 2.0
172
+
173
+ # focus may be scalar or a 1-D array of focal planes
174
+ focus_arr = np.atleast_1d(np.asarray(be.to_numpy(focus), dtype=np.float64))
175
+ scalar_focus = np.ndim(be.to_numpy(focus)) == 0
176
+
177
+ # depth per slice
178
+ if nz == 1:
179
+ depths = [0.0]
180
+ else:
181
+ depths = list(np.linspace(-depth_extent, depth_extent, nz))
182
+
183
+ # field sampling weight maps (varying mode only)
184
+ if psf == "varying":
185
+ gy, gx = grid
186
+ wy, nodes_y = _axis_hats(ny, gy)
187
+ wx, nodes_x = _axis_hats(nx, gx)
188
+
189
+ images = []
190
+ for f in focus_arr:
191
+ acc = be.zeros((ny, nx))
192
+ for zi in range(nz):
193
+ depth = depths[zi]
194
+ slc = slices[zi]
195
+ if psf == "single":
196
+ k = _psf(
197
+ system,
198
+ field=field,
199
+ depth=depth,
200
+ focus=float(f),
201
+ wavelength=wavelength,
202
+ n_rays=n_rays,
203
+ grid=psf_grid,
204
+ extent=psf_extent,
205
+ z_object=z_object,
206
+ )
207
+ acc = acc + be.fftconvolve(slc, k, mode="same")
208
+ else:
209
+ for gi in range(gy):
210
+ fy = field[1] + _pixel_to_field(nodes_y[gi], ny, ext_y)
211
+ for gj in range(gx):
212
+ fx = field[0] + _pixel_to_field(nodes_x[gj], nx, ext_x)
213
+ wmap = be.asarray(np.outer(wy[:, gi], wx[:, gj]))
214
+ k = _psf(
215
+ system,
216
+ field=(fx, fy),
217
+ depth=depth,
218
+ focus=float(f),
219
+ wavelength=wavelength,
220
+ n_rays=n_rays,
221
+ grid=psf_grid,
222
+ extent=psf_extent,
223
+ z_object=z_object,
224
+ )
225
+ acc = acc + be.fftconvolve(slc * wmap, k, mode="same")
226
+ images.append(acc)
227
+
228
+ if scalar_focus:
229
+ return images[0]
230
+ return be.stack(images, axis=0)
@@ -0,0 +1,79 @@
1
+ """Irradiance (flux density) estimation at a detector plane (Phase 4).
2
+
3
+ :func:`irradiance` traces a :class:`~optisketch.sources.Source` through a
4
+ system, propagates the exiting rays to a target plane, and accumulates a
5
+ weighted 2-D histogram of the valid hits. It is a thin wrapper over
6
+ :func:`~optisketch.kernels._trace_surfaces` and the backend ``histogram2d``.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from typing import TYPE_CHECKING, Any
12
+
13
+ from optisketch.kernels import _propagate_to_plane, _trace_surfaces
14
+ from optisketch.sources import emit
15
+
16
+ if TYPE_CHECKING:
17
+ from optisketch.rays import System
18
+ from optisketch.sources import Source
19
+
20
+
21
+ def irradiance(
22
+ system: System,
23
+ source: Source,
24
+ z: float,
25
+ grid: tuple[int, int] = (64, 64),
26
+ *,
27
+ extent: float | tuple[float, float] | None = None,
28
+ ) -> Any:
29
+ """Compute the irradiance distribution of *source* at plane ``z``.
30
+
31
+ The source is emitted into the system, traced through every surface, and
32
+ propagated in free space to ``z``. Valid hits are binned into a weighted
33
+ 2-D histogram (weights are the per-ray intensities). When *extent* is None
34
+ the window is sized to enclose all valid hits, so the histogram sum equals
35
+ the total weight of the valid rays.
36
+
37
+ Parameters
38
+ ----------
39
+ system : System
40
+ Optical system to trace through.
41
+ source : Source
42
+ Ray source to emit.
43
+ z : float
44
+ Detector-plane z-position (mm).
45
+ grid : tuple of int, optional
46
+ ``(ny, nx)`` histogram grid shape. Default ``(64, 64)``.
47
+ extent : float or tuple of float, optional
48
+ Half-width of the window about the origin (mm). Scalar applies to both
49
+ axes; ``(ey, ex)`` sets them separately. When None it is derived from
50
+ the data.
51
+
52
+ Returns
53
+ -------
54
+ array
55
+ 2-D irradiance histogram of shape *grid*, indexed ``[iy, ix]``.
56
+ """
57
+ be = system.backend
58
+ rays = emit(source, system)
59
+ final, _ = _trace_surfaces(rays, system.structure, system.params, be)
60
+ final = _propagate_to_plane(final, float(z), be)
61
+
62
+ xs, ys, valid = final.x, final.y, final.valid
63
+ zeros = be.zeros_like(xs)
64
+ wv = be.where(valid, final.i, be.zeros_like(final.i))
65
+
66
+ if extent is None:
67
+ ax = float(be.to_numpy(be.max(be.where(valid, be.abs(xs), zeros))))
68
+ ay = float(be.to_numpy(be.max(be.where(valid, be.abs(ys), zeros))))
69
+ ex = ax * 1.0 + 1e-6
70
+ ey = ay * 1.0 + 1e-6
71
+ elif isinstance(extent, (tuple, list)):
72
+ ey, ex = float(extent[0]), float(extent[1])
73
+ else:
74
+ ex = ey = float(extent)
75
+
76
+ xs_safe = be.where(valid, xs, zeros)
77
+ ys_safe = be.where(valid, ys, zeros)
78
+ rng = ((-ey, ey), (-ex, ex))
79
+ return be.histogram2d(ys_safe, xs_safe, bins=grid, range=rng, weights=wv)
@@ -0,0 +1,144 @@
1
+ """Point-spread-function estimation by geometric ray histogramming (Phase 4).
2
+
3
+ :func:`psf` traces a point emitter through a :class:`~optisketch.rays.System`,
4
+ propagates the exiting rays to a (possibly defocused) detector plane, and
5
+ histograms the valid image-plane hits into a normalised 2-D kernel centred on
6
+ the spot centroid. Because the kernel is centred on the centroid, the result is
7
+ a shift-invariant impulse response suitable for convolution in
8
+ :func:`~optisketch.analysis.image_sim.image_sim`.
9
+
10
+ Sweeping ``depth`` (object-side axial position) or ``focus`` (detector-side
11
+ axial position) produces a through-focus stack.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from typing import TYPE_CHECKING, Any
17
+
18
+ from optisketch.kernels import _propagate_to_plane, _trace_surfaces
19
+ from optisketch.sources import emit, point_source
20
+
21
+ if TYPE_CHECKING:
22
+ from optisketch.rays import System
23
+
24
+
25
+ def _masked_centroid(xs: Any, ys: Any, valid: Any, be: Any) -> tuple[float, float, int]:
26
+ """Return the NaN-safe centroid of valid samples as Python floats.
27
+
28
+ Parameters
29
+ ----------
30
+ xs : array
31
+ x-coordinates of the samples.
32
+ ys : array
33
+ y-coordinates of the samples.
34
+ valid : bool array
35
+ Mask selecting the samples that contribute.
36
+ be : Backend
37
+ Array-computation backend.
38
+
39
+ Returns
40
+ -------
41
+ cx : float
42
+ Centroid x-coordinate (0.0 if no valid samples).
43
+ cy : float
44
+ Centroid y-coordinate.
45
+ n : int
46
+ Number of valid samples.
47
+ """
48
+ zeros = be.zeros_like(xs)
49
+ vf = be.where(valid, be.ones_like(xs), zeros)
50
+ n = float(be.to_numpy(be.sum(vf)))
51
+ if n < 1.0:
52
+ return 0.0, 0.0, 0
53
+ cx = float(be.to_numpy(be.sum(be.where(valid, xs, zeros)))) / n
54
+ cy = float(be.to_numpy(be.sum(be.where(valid, ys, zeros)))) / n
55
+ return cx, cy, int(n)
56
+
57
+
58
+ def psf(
59
+ system: System,
60
+ field: tuple[float, float] = (0.0, 0.0),
61
+ *,
62
+ depth: float = 0.0,
63
+ focus: float = 0.0,
64
+ wavelength: float | None = None,
65
+ n_rays: int = 256,
66
+ grid: tuple[int, int] = (64, 64),
67
+ extent: float | None = None,
68
+ z_object: float = -100.0,
69
+ ) -> Any:
70
+ """Estimate the geometric PSF kernel of *system* for a point emitter.
71
+
72
+ A point source at lateral position *field* and axial position
73
+ ``z_object + depth`` is traced through the system; the exiting rays are
74
+ propagated to the detector plane ``image_z + focus`` and histogrammed into
75
+ a grid centred on the spot centroid. The kernel is normalised to unit sum.
76
+
77
+ Parameters
78
+ ----------
79
+ system : System
80
+ Optical system to evaluate.
81
+ field : tuple of float, optional
82
+ ``(x, y)`` lateral position of the point emitter (mm). Default on-axis.
83
+ depth : float, optional
84
+ Axial offset of the emitter from the nominal object plane (mm).
85
+ Default 0.
86
+ focus : float, optional
87
+ Axial offset of the detector plane from ``system.image_z`` (mm).
88
+ Default 0.
89
+ wavelength : float, optional
90
+ Emission wavelength (µm). Defaults to the system's first wavelength.
91
+ n_rays : int, optional
92
+ Number of pupil samples (rays). Default 256.
93
+ grid : tuple of int, optional
94
+ ``(ny, nx)`` kernel grid shape. Default ``(64, 64)``.
95
+ extent : float, optional
96
+ Half-width of the kernel window about the centroid (mm). When None it
97
+ is derived from the geometric spread of the valid rays.
98
+ z_object : float, optional
99
+ Nominal object-plane z-position (mm). Default ``-100.0``.
100
+
101
+ Returns
102
+ -------
103
+ array
104
+ Normalised 2-D PSF kernel of shape *grid*, indexed ``[iy, ix]``.
105
+ """
106
+ be = system.backend
107
+ wl = float(system.wavelengths[0]) if wavelength is None else float(wavelength)
108
+
109
+ src = point_source(
110
+ (float(field[0]), float(field[1])),
111
+ z_object=float(z_object) + float(depth),
112
+ wavelength=wl,
113
+ pupil_pattern="disk",
114
+ n_samples=n_rays,
115
+ )
116
+ rays = emit(src, system)
117
+ final, _ = _trace_surfaces(rays, system.structure, system.params, be)
118
+
119
+ z_det = float(system.image_z) + float(focus)
120
+ final = _propagate_to_plane(final, z_det, be)
121
+
122
+ xs, ys, valid = final.x, final.y, final.valid
123
+ cx, cy, _n = _masked_centroid(xs, ys, valid, be)
124
+
125
+ if extent is None:
126
+ zeros = be.zeros_like(xs)
127
+ dx = be.where(valid, xs - cx, zeros)
128
+ dy = be.where(valid, ys - cy, zeros)
129
+ r2 = dx * dx + dy * dy
130
+ spread = float(be.to_numpy(be.sqrt(be.max(be.where(valid, r2, zeros)))))
131
+ extent = spread * 1.1 + 1e-6
132
+ e = float(extent)
133
+
134
+ # weights: valid-ray intensity, zero elsewhere; keep invalid coords in-range
135
+ wv = be.where(valid, final.i, be.zeros_like(final.i))
136
+ xs_safe = be.where(valid, xs, be.full_like(xs, cx))
137
+ ys_safe = be.where(valid, ys, be.full_like(ys, cy))
138
+
139
+ rng = ((cy - e, cy + e), (cx - e, cx + e))
140
+ h = be.histogram2d(ys_safe, xs_safe, bins=grid, range=rng, weights=wv)
141
+
142
+ total = be.sum(h)
143
+ total_safe = be.where(total > 0.0, total, be.ones_like(total))
144
+ return h / total_safe
@@ -0,0 +1,129 @@
1
+ """Spot-diagram statistics for a traced ray bundle (Phase 4).
2
+
3
+ :func:`spot` reduces a :class:`~optisketch.rays.Rays` bundle to centroid,
4
+ RMS, and geometric-radius statistics over the valid rays. All reductions are
5
+ NaN-safe: invalid rays are masked out via ``backend.where`` before any sum or
6
+ maximum, so a NaN in a missed/TIR ray never corrupts the result.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from typing import TYPE_CHECKING, NamedTuple
12
+
13
+ if TYPE_CHECKING:
14
+ from optisketch.backends._protocol import Backend
15
+ from optisketch.rays import Rays
16
+
17
+
18
+ class SpotStats(NamedTuple):
19
+ """Summary statistics of a spot diagram.
20
+
21
+ Attributes
22
+ ----------
23
+ cx : float
24
+ x-coordinate of the centroid of valid ray hits (mm).
25
+ cy : float
26
+ y-coordinate of the centroid of valid ray hits (mm).
27
+ rms : float
28
+ Root-mean-square spot radius about the reference point (mm).
29
+ geo_radius : float
30
+ Maximum distance of any valid ray from the reference point (mm).
31
+ n_valid : int
32
+ Number of valid rays contributing to the statistics.
33
+ """
34
+
35
+ cx: float
36
+ cy: float
37
+ rms: float
38
+ geo_radius: float
39
+ n_valid: int
40
+
41
+
42
+ def _default_backend() -> Backend:
43
+ """Return a NumpyBackend instance without importing it at module load.
44
+
45
+ Returns
46
+ -------
47
+ Backend
48
+ A fresh :class:`~optisketch.backends.NumpyBackend`.
49
+ """
50
+ from optisketch.backends._numpy import NumpyBackend
51
+
52
+ return NumpyBackend()
53
+
54
+
55
+ def spot(
56
+ rays: Rays,
57
+ *,
58
+ reference: str = "centroid",
59
+ backend: Backend | None = None,
60
+ ) -> SpotStats:
61
+ """Compute spot-diagram statistics for a ray bundle.
62
+
63
+ Statistics are evaluated on the final ``(x, y)`` positions of the rays,
64
+ masked on ``rays.valid``. Two reference modes are supported: ``"centroid"``
65
+ measures RMS/geometric radius about the centroid of the valid hits, while
66
+ ``"chief"`` measures them about the chief ray (ray index 0).
67
+
68
+ Parameters
69
+ ----------
70
+ rays : Rays
71
+ Ray bundle, typically the output of a trace at the image plane.
72
+ reference : str, optional
73
+ ``"centroid"`` (default) or ``"chief"``. Selects the reference point
74
+ for the RMS and geometric-radius statistics.
75
+ backend : Backend, optional
76
+ Array-computation backend. Defaults to
77
+ :class:`~optisketch.backends.NumpyBackend`.
78
+
79
+ Returns
80
+ -------
81
+ SpotStats
82
+ Centroid, RMS radius, geometric radius, and valid-ray count.
83
+
84
+ Raises
85
+ ------
86
+ ValueError
87
+ If *reference* is not ``"centroid"`` or ``"chief"``.
88
+ """
89
+ if reference not in ("centroid", "chief"):
90
+ raise ValueError(
91
+ f"Unknown reference {reference!r}. Choose 'centroid' or 'chief'."
92
+ )
93
+
94
+ be = backend if backend is not None else _default_backend()
95
+
96
+ valid = rays.valid
97
+ zeros = be.zeros_like(rays.x)
98
+ # mask positions: invalid rays contribute 0 (never NaN) to the sums
99
+ xv = be.where(valid, rays.x, zeros)
100
+ yv = be.where(valid, rays.y, zeros)
101
+ vf = be.where(valid, be.ones_like(rays.x), zeros)
102
+
103
+ n = be.sum(vf)
104
+ n_safe = be.maximum(n, be.ones_like(n))
105
+ cx = be.sum(xv) / n_safe
106
+ cy = be.sum(yv) / n_safe
107
+
108
+ if reference == "chief":
109
+ rx = rays.x[0]
110
+ ry = rays.y[0]
111
+ else:
112
+ rx = cx
113
+ ry = cy
114
+
115
+ dx = be.where(valid, rays.x - rx, zeros)
116
+ dy = be.where(valid, rays.y - ry, zeros)
117
+ r2 = dx * dx + dy * dy
118
+
119
+ rms = be.sqrt(be.sum(r2) / n_safe)
120
+ geo = be.sqrt(be.max(be.where(valid, r2, zeros)))
121
+
122
+ to_np = be.to_numpy
123
+ return SpotStats(
124
+ cx=float(to_np(cx)),
125
+ cy=float(to_np(cy)),
126
+ rms=float(to_np(rms)),
127
+ geo_radius=float(to_np(geo)),
128
+ n_valid=int(to_np(n)),
129
+ )
@@ -0,0 +1,38 @@
1
+ """Backend constructors.
2
+
3
+ Usage::
4
+
5
+ from optisketch.backends import numpy, jax
6
+
7
+ be = numpy()
8
+ be_jax = jax() # raises ImportError if jax is not installed
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from optisketch.backends._numpy import NumpyBackend
14
+ from optisketch.backends._protocol import Backend, NotDifferentiable
15
+
16
+
17
+ def numpy() -> NumpyBackend:
18
+ """Return a :class:`NumpyBackend` instance."""
19
+ return NumpyBackend()
20
+
21
+
22
+ def jax():
23
+ """Return a :class:`JaxBackend` instance.
24
+
25
+ Raises :exc:`ImportError` if jax is not installed.
26
+ """
27
+ from optisketch.backends._jax import JaxBackend
28
+
29
+ return JaxBackend()
30
+
31
+
32
+ __all__ = [
33
+ "Backend",
34
+ "NotDifferentiable",
35
+ "NumpyBackend",
36
+ "jax",
37
+ "numpy",
38
+ ]