senoquant 1.0.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. senoquant/__init__.py +6 -0
  2. senoquant/_reader.py +7 -0
  3. senoquant/_widget.py +33 -0
  4. senoquant/napari.yaml +83 -0
  5. senoquant/reader/__init__.py +5 -0
  6. senoquant/reader/core.py +369 -0
  7. senoquant/tabs/__init__.py +15 -0
  8. senoquant/tabs/batch/__init__.py +10 -0
  9. senoquant/tabs/batch/backend.py +641 -0
  10. senoquant/tabs/batch/config.py +270 -0
  11. senoquant/tabs/batch/frontend.py +1283 -0
  12. senoquant/tabs/batch/io.py +326 -0
  13. senoquant/tabs/batch/layers.py +86 -0
  14. senoquant/tabs/quantification/__init__.py +1 -0
  15. senoquant/tabs/quantification/backend.py +228 -0
  16. senoquant/tabs/quantification/features/__init__.py +80 -0
  17. senoquant/tabs/quantification/features/base.py +142 -0
  18. senoquant/tabs/quantification/features/marker/__init__.py +5 -0
  19. senoquant/tabs/quantification/features/marker/config.py +69 -0
  20. senoquant/tabs/quantification/features/marker/dialog.py +437 -0
  21. senoquant/tabs/quantification/features/marker/export.py +879 -0
  22. senoquant/tabs/quantification/features/marker/feature.py +119 -0
  23. senoquant/tabs/quantification/features/marker/morphology.py +285 -0
  24. senoquant/tabs/quantification/features/marker/rows.py +654 -0
  25. senoquant/tabs/quantification/features/marker/thresholding.py +46 -0
  26. senoquant/tabs/quantification/features/roi.py +346 -0
  27. senoquant/tabs/quantification/features/spots/__init__.py +5 -0
  28. senoquant/tabs/quantification/features/spots/config.py +62 -0
  29. senoquant/tabs/quantification/features/spots/dialog.py +477 -0
  30. senoquant/tabs/quantification/features/spots/export.py +1292 -0
  31. senoquant/tabs/quantification/features/spots/feature.py +112 -0
  32. senoquant/tabs/quantification/features/spots/morphology.py +279 -0
  33. senoquant/tabs/quantification/features/spots/rows.py +241 -0
  34. senoquant/tabs/quantification/frontend.py +815 -0
  35. senoquant/tabs/segmentation/__init__.py +1 -0
  36. senoquant/tabs/segmentation/backend.py +131 -0
  37. senoquant/tabs/segmentation/frontend.py +1009 -0
  38. senoquant/tabs/segmentation/models/__init__.py +5 -0
  39. senoquant/tabs/segmentation/models/base.py +146 -0
  40. senoquant/tabs/segmentation/models/cpsam/details.json +65 -0
  41. senoquant/tabs/segmentation/models/cpsam/model.py +150 -0
  42. senoquant/tabs/segmentation/models/default_2d/details.json +69 -0
  43. senoquant/tabs/segmentation/models/default_2d/model.py +664 -0
  44. senoquant/tabs/segmentation/models/default_3d/details.json +69 -0
  45. senoquant/tabs/segmentation/models/default_3d/model.py +682 -0
  46. senoquant/tabs/segmentation/models/hf.py +71 -0
  47. senoquant/tabs/segmentation/models/nuclear_dilation/__init__.py +1 -0
  48. senoquant/tabs/segmentation/models/nuclear_dilation/details.json +26 -0
  49. senoquant/tabs/segmentation/models/nuclear_dilation/model.py +96 -0
  50. senoquant/tabs/segmentation/models/perinuclear_rings/__init__.py +1 -0
  51. senoquant/tabs/segmentation/models/perinuclear_rings/details.json +34 -0
  52. senoquant/tabs/segmentation/models/perinuclear_rings/model.py +132 -0
  53. senoquant/tabs/segmentation/stardist_onnx_utils/__init__.py +2 -0
  54. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/__init__.py +3 -0
  55. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/__init__.py +6 -0
  56. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/generate.py +470 -0
  57. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/prepare.py +273 -0
  58. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/rawdata.py +112 -0
  59. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/transform.py +384 -0
  60. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/__init__.py +0 -0
  61. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/blocks.py +184 -0
  62. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/losses.py +79 -0
  63. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/nets.py +165 -0
  64. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/predict.py +467 -0
  65. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/probability.py +67 -0
  66. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/train.py +148 -0
  67. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/io/__init__.py +163 -0
  68. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/__init__.py +52 -0
  69. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/base_model.py +329 -0
  70. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_isotropic.py +160 -0
  71. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_projection.py +178 -0
  72. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_standard.py +446 -0
  73. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_upsampling.py +54 -0
  74. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/config.py +254 -0
  75. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/pretrained.py +119 -0
  76. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/__init__.py +0 -0
  77. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/care_predict.py +180 -0
  78. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/__init__.py +5 -0
  79. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/plot_utils.py +159 -0
  80. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/six.py +18 -0
  81. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/tf.py +644 -0
  82. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/utils.py +272 -0
  83. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/version.py +1 -0
  84. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/docs/source/conf.py +368 -0
  85. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/setup.py +68 -0
  86. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_datagen.py +169 -0
  87. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_models.py +462 -0
  88. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_utils.py +166 -0
  89. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tools/create_zip_contents.py +34 -0
  90. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/__init__.py +30 -0
  91. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/big.py +624 -0
  92. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/bioimageio_utils.py +494 -0
  93. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/data/__init__.py +39 -0
  94. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/__init__.py +10 -0
  95. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom2d.py +215 -0
  96. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom3d.py +349 -0
  97. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/matching.py +483 -0
  98. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/__init__.py +28 -0
  99. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/base.py +1217 -0
  100. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model2d.py +594 -0
  101. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model3d.py +696 -0
  102. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/nms.py +384 -0
  103. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/__init__.py +2 -0
  104. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/plot.py +74 -0
  105. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/render.py +298 -0
  106. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/rays3d.py +373 -0
  107. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/sample_patches.py +65 -0
  108. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/__init__.py +0 -0
  109. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict2d.py +90 -0
  110. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict3d.py +93 -0
  111. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/utils.py +408 -0
  112. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/version.py +1 -0
  113. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/__init__.py +45 -0
  114. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/__init__.py +17 -0
  115. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/cli.py +55 -0
  116. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/core.py +285 -0
  117. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/__init__.py +15 -0
  118. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/cli.py +36 -0
  119. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/divisibility.py +193 -0
  120. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/probe.py +100 -0
  121. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/receptive_field.py +182 -0
  122. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/rf_cli.py +48 -0
  123. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/valid_sizes.py +278 -0
  124. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/__init__.py +8 -0
  125. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/core.py +157 -0
  126. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/__init__.py +17 -0
  127. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/core.py +226 -0
  128. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/__init__.py +5 -0
  129. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/core.py +401 -0
  130. senoquant/tabs/settings/__init__.py +1 -0
  131. senoquant/tabs/settings/backend.py +29 -0
  132. senoquant/tabs/settings/frontend.py +19 -0
  133. senoquant/tabs/spots/__init__.py +1 -0
  134. senoquant/tabs/spots/backend.py +139 -0
  135. senoquant/tabs/spots/frontend.py +800 -0
  136. senoquant/tabs/spots/models/__init__.py +5 -0
  137. senoquant/tabs/spots/models/base.py +94 -0
  138. senoquant/tabs/spots/models/rmp/details.json +61 -0
  139. senoquant/tabs/spots/models/rmp/model.py +499 -0
  140. senoquant/tabs/spots/models/udwt/details.json +103 -0
  141. senoquant/tabs/spots/models/udwt/model.py +482 -0
  142. senoquant/utils.py +25 -0
  143. senoquant-1.0.0b1.dist-info/METADATA +193 -0
  144. senoquant-1.0.0b1.dist-info/RECORD +148 -0
  145. senoquant-1.0.0b1.dist-info/WHEEL +5 -0
  146. senoquant-1.0.0b1.dist-info/entry_points.txt +2 -0
  147. senoquant-1.0.0b1.dist-info/licenses/LICENSE +28 -0
  148. senoquant-1.0.0b1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,682 @@
