senoquant 1.0.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. senoquant/__init__.py +6 -0
  2. senoquant/_reader.py +7 -0
  3. senoquant/_widget.py +33 -0
  4. senoquant/napari.yaml +83 -0
  5. senoquant/reader/__init__.py +5 -0
  6. senoquant/reader/core.py +369 -0
  7. senoquant/tabs/__init__.py +15 -0
  8. senoquant/tabs/batch/__init__.py +10 -0
  9. senoquant/tabs/batch/backend.py +641 -0
  10. senoquant/tabs/batch/config.py +270 -0
  11. senoquant/tabs/batch/frontend.py +1283 -0
  12. senoquant/tabs/batch/io.py +326 -0
  13. senoquant/tabs/batch/layers.py +86 -0
  14. senoquant/tabs/quantification/__init__.py +1 -0
  15. senoquant/tabs/quantification/backend.py +228 -0
  16. senoquant/tabs/quantification/features/__init__.py +80 -0
  17. senoquant/tabs/quantification/features/base.py +142 -0
  18. senoquant/tabs/quantification/features/marker/__init__.py +5 -0
  19. senoquant/tabs/quantification/features/marker/config.py +69 -0
  20. senoquant/tabs/quantification/features/marker/dialog.py +437 -0
  21. senoquant/tabs/quantification/features/marker/export.py +879 -0
  22. senoquant/tabs/quantification/features/marker/feature.py +119 -0
  23. senoquant/tabs/quantification/features/marker/morphology.py +285 -0
  24. senoquant/tabs/quantification/features/marker/rows.py +654 -0
  25. senoquant/tabs/quantification/features/marker/thresholding.py +46 -0
  26. senoquant/tabs/quantification/features/roi.py +346 -0
  27. senoquant/tabs/quantification/features/spots/__init__.py +5 -0
  28. senoquant/tabs/quantification/features/spots/config.py +62 -0
  29. senoquant/tabs/quantification/features/spots/dialog.py +477 -0
  30. senoquant/tabs/quantification/features/spots/export.py +1292 -0
  31. senoquant/tabs/quantification/features/spots/feature.py +112 -0
  32. senoquant/tabs/quantification/features/spots/morphology.py +279 -0
  33. senoquant/tabs/quantification/features/spots/rows.py +241 -0
  34. senoquant/tabs/quantification/frontend.py +815 -0
  35. senoquant/tabs/segmentation/__init__.py +1 -0
  36. senoquant/tabs/segmentation/backend.py +131 -0
  37. senoquant/tabs/segmentation/frontend.py +1009 -0
  38. senoquant/tabs/segmentation/models/__init__.py +5 -0
  39. senoquant/tabs/segmentation/models/base.py +146 -0
  40. senoquant/tabs/segmentation/models/cpsam/details.json +65 -0
  41. senoquant/tabs/segmentation/models/cpsam/model.py +150 -0
  42. senoquant/tabs/segmentation/models/default_2d/details.json +69 -0
  43. senoquant/tabs/segmentation/models/default_2d/model.py +664 -0
  44. senoquant/tabs/segmentation/models/default_3d/details.json +69 -0
  45. senoquant/tabs/segmentation/models/default_3d/model.py +682 -0
  46. senoquant/tabs/segmentation/models/hf.py +71 -0
  47. senoquant/tabs/segmentation/models/nuclear_dilation/__init__.py +1 -0
  48. senoquant/tabs/segmentation/models/nuclear_dilation/details.json +26 -0
  49. senoquant/tabs/segmentation/models/nuclear_dilation/model.py +96 -0
  50. senoquant/tabs/segmentation/models/perinuclear_rings/__init__.py +1 -0
  51. senoquant/tabs/segmentation/models/perinuclear_rings/details.json +34 -0
  52. senoquant/tabs/segmentation/models/perinuclear_rings/model.py +132 -0
  53. senoquant/tabs/segmentation/stardist_onnx_utils/__init__.py +2 -0
  54. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/__init__.py +3 -0
  55. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/__init__.py +6 -0
  56. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/generate.py +470 -0
  57. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/prepare.py +273 -0
  58. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/rawdata.py +112 -0
  59. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/transform.py +384 -0
  60. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/__init__.py +0 -0
  61. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/blocks.py +184 -0
  62. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/losses.py +79 -0
  63. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/nets.py +165 -0
  64. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/predict.py +467 -0
  65. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/probability.py +67 -0
  66. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/train.py +148 -0
  67. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/io/__init__.py +163 -0
  68. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/__init__.py +52 -0
  69. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/base_model.py +329 -0
  70. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_isotropic.py +160 -0
  71. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_projection.py +178 -0
  72. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_standard.py +446 -0
  73. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_upsampling.py +54 -0
  74. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/config.py +254 -0
  75. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/pretrained.py +119 -0
  76. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/__init__.py +0 -0
  77. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/care_predict.py +180 -0
  78. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/__init__.py +5 -0
  79. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/plot_utils.py +159 -0
  80. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/six.py +18 -0
  81. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/tf.py +644 -0
  82. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/utils.py +272 -0
  83. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/version.py +1 -0
  84. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/docs/source/conf.py +368 -0
  85. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/setup.py +68 -0
  86. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_datagen.py +169 -0
  87. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_models.py +462 -0
  88. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_utils.py +166 -0
  89. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tools/create_zip_contents.py +34 -0
  90. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/__init__.py +30 -0
  91. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/big.py +624 -0
  92. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/bioimageio_utils.py +494 -0
  93. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/data/__init__.py +39 -0
  94. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/__init__.py +10 -0
  95. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom2d.py +215 -0
  96. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom3d.py +349 -0
  97. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/matching.py +483 -0
  98. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/__init__.py +28 -0
  99. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/base.py +1217 -0
  100. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model2d.py +594 -0
  101. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model3d.py +696 -0
  102. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/nms.py +384 -0
  103. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/__init__.py +2 -0
  104. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/plot.py +74 -0
  105. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/render.py +298 -0
  106. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/rays3d.py +373 -0
  107. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/sample_patches.py +65 -0
  108. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/__init__.py +0 -0
  109. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict2d.py +90 -0
  110. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict3d.py +93 -0
  111. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/utils.py +408 -0
  112. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/version.py +1 -0
  113. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/__init__.py +45 -0
  114. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/__init__.py +17 -0
  115. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/cli.py +55 -0
  116. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/core.py +285 -0
  117. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/__init__.py +15 -0
  118. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/cli.py +36 -0
  119. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/divisibility.py +193 -0
  120. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/probe.py +100 -0
  121. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/receptive_field.py +182 -0
  122. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/rf_cli.py +48 -0
  123. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/valid_sizes.py +278 -0
  124. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/__init__.py +8 -0
  125. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/core.py +157 -0
  126. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/__init__.py +17 -0
  127. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/core.py +226 -0
  128. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/__init__.py +5 -0
  129. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/core.py +401 -0
  130. senoquant/tabs/settings/__init__.py +1 -0
  131. senoquant/tabs/settings/backend.py +29 -0
  132. senoquant/tabs/settings/frontend.py +19 -0
  133. senoquant/tabs/spots/__init__.py +1 -0
  134. senoquant/tabs/spots/backend.py +139 -0
  135. senoquant/tabs/spots/frontend.py +800 -0
  136. senoquant/tabs/spots/models/__init__.py +5 -0
  137. senoquant/tabs/spots/models/base.py +94 -0
  138. senoquant/tabs/spots/models/rmp/details.json +61 -0
  139. senoquant/tabs/spots/models/rmp/model.py +499 -0
  140. senoquant/tabs/spots/models/udwt/details.json +103 -0
  141. senoquant/tabs/spots/models/udwt/model.py +482 -0
  142. senoquant/utils.py +25 -0
  143. senoquant-1.0.0b1.dist-info/METADATA +193 -0
  144. senoquant-1.0.0b1.dist-info/RECORD +148 -0
  145. senoquant-1.0.0b1.dist-info/WHEEL +5 -0
  146. senoquant-1.0.0b1.dist-info/entry_points.txt +2 -0
  147. senoquant-1.0.0b1.dist-info/licenses/LICENSE +28 -0
  148. senoquant-1.0.0b1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,65 @@
