senoquant 1.0.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. senoquant/__init__.py +6 -0
  2. senoquant/_reader.py +7 -0
  3. senoquant/_widget.py +33 -0
  4. senoquant/napari.yaml +83 -0
  5. senoquant/reader/__init__.py +5 -0
  6. senoquant/reader/core.py +369 -0
  7. senoquant/tabs/__init__.py +15 -0
  8. senoquant/tabs/batch/__init__.py +10 -0
  9. senoquant/tabs/batch/backend.py +641 -0
  10. senoquant/tabs/batch/config.py +270 -0
  11. senoquant/tabs/batch/frontend.py +1283 -0
  12. senoquant/tabs/batch/io.py +326 -0
  13. senoquant/tabs/batch/layers.py +86 -0
  14. senoquant/tabs/quantification/__init__.py +1 -0
  15. senoquant/tabs/quantification/backend.py +228 -0
  16. senoquant/tabs/quantification/features/__init__.py +80 -0
  17. senoquant/tabs/quantification/features/base.py +142 -0
  18. senoquant/tabs/quantification/features/marker/__init__.py +5 -0
  19. senoquant/tabs/quantification/features/marker/config.py +69 -0
  20. senoquant/tabs/quantification/features/marker/dialog.py +437 -0
  21. senoquant/tabs/quantification/features/marker/export.py +879 -0
  22. senoquant/tabs/quantification/features/marker/feature.py +119 -0
  23. senoquant/tabs/quantification/features/marker/morphology.py +285 -0
  24. senoquant/tabs/quantification/features/marker/rows.py +654 -0
  25. senoquant/tabs/quantification/features/marker/thresholding.py +46 -0
  26. senoquant/tabs/quantification/features/roi.py +346 -0
  27. senoquant/tabs/quantification/features/spots/__init__.py +5 -0
  28. senoquant/tabs/quantification/features/spots/config.py +62 -0
  29. senoquant/tabs/quantification/features/spots/dialog.py +477 -0
  30. senoquant/tabs/quantification/features/spots/export.py +1292 -0
  31. senoquant/tabs/quantification/features/spots/feature.py +112 -0
  32. senoquant/tabs/quantification/features/spots/morphology.py +279 -0
  33. senoquant/tabs/quantification/features/spots/rows.py +241 -0
  34. senoquant/tabs/quantification/frontend.py +815 -0
  35. senoquant/tabs/segmentation/__init__.py +1 -0
  36. senoquant/tabs/segmentation/backend.py +131 -0
  37. senoquant/tabs/segmentation/frontend.py +1009 -0
  38. senoquant/tabs/segmentation/models/__init__.py +5 -0
  39. senoquant/tabs/segmentation/models/base.py +146 -0
  40. senoquant/tabs/segmentation/models/cpsam/details.json +65 -0
  41. senoquant/tabs/segmentation/models/cpsam/model.py +150 -0
  42. senoquant/tabs/segmentation/models/default_2d/details.json +69 -0
  43. senoquant/tabs/segmentation/models/default_2d/model.py +664 -0
  44. senoquant/tabs/segmentation/models/default_3d/details.json +69 -0
  45. senoquant/tabs/segmentation/models/default_3d/model.py +682 -0
  46. senoquant/tabs/segmentation/models/hf.py +71 -0
  47. senoquant/tabs/segmentation/models/nuclear_dilation/__init__.py +1 -0
  48. senoquant/tabs/segmentation/models/nuclear_dilation/details.json +26 -0
  49. senoquant/tabs/segmentation/models/nuclear_dilation/model.py +96 -0
  50. senoquant/tabs/segmentation/models/perinuclear_rings/__init__.py +1 -0
  51. senoquant/tabs/segmentation/models/perinuclear_rings/details.json +34 -0
  52. senoquant/tabs/segmentation/models/perinuclear_rings/model.py +132 -0
  53. senoquant/tabs/segmentation/stardist_onnx_utils/__init__.py +2 -0
  54. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/__init__.py +3 -0
  55. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/__init__.py +6 -0
  56. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/generate.py +470 -0
  57. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/prepare.py +273 -0
  58. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/rawdata.py +112 -0
  59. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/transform.py +384 -0
  60. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/__init__.py +0 -0
  61. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/blocks.py +184 -0
  62. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/losses.py +79 -0
  63. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/nets.py +165 -0
  64. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/predict.py +467 -0
  65. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/probability.py +67 -0
  66. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/train.py +148 -0
  67. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/io/__init__.py +163 -0
  68. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/__init__.py +52 -0
  69. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/base_model.py +329 -0
  70. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_isotropic.py +160 -0
  71. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_projection.py +178 -0
  72. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_standard.py +446 -0
  73. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_upsampling.py +54 -0
  74. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/config.py +254 -0
  75. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/pretrained.py +119 -0
  76. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/__init__.py +0 -0
  77. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/care_predict.py +180 -0
  78. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/__init__.py +5 -0
  79. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/plot_utils.py +159 -0
  80. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/six.py +18 -0
  81. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/tf.py +644 -0
  82. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/utils.py +272 -0
  83. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/version.py +1 -0
  84. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/docs/source/conf.py +368 -0
  85. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/setup.py +68 -0
  86. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_datagen.py +169 -0
  87. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_models.py +462 -0
  88. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_utils.py +166 -0
  89. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tools/create_zip_contents.py +34 -0
  90. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/__init__.py +30 -0
  91. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/big.py +624 -0
  92. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/bioimageio_utils.py +494 -0
  93. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/data/__init__.py +39 -0
  94. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/__init__.py +10 -0
  95. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom2d.py +215 -0
  96. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom3d.py +349 -0
  97. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/matching.py +483 -0
  98. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/__init__.py +28 -0
  99. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/base.py +1217 -0
  100. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model2d.py +594 -0
  101. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model3d.py +696 -0
  102. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/nms.py +384 -0
  103. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/__init__.py +2 -0
  104. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/plot.py +74 -0
  105. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/render.py +298 -0
  106. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/rays3d.py +373 -0
  107. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/sample_patches.py +65 -0
  108. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/__init__.py +0 -0
  109. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict2d.py +90 -0
  110. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict3d.py +93 -0
  111. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/utils.py +408 -0
  112. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/version.py +1 -0
  113. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/__init__.py +45 -0
  114. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/__init__.py +17 -0
  115. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/cli.py +55 -0
  116. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/core.py +285 -0
  117. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/__init__.py +15 -0
  118. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/cli.py +36 -0
  119. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/divisibility.py +193 -0
  120. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/probe.py +100 -0
  121. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/receptive_field.py +182 -0
  122. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/rf_cli.py +48 -0
  123. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/valid_sizes.py +278 -0
  124. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/__init__.py +8 -0
  125. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/core.py +157 -0
  126. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/__init__.py +17 -0
  127. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/core.py +226 -0
  128. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/__init__.py +5 -0
  129. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/core.py +401 -0
  130. senoquant/tabs/settings/__init__.py +1 -0
  131. senoquant/tabs/settings/backend.py +29 -0
  132. senoquant/tabs/settings/frontend.py +19 -0
  133. senoquant/tabs/spots/__init__.py +1 -0
  134. senoquant/tabs/spots/backend.py +139 -0
  135. senoquant/tabs/spots/frontend.py +800 -0
  136. senoquant/tabs/spots/models/__init__.py +5 -0
  137. senoquant/tabs/spots/models/base.py +94 -0
  138. senoquant/tabs/spots/models/rmp/details.json +61 -0
  139. senoquant/tabs/spots/models/rmp/model.py +499 -0
  140. senoquant/tabs/spots/models/udwt/details.json +103 -0
  141. senoquant/tabs/spots/models/udwt/model.py +482 -0
  142. senoquant/utils.py +25 -0
  143. senoquant-1.0.0b1.dist-info/METADATA +193 -0
  144. senoquant-1.0.0b1.dist-info/RECORD +148 -0
  145. senoquant-1.0.0b1.dist-info/WHEEL +5 -0
  146. senoquant-1.0.0b1.dist-info/entry_points.txt +2 -0
  147. senoquant-1.0.0b1.dist-info/licenses/LICENSE +28 -0
  148. senoquant-1.0.0b1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,45 @@
