senoquant 1.0.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- senoquant/__init__.py +6 -0
- senoquant/_reader.py +7 -0
- senoquant/_widget.py +33 -0
- senoquant/napari.yaml +83 -0
- senoquant/reader/__init__.py +5 -0
- senoquant/reader/core.py +369 -0
- senoquant/tabs/__init__.py +15 -0
- senoquant/tabs/batch/__init__.py +10 -0
- senoquant/tabs/batch/backend.py +641 -0
- senoquant/tabs/batch/config.py +270 -0
- senoquant/tabs/batch/frontend.py +1283 -0
- senoquant/tabs/batch/io.py +326 -0
- senoquant/tabs/batch/layers.py +86 -0
- senoquant/tabs/quantification/__init__.py +1 -0
- senoquant/tabs/quantification/backend.py +228 -0
- senoquant/tabs/quantification/features/__init__.py +80 -0
- senoquant/tabs/quantification/features/base.py +142 -0
- senoquant/tabs/quantification/features/marker/__init__.py +5 -0
- senoquant/tabs/quantification/features/marker/config.py +69 -0
- senoquant/tabs/quantification/features/marker/dialog.py +437 -0
- senoquant/tabs/quantification/features/marker/export.py +879 -0
- senoquant/tabs/quantification/features/marker/feature.py +119 -0
- senoquant/tabs/quantification/features/marker/morphology.py +285 -0
- senoquant/tabs/quantification/features/marker/rows.py +654 -0
- senoquant/tabs/quantification/features/marker/thresholding.py +46 -0
- senoquant/tabs/quantification/features/roi.py +346 -0
- senoquant/tabs/quantification/features/spots/__init__.py +5 -0
- senoquant/tabs/quantification/features/spots/config.py +62 -0
- senoquant/tabs/quantification/features/spots/dialog.py +477 -0
- senoquant/tabs/quantification/features/spots/export.py +1292 -0
- senoquant/tabs/quantification/features/spots/feature.py +112 -0
- senoquant/tabs/quantification/features/spots/morphology.py +279 -0
- senoquant/tabs/quantification/features/spots/rows.py +241 -0
- senoquant/tabs/quantification/frontend.py +815 -0
- senoquant/tabs/segmentation/__init__.py +1 -0
- senoquant/tabs/segmentation/backend.py +131 -0
- senoquant/tabs/segmentation/frontend.py +1009 -0
- senoquant/tabs/segmentation/models/__init__.py +5 -0
- senoquant/tabs/segmentation/models/base.py +146 -0
- senoquant/tabs/segmentation/models/cpsam/details.json +65 -0
- senoquant/tabs/segmentation/models/cpsam/model.py +150 -0
- senoquant/tabs/segmentation/models/default_2d/details.json +69 -0
- senoquant/tabs/segmentation/models/default_2d/model.py +664 -0
- senoquant/tabs/segmentation/models/default_3d/details.json +69 -0
- senoquant/tabs/segmentation/models/default_3d/model.py +682 -0
- senoquant/tabs/segmentation/models/hf.py +71 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/__init__.py +1 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/details.json +26 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/model.py +96 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/__init__.py +1 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/details.json +34 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/model.py +132 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/__init__.py +2 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/__init__.py +3 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/__init__.py +6 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/generate.py +470 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/prepare.py +273 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/rawdata.py +112 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/transform.py +384 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/blocks.py +184 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/losses.py +79 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/nets.py +165 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/predict.py +467 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/probability.py +67 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/train.py +148 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/io/__init__.py +163 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/__init__.py +52 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/base_model.py +329 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_isotropic.py +160 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_projection.py +178 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_standard.py +446 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_upsampling.py +54 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/config.py +254 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/pretrained.py +119 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/care_predict.py +180 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/__init__.py +5 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/plot_utils.py +159 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/six.py +18 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/tf.py +644 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/utils.py +272 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/version.py +1 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/docs/source/conf.py +368 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/setup.py +68 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_datagen.py +169 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_models.py +462 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_utils.py +166 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tools/create_zip_contents.py +34 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/__init__.py +30 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/big.py +624 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/bioimageio_utils.py +494 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/data/__init__.py +39 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/__init__.py +10 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom2d.py +215 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom3d.py +349 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/matching.py +483 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/__init__.py +28 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/base.py +1217 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model2d.py +594 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model3d.py +696 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/nms.py +384 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/__init__.py +2 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/plot.py +74 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/render.py +298 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/rays3d.py +373 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/sample_patches.py +65 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict2d.py +90 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict3d.py +93 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/utils.py +408 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/version.py +1 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/__init__.py +45 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/__init__.py +17 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/cli.py +55 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/core.py +285 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/__init__.py +15 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/cli.py +36 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/divisibility.py +193 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/probe.py +100 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/receptive_field.py +182 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/rf_cli.py +48 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/valid_sizes.py +278 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/__init__.py +8 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/core.py +157 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/__init__.py +17 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/core.py +226 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/__init__.py +5 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/core.py +401 -0
- senoquant/tabs/settings/__init__.py +1 -0
- senoquant/tabs/settings/backend.py +29 -0
- senoquant/tabs/settings/frontend.py +19 -0
- senoquant/tabs/spots/__init__.py +1 -0
- senoquant/tabs/spots/backend.py +139 -0
- senoquant/tabs/spots/frontend.py +800 -0
- senoquant/tabs/spots/models/__init__.py +5 -0
- senoquant/tabs/spots/models/base.py +94 -0
- senoquant/tabs/spots/models/rmp/details.json +61 -0
- senoquant/tabs/spots/models/rmp/model.py +499 -0
- senoquant/tabs/spots/models/udwt/details.json +103 -0
- senoquant/tabs/spots/models/udwt/model.py +482 -0
- senoquant/utils.py +25 -0
- senoquant-1.0.0b1.dist-info/METADATA +193 -0
- senoquant-1.0.0b1.dist-info/RECORD +148 -0
- senoquant-1.0.0b1.dist-info/WHEEL +5 -0
- senoquant-1.0.0b1.dist-info/entry_points.txt +2 -0
- senoquant-1.0.0b1.dist-info/licenses/LICENSE +28 -0
- senoquant-1.0.0b1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
"""Utilities for creating valid probe inputs for ONNX inspection."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from .divisibility import infer_div_by
|
|
10
|
+
from .valid_sizes import infer_valid_size_patterns_from_path, snap_size
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def make_probe_image(
|
|
14
|
+
image: np.ndarray,
|
|
15
|
+
*,
|
|
16
|
+
model_path: Path | None = None,
|
|
17
|
+
input_layout: str | None = None,
|
|
18
|
+
div_by_cache: dict[Path, tuple[int, ...]] | None = None,
|
|
19
|
+
valid_size_cache: dict[Path, list[object] | None] | None = None,
|
|
20
|
+
) -> np.ndarray:
|
|
21
|
+
"""Create a small probe image aligned with ONNX size constraints.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
image : numpy.ndarray
|
|
26
|
+
Input image array used to derive probe size.
|
|
27
|
+
model_path : pathlib.Path or None, optional
|
|
28
|
+
ONNX model path used for inspecting size constraints.
|
|
29
|
+
input_layout : str or None, optional
|
|
30
|
+
Model input layout (e.g., "NHWC", "NDHWC") used for size inspection.
|
|
31
|
+
div_by_cache : dict or None, optional
|
|
32
|
+
Cache for divisibility requirements keyed by model path.
|
|
33
|
+
valid_size_cache : dict or None, optional
|
|
34
|
+
Cache for valid size patterns keyed by model path.
|
|
35
|
+
|
|
36
|
+
Returns
|
|
37
|
+
-------
|
|
38
|
+
numpy.ndarray
|
|
39
|
+
Probe image padded/cropped to a valid spatial size.
|
|
40
|
+
"""
|
|
41
|
+
target = 256 if image.ndim == 2 else 64
|
|
42
|
+
probe_shape = []
|
|
43
|
+
for dim in image.shape:
|
|
44
|
+
size = min(dim, target)
|
|
45
|
+
if size >= 16:
|
|
46
|
+
size = size - (size % 16)
|
|
47
|
+
if size == 0:
|
|
48
|
+
size = min(dim, target)
|
|
49
|
+
probe_shape.append(max(1, size))
|
|
50
|
+
|
|
51
|
+
probe = image[tuple(slice(0, s) for s in probe_shape)]
|
|
52
|
+
|
|
53
|
+
if model_path is None or input_layout is None:
|
|
54
|
+
return probe
|
|
55
|
+
|
|
56
|
+
patterns = None
|
|
57
|
+
if valid_size_cache is not None:
|
|
58
|
+
patterns = valid_size_cache.get(model_path)
|
|
59
|
+
if patterns is None:
|
|
60
|
+
try:
|
|
61
|
+
patterns = infer_valid_size_patterns_from_path(
|
|
62
|
+
model_path,
|
|
63
|
+
input_layout,
|
|
64
|
+
image.ndim,
|
|
65
|
+
)
|
|
66
|
+
except Exception:
|
|
67
|
+
patterns = None
|
|
68
|
+
if valid_size_cache is not None:
|
|
69
|
+
valid_size_cache[model_path] = patterns
|
|
70
|
+
|
|
71
|
+
div_by = None
|
|
72
|
+
if div_by_cache is not None:
|
|
73
|
+
div_by = div_by_cache.get(model_path)
|
|
74
|
+
if div_by is None:
|
|
75
|
+
try:
|
|
76
|
+
div_by = infer_div_by(model_path, ndim=image.ndim)
|
|
77
|
+
except Exception:
|
|
78
|
+
div_by = None
|
|
79
|
+
if div_by_cache is not None and div_by is not None:
|
|
80
|
+
div_by_cache[model_path] = div_by
|
|
81
|
+
|
|
82
|
+
desired = list(probe.shape)
|
|
83
|
+
if patterns:
|
|
84
|
+
desired = [
|
|
85
|
+
max(1, snap_size(int(size), patterns[axis]))
|
|
86
|
+
for axis, size in enumerate(desired)
|
|
87
|
+
]
|
|
88
|
+
elif div_by:
|
|
89
|
+
desired = [
|
|
90
|
+
max(int(d), (int(size) // int(d)) * int(d)) if d else int(size)
|
|
91
|
+
for size, d in zip(desired, div_by)
|
|
92
|
+
]
|
|
93
|
+
|
|
94
|
+
desired = [max(1, int(size)) for size in desired]
|
|
95
|
+
crop_slices = tuple(slice(0, min(s, d)) for s, d in zip(probe.shape, desired))
|
|
96
|
+
probe = probe[crop_slices]
|
|
97
|
+
pads = [(0, max(0, d - s)) for s, d in zip(probe.shape, desired)]
|
|
98
|
+
if any(pad_after > 0 for _, pad_after in pads):
|
|
99
|
+
probe = np.pad(probe, pads, mode="reflect")
|
|
100
|
+
return probe
|
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
"""Empirically estimate receptive field and tiling overlap for ONNX models.
|
|
2
|
+
|
|
3
|
+
This module mirrors StarDist's empirical receptive-field estimation:
|
|
4
|
+
run the model once on a single-pixel impulse and once on zeros, then
|
|
5
|
+
measure the spatial support of the difference in the probability output.
|
|
6
|
+
The measured extents define the overlap needed to avoid tile boundary
|
|
7
|
+
artifacts in tiled prediction.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Iterable
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
|
|
17
|
+
from .divisibility import infer_div_by
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def infer_receptive_field(
|
|
21
|
+
model_path: str | Path,
|
|
22
|
+
ndim: int | None = None,
|
|
23
|
+
input_shape: tuple[int, ...] | None = None,
|
|
24
|
+
eps: float = 0.0,
|
|
25
|
+
) -> tuple[tuple[int, int], ...]:
|
|
26
|
+
"""Estimate the receptive field via impulse response.
|
|
27
|
+
|
|
28
|
+
This mirrors StarDist's empirical receptive-field estimation: run the model
|
|
29
|
+
on an impulse image and on zeros, then find the spatial support of the
|
|
30
|
+
difference.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
model_path : str or pathlib.Path
|
|
35
|
+
Path to the ONNX model file.
|
|
36
|
+
ndim : int or None, optional
|
|
37
|
+
Spatial dimensionality (2 or 3). If None, inferred from input rank.
|
|
38
|
+
input_shape : tuple[int, ...] or None, optional
|
|
39
|
+
Spatial shape for the probe input. If None, a power-of-two shape is
|
|
40
|
+
chosen and adjusted to satisfy the inferred divisibility.
|
|
41
|
+
eps : float, optional
|
|
42
|
+
Threshold used to detect non-zero influence in the output. Default 0.0.
|
|
43
|
+
|
|
44
|
+
Returns
|
|
45
|
+
-------
|
|
46
|
+
tuple[tuple[int, int], ...]
|
|
47
|
+
Per-axis receptive field extents as (left, right) offsets from the
|
|
48
|
+
center pixel/voxel in input coordinates.
|
|
49
|
+
|
|
50
|
+
Notes
|
|
51
|
+
-----
|
|
52
|
+
- The probe uses an impulse (single 1.0) at the spatial center.
|
|
53
|
+
- The probability output is selected by heuristics (last dim == 1).
|
|
54
|
+
- Output is mapped back to input resolution using the inferred grid.
|
|
55
|
+
"""
|
|
56
|
+
import onnxruntime as ort
|
|
57
|
+
from scipy.ndimage import zoom
|
|
58
|
+
|
|
59
|
+
model_path = Path(model_path)
|
|
60
|
+
session = ort.InferenceSession(str(model_path))
|
|
61
|
+
|
|
62
|
+
input_name = session.get_inputs()[0].name
|
|
63
|
+
output_names = [out.name for out in session.get_outputs()]
|
|
64
|
+
|
|
65
|
+
if ndim is None:
|
|
66
|
+
ndim = _infer_ndim_from_input(session)
|
|
67
|
+
|
|
68
|
+
if input_shape is None:
|
|
69
|
+
# Choose a reasonable power-of-two probe size and round up to a
|
|
70
|
+
# multiple of the inferred divisibility to avoid internal mismatches.
|
|
71
|
+
base = 256 if ndim == 2 else 64
|
|
72
|
+
div_by = infer_div_by(model_path, ndim=ndim)
|
|
73
|
+
input_shape = tuple(_round_up(base, d) for d in div_by)
|
|
74
|
+
|
|
75
|
+
if len(input_shape) != ndim:
|
|
76
|
+
raise ValueError("input_shape must match ndim.")
|
|
77
|
+
|
|
78
|
+
# Build impulse and zero inputs (NHWC/NDHWC).
|
|
79
|
+
center = tuple(s // 2 for s in input_shape)
|
|
80
|
+
x = np.zeros((1, *input_shape, 1), dtype=np.float32)
|
|
81
|
+
z = np.zeros_like(x)
|
|
82
|
+
x[(0, *center, 0)] = 1.0
|
|
83
|
+
|
|
84
|
+
# Run the model and extract the probability output.
|
|
85
|
+
y = _run_prob(session, output_names, input_name, x, ndim)
|
|
86
|
+
y0 = _run_prob(session, output_names, input_name, z, ndim)
|
|
87
|
+
|
|
88
|
+
# Infer grid from input/output shapes (input / output per axis).
|
|
89
|
+
grid = tuple(
|
|
90
|
+
max(1, int(round(si / so))) for si, so in zip(input_shape, y.shape)
|
|
91
|
+
)
|
|
92
|
+
y = zoom(y, grid, order=0)
|
|
93
|
+
y0 = zoom(y0, grid, order=0)
|
|
94
|
+
|
|
95
|
+
# Measure where the response differs from zero.
|
|
96
|
+
diff = np.abs(y - y0) > eps
|
|
97
|
+
indices = np.where(diff)
|
|
98
|
+
if any(len(i) == 0 for i in indices):
|
|
99
|
+
raise RuntimeError("Failed to detect receptive field; try a larger input_shape.")
|
|
100
|
+
|
|
101
|
+
return tuple((c - int(np.min(i)), int(np.max(i)) - c) for c, i in zip(center, indices))
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def recommend_tile_overlap(
|
|
105
|
+
model_path: str | Path,
|
|
106
|
+
ndim: int | None = None,
|
|
107
|
+
input_shape: tuple[int, ...] | None = None,
|
|
108
|
+
eps: float = 0.0,
|
|
109
|
+
) -> tuple[int, ...]:
|
|
110
|
+
"""Return recommended tile overlap per axis from empirical RF.
|
|
111
|
+
|
|
112
|
+
Parameters
|
|
113
|
+
----------
|
|
114
|
+
model_path : str or pathlib.Path
|
|
115
|
+
Path to the ONNX model file.
|
|
116
|
+
ndim : int or None, optional
|
|
117
|
+
Spatial dimensionality (2 or 3). If None, inferred from input rank.
|
|
118
|
+
input_shape : tuple[int, ...] or None, optional
|
|
119
|
+
Spatial probe input shape. If None, a default shape is used.
|
|
120
|
+
eps : float, optional
|
|
121
|
+
Threshold used to detect non-zero influence in the output.
|
|
122
|
+
|
|
123
|
+
Returns
|
|
124
|
+
-------
|
|
125
|
+
tuple[int, ...]
|
|
126
|
+
Per-axis overlap in input pixels.
|
|
127
|
+
"""
|
|
128
|
+
rf = infer_receptive_field(
|
|
129
|
+
model_path=model_path,
|
|
130
|
+
ndim=ndim,
|
|
131
|
+
input_shape=input_shape,
|
|
132
|
+
eps=eps,
|
|
133
|
+
)
|
|
134
|
+
return tuple(max(pair) for pair in rf)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _run_prob(session, output_names, input_name, input_tensor, ndim: int) -> np.ndarray:
|
|
138
|
+
"""Run the ONNX model and return the probability output in spatial layout."""
|
|
139
|
+
outputs = session.run(output_names, {input_name: input_tensor})
|
|
140
|
+
prob = _select_prob_output(outputs)
|
|
141
|
+
prob = _to_spatial(prob, ndim)
|
|
142
|
+
return prob
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _select_prob_output(outputs: list[np.ndarray]) -> np.ndarray:
|
|
146
|
+
"""Pick the probability output from ONNX outputs."""
|
|
147
|
+
for arr in outputs:
|
|
148
|
+
if arr.ndim >= 4 and arr.shape[-1] == 1:
|
|
149
|
+
return arr
|
|
150
|
+
return outputs[0]
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def _to_spatial(prob: np.ndarray, ndim: int) -> np.ndarray:
|
|
154
|
+
"""Convert a batched prob tensor into spatial layout (YX/ZYX)."""
|
|
155
|
+
if ndim == 2:
|
|
156
|
+
if prob.ndim == 4 and prob.shape[-1] == 1:
|
|
157
|
+
return prob[0, ..., 0]
|
|
158
|
+
if prob.ndim == 4 and prob.shape[1] == 1:
|
|
159
|
+
return prob[0, 0, ...]
|
|
160
|
+
if ndim == 3:
|
|
161
|
+
if prob.ndim == 5 and prob.shape[-1] == 1:
|
|
162
|
+
return prob[0, ..., 0]
|
|
163
|
+
if prob.ndim == 5 and prob.shape[1] == 1:
|
|
164
|
+
return prob[0, 0, ...]
|
|
165
|
+
raise ValueError("Unsupported prob output layout.")
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _infer_ndim_from_input(session) -> int:
|
|
169
|
+
"""Infer spatial dimensionality from ONNX session input rank."""
|
|
170
|
+
shape = session.get_inputs()[0].shape
|
|
171
|
+
if len(shape) == 4:
|
|
172
|
+
return 2
|
|
173
|
+
if len(shape) == 5:
|
|
174
|
+
return 3
|
|
175
|
+
raise ValueError(f"Unsupported input rank {len(shape)}.")
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _round_up(value: int, multiple: int) -> int:
|
|
179
|
+
"""Round up ``value`` to the next multiple."""
|
|
180
|
+
if multiple <= 0:
|
|
181
|
+
return value
|
|
182
|
+
return int(np.ceil(value / multiple) * multiple)
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""CLI for empirical receptive-field estimation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import argparse
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
from .receptive_field import infer_receptive_field, recommend_tile_overlap
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _parse_args() -> argparse.Namespace:
|
|
12
|
+
parser = argparse.ArgumentParser(description="Estimate ONNX receptive field.")
|
|
13
|
+
parser.add_argument("model", type=Path, help="Path to the ONNX model.")
|
|
14
|
+
parser.add_argument("--ndim", type=int, choices=(2, 3), default=None)
|
|
15
|
+
parser.add_argument(
|
|
16
|
+
"--shape",
|
|
17
|
+
type=int,
|
|
18
|
+
nargs="+",
|
|
19
|
+
default=None,
|
|
20
|
+
help="Spatial input shape (e.g. --shape 256 256 or --shape 64 64 64).",
|
|
21
|
+
)
|
|
22
|
+
parser.add_argument("--eps", type=float, default=0.0)
|
|
23
|
+
return parser.parse_args()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def main() -> None:
|
|
27
|
+
args = _parse_args()
|
|
28
|
+
input_shape = tuple(args.shape) if args.shape else None
|
|
29
|
+
rf = infer_receptive_field(
|
|
30
|
+
model_path=args.model,
|
|
31
|
+
ndim=args.ndim,
|
|
32
|
+
input_shape=input_shape,
|
|
33
|
+
eps=args.eps,
|
|
34
|
+
)
|
|
35
|
+
overlap = recommend_tile_overlap(
|
|
36
|
+
model_path=args.model,
|
|
37
|
+
ndim=args.ndim,
|
|
38
|
+
input_shape=input_shape,
|
|
39
|
+
eps=args.eps,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
print(f"Model: {args.model}")
|
|
43
|
+
print(f"Receptive field: {rf}")
|
|
44
|
+
print(f"Recommended overlap: {overlap}")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
if __name__ == "__main__":
|
|
48
|
+
main()
|
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
"""Empirically infer valid spatial sizes for ONNX model inputs.
|
|
2
|
+
|
|
3
|
+
This module probes the ONNX runtime by running the model on small inputs and
|
|
4
|
+
recording which spatial sizes succeed. It then summarizes valid sizes as
|
|
5
|
+
periodic residues (e.g., sizes of the form ``16k`` or ``16k+1``).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import math
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from typing import Iterable, Sequence
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass(frozen=True)
|
|
18
|
+
class ValidSizePattern:
|
|
19
|
+
"""Periodic validity pattern for a single spatial axis.
|
|
20
|
+
|
|
21
|
+
Attributes
|
|
22
|
+
----------
|
|
23
|
+
period : int
|
|
24
|
+
Periodicity for valid sizes.
|
|
25
|
+
residues : tuple[int, ...]
|
|
26
|
+
Allowed ``size % period`` residues.
|
|
27
|
+
min_valid : int
|
|
28
|
+
Smallest observed valid size.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
period: int
|
|
32
|
+
residues: tuple[int, ...]
|
|
33
|
+
min_valid: int
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def infer_valid_size_patterns(
|
|
37
|
+
session,
|
|
38
|
+
input_name: str,
|
|
39
|
+
output_names: Iterable[str],
|
|
40
|
+
input_layout: str,
|
|
41
|
+
ndim: int,
|
|
42
|
+
max_probe: int = 64,
|
|
43
|
+
) -> list[ValidSizePattern]:
|
|
44
|
+
"""Probe ONNX execution to infer valid size residues per axis.
|
|
45
|
+
|
|
46
|
+
Parameters
|
|
47
|
+
----------
|
|
48
|
+
session : onnxruntime.InferenceSession
|
|
49
|
+
ONNX Runtime session used to execute the model.
|
|
50
|
+
input_name : str
|
|
51
|
+
Name of the ONNX input tensor.
|
|
52
|
+
output_names : Iterable[str]
|
|
53
|
+
Output tensor names to request during inference.
|
|
54
|
+
input_layout : str
|
|
55
|
+
Input layout string (e.g., ``"NHWC"`` or ``"NDHWC"``).
|
|
56
|
+
ndim : int
|
|
57
|
+
Spatial dimensionality (2 or 3).
|
|
58
|
+
max_probe : int, optional
|
|
59
|
+
Maximum size to probe per axis. Default is 64.
|
|
60
|
+
|
|
61
|
+
Returns
|
|
62
|
+
-------
|
|
63
|
+
list[ValidSizePattern]
|
|
64
|
+
One entry per axis describing periodic valid size residues.
|
|
65
|
+
|
|
66
|
+
Raises
|
|
67
|
+
------
|
|
68
|
+
RuntimeError
|
|
69
|
+
If no valid sizes can be found within the probe range.
|
|
70
|
+
"""
|
|
71
|
+
if ndim not in (2, 3):
|
|
72
|
+
raise ValueError("ndim must be 2 or 3.")
|
|
73
|
+
|
|
74
|
+
output_names = list(output_names)
|
|
75
|
+
|
|
76
|
+
base = _find_valid_base(
|
|
77
|
+
session, input_name, output_names, input_layout, ndim, max_probe
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
patterns: list[ValidSizePattern] = []
|
|
81
|
+
any_valid = False
|
|
82
|
+
for axis in range(ndim):
|
|
83
|
+
valid = []
|
|
84
|
+
for size in range(1, max_probe + 1):
|
|
85
|
+
shape = [base] * ndim
|
|
86
|
+
shape[axis] = size
|
|
87
|
+
if _try_run(session, input_name, output_names, input_layout, shape):
|
|
88
|
+
valid.append(size)
|
|
89
|
+
if not valid:
|
|
90
|
+
patterns.append(ValidSizePattern(period=1, residues=(0,), min_valid=1))
|
|
91
|
+
continue
|
|
92
|
+
any_valid = True
|
|
93
|
+
period, residues = _infer_period_and_residues(valid, max_probe)
|
|
94
|
+
patterns.append(
|
|
95
|
+
ValidSizePattern(
|
|
96
|
+
period=int(period),
|
|
97
|
+
residues=tuple(int(r) for r in residues),
|
|
98
|
+
min_valid=int(min(valid)),
|
|
99
|
+
)
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
if not any_valid:
|
|
103
|
+
raise RuntimeError(
|
|
104
|
+
f"No valid sizes found within 1..{max_probe} for any axis."
|
|
105
|
+
)
|
|
106
|
+
return patterns
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def infer_valid_size_patterns_from_path(
|
|
110
|
+
model_path,
|
|
111
|
+
input_layout: str,
|
|
112
|
+
ndim: int,
|
|
113
|
+
max_probe: int = 64,
|
|
114
|
+
) -> list[ValidSizePattern]:
|
|
115
|
+
"""Probe valid sizes using a temporary, quiet ONNX session.
|
|
116
|
+
|
|
117
|
+
Parameters
|
|
118
|
+
----------
|
|
119
|
+
model_path : str or pathlib.Path
|
|
120
|
+
Path to the ONNX model file.
|
|
121
|
+
input_layout : str
|
|
122
|
+
Input layout string (e.g., ``"NHWC"`` or ``"NDHWC"``).
|
|
123
|
+
ndim : int
|
|
124
|
+
Spatial dimensionality (2 or 3).
|
|
125
|
+
max_probe : int, optional
|
|
126
|
+
Maximum size to probe per axis. Default is 64.
|
|
127
|
+
|
|
128
|
+
Returns
|
|
129
|
+
-------
|
|
130
|
+
list[ValidSizePattern]
|
|
131
|
+
One entry per axis describing periodic valid size residues.
|
|
132
|
+
"""
|
|
133
|
+
import onnxruntime as ort
|
|
134
|
+
|
|
135
|
+
sess_options = ort.SessionOptions()
|
|
136
|
+
# Suppress ORT error logs during probe failures.
|
|
137
|
+
sess_options.log_severity_level = 4
|
|
138
|
+
session = ort.InferenceSession(str(model_path), sess_options=sess_options)
|
|
139
|
+
input_name = session.get_inputs()[0].name
|
|
140
|
+
output_names = [o.name for o in session.get_outputs()]
|
|
141
|
+
return infer_valid_size_patterns(
|
|
142
|
+
session,
|
|
143
|
+
input_name,
|
|
144
|
+
output_names,
|
|
145
|
+
input_layout,
|
|
146
|
+
ndim,
|
|
147
|
+
max_probe=max_probe,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def _find_valid_base(
|
|
152
|
+
session,
|
|
153
|
+
input_name: str,
|
|
154
|
+
output_names: list[str],
|
|
155
|
+
input_layout: str,
|
|
156
|
+
ndim: int,
|
|
157
|
+
max_probe: int,
|
|
158
|
+
) -> int:
|
|
159
|
+
"""Return the smallest symmetric size that executes successfully."""
|
|
160
|
+
for size in range(1, max_probe + 1):
|
|
161
|
+
shape = [size] * ndim
|
|
162
|
+
if _try_run(session, input_name, output_names, input_layout, shape):
|
|
163
|
+
return size
|
|
164
|
+
raise RuntimeError("Failed to find any valid base size for probing.")
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def _try_run(
|
|
168
|
+
session,
|
|
169
|
+
input_name: str,
|
|
170
|
+
output_names: list[str],
|
|
171
|
+
input_layout: str,
|
|
172
|
+
spatial_shape: list[int],
|
|
173
|
+
) -> bool:
|
|
174
|
+
"""Return True if the model runs on the given spatial shape."""
|
|
175
|
+
if input_layout in ("NHWC", "NDHWC"):
|
|
176
|
+
input_tensor = np.zeros((1, *spatial_shape, 1), dtype=np.float32)
|
|
177
|
+
elif input_layout in ("NCHW", "NCDHW"):
|
|
178
|
+
input_tensor = np.zeros((1, 1, *spatial_shape), dtype=np.float32)
|
|
179
|
+
else:
|
|
180
|
+
raise ValueError(f"Unsupported input layout {input_layout}.")
|
|
181
|
+
try:
|
|
182
|
+
session.run(list(output_names), {input_name: input_tensor})
|
|
183
|
+
except Exception:
|
|
184
|
+
return False
|
|
185
|
+
return True
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def _infer_period_and_residues(
|
|
189
|
+
valid_sizes: list[int], max_probe: int
|
|
190
|
+
) -> tuple[int, list[int]]:
|
|
191
|
+
"""Infer periodicity and residue set for valid sizes."""
|
|
192
|
+
valid_set = set(valid_sizes)
|
|
193
|
+
if not valid_set:
|
|
194
|
+
return 1, [0]
|
|
195
|
+
|
|
196
|
+
min_valid = min(valid_set)
|
|
197
|
+
for period in range(1, max_probe + 1):
|
|
198
|
+
residues = {v % period for v in valid_set}
|
|
199
|
+
ok = True
|
|
200
|
+
for size in range(min_valid, max_probe + 1):
|
|
201
|
+
if (size % period in residues) != (size in valid_set):
|
|
202
|
+
ok = False
|
|
203
|
+
break
|
|
204
|
+
if ok:
|
|
205
|
+
return period, sorted(residues)
|
|
206
|
+
|
|
207
|
+
if len(valid_sizes) < 2:
|
|
208
|
+
return max(1, valid_sizes[0]), [valid_sizes[0] % max(1, valid_sizes[0])]
|
|
209
|
+
|
|
210
|
+
diffs = [b - a for a, b in zip(valid_sizes, valid_sizes[1:]) if b > a]
|
|
211
|
+
period = diffs[0]
|
|
212
|
+
for d in diffs[1:]:
|
|
213
|
+
period = math.gcd(period, d)
|
|
214
|
+
residues = sorted({v % period for v in valid_set})
|
|
215
|
+
return max(1, period), residues
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def snap_size(size: int, pattern: ValidSizePattern) -> int:
|
|
219
|
+
"""Adjust a size to the nearest valid residue at or below ``size``.
|
|
220
|
+
|
|
221
|
+
Parameters
|
|
222
|
+
----------
|
|
223
|
+
size : int
|
|
224
|
+
Proposed size.
|
|
225
|
+
pattern : ValidSizePattern
|
|
226
|
+
Valid size pattern for the axis.
|
|
227
|
+
|
|
228
|
+
Returns
|
|
229
|
+
-------
|
|
230
|
+
int
|
|
231
|
+
Snapped valid size.
|
|
232
|
+
"""
|
|
233
|
+
period = max(1, int(pattern.period))
|
|
234
|
+
residues = set(int(r) for r in pattern.residues)
|
|
235
|
+
min_valid = int(pattern.min_valid)
|
|
236
|
+
if size <= min_valid:
|
|
237
|
+
return min_valid
|
|
238
|
+
for delta in range(period + 1):
|
|
239
|
+
candidate = size - delta
|
|
240
|
+
if candidate < min_valid:
|
|
241
|
+
break
|
|
242
|
+
if candidate % period in residues:
|
|
243
|
+
return candidate
|
|
244
|
+
candidate = size
|
|
245
|
+
while candidate % period not in residues:
|
|
246
|
+
candidate += 1
|
|
247
|
+
return candidate
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def snap_shape(
|
|
251
|
+
shape: Sequence[int],
|
|
252
|
+
patterns: Sequence[ValidSizePattern],
|
|
253
|
+
*,
|
|
254
|
+
skip_axes: Sequence[int] = (),
|
|
255
|
+
) -> tuple[int, ...]:
|
|
256
|
+
"""Snap each axis of a shape to the nearest valid size.
|
|
257
|
+
|
|
258
|
+
Parameters
|
|
259
|
+
----------
|
|
260
|
+
shape : Sequence[int]
|
|
261
|
+
Proposed spatial shape.
|
|
262
|
+
patterns : Sequence[ValidSizePattern]
|
|
263
|
+
Per-axis valid size patterns.
|
|
264
|
+
skip_axes : Sequence[int], optional
|
|
265
|
+
Axes to leave unchanged (e.g., skip Z for 3D models).
|
|
266
|
+
|
|
267
|
+
Returns
|
|
268
|
+
-------
|
|
269
|
+
tuple[int, ...]
|
|
270
|
+
Snapped spatial shape.
|
|
271
|
+
"""
|
|
272
|
+
snapped = []
|
|
273
|
+
for axis, size in enumerate(shape):
|
|
274
|
+
if axis in skip_axes:
|
|
275
|
+
snapped.append(int(size))
|
|
276
|
+
continue
|
|
277
|
+
snapped.append(snap_size(int(size), patterns[axis]))
|
|
278
|
+
return tuple(snapped)
|