1
+ """provides a faster sampling function"""
2
+
3
+ import numpy as np
4
+ from csbdeep.utils import _raise, choice
5
+
6
+
7
+ def sample_patches(datas, patch_size, n_samples, valid_inds=None, verbose=False):
8
+ """optimized version of csbdeep.data.sample_patches_from_multiple_stacks
9
+ """
10
+
11
+ len(patch_size)==datas[0].ndim or _raise(ValueError())
12
+
13
+ if not all(( a.shape == datas[0].shape for a in datas )):
14
+ raise ValueError("all input shapes must be the same: %s" % (" / ".join(str(a.shape) for a in datas)))
15
+
16
+ if not all(( 0 < s <= d for s,d in zip(patch_size,datas[0].shape) )):
17
+ raise ValueError("patch_size %s negative or larger than data shape %s along some dimensions" % (str(patch_size), str(datas[0].shape)))
18
+
19
+ if valid_inds is None:
20
+ valid_inds = tuple(_s.ravel() for _s in np.meshgrid(*tuple(np.arange(p//2,s-p//2+1) for s,p in zip(datas[0].shape, patch_size))))
21
+
22
+ n_valid = len(valid_inds[0])
23
+
24
+ if n_valid == 0:
25
+ raise ValueError("no regions to sample from!")
26
+
27
+ idx = choice(range(n_valid), n_samples, replace=(n_valid < n_samples))
28
+ rand_inds = [v[idx] for v in valid_inds]
29
+ res = [np.stack([data[tuple(slice(_r-(_p//2),_r+_p-(_p//2)) for _r,_p in zip(r,patch_size))] for r in zip(*rand_inds)]) for data in datas]
30
+
31
+ return res
32
+
33
+
34
+ def get_valid_inds(img, patch_size, patch_filter=None):
35
+ """
36
+ Returns all indices of an image that
37
+ - can be used as center points for sampling patches of a given patch_size, and
38
+ - are part of the boolean mask given by the function patch_filter (if provided)
39
+
40
+ img: np.ndarray
41
+ patch_size: tuple of ints
42
+ the width of patches per img dimension,
43
+ patch_filter: None or callable
44
+ a function with signature patch_filter(img, patch_size) returning a boolean mask
45
+ """
46
+
47
+ len(patch_size)==img.ndim or _raise(ValueError())
48
+
49
+ if not all(( 0 < s <= d for s,d in zip(patch_size,img.shape))):
50
+ raise ValueError("patch_size %s negative or larger than image shape %s along some dimensions" % (str(patch_size), str(img.shape)))
51
+
52
+ if patch_filter is None:
53
+ # only cut border indices (which is faster)
54
+ patch_mask = np.ones(img.shape,dtype=bool)
55
+ valid_inds = tuple(np.arange(p // 2, s - p + p // 2 + 1).astype(np.uint32) for p, s in zip(patch_size, img.shape))
56
+ valid_inds = tuple(s.ravel() for s in np.meshgrid(*valid_inds, indexing='ij'))
57
+ else:
58
+ patch_mask = patch_filter(img, patch_size)
59
+
60
+ # get the valid indices
61
+ border_slices = tuple([slice(p // 2, s - p + p // 2 + 1) for p, s in zip(patch_size, img.shape)])
62
+ valid_inds = np.where(patch_mask[border_slices])
63
+ valid_inds = tuple((v + s.start).astype(np.uint32) for s, v in zip(border_slices, valid_inds))
64
+
65
+ return valid_inds
@@ -0,0 +1,90 @@
1
+ """
2
+
3
+ Command line script to perform prediction in 2D
4
+
5
+ """
6
+
7
+
8
+ import os
9
+ import sys
10
+ import numpy as np
11
+ from tqdm import tqdm
12
+ import json
13
+ import argparse
14
+ import pprint
15
+ import pathlib
16
+ import warnings
17
+
18
+ def main():
19
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="""
20
+ Prediction script for a 2D stardist model, usage: stardist-predict -i input.tif -m model_folder_or_pretrained_name -o output_folder
21
+
22
+ """)
23
+ parser.add_argument("-i","--input", type=str, nargs="+", required=True, help = "input file (tiff)")
24
+ parser.add_argument("-o","--outdir", type=str, default='.', help = "output directory")
25
+ parser.add_argument("--outname", type=str, nargs="+", default='{img}.stardist.tif', help = "output file name (tiff)")
26
+
27
+ group = parser.add_mutually_exclusive_group(required=True)
28
+ group.add_argument('-m', '--model', type=str, default=None, help = "model folder / pretrained model to use")
29
+ parser.add_argument("--axes", type=str, default = None, help = "axes to use for the input, e.g. 'XYC'")
30
+ parser.add_argument("--n_tiles", type=int, nargs=2, default = None, help = "number of tiles to use for prediction")
31
+ parser.add_argument("--pnorm", type=float, nargs=2, default = [1,99.8], help = "pmin/pmax to use for normalization")
32
+ parser.add_argument("--prob_thresh", type=float, default=None, help = "prob_thresh for model (if not given use model default)")
33
+ parser.add_argument("--nms_thresh", type=float, default=None, help = "nms_thresh for model (if not given use model default)")
34
+
35
+ parser.add_argument("-v", "--verbose", action='store_true')
36
+
37
+ args = parser.parse_args()
38
+
39
+
40
+ from csbdeep.utils import normalize
41
+ from csbdeep.models.base_model import get_registered_models
42
+ from stardist.models import StarDist2D
43
+ from imageio import imread
44
+ from tifffile import imwrite
45
+
46
+ get_registered_models(StarDist2D, verbose=True)
47
+
48
+ if pathlib.Path(args.model).is_dir():
49
+ model = StarDist2D(None, name=args.model)
50
+ else:
51
+ model = StarDist2D.from_pretrained(args.model)
52
+
53
+ if model is None:
54
+ raise ValueError(f"unknown model: {args.model}\navailable models:\n {get_registered_models(StarDist2D, verbose=True)}")
55
+
56
+ for fname in args.input:
57
+ if args.verbose:
58
+ print(f'reading image {fname}')
59
+
60
+ img = imread(fname)
61
+
62
+ if not img.ndim in (2,3):
63
+ raise ValueError(f'currently only 2d and 3d images are supported by the prediction script')
64
+
65
+ if args.axes is None:
66
+ args.axes = {2:'YX',3:'YXC'}[img.ndim]
67
+
68
+ if len(args.axes) != img.ndim:
69
+ raise ValueError(f'dimension of input ({img.ndim}) not the same as length of given axes ({len(args.axes)})')
70
+
71
+ if args.verbose:
72
+ print(f'loaded image of size {img.shape}')
73
+
74
+ if args.verbose:
75
+ print(f'normalizing...')
76
+
77
+ img = normalize(img,*args.pnorm)
78
+
79
+ labels, _ = model.predict_instances(img,
80
+ n_tiles=args.n_tiles,
81
+ prob_thresh=args.prob_thresh,
82
+ nms_thresh=args.nms_thresh)
83
+ out = pathlib.Path(args.outdir)
84
+ out.mkdir(parents=True,exist_ok=True)
85
+
86
+ imwrite(out/args.outname.format(img=pathlib.Path(fname).with_suffix('').name), labels, compression='zlib')
87
+
88
+
89
+ if __name__ == '__main__':
90
+ main()
@@ -0,0 +1,93 @@
1
+ """
2
+
3
+ Command line script to perform prediction in 3D
4
+
5
+ """
6
+
7
+
8
+ import os
9
+ import sys
10
+ import numpy as np
11
+ from tqdm import tqdm
12
+ import json
13
+ import argparse
14
+ import pprint
15
+ import pathlib
16
+ import warnings
17
+
18
+ def main():
19
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="""
20
+ Prediction script for a 3D stardist model, usage: stardist-predict -i input.tif -m model_folder_or_pretrained_name -o output_folder
21
+
22
+ """)
23
+ parser.add_argument("-i","--input", type=str, nargs="+", required=True, help = "input file (tiff)")
24
+ parser.add_argument("-o","--outdir", type=str, default='.', help = "output directory")
25
+ parser.add_argument("--outname", type=str, nargs="+", default='{img}.stardist.tif', help = "output file name (tiff)")
26
+
27
+ group = parser.add_mutually_exclusive_group(required=True)
28
+ group.add_argument('-m', '--model', type=str, default=None, help = "model folder / pretrained model to use")
29
+ parser.add_argument("--axes", type=str, default = None, help = "axes to use for the input, e.g. 'XYC'")
30
+ parser.add_argument("--n_tiles", type=int, nargs=3, default = None, help = "number of tiles to use for prediction")
31
+ parser.add_argument("--pnorm", type=float, nargs=2, default = [1,99.8], help = "pmin/pmax to use for normalization")
32
+ parser.add_argument("--prob_thresh", type=float, default=None, help = "prob_thresh for model (if not given use model default)")
33
+ parser.add_argument("--nms_thresh", type=float, default=None, help = "nms_thresh for model (if not given use model default)")
34
+
35
+ parser.add_argument("-v", "--verbose", action='store_true')
36
+
37
+ args = parser.parse_args()
38
+
39
+
40
+ from csbdeep.utils import normalize
41
+ from csbdeep.models.base_model import get_registered_models
42
+ from stardist.models import StarDist3D
43
+ from tifffile import imwrite, imread
44
+
45
+ get_registered_models(StarDist3D, verbose=True)
46
+
47
+ if pathlib.Path(args.model).is_dir():
48
+ model = StarDist3D(None, name=args.model)
49
+ else:
50
+ model = StarDist3D.from_pretrained(args.model)
51
+
52
+ if model is None:
53
+ raise ValueError(f"unknown model: {args.model}\navailable models:\n {get_registered_models(StarDist2D, verbose=True)}")
54
+
55
+ for fname in args.input:
56
+ if args.verbose:
57
+ print(f'reading image {fname}')
58
+
59
+ if not pathlib.Path(fname).suffix.lower() in (".tif", ".tiff"):
60
+ raise ValueError('only tiff files supported in 3D for now')
61
+
62
+ img = imread(fname)
63
+
64
+
65
+ if not img.ndim in (3,4):
66
+ raise ValueError(f'currently only 3d (or 4D with channel) images are supported by the prediction script')
67
+
68
+ if args.axes is None:
69
+ args.axes = {3:'ZYX',4:'ZYXC'}[img.ndim]
70
+
71
+ if len(args.axes) != img.ndim:
72
+ raise ValueError(f'dimension of input ({img.ndim}) not the same as length of given axes ({len(args.axes)})')
73
+
74
+ if args.verbose:
75
+ print(f'loaded image of size {img.shape}')
76
+
77
+ if args.verbose:
78
+ print(f'normalizing...')
79
+
80
+ img = normalize(img,*args.pnorm)
81
+
82
+ labels, _ = model.predict_instances(img,
83
+ n_tiles=args.n_tiles,
84
+ prob_thresh=args.prob_thresh,
85
+ nms_thresh=args.nms_thresh)
86
+ out = pathlib.Path(args.outdir)
87
+ out.mkdir(parents=True,exist_ok=True)
88
+
89
+ imwrite(out/args.outname.format(img=pathlib.Path(fname).with_suffix('').name), labels, compression='zlib')
90
+
91
+
92
+ if __name__ == '__main__':
93
+ main()
@@ -0,0 +1,408 @@
1
+ from __future__ import print_function, unicode_literals, absolute_import, division
2
+
3
+ import numpy as np
4
+ import warnings
5
+ import os
6
+ import datetime
7
+ from tqdm import tqdm
8
+ from collections import defaultdict
9
+ from zipfile import ZipFile, ZIP_DEFLATED
10
+ from scipy.ndimage import distance_transform_edt, binary_fill_holes
11
+ from scipy.ndimage import find_objects
12
+ from scipy.optimize import minimize_scalar
13
+ from skimage.measure import regionprops
14
+ from csbdeep.utils import _raise
15
+ from csbdeep.utils.six import Path
16
+ from collections.abc import Iterable
17
+
18
+ from .matching import matching_dataset, _check_label_array
19
+
20
+
21
+ try:
22
+ from edt import edt
23
+ _edt_available = True
24
+ try: _edt_parallel_max = len(os.sched_getaffinity(0))
25
+ except: _edt_parallel_max = 128
26
+ _edt_parallel_default = 4
27
+ _edt_parallel = os.environ.get('STARDIST_EDT_NUM_THREADS', _edt_parallel_default)
28
+ try:
29
+ _edt_parallel = min(_edt_parallel_max, int(_edt_parallel))
30
+ except ValueError as e:
31
+ warnings.warn(f"Invalid value ({_edt_parallel}) for STARDIST_EDT_NUM_THREADS. Using default value ({_edt_parallel_default}) instead.")
32
+ _edt_parallel = _edt_parallel_default
33
+ del _edt_parallel_default, _edt_parallel_max
34
+ except ImportError:
35
+ _edt_available = False
36
+ # warnings.warn("Could not find package edt... \nConsider installing it with \n pip install edt\nto improve training data generation performance.")
37
+ pass
38
+
39
+
40
+ def gputools_available():
41
+ try:
42
+ import gputools
43
+ except:
44
+ return False
45
+ return True
46
+
47
+
48
+ def path_absolute(path_relative):
49
+ """ Get absolute path to resource"""
50
+ base_path = os.path.abspath(os.path.dirname(__file__))
51
+ return os.path.join(base_path, path_relative)
52
+
53
+
54
+ def _is_power_of_2(i):
55
+ assert i > 0
56
+ e = np.log2(i)
57
+ return e == int(e)
58
+
59
+
60
+ def _normalize_grid(grid,n):
61
+ try:
62
+ grid = tuple(grid)
63
+ (len(grid) == n and
64
+ all(map(np.isscalar,grid)) and
65
+ all(map(_is_power_of_2,grid))) or _raise(TypeError())
66
+ return tuple(int(g) for g in grid)
67
+ except (TypeError, AssertionError):
68
+ raise ValueError("grid = {grid} must be a list/tuple of length {n} with values that are power of 2".format(grid=grid, n=n))
69
+
70
+
71
+ def edt_prob(lbl_img, anisotropy=None):
72
+ if _edt_available:
73
+ return _edt_prob_edt(lbl_img, anisotropy=anisotropy)
74
+ else:
75
+ # warnings.warn("Could not find package edt... \nConsider installing it with \n pip install edt\nto improve training data generation performance.")
76
+ return _edt_prob_scipy(lbl_img, anisotropy=anisotropy)
77
+
78
+ def _edt_prob_edt(lbl_img, anisotropy=None):
79
+ """Perform EDT on each labeled object and normalize.
80
+ Internally uses https://github.com/seung-lab/euclidean-distance-transform-3d
81
+ that can handle multiple labels at once
82
+ """
83
+ lbl_img = np.ascontiguousarray(lbl_img)
84
+ constant_img = lbl_img.min() == lbl_img.max() and lbl_img.flat[0] > 0
85
+ if constant_img:
86
+ warnings.warn("EDT of constant label image is ill-defined. (Assuming background around it.)")
87
+ # we just need to compute the edt once but then normalize it for each object
88
+ prob = edt(lbl_img, anisotropy=anisotropy, black_border=constant_img, parallel=_edt_parallel)
89
+ objects = find_objects(lbl_img)
90
+ for i,sl in enumerate(objects,1):
91
+ # i: object label id, sl: slices of object in lbl_img
92
+ if sl is None: continue
93
+ _mask = lbl_img[sl]==i
94
+ # normalize it
95
+ prob[sl][_mask] /= np.max(prob[sl][_mask]+1e-10)
96
+ return prob
97
+
98
+ def _edt_prob_scipy(lbl_img, anisotropy=None):
99
+ """Perform EDT on each labeled object and normalize."""
100
+ def grow(sl,interior):
101
+ return tuple(slice(s.start-int(w[0]),s.stop+int(w[1])) for s,w in zip(sl,interior))
102
+ def shrink(interior):
103
+ return tuple(slice(int(w[0]),(-1 if w[1] else None)) for w in interior)
104
+ constant_img = lbl_img.min() == lbl_img.max() and lbl_img.flat[0] > 0
105
+ if constant_img:
106
+ lbl_img = np.pad(lbl_img, ((1,1),)*lbl_img.ndim, mode='constant')
107
+ warnings.warn("EDT of constant label image is ill-defined. (Assuming background around it.)")
108
+ objects = find_objects(lbl_img)
109
+ prob = np.zeros(lbl_img.shape,np.float32)
110
+ for i,sl in enumerate(objects,1):
111
+ # i: object label id, sl: slices of object in lbl_img
112
+ if sl is None: continue
113
+ interior = [(s.start>0,s.stop<sz) for s,sz in zip(sl,lbl_img.shape)]
114
+ # 1. grow object slice by 1 for all interior object bounding boxes
115
+ # 2. perform (correct) EDT for object with label id i
116
+ # 3. extract EDT for object of original slice and normalize
117
+ # 4. store edt for object only for pixels of given label id i
118
+ shrink_slice = shrink(interior)
119
+ grown_mask = lbl_img[grow(sl,interior)]==i
120
+ mask = grown_mask[shrink_slice]
121
+ edt = distance_transform_edt(grown_mask, sampling=anisotropy)[shrink_slice][mask]
122
+ prob[sl][mask] = edt/(np.max(edt)+1e-10)
123
+ if constant_img:
124
+ prob = prob[(slice(1,-1),)*lbl_img.ndim].copy()
125
+ return prob
126
+
127
+
128
+ def _fill_label_holes(lbl_img, **kwargs):
129
+ lbl_img_filled = np.zeros_like(lbl_img)
130
+ for l in (set(np.unique(lbl_img)) - set([0])):
131
+ mask = lbl_img==l
132
+ mask_filled = binary_fill_holes(mask,**kwargs)
133
+ lbl_img_filled[mask_filled] = l
134
+ return lbl_img_filled
135
+
136
+
137
+ def fill_label_holes(lbl_img, **kwargs):
138
+ """Fill small holes in label image."""
139
+ # TODO: refactor 'fill_label_holes' and 'edt_prob' to share code
140
+ def grow(sl,interior):
141
+ return tuple(slice(s.start-int(w[0]),s.stop+int(w[1])) for s,w in zip(sl,interior))
142
+ def shrink(interior):
143
+ return tuple(slice(int(w[0]),(-1 if w[1] else None)) for w in interior)
144
+ objects = find_objects(lbl_img)
145
+ lbl_img_filled = np.zeros_like(lbl_img)
146
+ for i,sl in enumerate(objects,1):
147
+ if sl is None: continue
148
+ interior = [(s.start>0,s.stop<sz) for s,sz in zip(sl,lbl_img.shape)]
149
+ shrink_slice = shrink(interior)
150
+ grown_mask = lbl_img[grow(sl,interior)]==i
151
+ mask_filled = binary_fill_holes(grown_mask,**kwargs)[shrink_slice]
152
+ lbl_img_filled[sl][mask_filled] = i
153
+ if lbl_img.min() < 0:
154
+ # preserve (and fill holes in) negative labels ('find_objects' ignores these)
155
+ lbl_neg_filled = -fill_label_holes(-np.minimum(lbl_img, 0))
156
+ mask = lbl_neg_filled < 0
157
+ lbl_img_filled[mask] = lbl_neg_filled[mask]
158
+ return lbl_img_filled
159
+
160
+
161
+ def sample_points(n_samples, mask, prob=None, b=2):
162
+ """sample points to draw some of the associated polygons"""
163
+ if b is not None and b > 0:
164
+ # ignore image boundary, since predictions may not be reliable
165
+ mask_b = np.zeros_like(mask)
166
+ mask_b[b:-b,b:-b] = True
167
+ else:
168
+ mask_b = True
169
+
170
+ points = np.nonzero(mask & mask_b)
171
+
172
+ if prob is not None:
173
+ # weighted sampling via prob
174
+ w = prob[points[0],points[1]].astype(np.float64)
175
+ w /= np.sum(w)
176
+ ind = np.random.choice(len(points[0]), n_samples, replace=True, p=w)
177
+ else:
178
+ ind = np.random.choice(len(points[0]), n_samples, replace=True)
179
+
180
+ points = points[0][ind], points[1][ind]
181
+ points = np.stack(points,axis=-1)
182
+ return points
183
+
184
+
185
+ def calculate_extents(lbl, func=np.median):
186
+ """ Aggregate bounding box sizes of objects in label images. """
187
+ if (isinstance(lbl,np.ndarray) and lbl.ndim==4) or (not isinstance(lbl,np.ndarray) and isinstance(lbl,Iterable)):
188
+ return func(np.stack([calculate_extents(_lbl,func) for _lbl in lbl], axis=0), axis=0)
189
+
190
+ n = lbl.ndim
191
+ n in (2,3) or _raise(ValueError("label image should be 2- or 3-dimensional (or pass a list of these)"))
192
+
193
+ regs = regionprops(lbl)
194
+ if len(regs) == 0:
195
+ return np.zeros(n)
196
+ else:
197
+ extents = np.array([np.array(r.bbox[n:])-np.array(r.bbox[:n]) for r in regs])
198
+ return func(extents, axis=0)
199
+
200
+
201
+ def polyroi_bytearray(x,y,pos=None,subpixel=True):
202
+ """ Byte array of polygon roi with provided x and y coordinates
203
+ See https://github.com/imagej/imagej1/blob/master/ij/io/RoiDecoder.java
204
+ """
205
+ import struct
206
+ def _int16(x):
207
+ return int(x).to_bytes(2, byteorder='big', signed=True)
208
+ def _uint16(x):
209
+ return int(x).to_bytes(2, byteorder='big', signed=False)
210
+ def _int32(x):
211
+ return int(x).to_bytes(4, byteorder='big', signed=True)
212
+ def _float(x):
213
+ return struct.pack(">f", x)
214
+
215
+ subpixel = bool(subpixel)
216
+ # add offset since pixel center is at (0.5,0.5) in ImageJ
217
+ x_raw = np.asarray(x).ravel() + 0.5
218
+ y_raw = np.asarray(y).ravel() + 0.5
219
+ x = np.round(x_raw)
220
+ y = np.round(y_raw)
221
+ assert len(x) == len(y)
222
+ top, left, bottom, right = y.min(), x.min(), y.max(), x.max() # bbox
223
+
224
+ n_coords = len(x)
225
+ bytes_header = 64
226
+ bytes_total = bytes_header + n_coords*2*2 + subpixel*n_coords*2*4
227
+ B = [0] * bytes_total
228
+ B[ 0: 4] = map(ord,'Iout') # magic start
229
+ B[ 4: 6] = _int16(227) # version
230
+ B[ 6: 8] = _int16(0) # roi type (0 = polygon)
231
+ B[ 8:10] = _int16(top) # bbox top
232
+ B[10:12] = _int16(left) # bbox left
233
+ B[12:14] = _int16(bottom) # bbox bottom
234
+ B[14:16] = _int16(right) # bbox right
235
+ B[16:18] = _uint16(n_coords) # number of coordinates
236
+ if subpixel:
237
+ B[50:52] = _int16(128) # subpixel resolution (option flag)
238
+ if pos is not None:
239
+ B[56:60] = _int32(pos) # position (C, Z, or T)
240
+
241
+ for i,(_x,_y) in enumerate(zip(x,y)):
242
+ xs = bytes_header + 2*i
243
+ ys = xs + 2*n_coords
244
+ B[xs:xs+2] = _int16(_x - left)
245
+ B[ys:ys+2] = _int16(_y - top)
246
+
247
+ if subpixel:
248
+ base1 = bytes_header + n_coords*2*2
249
+ base2 = base1 + n_coords*4
250
+ for i,(_x,_y) in enumerate(zip(x_raw,y_raw)):
251
+ xs = base1 + 4*i
252
+ ys = base2 + 4*i
253
+ B[xs:xs+4] = _float(_x)
254
+ B[ys:ys+4] = _float(_y)
255
+
256
+ return bytearray(B)
257
+
258
+
259
+ def export_imagej_rois(fname, polygons, set_position=True, subpixel=True, compression=ZIP_DEFLATED):
260
+ """ polygons assumed to be a list of arrays with shape (id,2,c) """
261
+
262
+ if isinstance(polygons,np.ndarray):
263
+ polygons = (polygons,)
264
+
265
+ fname = Path(fname)
266
+ if fname.suffix == '.zip':
267
+ fname = fname.with_suffix('')
268
+
269
+ with ZipFile(str(fname)+'.zip', mode='w', compression=compression) as roizip:
270
+ for pos,polygroup in enumerate(polygons,start=1):
271
+ for i,poly in enumerate(polygroup,start=1):
272
+ roi = polyroi_bytearray(poly[1],poly[0], pos=(pos if set_position else None), subpixel=subpixel)
273
+ roizip.writestr('{pos:03d}_{i:03d}.roi'.format(pos=pos,i=i), roi)
274
+
275
+
276
+ def optimize_threshold(Y, Yhat, model, nms_thresh, measure='accuracy', iou_threshs=[0.3,0.5,0.7], bracket=None, tol=1e-2, maxiter=20, verbose=1):
277
+ """ Tune prob_thresh for provided (fixed) nms_thresh to maximize matching score (for given measure and averaged over iou_threshs). """
278
+ np.isscalar(nms_thresh) or _raise(ValueError("nms_thresh must be a scalar"))
279
+ iou_threshs = [iou_threshs] if np.isscalar(iou_threshs) else iou_threshs
280
+ values = dict()
281
+
282
+ if bracket is None:
283
+ max_prob = max([np.max(prob) for prob, dist in Yhat])
284
+ bracket = max_prob/2, max_prob
285
+ # print("bracket =", bracket)
286
+
287
+ with tqdm(total=maxiter, disable=(verbose!=1), desc="NMS threshold = %g" % nms_thresh) as progress:
288
+
289
+ def fn(thr):
290
+ prob_thresh = np.clip(thr, *bracket)
291
+ value = values.get(prob_thresh)
292
+ if value is None:
293
+ Y_instances = [model._instances_from_prediction(y.shape, *prob_dist, prob_thresh=prob_thresh, nms_thresh=nms_thresh)[0] for y,prob_dist in zip(Y,Yhat)]
294
+ stats = matching_dataset(Y, Y_instances, thresh=iou_threshs, show_progress=False, parallel=True)
295
+ values[prob_thresh] = value = np.mean([s._asdict()[measure] for s in stats])
296
+ if verbose > 1:
297
+ print("{now} thresh: {prob_thresh:f} {measure}: {value:f}".format(
298
+ now = datetime.datetime.now().strftime('%H:%M:%S'),
299
+ prob_thresh = prob_thresh,
300
+ measure = measure,
301
+ value = value,
302
+ ), flush=True)
303
+ else:
304
+ progress.update()
305
+ progress.set_postfix_str("{prob_thresh:.3f} -> {value:.3f}".format(prob_thresh=prob_thresh, value=value))
306
+ progress.refresh()
307
+ return -value
308
+
309
+ opt = minimize_scalar(fn, method='golden', bracket=bracket, tol=tol, options={'maxiter': maxiter})
310
+
311
+ verbose > 1 and print('\n',opt, flush=True)
312
+ return opt.x, -opt.fun
313
+
314
+
315
+ def _invert_dict(d):
316
+ """ return v-> [k_1,k_2,k_3....] for k,v in d"""
317
+ res = defaultdict(list)
318
+ for k,v in d.items():
319
+ res[v].append(k)
320
+ return res
321
+
322
+
323
+ def mask_to_categorical(y, n_classes, classes, return_cls_dict=False):
324
+ """generates a multi-channel categorical class map
325
+
326
+ Parameters
327
+ ----------
328
+ y : n-dimensional ndarray
329
+ integer label array
330
+ n_classes : int
331
+ Number of different classes (without background)
332
+ classes: dict, integer, or None
333
+ the label to class assignment
334
+ can be
335
+ - dict {label -> class_id}
336
+ the value of class_id can be
337
+ 0 -> background class
338
+ 1...n_classes -> the respective object class (1 ... n_classes)
339
+ None -> ignore object (prob is set to -1 for the pixels of the object, except for background class)
340
+ - single integer value or None -> broadcast value to all labels
341
+
342
+ Returns
343
+ -------
344
+ probability map of shape y.shape+(n_classes+1,) (first channel is background)
345
+
346
+ """
347
+
348
+ _check_label_array(y, 'y')
349
+ if not (np.issubdtype(type(n_classes), np.integer) and n_classes>=1):
350
+ raise ValueError(f"n_classes is '{n_classes}' but should be a positive integer")
351
+
352
+ y_labels = np.unique(y[y>0]).tolist()
353
+
354
+ # build dict class_id -> labels (inverse of classes)
355
+ if np.issubdtype(type(classes), np.integer) or classes is None:
356
+ classes = dict((k,classes) for k in y_labels)
357
+ elif isinstance(classes, dict):
358
+ pass
359
+ else:
360
+ raise ValueError("classes should be dict, single scalar, or None!")
361
+
362
+ if not set(y_labels).issubset(set(classes.keys())):
363
+ raise ValueError(f"all gt labels should be present in class dict provided \ngt_labels found\n{set(y_labels)}\nclass dict labels provided\n{set(classes.keys())}")
364
+
365
+ cls_dict = _invert_dict(classes)
366
+
367
+ # prob map
368
+ y_mask = np.zeros(y.shape+(n_classes+1,), np.float32)
369
+
370
+ for cls, labels in cls_dict.items():
371
+ if cls is None:
372
+ # prob == -1 will be used in the loss to ignore object
373
+ y_mask[np.isin(y, labels), :] = -1
374
+ elif np.issubdtype(type(cls), np.integer) and 0 <= cls <= n_classes:
375
+ y_mask[np.isin(y, labels), cls] = 1
376
+ else:
377
+ raise ValueError(f"Wrong class id '{cls}' (for n_classes={n_classes})")
378
+
379
+ # set 0/1 background prob (unaffected by None values for class ids)
380
+ y_mask[...,0] = (y==0)
381
+
382
+ if return_cls_dict:
383
+ return y_mask, cls_dict
384
+ else:
385
+ return y_mask
386
+
387
+
388
+ def _is_floatarray(x):
389
+ return isinstance(x.dtype.type(0),np.floating)
390
+
391
+
392
+ def abspath(root, relpath):
393
+ from pathlib import Path
394
+ root = Path(root)
395
+ if root.is_dir():
396
+ path = root/relpath
397
+ else:
398
+ path = root.parent/relpath
399
+ return str(path.absolute())
400
+
401
+
402
+ def grid_divisible_patch_size(patch_size, grid, warn=True):
403
+ patch_size, grid = tuple(patch_size), tuple(grid)
404
+ assert len(patch_size) == len(grid)
405
+ patch_size_divisible = tuple(int(np.ceil(sh/g)*g) for sh,g in zip(patch_size,grid))
406
+ if patch_size != patch_size_divisible and warn:
407
+ warnings.warn(f"increasing patch_size from {patch_size} to {patch_size_divisible}, since it was not evenly divisible by grid {grid}")
408
+ return patch_size_divisible
@@ -0,0 +1 @@
1
+ __version__ = '0.9.2'