1
+ """StarDist ONNX segmentation model implementation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ import importlib.util
7
+ import sys
8
+ import types
9
+ from typing import TYPE_CHECKING
10
+
11
+ import numpy as np
12
+ import onnxruntime as ort
13
+ from scipy import ndimage as ndi
14
+
15
+ from senoquant.utils import layer_data_asarray
16
+ from ..hf import DEFAULT_REPO_ID, ensure_hf_model
17
+ from ..base import SenoQuantSegmentationModel
18
+ from senoquant.tabs.segmentation.stardist_onnx_utils.onnx_framework import (
19
+ normalize,
20
+ predict_tiled,
21
+ )
22
+ from senoquant.tabs.segmentation.stardist_onnx_utils.onnx_framework.inspect import (
23
+ make_probe_image,
24
+ )
25
+ if TYPE_CHECKING:
26
+ from senoquant.tabs.segmentation.stardist_onnx_utils.onnx_framework.inspect.valid_sizes import (
27
+ ValidSizePattern,
28
+ )
29
+
30
+
31
+ class StarDistOnnxModel(SenoQuantSegmentationModel):
32
+ """StarDist ONNX 3D segmentation model.
33
+
34
+ This wrapper loads an exported StarDist 3D ONNX model, runs
35
+ preprocessing and tiled inference, and postprocesses the outputs into
36
+ instance labels using StarDist geometry and NMS utilities.
37
+
38
+ Notes
39
+ -----
40
+ - Inputs must be single-channel images in ZYX (3D) order.
41
+ - ONNX model outputs are assumed to be probability and distance maps.
42
+ """
43
+
44
+ def __init__(self, models_root=None) -> None:
45
+ """Initialize the StarDist ONNX model wrapper.
46
+
47
+ Parameters
48
+ ----------
49
+ models_root : pathlib.Path or None
50
+ Optional root directory for model storage.
51
+ """
52
+ super().__init__("default_3d", models_root=models_root)
53
+ self._sessions: dict[Path, ort.InferenceSession] = {}
54
+ self._rays_class = None
55
+ self._has_stardist_2d_lib = False
56
+ self._has_stardist_3d_lib = False
57
+ self._div_by_cache: dict[Path, tuple[int, ...]] = {}
58
+ self._overlap_cache: dict[Path, tuple[int, ...]] = {}
59
+ self._valid_size_cache: dict[Path, list["ValidSizePattern"] | None] = {}
60
+
61
+ def run(self, **kwargs) -> dict:
62
+ """Run StarDist ONNX for nuclear segmentation.
63
+
64
+ Parameters
65
+ ----------
66
+ **kwargs
67
+ task : str
68
+ Must be "nuclear" for this model.
69
+ layer : napari.layers.Image
70
+ Single-channel image layer (YX or ZYX).
71
+ settings : dict
72
+ Model settings keyed by ``details.json``.
73
+
74
+ Returns
75
+ -------
76
+ dict
77
+ Dictionary with:
78
+ - ``masks``: instance label image
79
+ - ``prob``: probability map
80
+ - ``dist``: distance/ray map
81
+ - ``info``: NMS metadata (points, prob, dist)
82
+ """
83
+ task = kwargs.get("task")
84
+ if task != "nuclear":
85
+ raise ValueError("StarDist ONNX only supports nuclear segmentation.")
86
+
87
+ layer = kwargs.get("layer")
88
+ settings = kwargs.get("settings", {})
89
+ image = self._extract_layer_data(layer, required=True)
90
+ original_shape = image.shape
91
+
92
+ if image.ndim != 3:
93
+ raise ValueError("StarDist ONNX 3D expects a 3D (ZYX) image.")
94
+
95
+ image = image.astype(np.float32, copy=False)
96
+ image, scale = self._scale_input(image, settings)
97
+ image = self._scale_intensity(image)
98
+ if settings.get("normalize", True):
99
+ pmin = float(settings.get("pmin", 1.0))
100
+ pmax = float(settings.get("pmax", 99.8))
101
+ image = normalize(image, pmin=pmin, pmax=pmax)
102
+
103
+ model_path = self._resolve_model_path(image.ndim)
104
+ session = self._get_session(image.ndim)
105
+ input_name, output_names = self._resolve_io_names(session)
106
+
107
+ input_layout = "NDHWC"
108
+ prob_layout = "NDHWC"
109
+ dist_layout = "NZYXR"
110
+
111
+ grid = self._infer_grid(
112
+ image,
113
+ session,
114
+ input_name,
115
+ output_names,
116
+ input_layout,
117
+ prob_layout,
118
+ model_path=model_path,
119
+ )
120
+
121
+ tile_shape, overlap = self._infer_tiling(
122
+ image, model_path, session, input_name, output_names, input_layout
123
+ )
124
+ div_by = self._div_by_cache.get(model_path, grid)
125
+
126
+ try:
127
+ prob, dist = predict_tiled(
128
+ image,
129
+ session,
130
+ input_name=input_name,
131
+ output_names=output_names,
132
+ grid=grid,
133
+ input_layout=input_layout,
134
+ prob_layout=prob_layout,
135
+ dist_layout=dist_layout,
136
+ tile_shape=tile_shape,
137
+ overlap=overlap,
138
+ div_by=div_by,
139
+ )
140
+ except Exception:
141
+ if "CoreMLExecutionProvider" not in session.get_providers():
142
+ raise
143
+ session = self._get_session(
144
+ image.ndim, providers_override=["CPUExecutionProvider"]
145
+ )
146
+ prob, dist = predict_tiled(
147
+ image,
148
+ session,
149
+ input_name=input_name,
150
+ output_names=output_names,
151
+ grid=grid,
152
+ input_layout=input_layout,
153
+ prob_layout=prob_layout,
154
+ dist_layout=dist_layout,
155
+ tile_shape=tile_shape,
156
+ overlap=overlap,
157
+ div_by=div_by,
158
+ )
159
+
160
+ prob_thresh = float(settings.get("prob_thresh", 0.5))
161
+ nms_thresh = float(settings.get("nms_thresh", 0.4))
162
+
163
+ self._ensure_stardist_lib_stubs()
164
+ from senoquant.tabs.segmentation.stardist_onnx_utils.onnx_framework import (
165
+ instances_from_prediction_3d,
166
+ )
167
+
168
+ if not self._has_stardist_3d_lib:
169
+ raise RuntimeError(
170
+ "3D StarDist labeling requires compiled ops; build "
171
+ "extensions in stardist_onnx_utils/_stardist/lib."
172
+ )
173
+ rays = self._get_rays_class()(n=dist.shape[-1])
174
+ labels, info = instances_from_prediction_3d(
175
+ prob,
176
+ dist,
177
+ grid=grid,
178
+ prob_thresh=prob_thresh,
179
+ nms_thresh=nms_thresh,
180
+ rays=rays,
181
+ scale=scale,
182
+ img_shape=original_shape,
183
+ )
184
+
185
+ return {"masks": labels, "prob": prob, "dist": dist, "info": info}
186
+
187
+ def _scale_input(
188
+ self, image: np.ndarray, settings: dict
189
+ ) -> tuple[np.ndarray, dict[str, float] | None]:
190
+ """Scale the input image to match training object sizes.
191
+
192
+ Parameters
193
+ ----------
194
+ image : numpy.ndarray
195
+ Input 3D image in ZYX order.
196
+ settings : dict
197
+ Model settings containing the ``object_diameter_px`` entry.
198
+
199
+ Returns
200
+ -------
201
+ numpy.ndarray
202
+ Scaled image. If no scaling is requested, returns the input image.
203
+ dict[str, float] or None
204
+ Scale factors keyed by axis (``"Z"``, ``"Y"``, ``"X"``) for rescaling
205
+ predictions back to the original image space.
206
+ """
207
+ diameter_px = float(settings.get("object_diameter_px", 30.0))
208
+ if diameter_px <= 0:
209
+ raise ValueError("Object diameter (px) must be positive.")
210
+ scale_factor = 30.0 / diameter_px
211
+ if np.isclose(scale_factor, 1.0):
212
+ return image, None
213
+
214
+ scale = (scale_factor, scale_factor, scale_factor)
215
+ scaled = ndi.zoom(image, scale, order=1)
216
+ if min(scaled.shape) < 1:
217
+ raise ValueError(
218
+ "Scaling factor produced an empty image; adjust object diameter."
219
+ )
220
+ return scaled.astype(np.float32, copy=False), {
221
+ "Z": scale_factor,
222
+ "Y": scale_factor,
223
+ "X": scale_factor,
224
+ }
225
+
226
+ @staticmethod
227
+ def _scale_intensity(image: np.ndarray) -> np.ndarray:
228
+ """Scale image intensities into [0, 1] using min/max."""
229
+ imin = float(np.nanmin(image))
230
+ imax = float(np.nanmax(image))
231
+ if not np.isfinite(imin) or not np.isfinite(imax):
232
+ return image
233
+ if imax <= imin:
234
+ return image
235
+ return ((image - imin) / (imax - imin)).astype(np.float32, copy=False)
236
+
237
+ def _extract_layer_data(self, layer, required: bool) -> np.ndarray:
238
+ """Return numpy data for a napari layer.
239
+
240
+ Parameters
241
+ ----------
242
+ layer : object or None
243
+ Napari layer to convert.
244
+ required : bool
245
+ Whether a missing layer should raise an error.
246
+
247
+ Returns
248
+ -------
249
+ numpy.ndarray
250
+ Layer data as an array.
251
+ """
252
+ if layer is None:
253
+ if required:
254
+ raise ValueError("Layer is required for StarDist ONNX.")
255
+ return None
256
+ return layer_data_asarray(layer)
257
+
258
+ def _get_session(
259
+ self, ndim: int, *, providers_override: list[str] | None = None
260
+ ) -> ort.InferenceSession:
261
+ """Return (and cache) an ONNX Runtime session for 2D or 3D models."""
262
+ model_path = self._resolve_model_path(ndim)
263
+ session = self._sessions.get(model_path)
264
+ if session is None or providers_override is not None:
265
+ providers = providers_override or self._preferred_providers()
266
+ session = ort.InferenceSession(
267
+ str(model_path),
268
+ providers=providers,
269
+ )
270
+ self._sessions[model_path] = session
271
+ return session
272
+
273
+ @staticmethod
274
+ def _preferred_providers() -> list[str]:
275
+ """Return a provider list that prefers GPU providers when available."""
276
+ available = set(ort.get_available_providers())
277
+ preferred = [
278
+ "CUDAExecutionProvider",
279
+ "ROCMExecutionProvider",
280
+ "DirectMLExecutionProvider",
281
+ "CoreMLExecutionProvider",
282
+ "CPUExecutionProvider",
283
+ ]
284
+ providers = [provider for provider in preferred if provider in available]
285
+ if not providers:
286
+ providers = list(available)
287
+ return providers
288
+
289
+ def _infer_tiling(
290
+ self,
291
+ image: np.ndarray,
292
+ model_path: Path,
293
+ session: ort.InferenceSession,
294
+ input_name: str,
295
+ output_names: list[str],
296
+ input_layout: str,
297
+ ) -> tuple[tuple[int, ...], tuple[int, ...]]:
298
+ """Infer tiling shape and overlap for ONNX tiled prediction.
299
+
300
+ This method uses the ONNX inspection utilities to derive:
301
+ - the per-axis divisibility requirement (``div_by``), and
302
+ - a recommended overlap based on the empirical receptive field.
303
+
304
+ The inferred values are cached per ONNX model path so the expensive
305
+ inspection (graph parsing / RF probing) only happens once per model.
306
+ If inspection fails for any reason, safe fallbacks are used:
307
+ ``div_by = (1, ... )`` and ``overlap = (0, ... )``.
308
+
309
+ Parameters
310
+ ----------
311
+ image : numpy.ndarray
312
+ Input image used to determine spatial dimensionality and to
313
+ clamp tile shape/overlap to valid ranges.
314
+ model_path : pathlib.Path
315
+ Path to the ONNX model, used as a cache key for inferred values.
316
+
317
+ Returns
318
+ -------
319
+ tuple[tuple[int, ...], tuple[int, ...]]
320
+ A tuple ``(tile_shape, overlap)``, each a per-axis tuple with
321
+ the same length as ``image.ndim``. ``tile_shape`` is rounded
322
+ down to the nearest multiple of ``div_by`` (never exceeding the
323
+ input size), and ``overlap`` is clamped to ``[0, tile_size - 1]``.
324
+ The XY tile sizes are capped at 1024 pixels per axis to avoid
325
+ feeding overly large tiles to the ONNX model.
326
+ """
327
+ ndim = image.ndim
328
+ div_by = self._div_by_cache.get(model_path)
329
+ if div_by is None:
330
+ try:
331
+ from senoquant.tabs.segmentation.stardist_onnx_utils.onnx_framework.inspect import (
332
+ infer_div_by,
333
+ )
334
+ except Exception:
335
+ div_by = (1,) * ndim
336
+ else:
337
+ try:
338
+ div_by = infer_div_by(model_path, ndim=ndim)
339
+ except Exception:
340
+ div_by = (1,) * ndim
341
+ self._div_by_cache[model_path] = div_by
342
+
343
+ overlap = self._overlap_cache.get(model_path)
344
+ if overlap is None:
345
+ try:
346
+ from senoquant.tabs.segmentation.stardist_onnx_utils.onnx_framework.inspect.receptive_field import (
347
+ recommend_tile_overlap,
348
+ )
349
+ except Exception:
350
+ overlap = (0,) * ndim
351
+ else:
352
+ try:
353
+ overlap = recommend_tile_overlap(model_path, ndim=ndim)
354
+ except Exception:
355
+ overlap = (0,) * ndim
356
+ self._overlap_cache[model_path] = overlap
357
+
358
+ max_tile = 1024
359
+ if image.ndim == 3:
360
+ capped_shape = (
361
+ image.shape[0],
362
+ min(image.shape[1], max_tile),
363
+ min(image.shape[2], max_tile),
364
+ )
365
+ else:
366
+ capped_shape = tuple(min(size, max_tile) for size in image.shape)
367
+
368
+ tile_shape = tuple(
369
+ max(div, (size // div) * div) if div > 0 else size
370
+ for size, div in zip(capped_shape, div_by)
371
+ )
372
+
373
+ patterns = self._valid_size_cache.get(model_path)
374
+ if patterns is None:
375
+ try:
376
+ from senoquant.tabs.segmentation.stardist_onnx_utils.onnx_framework.inspect.valid_sizes import (
377
+ infer_valid_size_patterns_from_path,
378
+ )
379
+ except Exception:
380
+ patterns = None
381
+ else:
382
+ try:
383
+ patterns = infer_valid_size_patterns_from_path(
384
+ model_path,
385
+ input_layout,
386
+ ndim,
387
+ )
388
+ except Exception:
389
+ patterns = None
390
+ self._valid_size_cache[model_path] = patterns
391
+
392
+ if patterns:
393
+ from senoquant.tabs.segmentation.stardist_onnx_utils.onnx_framework.inspect.valid_sizes import (
394
+ snap_shape,
395
+ )
396
+
397
+ skip = (0,) if ndim == 3 else ()
398
+ tile_shape = snap_shape(tile_shape, patterns, skip_axes=skip)
399
+ overlap = tuple(
400
+ max(0, min(int(ov), max(0, ts - 1)))
401
+ for ov, ts in zip(overlap, tile_shape)
402
+ )
403
+ return tile_shape, overlap
404
+
405
+ def _resolve_model_path(self, ndim: int) -> Path:
406
+ """Resolve the ONNX model file for 2D or 3D inference.
407
+
408
+ Parameters
409
+ ----------
410
+ ndim : int
411
+ Spatial dimensionality (2 or 3).
412
+
413
+ Returns
414
+ -------
415
+ pathlib.Path
416
+ Path to the ONNX model file.
417
+
418
+ Raises
419
+ ------
420
+ FileNotFoundError
421
+ If no ONNX model file is found.
422
+ ValueError
423
+ If multiple candidates are found without a default name.
424
+ """
425
+ if ndim != 3:
426
+ raise ValueError("StarDist ONNX 3D expects a 3D model.")
427
+ default_filename = "default_3d.onnx"
428
+ candidates = [
429
+ self.model_dir / "onnx_models" / default_filename,
430
+ self.model_dir / default_filename,
431
+ self.model_dir / "onnx_models" / "stardist3d_3D_demo.onnx",
432
+ self.model_dir / "stardist3d_3D_demo.onnx",
433
+ self.model_dir / "stardist3d.onnx",
434
+ ]
435
+
436
+ for path in candidates:
437
+ if path.exists():
438
+ return path
439
+
440
+ try:
441
+ downloaded = ensure_hf_model(
442
+ default_filename,
443
+ self.model_dir / "onnx_models",
444
+ repo_id=DEFAULT_REPO_ID,
445
+ )
446
+ except RuntimeError:
447
+ downloaded = None
448
+ if downloaded is not None and downloaded.exists():
449
+ return downloaded
450
+
451
+ matches = []
452
+ for folder in (self.model_dir / "onnx_models", self.model_dir):
453
+ if folder.exists():
454
+ matches.extend(sorted(folder.glob("*.onnx")))
455
+
456
+ if len(matches) == 1:
457
+ return matches[0]
458
+ if len(matches) > 1:
459
+ raise ValueError(
460
+ "Multiple ONNX files found; keep one or use default file names."
461
+ )
462
+ raise FileNotFoundError(
463
+ "No ONNX model found. Place the exported model in the model folder "
464
+ "or allow SenoQuant to download it from the model repository."
465
+ )
466
+
467
+ def _resolve_io_names(self, session: ort.InferenceSession) -> tuple[str, list[str]]:
468
+ """Resolve input and output tensor names for prob/dist inference."""
469
+ inputs = session.get_inputs()
470
+ outputs = session.get_outputs()
471
+ if not inputs:
472
+ raise RuntimeError("ONNX model has no inputs.")
473
+ if len(outputs) < 2:
474
+ raise RuntimeError("ONNX model must have prob and dist outputs.")
475
+
476
+ input_name = inputs[0].name
477
+
478
+ prob = None
479
+ dist = None
480
+ for output in outputs:
481
+ name = output.name.lower()
482
+ if "prob" in name and prob is None:
483
+ prob = output
484
+ elif "dist" in name and dist is None:
485
+ dist = output
486
+
487
+ if prob is None or dist is None:
488
+ for output in outputs:
489
+ shape = output.shape or []
490
+ channel = shape[-1] if shape else None
491
+ if channel == 1 and prob is None:
492
+ prob = output
493
+ elif channel not in (None, 1) and dist is None:
494
+ dist = output
495
+
496
+ if prob is None or dist is None:
497
+ prob, dist = outputs[0], outputs[1]
498
+
499
+ return input_name, [prob.name, dist.name]
500
+
501
+ def _ensure_stardist_lib_stubs(self) -> None:
502
+ """Ensure StarDist modules import without compiled extensions.
503
+
504
+ This registers minimal stubs for compiled modules when shared
505
+ libraries are absent, allowing Python utilities to import.
506
+ """
507
+ utils_root = self._get_utils_root()
508
+ csbdeep_root = utils_root / "_csbdeep"
509
+ if csbdeep_root.exists():
510
+ csbdeep_path = str(csbdeep_root)
511
+ if csbdeep_path not in sys.path:
512
+ sys.path.insert(0, csbdeep_path)
513
+
514
+ stardist_pkg = (
515
+ "senoquant.tabs.segmentation.stardist_onnx_utils._stardist"
516
+ )
517
+ if stardist_pkg not in sys.modules:
518
+ pkg = types.ModuleType(stardist_pkg)
519
+ pkg.__path__ = [str(utils_root / "_stardist")]
520
+ sys.modules[stardist_pkg] = pkg
521
+
522
+ base_pkg = f"{stardist_pkg}.lib"
523
+ lib_dirs = [utils_root / "_stardist" / "lib"]
524
+ for entry in list(sys.path):
525
+ if not entry:
526
+ continue
527
+ try:
528
+ candidate = (
529
+ Path(entry)
530
+ / "senoquant"
531
+ / "tabs"
532
+ / "segmentation"
533
+ / "stardist_onnx_utils"
534
+ / "_stardist"
535
+ / "lib"
536
+ )
537
+ except Exception:
538
+ continue
539
+ if candidate.exists():
540
+ lib_dirs.append(candidate)
541
+
542
+ if base_pkg in sys.modules:
543
+ pkg = sys.modules[base_pkg]
544
+ pkg.__path__ = [str(p) for p in lib_dirs]
545
+ else:
546
+ pkg = types.ModuleType(base_pkg)
547
+ pkg.__path__ = [str(p) for p in lib_dirs]
548
+ sys.modules[base_pkg] = pkg
549
+
550
+ def _stub(*_args, **_kwargs):
551
+ raise RuntimeError("StarDist compiled ops are unavailable.")
552
+
553
+ has_2d = False
554
+ has_3d = False
555
+ for lib_dir in lib_dirs:
556
+ has_2d = has_2d or any(lib_dir.glob("stardist2d*.so")) or any(
557
+ lib_dir.glob("stardist2d*.pyd")
558
+ )
559
+ has_3d = has_3d or any(lib_dir.glob("stardist3d*.so")) or any(
560
+ lib_dir.glob("stardist3d*.pyd")
561
+ )
562
+ self._has_stardist_2d_lib = has_2d
563
+ self._has_stardist_3d_lib = has_3d
564
+
565
+ mod2d = f"{base_pkg}.stardist2d"
566
+ if has_2d and mod2d in sys.modules:
567
+ if getattr(sys.modules[mod2d], "__file__", None) is None:
568
+ del sys.modules[mod2d]
569
+ if not has_2d and mod2d not in sys.modules:
570
+ module = types.ModuleType(mod2d)
571
+ module.c_star_dist = _stub
572
+ module.c_non_max_suppression_inds_old = _stub
573
+ module.c_non_max_suppression_inds = _stub
574
+ sys.modules[mod2d] = module
575
+
576
+ mod3d = f"{base_pkg}.stardist3d"
577
+ if has_3d and mod3d in sys.modules:
578
+ if getattr(sys.modules[mod3d], "__file__", None) is None:
579
+ del sys.modules[mod3d]
580
+ if not has_3d and mod3d not in sys.modules:
581
+ module = types.ModuleType(mod3d)
582
+ module.c_star_dist3d = _stub
583
+ module.c_polyhedron_to_label = _stub
584
+ module.c_non_max_suppression_inds = _stub
585
+ sys.modules[mod3d] = module
586
+
587
+ def _get_rays_class(self):
588
+ """Load and cache the StarDist Rays_GoldenSpiral class."""
589
+ if self._rays_class is not None:
590
+ return self._rays_class
591
+
592
+ utils_root = self._get_utils_root()
593
+ rays_path = utils_root / "_stardist" / "rays3d.py"
594
+ if not rays_path.exists():
595
+ raise FileNotFoundError("Could not locate StarDist rays3d.py.")
596
+
597
+ module_name = "senoquant_stardist_rays3d"
598
+ spec = importlib.util.spec_from_file_location(module_name, rays_path)
599
+ if spec is None or spec.loader is None:
600
+ raise ImportError("Failed to load StarDist rays3d module.")
601
+ module = importlib.util.module_from_spec(spec)
602
+ spec.loader.exec_module(module)
603
+ self._rays_class = module.Rays_GoldenSpiral
604
+ return self._rays_class
605
+
606
+ def _get_utils_root(self) -> Path:
607
+ """Return the stardist_onnx_utils package root."""
608
+ return Path(__file__).resolve().parents[2] / "stardist_onnx_utils"
609
+
610
+ def _infer_grid(
611
+ self,
612
+ image: np.ndarray,
613
+ session: ort.InferenceSession,
614
+ input_name: str,
615
+ output_names: list[str],
616
+ input_layout: str,
617
+ prob_layout: str,
618
+ *,
619
+ model_path: Path | None = None,
620
+ ) -> tuple[int, ...]:
621
+ """Infer model grid/stride by running a probe tile.
622
+
623
+ Parameters
624
+ ----------
625
+ image : numpy.ndarray
626
+ Input image.
627
+ session : onnxruntime.InferenceSession
628
+ ONNX Runtime session.
629
+ input_name : str
630
+ ONNX input tensor name.
631
+ output_names : list[str]
632
+ ONNX output tensor names (prob, dist).
633
+ input_layout : str
634
+ Input layout string (e.g., "NHWC", "NDHWC").
635
+ prob_layout : str
636
+ Probability output layout string.
637
+
638
+ Returns
639
+ -------
640
+ tuple[int, ...]
641
+ Estimated grid/stride per axis.
642
+ """
643
+ probe = self._make_probe_image(
644
+ image, model_path=model_path, input_layout=input_layout
645
+ )
646
+ if input_layout in ("NHWC", "NDHWC"):
647
+ input_tensor = probe[np.newaxis, ..., np.newaxis]
648
+ else:
649
+ input_tensor = probe[np.newaxis, np.newaxis, ...]
650
+
651
+ prob = session.run(output_names, {input_name: input_tensor})[0]
652
+ if prob_layout in ("NHWC", "NDHWC"):
653
+ out_shape = prob.shape[1:-1]
654
+ elif prob_layout in ("NCHW", "NCDHW"):
655
+ out_shape = prob.shape[2:]
656
+ else:
657
+ raise ValueError(f"Unsupported prob layout {prob_layout}.")
658
+
659
+ grid = []
660
+ for dim_in, dim_out in zip(probe.shape, out_shape):
661
+ if dim_out in (0, None):
662
+ grid.append(1)
663
+ continue
664
+ ratio = dim_in / dim_out
665
+ grid.append(max(1, int(round(ratio))))
666
+ return tuple(grid)
667
+
668
+ def _make_probe_image(
669
+ self,
670
+ image: np.ndarray,
671
+ *,
672
+ model_path: Path | None = None,
673
+ input_layout: str | None = None,
674
+ ) -> np.ndarray:
675
+ """Create a small probe image for grid inference."""
676
+ return make_probe_image(
677
+ image,
678
+ model_path=model_path,
679
+ input_layout=input_layout,
680
+ div_by_cache=self._div_by_cache,
681
+ valid_size_cache=self._valid_size_cache,
682
+ )