1
+ """ONNX tiling and prediction framework for StarDist."""
2
+
3
+ __all__ = [
4
+ "normalize",
5
+ "pad_for_tiling",
6
+ "pad_to_multiple",
7
+ "unpad_to_shape",
8
+ "TilingSpec",
9
+ "default_tiling_spec",
10
+ "predict_tiled",
11
+ "instances_from_prediction_2d",
12
+ "instances_from_prediction_3d",
13
+ "DEFAULT_2D_MODEL",
14
+ "DEFAULT_3D_MODEL",
15
+ "convert_model_to_onnx",
16
+ "convert_pretrained_2d",
17
+ "convert_pretrained_3d",
18
+ "infer_div_by",
19
+ "summarize_model_io",
20
+ ]
21
+
22
+
23
+ def __getattr__(name):
24
+ if name in {"normalize", "pad_for_tiling", "pad_to_multiple", "unpad_to_shape"}:
25
+ from . import pre as _pre
26
+ return getattr(_pre, name)
27
+ if name in {"TilingSpec", "default_tiling_spec", "predict_tiled"}:
28
+ from . import predict as _predict
29
+ return getattr(_predict, name)
30
+ if name in {"instances_from_prediction_2d", "instances_from_prediction_3d"}:
31
+ from . import post as _post
32
+ return getattr(_post, name)
33
+ if name in {
34
+ "DEFAULT_2D_MODEL",
35
+ "DEFAULT_3D_MODEL",
36
+ "convert_model_to_onnx",
37
+ "convert_pretrained_2d",
38
+ "convert_pretrained_3d",
39
+ }:
40
+ from . import convert as _convert
41
+ return getattr(_convert, name)
42
+ if name in {"infer_div_by", "summarize_model_io"}:
43
+ from . import inspect as _inspect
44
+ return getattr(_inspect, name)
45
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
@@ -0,0 +1,17 @@
1
+ """ONNX conversion helpers for StarDist models."""
2
+
3
+ from .core import (
4
+ DEFAULT_2D_MODEL,
5
+ DEFAULT_3D_MODEL,
6
+ convert_model_to_onnx,
7
+ convert_pretrained_2d,
8
+ convert_pretrained_3d,
9
+ )
10
+
11
+ __all__ = [
12
+ "DEFAULT_2D_MODEL",
13
+ "DEFAULT_3D_MODEL",
14
+ "convert_model_to_onnx",
15
+ "convert_pretrained_2d",
16
+ "convert_pretrained_3d",
17
+ ]
@@ -0,0 +1,55 @@
1
+ """CLI for converting StarDist models to ONNX."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ from pathlib import Path
7
+
8
+ from .core import (
9
+ DEFAULT_2D_MODEL,
10
+ DEFAULT_3D_MODEL,
11
+ convert_pretrained_2d,
12
+ convert_pretrained_3d,
13
+ )
14
+
15
+
16
+ def main() -> None:
17
+ parser = argparse.ArgumentParser(description="Convert StarDist models to ONNX.")
18
+ parser.add_argument(
19
+ "--dim",
20
+ choices=("2", "3", "2d", "3d"),
21
+ default="2d",
22
+ help="Model dimensionality.",
23
+ )
24
+ parser.add_argument(
25
+ "--model",
26
+ default=None,
27
+ help="Pretrained model name/alias or model directory path.",
28
+ )
29
+ parser.add_argument(
30
+ "--output",
31
+ default=".",
32
+ help="Output directory or ONNX file path.",
33
+ )
34
+ parser.add_argument(
35
+ "--opset",
36
+ type=int,
37
+ default=18,
38
+ help="ONNX opset version to export.",
39
+ )
40
+ args = parser.parse_args()
41
+
42
+ dim = 2 if args.dim in ("2", "2d") else 3
43
+ model_name = args.model or (DEFAULT_2D_MODEL if dim == 2 else DEFAULT_3D_MODEL)
44
+ output = Path(args.output)
45
+
46
+ if dim == 2:
47
+ path = convert_pretrained_2d(model_name, output, opset=args.opset)
48
+ else:
49
+ path = convert_pretrained_3d(model_name, output, opset=args.opset)
50
+
51
+ print(f"Saved ONNX model to {path}")
52
+
53
+
54
+ if __name__ == "__main__":
55
+ main()
@@ -0,0 +1,285 @@
1
+ """Convert StarDist Keras models to ONNX."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ import re
7
+ import sys
8
+ import types
9
+ import importlib
10
+ import tempfile
11
+
12
+
13
+ DEFAULT_2D_MODEL = "2D_versatile_fluo"
14
+ DEFAULT_3D_MODEL = "3D_demo"
15
+
16
+
17
+ def convert_pretrained_2d(
18
+ model_name: str = DEFAULT_2D_MODEL,
19
+ output: str | Path = ".",
20
+ *,
21
+ opset: int = 18,
22
+ ) -> Path:
23
+ """Convert a pretrained StarDist2D model to ONNX.
24
+
25
+ Parameters
26
+ ----------
27
+ model_name : str, optional
28
+ Pretrained model name or alias. Defaults to ``2D_versatile_fluo``.
29
+ output : str or pathlib.Path, optional
30
+ Output directory or ONNX file path. Defaults to the current directory.
31
+ opset : int, optional
32
+ ONNX opset version to export. Defaults to 13.
33
+
34
+ Returns
35
+ -------
36
+ pathlib.Path
37
+ Path to the saved ONNX model.
38
+ """
39
+ model = _load_stardist_model(2, model_name)
40
+ output_path = _resolve_output_path(output, f"stardist2d_{_safe_name(model_name)}.onnx")
41
+ return convert_model_to_onnx(model, output_path, opset=opset)
42
+
43
+
44
+ def convert_pretrained_3d(
45
+ model_name: str = DEFAULT_3D_MODEL,
46
+ output: str | Path = ".",
47
+ *,
48
+ opset: int = 18,
49
+ ) -> Path:
50
+ """Convert a pretrained StarDist3D model to ONNX.
51
+
52
+ Parameters
53
+ ----------
54
+ model_name : str, optional
55
+ Pretrained model name or alias. Defaults to ``3D_demo``.
56
+ output : str or pathlib.Path, optional
57
+ Output directory or ONNX file path. Defaults to the current directory.
58
+ opset : int, optional
59
+ ONNX opset version to export. Defaults to 13.
60
+
61
+ Returns
62
+ -------
63
+ pathlib.Path
64
+ Path to the saved ONNX model.
65
+ """
66
+ model = _load_stardist_model(3, model_name)
67
+ output_path = _resolve_output_path(output, f"stardist3d_{_safe_name(model_name)}.onnx")
68
+ return convert_model_to_onnx(model, output_path, opset=opset)
69
+
70
+
71
+ def convert_model_to_onnx(model, output_path: str | Path, *, opset: int = 18) -> Path:
72
+ """Convert a StarDist model instance to ONNX.
73
+
74
+ Parameters
75
+ ----------
76
+ model : object
77
+ StarDist2D or StarDist3D instance with a ``keras_model`` attribute.
78
+ output_path : str or pathlib.Path
79
+ File path to save the ONNX model.
80
+ opset : int, optional
81
+ ONNX opset version to export. Defaults to 13.
82
+
83
+ Returns
84
+ -------
85
+ pathlib.Path
86
+ Path to the saved ONNX model.
87
+ """
88
+ tf = _import_tensorflow()
89
+ tf2onnx = _import_tf2onnx()
90
+
91
+ output_path = Path(output_path)
92
+ output_path.parent.mkdir(parents=True, exist_ok=True)
93
+
94
+ keras_model = model.keras_model
95
+ keras_model.trainable = False
96
+
97
+ input_tensor = keras_model.inputs[0]
98
+ input_name = input_tensor.name.split(":")[0]
99
+ input_shape = list(input_tensor.shape)
100
+ if input_shape and input_shape[0] is None:
101
+ input_shape[0] = 1
102
+ input_signature = (tf.TensorSpec(tuple(input_shape), input_tensor.dtype, name=input_name),)
103
+ try:
104
+ _convert_via_saved_model(tf2onnx, keras_model, input_signature, opset, output_path)
105
+ except Exception:
106
+ try:
107
+ output_names = [out.name.split(":")[0] for out in keras_model.outputs]
108
+ tf2onnx.convert.from_keras(
109
+ keras_model,
110
+ input_signature=input_signature,
111
+ opset=opset,
112
+ output_path=str(output_path),
113
+ output_names=output_names,
114
+ )
115
+ except TypeError:
116
+ try:
117
+ tf2onnx.convert.from_keras(
118
+ keras_model,
119
+ input_signature=input_signature,
120
+ opset=opset,
121
+ output_path=str(output_path),
122
+ )
123
+ except ValueError as exc:
124
+ if "explicit_paddings" not in str(exc):
125
+ raise
126
+ _convert_via_frozen_graph(
127
+ tf2onnx, tf, keras_model, input_signature, opset, output_path
128
+ )
129
+ return output_path
130
+
131
+
132
+ def _load_stardist_model(ndim: int, name_or_path: str):
133
+ _ensure_csbdeep_on_path()
134
+ _ensure_stardist_stub()
135
+ if ndim == 2:
136
+ module = importlib.import_module(
137
+ "senoquant.tabs.segmentation.stardist_onnx_utils._stardist.models"
138
+ )
139
+ model_cls = module.StarDist2D
140
+ elif ndim == 3:
141
+ module = importlib.import_module(
142
+ "senoquant.tabs.segmentation.stardist_onnx_utils._stardist.models"
143
+ )
144
+ model_cls = module.StarDist3D
145
+ else:
146
+ raise ValueError("ndim must be 2 or 3.")
147
+
148
+ model_path = Path(name_or_path)
149
+ if model_path.is_dir():
150
+ return model_cls(None, name=model_path.name, basedir=str(model_path.parent))
151
+ model = model_cls.from_pretrained(name_or_path)
152
+ if model is None:
153
+ raise ValueError(f"Unknown pretrained model: {name_or_path}")
154
+ return model
155
+
156
+
157
+ def _resolve_output_path(output: str | Path, default_name: str) -> Path:
158
+ output_path = Path(output)
159
+ if output_path.suffix.lower() != ".onnx":
160
+ output_path = output_path / default_name
161
+ return output_path
162
+
163
+
164
+ def _safe_name(name: str) -> str:
165
+ return re.sub(r"[^A-Za-z0-9._-]+", "_", name).strip("_")
166
+
167
+
168
+ def _ensure_csbdeep_on_path() -> None:
169
+ root = Path(__file__).resolve().parents[2]
170
+ csbdeep_root = root / "_csbdeep"
171
+ if csbdeep_root.exists():
172
+ csbdeep_path = str(csbdeep_root)
173
+ if csbdeep_path not in sys.path:
174
+ sys.path.insert(0, csbdeep_path)
175
+
176
+
177
+ def _ensure_stardist_stub() -> None:
178
+ base_pkg = "senoquant.tabs.segmentation.stardist_onnx_utils._stardist"
179
+ root = Path(__file__).resolve().parents[2] / "_stardist"
180
+ if base_pkg not in sys.modules:
181
+ pkg = types.ModuleType(base_pkg)
182
+ pkg.__path__ = [str(root)]
183
+ sys.modules[base_pkg] = pkg
184
+ geom_name = f"{base_pkg}.geometry"
185
+ if geom_name not in sys.modules:
186
+ geom = types.ModuleType(geom_name)
187
+
188
+ def _stub(*_args, **_kwargs):
189
+ raise RuntimeError("StarDist geometry helpers are unavailable in converter.")
190
+
191
+ geom.star_dist = _stub
192
+ geom.dist_to_coord = _stub
193
+ geom.polygons_to_label = _stub
194
+ geom.star_dist3D = _stub
195
+ geom.polyhedron_to_label = _stub
196
+ sys.modules[geom_name] = geom
197
+
198
+
199
+ def _import_tensorflow():
200
+ try:
201
+ import tensorflow as tf
202
+ except ImportError as exc:
203
+ raise RuntimeError("TensorFlow is required to export StarDist models.") from exc
204
+ return tf
205
+
206
+
207
+ def _import_tf2onnx():
208
+ try:
209
+ import numpy as np
210
+ # tf2onnx still references deprecated numpy aliases in some versions.
211
+ for alias, value in {
212
+ "bool": np.bool_,
213
+ "object": np.object_,
214
+ }.items():
215
+ if not hasattr(np, alias):
216
+ setattr(np, alias, value)
217
+ import tf2onnx
218
+ except ImportError as exc:
219
+ raise RuntimeError("tf2onnx is required to export StarDist models.") from exc
220
+ return tf2onnx
221
+
222
+
223
+ def _convert_via_frozen_graph(tf2onnx, tf, keras_model, input_signature, opset, output_path):
224
+ @tf.function
225
+ def _model_fn(*args):
226
+ return keras_model(*args, training=False)
227
+
228
+ concrete = _model_fn.get_concrete_function(*input_signature)
229
+
230
+ try:
231
+ from tensorflow.python.framework.convert_to_constants import (
232
+ convert_variables_to_constants_v2,
233
+ )
234
+ except ImportError as exc:
235
+ raise RuntimeError("TensorFlow constants converter is unavailable.") from exc
236
+
237
+ frozen_func = convert_variables_to_constants_v2(concrete)
238
+ graph_def = frozen_func.graph.as_graph_def()
239
+ inputs = [tensor.name for tensor in frozen_func.inputs]
240
+ outputs = [tensor.name for tensor in frozen_func.outputs]
241
+
242
+ _strip_empty_explicit_paddings(graph_def)
243
+
244
+ try:
245
+ tf2onnx.convert.from_graph_def(
246
+ graph_def,
247
+ input_names=inputs,
248
+ output_names=outputs,
249
+ opset=opset,
250
+ output_path=str(output_path),
251
+ )
252
+ except TypeError:
253
+ tf2onnx.convert.from_graph_def(
254
+ graph_def,
255
+ inputs,
256
+ outputs,
257
+ opset=opset,
258
+ output_path=str(output_path),
259
+ )
260
+
261
+
262
+ def _convert_via_saved_model(tf2onnx, keras_model, input_signature, opset, output_path):
263
+ if not hasattr(keras_model, "export"):
264
+ raise RuntimeError("Keras model does not support export().")
265
+ export_dir = Path(tempfile.mkdtemp(prefix="stardist_saved_model_"))
266
+ keras_model.export(
267
+ str(export_dir),
268
+ format="tf_saved_model",
269
+ input_signature=input_signature,
270
+ )
271
+ if hasattr(tf2onnx.convert, "from_saved_model"):
272
+ tf2onnx.convert.from_saved_model(
273
+ str(export_dir),
274
+ output_path=str(output_path),
275
+ opset=opset,
276
+ )
277
+ else:
278
+ raise RuntimeError("tf2onnx does not support from_saved_model.")
279
+
280
+
281
+ def _strip_empty_explicit_paddings(graph_def):
282
+ for node in graph_def.node:
283
+ attr = node.attr.get("explicit_paddings")
284
+ if attr is not None and len(attr.list.i) == 0:
285
+ del node.attr["explicit_paddings"]
@@ -0,0 +1,15 @@
1
+ """ONNX model inspection utilities."""
2
+
3
+ from .divisibility import infer_div_by, summarize_model_io
4
+ from .receptive_field import infer_receptive_field, recommend_tile_overlap
5
+ from .valid_sizes import infer_valid_size_patterns
6
+ from .probe import make_probe_image
7
+
8
+ __all__ = [
9
+ "infer_div_by",
10
+ "summarize_model_io",
11
+ "infer_receptive_field",
12
+ "recommend_tile_overlap",
13
+ "infer_valid_size_patterns",
14
+ "make_probe_image",
15
+ ]
@@ -0,0 +1,36 @@
1
+ """CLI for inspecting StarDist ONNX models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ from pathlib import Path
7
+
8
+ from .divisibility import infer_div_by, summarize_model_io
9
+
10
+
11
+ def _parse_args() -> argparse.Namespace:
12
+ parser = argparse.ArgumentParser(description="Inspect an ONNX model.")
13
+ parser.add_argument("model", type=Path, help="Path to the ONNX model.")
14
+ parser.add_argument("--ndim", type=int, choices=(2, 3), default=None)
15
+ return parser.parse_args()
16
+
17
+
18
+ def main() -> None:
19
+ args = _parse_args()
20
+ # Summarize model IO shapes to show dynamic/static dims.
21
+ summary = summarize_model_io(args.model)
22
+ # Infer the spatial divisibility required by the graph.
23
+ div_by = infer_div_by(args.model, ndim=args.ndim)
24
+
25
+ print(f"Model: {args.model}")
26
+ print("Inputs:")
27
+ for idx, dims in enumerate(summary["inputs"]):
28
+ print(f" [{idx}] {dims}")
29
+ print("Outputs:")
30
+ for idx, dims in enumerate(summary["outputs"]):
31
+ print(f" [{idx}] {dims}")
32
+ print(f"Inferred div_by: {div_by}")
33
+
34
+
35
+ if __name__ == "__main__":
36
+ main()
@@ -0,0 +1,193 @@
1
+ """Infer input divisibility constraints from an ONNX graph.
2
+
3
+ This module inspects ONNX graphs to infer the minimal spatial divisibility
4
+ required to run the model without shape mismatches through down/upsampling
5
+ paths (e.g., U-Net skip connections).
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from pathlib import Path
11
+ from typing import Iterable
12
+
13
+
14
+ def infer_div_by(model_path: str | Path, ndim: int | None = None) -> tuple[int, ...]:
15
+ """Infer the spatial divisibility required by an ONNX model.
16
+
17
+ This inspects the graph to estimate the cumulative downsampling factor
18
+ along spatial axes. The result is the minimal per-axis multiple that the
19
+ model input should be divisible by to avoid internal shape mismatches
20
+ (e.g., concatenation of encoder/decoder feature maps).
21
+
22
+ Parameters
23
+ ----------
24
+ model_path : str or pathlib.Path
25
+ Path to the ONNX model file.
26
+ ndim : int or None, optional
27
+ Number of spatial dimensions (2 or 3). If ``None``, the input rank is
28
+ used to infer dimensionality (rank 4 -> 2D, rank 5 -> 3D).
29
+
30
+ Returns
31
+ -------
32
+ tuple[int, ...]
33
+ Per-axis divisibility requirement (e.g., ``(16, 16)`` or
34
+ ``(8, 8, 8)``).
35
+
36
+ Notes
37
+ -----
38
+ - The algorithm tracks cumulative scaling factors by propagating
39
+ per-axis scale values through the graph.
40
+ - Downsampling ops (Conv/Pool with stride > 1) increase the scale.
41
+ - Upsampling ops (ConvTranspose/Resize) reduce the scale.
42
+ - The maximum scale observed across the graph is returned.
43
+ """
44
+ # Load the ONNX graph and find the primary input tensor.
45
+ model = _load_onnx(model_path)
46
+ input_name = model.graph.input[0].name if model.graph.input else None
47
+ if input_name is None:
48
+ raise ValueError("ONNX model has no graph inputs.")
49
+
50
+ # Determine the number of spatial dimensions if not specified.
51
+ if ndim is None:
52
+ ndim = _infer_ndim(model)
53
+
54
+ # Collect initializer tensors so we can read Resize scales, etc.
55
+ init_map = _initializers(model)
56
+
57
+ # Map tensor name -> per-axis scale relative to the original input.
58
+ scale_map: dict[str, list[float]] = {input_name: [1.0] * ndim}
59
+ # Track the maximum cumulative downsample per axis across the graph.
60
+ max_scale = [1.0] * ndim
61
+
62
+ for node in model.graph.node:
63
+ # Resolve the input scales for this node if we have them.
64
+ input_scales = [scale_map[name] for name in node.input if name in scale_map]
65
+ # Merge multiple inputs by taking the maximum scale per axis.
66
+ base = (
67
+ [max(values) for values in zip(*input_scales)]
68
+ if input_scales
69
+ else [1.0] * ndim
70
+ )
71
+ # Default: node does not change spatial scale.
72
+ factor = [1.0] * ndim
73
+
74
+ # Downsampling: increase scale by stride.
75
+ if node.op_type in ("Conv", "MaxPool", "AveragePool"):
76
+ strides = _get_attr_ints(node, "strides")
77
+ if strides:
78
+ factor = [float(s) for s in strides[-ndim:]]
79
+ # Upsampling: reduce scale by stride.
80
+ elif node.op_type == "ConvTranspose":
81
+ strides = _get_attr_ints(node, "strides")
82
+ if strides:
83
+ factor = [1.0 / float(s) if s else 1.0 for s in strides[-ndim:]]
84
+ # Resize/Upsample may carry explicit scales as initializers.
85
+ elif node.op_type in ("Resize", "Upsample"):
86
+ scales = _get_resize_scales(node, init_map)
87
+ if scales is not None and len(scales) >= ndim:
88
+ spatial = scales[-ndim:]
89
+ factor = [
90
+ 1.0 / float(s) if float(s) not in (0.0, 1.0) else 1.0
91
+ for s in spatial
92
+ ]
93
+
94
+ # Propagate the updated scale to all outputs of this node.
95
+ out_scale = [b * f for b, f in zip(base, factor)]
96
+ for output in node.output:
97
+ scale_map[output] = out_scale
98
+ # Record the maximum scale seen so far.
99
+ max_scale = [max(m, s) for m, s in zip(max_scale, out_scale)]
100
+
101
+ # Convert to integer divisibility requirements.
102
+ return tuple(int(round(s)) if s >= 1 else 1 for s in max_scale)
103
+
104
+
105
+ def summarize_model_io(model_path: str | Path) -> dict[str, list[list[str]]]:
106
+ """Return a simple summary of model input/output shapes.
107
+
108
+ Parameters
109
+ ----------
110
+ model_path : str or pathlib.Path
111
+ Path to the ONNX model file.
112
+
113
+ Returns
114
+ -------
115
+ dict
116
+ Dictionary with ``inputs`` and ``outputs`` lists. Each entry is a
117
+ list of dimension labels (e.g., ``"1"``, ``"H (dynamic)"``).
118
+ """
119
+ # Load the graph and format the shapes for user-friendly display.
120
+ model = _load_onnx(model_path)
121
+ inputs = [_format_shape(inp.type.tensor_type.shape) for inp in model.graph.input]
122
+ outputs = [_format_shape(out.type.tensor_type.shape) for out in model.graph.output]
123
+ return {"inputs": inputs, "outputs": outputs}
124
+
125
+
126
+ def _load_onnx(model_path: str | Path):
127
+ """Load an ONNX model, raising a helpful error if onnx is missing."""
128
+ try:
129
+ import onnx
130
+ except Exception as exc:
131
+ # Keep error explicit so users know to install the dependency.
132
+ raise RuntimeError("onnx is required for model inspection.") from exc
133
+ return onnx.load(str(model_path))
134
+
135
+
136
+ def _initializers(model) -> dict[str, Iterable[float]]:
137
+ """Materialize ONNX initializers into a name -> numpy array map."""
138
+ from onnx import numpy_helper
139
+
140
+ return {
141
+ init.name: numpy_helper.to_array(init)
142
+ for init in model.graph.initializer
143
+ }
144
+
145
+
146
+ def _infer_ndim(model) -> int:
147
+ """Infer the spatial dimensionality from the model input rank."""
148
+ if not model.graph.input:
149
+ raise ValueError("ONNX model has no graph inputs.")
150
+ shape = model.graph.input[0].type.tensor_type.shape
151
+ rank = len(shape.dim)
152
+ if rank == 4:
153
+ return 2
154
+ if rank == 5:
155
+ return 3
156
+ raise ValueError(f"Unsupported input rank {rank}; pass ndim explicitly.")
157
+
158
+
159
+ def _get_attr_ints(node, name: str) -> list[int] | None:
160
+ """Extract INT/INTS attributes from a node."""
161
+ for attr in node.attribute:
162
+ if attr.name == name:
163
+ if attr.type == attr.INTS:
164
+ return list(attr.ints)
165
+ if attr.type == attr.INT:
166
+ return [attr.i]
167
+ return None
168
+
169
+
170
+ def _get_resize_scales(node, init_map: dict[str, Iterable[float]]):
171
+ """Return resize scales from initializer inputs or node attributes."""
172
+ # Newer ONNX Resize uses a scales tensor input.
173
+ for input_name in reversed(node.input):
174
+ if input_name in init_map:
175
+ return init_map[input_name]
176
+ # Older Resize/Upsample variants may store scales as attributes.
177
+ for attr in node.attribute:
178
+ if attr.name == "scales" and attr.type == attr.FLOATS:
179
+ return list(attr.floats)
180
+ return None
181
+
182
+
183
+ def _format_shape(shape) -> list[str]:
184
+ """Format an ONNX TensorShapeProto into a list of human-readable dims."""
185
+ dims: list[str] = []
186
+ for dim in shape.dim:
187
+ if dim.dim_param:
188
+ dims.append(f"{dim.dim_param} (dynamic)")
189
+ elif dim.dim_value:
190
+ dims.append(str(dim.dim_value))
191
+ else:
192
+ dims.append("? (dynamic)")
193
+ return dims