senoquant 1.0.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- senoquant/__init__.py +6 -0
- senoquant/_reader.py +7 -0
- senoquant/_widget.py +33 -0
- senoquant/napari.yaml +83 -0
- senoquant/reader/__init__.py +5 -0
- senoquant/reader/core.py +369 -0
- senoquant/tabs/__init__.py +15 -0
- senoquant/tabs/batch/__init__.py +10 -0
- senoquant/tabs/batch/backend.py +641 -0
- senoquant/tabs/batch/config.py +270 -0
- senoquant/tabs/batch/frontend.py +1283 -0
- senoquant/tabs/batch/io.py +326 -0
- senoquant/tabs/batch/layers.py +86 -0
- senoquant/tabs/quantification/__init__.py +1 -0
- senoquant/tabs/quantification/backend.py +228 -0
- senoquant/tabs/quantification/features/__init__.py +80 -0
- senoquant/tabs/quantification/features/base.py +142 -0
- senoquant/tabs/quantification/features/marker/__init__.py +5 -0
- senoquant/tabs/quantification/features/marker/config.py +69 -0
- senoquant/tabs/quantification/features/marker/dialog.py +437 -0
- senoquant/tabs/quantification/features/marker/export.py +879 -0
- senoquant/tabs/quantification/features/marker/feature.py +119 -0
- senoquant/tabs/quantification/features/marker/morphology.py +285 -0
- senoquant/tabs/quantification/features/marker/rows.py +654 -0
- senoquant/tabs/quantification/features/marker/thresholding.py +46 -0
- senoquant/tabs/quantification/features/roi.py +346 -0
- senoquant/tabs/quantification/features/spots/__init__.py +5 -0
- senoquant/tabs/quantification/features/spots/config.py +62 -0
- senoquant/tabs/quantification/features/spots/dialog.py +477 -0
- senoquant/tabs/quantification/features/spots/export.py +1292 -0
- senoquant/tabs/quantification/features/spots/feature.py +112 -0
- senoquant/tabs/quantification/features/spots/morphology.py +279 -0
- senoquant/tabs/quantification/features/spots/rows.py +241 -0
- senoquant/tabs/quantification/frontend.py +815 -0
- senoquant/tabs/segmentation/__init__.py +1 -0
- senoquant/tabs/segmentation/backend.py +131 -0
- senoquant/tabs/segmentation/frontend.py +1009 -0
- senoquant/tabs/segmentation/models/__init__.py +5 -0
- senoquant/tabs/segmentation/models/base.py +146 -0
- senoquant/tabs/segmentation/models/cpsam/details.json +65 -0
- senoquant/tabs/segmentation/models/cpsam/model.py +150 -0
- senoquant/tabs/segmentation/models/default_2d/details.json +69 -0
- senoquant/tabs/segmentation/models/default_2d/model.py +664 -0
- senoquant/tabs/segmentation/models/default_3d/details.json +69 -0
- senoquant/tabs/segmentation/models/default_3d/model.py +682 -0
- senoquant/tabs/segmentation/models/hf.py +71 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/__init__.py +1 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/details.json +26 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/model.py +96 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/__init__.py +1 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/details.json +34 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/model.py +132 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/__init__.py +2 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/__init__.py +3 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/__init__.py +6 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/generate.py +470 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/prepare.py +273 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/rawdata.py +112 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/transform.py +384 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/blocks.py +184 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/losses.py +79 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/nets.py +165 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/predict.py +467 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/probability.py +67 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/train.py +148 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/io/__init__.py +163 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/__init__.py +52 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/base_model.py +329 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_isotropic.py +160 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_projection.py +178 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_standard.py +446 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_upsampling.py +54 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/config.py +254 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/pretrained.py +119 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/care_predict.py +180 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/__init__.py +5 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/plot_utils.py +159 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/six.py +18 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/tf.py +644 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/utils.py +272 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/version.py +1 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/docs/source/conf.py +368 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/setup.py +68 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_datagen.py +169 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_models.py +462 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_utils.py +166 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tools/create_zip_contents.py +34 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/__init__.py +30 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/big.py +624 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/bioimageio_utils.py +494 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/data/__init__.py +39 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/__init__.py +10 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom2d.py +215 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom3d.py +349 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/matching.py +483 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/__init__.py +28 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/base.py +1217 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model2d.py +594 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model3d.py +696 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/nms.py +384 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/__init__.py +2 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/plot.py +74 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/render.py +298 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/rays3d.py +373 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/sample_patches.py +65 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict2d.py +90 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict3d.py +93 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/utils.py +408 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/version.py +1 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/__init__.py +45 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/__init__.py +17 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/cli.py +55 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/core.py +285 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/__init__.py +15 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/cli.py +36 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/divisibility.py +193 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/probe.py +100 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/receptive_field.py +182 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/rf_cli.py +48 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/valid_sizes.py +278 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/__init__.py +8 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/core.py +157 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/__init__.py +17 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/core.py +226 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/__init__.py +5 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/core.py +401 -0
- senoquant/tabs/settings/__init__.py +1 -0
- senoquant/tabs/settings/backend.py +29 -0
- senoquant/tabs/settings/frontend.py +19 -0
- senoquant/tabs/spots/__init__.py +1 -0
- senoquant/tabs/spots/backend.py +139 -0
- senoquant/tabs/spots/frontend.py +800 -0
- senoquant/tabs/spots/models/__init__.py +5 -0
- senoquant/tabs/spots/models/base.py +94 -0
- senoquant/tabs/spots/models/rmp/details.json +61 -0
- senoquant/tabs/spots/models/rmp/model.py +499 -0
- senoquant/tabs/spots/models/udwt/details.json +103 -0
- senoquant/tabs/spots/models/udwt/model.py +482 -0
- senoquant/utils.py +25 -0
- senoquant-1.0.0b1.dist-info/METADATA +193 -0
- senoquant-1.0.0b1.dist-info/RECORD +148 -0
- senoquant-1.0.0b1.dist-info/WHEEL +5 -0
- senoquant-1.0.0b1.dist-info/entry_points.txt +2 -0
- senoquant-1.0.0b1.dist-info/licenses/LICENSE +28 -0
- senoquant-1.0.0b1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""provides a faster sampling function"""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from csbdeep.utils import _raise, choice
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def sample_patches(datas, patch_size, n_samples, valid_inds=None, verbose=False):
|
|
8
|
+
"""optimized version of csbdeep.data.sample_patches_from_multiple_stacks
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
len(patch_size)==datas[0].ndim or _raise(ValueError())
|
|
12
|
+
|
|
13
|
+
if not all(( a.shape == datas[0].shape for a in datas )):
|
|
14
|
+
raise ValueError("all input shapes must be the same: %s" % (" / ".join(str(a.shape) for a in datas)))
|
|
15
|
+
|
|
16
|
+
if not all(( 0 < s <= d for s,d in zip(patch_size,datas[0].shape) )):
|
|
17
|
+
raise ValueError("patch_size %s negative or larger than data shape %s along some dimensions" % (str(patch_size), str(datas[0].shape)))
|
|
18
|
+
|
|
19
|
+
if valid_inds is None:
|
|
20
|
+
valid_inds = tuple(_s.ravel() for _s in np.meshgrid(*tuple(np.arange(p//2,s-p//2+1) for s,p in zip(datas[0].shape, patch_size))))
|
|
21
|
+
|
|
22
|
+
n_valid = len(valid_inds[0])
|
|
23
|
+
|
|
24
|
+
if n_valid == 0:
|
|
25
|
+
raise ValueError("no regions to sample from!")
|
|
26
|
+
|
|
27
|
+
idx = choice(range(n_valid), n_samples, replace=(n_valid < n_samples))
|
|
28
|
+
rand_inds = [v[idx] for v in valid_inds]
|
|
29
|
+
res = [np.stack([data[tuple(slice(_r-(_p//2),_r+_p-(_p//2)) for _r,_p in zip(r,patch_size))] for r in zip(*rand_inds)]) for data in datas]
|
|
30
|
+
|
|
31
|
+
return res
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def get_valid_inds(img, patch_size, patch_filter=None):
|
|
35
|
+
"""
|
|
36
|
+
Returns all indices of an image that
|
|
37
|
+
- can be used as center points for sampling patches of a given patch_size, and
|
|
38
|
+
- are part of the boolean mask given by the function patch_filter (if provided)
|
|
39
|
+
|
|
40
|
+
img: np.ndarray
|
|
41
|
+
patch_size: tuple of ints
|
|
42
|
+
the width of patches per img dimension,
|
|
43
|
+
patch_filter: None or callable
|
|
44
|
+
a function with signature patch_filter(img, patch_size) returning a boolean mask
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
len(patch_size)==img.ndim or _raise(ValueError())
|
|
48
|
+
|
|
49
|
+
if not all(( 0 < s <= d for s,d in zip(patch_size,img.shape))):
|
|
50
|
+
raise ValueError("patch_size %s negative or larger than image shape %s along some dimensions" % (str(patch_size), str(img.shape)))
|
|
51
|
+
|
|
52
|
+
if patch_filter is None:
|
|
53
|
+
# only cut border indices (which is faster)
|
|
54
|
+
patch_mask = np.ones(img.shape,dtype=bool)
|
|
55
|
+
valid_inds = tuple(np.arange(p // 2, s - p + p // 2 + 1).astype(np.uint32) for p, s in zip(patch_size, img.shape))
|
|
56
|
+
valid_inds = tuple(s.ravel() for s in np.meshgrid(*valid_inds, indexing='ij'))
|
|
57
|
+
else:
|
|
58
|
+
patch_mask = patch_filter(img, patch_size)
|
|
59
|
+
|
|
60
|
+
# get the valid indices
|
|
61
|
+
border_slices = tuple([slice(p // 2, s - p + p // 2 + 1) for p, s in zip(patch_size, img.shape)])
|
|
62
|
+
valid_inds = np.where(patch_mask[border_slices])
|
|
63
|
+
valid_inds = tuple((v + s.start).astype(np.uint32) for s, v in zip(border_slices, valid_inds))
|
|
64
|
+
|
|
65
|
+
return valid_inds
|
|
File without changes
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""
|
|
2
|
+
|
|
3
|
+
Command line script to perform prediction in 2D
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
import sys
|
|
10
|
+
import numpy as np
|
|
11
|
+
from tqdm import tqdm
|
|
12
|
+
import json
|
|
13
|
+
import argparse
|
|
14
|
+
import pprint
|
|
15
|
+
import pathlib
|
|
16
|
+
import warnings
|
|
17
|
+
|
|
18
|
+
def main():
|
|
19
|
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="""
|
|
20
|
+
Prediction script for a 2D stardist model, usage: stardist-predict -i input.tif -m model_folder_or_pretrained_name -o output_folder
|
|
21
|
+
|
|
22
|
+
""")
|
|
23
|
+
parser.add_argument("-i","--input", type=str, nargs="+", required=True, help = "input file (tiff)")
|
|
24
|
+
parser.add_argument("-o","--outdir", type=str, default='.', help = "output directory")
|
|
25
|
+
parser.add_argument("--outname", type=str, nargs="+", default='{img}.stardist.tif', help = "output file name (tiff)")
|
|
26
|
+
|
|
27
|
+
group = parser.add_mutually_exclusive_group(required=True)
|
|
28
|
+
group.add_argument('-m', '--model', type=str, default=None, help = "model folder / pretrained model to use")
|
|
29
|
+
parser.add_argument("--axes", type=str, default = None, help = "axes to use for the input, e.g. 'XYC'")
|
|
30
|
+
parser.add_argument("--n_tiles", type=int, nargs=2, default = None, help = "number of tiles to use for prediction")
|
|
31
|
+
parser.add_argument("--pnorm", type=float, nargs=2, default = [1,99.8], help = "pmin/pmax to use for normalization")
|
|
32
|
+
parser.add_argument("--prob_thresh", type=float, default=None, help = "prob_thresh for model (if not given use model default)")
|
|
33
|
+
parser.add_argument("--nms_thresh", type=float, default=None, help = "nms_thresh for model (if not given use model default)")
|
|
34
|
+
|
|
35
|
+
parser.add_argument("-v", "--verbose", action='store_true')
|
|
36
|
+
|
|
37
|
+
args = parser.parse_args()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
from csbdeep.utils import normalize
|
|
41
|
+
from csbdeep.models.base_model import get_registered_models
|
|
42
|
+
from stardist.models import StarDist2D
|
|
43
|
+
from imageio import imread
|
|
44
|
+
from tifffile import imwrite
|
|
45
|
+
|
|
46
|
+
get_registered_models(StarDist2D, verbose=True)
|
|
47
|
+
|
|
48
|
+
if pathlib.Path(args.model).is_dir():
|
|
49
|
+
model = StarDist2D(None, name=args.model)
|
|
50
|
+
else:
|
|
51
|
+
model = StarDist2D.from_pretrained(args.model)
|
|
52
|
+
|
|
53
|
+
if model is None:
|
|
54
|
+
raise ValueError(f"unknown model: {args.model}\navailable models:\n {get_registered_models(StarDist2D, verbose=True)}")
|
|
55
|
+
|
|
56
|
+
for fname in args.input:
|
|
57
|
+
if args.verbose:
|
|
58
|
+
print(f'reading image {fname}')
|
|
59
|
+
|
|
60
|
+
img = imread(fname)
|
|
61
|
+
|
|
62
|
+
if not img.ndim in (2,3):
|
|
63
|
+
raise ValueError(f'currently only 2d and 3d images are supported by the prediction script')
|
|
64
|
+
|
|
65
|
+
if args.axes is None:
|
|
66
|
+
args.axes = {2:'YX',3:'YXC'}[img.ndim]
|
|
67
|
+
|
|
68
|
+
if len(args.axes) != img.ndim:
|
|
69
|
+
raise ValueError(f'dimension of input ({img.ndim}) not the same as length of given axes ({len(args.axes)})')
|
|
70
|
+
|
|
71
|
+
if args.verbose:
|
|
72
|
+
print(f'loaded image of size {img.shape}')
|
|
73
|
+
|
|
74
|
+
if args.verbose:
|
|
75
|
+
print(f'normalizing...')
|
|
76
|
+
|
|
77
|
+
img = normalize(img,*args.pnorm)
|
|
78
|
+
|
|
79
|
+
labels, _ = model.predict_instances(img,
|
|
80
|
+
n_tiles=args.n_tiles,
|
|
81
|
+
prob_thresh=args.prob_thresh,
|
|
82
|
+
nms_thresh=args.nms_thresh)
|
|
83
|
+
out = pathlib.Path(args.outdir)
|
|
84
|
+
out.mkdir(parents=True,exist_ok=True)
|
|
85
|
+
|
|
86
|
+
imwrite(out/args.outname.format(img=pathlib.Path(fname).with_suffix('').name), labels, compression='zlib')
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
if __name__ == '__main__':
|
|
90
|
+
main()
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
"""
|
|
2
|
+
|
|
3
|
+
Command line script to perform prediction in 3D
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
import sys
|
|
10
|
+
import numpy as np
|
|
11
|
+
from tqdm import tqdm
|
|
12
|
+
import json
|
|
13
|
+
import argparse
|
|
14
|
+
import pprint
|
|
15
|
+
import pathlib
|
|
16
|
+
import warnings
|
|
17
|
+
|
|
18
|
+
def main():
|
|
19
|
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="""
|
|
20
|
+
Prediction script for a 3D stardist model, usage: stardist-predict -i input.tif -m model_folder_or_pretrained_name -o output_folder
|
|
21
|
+
|
|
22
|
+
""")
|
|
23
|
+
parser.add_argument("-i","--input", type=str, nargs="+", required=True, help = "input file (tiff)")
|
|
24
|
+
parser.add_argument("-o","--outdir", type=str, default='.', help = "output directory")
|
|
25
|
+
parser.add_argument("--outname", type=str, nargs="+", default='{img}.stardist.tif', help = "output file name (tiff)")
|
|
26
|
+
|
|
27
|
+
group = parser.add_mutually_exclusive_group(required=True)
|
|
28
|
+
group.add_argument('-m', '--model', type=str, default=None, help = "model folder / pretrained model to use")
|
|
29
|
+
parser.add_argument("--axes", type=str, default = None, help = "axes to use for the input, e.g. 'XYC'")
|
|
30
|
+
parser.add_argument("--n_tiles", type=int, nargs=3, default = None, help = "number of tiles to use for prediction")
|
|
31
|
+
parser.add_argument("--pnorm", type=float, nargs=2, default = [1,99.8], help = "pmin/pmax to use for normalization")
|
|
32
|
+
parser.add_argument("--prob_thresh", type=float, default=None, help = "prob_thresh for model (if not given use model default)")
|
|
33
|
+
parser.add_argument("--nms_thresh", type=float, default=None, help = "nms_thresh for model (if not given use model default)")
|
|
34
|
+
|
|
35
|
+
parser.add_argument("-v", "--verbose", action='store_true')
|
|
36
|
+
|
|
37
|
+
args = parser.parse_args()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
from csbdeep.utils import normalize
|
|
41
|
+
from csbdeep.models.base_model import get_registered_models
|
|
42
|
+
from stardist.models import StarDist3D
|
|
43
|
+
from tifffile import imwrite, imread
|
|
44
|
+
|
|
45
|
+
get_registered_models(StarDist3D, verbose=True)
|
|
46
|
+
|
|
47
|
+
if pathlib.Path(args.model).is_dir():
|
|
48
|
+
model = StarDist3D(None, name=args.model)
|
|
49
|
+
else:
|
|
50
|
+
model = StarDist3D.from_pretrained(args.model)
|
|
51
|
+
|
|
52
|
+
if model is None:
|
|
53
|
+
raise ValueError(f"unknown model: {args.model}\navailable models:\n {get_registered_models(StarDist2D, verbose=True)}")
|
|
54
|
+
|
|
55
|
+
for fname in args.input:
|
|
56
|
+
if args.verbose:
|
|
57
|
+
print(f'reading image {fname}')
|
|
58
|
+
|
|
59
|
+
if not pathlib.Path(fname).suffix.lower() in (".tif", ".tiff"):
|
|
60
|
+
raise ValueError('only tiff files supported in 3D for now')
|
|
61
|
+
|
|
62
|
+
img = imread(fname)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
if not img.ndim in (3,4):
|
|
66
|
+
raise ValueError(f'currently only 3d (or 4D with channel) images are supported by the prediction script')
|
|
67
|
+
|
|
68
|
+
if args.axes is None:
|
|
69
|
+
args.axes = {3:'ZYX',4:'ZYXC'}[img.ndim]
|
|
70
|
+
|
|
71
|
+
if len(args.axes) != img.ndim:
|
|
72
|
+
raise ValueError(f'dimension of input ({img.ndim}) not the same as length of given axes ({len(args.axes)})')
|
|
73
|
+
|
|
74
|
+
if args.verbose:
|
|
75
|
+
print(f'loaded image of size {img.shape}')
|
|
76
|
+
|
|
77
|
+
if args.verbose:
|
|
78
|
+
print(f'normalizing...')
|
|
79
|
+
|
|
80
|
+
img = normalize(img,*args.pnorm)
|
|
81
|
+
|
|
82
|
+
labels, _ = model.predict_instances(img,
|
|
83
|
+
n_tiles=args.n_tiles,
|
|
84
|
+
prob_thresh=args.prob_thresh,
|
|
85
|
+
nms_thresh=args.nms_thresh)
|
|
86
|
+
out = pathlib.Path(args.outdir)
|
|
87
|
+
out.mkdir(parents=True,exist_ok=True)
|
|
88
|
+
|
|
89
|
+
imwrite(out/args.outname.format(img=pathlib.Path(fname).with_suffix('').name), labels, compression='zlib')
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
if __name__ == '__main__':
|
|
93
|
+
main()
|
|
@@ -0,0 +1,408 @@
|
|
|
1
|
+
from __future__ import print_function, unicode_literals, absolute_import, division
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import warnings
|
|
5
|
+
import os
|
|
6
|
+
import datetime
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
from collections import defaultdict
|
|
9
|
+
from zipfile import ZipFile, ZIP_DEFLATED
|
|
10
|
+
from scipy.ndimage import distance_transform_edt, binary_fill_holes
|
|
11
|
+
from scipy.ndimage import find_objects
|
|
12
|
+
from scipy.optimize import minimize_scalar
|
|
13
|
+
from skimage.measure import regionprops
|
|
14
|
+
from csbdeep.utils import _raise
|
|
15
|
+
from csbdeep.utils.six import Path
|
|
16
|
+
from collections.abc import Iterable
|
|
17
|
+
|
|
18
|
+
from .matching import matching_dataset, _check_label_array
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
from edt import edt
|
|
23
|
+
_edt_available = True
|
|
24
|
+
try: _edt_parallel_max = len(os.sched_getaffinity(0))
|
|
25
|
+
except: _edt_parallel_max = 128
|
|
26
|
+
_edt_parallel_default = 4
|
|
27
|
+
_edt_parallel = os.environ.get('STARDIST_EDT_NUM_THREADS', _edt_parallel_default)
|
|
28
|
+
try:
|
|
29
|
+
_edt_parallel = min(_edt_parallel_max, int(_edt_parallel))
|
|
30
|
+
except ValueError as e:
|
|
31
|
+
warnings.warn(f"Invalid value ({_edt_parallel}) for STARDIST_EDT_NUM_THREADS. Using default value ({_edt_parallel_default}) instead.")
|
|
32
|
+
_edt_parallel = _edt_parallel_default
|
|
33
|
+
del _edt_parallel_default, _edt_parallel_max
|
|
34
|
+
except ImportError:
|
|
35
|
+
_edt_available = False
|
|
36
|
+
# warnings.warn("Could not find package edt... \nConsider installing it with \n pip install edt\nto improve training data generation performance.")
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def gputools_available():
|
|
41
|
+
try:
|
|
42
|
+
import gputools
|
|
43
|
+
except:
|
|
44
|
+
return False
|
|
45
|
+
return True
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def path_absolute(path_relative):
|
|
49
|
+
""" Get absolute path to resource"""
|
|
50
|
+
base_path = os.path.abspath(os.path.dirname(__file__))
|
|
51
|
+
return os.path.join(base_path, path_relative)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _is_power_of_2(i):
|
|
55
|
+
assert i > 0
|
|
56
|
+
e = np.log2(i)
|
|
57
|
+
return e == int(e)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _normalize_grid(grid,n):
|
|
61
|
+
try:
|
|
62
|
+
grid = tuple(grid)
|
|
63
|
+
(len(grid) == n and
|
|
64
|
+
all(map(np.isscalar,grid)) and
|
|
65
|
+
all(map(_is_power_of_2,grid))) or _raise(TypeError())
|
|
66
|
+
return tuple(int(g) for g in grid)
|
|
67
|
+
except (TypeError, AssertionError):
|
|
68
|
+
raise ValueError("grid = {grid} must be a list/tuple of length {n} with values that are power of 2".format(grid=grid, n=n))
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def edt_prob(lbl_img, anisotropy=None):
|
|
72
|
+
if _edt_available:
|
|
73
|
+
return _edt_prob_edt(lbl_img, anisotropy=anisotropy)
|
|
74
|
+
else:
|
|
75
|
+
# warnings.warn("Could not find package edt... \nConsider installing it with \n pip install edt\nto improve training data generation performance.")
|
|
76
|
+
return _edt_prob_scipy(lbl_img, anisotropy=anisotropy)
|
|
77
|
+
|
|
78
|
+
def _edt_prob_edt(lbl_img, anisotropy=None):
|
|
79
|
+
"""Perform EDT on each labeled object and normalize.
|
|
80
|
+
Internally uses https://github.com/seung-lab/euclidean-distance-transform-3d
|
|
81
|
+
that can handle multiple labels at once
|
|
82
|
+
"""
|
|
83
|
+
lbl_img = np.ascontiguousarray(lbl_img)
|
|
84
|
+
constant_img = lbl_img.min() == lbl_img.max() and lbl_img.flat[0] > 0
|
|
85
|
+
if constant_img:
|
|
86
|
+
warnings.warn("EDT of constant label image is ill-defined. (Assuming background around it.)")
|
|
87
|
+
# we just need to compute the edt once but then normalize it for each object
|
|
88
|
+
prob = edt(lbl_img, anisotropy=anisotropy, black_border=constant_img, parallel=_edt_parallel)
|
|
89
|
+
objects = find_objects(lbl_img)
|
|
90
|
+
for i,sl in enumerate(objects,1):
|
|
91
|
+
# i: object label id, sl: slices of object in lbl_img
|
|
92
|
+
if sl is None: continue
|
|
93
|
+
_mask = lbl_img[sl]==i
|
|
94
|
+
# normalize it
|
|
95
|
+
prob[sl][_mask] /= np.max(prob[sl][_mask]+1e-10)
|
|
96
|
+
return prob
|
|
97
|
+
|
|
98
|
+
def _edt_prob_scipy(lbl_img, anisotropy=None):
|
|
99
|
+
"""Perform EDT on each labeled object and normalize."""
|
|
100
|
+
def grow(sl,interior):
|
|
101
|
+
return tuple(slice(s.start-int(w[0]),s.stop+int(w[1])) for s,w in zip(sl,interior))
|
|
102
|
+
def shrink(interior):
|
|
103
|
+
return tuple(slice(int(w[0]),(-1 if w[1] else None)) for w in interior)
|
|
104
|
+
constant_img = lbl_img.min() == lbl_img.max() and lbl_img.flat[0] > 0
|
|
105
|
+
if constant_img:
|
|
106
|
+
lbl_img = np.pad(lbl_img, ((1,1),)*lbl_img.ndim, mode='constant')
|
|
107
|
+
warnings.warn("EDT of constant label image is ill-defined. (Assuming background around it.)")
|
|
108
|
+
objects = find_objects(lbl_img)
|
|
109
|
+
prob = np.zeros(lbl_img.shape,np.float32)
|
|
110
|
+
for i,sl in enumerate(objects,1):
|
|
111
|
+
# i: object label id, sl: slices of object in lbl_img
|
|
112
|
+
if sl is None: continue
|
|
113
|
+
interior = [(s.start>0,s.stop<sz) for s,sz in zip(sl,lbl_img.shape)]
|
|
114
|
+
# 1. grow object slice by 1 for all interior object bounding boxes
|
|
115
|
+
# 2. perform (correct) EDT for object with label id i
|
|
116
|
+
# 3. extract EDT for object of original slice and normalize
|
|
117
|
+
# 4. store edt for object only for pixels of given label id i
|
|
118
|
+
shrink_slice = shrink(interior)
|
|
119
|
+
grown_mask = lbl_img[grow(sl,interior)]==i
|
|
120
|
+
mask = grown_mask[shrink_slice]
|
|
121
|
+
edt = distance_transform_edt(grown_mask, sampling=anisotropy)[shrink_slice][mask]
|
|
122
|
+
prob[sl][mask] = edt/(np.max(edt)+1e-10)
|
|
123
|
+
if constant_img:
|
|
124
|
+
prob = prob[(slice(1,-1),)*lbl_img.ndim].copy()
|
|
125
|
+
return prob
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _fill_label_holes(lbl_img, **kwargs):
|
|
129
|
+
lbl_img_filled = np.zeros_like(lbl_img)
|
|
130
|
+
for l in (set(np.unique(lbl_img)) - set([0])):
|
|
131
|
+
mask = lbl_img==l
|
|
132
|
+
mask_filled = binary_fill_holes(mask,**kwargs)
|
|
133
|
+
lbl_img_filled[mask_filled] = l
|
|
134
|
+
return lbl_img_filled
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def fill_label_holes(lbl_img, **kwargs):
|
|
138
|
+
"""Fill small holes in label image."""
|
|
139
|
+
# TODO: refactor 'fill_label_holes' and 'edt_prob' to share code
|
|
140
|
+
def grow(sl,interior):
|
|
141
|
+
return tuple(slice(s.start-int(w[0]),s.stop+int(w[1])) for s,w in zip(sl,interior))
|
|
142
|
+
def shrink(interior):
|
|
143
|
+
return tuple(slice(int(w[0]),(-1 if w[1] else None)) for w in interior)
|
|
144
|
+
objects = find_objects(lbl_img)
|
|
145
|
+
lbl_img_filled = np.zeros_like(lbl_img)
|
|
146
|
+
for i,sl in enumerate(objects,1):
|
|
147
|
+
if sl is None: continue
|
|
148
|
+
interior = [(s.start>0,s.stop<sz) for s,sz in zip(sl,lbl_img.shape)]
|
|
149
|
+
shrink_slice = shrink(interior)
|
|
150
|
+
grown_mask = lbl_img[grow(sl,interior)]==i
|
|
151
|
+
mask_filled = binary_fill_holes(grown_mask,**kwargs)[shrink_slice]
|
|
152
|
+
lbl_img_filled[sl][mask_filled] = i
|
|
153
|
+
if lbl_img.min() < 0:
|
|
154
|
+
# preserve (and fill holes in) negative labels ('find_objects' ignores these)
|
|
155
|
+
lbl_neg_filled = -fill_label_holes(-np.minimum(lbl_img, 0))
|
|
156
|
+
mask = lbl_neg_filled < 0
|
|
157
|
+
lbl_img_filled[mask] = lbl_neg_filled[mask]
|
|
158
|
+
return lbl_img_filled
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def sample_points(n_samples, mask, prob=None, b=2):
|
|
162
|
+
"""sample points to draw some of the associated polygons"""
|
|
163
|
+
if b is not None and b > 0:
|
|
164
|
+
# ignore image boundary, since predictions may not be reliable
|
|
165
|
+
mask_b = np.zeros_like(mask)
|
|
166
|
+
mask_b[b:-b,b:-b] = True
|
|
167
|
+
else:
|
|
168
|
+
mask_b = True
|
|
169
|
+
|
|
170
|
+
points = np.nonzero(mask & mask_b)
|
|
171
|
+
|
|
172
|
+
if prob is not None:
|
|
173
|
+
# weighted sampling via prob
|
|
174
|
+
w = prob[points[0],points[1]].astype(np.float64)
|
|
175
|
+
w /= np.sum(w)
|
|
176
|
+
ind = np.random.choice(len(points[0]), n_samples, replace=True, p=w)
|
|
177
|
+
else:
|
|
178
|
+
ind = np.random.choice(len(points[0]), n_samples, replace=True)
|
|
179
|
+
|
|
180
|
+
points = points[0][ind], points[1][ind]
|
|
181
|
+
points = np.stack(points,axis=-1)
|
|
182
|
+
return points
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def calculate_extents(lbl, func=np.median):
|
|
186
|
+
""" Aggregate bounding box sizes of objects in label images. """
|
|
187
|
+
if (isinstance(lbl,np.ndarray) and lbl.ndim==4) or (not isinstance(lbl,np.ndarray) and isinstance(lbl,Iterable)):
|
|
188
|
+
return func(np.stack([calculate_extents(_lbl,func) for _lbl in lbl], axis=0), axis=0)
|
|
189
|
+
|
|
190
|
+
n = lbl.ndim
|
|
191
|
+
n in (2,3) or _raise(ValueError("label image should be 2- or 3-dimensional (or pass a list of these)"))
|
|
192
|
+
|
|
193
|
+
regs = regionprops(lbl)
|
|
194
|
+
if len(regs) == 0:
|
|
195
|
+
return np.zeros(n)
|
|
196
|
+
else:
|
|
197
|
+
extents = np.array([np.array(r.bbox[n:])-np.array(r.bbox[:n]) for r in regs])
|
|
198
|
+
return func(extents, axis=0)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def polyroi_bytearray(x,y,pos=None,subpixel=True):
|
|
202
|
+
""" Byte array of polygon roi with provided x and y coordinates
|
|
203
|
+
See https://github.com/imagej/imagej1/blob/master/ij/io/RoiDecoder.java
|
|
204
|
+
"""
|
|
205
|
+
import struct
|
|
206
|
+
def _int16(x):
|
|
207
|
+
return int(x).to_bytes(2, byteorder='big', signed=True)
|
|
208
|
+
def _uint16(x):
|
|
209
|
+
return int(x).to_bytes(2, byteorder='big', signed=False)
|
|
210
|
+
def _int32(x):
|
|
211
|
+
return int(x).to_bytes(4, byteorder='big', signed=True)
|
|
212
|
+
def _float(x):
|
|
213
|
+
return struct.pack(">f", x)
|
|
214
|
+
|
|
215
|
+
subpixel = bool(subpixel)
|
|
216
|
+
# add offset since pixel center is at (0.5,0.5) in ImageJ
|
|
217
|
+
x_raw = np.asarray(x).ravel() + 0.5
|
|
218
|
+
y_raw = np.asarray(y).ravel() + 0.5
|
|
219
|
+
x = np.round(x_raw)
|
|
220
|
+
y = np.round(y_raw)
|
|
221
|
+
assert len(x) == len(y)
|
|
222
|
+
top, left, bottom, right = y.min(), x.min(), y.max(), x.max() # bbox
|
|
223
|
+
|
|
224
|
+
n_coords = len(x)
|
|
225
|
+
bytes_header = 64
|
|
226
|
+
bytes_total = bytes_header + n_coords*2*2 + subpixel*n_coords*2*4
|
|
227
|
+
B = [0] * bytes_total
|
|
228
|
+
B[ 0: 4] = map(ord,'Iout') # magic start
|
|
229
|
+
B[ 4: 6] = _int16(227) # version
|
|
230
|
+
B[ 6: 8] = _int16(0) # roi type (0 = polygon)
|
|
231
|
+
B[ 8:10] = _int16(top) # bbox top
|
|
232
|
+
B[10:12] = _int16(left) # bbox left
|
|
233
|
+
B[12:14] = _int16(bottom) # bbox bottom
|
|
234
|
+
B[14:16] = _int16(right) # bbox right
|
|
235
|
+
B[16:18] = _uint16(n_coords) # number of coordinates
|
|
236
|
+
if subpixel:
|
|
237
|
+
B[50:52] = _int16(128) # subpixel resolution (option flag)
|
|
238
|
+
if pos is not None:
|
|
239
|
+
B[56:60] = _int32(pos) # position (C, Z, or T)
|
|
240
|
+
|
|
241
|
+
for i,(_x,_y) in enumerate(zip(x,y)):
|
|
242
|
+
xs = bytes_header + 2*i
|
|
243
|
+
ys = xs + 2*n_coords
|
|
244
|
+
B[xs:xs+2] = _int16(_x - left)
|
|
245
|
+
B[ys:ys+2] = _int16(_y - top)
|
|
246
|
+
|
|
247
|
+
if subpixel:
|
|
248
|
+
base1 = bytes_header + n_coords*2*2
|
|
249
|
+
base2 = base1 + n_coords*4
|
|
250
|
+
for i,(_x,_y) in enumerate(zip(x_raw,y_raw)):
|
|
251
|
+
xs = base1 + 4*i
|
|
252
|
+
ys = base2 + 4*i
|
|
253
|
+
B[xs:xs+4] = _float(_x)
|
|
254
|
+
B[ys:ys+4] = _float(_y)
|
|
255
|
+
|
|
256
|
+
return bytearray(B)
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def export_imagej_rois(fname, polygons, set_position=True, subpixel=True, compression=ZIP_DEFLATED):
|
|
260
|
+
""" polygons assumed to be a list of arrays with shape (id,2,c) """
|
|
261
|
+
|
|
262
|
+
if isinstance(polygons,np.ndarray):
|
|
263
|
+
polygons = (polygons,)
|
|
264
|
+
|
|
265
|
+
fname = Path(fname)
|
|
266
|
+
if fname.suffix == '.zip':
|
|
267
|
+
fname = fname.with_suffix('')
|
|
268
|
+
|
|
269
|
+
with ZipFile(str(fname)+'.zip', mode='w', compression=compression) as roizip:
|
|
270
|
+
for pos,polygroup in enumerate(polygons,start=1):
|
|
271
|
+
for i,poly in enumerate(polygroup,start=1):
|
|
272
|
+
roi = polyroi_bytearray(poly[1],poly[0], pos=(pos if set_position else None), subpixel=subpixel)
|
|
273
|
+
roizip.writestr('{pos:03d}_{i:03d}.roi'.format(pos=pos,i=i), roi)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def optimize_threshold(Y, Yhat, model, nms_thresh, measure='accuracy', iou_threshs=[0.3,0.5,0.7], bracket=None, tol=1e-2, maxiter=20, verbose=1):
|
|
277
|
+
""" Tune prob_thresh for provided (fixed) nms_thresh to maximize matching score (for given measure and averaged over iou_threshs). """
|
|
278
|
+
np.isscalar(nms_thresh) or _raise(ValueError("nms_thresh must be a scalar"))
|
|
279
|
+
iou_threshs = [iou_threshs] if np.isscalar(iou_threshs) else iou_threshs
|
|
280
|
+
values = dict()
|
|
281
|
+
|
|
282
|
+
if bracket is None:
|
|
283
|
+
max_prob = max([np.max(prob) for prob, dist in Yhat])
|
|
284
|
+
bracket = max_prob/2, max_prob
|
|
285
|
+
# print("bracket =", bracket)
|
|
286
|
+
|
|
287
|
+
with tqdm(total=maxiter, disable=(verbose!=1), desc="NMS threshold = %g" % nms_thresh) as progress:
|
|
288
|
+
|
|
289
|
+
def fn(thr):
|
|
290
|
+
prob_thresh = np.clip(thr, *bracket)
|
|
291
|
+
value = values.get(prob_thresh)
|
|
292
|
+
if value is None:
|
|
293
|
+
Y_instances = [model._instances_from_prediction(y.shape, *prob_dist, prob_thresh=prob_thresh, nms_thresh=nms_thresh)[0] for y,prob_dist in zip(Y,Yhat)]
|
|
294
|
+
stats = matching_dataset(Y, Y_instances, thresh=iou_threshs, show_progress=False, parallel=True)
|
|
295
|
+
values[prob_thresh] = value = np.mean([s._asdict()[measure] for s in stats])
|
|
296
|
+
if verbose > 1:
|
|
297
|
+
print("{now} thresh: {prob_thresh:f} {measure}: {value:f}".format(
|
|
298
|
+
now = datetime.datetime.now().strftime('%H:%M:%S'),
|
|
299
|
+
prob_thresh = prob_thresh,
|
|
300
|
+
measure = measure,
|
|
301
|
+
value = value,
|
|
302
|
+
), flush=True)
|
|
303
|
+
else:
|
|
304
|
+
progress.update()
|
|
305
|
+
progress.set_postfix_str("{prob_thresh:.3f} -> {value:.3f}".format(prob_thresh=prob_thresh, value=value))
|
|
306
|
+
progress.refresh()
|
|
307
|
+
return -value
|
|
308
|
+
|
|
309
|
+
opt = minimize_scalar(fn, method='golden', bracket=bracket, tol=tol, options={'maxiter': maxiter})
|
|
310
|
+
|
|
311
|
+
verbose > 1 and print('\n',opt, flush=True)
|
|
312
|
+
return opt.x, -opt.fun
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def _invert_dict(d):
|
|
316
|
+
""" return v-> [k_1,k_2,k_3....] for k,v in d"""
|
|
317
|
+
res = defaultdict(list)
|
|
318
|
+
for k,v in d.items():
|
|
319
|
+
res[v].append(k)
|
|
320
|
+
return res
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def mask_to_categorical(y, n_classes, classes, return_cls_dict=False):
|
|
324
|
+
"""generates a multi-channel categorical class map
|
|
325
|
+
|
|
326
|
+
Parameters
|
|
327
|
+
----------
|
|
328
|
+
y : n-dimensional ndarray
|
|
329
|
+
integer label array
|
|
330
|
+
n_classes : int
|
|
331
|
+
Number of different classes (without background)
|
|
332
|
+
classes: dict, integer, or None
|
|
333
|
+
the label to class assignment
|
|
334
|
+
can be
|
|
335
|
+
- dict {label -> class_id}
|
|
336
|
+
the value of class_id can be
|
|
337
|
+
0 -> background class
|
|
338
|
+
1...n_classes -> the respective object class (1 ... n_classes)
|
|
339
|
+
None -> ignore object (prob is set to -1 for the pixels of the object, except for background class)
|
|
340
|
+
- single integer value or None -> broadcast value to all labels
|
|
341
|
+
|
|
342
|
+
Returns
|
|
343
|
+
-------
|
|
344
|
+
probability map of shape y.shape+(n_classes+1,) (first channel is background)
|
|
345
|
+
|
|
346
|
+
"""
|
|
347
|
+
|
|
348
|
+
_check_label_array(y, 'y')
|
|
349
|
+
if not (np.issubdtype(type(n_classes), np.integer) and n_classes>=1):
|
|
350
|
+
raise ValueError(f"n_classes is '{n_classes}' but should be a positive integer")
|
|
351
|
+
|
|
352
|
+
y_labels = np.unique(y[y>0]).tolist()
|
|
353
|
+
|
|
354
|
+
# build dict class_id -> labels (inverse of classes)
|
|
355
|
+
if np.issubdtype(type(classes), np.integer) or classes is None:
|
|
356
|
+
classes = dict((k,classes) for k in y_labels)
|
|
357
|
+
elif isinstance(classes, dict):
|
|
358
|
+
pass
|
|
359
|
+
else:
|
|
360
|
+
raise ValueError("classes should be dict, single scalar, or None!")
|
|
361
|
+
|
|
362
|
+
if not set(y_labels).issubset(set(classes.keys())):
|
|
363
|
+
raise ValueError(f"all gt labels should be present in class dict provided \ngt_labels found\n{set(y_labels)}\nclass dict labels provided\n{set(classes.keys())}")
|
|
364
|
+
|
|
365
|
+
cls_dict = _invert_dict(classes)
|
|
366
|
+
|
|
367
|
+
# prob map
|
|
368
|
+
y_mask = np.zeros(y.shape+(n_classes+1,), np.float32)
|
|
369
|
+
|
|
370
|
+
for cls, labels in cls_dict.items():
|
|
371
|
+
if cls is None:
|
|
372
|
+
# prob == -1 will be used in the loss to ignore object
|
|
373
|
+
y_mask[np.isin(y, labels), :] = -1
|
|
374
|
+
elif np.issubdtype(type(cls), np.integer) and 0 <= cls <= n_classes:
|
|
375
|
+
y_mask[np.isin(y, labels), cls] = 1
|
|
376
|
+
else:
|
|
377
|
+
raise ValueError(f"Wrong class id '{cls}' (for n_classes={n_classes})")
|
|
378
|
+
|
|
379
|
+
# set 0/1 background prob (unaffected by None values for class ids)
|
|
380
|
+
y_mask[...,0] = (y==0)
|
|
381
|
+
|
|
382
|
+
if return_cls_dict:
|
|
383
|
+
return y_mask, cls_dict
|
|
384
|
+
else:
|
|
385
|
+
return y_mask
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def _is_floatarray(x):
|
|
389
|
+
return isinstance(x.dtype.type(0),np.floating)
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def abspath(root, relpath):
|
|
393
|
+
from pathlib import Path
|
|
394
|
+
root = Path(root)
|
|
395
|
+
if root.is_dir():
|
|
396
|
+
path = root/relpath
|
|
397
|
+
else:
|
|
398
|
+
path = root.parent/relpath
|
|
399
|
+
return str(path.absolute())
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
def grid_divisible_patch_size(patch_size, grid, warn=True):
|
|
403
|
+
patch_size, grid = tuple(patch_size), tuple(grid)
|
|
404
|
+
assert len(patch_size) == len(grid)
|
|
405
|
+
patch_size_divisible = tuple(int(np.ceil(sh/g)*g) for sh,g in zip(patch_size,grid))
|
|
406
|
+
if patch_size != patch_size_divisible and warn:
|
|
407
|
+
warnings.warn(f"increasing patch_size from {patch_size} to {patch_size_divisible}, since it was not evenly divisible by grid {grid}")
|
|
408
|
+
return patch_size_divisible
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = '0.9.2'
|