senoquant 1.0.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. senoquant/__init__.py +6 -0
  2. senoquant/_reader.py +7 -0
  3. senoquant/_widget.py +33 -0
  4. senoquant/napari.yaml +83 -0
  5. senoquant/reader/__init__.py +5 -0
  6. senoquant/reader/core.py +369 -0
  7. senoquant/tabs/__init__.py +15 -0
  8. senoquant/tabs/batch/__init__.py +10 -0
  9. senoquant/tabs/batch/backend.py +641 -0
  10. senoquant/tabs/batch/config.py +270 -0
  11. senoquant/tabs/batch/frontend.py +1283 -0
  12. senoquant/tabs/batch/io.py +326 -0
  13. senoquant/tabs/batch/layers.py +86 -0
  14. senoquant/tabs/quantification/__init__.py +1 -0
  15. senoquant/tabs/quantification/backend.py +228 -0
  16. senoquant/tabs/quantification/features/__init__.py +80 -0
  17. senoquant/tabs/quantification/features/base.py +142 -0
  18. senoquant/tabs/quantification/features/marker/__init__.py +5 -0
  19. senoquant/tabs/quantification/features/marker/config.py +69 -0
  20. senoquant/tabs/quantification/features/marker/dialog.py +437 -0
  21. senoquant/tabs/quantification/features/marker/export.py +879 -0
  22. senoquant/tabs/quantification/features/marker/feature.py +119 -0
  23. senoquant/tabs/quantification/features/marker/morphology.py +285 -0
  24. senoquant/tabs/quantification/features/marker/rows.py +654 -0
  25. senoquant/tabs/quantification/features/marker/thresholding.py +46 -0
  26. senoquant/tabs/quantification/features/roi.py +346 -0
  27. senoquant/tabs/quantification/features/spots/__init__.py +5 -0
  28. senoquant/tabs/quantification/features/spots/config.py +62 -0
  29. senoquant/tabs/quantification/features/spots/dialog.py +477 -0
  30. senoquant/tabs/quantification/features/spots/export.py +1292 -0
  31. senoquant/tabs/quantification/features/spots/feature.py +112 -0
  32. senoquant/tabs/quantification/features/spots/morphology.py +279 -0
  33. senoquant/tabs/quantification/features/spots/rows.py +241 -0
  34. senoquant/tabs/quantification/frontend.py +815 -0
  35. senoquant/tabs/segmentation/__init__.py +1 -0
  36. senoquant/tabs/segmentation/backend.py +131 -0
  37. senoquant/tabs/segmentation/frontend.py +1009 -0
  38. senoquant/tabs/segmentation/models/__init__.py +5 -0
  39. senoquant/tabs/segmentation/models/base.py +146 -0
  40. senoquant/tabs/segmentation/models/cpsam/details.json +65 -0
  41. senoquant/tabs/segmentation/models/cpsam/model.py +150 -0
  42. senoquant/tabs/segmentation/models/default_2d/details.json +69 -0
  43. senoquant/tabs/segmentation/models/default_2d/model.py +664 -0
  44. senoquant/tabs/segmentation/models/default_3d/details.json +69 -0
  45. senoquant/tabs/segmentation/models/default_3d/model.py +682 -0
  46. senoquant/tabs/segmentation/models/hf.py +71 -0
  47. senoquant/tabs/segmentation/models/nuclear_dilation/__init__.py +1 -0
  48. senoquant/tabs/segmentation/models/nuclear_dilation/details.json +26 -0
  49. senoquant/tabs/segmentation/models/nuclear_dilation/model.py +96 -0
  50. senoquant/tabs/segmentation/models/perinuclear_rings/__init__.py +1 -0
  51. senoquant/tabs/segmentation/models/perinuclear_rings/details.json +34 -0
  52. senoquant/tabs/segmentation/models/perinuclear_rings/model.py +132 -0
  53. senoquant/tabs/segmentation/stardist_onnx_utils/__init__.py +2 -0
  54. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/__init__.py +3 -0
  55. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/__init__.py +6 -0
  56. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/generate.py +470 -0
  57. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/prepare.py +273 -0
  58. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/rawdata.py +112 -0
  59. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/transform.py +384 -0
  60. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/__init__.py +0 -0
  61. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/blocks.py +184 -0
  62. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/losses.py +79 -0
  63. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/nets.py +165 -0
  64. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/predict.py +467 -0
  65. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/probability.py +67 -0
  66. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/train.py +148 -0
  67. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/io/__init__.py +163 -0
  68. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/__init__.py +52 -0
  69. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/base_model.py +329 -0
  70. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_isotropic.py +160 -0
  71. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_projection.py +178 -0
  72. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_standard.py +446 -0
  73. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_upsampling.py +54 -0
  74. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/config.py +254 -0
  75. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/pretrained.py +119 -0
  76. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/__init__.py +0 -0
  77. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/care_predict.py +180 -0
  78. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/__init__.py +5 -0
  79. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/plot_utils.py +159 -0
  80. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/six.py +18 -0
  81. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/tf.py +644 -0
  82. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/utils.py +272 -0
  83. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/version.py +1 -0
  84. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/docs/source/conf.py +368 -0
  85. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/setup.py +68 -0
  86. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_datagen.py +169 -0
  87. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_models.py +462 -0
  88. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_utils.py +166 -0
  89. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tools/create_zip_contents.py +34 -0
  90. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/__init__.py +30 -0
  91. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/big.py +624 -0
  92. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/bioimageio_utils.py +494 -0
  93. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/data/__init__.py +39 -0
  94. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/__init__.py +10 -0
  95. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom2d.py +215 -0
  96. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom3d.py +349 -0
  97. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/matching.py +483 -0
  98. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/__init__.py +28 -0
  99. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/base.py +1217 -0
  100. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model2d.py +594 -0
  101. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model3d.py +696 -0
  102. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/nms.py +384 -0
  103. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/__init__.py +2 -0
  104. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/plot.py +74 -0
  105. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/render.py +298 -0
  106. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/rays3d.py +373 -0
  107. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/sample_patches.py +65 -0
  108. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/__init__.py +0 -0
  109. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict2d.py +90 -0
  110. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict3d.py +93 -0
  111. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/utils.py +408 -0
  112. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/version.py +1 -0
  113. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/__init__.py +45 -0
  114. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/__init__.py +17 -0
  115. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/cli.py +55 -0
  116. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/core.py +285 -0
  117. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/__init__.py +15 -0
  118. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/cli.py +36 -0
  119. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/divisibility.py +193 -0
  120. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/probe.py +100 -0
  121. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/receptive_field.py +182 -0
  122. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/rf_cli.py +48 -0
  123. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/valid_sizes.py +278 -0
  124. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/__init__.py +8 -0
  125. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/core.py +157 -0
  126. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/__init__.py +17 -0
  127. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/core.py +226 -0
  128. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/__init__.py +5 -0
  129. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/core.py +401 -0
  130. senoquant/tabs/settings/__init__.py +1 -0
  131. senoquant/tabs/settings/backend.py +29 -0
  132. senoquant/tabs/settings/frontend.py +19 -0
  133. senoquant/tabs/spots/__init__.py +1 -0
  134. senoquant/tabs/spots/backend.py +139 -0
  135. senoquant/tabs/spots/frontend.py +800 -0
  136. senoquant/tabs/spots/models/__init__.py +5 -0
  137. senoquant/tabs/spots/models/base.py +94 -0
  138. senoquant/tabs/spots/models/rmp/details.json +61 -0
  139. senoquant/tabs/spots/models/rmp/model.py +499 -0
  140. senoquant/tabs/spots/models/udwt/details.json +103 -0
  141. senoquant/tabs/spots/models/udwt/model.py +482 -0
  142. senoquant/utils.py +25 -0
  143. senoquant-1.0.0b1.dist-info/METADATA +193 -0
  144. senoquant-1.0.0b1.dist-info/RECORD +148 -0
  145. senoquant-1.0.0b1.dist-info/WHEEL +5 -0
  146. senoquant-1.0.0b1.dist-info/entry_points.txt +2 -0
  147. senoquant-1.0.0b1.dist-info/licenses/LICENSE +28 -0
  148. senoquant-1.0.0b1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,100 @@
