senoquant 1.0.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- senoquant/__init__.py +6 -0
- senoquant/_reader.py +7 -0
- senoquant/_widget.py +33 -0
- senoquant/napari.yaml +83 -0
- senoquant/reader/__init__.py +5 -0
- senoquant/reader/core.py +369 -0
- senoquant/tabs/__init__.py +15 -0
- senoquant/tabs/batch/__init__.py +10 -0
- senoquant/tabs/batch/backend.py +641 -0
- senoquant/tabs/batch/config.py +270 -0
- senoquant/tabs/batch/frontend.py +1283 -0
- senoquant/tabs/batch/io.py +326 -0
- senoquant/tabs/batch/layers.py +86 -0
- senoquant/tabs/quantification/__init__.py +1 -0
- senoquant/tabs/quantification/backend.py +228 -0
- senoquant/tabs/quantification/features/__init__.py +80 -0
- senoquant/tabs/quantification/features/base.py +142 -0
- senoquant/tabs/quantification/features/marker/__init__.py +5 -0
- senoquant/tabs/quantification/features/marker/config.py +69 -0
- senoquant/tabs/quantification/features/marker/dialog.py +437 -0
- senoquant/tabs/quantification/features/marker/export.py +879 -0
- senoquant/tabs/quantification/features/marker/feature.py +119 -0
- senoquant/tabs/quantification/features/marker/morphology.py +285 -0
- senoquant/tabs/quantification/features/marker/rows.py +654 -0
- senoquant/tabs/quantification/features/marker/thresholding.py +46 -0
- senoquant/tabs/quantification/features/roi.py +346 -0
- senoquant/tabs/quantification/features/spots/__init__.py +5 -0
- senoquant/tabs/quantification/features/spots/config.py +62 -0
- senoquant/tabs/quantification/features/spots/dialog.py +477 -0
- senoquant/tabs/quantification/features/spots/export.py +1292 -0
- senoquant/tabs/quantification/features/spots/feature.py +112 -0
- senoquant/tabs/quantification/features/spots/morphology.py +279 -0
- senoquant/tabs/quantification/features/spots/rows.py +241 -0
- senoquant/tabs/quantification/frontend.py +815 -0
- senoquant/tabs/segmentation/__init__.py +1 -0
- senoquant/tabs/segmentation/backend.py +131 -0
- senoquant/tabs/segmentation/frontend.py +1009 -0
- senoquant/tabs/segmentation/models/__init__.py +5 -0
- senoquant/tabs/segmentation/models/base.py +146 -0
- senoquant/tabs/segmentation/models/cpsam/details.json +65 -0
- senoquant/tabs/segmentation/models/cpsam/model.py +150 -0
- senoquant/tabs/segmentation/models/default_2d/details.json +69 -0
- senoquant/tabs/segmentation/models/default_2d/model.py +664 -0
- senoquant/tabs/segmentation/models/default_3d/details.json +69 -0
- senoquant/tabs/segmentation/models/default_3d/model.py +682 -0
- senoquant/tabs/segmentation/models/hf.py +71 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/__init__.py +1 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/details.json +26 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/model.py +96 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/__init__.py +1 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/details.json +34 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/model.py +132 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/__init__.py +2 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/__init__.py +3 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/__init__.py +6 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/generate.py +470 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/prepare.py +273 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/rawdata.py +112 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/transform.py +384 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/blocks.py +184 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/losses.py +79 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/nets.py +165 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/predict.py +467 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/probability.py +67 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/train.py +148 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/io/__init__.py +163 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/__init__.py +52 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/base_model.py +329 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_isotropic.py +160 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_projection.py +178 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_standard.py +446 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_upsampling.py +54 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/config.py +254 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/pretrained.py +119 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/care_predict.py +180 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/__init__.py +5 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/plot_utils.py +159 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/six.py +18 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/tf.py +644 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/utils.py +272 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/version.py +1 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/docs/source/conf.py +368 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/setup.py +68 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_datagen.py +169 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_models.py +462 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_utils.py +166 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tools/create_zip_contents.py +34 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/__init__.py +30 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/big.py +624 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/bioimageio_utils.py +494 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/data/__init__.py +39 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/__init__.py +10 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom2d.py +215 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom3d.py +349 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/matching.py +483 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/__init__.py +28 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/base.py +1217 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model2d.py +594 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model3d.py +696 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/nms.py +384 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/__init__.py +2 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/plot.py +74 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/render.py +298 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/rays3d.py +373 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/sample_patches.py +65 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict2d.py +90 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict3d.py +93 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/utils.py +408 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/version.py +1 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/__init__.py +45 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/__init__.py +17 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/cli.py +55 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/core.py +285 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/__init__.py +15 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/cli.py +36 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/divisibility.py +193 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/probe.py +100 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/receptive_field.py +182 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/rf_cli.py +48 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/valid_sizes.py +278 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/__init__.py +8 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/core.py +157 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/__init__.py +17 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/core.py +226 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/__init__.py +5 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/core.py +401 -0
- senoquant/tabs/settings/__init__.py +1 -0
- senoquant/tabs/settings/backend.py +29 -0
- senoquant/tabs/settings/frontend.py +19 -0
- senoquant/tabs/spots/__init__.py +1 -0
- senoquant/tabs/spots/backend.py +139 -0
- senoquant/tabs/spots/frontend.py +800 -0
- senoquant/tabs/spots/models/__init__.py +5 -0
- senoquant/tabs/spots/models/base.py +94 -0
- senoquant/tabs/spots/models/rmp/details.json +61 -0
- senoquant/tabs/spots/models/rmp/model.py +499 -0
- senoquant/tabs/spots/models/udwt/details.json +103 -0
- senoquant/tabs/spots/models/udwt/model.py +482 -0
- senoquant/utils.py +25 -0
- senoquant-1.0.0b1.dist-info/METADATA +193 -0
- senoquant-1.0.0b1.dist-info/RECORD +148 -0
- senoquant-1.0.0b1.dist-info/WHEEL +5 -0
- senoquant-1.0.0b1.dist-info/entry_points.txt +2 -0
- senoquant-1.0.0b1.dist-info/licenses/LICENSE +28 -0
- senoquant-1.0.0b1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""ONNX tiling and prediction framework for StarDist."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"normalize",
|
|
5
|
+
"pad_for_tiling",
|
|
6
|
+
"pad_to_multiple",
|
|
7
|
+
"unpad_to_shape",
|
|
8
|
+
"TilingSpec",
|
|
9
|
+
"default_tiling_spec",
|
|
10
|
+
"predict_tiled",
|
|
11
|
+
"instances_from_prediction_2d",
|
|
12
|
+
"instances_from_prediction_3d",
|
|
13
|
+
"DEFAULT_2D_MODEL",
|
|
14
|
+
"DEFAULT_3D_MODEL",
|
|
15
|
+
"convert_model_to_onnx",
|
|
16
|
+
"convert_pretrained_2d",
|
|
17
|
+
"convert_pretrained_3d",
|
|
18
|
+
"infer_div_by",
|
|
19
|
+
"summarize_model_io",
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def __getattr__(name):
|
|
24
|
+
if name in {"normalize", "pad_for_tiling", "pad_to_multiple", "unpad_to_shape"}:
|
|
25
|
+
from . import pre as _pre
|
|
26
|
+
return getattr(_pre, name)
|
|
27
|
+
if name in {"TilingSpec", "default_tiling_spec", "predict_tiled"}:
|
|
28
|
+
from . import predict as _predict
|
|
29
|
+
return getattr(_predict, name)
|
|
30
|
+
if name in {"instances_from_prediction_2d", "instances_from_prediction_3d"}:
|
|
31
|
+
from . import post as _post
|
|
32
|
+
return getattr(_post, name)
|
|
33
|
+
if name in {
|
|
34
|
+
"DEFAULT_2D_MODEL",
|
|
35
|
+
"DEFAULT_3D_MODEL",
|
|
36
|
+
"convert_model_to_onnx",
|
|
37
|
+
"convert_pretrained_2d",
|
|
38
|
+
"convert_pretrained_3d",
|
|
39
|
+
}:
|
|
40
|
+
from . import convert as _convert
|
|
41
|
+
return getattr(_convert, name)
|
|
42
|
+
if name in {"infer_div_by", "summarize_model_io"}:
|
|
43
|
+
from . import inspect as _inspect
|
|
44
|
+
return getattr(_inspect, name)
|
|
45
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""ONNX conversion helpers for StarDist models."""
|
|
2
|
+
|
|
3
|
+
from .core import (
|
|
4
|
+
DEFAULT_2D_MODEL,
|
|
5
|
+
DEFAULT_3D_MODEL,
|
|
6
|
+
convert_model_to_onnx,
|
|
7
|
+
convert_pretrained_2d,
|
|
8
|
+
convert_pretrained_3d,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"DEFAULT_2D_MODEL",
|
|
13
|
+
"DEFAULT_3D_MODEL",
|
|
14
|
+
"convert_model_to_onnx",
|
|
15
|
+
"convert_pretrained_2d",
|
|
16
|
+
"convert_pretrained_3d",
|
|
17
|
+
]
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
"""CLI for converting StarDist models to ONNX."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import argparse
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
from .core import (
|
|
9
|
+
DEFAULT_2D_MODEL,
|
|
10
|
+
DEFAULT_3D_MODEL,
|
|
11
|
+
convert_pretrained_2d,
|
|
12
|
+
convert_pretrained_3d,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def main() -> None:
|
|
17
|
+
parser = argparse.ArgumentParser(description="Convert StarDist models to ONNX.")
|
|
18
|
+
parser.add_argument(
|
|
19
|
+
"--dim",
|
|
20
|
+
choices=("2", "3", "2d", "3d"),
|
|
21
|
+
default="2d",
|
|
22
|
+
help="Model dimensionality.",
|
|
23
|
+
)
|
|
24
|
+
parser.add_argument(
|
|
25
|
+
"--model",
|
|
26
|
+
default=None,
|
|
27
|
+
help="Pretrained model name/alias or model directory path.",
|
|
28
|
+
)
|
|
29
|
+
parser.add_argument(
|
|
30
|
+
"--output",
|
|
31
|
+
default=".",
|
|
32
|
+
help="Output directory or ONNX file path.",
|
|
33
|
+
)
|
|
34
|
+
parser.add_argument(
|
|
35
|
+
"--opset",
|
|
36
|
+
type=int,
|
|
37
|
+
default=18,
|
|
38
|
+
help="ONNX opset version to export.",
|
|
39
|
+
)
|
|
40
|
+
args = parser.parse_args()
|
|
41
|
+
|
|
42
|
+
dim = 2 if args.dim in ("2", "2d") else 3
|
|
43
|
+
model_name = args.model or (DEFAULT_2D_MODEL if dim == 2 else DEFAULT_3D_MODEL)
|
|
44
|
+
output = Path(args.output)
|
|
45
|
+
|
|
46
|
+
if dim == 2:
|
|
47
|
+
path = convert_pretrained_2d(model_name, output, opset=args.opset)
|
|
48
|
+
else:
|
|
49
|
+
path = convert_pretrained_3d(model_name, output, opset=args.opset)
|
|
50
|
+
|
|
51
|
+
print(f"Saved ONNX model to {path}")
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
if __name__ == "__main__":
|
|
55
|
+
main()
|
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
"""Convert StarDist Keras models to ONNX."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
import re
|
|
7
|
+
import sys
|
|
8
|
+
import types
|
|
9
|
+
import importlib
|
|
10
|
+
import tempfile
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
DEFAULT_2D_MODEL = "2D_versatile_fluo"
|
|
14
|
+
DEFAULT_3D_MODEL = "3D_demo"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def convert_pretrained_2d(
|
|
18
|
+
model_name: str = DEFAULT_2D_MODEL,
|
|
19
|
+
output: str | Path = ".",
|
|
20
|
+
*,
|
|
21
|
+
opset: int = 18,
|
|
22
|
+
) -> Path:
|
|
23
|
+
"""Convert a pretrained StarDist2D model to ONNX.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
model_name : str, optional
|
|
28
|
+
Pretrained model name or alias. Defaults to ``2D_versatile_fluo``.
|
|
29
|
+
output : str or pathlib.Path, optional
|
|
30
|
+
Output directory or ONNX file path. Defaults to the current directory.
|
|
31
|
+
opset : int, optional
|
|
32
|
+
ONNX opset version to export. Defaults to 13.
|
|
33
|
+
|
|
34
|
+
Returns
|
|
35
|
+
-------
|
|
36
|
+
pathlib.Path
|
|
37
|
+
Path to the saved ONNX model.
|
|
38
|
+
"""
|
|
39
|
+
model = _load_stardist_model(2, model_name)
|
|
40
|
+
output_path = _resolve_output_path(output, f"stardist2d_{_safe_name(model_name)}.onnx")
|
|
41
|
+
return convert_model_to_onnx(model, output_path, opset=opset)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def convert_pretrained_3d(
|
|
45
|
+
model_name: str = DEFAULT_3D_MODEL,
|
|
46
|
+
output: str | Path = ".",
|
|
47
|
+
*,
|
|
48
|
+
opset: int = 18,
|
|
49
|
+
) -> Path:
|
|
50
|
+
"""Convert a pretrained StarDist3D model to ONNX.
|
|
51
|
+
|
|
52
|
+
Parameters
|
|
53
|
+
----------
|
|
54
|
+
model_name : str, optional
|
|
55
|
+
Pretrained model name or alias. Defaults to ``3D_demo``.
|
|
56
|
+
output : str or pathlib.Path, optional
|
|
57
|
+
Output directory or ONNX file path. Defaults to the current directory.
|
|
58
|
+
opset : int, optional
|
|
59
|
+
ONNX opset version to export. Defaults to 13.
|
|
60
|
+
|
|
61
|
+
Returns
|
|
62
|
+
-------
|
|
63
|
+
pathlib.Path
|
|
64
|
+
Path to the saved ONNX model.
|
|
65
|
+
"""
|
|
66
|
+
model = _load_stardist_model(3, model_name)
|
|
67
|
+
output_path = _resolve_output_path(output, f"stardist3d_{_safe_name(model_name)}.onnx")
|
|
68
|
+
return convert_model_to_onnx(model, output_path, opset=opset)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def convert_model_to_onnx(model, output_path: str | Path, *, opset: int = 18) -> Path:
|
|
72
|
+
"""Convert a StarDist model instance to ONNX.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
model : object
|
|
77
|
+
StarDist2D or StarDist3D instance with a ``keras_model`` attribute.
|
|
78
|
+
output_path : str or pathlib.Path
|
|
79
|
+
File path to save the ONNX model.
|
|
80
|
+
opset : int, optional
|
|
81
|
+
ONNX opset version to export. Defaults to 13.
|
|
82
|
+
|
|
83
|
+
Returns
|
|
84
|
+
-------
|
|
85
|
+
pathlib.Path
|
|
86
|
+
Path to the saved ONNX model.
|
|
87
|
+
"""
|
|
88
|
+
tf = _import_tensorflow()
|
|
89
|
+
tf2onnx = _import_tf2onnx()
|
|
90
|
+
|
|
91
|
+
output_path = Path(output_path)
|
|
92
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
93
|
+
|
|
94
|
+
keras_model = model.keras_model
|
|
95
|
+
keras_model.trainable = False
|
|
96
|
+
|
|
97
|
+
input_tensor = keras_model.inputs[0]
|
|
98
|
+
input_name = input_tensor.name.split(":")[0]
|
|
99
|
+
input_shape = list(input_tensor.shape)
|
|
100
|
+
if input_shape and input_shape[0] is None:
|
|
101
|
+
input_shape[0] = 1
|
|
102
|
+
input_signature = (tf.TensorSpec(tuple(input_shape), input_tensor.dtype, name=input_name),)
|
|
103
|
+
try:
|
|
104
|
+
_convert_via_saved_model(tf2onnx, keras_model, input_signature, opset, output_path)
|
|
105
|
+
except Exception:
|
|
106
|
+
try:
|
|
107
|
+
output_names = [out.name.split(":")[0] for out in keras_model.outputs]
|
|
108
|
+
tf2onnx.convert.from_keras(
|
|
109
|
+
keras_model,
|
|
110
|
+
input_signature=input_signature,
|
|
111
|
+
opset=opset,
|
|
112
|
+
output_path=str(output_path),
|
|
113
|
+
output_names=output_names,
|
|
114
|
+
)
|
|
115
|
+
except TypeError:
|
|
116
|
+
try:
|
|
117
|
+
tf2onnx.convert.from_keras(
|
|
118
|
+
keras_model,
|
|
119
|
+
input_signature=input_signature,
|
|
120
|
+
opset=opset,
|
|
121
|
+
output_path=str(output_path),
|
|
122
|
+
)
|
|
123
|
+
except ValueError as exc:
|
|
124
|
+
if "explicit_paddings" not in str(exc):
|
|
125
|
+
raise
|
|
126
|
+
_convert_via_frozen_graph(
|
|
127
|
+
tf2onnx, tf, keras_model, input_signature, opset, output_path
|
|
128
|
+
)
|
|
129
|
+
return output_path
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _load_stardist_model(ndim: int, name_or_path: str):
|
|
133
|
+
_ensure_csbdeep_on_path()
|
|
134
|
+
_ensure_stardist_stub()
|
|
135
|
+
if ndim == 2:
|
|
136
|
+
module = importlib.import_module(
|
|
137
|
+
"senoquant.tabs.segmentation.stardist_onnx_utils._stardist.models"
|
|
138
|
+
)
|
|
139
|
+
model_cls = module.StarDist2D
|
|
140
|
+
elif ndim == 3:
|
|
141
|
+
module = importlib.import_module(
|
|
142
|
+
"senoquant.tabs.segmentation.stardist_onnx_utils._stardist.models"
|
|
143
|
+
)
|
|
144
|
+
model_cls = module.StarDist3D
|
|
145
|
+
else:
|
|
146
|
+
raise ValueError("ndim must be 2 or 3.")
|
|
147
|
+
|
|
148
|
+
model_path = Path(name_or_path)
|
|
149
|
+
if model_path.is_dir():
|
|
150
|
+
return model_cls(None, name=model_path.name, basedir=str(model_path.parent))
|
|
151
|
+
model = model_cls.from_pretrained(name_or_path)
|
|
152
|
+
if model is None:
|
|
153
|
+
raise ValueError(f"Unknown pretrained model: {name_or_path}")
|
|
154
|
+
return model
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def _resolve_output_path(output: str | Path, default_name: str) -> Path:
|
|
158
|
+
output_path = Path(output)
|
|
159
|
+
if output_path.suffix.lower() != ".onnx":
|
|
160
|
+
output_path = output_path / default_name
|
|
161
|
+
return output_path
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def _safe_name(name: str) -> str:
|
|
165
|
+
return re.sub(r"[^A-Za-z0-9._-]+", "_", name).strip("_")
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _ensure_csbdeep_on_path() -> None:
|
|
169
|
+
root = Path(__file__).resolve().parents[2]
|
|
170
|
+
csbdeep_root = root / "_csbdeep"
|
|
171
|
+
if csbdeep_root.exists():
|
|
172
|
+
csbdeep_path = str(csbdeep_root)
|
|
173
|
+
if csbdeep_path not in sys.path:
|
|
174
|
+
sys.path.insert(0, csbdeep_path)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def _ensure_stardist_stub() -> None:
|
|
178
|
+
base_pkg = "senoquant.tabs.segmentation.stardist_onnx_utils._stardist"
|
|
179
|
+
root = Path(__file__).resolve().parents[2] / "_stardist"
|
|
180
|
+
if base_pkg not in sys.modules:
|
|
181
|
+
pkg = types.ModuleType(base_pkg)
|
|
182
|
+
pkg.__path__ = [str(root)]
|
|
183
|
+
sys.modules[base_pkg] = pkg
|
|
184
|
+
geom_name = f"{base_pkg}.geometry"
|
|
185
|
+
if geom_name not in sys.modules:
|
|
186
|
+
geom = types.ModuleType(geom_name)
|
|
187
|
+
|
|
188
|
+
def _stub(*_args, **_kwargs):
|
|
189
|
+
raise RuntimeError("StarDist geometry helpers are unavailable in converter.")
|
|
190
|
+
|
|
191
|
+
geom.star_dist = _stub
|
|
192
|
+
geom.dist_to_coord = _stub
|
|
193
|
+
geom.polygons_to_label = _stub
|
|
194
|
+
geom.star_dist3D = _stub
|
|
195
|
+
geom.polyhedron_to_label = _stub
|
|
196
|
+
sys.modules[geom_name] = geom
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def _import_tensorflow():
|
|
200
|
+
try:
|
|
201
|
+
import tensorflow as tf
|
|
202
|
+
except ImportError as exc:
|
|
203
|
+
raise RuntimeError("TensorFlow is required to export StarDist models.") from exc
|
|
204
|
+
return tf
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def _import_tf2onnx():
|
|
208
|
+
try:
|
|
209
|
+
import numpy as np
|
|
210
|
+
# tf2onnx still references deprecated numpy aliases in some versions.
|
|
211
|
+
for alias, value in {
|
|
212
|
+
"bool": np.bool_,
|
|
213
|
+
"object": np.object_,
|
|
214
|
+
}.items():
|
|
215
|
+
if not hasattr(np, alias):
|
|
216
|
+
setattr(np, alias, value)
|
|
217
|
+
import tf2onnx
|
|
218
|
+
except ImportError as exc:
|
|
219
|
+
raise RuntimeError("tf2onnx is required to export StarDist models.") from exc
|
|
220
|
+
return tf2onnx
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def _convert_via_frozen_graph(tf2onnx, tf, keras_model, input_signature, opset, output_path):
|
|
224
|
+
@tf.function
|
|
225
|
+
def _model_fn(*args):
|
|
226
|
+
return keras_model(*args, training=False)
|
|
227
|
+
|
|
228
|
+
concrete = _model_fn.get_concrete_function(*input_signature)
|
|
229
|
+
|
|
230
|
+
try:
|
|
231
|
+
from tensorflow.python.framework.convert_to_constants import (
|
|
232
|
+
convert_variables_to_constants_v2,
|
|
233
|
+
)
|
|
234
|
+
except ImportError as exc:
|
|
235
|
+
raise RuntimeError("TensorFlow constants converter is unavailable.") from exc
|
|
236
|
+
|
|
237
|
+
frozen_func = convert_variables_to_constants_v2(concrete)
|
|
238
|
+
graph_def = frozen_func.graph.as_graph_def()
|
|
239
|
+
inputs = [tensor.name for tensor in frozen_func.inputs]
|
|
240
|
+
outputs = [tensor.name for tensor in frozen_func.outputs]
|
|
241
|
+
|
|
242
|
+
_strip_empty_explicit_paddings(graph_def)
|
|
243
|
+
|
|
244
|
+
try:
|
|
245
|
+
tf2onnx.convert.from_graph_def(
|
|
246
|
+
graph_def,
|
|
247
|
+
input_names=inputs,
|
|
248
|
+
output_names=outputs,
|
|
249
|
+
opset=opset,
|
|
250
|
+
output_path=str(output_path),
|
|
251
|
+
)
|
|
252
|
+
except TypeError:
|
|
253
|
+
tf2onnx.convert.from_graph_def(
|
|
254
|
+
graph_def,
|
|
255
|
+
inputs,
|
|
256
|
+
outputs,
|
|
257
|
+
opset=opset,
|
|
258
|
+
output_path=str(output_path),
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def _convert_via_saved_model(tf2onnx, keras_model, input_signature, opset, output_path):
|
|
263
|
+
if not hasattr(keras_model, "export"):
|
|
264
|
+
raise RuntimeError("Keras model does not support export().")
|
|
265
|
+
export_dir = Path(tempfile.mkdtemp(prefix="stardist_saved_model_"))
|
|
266
|
+
keras_model.export(
|
|
267
|
+
str(export_dir),
|
|
268
|
+
format="tf_saved_model",
|
|
269
|
+
input_signature=input_signature,
|
|
270
|
+
)
|
|
271
|
+
if hasattr(tf2onnx.convert, "from_saved_model"):
|
|
272
|
+
tf2onnx.convert.from_saved_model(
|
|
273
|
+
str(export_dir),
|
|
274
|
+
output_path=str(output_path),
|
|
275
|
+
opset=opset,
|
|
276
|
+
)
|
|
277
|
+
else:
|
|
278
|
+
raise RuntimeError("tf2onnx does not support from_saved_model.")
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def _strip_empty_explicit_paddings(graph_def):
|
|
282
|
+
for node in graph_def.node:
|
|
283
|
+
attr = node.attr.get("explicit_paddings")
|
|
284
|
+
if attr is not None and len(attr.list.i) == 0:
|
|
285
|
+
del node.attr["explicit_paddings"]
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""ONNX model inspection utilities."""
|
|
2
|
+
|
|
3
|
+
from .divisibility import infer_div_by, summarize_model_io
|
|
4
|
+
from .receptive_field import infer_receptive_field, recommend_tile_overlap
|
|
5
|
+
from .valid_sizes import infer_valid_size_patterns
|
|
6
|
+
from .probe import make_probe_image
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"infer_div_by",
|
|
10
|
+
"summarize_model_io",
|
|
11
|
+
"infer_receptive_field",
|
|
12
|
+
"recommend_tile_overlap",
|
|
13
|
+
"infer_valid_size_patterns",
|
|
14
|
+
"make_probe_image",
|
|
15
|
+
]
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
"""CLI for inspecting StarDist ONNX models."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import argparse
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
from .divisibility import infer_div_by, summarize_model_io
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _parse_args() -> argparse.Namespace:
|
|
12
|
+
parser = argparse.ArgumentParser(description="Inspect an ONNX model.")
|
|
13
|
+
parser.add_argument("model", type=Path, help="Path to the ONNX model.")
|
|
14
|
+
parser.add_argument("--ndim", type=int, choices=(2, 3), default=None)
|
|
15
|
+
return parser.parse_args()
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def main() -> None:
|
|
19
|
+
args = _parse_args()
|
|
20
|
+
# Summarize model IO shapes to show dynamic/static dims.
|
|
21
|
+
summary = summarize_model_io(args.model)
|
|
22
|
+
# Infer the spatial divisibility required by the graph.
|
|
23
|
+
div_by = infer_div_by(args.model, ndim=args.ndim)
|
|
24
|
+
|
|
25
|
+
print(f"Model: {args.model}")
|
|
26
|
+
print("Inputs:")
|
|
27
|
+
for idx, dims in enumerate(summary["inputs"]):
|
|
28
|
+
print(f" [{idx}] {dims}")
|
|
29
|
+
print("Outputs:")
|
|
30
|
+
for idx, dims in enumerate(summary["outputs"]):
|
|
31
|
+
print(f" [{idx}] {dims}")
|
|
32
|
+
print(f"Inferred div_by: {div_by}")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
if __name__ == "__main__":
|
|
36
|
+
main()
|
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
"""Infer input divisibility constraints from an ONNX graph.
|
|
2
|
+
|
|
3
|
+
This module inspects ONNX graphs to infer the minimal spatial divisibility
|
|
4
|
+
required to run the model without shape mismatches through down/upsampling
|
|
5
|
+
paths (e.g., U-Net skip connections).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Iterable
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def infer_div_by(model_path: str | Path, ndim: int | None = None) -> tuple[int, ...]:
|
|
15
|
+
"""Infer the spatial divisibility required by an ONNX model.
|
|
16
|
+
|
|
17
|
+
This inspects the graph to estimate the cumulative downsampling factor
|
|
18
|
+
along spatial axes. The result is the minimal per-axis multiple that the
|
|
19
|
+
model input should be divisible by to avoid internal shape mismatches
|
|
20
|
+
(e.g., concatenation of encoder/decoder feature maps).
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
model_path : str or pathlib.Path
|
|
25
|
+
Path to the ONNX model file.
|
|
26
|
+
ndim : int or None, optional
|
|
27
|
+
Number of spatial dimensions (2 or 3). If ``None``, the input rank is
|
|
28
|
+
used to infer dimensionality (rank 4 -> 2D, rank 5 -> 3D).
|
|
29
|
+
|
|
30
|
+
Returns
|
|
31
|
+
-------
|
|
32
|
+
tuple[int, ...]
|
|
33
|
+
Per-axis divisibility requirement (e.g., ``(16, 16)`` or
|
|
34
|
+
``(8, 8, 8)``).
|
|
35
|
+
|
|
36
|
+
Notes
|
|
37
|
+
-----
|
|
38
|
+
- The algorithm tracks cumulative scaling factors by propagating
|
|
39
|
+
per-axis scale values through the graph.
|
|
40
|
+
- Downsampling ops (Conv/Pool with stride > 1) increase the scale.
|
|
41
|
+
- Upsampling ops (ConvTranspose/Resize) reduce the scale.
|
|
42
|
+
- The maximum scale observed across the graph is returned.
|
|
43
|
+
"""
|
|
44
|
+
# Load the ONNX graph and find the primary input tensor.
|
|
45
|
+
model = _load_onnx(model_path)
|
|
46
|
+
input_name = model.graph.input[0].name if model.graph.input else None
|
|
47
|
+
if input_name is None:
|
|
48
|
+
raise ValueError("ONNX model has no graph inputs.")
|
|
49
|
+
|
|
50
|
+
# Determine the number of spatial dimensions if not specified.
|
|
51
|
+
if ndim is None:
|
|
52
|
+
ndim = _infer_ndim(model)
|
|
53
|
+
|
|
54
|
+
# Collect initializer tensors so we can read Resize scales, etc.
|
|
55
|
+
init_map = _initializers(model)
|
|
56
|
+
|
|
57
|
+
# Map tensor name -> per-axis scale relative to the original input.
|
|
58
|
+
scale_map: dict[str, list[float]] = {input_name: [1.0] * ndim}
|
|
59
|
+
# Track the maximum cumulative downsample per axis across the graph.
|
|
60
|
+
max_scale = [1.0] * ndim
|
|
61
|
+
|
|
62
|
+
for node in model.graph.node:
|
|
63
|
+
# Resolve the input scales for this node if we have them.
|
|
64
|
+
input_scales = [scale_map[name] for name in node.input if name in scale_map]
|
|
65
|
+
# Merge multiple inputs by taking the maximum scale per axis.
|
|
66
|
+
base = (
|
|
67
|
+
[max(values) for values in zip(*input_scales)]
|
|
68
|
+
if input_scales
|
|
69
|
+
else [1.0] * ndim
|
|
70
|
+
)
|
|
71
|
+
# Default: node does not change spatial scale.
|
|
72
|
+
factor = [1.0] * ndim
|
|
73
|
+
|
|
74
|
+
# Downsampling: increase scale by stride.
|
|
75
|
+
if node.op_type in ("Conv", "MaxPool", "AveragePool"):
|
|
76
|
+
strides = _get_attr_ints(node, "strides")
|
|
77
|
+
if strides:
|
|
78
|
+
factor = [float(s) for s in strides[-ndim:]]
|
|
79
|
+
# Upsampling: reduce scale by stride.
|
|
80
|
+
elif node.op_type == "ConvTranspose":
|
|
81
|
+
strides = _get_attr_ints(node, "strides")
|
|
82
|
+
if strides:
|
|
83
|
+
factor = [1.0 / float(s) if s else 1.0 for s in strides[-ndim:]]
|
|
84
|
+
# Resize/Upsample may carry explicit scales as initializers.
|
|
85
|
+
elif node.op_type in ("Resize", "Upsample"):
|
|
86
|
+
scales = _get_resize_scales(node, init_map)
|
|
87
|
+
if scales is not None and len(scales) >= ndim:
|
|
88
|
+
spatial = scales[-ndim:]
|
|
89
|
+
factor = [
|
|
90
|
+
1.0 / float(s) if float(s) not in (0.0, 1.0) else 1.0
|
|
91
|
+
for s in spatial
|
|
92
|
+
]
|
|
93
|
+
|
|
94
|
+
# Propagate the updated scale to all outputs of this node.
|
|
95
|
+
out_scale = [b * f for b, f in zip(base, factor)]
|
|
96
|
+
for output in node.output:
|
|
97
|
+
scale_map[output] = out_scale
|
|
98
|
+
# Record the maximum scale seen so far.
|
|
99
|
+
max_scale = [max(m, s) for m, s in zip(max_scale, out_scale)]
|
|
100
|
+
|
|
101
|
+
# Convert to integer divisibility requirements.
|
|
102
|
+
return tuple(int(round(s)) if s >= 1 else 1 for s in max_scale)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def summarize_model_io(model_path: str | Path) -> dict[str, list[list[str]]]:
|
|
106
|
+
"""Return a simple summary of model input/output shapes.
|
|
107
|
+
|
|
108
|
+
Parameters
|
|
109
|
+
----------
|
|
110
|
+
model_path : str or pathlib.Path
|
|
111
|
+
Path to the ONNX model file.
|
|
112
|
+
|
|
113
|
+
Returns
|
|
114
|
+
-------
|
|
115
|
+
dict
|
|
116
|
+
Dictionary with ``inputs`` and ``outputs`` lists. Each entry is a
|
|
117
|
+
list of dimension labels (e.g., ``"1"``, ``"H (dynamic)"``).
|
|
118
|
+
"""
|
|
119
|
+
# Load the graph and format the shapes for user-friendly display.
|
|
120
|
+
model = _load_onnx(model_path)
|
|
121
|
+
inputs = [_format_shape(inp.type.tensor_type.shape) for inp in model.graph.input]
|
|
122
|
+
outputs = [_format_shape(out.type.tensor_type.shape) for out in model.graph.output]
|
|
123
|
+
return {"inputs": inputs, "outputs": outputs}
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _load_onnx(model_path: str | Path):
|
|
127
|
+
"""Load an ONNX model, raising a helpful error if onnx is missing."""
|
|
128
|
+
try:
|
|
129
|
+
import onnx
|
|
130
|
+
except Exception as exc:
|
|
131
|
+
# Keep error explicit so users know to install the dependency.
|
|
132
|
+
raise RuntimeError("onnx is required for model inspection.") from exc
|
|
133
|
+
return onnx.load(str(model_path))
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _initializers(model) -> dict[str, Iterable[float]]:
|
|
137
|
+
"""Materialize ONNX initializers into a name -> numpy array map."""
|
|
138
|
+
from onnx import numpy_helper
|
|
139
|
+
|
|
140
|
+
return {
|
|
141
|
+
init.name: numpy_helper.to_array(init)
|
|
142
|
+
for init in model.graph.initializer
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _infer_ndim(model) -> int:
|
|
147
|
+
"""Infer the spatial dimensionality from the model input rank."""
|
|
148
|
+
if not model.graph.input:
|
|
149
|
+
raise ValueError("ONNX model has no graph inputs.")
|
|
150
|
+
shape = model.graph.input[0].type.tensor_type.shape
|
|
151
|
+
rank = len(shape.dim)
|
|
152
|
+
if rank == 4:
|
|
153
|
+
return 2
|
|
154
|
+
if rank == 5:
|
|
155
|
+
return 3
|
|
156
|
+
raise ValueError(f"Unsupported input rank {rank}; pass ndim explicitly.")
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _get_attr_ints(node, name: str) -> list[int] | None:
|
|
160
|
+
"""Extract INT/INTS attributes from a node."""
|
|
161
|
+
for attr in node.attribute:
|
|
162
|
+
if attr.name == name:
|
|
163
|
+
if attr.type == attr.INTS:
|
|
164
|
+
return list(attr.ints)
|
|
165
|
+
if attr.type == attr.INT:
|
|
166
|
+
return [attr.i]
|
|
167
|
+
return None
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def _get_resize_scales(node, init_map: dict[str, Iterable[float]]):
|
|
171
|
+
"""Return resize scales from initializer inputs or node attributes."""
|
|
172
|
+
# Newer ONNX Resize uses a scales tensor input.
|
|
173
|
+
for input_name in reversed(node.input):
|
|
174
|
+
if input_name in init_map:
|
|
175
|
+
return init_map[input_name]
|
|
176
|
+
# Older Resize/Upsample variants may store scales as attributes.
|
|
177
|
+
for attr in node.attribute:
|
|
178
|
+
if attr.name == "scales" and attr.type == attr.FLOATS:
|
|
179
|
+
return list(attr.floats)
|
|
180
|
+
return None
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def _format_shape(shape) -> list[str]:
|
|
184
|
+
"""Format an ONNX TensorShapeProto into a list of human-readable dims."""
|
|
185
|
+
dims: list[str] = []
|
|
186
|
+
for dim in shape.dim:
|
|
187
|
+
if dim.dim_param:
|
|
188
|
+
dims.append(f"{dim.dim_param} (dynamic)")
|
|
189
|
+
elif dim.dim_value:
|
|
190
|
+
dims.append(str(dim.dim_value))
|
|
191
|
+
else:
|
|
192
|
+
dims.append("? (dynamic)")
|
|
193
|
+
return dims
|