senoquant 1.0.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. senoquant/__init__.py +6 -0
  2. senoquant/_reader.py +7 -0
  3. senoquant/_widget.py +33 -0
  4. senoquant/napari.yaml +83 -0
  5. senoquant/reader/__init__.py +5 -0
  6. senoquant/reader/core.py +369 -0
  7. senoquant/tabs/__init__.py +15 -0
  8. senoquant/tabs/batch/__init__.py +10 -0
  9. senoquant/tabs/batch/backend.py +641 -0
  10. senoquant/tabs/batch/config.py +270 -0
  11. senoquant/tabs/batch/frontend.py +1283 -0
  12. senoquant/tabs/batch/io.py +326 -0
  13. senoquant/tabs/batch/layers.py +86 -0
  14. senoquant/tabs/quantification/__init__.py +1 -0
  15. senoquant/tabs/quantification/backend.py +228 -0
  16. senoquant/tabs/quantification/features/__init__.py +80 -0
  17. senoquant/tabs/quantification/features/base.py +142 -0
  18. senoquant/tabs/quantification/features/marker/__init__.py +5 -0
  19. senoquant/tabs/quantification/features/marker/config.py +69 -0
  20. senoquant/tabs/quantification/features/marker/dialog.py +437 -0
  21. senoquant/tabs/quantification/features/marker/export.py +879 -0
  22. senoquant/tabs/quantification/features/marker/feature.py +119 -0
  23. senoquant/tabs/quantification/features/marker/morphology.py +285 -0
  24. senoquant/tabs/quantification/features/marker/rows.py +654 -0
  25. senoquant/tabs/quantification/features/marker/thresholding.py +46 -0
  26. senoquant/tabs/quantification/features/roi.py +346 -0
  27. senoquant/tabs/quantification/features/spots/__init__.py +5 -0
  28. senoquant/tabs/quantification/features/spots/config.py +62 -0
  29. senoquant/tabs/quantification/features/spots/dialog.py +477 -0
  30. senoquant/tabs/quantification/features/spots/export.py +1292 -0
  31. senoquant/tabs/quantification/features/spots/feature.py +112 -0
  32. senoquant/tabs/quantification/features/spots/morphology.py +279 -0
  33. senoquant/tabs/quantification/features/spots/rows.py +241 -0
  34. senoquant/tabs/quantification/frontend.py +815 -0
  35. senoquant/tabs/segmentation/__init__.py +1 -0
  36. senoquant/tabs/segmentation/backend.py +131 -0
  37. senoquant/tabs/segmentation/frontend.py +1009 -0
  38. senoquant/tabs/segmentation/models/__init__.py +5 -0
  39. senoquant/tabs/segmentation/models/base.py +146 -0
  40. senoquant/tabs/segmentation/models/cpsam/details.json +65 -0
  41. senoquant/tabs/segmentation/models/cpsam/model.py +150 -0
  42. senoquant/tabs/segmentation/models/default_2d/details.json +69 -0
  43. senoquant/tabs/segmentation/models/default_2d/model.py +664 -0
  44. senoquant/tabs/segmentation/models/default_3d/details.json +69 -0
  45. senoquant/tabs/segmentation/models/default_3d/model.py +682 -0
  46. senoquant/tabs/segmentation/models/hf.py +71 -0
  47. senoquant/tabs/segmentation/models/nuclear_dilation/__init__.py +1 -0
  48. senoquant/tabs/segmentation/models/nuclear_dilation/details.json +26 -0
  49. senoquant/tabs/segmentation/models/nuclear_dilation/model.py +96 -0
  50. senoquant/tabs/segmentation/models/perinuclear_rings/__init__.py +1 -0
  51. senoquant/tabs/segmentation/models/perinuclear_rings/details.json +34 -0
  52. senoquant/tabs/segmentation/models/perinuclear_rings/model.py +132 -0
  53. senoquant/tabs/segmentation/stardist_onnx_utils/__init__.py +2 -0
  54. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/__init__.py +3 -0
  55. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/__init__.py +6 -0
  56. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/generate.py +470 -0
  57. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/prepare.py +273 -0
  58. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/rawdata.py +112 -0
  59. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/transform.py +384 -0
  60. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/__init__.py +0 -0
  61. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/blocks.py +184 -0
  62. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/losses.py +79 -0
  63. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/nets.py +165 -0
  64. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/predict.py +467 -0
  65. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/probability.py +67 -0
  66. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/train.py +148 -0
  67. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/io/__init__.py +163 -0
  68. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/__init__.py +52 -0
  69. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/base_model.py +329 -0
  70. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_isotropic.py +160 -0
  71. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_projection.py +178 -0
  72. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_standard.py +446 -0
  73. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_upsampling.py +54 -0
  74. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/config.py +254 -0
  75. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/pretrained.py +119 -0
  76. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/__init__.py +0 -0
  77. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/care_predict.py +180 -0
  78. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/__init__.py +5 -0
  79. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/plot_utils.py +159 -0
  80. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/six.py +18 -0
  81. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/tf.py +644 -0
  82. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/utils.py +272 -0
  83. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/version.py +1 -0
  84. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/docs/source/conf.py +368 -0
  85. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/setup.py +68 -0
  86. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_datagen.py +169 -0
  87. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_models.py +462 -0
  88. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_utils.py +166 -0
  89. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tools/create_zip_contents.py +34 -0
  90. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/__init__.py +30 -0
  91. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/big.py +624 -0
  92. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/bioimageio_utils.py +494 -0
  93. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/data/__init__.py +39 -0
  94. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/__init__.py +10 -0
  95. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom2d.py +215 -0
  96. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom3d.py +349 -0
  97. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/matching.py +483 -0
  98. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/__init__.py +28 -0
  99. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/base.py +1217 -0
  100. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model2d.py +594 -0
  101. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model3d.py +696 -0
  102. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/nms.py +384 -0
  103. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/__init__.py +2 -0
  104. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/plot.py +74 -0
  105. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/render.py +298 -0
  106. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/rays3d.py +373 -0
  107. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/sample_patches.py +65 -0
  108. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/__init__.py +0 -0
  109. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict2d.py +90 -0
  110. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict3d.py +93 -0
  111. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/utils.py +408 -0
  112. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/version.py +1 -0
  113. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/__init__.py +45 -0
  114. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/__init__.py +17 -0
  115. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/cli.py +55 -0
  116. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/core.py +285 -0
  117. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/__init__.py +15 -0
  118. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/cli.py +36 -0
  119. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/divisibility.py +193 -0
  120. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/probe.py +100 -0
  121. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/receptive_field.py +182 -0
  122. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/rf_cli.py +48 -0
  123. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/valid_sizes.py +278 -0
  124. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/__init__.py +8 -0
  125. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/core.py +157 -0
  126. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/__init__.py +17 -0
  127. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/core.py +226 -0
  128. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/__init__.py +5 -0
  129. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/core.py +401 -0
  130. senoquant/tabs/settings/__init__.py +1 -0
  131. senoquant/tabs/settings/backend.py +29 -0
  132. senoquant/tabs/settings/frontend.py +19 -0
  133. senoquant/tabs/spots/__init__.py +1 -0
  134. senoquant/tabs/spots/backend.py +139 -0
  135. senoquant/tabs/spots/frontend.py +800 -0
  136. senoquant/tabs/spots/models/__init__.py +5 -0
  137. senoquant/tabs/spots/models/base.py +94 -0
  138. senoquant/tabs/spots/models/rmp/details.json +61 -0
  139. senoquant/tabs/spots/models/rmp/model.py +499 -0
  140. senoquant/tabs/spots/models/udwt/details.json +103 -0
  141. senoquant/tabs/spots/models/udwt/model.py +482 -0
  142. senoquant/utils.py +25 -0
  143. senoquant-1.0.0b1.dist-info/METADATA +193 -0
  144. senoquant-1.0.0b1.dist-info/RECORD +148 -0
  145. senoquant-1.0.0b1.dist-info/WHEEL +5 -0
  146. senoquant-1.0.0b1.dist-info/entry_points.txt +2 -0
  147. senoquant-1.0.0b1.dist-info/licenses/LICENSE +28 -0
  148. senoquant-1.0.0b1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,664 @@
