senoquant 1.0.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- senoquant/__init__.py +6 -0
- senoquant/_reader.py +7 -0
- senoquant/_widget.py +33 -0
- senoquant/napari.yaml +83 -0
- senoquant/reader/__init__.py +5 -0
- senoquant/reader/core.py +369 -0
- senoquant/tabs/__init__.py +15 -0
- senoquant/tabs/batch/__init__.py +10 -0
- senoquant/tabs/batch/backend.py +641 -0
- senoquant/tabs/batch/config.py +270 -0
- senoquant/tabs/batch/frontend.py +1283 -0
- senoquant/tabs/batch/io.py +326 -0
- senoquant/tabs/batch/layers.py +86 -0
- senoquant/tabs/quantification/__init__.py +1 -0
- senoquant/tabs/quantification/backend.py +228 -0
- senoquant/tabs/quantification/features/__init__.py +80 -0
- senoquant/tabs/quantification/features/base.py +142 -0
- senoquant/tabs/quantification/features/marker/__init__.py +5 -0
- senoquant/tabs/quantification/features/marker/config.py +69 -0
- senoquant/tabs/quantification/features/marker/dialog.py +437 -0
- senoquant/tabs/quantification/features/marker/export.py +879 -0
- senoquant/tabs/quantification/features/marker/feature.py +119 -0
- senoquant/tabs/quantification/features/marker/morphology.py +285 -0
- senoquant/tabs/quantification/features/marker/rows.py +654 -0
- senoquant/tabs/quantification/features/marker/thresholding.py +46 -0
- senoquant/tabs/quantification/features/roi.py +346 -0
- senoquant/tabs/quantification/features/spots/__init__.py +5 -0
- senoquant/tabs/quantification/features/spots/config.py +62 -0
- senoquant/tabs/quantification/features/spots/dialog.py +477 -0
- senoquant/tabs/quantification/features/spots/export.py +1292 -0
- senoquant/tabs/quantification/features/spots/feature.py +112 -0
- senoquant/tabs/quantification/features/spots/morphology.py +279 -0
- senoquant/tabs/quantification/features/spots/rows.py +241 -0
- senoquant/tabs/quantification/frontend.py +815 -0
- senoquant/tabs/segmentation/__init__.py +1 -0
- senoquant/tabs/segmentation/backend.py +131 -0
- senoquant/tabs/segmentation/frontend.py +1009 -0
- senoquant/tabs/segmentation/models/__init__.py +5 -0
- senoquant/tabs/segmentation/models/base.py +146 -0
- senoquant/tabs/segmentation/models/cpsam/details.json +65 -0
- senoquant/tabs/segmentation/models/cpsam/model.py +150 -0
- senoquant/tabs/segmentation/models/default_2d/details.json +69 -0
- senoquant/tabs/segmentation/models/default_2d/model.py +664 -0
- senoquant/tabs/segmentation/models/default_3d/details.json +69 -0
- senoquant/tabs/segmentation/models/default_3d/model.py +682 -0
- senoquant/tabs/segmentation/models/hf.py +71 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/__init__.py +1 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/details.json +26 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/model.py +96 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/__init__.py +1 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/details.json +34 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/model.py +132 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/__init__.py +2 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/__init__.py +3 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/__init__.py +6 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/generate.py +470 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/prepare.py +273 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/rawdata.py +112 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/transform.py +384 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/blocks.py +184 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/losses.py +79 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/nets.py +165 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/predict.py +467 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/probability.py +67 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/train.py +148 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/io/__init__.py +163 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/__init__.py +52 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/base_model.py +329 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_isotropic.py +160 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_projection.py +178 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_standard.py +446 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_upsampling.py +54 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/config.py +254 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/pretrained.py +119 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/care_predict.py +180 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/__init__.py +5 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/plot_utils.py +159 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/six.py +18 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/tf.py +644 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/utils.py +272 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/version.py +1 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/docs/source/conf.py +368 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/setup.py +68 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_datagen.py +169 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_models.py +462 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_utils.py +166 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tools/create_zip_contents.py +34 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/__init__.py +30 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/big.py +624 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/bioimageio_utils.py +494 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/data/__init__.py +39 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/__init__.py +10 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom2d.py +215 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom3d.py +349 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/matching.py +483 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/__init__.py +28 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/base.py +1217 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model2d.py +594 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model3d.py +696 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/nms.py +384 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/__init__.py +2 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/plot.py +74 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/render.py +298 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/rays3d.py +373 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/sample_patches.py +65 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict2d.py +90 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict3d.py +93 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/utils.py +408 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/version.py +1 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/__init__.py +45 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/__init__.py +17 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/cli.py +55 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/core.py +285 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/__init__.py +15 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/cli.py +36 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/divisibility.py +193 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/probe.py +100 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/receptive_field.py +182 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/rf_cli.py +48 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/valid_sizes.py +278 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/__init__.py +8 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/core.py +157 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/__init__.py +17 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/core.py +226 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/__init__.py +5 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/core.py +401 -0
- senoquant/tabs/settings/__init__.py +1 -0
- senoquant/tabs/settings/backend.py +29 -0
- senoquant/tabs/settings/frontend.py +19 -0
- senoquant/tabs/spots/__init__.py +1 -0
- senoquant/tabs/spots/backend.py +139 -0
- senoquant/tabs/spots/frontend.py +800 -0
- senoquant/tabs/spots/models/__init__.py +5 -0
- senoquant/tabs/spots/models/base.py +94 -0
- senoquant/tabs/spots/models/rmp/details.json +61 -0
- senoquant/tabs/spots/models/rmp/model.py +499 -0
- senoquant/tabs/spots/models/udwt/details.json +103 -0
- senoquant/tabs/spots/models/udwt/model.py +482 -0
- senoquant/utils.py +25 -0
- senoquant-1.0.0b1.dist-info/METADATA +193 -0
- senoquant-1.0.0b1.dist-info/RECORD +148 -0
- senoquant-1.0.0b1.dist-info/WHEEL +5 -0
- senoquant-1.0.0b1.dist-info/entry_points.txt +2 -0
- senoquant-1.0.0b1.dist-info/licenses/LICENSE +28 -0
- senoquant-1.0.0b1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,682 @@
|
|
|
1
|
+
"""StarDist ONNX segmentation model implementation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
import importlib.util
|
|
7
|
+
import sys
|
|
8
|
+
import types
|
|
9
|
+
from typing import TYPE_CHECKING
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import onnxruntime as ort
|
|
13
|
+
from scipy import ndimage as ndi
|
|
14
|
+
|
|
15
|
+
from senoquant.utils import layer_data_asarray
|
|
16
|
+
from ..hf import DEFAULT_REPO_ID, ensure_hf_model
|
|
17
|
+
from ..base import SenoQuantSegmentationModel
|
|
18
|
+
from senoquant.tabs.segmentation.stardist_onnx_utils.onnx_framework import (
|
|
19
|
+
normalize,
|
|
20
|
+
predict_tiled,
|
|
21
|
+
)
|
|
22
|
+
from senoquant.tabs.segmentation.stardist_onnx_utils.onnx_framework.inspect import (
|
|
23
|
+
make_probe_image,
|
|
24
|
+
)
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from senoquant.tabs.segmentation.stardist_onnx_utils.onnx_framework.inspect.valid_sizes import (
|
|
27
|
+
ValidSizePattern,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class StarDistOnnxModel(SenoQuantSegmentationModel):
|
|
32
|
+
"""StarDist ONNX 3D segmentation model.
|
|
33
|
+
|
|
34
|
+
This wrapper loads an exported StarDist 3D ONNX model, runs
|
|
35
|
+
preprocessing and tiled inference, and postprocesses the outputs into
|
|
36
|
+
instance labels using StarDist geometry and NMS utilities.
|
|
37
|
+
|
|
38
|
+
Notes
|
|
39
|
+
-----
|
|
40
|
+
- Inputs must be single-channel images in ZYX (3D) order.
|
|
41
|
+
- ONNX model outputs are assumed to be probability and distance maps.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(self, models_root=None) -> None:
|
|
45
|
+
"""Initialize the StarDist ONNX model wrapper.
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
models_root : pathlib.Path or None
|
|
50
|
+
Optional root directory for model storage.
|
|
51
|
+
"""
|
|
52
|
+
super().__init__("default_3d", models_root=models_root)
|
|
53
|
+
self._sessions: dict[Path, ort.InferenceSession] = {}
|
|
54
|
+
self._rays_class = None
|
|
55
|
+
self._has_stardist_2d_lib = False
|
|
56
|
+
self._has_stardist_3d_lib = False
|
|
57
|
+
self._div_by_cache: dict[Path, tuple[int, ...]] = {}
|
|
58
|
+
self._overlap_cache: dict[Path, tuple[int, ...]] = {}
|
|
59
|
+
self._valid_size_cache: dict[Path, list["ValidSizePattern"] | None] = {}
|
|
60
|
+
|
|
61
|
+
def run(self, **kwargs) -> dict:
|
|
62
|
+
"""Run StarDist ONNX for nuclear segmentation.
|
|
63
|
+
|
|
64
|
+
Parameters
|
|
65
|
+
----------
|
|
66
|
+
**kwargs
|
|
67
|
+
task : str
|
|
68
|
+
Must be "nuclear" for this model.
|
|
69
|
+
layer : napari.layers.Image
|
|
70
|
+
Single-channel image layer (YX or ZYX).
|
|
71
|
+
settings : dict
|
|
72
|
+
Model settings keyed by ``details.json``.
|
|
73
|
+
|
|
74
|
+
Returns
|
|
75
|
+
-------
|
|
76
|
+
dict
|
|
77
|
+
Dictionary with:
|
|
78
|
+
- ``masks``: instance label image
|
|
79
|
+
- ``prob``: probability map
|
|
80
|
+
- ``dist``: distance/ray map
|
|
81
|
+
- ``info``: NMS metadata (points, prob, dist)
|
|
82
|
+
"""
|
|
83
|
+
task = kwargs.get("task")
|
|
84
|
+
if task != "nuclear":
|
|
85
|
+
raise ValueError("StarDist ONNX only supports nuclear segmentation.")
|
|
86
|
+
|
|
87
|
+
layer = kwargs.get("layer")
|
|
88
|
+
settings = kwargs.get("settings", {})
|
|
89
|
+
image = self._extract_layer_data(layer, required=True)
|
|
90
|
+
original_shape = image.shape
|
|
91
|
+
|
|
92
|
+
if image.ndim != 3:
|
|
93
|
+
raise ValueError("StarDist ONNX 3D expects a 3D (ZYX) image.")
|
|
94
|
+
|
|
95
|
+
image = image.astype(np.float32, copy=False)
|
|
96
|
+
image, scale = self._scale_input(image, settings)
|
|
97
|
+
image = self._scale_intensity(image)
|
|
98
|
+
if settings.get("normalize", True):
|
|
99
|
+
pmin = float(settings.get("pmin", 1.0))
|
|
100
|
+
pmax = float(settings.get("pmax", 99.8))
|
|
101
|
+
image = normalize(image, pmin=pmin, pmax=pmax)
|
|
102
|
+
|
|
103
|
+
model_path = self._resolve_model_path(image.ndim)
|
|
104
|
+
session = self._get_session(image.ndim)
|
|
105
|
+
input_name, output_names = self._resolve_io_names(session)
|
|
106
|
+
|
|
107
|
+
input_layout = "NDHWC"
|
|
108
|
+
prob_layout = "NDHWC"
|
|
109
|
+
dist_layout = "NZYXR"
|
|
110
|
+
|
|
111
|
+
grid = self._infer_grid(
|
|
112
|
+
image,
|
|
113
|
+
session,
|
|
114
|
+
input_name,
|
|
115
|
+
output_names,
|
|
116
|
+
input_layout,
|
|
117
|
+
prob_layout,
|
|
118
|
+
model_path=model_path,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
tile_shape, overlap = self._infer_tiling(
|
|
122
|
+
image, model_path, session, input_name, output_names, input_layout
|
|
123
|
+
)
|
|
124
|
+
div_by = self._div_by_cache.get(model_path, grid)
|
|
125
|
+
|
|
126
|
+
try:
|
|
127
|
+
prob, dist = predict_tiled(
|
|
128
|
+
image,
|
|
129
|
+
session,
|
|
130
|
+
input_name=input_name,
|
|
131
|
+
output_names=output_names,
|
|
132
|
+
grid=grid,
|
|
133
|
+
input_layout=input_layout,
|
|
134
|
+
prob_layout=prob_layout,
|
|
135
|
+
dist_layout=dist_layout,
|
|
136
|
+
tile_shape=tile_shape,
|
|
137
|
+
overlap=overlap,
|
|
138
|
+
div_by=div_by,
|
|
139
|
+
)
|
|
140
|
+
except Exception:
|
|
141
|
+
if "CoreMLExecutionProvider" not in session.get_providers():
|
|
142
|
+
raise
|
|
143
|
+
session = self._get_session(
|
|
144
|
+
image.ndim, providers_override=["CPUExecutionProvider"]
|
|
145
|
+
)
|
|
146
|
+
prob, dist = predict_tiled(
|
|
147
|
+
image,
|
|
148
|
+
session,
|
|
149
|
+
input_name=input_name,
|
|
150
|
+
output_names=output_names,
|
|
151
|
+
grid=grid,
|
|
152
|
+
input_layout=input_layout,
|
|
153
|
+
prob_layout=prob_layout,
|
|
154
|
+
dist_layout=dist_layout,
|
|
155
|
+
tile_shape=tile_shape,
|
|
156
|
+
overlap=overlap,
|
|
157
|
+
div_by=div_by,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
prob_thresh = float(settings.get("prob_thresh", 0.5))
|
|
161
|
+
nms_thresh = float(settings.get("nms_thresh", 0.4))
|
|
162
|
+
|
|
163
|
+
self._ensure_stardist_lib_stubs()
|
|
164
|
+
from senoquant.tabs.segmentation.stardist_onnx_utils.onnx_framework import (
|
|
165
|
+
instances_from_prediction_3d,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
if not self._has_stardist_3d_lib:
|
|
169
|
+
raise RuntimeError(
|
|
170
|
+
"3D StarDist labeling requires compiled ops; build "
|
|
171
|
+
"extensions in stardist_onnx_utils/_stardist/lib."
|
|
172
|
+
)
|
|
173
|
+
rays = self._get_rays_class()(n=dist.shape[-1])
|
|
174
|
+
labels, info = instances_from_prediction_3d(
|
|
175
|
+
prob,
|
|
176
|
+
dist,
|
|
177
|
+
grid=grid,
|
|
178
|
+
prob_thresh=prob_thresh,
|
|
179
|
+
nms_thresh=nms_thresh,
|
|
180
|
+
rays=rays,
|
|
181
|
+
scale=scale,
|
|
182
|
+
img_shape=original_shape,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
return {"masks": labels, "prob": prob, "dist": dist, "info": info}
|
|
186
|
+
|
|
187
|
+
def _scale_input(
|
|
188
|
+
self, image: np.ndarray, settings: dict
|
|
189
|
+
) -> tuple[np.ndarray, dict[str, float] | None]:
|
|
190
|
+
"""Scale the input image to match training object sizes.
|
|
191
|
+
|
|
192
|
+
Parameters
|
|
193
|
+
----------
|
|
194
|
+
image : numpy.ndarray
|
|
195
|
+
Input 3D image in ZYX order.
|
|
196
|
+
settings : dict
|
|
197
|
+
Model settings containing the ``object_diameter_px`` entry.
|
|
198
|
+
|
|
199
|
+
Returns
|
|
200
|
+
-------
|
|
201
|
+
numpy.ndarray
|
|
202
|
+
Scaled image. If no scaling is requested, returns the input image.
|
|
203
|
+
dict[str, float] or None
|
|
204
|
+
Scale factors keyed by axis (``"Z"``, ``"Y"``, ``"X"``) for rescaling
|
|
205
|
+
predictions back to the original image space.
|
|
206
|
+
"""
|
|
207
|
+
diameter_px = float(settings.get("object_diameter_px", 30.0))
|
|
208
|
+
if diameter_px <= 0:
|
|
209
|
+
raise ValueError("Object diameter (px) must be positive.")
|
|
210
|
+
scale_factor = 30.0 / diameter_px
|
|
211
|
+
if np.isclose(scale_factor, 1.0):
|
|
212
|
+
return image, None
|
|
213
|
+
|
|
214
|
+
scale = (scale_factor, scale_factor, scale_factor)
|
|
215
|
+
scaled = ndi.zoom(image, scale, order=1)
|
|
216
|
+
if min(scaled.shape) < 1:
|
|
217
|
+
raise ValueError(
|
|
218
|
+
"Scaling factor produced an empty image; adjust object diameter."
|
|
219
|
+
)
|
|
220
|
+
return scaled.astype(np.float32, copy=False), {
|
|
221
|
+
"Z": scale_factor,
|
|
222
|
+
"Y": scale_factor,
|
|
223
|
+
"X": scale_factor,
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
@staticmethod
|
|
227
|
+
def _scale_intensity(image: np.ndarray) -> np.ndarray:
|
|
228
|
+
"""Scale image intensities into [0, 1] using min/max."""
|
|
229
|
+
imin = float(np.nanmin(image))
|
|
230
|
+
imax = float(np.nanmax(image))
|
|
231
|
+
if not np.isfinite(imin) or not np.isfinite(imax):
|
|
232
|
+
return image
|
|
233
|
+
if imax <= imin:
|
|
234
|
+
return image
|
|
235
|
+
return ((image - imin) / (imax - imin)).astype(np.float32, copy=False)
|
|
236
|
+
|
|
237
|
+
def _extract_layer_data(self, layer, required: bool) -> np.ndarray:
|
|
238
|
+
"""Return numpy data for a napari layer.
|
|
239
|
+
|
|
240
|
+
Parameters
|
|
241
|
+
----------
|
|
242
|
+
layer : object or None
|
|
243
|
+
Napari layer to convert.
|
|
244
|
+
required : bool
|
|
245
|
+
Whether a missing layer should raise an error.
|
|
246
|
+
|
|
247
|
+
Returns
|
|
248
|
+
-------
|
|
249
|
+
numpy.ndarray
|
|
250
|
+
Layer data as an array.
|
|
251
|
+
"""
|
|
252
|
+
if layer is None:
|
|
253
|
+
if required:
|
|
254
|
+
raise ValueError("Layer is required for StarDist ONNX.")
|
|
255
|
+
return None
|
|
256
|
+
return layer_data_asarray(layer)
|
|
257
|
+
|
|
258
|
+
def _get_session(
|
|
259
|
+
self, ndim: int, *, providers_override: list[str] | None = None
|
|
260
|
+
) -> ort.InferenceSession:
|
|
261
|
+
"""Return (and cache) an ONNX Runtime session for 2D or 3D models."""
|
|
262
|
+
model_path = self._resolve_model_path(ndim)
|
|
263
|
+
session = self._sessions.get(model_path)
|
|
264
|
+
if session is None or providers_override is not None:
|
|
265
|
+
providers = providers_override or self._preferred_providers()
|
|
266
|
+
session = ort.InferenceSession(
|
|
267
|
+
str(model_path),
|
|
268
|
+
providers=providers,
|
|
269
|
+
)
|
|
270
|
+
self._sessions[model_path] = session
|
|
271
|
+
return session
|
|
272
|
+
|
|
273
|
+
@staticmethod
|
|
274
|
+
def _preferred_providers() -> list[str]:
|
|
275
|
+
"""Return a provider list that prefers GPU providers when available."""
|
|
276
|
+
available = set(ort.get_available_providers())
|
|
277
|
+
preferred = [
|
|
278
|
+
"CUDAExecutionProvider",
|
|
279
|
+
"ROCMExecutionProvider",
|
|
280
|
+
"DirectMLExecutionProvider",
|
|
281
|
+
"CoreMLExecutionProvider",
|
|
282
|
+
"CPUExecutionProvider",
|
|
283
|
+
]
|
|
284
|
+
providers = [provider for provider in preferred if provider in available]
|
|
285
|
+
if not providers:
|
|
286
|
+
providers = list(available)
|
|
287
|
+
return providers
|
|
288
|
+
|
|
289
|
+
def _infer_tiling(
|
|
290
|
+
self,
|
|
291
|
+
image: np.ndarray,
|
|
292
|
+
model_path: Path,
|
|
293
|
+
session: ort.InferenceSession,
|
|
294
|
+
input_name: str,
|
|
295
|
+
output_names: list[str],
|
|
296
|
+
input_layout: str,
|
|
297
|
+
) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
|
298
|
+
"""Infer tiling shape and overlap for ONNX tiled prediction.
|
|
299
|
+
|
|
300
|
+
This method uses the ONNX inspection utilities to derive:
|
|
301
|
+
- the per-axis divisibility requirement (``div_by``), and
|
|
302
|
+
- a recommended overlap based on the empirical receptive field.
|
|
303
|
+
|
|
304
|
+
The inferred values are cached per ONNX model path so the expensive
|
|
305
|
+
inspection (graph parsing / RF probing) only happens once per model.
|
|
306
|
+
If inspection fails for any reason, safe fallbacks are used:
|
|
307
|
+
``div_by = (1, ... )`` and ``overlap = (0, ... )``.
|
|
308
|
+
|
|
309
|
+
Parameters
|
|
310
|
+
----------
|
|
311
|
+
image : numpy.ndarray
|
|
312
|
+
Input image used to determine spatial dimensionality and to
|
|
313
|
+
clamp tile shape/overlap to valid ranges.
|
|
314
|
+
model_path : pathlib.Path
|
|
315
|
+
Path to the ONNX model, used as a cache key for inferred values.
|
|
316
|
+
|
|
317
|
+
Returns
|
|
318
|
+
-------
|
|
319
|
+
tuple[tuple[int, ...], tuple[int, ...]]
|
|
320
|
+
A tuple ``(tile_shape, overlap)``, each a per-axis tuple with
|
|
321
|
+
the same length as ``image.ndim``. ``tile_shape`` is rounded
|
|
322
|
+
down to the nearest multiple of ``div_by`` (never exceeding the
|
|
323
|
+
input size), and ``overlap`` is clamped to ``[0, tile_size - 1]``.
|
|
324
|
+
The XY tile sizes are capped at 1024 pixels per axis to avoid
|
|
325
|
+
feeding overly large tiles to the ONNX model.
|
|
326
|
+
"""
|
|
327
|
+
ndim = image.ndim
|
|
328
|
+
div_by = self._div_by_cache.get(model_path)
|
|
329
|
+
if div_by is None:
|
|
330
|
+
try:
|
|
331
|
+
from senoquant.tabs.segmentation.stardist_onnx_utils.onnx_framework.inspect import (
|
|
332
|
+
infer_div_by,
|
|
333
|
+
)
|
|
334
|
+
except Exception:
|
|
335
|
+
div_by = (1,) * ndim
|
|
336
|
+
else:
|
|
337
|
+
try:
|
|
338
|
+
div_by = infer_div_by(model_path, ndim=ndim)
|
|
339
|
+
except Exception:
|
|
340
|
+
div_by = (1,) * ndim
|
|
341
|
+
self._div_by_cache[model_path] = div_by
|
|
342
|
+
|
|
343
|
+
overlap = self._overlap_cache.get(model_path)
|
|
344
|
+
if overlap is None:
|
|
345
|
+
try:
|
|
346
|
+
from senoquant.tabs.segmentation.stardist_onnx_utils.onnx_framework.inspect.receptive_field import (
|
|
347
|
+
recommend_tile_overlap,
|
|
348
|
+
)
|
|
349
|
+
except Exception:
|
|
350
|
+
overlap = (0,) * ndim
|
|
351
|
+
else:
|
|
352
|
+
try:
|
|
353
|
+
overlap = recommend_tile_overlap(model_path, ndim=ndim)
|
|
354
|
+
except Exception:
|
|
355
|
+
overlap = (0,) * ndim
|
|
356
|
+
self._overlap_cache[model_path] = overlap
|
|
357
|
+
|
|
358
|
+
max_tile = 1024
|
|
359
|
+
if image.ndim == 3:
|
|
360
|
+
capped_shape = (
|
|
361
|
+
image.shape[0],
|
|
362
|
+
min(image.shape[1], max_tile),
|
|
363
|
+
min(image.shape[2], max_tile),
|
|
364
|
+
)
|
|
365
|
+
else:
|
|
366
|
+
capped_shape = tuple(min(size, max_tile) for size in image.shape)
|
|
367
|
+
|
|
368
|
+
tile_shape = tuple(
|
|
369
|
+
max(div, (size // div) * div) if div > 0 else size
|
|
370
|
+
for size, div in zip(capped_shape, div_by)
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
patterns = self._valid_size_cache.get(model_path)
|
|
374
|
+
if patterns is None:
|
|
375
|
+
try:
|
|
376
|
+
from senoquant.tabs.segmentation.stardist_onnx_utils.onnx_framework.inspect.valid_sizes import (
|
|
377
|
+
infer_valid_size_patterns_from_path,
|
|
378
|
+
)
|
|
379
|
+
except Exception:
|
|
380
|
+
patterns = None
|
|
381
|
+
else:
|
|
382
|
+
try:
|
|
383
|
+
patterns = infer_valid_size_patterns_from_path(
|
|
384
|
+
model_path,
|
|
385
|
+
input_layout,
|
|
386
|
+
ndim,
|
|
387
|
+
)
|
|
388
|
+
except Exception:
|
|
389
|
+
patterns = None
|
|
390
|
+
self._valid_size_cache[model_path] = patterns
|
|
391
|
+
|
|
392
|
+
if patterns:
|
|
393
|
+
from senoquant.tabs.segmentation.stardist_onnx_utils.onnx_framework.inspect.valid_sizes import (
|
|
394
|
+
snap_shape,
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
skip = (0,) if ndim == 3 else ()
|
|
398
|
+
tile_shape = snap_shape(tile_shape, patterns, skip_axes=skip)
|
|
399
|
+
overlap = tuple(
|
|
400
|
+
max(0, min(int(ov), max(0, ts - 1)))
|
|
401
|
+
for ov, ts in zip(overlap, tile_shape)
|
|
402
|
+
)
|
|
403
|
+
return tile_shape, overlap
|
|
404
|
+
|
|
405
|
+
def _resolve_model_path(self, ndim: int) -> Path:
|
|
406
|
+
"""Resolve the ONNX model file for 2D or 3D inference.
|
|
407
|
+
|
|
408
|
+
Parameters
|
|
409
|
+
----------
|
|
410
|
+
ndim : int
|
|
411
|
+
Spatial dimensionality (2 or 3).
|
|
412
|
+
|
|
413
|
+
Returns
|
|
414
|
+
-------
|
|
415
|
+
pathlib.Path
|
|
416
|
+
Path to the ONNX model file.
|
|
417
|
+
|
|
418
|
+
Raises
|
|
419
|
+
------
|
|
420
|
+
FileNotFoundError
|
|
421
|
+
If no ONNX model file is found.
|
|
422
|
+
ValueError
|
|
423
|
+
If multiple candidates are found without a default name.
|
|
424
|
+
"""
|
|
425
|
+
if ndim != 3:
|
|
426
|
+
raise ValueError("StarDist ONNX 3D expects a 3D model.")
|
|
427
|
+
default_filename = "default_3d.onnx"
|
|
428
|
+
candidates = [
|
|
429
|
+
self.model_dir / "onnx_models" / default_filename,
|
|
430
|
+
self.model_dir / default_filename,
|
|
431
|
+
self.model_dir / "onnx_models" / "stardist3d_3D_demo.onnx",
|
|
432
|
+
self.model_dir / "stardist3d_3D_demo.onnx",
|
|
433
|
+
self.model_dir / "stardist3d.onnx",
|
|
434
|
+
]
|
|
435
|
+
|
|
436
|
+
for path in candidates:
|
|
437
|
+
if path.exists():
|
|
438
|
+
return path
|
|
439
|
+
|
|
440
|
+
try:
|
|
441
|
+
downloaded = ensure_hf_model(
|
|
442
|
+
default_filename,
|
|
443
|
+
self.model_dir / "onnx_models",
|
|
444
|
+
repo_id=DEFAULT_REPO_ID,
|
|
445
|
+
)
|
|
446
|
+
except RuntimeError:
|
|
447
|
+
downloaded = None
|
|
448
|
+
if downloaded is not None and downloaded.exists():
|
|
449
|
+
return downloaded
|
|
450
|
+
|
|
451
|
+
matches = []
|
|
452
|
+
for folder in (self.model_dir / "onnx_models", self.model_dir):
|
|
453
|
+
if folder.exists():
|
|
454
|
+
matches.extend(sorted(folder.glob("*.onnx")))
|
|
455
|
+
|
|
456
|
+
if len(matches) == 1:
|
|
457
|
+
return matches[0]
|
|
458
|
+
if len(matches) > 1:
|
|
459
|
+
raise ValueError(
|
|
460
|
+
"Multiple ONNX files found; keep one or use default file names."
|
|
461
|
+
)
|
|
462
|
+
raise FileNotFoundError(
|
|
463
|
+
"No ONNX model found. Place the exported model in the model folder "
|
|
464
|
+
"or allow SenoQuant to download it from the model repository."
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
def _resolve_io_names(self, session: ort.InferenceSession) -> tuple[str, list[str]]:
|
|
468
|
+
"""Resolve input and output tensor names for prob/dist inference."""
|
|
469
|
+
inputs = session.get_inputs()
|
|
470
|
+
outputs = session.get_outputs()
|
|
471
|
+
if not inputs:
|
|
472
|
+
raise RuntimeError("ONNX model has no inputs.")
|
|
473
|
+
if len(outputs) < 2:
|
|
474
|
+
raise RuntimeError("ONNX model must have prob and dist outputs.")
|
|
475
|
+
|
|
476
|
+
input_name = inputs[0].name
|
|
477
|
+
|
|
478
|
+
prob = None
|
|
479
|
+
dist = None
|
|
480
|
+
for output in outputs:
|
|
481
|
+
name = output.name.lower()
|
|
482
|
+
if "prob" in name and prob is None:
|
|
483
|
+
prob = output
|
|
484
|
+
elif "dist" in name and dist is None:
|
|
485
|
+
dist = output
|
|
486
|
+
|
|
487
|
+
if prob is None or dist is None:
|
|
488
|
+
for output in outputs:
|
|
489
|
+
shape = output.shape or []
|
|
490
|
+
channel = shape[-1] if shape else None
|
|
491
|
+
if channel == 1 and prob is None:
|
|
492
|
+
prob = output
|
|
493
|
+
elif channel not in (None, 1) and dist is None:
|
|
494
|
+
dist = output
|
|
495
|
+
|
|
496
|
+
if prob is None or dist is None:
|
|
497
|
+
prob, dist = outputs[0], outputs[1]
|
|
498
|
+
|
|
499
|
+
return input_name, [prob.name, dist.name]
|
|
500
|
+
|
|
501
|
+
def _ensure_stardist_lib_stubs(self) -> None:
|
|
502
|
+
"""Ensure StarDist modules import without compiled extensions.
|
|
503
|
+
|
|
504
|
+
This registers minimal stubs for compiled modules when shared
|
|
505
|
+
libraries are absent, allowing Python utilities to import.
|
|
506
|
+
"""
|
|
507
|
+
utils_root = self._get_utils_root()
|
|
508
|
+
csbdeep_root = utils_root / "_csbdeep"
|
|
509
|
+
if csbdeep_root.exists():
|
|
510
|
+
csbdeep_path = str(csbdeep_root)
|
|
511
|
+
if csbdeep_path not in sys.path:
|
|
512
|
+
sys.path.insert(0, csbdeep_path)
|
|
513
|
+
|
|
514
|
+
stardist_pkg = (
|
|
515
|
+
"senoquant.tabs.segmentation.stardist_onnx_utils._stardist"
|
|
516
|
+
)
|
|
517
|
+
if stardist_pkg not in sys.modules:
|
|
518
|
+
pkg = types.ModuleType(stardist_pkg)
|
|
519
|
+
pkg.__path__ = [str(utils_root / "_stardist")]
|
|
520
|
+
sys.modules[stardist_pkg] = pkg
|
|
521
|
+
|
|
522
|
+
base_pkg = f"{stardist_pkg}.lib"
|
|
523
|
+
lib_dirs = [utils_root / "_stardist" / "lib"]
|
|
524
|
+
for entry in list(sys.path):
|
|
525
|
+
if not entry:
|
|
526
|
+
continue
|
|
527
|
+
try:
|
|
528
|
+
candidate = (
|
|
529
|
+
Path(entry)
|
|
530
|
+
/ "senoquant"
|
|
531
|
+
/ "tabs"
|
|
532
|
+
/ "segmentation"
|
|
533
|
+
/ "stardist_onnx_utils"
|
|
534
|
+
/ "_stardist"
|
|
535
|
+
/ "lib"
|
|
536
|
+
)
|
|
537
|
+
except Exception:
|
|
538
|
+
continue
|
|
539
|
+
if candidate.exists():
|
|
540
|
+
lib_dirs.append(candidate)
|
|
541
|
+
|
|
542
|
+
if base_pkg in sys.modules:
|
|
543
|
+
pkg = sys.modules[base_pkg]
|
|
544
|
+
pkg.__path__ = [str(p) for p in lib_dirs]
|
|
545
|
+
else:
|
|
546
|
+
pkg = types.ModuleType(base_pkg)
|
|
547
|
+
pkg.__path__ = [str(p) for p in lib_dirs]
|
|
548
|
+
sys.modules[base_pkg] = pkg
|
|
549
|
+
|
|
550
|
+
def _stub(*_args, **_kwargs):
|
|
551
|
+
raise RuntimeError("StarDist compiled ops are unavailable.")
|
|
552
|
+
|
|
553
|
+
has_2d = False
|
|
554
|
+
has_3d = False
|
|
555
|
+
for lib_dir in lib_dirs:
|
|
556
|
+
has_2d = has_2d or any(lib_dir.glob("stardist2d*.so")) or any(
|
|
557
|
+
lib_dir.glob("stardist2d*.pyd")
|
|
558
|
+
)
|
|
559
|
+
has_3d = has_3d or any(lib_dir.glob("stardist3d*.so")) or any(
|
|
560
|
+
lib_dir.glob("stardist3d*.pyd")
|
|
561
|
+
)
|
|
562
|
+
self._has_stardist_2d_lib = has_2d
|
|
563
|
+
self._has_stardist_3d_lib = has_3d
|
|
564
|
+
|
|
565
|
+
mod2d = f"{base_pkg}.stardist2d"
|
|
566
|
+
if has_2d and mod2d in sys.modules:
|
|
567
|
+
if getattr(sys.modules[mod2d], "__file__", None) is None:
|
|
568
|
+
del sys.modules[mod2d]
|
|
569
|
+
if not has_2d and mod2d not in sys.modules:
|
|
570
|
+
module = types.ModuleType(mod2d)
|
|
571
|
+
module.c_star_dist = _stub
|
|
572
|
+
module.c_non_max_suppression_inds_old = _stub
|
|
573
|
+
module.c_non_max_suppression_inds = _stub
|
|
574
|
+
sys.modules[mod2d] = module
|
|
575
|
+
|
|
576
|
+
mod3d = f"{base_pkg}.stardist3d"
|
|
577
|
+
if has_3d and mod3d in sys.modules:
|
|
578
|
+
if getattr(sys.modules[mod3d], "__file__", None) is None:
|
|
579
|
+
del sys.modules[mod3d]
|
|
580
|
+
if not has_3d and mod3d not in sys.modules:
|
|
581
|
+
module = types.ModuleType(mod3d)
|
|
582
|
+
module.c_star_dist3d = _stub
|
|
583
|
+
module.c_polyhedron_to_label = _stub
|
|
584
|
+
module.c_non_max_suppression_inds = _stub
|
|
585
|
+
sys.modules[mod3d] = module
|
|
586
|
+
|
|
587
|
+
def _get_rays_class(self):
|
|
588
|
+
"""Load and cache the StarDist Rays_GoldenSpiral class."""
|
|
589
|
+
if self._rays_class is not None:
|
|
590
|
+
return self._rays_class
|
|
591
|
+
|
|
592
|
+
utils_root = self._get_utils_root()
|
|
593
|
+
rays_path = utils_root / "_stardist" / "rays3d.py"
|
|
594
|
+
if not rays_path.exists():
|
|
595
|
+
raise FileNotFoundError("Could not locate StarDist rays3d.py.")
|
|
596
|
+
|
|
597
|
+
module_name = "senoquant_stardist_rays3d"
|
|
598
|
+
spec = importlib.util.spec_from_file_location(module_name, rays_path)
|
|
599
|
+
if spec is None or spec.loader is None:
|
|
600
|
+
raise ImportError("Failed to load StarDist rays3d module.")
|
|
601
|
+
module = importlib.util.module_from_spec(spec)
|
|
602
|
+
spec.loader.exec_module(module)
|
|
603
|
+
self._rays_class = module.Rays_GoldenSpiral
|
|
604
|
+
return self._rays_class
|
|
605
|
+
|
|
606
|
+
def _get_utils_root(self) -> Path:
|
|
607
|
+
"""Return the stardist_onnx_utils package root."""
|
|
608
|
+
return Path(__file__).resolve().parents[2] / "stardist_onnx_utils"
|
|
609
|
+
|
|
610
|
+
def _infer_grid(
|
|
611
|
+
self,
|
|
612
|
+
image: np.ndarray,
|
|
613
|
+
session: ort.InferenceSession,
|
|
614
|
+
input_name: str,
|
|
615
|
+
output_names: list[str],
|
|
616
|
+
input_layout: str,
|
|
617
|
+
prob_layout: str,
|
|
618
|
+
*,
|
|
619
|
+
model_path: Path | None = None,
|
|
620
|
+
) -> tuple[int, ...]:
|
|
621
|
+
"""Infer model grid/stride by running a probe tile.
|
|
622
|
+
|
|
623
|
+
Parameters
|
|
624
|
+
----------
|
|
625
|
+
image : numpy.ndarray
|
|
626
|
+
Input image.
|
|
627
|
+
session : onnxruntime.InferenceSession
|
|
628
|
+
ONNX Runtime session.
|
|
629
|
+
input_name : str
|
|
630
|
+
ONNX input tensor name.
|
|
631
|
+
output_names : list[str]
|
|
632
|
+
ONNX output tensor names (prob, dist).
|
|
633
|
+
input_layout : str
|
|
634
|
+
Input layout string (e.g., "NHWC", "NDHWC").
|
|
635
|
+
prob_layout : str
|
|
636
|
+
Probability output layout string.
|
|
637
|
+
|
|
638
|
+
Returns
|
|
639
|
+
-------
|
|
640
|
+
tuple[int, ...]
|
|
641
|
+
Estimated grid/stride per axis.
|
|
642
|
+
"""
|
|
643
|
+
probe = self._make_probe_image(
|
|
644
|
+
image, model_path=model_path, input_layout=input_layout
|
|
645
|
+
)
|
|
646
|
+
if input_layout in ("NHWC", "NDHWC"):
|
|
647
|
+
input_tensor = probe[np.newaxis, ..., np.newaxis]
|
|
648
|
+
else:
|
|
649
|
+
input_tensor = probe[np.newaxis, np.newaxis, ...]
|
|
650
|
+
|
|
651
|
+
prob = session.run(output_names, {input_name: input_tensor})[0]
|
|
652
|
+
if prob_layout in ("NHWC", "NDHWC"):
|
|
653
|
+
out_shape = prob.shape[1:-1]
|
|
654
|
+
elif prob_layout in ("NCHW", "NCDHW"):
|
|
655
|
+
out_shape = prob.shape[2:]
|
|
656
|
+
else:
|
|
657
|
+
raise ValueError(f"Unsupported prob layout {prob_layout}.")
|
|
658
|
+
|
|
659
|
+
grid = []
|
|
660
|
+
for dim_in, dim_out in zip(probe.shape, out_shape):
|
|
661
|
+
if dim_out in (0, None):
|
|
662
|
+
grid.append(1)
|
|
663
|
+
continue
|
|
664
|
+
ratio = dim_in / dim_out
|
|
665
|
+
grid.append(max(1, int(round(ratio))))
|
|
666
|
+
return tuple(grid)
|
|
667
|
+
|
|
668
|
+
def _make_probe_image(
|
|
669
|
+
self,
|
|
670
|
+
image: np.ndarray,
|
|
671
|
+
*,
|
|
672
|
+
model_path: Path | None = None,
|
|
673
|
+
input_layout: str | None = None,
|
|
674
|
+
) -> np.ndarray:
|
|
675
|
+
"""Create a small probe image for grid inference."""
|
|
676
|
+
return make_probe_image(
|
|
677
|
+
image,
|
|
678
|
+
model_path=model_path,
|
|
679
|
+
input_layout=input_layout,
|
|
680
|
+
div_by_cache=self._div_by_cache,
|
|
681
|
+
valid_size_cache=self._valid_size_cache,
|
|
682
|
+
)
|