1
+ """Utilities for creating valid probe inputs for ONNX inspection."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+
9
+ from .divisibility import infer_div_by
10
+ from .valid_sizes import infer_valid_size_patterns_from_path, snap_size
11
+
12
+
13
+ def make_probe_image(
14
+ image: np.ndarray,
15
+ *,
16
+ model_path: Path | None = None,
17
+ input_layout: str | None = None,
18
+ div_by_cache: dict[Path, tuple[int, ...]] | None = None,
19
+ valid_size_cache: dict[Path, list[object] | None] | None = None,
20
+ ) -> np.ndarray:
21
+ """Create a small probe image aligned with ONNX size constraints.
22
+
23
+ Parameters
24
+ ----------
25
+ image : numpy.ndarray
26
+ Input image array used to derive probe size.
27
+ model_path : pathlib.Path or None, optional
28
+ ONNX model path used for inspecting size constraints.
29
+ input_layout : str or None, optional
30
+ Model input layout (e.g., "NHWC", "NDHWC") used for size inspection.
31
+ div_by_cache : dict or None, optional
32
+ Cache for divisibility requirements keyed by model path.
33
+ valid_size_cache : dict or None, optional
34
+ Cache for valid size patterns keyed by model path.
35
+
36
+ Returns
37
+ -------
38
+ numpy.ndarray
39
+ Probe image padded/cropped to a valid spatial size.
40
+ """
41
+ target = 256 if image.ndim == 2 else 64
42
+ probe_shape = []
43
+ for dim in image.shape:
44
+ size = min(dim, target)
45
+ if size >= 16:
46
+ size = size - (size % 16)
47
+ if size == 0:
48
+ size = min(dim, target)
49
+ probe_shape.append(max(1, size))
50
+
51
+ probe = image[tuple(slice(0, s) for s in probe_shape)]
52
+
53
+ if model_path is None or input_layout is None:
54
+ return probe
55
+
56
+ patterns = None
57
+ if valid_size_cache is not None:
58
+ patterns = valid_size_cache.get(model_path)
59
+ if patterns is None:
60
+ try:
61
+ patterns = infer_valid_size_patterns_from_path(
62
+ model_path,
63
+ input_layout,
64
+ image.ndim,
65
+ )
66
+ except Exception:
67
+ patterns = None
68
+ if valid_size_cache is not None:
69
+ valid_size_cache[model_path] = patterns
70
+
71
+ div_by = None
72
+ if div_by_cache is not None:
73
+ div_by = div_by_cache.get(model_path)
74
+ if div_by is None:
75
+ try:
76
+ div_by = infer_div_by(model_path, ndim=image.ndim)
77
+ except Exception:
78
+ div_by = None
79
+ if div_by_cache is not None and div_by is not None:
80
+ div_by_cache[model_path] = div_by
81
+
82
+ desired = list(probe.shape)
83
+ if patterns:
84
+ desired = [
85
+ max(1, snap_size(int(size), patterns[axis]))
86
+ for axis, size in enumerate(desired)
87
+ ]
88
+ elif div_by:
89
+ desired = [
90
+ max(int(d), (int(size) // int(d)) * int(d)) if d else int(size)
91
+ for size, d in zip(desired, div_by)
92
+ ]
93
+
94
+ desired = [max(1, int(size)) for size in desired]
95
+ crop_slices = tuple(slice(0, min(s, d)) for s, d in zip(probe.shape, desired))
96
+ probe = probe[crop_slices]
97
+ pads = [(0, max(0, d - s)) for s, d in zip(probe.shape, desired)]
98
+ if any(pad_after > 0 for _, pad_after in pads):
99
+ probe = np.pad(probe, pads, mode="reflect")
100
+ return probe
@@ -0,0 +1,182 @@
1
+ """Empirically estimate receptive field and tiling overlap for ONNX models.
2
+
3
+ This module mirrors StarDist's empirical receptive-field estimation:
4
+ run the model once on a single-pixel impulse and once on zeros, then
5
+ measure the spatial support of the difference in the probability output.
6
+ The measured extents define the overlap needed to avoid tile boundary
7
+ artifacts in tiled prediction.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from pathlib import Path
13
+ from typing import Iterable
14
+
15
+ import numpy as np
16
+
17
+ from .divisibility import infer_div_by
18
+
19
+
20
+ def infer_receptive_field(
21
+ model_path: str | Path,
22
+ ndim: int | None = None,
23
+ input_shape: tuple[int, ...] | None = None,
24
+ eps: float = 0.0,
25
+ ) -> tuple[tuple[int, int], ...]:
26
+ """Estimate the receptive field via impulse response.
27
+
28
+ This mirrors StarDist's empirical receptive-field estimation: run the model
29
+ on an impulse image and on zeros, then find the spatial support of the
30
+ difference.
31
+
32
+ Parameters
33
+ ----------
34
+ model_path : str or pathlib.Path
35
+ Path to the ONNX model file.
36
+ ndim : int or None, optional
37
+ Spatial dimensionality (2 or 3). If None, inferred from input rank.
38
+ input_shape : tuple[int, ...] or None, optional
39
+ Spatial shape for the probe input. If None, a power-of-two shape is
40
+ chosen and adjusted to satisfy the inferred divisibility.
41
+ eps : float, optional
42
+ Threshold used to detect non-zero influence in the output. Default 0.0.
43
+
44
+ Returns
45
+ -------
46
+ tuple[tuple[int, int], ...]
47
+ Per-axis receptive field extents as (left, right) offsets from the
48
+ center pixel/voxel in input coordinates.
49
+
50
+ Notes
51
+ -----
52
+ - The probe uses an impulse (single 1.0) at the spatial center.
53
+ - The probability output is selected by heuristics (last dim == 1).
54
+ - Output is mapped back to input resolution using the inferred grid.
55
+ """
56
+ import onnxruntime as ort
57
+ from scipy.ndimage import zoom
58
+
59
+ model_path = Path(model_path)
60
+ session = ort.InferenceSession(str(model_path))
61
+
62
+ input_name = session.get_inputs()[0].name
63
+ output_names = [out.name for out in session.get_outputs()]
64
+
65
+ if ndim is None:
66
+ ndim = _infer_ndim_from_input(session)
67
+
68
+ if input_shape is None:
69
+ # Choose a reasonable power-of-two probe size and round up to a
70
+ # multiple of the inferred divisibility to avoid internal mismatches.
71
+ base = 256 if ndim == 2 else 64
72
+ div_by = infer_div_by(model_path, ndim=ndim)
73
+ input_shape = tuple(_round_up(base, d) for d in div_by)
74
+
75
+ if len(input_shape) != ndim:
76
+ raise ValueError("input_shape must match ndim.")
77
+
78
+ # Build impulse and zero inputs (NHWC/NDHWC).
79
+ center = tuple(s // 2 for s in input_shape)
80
+ x = np.zeros((1, *input_shape, 1), dtype=np.float32)
81
+ z = np.zeros_like(x)
82
+ x[(0, *center, 0)] = 1.0
83
+
84
+ # Run the model and extract the probability output.
85
+ y = _run_prob(session, output_names, input_name, x, ndim)
86
+ y0 = _run_prob(session, output_names, input_name, z, ndim)
87
+
88
+ # Infer grid from input/output shapes (input / output per axis).
89
+ grid = tuple(
90
+ max(1, int(round(si / so))) for si, so in zip(input_shape, y.shape)
91
+ )
92
+ y = zoom(y, grid, order=0)
93
+ y0 = zoom(y0, grid, order=0)
94
+
95
+ # Measure where the response differs from zero.
96
+ diff = np.abs(y - y0) > eps
97
+ indices = np.where(diff)
98
+ if any(len(i) == 0 for i in indices):
99
+ raise RuntimeError("Failed to detect receptive field; try a larger input_shape.")
100
+
101
+ return tuple((c - int(np.min(i)), int(np.max(i)) - c) for c, i in zip(center, indices))
102
+
103
+
104
+ def recommend_tile_overlap(
105
+ model_path: str | Path,
106
+ ndim: int | None = None,
107
+ input_shape: tuple[int, ...] | None = None,
108
+ eps: float = 0.0,
109
+ ) -> tuple[int, ...]:
110
+ """Return recommended tile overlap per axis from empirical RF.
111
+
112
+ Parameters
113
+ ----------
114
+ model_path : str or pathlib.Path
115
+ Path to the ONNX model file.
116
+ ndim : int or None, optional
117
+ Spatial dimensionality (2 or 3). If None, inferred from input rank.
118
+ input_shape : tuple[int, ...] or None, optional
119
+ Spatial probe input shape. If None, a default shape is used.
120
+ eps : float, optional
121
+ Threshold used to detect non-zero influence in the output.
122
+
123
+ Returns
124
+ -------
125
+ tuple[int, ...]
126
+ Per-axis overlap in input pixels.
127
+ """
128
+ rf = infer_receptive_field(
129
+ model_path=model_path,
130
+ ndim=ndim,
131
+ input_shape=input_shape,
132
+ eps=eps,
133
+ )
134
+ return tuple(max(pair) for pair in rf)
135
+
136
+
137
+ def _run_prob(session, output_names, input_name, input_tensor, ndim: int) -> np.ndarray:
138
+ """Run the ONNX model and return the probability output in spatial layout."""
139
+ outputs = session.run(output_names, {input_name: input_tensor})
140
+ prob = _select_prob_output(outputs)
141
+ prob = _to_spatial(prob, ndim)
142
+ return prob
143
+
144
+
145
+ def _select_prob_output(outputs: list[np.ndarray]) -> np.ndarray:
146
+ """Pick the probability output from ONNX outputs."""
147
+ for arr in outputs:
148
+ if arr.ndim >= 4 and arr.shape[-1] == 1:
149
+ return arr
150
+ return outputs[0]
151
+
152
+
153
+ def _to_spatial(prob: np.ndarray, ndim: int) -> np.ndarray:
154
+ """Convert a batched prob tensor into spatial layout (YX/ZYX)."""
155
+ if ndim == 2:
156
+ if prob.ndim == 4 and prob.shape[-1] == 1:
157
+ return prob[0, ..., 0]
158
+ if prob.ndim == 4 and prob.shape[1] == 1:
159
+ return prob[0, 0, ...]
160
+ if ndim == 3:
161
+ if prob.ndim == 5 and prob.shape[-1] == 1:
162
+ return prob[0, ..., 0]
163
+ if prob.ndim == 5 and prob.shape[1] == 1:
164
+ return prob[0, 0, ...]
165
+ raise ValueError("Unsupported prob output layout.")
166
+
167
+
168
+ def _infer_ndim_from_input(session) -> int:
169
+ """Infer spatial dimensionality from ONNX session input rank."""
170
+ shape = session.get_inputs()[0].shape
171
+ if len(shape) == 4:
172
+ return 2
173
+ if len(shape) == 5:
174
+ return 3
175
+ raise ValueError(f"Unsupported input rank {len(shape)}.")
176
+
177
+
178
+ def _round_up(value: int, multiple: int) -> int:
179
+ """Round up ``value`` to the next multiple."""
180
+ if multiple <= 0:
181
+ return value
182
+ return int(np.ceil(value / multiple) * multiple)
@@ -0,0 +1,48 @@
1
+ """CLI for empirical receptive-field estimation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ from pathlib import Path
7
+
8
+ from .receptive_field import infer_receptive_field, recommend_tile_overlap
9
+
10
+
11
+ def _parse_args() -> argparse.Namespace:
12
+ parser = argparse.ArgumentParser(description="Estimate ONNX receptive field.")
13
+ parser.add_argument("model", type=Path, help="Path to the ONNX model.")
14
+ parser.add_argument("--ndim", type=int, choices=(2, 3), default=None)
15
+ parser.add_argument(
16
+ "--shape",
17
+ type=int,
18
+ nargs="+",
19
+ default=None,
20
+ help="Spatial input shape (e.g. --shape 256 256 or --shape 64 64 64).",
21
+ )
22
+ parser.add_argument("--eps", type=float, default=0.0)
23
+ return parser.parse_args()
24
+
25
+
26
+ def main() -> None:
27
+ args = _parse_args()
28
+ input_shape = tuple(args.shape) if args.shape else None
29
+ rf = infer_receptive_field(
30
+ model_path=args.model,
31
+ ndim=args.ndim,
32
+ input_shape=input_shape,
33
+ eps=args.eps,
34
+ )
35
+ overlap = recommend_tile_overlap(
36
+ model_path=args.model,
37
+ ndim=args.ndim,
38
+ input_shape=input_shape,
39
+ eps=args.eps,
40
+ )
41
+
42
+ print(f"Model: {args.model}")
43
+ print(f"Receptive field: {rf}")
44
+ print(f"Recommended overlap: {overlap}")
45
+
46
+
47
+ if __name__ == "__main__":
48
+ main()
@@ -0,0 +1,278 @@
1
+ """Empirically infer valid spatial sizes for ONNX model inputs.
2
+
3
+ This module probes the ONNX runtime by running the model on small inputs and
4
+ recording which spatial sizes succeed. It then summarizes valid sizes as
5
+ periodic residues (e.g., sizes of the form ``16k`` or ``16k+1``).
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import math
11
+ from dataclasses import dataclass
12
+ from typing import Iterable, Sequence
13
+
14
+ import numpy as np
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class ValidSizePattern:
19
+ """Periodic validity pattern for a single spatial axis.
20
+
21
+ Attributes
22
+ ----------
23
+ period : int
24
+ Periodicity for valid sizes.
25
+ residues : tuple[int, ...]
26
+ Allowed ``size % period`` residues.
27
+ min_valid : int
28
+ Smallest observed valid size.
29
+ """
30
+
31
+ period: int
32
+ residues: tuple[int, ...]
33
+ min_valid: int
34
+
35
+
36
+ def infer_valid_size_patterns(
37
+ session,
38
+ input_name: str,
39
+ output_names: Iterable[str],
40
+ input_layout: str,
41
+ ndim: int,
42
+ max_probe: int = 64,
43
+ ) -> list[ValidSizePattern]:
44
+ """Probe ONNX execution to infer valid size residues per axis.
45
+
46
+ Parameters
47
+ ----------
48
+ session : onnxruntime.InferenceSession
49
+ ONNX Runtime session used to execute the model.
50
+ input_name : str
51
+ Name of the ONNX input tensor.
52
+ output_names : Iterable[str]
53
+ Output tensor names to request during inference.
54
+ input_layout : str
55
+ Input layout string (e.g., ``"NHWC"`` or ``"NDHWC"``).
56
+ ndim : int
57
+ Spatial dimensionality (2 or 3).
58
+ max_probe : int, optional
59
+ Maximum size to probe per axis. Default is 64.
60
+
61
+ Returns
62
+ -------
63
+ list[ValidSizePattern]
64
+ One entry per axis describing periodic valid size residues.
65
+
66
+ Raises
67
+ ------
68
+ RuntimeError
69
+ If no valid sizes can be found within the probe range.
70
+ """
71
+ if ndim not in (2, 3):
72
+ raise ValueError("ndim must be 2 or 3.")
73
+
74
+ output_names = list(output_names)
75
+
76
+ base = _find_valid_base(
77
+ session, input_name, output_names, input_layout, ndim, max_probe
78
+ )
79
+
80
+ patterns: list[ValidSizePattern] = []
81
+ any_valid = False
82
+ for axis in range(ndim):
83
+ valid = []
84
+ for size in range(1, max_probe + 1):
85
+ shape = [base] * ndim
86
+ shape[axis] = size
87
+ if _try_run(session, input_name, output_names, input_layout, shape):
88
+ valid.append(size)
89
+ if not valid:
90
+ patterns.append(ValidSizePattern(period=1, residues=(0,), min_valid=1))
91
+ continue
92
+ any_valid = True
93
+ period, residues = _infer_period_and_residues(valid, max_probe)
94
+ patterns.append(
95
+ ValidSizePattern(
96
+ period=int(period),
97
+ residues=tuple(int(r) for r in residues),
98
+ min_valid=int(min(valid)),
99
+ )
100
+ )
101
+
102
+ if not any_valid:
103
+ raise RuntimeError(
104
+ f"No valid sizes found within 1..{max_probe} for any axis."
105
+ )
106
+ return patterns
107
+
108
+
109
+ def infer_valid_size_patterns_from_path(
110
+ model_path,
111
+ input_layout: str,
112
+ ndim: int,
113
+ max_probe: int = 64,
114
+ ) -> list[ValidSizePattern]:
115
+ """Probe valid sizes using a temporary, quiet ONNX session.
116
+
117
+ Parameters
118
+ ----------
119
+ model_path : str or pathlib.Path
120
+ Path to the ONNX model file.
121
+ input_layout : str
122
+ Input layout string (e.g., ``"NHWC"`` or ``"NDHWC"``).
123
+ ndim : int
124
+ Spatial dimensionality (2 or 3).
125
+ max_probe : int, optional
126
+ Maximum size to probe per axis. Default is 64.
127
+
128
+ Returns
129
+ -------
130
+ list[ValidSizePattern]
131
+ One entry per axis describing periodic valid size residues.
132
+ """
133
+ import onnxruntime as ort
134
+
135
+ sess_options = ort.SessionOptions()
136
+ # Suppress ORT error logs during probe failures.
137
+ sess_options.log_severity_level = 4
138
+ session = ort.InferenceSession(str(model_path), sess_options=sess_options)
139
+ input_name = session.get_inputs()[0].name
140
+ output_names = [o.name for o in session.get_outputs()]
141
+ return infer_valid_size_patterns(
142
+ session,
143
+ input_name,
144
+ output_names,
145
+ input_layout,
146
+ ndim,
147
+ max_probe=max_probe,
148
+ )
149
+
150
+
151
+ def _find_valid_base(
152
+ session,
153
+ input_name: str,
154
+ output_names: list[str],
155
+ input_layout: str,
156
+ ndim: int,
157
+ max_probe: int,
158
+ ) -> int:
159
+ """Return the smallest symmetric size that executes successfully."""
160
+ for size in range(1, max_probe + 1):
161
+ shape = [size] * ndim
162
+ if _try_run(session, input_name, output_names, input_layout, shape):
163
+ return size
164
+ raise RuntimeError("Failed to find any valid base size for probing.")
165
+
166
+
167
+ def _try_run(
168
+ session,
169
+ input_name: str,
170
+ output_names: list[str],
171
+ input_layout: str,
172
+ spatial_shape: list[int],
173
+ ) -> bool:
174
+ """Return True if the model runs on the given spatial shape."""
175
+ if input_layout in ("NHWC", "NDHWC"):
176
+ input_tensor = np.zeros((1, *spatial_shape, 1), dtype=np.float32)
177
+ elif input_layout in ("NCHW", "NCDHW"):
178
+ input_tensor = np.zeros((1, 1, *spatial_shape), dtype=np.float32)
179
+ else:
180
+ raise ValueError(f"Unsupported input layout {input_layout}.")
181
+ try:
182
+ session.run(list(output_names), {input_name: input_tensor})
183
+ except Exception:
184
+ return False
185
+ return True
186
+
187
+
188
+ def _infer_period_and_residues(
189
+ valid_sizes: list[int], max_probe: int
190
+ ) -> tuple[int, list[int]]:
191
+ """Infer periodicity and residue set for valid sizes."""
192
+ valid_set = set(valid_sizes)
193
+ if not valid_set:
194
+ return 1, [0]
195
+
196
+ min_valid = min(valid_set)
197
+ for period in range(1, max_probe + 1):
198
+ residues = {v % period for v in valid_set}
199
+ ok = True
200
+ for size in range(min_valid, max_probe + 1):
201
+ if (size % period in residues) != (size in valid_set):
202
+ ok = False
203
+ break
204
+ if ok:
205
+ return period, sorted(residues)
206
+
207
+ if len(valid_sizes) < 2:
208
+ return max(1, valid_sizes[0]), [valid_sizes[0] % max(1, valid_sizes[0])]
209
+
210
+ diffs = [b - a for a, b in zip(valid_sizes, valid_sizes[1:]) if b > a]
211
+ period = diffs[0]
212
+ for d in diffs[1:]:
213
+ period = math.gcd(period, d)
214
+ residues = sorted({v % period for v in valid_set})
215
+ return max(1, period), residues
216
+
217
+
218
+ def snap_size(size: int, pattern: ValidSizePattern) -> int:
219
+ """Adjust a size to the nearest valid residue at or below ``size``.
220
+
221
+ Parameters
222
+ ----------
223
+ size : int
224
+ Proposed size.
225
+ pattern : ValidSizePattern
226
+ Valid size pattern for the axis.
227
+
228
+ Returns
229
+ -------
230
+ int
231
+ Snapped valid size.
232
+ """
233
+ period = max(1, int(pattern.period))
234
+ residues = set(int(r) for r in pattern.residues)
235
+ min_valid = int(pattern.min_valid)
236
+ if size <= min_valid:
237
+ return min_valid
238
+ for delta in range(period + 1):
239
+ candidate = size - delta
240
+ if candidate < min_valid:
241
+ break
242
+ if candidate % period in residues:
243
+ return candidate
244
+ candidate = size
245
+ while candidate % period not in residues:
246
+ candidate += 1
247
+ return candidate
248
+
249
+
250
+ def snap_shape(
251
+ shape: Sequence[int],
252
+ patterns: Sequence[ValidSizePattern],
253
+ *,
254
+ skip_axes: Sequence[int] = (),
255
+ ) -> tuple[int, ...]:
256
+ """Snap each axis of a shape to the nearest valid size.
257
+
258
+ Parameters
259
+ ----------
260
+ shape : Sequence[int]
261
+ Proposed spatial shape.
262
+ patterns : Sequence[ValidSizePattern]
263
+ Per-axis valid size patterns.
264
+ skip_axes : Sequence[int], optional
265
+ Axes to leave unchanged (e.g., skip Z for 3D models).
266
+
267
+ Returns
268
+ -------
269
+ tuple[int, ...]
270
+ Snapped spatial shape.
271
+ """
272
+ snapped = []
273
+ for axis, size in enumerate(shape):
274
+ if axis in skip_axes:
275
+ snapped.append(int(size))
276
+ continue
277
+ snapped.append(snap_size(int(size), patterns[axis]))
278
+ return tuple(snapped)
@@ -0,0 +1,8 @@
1
+ """Post-processing utilities for ONNX StarDist inference."""
2
+
3
+ from .core import instances_from_prediction_2d, instances_from_prediction_3d
4
+
5
+ __all__ = [
6
+ "instances_from_prediction_2d",
7
+ "instances_from_prediction_3d",
8
+ ]