1
+ """StarDist ONNX segmentation model implementation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ import importlib.util
7
+ import sys
8
+ import types
9
+ from typing import TYPE_CHECKING
10
+
11
+ import numpy as np
12
+ import onnxruntime as ort
13
+ from scipy import ndimage as ndi
14
+
15
+ from senoquant.utils import layer_data_asarray
16
+ from ..hf import DEFAULT_REPO_ID, ensure_hf_model
17
+ from ..base import SenoQuantSegmentationModel
18
+ from senoquant.tabs.segmentation.stardist_onnx_utils.onnx_framework import (
19
+ normalize,
20
+ predict_tiled,
21
+ )
22
+ from senoquant.tabs.segmentation.stardist_onnx_utils.onnx_framework.inspect import (
23
+ make_probe_image,
24
+ )
25
+ if TYPE_CHECKING:
26
+ from senoquant.tabs.segmentation.stardist_onnx_utils.onnx_framework.inspect.valid_sizes import (
27
+ ValidSizePattern,
28
+ )
29
+
30
+
31
+ class StarDistOnnxModel(SenoQuantSegmentationModel):
32
+ """StarDist ONNX 2D segmentation model.
33
+
34
+ This wrapper loads an exported StarDist 2D ONNX model, runs
35
+ preprocessing and tiled inference, and postprocesses the outputs into
36
+ instance labels using StarDist geometry and NMS utilities.
37
+
38
+ Notes
39
+ -----
40
+ - Inputs must be single-channel images in YX (2D) order.
41
+ - ONNX model outputs are assumed to be probability and distance maps.
42
+ """
43
+
44
+ def __init__(self, models_root=None) -> None:
45
+ """Initialize the StarDist ONNX model wrapper.
46
+
47
+ Parameters
48
+ ----------
49
+ models_root : pathlib.Path or None
50
+ Optional root directory for model storage.
51
+ """
52
+ super().__init__("default_2d", models_root=models_root)
53
+ self._sessions: dict[Path, ort.InferenceSession] = {}
54
+ self._rays_class = None
55
+ self._has_stardist_2d_lib = False
56
+ self._has_stardist_3d_lib = False
57
+ self._div_by_cache: dict[Path, tuple[int, ...]] = {}
58
+ self._overlap_cache: dict[Path, tuple[int, ...]] = {}
59
+ self._valid_size_cache: dict[Path, list["ValidSizePattern"] | None] = {}
60
+
61
+ def run(self, **kwargs) -> dict:
62
+ """Run StarDist ONNX for nuclear segmentation.
63
+
64
+ Parameters
65
+ ----------
66
+ **kwargs
67
+ task : str
68
+ Must be "nuclear" for this model.
69
+ layer : napari.layers.Image
70
+ Single-channel image layer (YX or ZYX).
71
+ settings : dict
72
+ Model settings keyed by ``details.json``.
73
+
74
+ Returns
75
+ -------
76
+ dict
77
+ Dictionary with:
78
+ - ``masks``: instance label image
79
+ - ``prob``: probability map
80
+ - ``dist``: distance/ray map
81
+ - ``info``: NMS metadata (points, prob, dist)
82
+ """
83
+ task = kwargs.get("task")
84
+ if task != "nuclear":
85
+ raise ValueError("StarDist ONNX only supports nuclear segmentation.")
86
+
87
+ layer = kwargs.get("layer")
88
+ settings = kwargs.get("settings", {})
89
+ image = self._extract_layer_data(layer, required=True)
90
+ original_shape = image.shape
91
+
92
+ if image.ndim != 2:
93
+ raise ValueError("StarDist ONNX 2D expects a 2D (YX) image.")
94
+
95
+ image = image.astype(np.float32, copy=False)
96
+ image, scale = self._scale_input(image, settings)
97
+ image = self._scale_intensity(image)
98
+ if settings.get("normalize", True):
99
+ pmin = float(settings.get("pmin", 1.0))
100
+ pmax = float(settings.get("pmax", 99.8))
101
+ image = normalize(image, pmin=pmin, pmax=pmax)
102
+
103
+ model_path = self._resolve_model_path(image.ndim)
104
+ session = self._get_session(image.ndim)
105
+ input_name, output_names = self._resolve_io_names(session)
106
+
107
+ input_layout = "NHWC"
108
+ prob_layout = "NHWC"
109
+ dist_layout = "NYXR"
110
+
111
+ grid = self._infer_grid(
112
+ image,
113
+ session,
114
+ input_name,
115
+ output_names,
116
+ input_layout,
117
+ prob_layout,
118
+ model_path=model_path,
119
+ )
120
+
121
+ tile_shape, overlap = self._infer_tiling(
122
+ image, model_path, session, input_name, output_names, input_layout
123
+ )
124
+ div_by = self._div_by_cache.get(model_path, grid)
125
+
126
+ try:
127
+ prob, dist = predict_tiled(
128
+ image,
129
+ session,
130
+ input_name=input_name,
131
+ output_names=output_names,
132
+ grid=grid,
133
+ input_layout=input_layout,
134
+ prob_layout=prob_layout,
135
+ dist_layout=dist_layout,
136
+ tile_shape=tile_shape,
137
+ overlap=overlap,
138
+ div_by=div_by,
139
+ )
140
+ except Exception:
141
+ if "CoreMLExecutionProvider" not in session.get_providers():
142
+ raise
143
+ session = self._get_session(
144
+ image.ndim, providers_override=["CPUExecutionProvider"]
145
+ )
146
+ prob, dist = predict_tiled(
147
+ image,
148
+ session,
149
+ input_name=input_name,
150
+ output_names=output_names,
151
+ grid=grid,
152
+ input_layout=input_layout,
153
+ prob_layout=prob_layout,
154
+ dist_layout=dist_layout,
155
+ tile_shape=tile_shape,
156
+ overlap=overlap,
157
+ div_by=div_by,
158
+ )
159
+
160
+ prob_thresh = float(settings.get("prob_thresh", 0.5))
161
+ nms_thresh = float(settings.get("nms_thresh", 0.4))
162
+
163
+ self._ensure_stardist_lib_stubs()
164
+ from senoquant.tabs.segmentation.stardist_onnx_utils.onnx_framework import (
165
+ instances_from_prediction_2d,
166
+ )
167
+
168
+ if not self._has_stardist_2d_lib:
169
+ raise RuntimeError(
170
+ "StarDist 2D compiled ops are missing. Build the "
171
+ "extensions in stardist_onnx_utils/_stardist/lib."
172
+ )
173
+ labels, info = instances_from_prediction_2d(
174
+ prob,
175
+ dist,
176
+ grid=grid,
177
+ prob_thresh=prob_thresh,
178
+ nms_thresh=nms_thresh,
179
+ scale=scale,
180
+ img_shape=original_shape,
181
+ )
182
+
183
+ return {"masks": labels, "prob": prob, "dist": dist, "info": info}
184
+
185
+ def _scale_input(
186
+ self, image: np.ndarray, settings: dict
187
+ ) -> tuple[np.ndarray, dict[str, float] | None]:
188
+ """Scale the input image to match training object sizes.
189
+
190
+ Parameters
191
+ ----------
192
+ image : numpy.ndarray
193
+ Input 2D image in YX order.
194
+ settings : dict
195
+ Model settings containing the ``object_diameter_px`` entry.
196
+
197
+ Returns
198
+ -------
199
+ numpy.ndarray
200
+ Scaled image. If no scaling is requested, returns the input image.
201
+ dict[str, float] or None
202
+ Scale factors keyed by axis (``"Y"``, ``"X"``) for rescaling
203
+ predictions back to the original image space.
204
+ """
205
+ diameter_px = float(settings.get("object_diameter_px", 30.0))
206
+ if diameter_px <= 0:
207
+ raise ValueError("Object diameter (px) must be positive.")
208
+ scale_factor = 17.44 / diameter_px
209
+ if np.isclose(scale_factor, 1.0):
210
+ return image, None
211
+
212
+ scale = (scale_factor, scale_factor)
213
+ scaled = ndi.zoom(image, scale, order=1)
214
+ if min(scaled.shape) < 1:
215
+ raise ValueError(
216
+ "Scaling factor produced an empty image; adjust object diameter."
217
+ )
218
+ return scaled.astype(np.float32, copy=False), {"Y": scale_factor, "X": scale_factor}
219
+
220
+ @staticmethod
221
+ def _scale_intensity(image: np.ndarray) -> np.ndarray:
222
+ """Scale image intensities into [0, 1] using min/max."""
223
+ imin = float(np.nanmin(image))
224
+ imax = float(np.nanmax(image))
225
+ if not np.isfinite(imin) or not np.isfinite(imax):
226
+ return image
227
+ if imax <= imin:
228
+ return image
229
+ return ((image - imin) / (imax - imin)).astype(np.float32, copy=False)
230
+
231
+ def _extract_layer_data(self, layer, required: bool) -> np.ndarray:
232
+ """Return numpy data for a napari layer.
233
+
234
+ Parameters
235
+ ----------
236
+ layer : object or None
237
+ Napari layer to convert.
238
+ required : bool
239
+ Whether a missing layer should raise an error.
240
+
241
+ Returns
242
+ -------
243
+ numpy.ndarray
244
+ Layer data as an array.
245
+ """
246
+ if layer is None:
247
+ if required:
248
+ raise ValueError("Layer is required for StarDist ONNX.")
249
+ return None
250
+ return layer_data_asarray(layer)
251
+
252
+ def _get_session(
253
+ self, ndim: int, *, providers_override: list[str] | None = None
254
+ ) -> ort.InferenceSession:
255
+ """Return (and cache) an ONNX Runtime session for 2D or 3D models."""
256
+ model_path = self._resolve_model_path(ndim)
257
+ session = self._sessions.get(model_path)
258
+ if session is None or providers_override is not None:
259
+ providers = providers_override or self._preferred_providers()
260
+ session = ort.InferenceSession(
261
+ str(model_path),
262
+ providers=providers,
263
+ )
264
+ self._sessions[model_path] = session
265
+ return session
266
+
267
+ @staticmethod
268
+ def _preferred_providers() -> list[str]:
269
+ """Return a provider list that prefers GPU providers when available."""
270
+ available = set(ort.get_available_providers())
271
+ preferred = [
272
+ "CUDAExecutionProvider",
273
+ "ROCMExecutionProvider",
274
+ "DirectMLExecutionProvider",
275
+ "CoreMLExecutionProvider",
276
+ "CPUExecutionProvider",
277
+ ]
278
+ providers = [provider for provider in preferred if provider in available]
279
+ if not providers:
280
+ providers = list(available)
281
+ return providers
282
+
283
+ def _infer_tiling(
284
+ self,
285
+ image: np.ndarray,
286
+ model_path: Path,
287
+ session: ort.InferenceSession,
288
+ input_name: str,
289
+ output_names: list[str],
290
+ input_layout: str,
291
+ ) -> tuple[tuple[int, ...], tuple[int, ...]]:
292
+ """Infer tiling shape and overlap for ONNX tiled prediction.
293
+
294
+ This method uses the ONNX inspection utilities to derive:
295
+ - the per-axis divisibility requirement (``div_by``), and
296
+ - a recommended overlap based on the empirical receptive field.
297
+
298
+ The inferred values are cached per ONNX model path so the expensive
299
+ inspection (graph parsing / RF probing) only happens once per model.
300
+ If inspection fails for any reason, safe fallbacks are used:
301
+ ``div_by = (1, ... )`` and ``overlap = (0, ... )``.
302
+
303
+ Parameters
304
+ ----------
305
+ image : numpy.ndarray
306
+ Input image used to determine spatial dimensionality and to
307
+ clamp tile shape/overlap to valid ranges.
308
+ model_path : pathlib.Path
309
+ Path to the ONNX model, used as a cache key for inferred values.
310
+
311
+ Returns
312
+ -------
313
+ tuple[tuple[int, ...], tuple[int, ...]]
314
+ A tuple ``(tile_shape, overlap)``, each a per-axis tuple with
315
+ the same length as ``image.ndim``. ``tile_shape`` is rounded
316
+ down to the nearest multiple of ``div_by`` (never exceeding the
317
+ input size), and ``overlap`` is clamped to ``[0, tile_size - 1]``.
318
+ Tile sizes are additionally capped at 1024 pixels per axis to
319
+ avoid feeding overly large tiles to the ONNX model.
320
+ """
321
+ ndim = image.ndim
322
+ div_by = self._div_by_cache.get(model_path)
323
+ if div_by is None:
324
+ div_by = (16,) * ndim
325
+ self._div_by_cache[model_path] = div_by
326
+
327
+ overlap = self._overlap_cache.get(model_path)
328
+ if overlap is None:
329
+ try:
330
+ from senoquant.tabs.segmentation.stardist_onnx_utils.onnx_framework.inspect.receptive_field import (
331
+ recommend_tile_overlap,
332
+ )
333
+ except Exception:
334
+ overlap = (0,) * ndim
335
+ else:
336
+ try:
337
+ overlap = recommend_tile_overlap(model_path, ndim=ndim)
338
+ except Exception:
339
+ overlap = (0,) * ndim
340
+ self._overlap_cache[model_path] = overlap
341
+
342
+ max_tile = 1024
343
+ capped_shape = tuple(min(size, max_tile) for size in image.shape)
344
+
345
+ tile_shape = tuple(
346
+ max(div, (size // div) * div) if div > 0 else size
347
+ for size, div in zip(capped_shape, div_by)
348
+ )
349
+
350
+ patterns = self._valid_size_cache.get(model_path)
351
+ if patterns is None:
352
+ try:
353
+ from senoquant.tabs.segmentation.stardist_onnx_utils.onnx_framework.inspect.valid_sizes import (
354
+ infer_valid_size_patterns_from_path,
355
+ )
356
+ except Exception:
357
+ patterns = None
358
+ else:
359
+ try:
360
+ patterns = infer_valid_size_patterns_from_path(
361
+ model_path,
362
+ input_layout,
363
+ ndim,
364
+ )
365
+ except Exception:
366
+ patterns = None
367
+ self._valid_size_cache[model_path] = patterns
368
+
369
+ if patterns:
370
+ from senoquant.tabs.segmentation.stardist_onnx_utils.onnx_framework.inspect.valid_sizes import (
371
+ snap_shape,
372
+ )
373
+
374
+ tile_shape = snap_shape(tile_shape, patterns)
375
+ tile_shape = tuple(
376
+ max(16, (ts // 16) * 16)
377
+ for ts in tile_shape
378
+ )
379
+ overlap = tuple(
380
+ max(0, min(int(ov), max(0, ts - 1)))
381
+ for ov, ts in zip(overlap, tile_shape)
382
+ )
383
+ return tile_shape, overlap
384
+
385
+ def _resolve_model_path(self, ndim: int) -> Path:
386
+ """Resolve the ONNX model file for 2D or 3D inference.
387
+
388
+ Parameters
389
+ ----------
390
+ ndim : int
391
+ Spatial dimensionality (2 or 3).
392
+
393
+ Returns
394
+ -------
395
+ pathlib.Path
396
+ Path to the ONNX model file.
397
+
398
+ Raises
399
+ ------
400
+ FileNotFoundError
401
+ If no ONNX model file is found.
402
+ ValueError
403
+ If multiple candidates are found without a default name.
404
+ """
405
+ if ndim != 2:
406
+ raise ValueError("StarDist ONNX 2D expects a 2D model.")
407
+ default_filename = "default_2d.onnx"
408
+ candidates = [
409
+ self.model_dir / "onnx_models" / default_filename,
410
+ self.model_dir / default_filename,
411
+ self.model_dir / "onnx_models" / "stardist_mod_2d.onnx",
412
+ self.model_dir / "onnx_models" / "stardist2d_2D_versatile_fluo.onnx",
413
+ self.model_dir / "stardist_mod_2d.onnx",
414
+ self.model_dir / "stardist2d_2D_versatile_fluo.onnx",
415
+ self.model_dir / "stardist2d.onnx",
416
+ ]
417
+
418
+ for path in candidates:
419
+ if path.exists():
420
+ return path
421
+
422
+ try:
423
+ downloaded = ensure_hf_model(
424
+ default_filename,
425
+ self.model_dir / "onnx_models",
426
+ repo_id=DEFAULT_REPO_ID,
427
+ )
428
+ except RuntimeError:
429
+ downloaded = None
430
+ if downloaded is not None and downloaded.exists():
431
+ return downloaded
432
+
433
+ matches = []
434
+ for folder in (self.model_dir / "onnx_models", self.model_dir):
435
+ if folder.exists():
436
+ matches.extend(sorted(folder.glob("*.onnx")))
437
+
438
+ if len(matches) == 1:
439
+ return matches[0]
440
+ if len(matches) > 1:
441
+ raise ValueError(
442
+ "Multiple ONNX files found; keep one or use default file names."
443
+ )
444
+ raise FileNotFoundError(
445
+ "No ONNX model found. Place the exported model in the model folder "
446
+ "or allow SenoQuant to download it from the model repository."
447
+ )
448
+
449
+ def _resolve_io_names(self, session: ort.InferenceSession) -> tuple[str, list[str]]:
450
+ """Resolve input and output tensor names for prob/dist inference."""
451
+ inputs = session.get_inputs()
452
+ outputs = session.get_outputs()
453
+ if not inputs:
454
+ raise RuntimeError("ONNX model has no inputs.")
455
+ if len(outputs) < 2:
456
+ raise RuntimeError("ONNX model must have prob and dist outputs.")
457
+
458
+ input_name = inputs[0].name
459
+
460
+ prob = None
461
+ dist = None
462
+ for output in outputs:
463
+ name = output.name.lower()
464
+ if "prob" in name and prob is None:
465
+ prob = output
466
+ elif "dist" in name and dist is None:
467
+ dist = output
468
+
469
+ if prob is None or dist is None:
470
+ for output in outputs:
471
+ shape = output.shape or []
472
+ channel = shape[-1] if shape else None
473
+ if channel == 1 and prob is None:
474
+ prob = output
475
+ elif channel not in (None, 1) and dist is None:
476
+ dist = output
477
+
478
+ if prob is None or dist is None:
479
+ prob, dist = outputs[0], outputs[1]
480
+
481
+ return input_name, [prob.name, dist.name]
482
+
483
+ def _ensure_stardist_lib_stubs(self) -> None:
484
+ """Ensure StarDist modules import without compiled extensions.
485
+
486
+ This registers minimal stubs for compiled modules when shared
487
+ libraries are absent, allowing Python utilities to import.
488
+ """
489
+ utils_root = self._get_utils_root()
490
+ csbdeep_root = utils_root / "_csbdeep"
491
+ if csbdeep_root.exists():
492
+ csbdeep_path = str(csbdeep_root)
493
+ if csbdeep_path not in sys.path:
494
+ sys.path.insert(0, csbdeep_path)
495
+
496
+ stardist_pkg = (
497
+ "senoquant.tabs.segmentation.stardist_onnx_utils._stardist"
498
+ )
499
+ if stardist_pkg not in sys.modules:
500
+ pkg = types.ModuleType(stardist_pkg)
501
+ pkg.__path__ = [str(utils_root / "_stardist")]
502
+ sys.modules[stardist_pkg] = pkg
503
+
504
+ base_pkg = f"{stardist_pkg}.lib"
505
+ lib_dirs = [utils_root / "_stardist" / "lib"]
506
+ for entry in list(sys.path):
507
+ if not entry:
508
+ continue
509
+ try:
510
+ candidate = (
511
+ Path(entry)
512
+ / "senoquant"
513
+ / "tabs"
514
+ / "segmentation"
515
+ / "stardist_onnx_utils"
516
+ / "_stardist"
517
+ / "lib"
518
+ )
519
+ except Exception:
520
+ continue
521
+ if candidate.exists():
522
+ lib_dirs.append(candidate)
523
+
524
+ if base_pkg in sys.modules:
525
+ pkg = sys.modules[base_pkg]
526
+ pkg.__path__ = [str(p) for p in lib_dirs]
527
+ else:
528
+ pkg = types.ModuleType(base_pkg)
529
+ pkg.__path__ = [str(p) for p in lib_dirs]
530
+ sys.modules[base_pkg] = pkg
531
+
532
+ def _stub(*_args, **_kwargs):
533
+ raise RuntimeError("StarDist compiled ops are unavailable.")
534
+
535
+ has_2d = False
536
+ has_3d = False
537
+ for lib_dir in lib_dirs:
538
+ has_2d = has_2d or any(lib_dir.glob("stardist2d*.so")) or any(
539
+ lib_dir.glob("stardist2d*.pyd")
540
+ )
541
+ has_3d = has_3d or any(lib_dir.glob("stardist3d*.so")) or any(
542
+ lib_dir.glob("stardist3d*.pyd")
543
+ )
544
+ self._has_stardist_2d_lib = has_2d
545
+ self._has_stardist_3d_lib = has_3d
546
+
547
+ mod2d = f"{base_pkg}.stardist2d"
548
+ if has_2d and mod2d in sys.modules:
549
+ if getattr(sys.modules[mod2d], "__file__", None) is None:
550
+ del sys.modules[mod2d]
551
+ if not has_2d and mod2d not in sys.modules:
552
+ module = types.ModuleType(mod2d)
553
+ module.c_star_dist = _stub
554
+ module.c_non_max_suppression_inds_old = _stub
555
+ module.c_non_max_suppression_inds = _stub
556
+ sys.modules[mod2d] = module
557
+
558
+ mod3d = f"{base_pkg}.stardist3d"
559
+ if has_3d and mod3d in sys.modules:
560
+ if getattr(sys.modules[mod3d], "__file__", None) is None:
561
+ del sys.modules[mod3d]
562
+ if not has_3d and mod3d not in sys.modules:
563
+ module = types.ModuleType(mod3d)
564
+ module.c_star_dist3d = _stub
565
+ module.c_polyhedron_to_label = _stub
566
+ module.c_non_max_suppression_inds = _stub
567
+ sys.modules[mod3d] = module
568
+
569
+ def _get_rays_class(self):
570
+ """Load and cache the StarDist Rays_GoldenSpiral class."""
571
+ if self._rays_class is not None:
572
+ return self._rays_class
573
+
574
+ utils_root = self._get_utils_root()
575
+ rays_path = utils_root / "_stardist" / "rays3d.py"
576
+ if not rays_path.exists():
577
+ raise FileNotFoundError("Could not locate StarDist rays3d.py.")
578
+
579
+ module_name = "senoquant_stardist_rays3d"
580
+ spec = importlib.util.spec_from_file_location(module_name, rays_path)
581
+ if spec is None or spec.loader is None:
582
+ raise ImportError("Failed to load StarDist rays3d module.")
583
+ module = importlib.util.module_from_spec(spec)
584
+ spec.loader.exec_module(module)
585
+ self._rays_class = module.Rays_GoldenSpiral
586
+ return self._rays_class
587
+
588
+ def _get_utils_root(self) -> Path:
589
+ """Return the stardist_onnx_utils package root."""
590
+ return Path(__file__).resolve().parents[2] / "stardist_onnx_utils"
591
+
592
+ def _infer_grid(
593
+ self,
594
+ image: np.ndarray,
595
+ session: ort.InferenceSession,
596
+ input_name: str,
597
+ output_names: list[str],
598
+ input_layout: str,
599
+ prob_layout: str,
600
+ *,
601
+ model_path: Path | None = None,
602
+ ) -> tuple[int, ...]:
603
+ """Infer model grid/stride by running a probe tile.
604
+
605
+ Parameters
606
+ ----------
607
+ image : numpy.ndarray
608
+ Input image.
609
+ session : onnxruntime.InferenceSession
610
+ ONNX Runtime session.
611
+ input_name : str
612
+ ONNX input tensor name.
613
+ output_names : list[str]
614
+ ONNX output tensor names (prob, dist).
615
+ input_layout : str
616
+ Input layout string (e.g., "NHWC", "NDHWC").
617
+ prob_layout : str
618
+ Probability output layout string.
619
+
620
+ Returns
621
+ -------
622
+ tuple[int, ...]
623
+ Estimated grid/stride per axis.
624
+ """
625
+ probe = self._make_probe_image(
626
+ image, model_path=model_path, input_layout=input_layout
627
+ )
628
+ if input_layout in ("NHWC", "NDHWC"):
629
+ input_tensor = probe[np.newaxis, ..., np.newaxis]
630
+ else:
631
+ input_tensor = probe[np.newaxis, np.newaxis, ...]
632
+
633
+ prob = session.run(output_names, {input_name: input_tensor})[0]
634
+ if prob_layout in ("NHWC", "NDHWC"):
635
+ out_shape = prob.shape[1:-1]
636
+ elif prob_layout in ("NCHW", "NCDHW"):
637
+ out_shape = prob.shape[2:]
638
+ else:
639
+ raise ValueError(f"Unsupported prob layout {prob_layout}.")
640
+
641
+ grid = []
642
+ for dim_in, dim_out in zip(probe.shape, out_shape):
643
+ if dim_out in (0, None):
644
+ grid.append(1)
645
+ continue
646
+ ratio = dim_in / dim_out
647
+ grid.append(max(1, int(round(ratio))))
648
+ return tuple(grid)
649
+
650
+ def _make_probe_image(
651
+ self,
652
+ image: np.ndarray,
653
+ *,
654
+ model_path: Path | None = None,
655
+ input_layout: str | None = None,
656
+ ) -> np.ndarray:
657
+ """Create a small probe image for grid inference."""
658
+ return make_probe_image(
659
+ image,
660
+ model_path=model_path,
661
+ input_layout=input_layout,
662
+ div_by_cache=self._div_by_cache,
663
+ valid_size_cache=self._valid_size_cache,
664
+ )