senoquant 1.0.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. senoquant/__init__.py +6 -0
  2. senoquant/_reader.py +7 -0
  3. senoquant/_widget.py +33 -0
  4. senoquant/napari.yaml +83 -0
  5. senoquant/reader/__init__.py +5 -0
  6. senoquant/reader/core.py +369 -0
  7. senoquant/tabs/__init__.py +15 -0
  8. senoquant/tabs/batch/__init__.py +10 -0
  9. senoquant/tabs/batch/backend.py +641 -0
  10. senoquant/tabs/batch/config.py +270 -0
  11. senoquant/tabs/batch/frontend.py +1283 -0
  12. senoquant/tabs/batch/io.py +326 -0
  13. senoquant/tabs/batch/layers.py +86 -0
  14. senoquant/tabs/quantification/__init__.py +1 -0
  15. senoquant/tabs/quantification/backend.py +228 -0
  16. senoquant/tabs/quantification/features/__init__.py +80 -0
  17. senoquant/tabs/quantification/features/base.py +142 -0
  18. senoquant/tabs/quantification/features/marker/__init__.py +5 -0
  19. senoquant/tabs/quantification/features/marker/config.py +69 -0
  20. senoquant/tabs/quantification/features/marker/dialog.py +437 -0
  21. senoquant/tabs/quantification/features/marker/export.py +879 -0
  22. senoquant/tabs/quantification/features/marker/feature.py +119 -0
  23. senoquant/tabs/quantification/features/marker/morphology.py +285 -0
  24. senoquant/tabs/quantification/features/marker/rows.py +654 -0
  25. senoquant/tabs/quantification/features/marker/thresholding.py +46 -0
  26. senoquant/tabs/quantification/features/roi.py +346 -0
  27. senoquant/tabs/quantification/features/spots/__init__.py +5 -0
  28. senoquant/tabs/quantification/features/spots/config.py +62 -0
  29. senoquant/tabs/quantification/features/spots/dialog.py +477 -0
  30. senoquant/tabs/quantification/features/spots/export.py +1292 -0
  31. senoquant/tabs/quantification/features/spots/feature.py +112 -0
  32. senoquant/tabs/quantification/features/spots/morphology.py +279 -0
  33. senoquant/tabs/quantification/features/spots/rows.py +241 -0
  34. senoquant/tabs/quantification/frontend.py +815 -0
  35. senoquant/tabs/segmentation/__init__.py +1 -0
  36. senoquant/tabs/segmentation/backend.py +131 -0
  37. senoquant/tabs/segmentation/frontend.py +1009 -0
  38. senoquant/tabs/segmentation/models/__init__.py +5 -0
  39. senoquant/tabs/segmentation/models/base.py +146 -0
  40. senoquant/tabs/segmentation/models/cpsam/details.json +65 -0
  41. senoquant/tabs/segmentation/models/cpsam/model.py +150 -0
  42. senoquant/tabs/segmentation/models/default_2d/details.json +69 -0
  43. senoquant/tabs/segmentation/models/default_2d/model.py +664 -0
  44. senoquant/tabs/segmentation/models/default_3d/details.json +69 -0
  45. senoquant/tabs/segmentation/models/default_3d/model.py +682 -0
  46. senoquant/tabs/segmentation/models/hf.py +71 -0
  47. senoquant/tabs/segmentation/models/nuclear_dilation/__init__.py +1 -0
  48. senoquant/tabs/segmentation/models/nuclear_dilation/details.json +26 -0
  49. senoquant/tabs/segmentation/models/nuclear_dilation/model.py +96 -0
  50. senoquant/tabs/segmentation/models/perinuclear_rings/__init__.py +1 -0
  51. senoquant/tabs/segmentation/models/perinuclear_rings/details.json +34 -0
  52. senoquant/tabs/segmentation/models/perinuclear_rings/model.py +132 -0
  53. senoquant/tabs/segmentation/stardist_onnx_utils/__init__.py +2 -0
  54. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/__init__.py +3 -0
  55. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/__init__.py +6 -0
  56. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/generate.py +470 -0
  57. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/prepare.py +273 -0
  58. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/rawdata.py +112 -0
  59. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/transform.py +384 -0
  60. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/__init__.py +0 -0
  61. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/blocks.py +184 -0
  62. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/losses.py +79 -0
  63. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/nets.py +165 -0
  64. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/predict.py +467 -0
  65. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/probability.py +67 -0
  66. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/train.py +148 -0
  67. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/io/__init__.py +163 -0
  68. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/__init__.py +52 -0
  69. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/base_model.py +329 -0
  70. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_isotropic.py +160 -0
  71. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_projection.py +178 -0
  72. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_standard.py +446 -0
  73. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_upsampling.py +54 -0
  74. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/config.py +254 -0
  75. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/pretrained.py +119 -0
  76. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/__init__.py +0 -0
  77. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/care_predict.py +180 -0
  78. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/__init__.py +5 -0
  79. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/plot_utils.py +159 -0
  80. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/six.py +18 -0
  81. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/tf.py +644 -0
  82. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/utils.py +272 -0
  83. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/version.py +1 -0
  84. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/docs/source/conf.py +368 -0
  85. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/setup.py +68 -0
  86. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_datagen.py +169 -0
  87. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_models.py +462 -0
  88. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_utils.py +166 -0
  89. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tools/create_zip_contents.py +34 -0
  90. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/__init__.py +30 -0
  91. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/big.py +624 -0
  92. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/bioimageio_utils.py +494 -0
  93. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/data/__init__.py +39 -0
  94. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/__init__.py +10 -0
  95. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom2d.py +215 -0
  96. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom3d.py +349 -0
  97. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/matching.py +483 -0
  98. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/__init__.py +28 -0
  99. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/base.py +1217 -0
  100. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model2d.py +594 -0
  101. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model3d.py +696 -0
  102. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/nms.py +384 -0
  103. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/__init__.py +2 -0
  104. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/plot.py +74 -0
  105. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/render.py +298 -0
  106. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/rays3d.py +373 -0
  107. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/sample_patches.py +65 -0
  108. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/__init__.py +0 -0
  109. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict2d.py +90 -0
  110. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict3d.py +93 -0
  111. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/utils.py +408 -0
  112. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/version.py +1 -0
  113. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/__init__.py +45 -0
  114. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/__init__.py +17 -0
  115. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/cli.py +55 -0
  116. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/core.py +285 -0
  117. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/__init__.py +15 -0
  118. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/cli.py +36 -0
  119. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/divisibility.py +193 -0
  120. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/probe.py +100 -0
  121. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/receptive_field.py +182 -0
  122. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/rf_cli.py +48 -0
  123. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/valid_sizes.py +278 -0
  124. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/__init__.py +8 -0
  125. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/core.py +157 -0
  126. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/__init__.py +17 -0
  127. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/core.py +226 -0
  128. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/__init__.py +5 -0
  129. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/core.py +401 -0
  130. senoquant/tabs/settings/__init__.py +1 -0
  131. senoquant/tabs/settings/backend.py +29 -0
  132. senoquant/tabs/settings/frontend.py +19 -0
  133. senoquant/tabs/spots/__init__.py +1 -0
  134. senoquant/tabs/spots/backend.py +139 -0
  135. senoquant/tabs/spots/frontend.py +800 -0
  136. senoquant/tabs/spots/models/__init__.py +5 -0
  137. senoquant/tabs/spots/models/base.py +94 -0
  138. senoquant/tabs/spots/models/rmp/details.json +61 -0
  139. senoquant/tabs/spots/models/rmp/model.py +499 -0
  140. senoquant/tabs/spots/models/udwt/details.json +103 -0
  141. senoquant/tabs/spots/models/udwt/model.py +482 -0
  142. senoquant/utils.py +25 -0
  143. senoquant-1.0.0b1.dist-info/METADATA +193 -0
  144. senoquant-1.0.0b1.dist-info/RECORD +148 -0
  145. senoquant-1.0.0b1.dist-info/WHEEL +5 -0
  146. senoquant-1.0.0b1.dist-info/entry_points.txt +2 -0
  147. senoquant-1.0.0b1.dist-info/licenses/LICENSE +28 -0
  148. senoquant-1.0.0b1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,169 @@
1
+ from __future__ import print_function, unicode_literals, absolute_import, division
2
+ from six.moves import range, zip, map, reduce, filter
3
+
4
+ # import warnings
5
+ import numpy as np
6
+ import pytest
7
+ from tifffile import imread
8
+ try:
9
+ from tifffile import imwrite as imsave
10
+ except ImportError:
11
+ from tifffile import imsave
12
+ from csbdeep.data import RawData, create_patches, create_patches_reduced_target
13
+ from csbdeep.io import load_training_data
14
+ from csbdeep.utils import Path, axes_dict, move_image_axes, backend_channels_last
15
+
16
+
17
+
18
+ def test_create_patches():
19
+ rng = np.random.RandomState(42)
20
+ def get_data(n_images, axes, shape):
21
+ def _gen():
22
+ for i in range(n_images):
23
+ x = rng.uniform(size=shape)
24
+ y = 5 + 3*x
25
+ yield x, y, axes, None
26
+ return RawData(_gen, n_images, '')
27
+
28
+ n_images, n_patches_per_image = 2, 4
29
+ def _create(img_size,img_axes,patch_size,patch_axes):
30
+ X,Y,XYaxes = create_patches (
31
+ raw_data = get_data(n_images, img_axes, img_size),
32
+ patch_size = patch_size,
33
+ patch_axes = patch_axes,
34
+ n_patches_per_image = n_patches_per_image,
35
+ )
36
+ assert len(X) == n_images*n_patches_per_image
37
+ assert np.allclose(X,Y,atol=1e-6)
38
+ if patch_axes is not None:
39
+ assert XYaxes == 'SC'+patch_axes.replace('C','')
40
+
41
+ _create((128,128),'YX',(32,32),'YX')
42
+ _create((128,128),'YX',(32,32),None)
43
+ _create((128,128),'YX',(32,32),'XY')
44
+ _create((128,128),'YX',(32,32,1),'XYC')
45
+
46
+ _create((32,48,32),'ZYX',(16,32,8),None)
47
+ _create((32,48,32),'ZYX',(16,32,8),'ZYX')
48
+ _create((32,48,32),'ZYX',(16,32,8),'YXZ')
49
+ _create((32,48,32),'ZYX',(16,32,1,8),'YXCZ')
50
+
51
+
52
+
53
+ def test_create_patches_reduced_target():
54
+ rng = np.random.RandomState(42)
55
+ def get_data(n_images, axes, shape):
56
+ red_n = rng.choice(len(axes)-1)+1
57
+ red_axes = ''.join(rng.choice(tuple(axes),red_n,replace=False))
58
+ keepdims = bool(rng.choice((True,False)))
59
+
60
+ def _gen():
61
+ for i in range(n_images):
62
+ x = rng.uniform(size=shape)
63
+ y = np.mean(x,axis=tuple(axes_dict(axes)[a] for a in red_axes),keepdims=keepdims)
64
+ yield x, y, axes, None
65
+ return RawData(_gen, n_images, ''), red_axes, keepdims
66
+
67
+ n_images, n_patches_per_image = 2, 4
68
+ def _create(red_none,img_size,img_axes,patch_size,patch_axes):
69
+ raw_data, red_axes, keepdims = get_data(n_images, img_axes, img_size)
70
+ # change patch_size to (img_size or None) for red_axes
71
+ patch_size = list(patch_size)
72
+ for a in red_axes:
73
+ patch_size[axes_dict(img_axes if patch_axes is None else patch_axes)[a]] = (
74
+ None if red_none else img_size[axes_dict(img_axes)[a]]
75
+ )
76
+ X,Y,XYaxes = create_patches_reduced_target (
77
+ raw_data = raw_data,
78
+ patch_size = patch_size,
79
+ patch_axes = patch_axes,
80
+ n_patches_per_image = n_patches_per_image,
81
+ reduction_axes = red_axes,
82
+ target_axes = rng.choice((None,img_axes)) if keepdims else ''.join(a for a in img_axes if a not in red_axes),
83
+ #
84
+ normalization = lambda patches_x, patches_y, *args: (patches_x, patches_y),
85
+ verbose = False,
86
+ )
87
+ assert len(X) == n_images*n_patches_per_image
88
+ _X = np.mean(X,axis=tuple(axes_dict(XYaxes)[a] for a in red_axes),keepdims=True)
89
+ err = np.max(np.abs(_X-Y))
90
+ assert err < 1e-5
91
+
92
+ for b in (True,False):
93
+ _create(b,(128,128),'YX',(32,32),'YX')
94
+ _create(b,(128,128),'YX',(32,32),None)
95
+ _create(b,(128,128),'YX',(32,32),'XY')
96
+ _create(b,(128,128),'YX',(32,32,1),'XYC')
97
+
98
+ _create(b,(32,48,32),'ZYX',(16,32,8),None)
99
+ _create(b,(32,48,32),'ZYX',(16,32,8),'ZYX')
100
+ _create(b,(32,48,32),'ZYX',(16,32,8),'YXZ')
101
+ _create(b,(32,48,32),'ZYX',(16,32,1,8),'YXCZ')
102
+
103
+ _create(b,(128,2,128),'YCX',(32,2,32),'YCX')
104
+ _create(b,(3,128,128),'CYX',(3,32,32),None)
105
+ _create(b,(128,128,4),'YXC',(4,32,32),'CXY')
106
+ _create(b,(128,128,5),'YXC',(32,32,5),'XYC')
107
+
108
+ _create(b,(32,48,2,32),'ZYCX',(16,32,2,8),None)
109
+ _create(b,(32,3,48,32),'ZCYX',(3,16,32,8),'CZYX')
110
+ _create(b,(4,32,48,32),'CZYX',(16,32,8,4),'YXZC')
111
+ _create(b,(32,48,32,2),'ZYXC',(16,32,2,8),'YXCZ')
112
+
113
+
114
+
115
+ def test_create_save_and_load(tmpdir):
116
+ rng = np.random.RandomState(42)
117
+ tmpdir = Path(str(tmpdir))
118
+ save_file = str(tmpdir / 'data.npz')
119
+
120
+ n_images, n_patches_per_image = 2, 4
121
+ def _create(img_size,img_axes,patch_size,patch_axes):
122
+ U,V = (rng.uniform(size=(n_images,)+img_size) for _ in range(2))
123
+ X,Y,XYaxes = create_patches (
124
+ raw_data = RawData.from_arrays(U,V,img_axes),
125
+ patch_size = patch_size,
126
+ patch_axes = patch_axes,
127
+ n_patches_per_image = n_patches_per_image,
128
+ save_file = save_file
129
+ )
130
+ (_X,_Y), val_data, _XYaxes = load_training_data(save_file,verbose=True)
131
+ assert val_data is None
132
+ assert _XYaxes[-1 if backend_channels_last else 1] == 'C'
133
+ _X,_Y = (move_image_axes(u,fr=_XYaxes,to=XYaxes) for u in (_X,_Y))
134
+ assert np.allclose(X,_X,atol=1e-6)
135
+ assert np.allclose(Y,_Y,atol=1e-6)
136
+ assert set(XYaxes) == set(_XYaxes)
137
+ assert load_training_data(save_file,validation_split=0.5)[2] is not None
138
+ assert all(len(x)==3 for x in load_training_data(save_file,n_images=3)[0])
139
+
140
+ _create(( 64,64), 'YX',(16,16 ),None)
141
+ _create(( 64,64), 'YX',(16,16 ),'YX')
142
+ _create(( 64,64), 'YX',(16,16,1),'YXC')
143
+ _create((1,64,64),'CYX',( 16,16),'YX')
144
+ _create((1,64,64),'CYX',(1,16,16),None)
145
+ _create((64,3,64),'YCX',(3,16,16),'CYX')
146
+ _create((64,3,64),'YCX',(16,16,3),'YXC')
147
+
148
+
149
+
150
+ def test_rawdata_from_folder(tmpdir):
151
+ rng = np.random.RandomState(42)
152
+ tmpdir = Path(str(tmpdir))
153
+
154
+ n_images, img_size, img_axes = 3, (64,64), 'YX'
155
+ data = {'X' : rng.uniform(size=(n_images,)+img_size).astype(np.float32),
156
+ 'Y' : rng.uniform(size=(n_images,)+img_size).astype(np.float32)}
157
+
158
+ for name,images in data.items():
159
+ (tmpdir/name).mkdir(exist_ok=True)
160
+ for i,img in enumerate(images):
161
+ imsave(str(tmpdir/name/('img_%02d.tif'%i)),img)
162
+
163
+ raw_data = RawData.from_folder(str(tmpdir),['X'],'Y',img_axes)
164
+ assert raw_data.size == n_images
165
+ for i,(x,y,axes,mask) in enumerate(raw_data.generator()):
166
+ assert mask is None
167
+ assert axes == img_axes
168
+ assert any(np.allclose(x,u) for u in data['X'])
169
+ assert any(np.allclose(y,u) for u in data['Y'])
@@ -0,0 +1,462 @@
1
+ from __future__ import print_function, unicode_literals, absolute_import, division
2
+ from six.moves import range, zip, map, reduce, filter
3
+
4
+ from itertools import product
5
+
6
+ # import warnings
7
+ import numpy as np
8
+ import pytest
9
+ from csbdeep.data import NoNormalizer, NoResizer
10
+ from csbdeep.internals.predict import tile_overlap
11
+ from csbdeep.utils.tf import IS_KERAS_3_PLUS, BACKEND as K
12
+
13
+ from csbdeep.internals.nets import receptive_field_unet
14
+ from csbdeep.models import Config, CARE, UpsamplingCARE, IsotropicCARE
15
+ from csbdeep.models import ProjectionConfig, ProjectionCARE
16
+ from csbdeep.utils import axes_dict
17
+ from csbdeep.utils.six import FileNotFoundError
18
+
19
+
20
+
21
+ def config_generator(cls=Config, **kwargs):
22
+ assert 'axes' in kwargs
23
+ keys, values = kwargs.keys(), kwargs.values()
24
+ values = [v if isinstance(v,(list,tuple)) else [v] for v in values]
25
+ for p in product(*values):
26
+ yield cls(**dict(zip(keys,p)))
27
+
28
+
29
+
30
+ def test_config():
31
+ assert K.image_data_format() in ('channels_first','channels_last')
32
+ def _with_channel(axes):
33
+ axes = axes.upper()
34
+ if 'C' in axes:
35
+ return axes
36
+ return (axes+'C') if K.image_data_format() == 'channels_last' else ('C'+axes)
37
+
38
+ axes_list = [
39
+ ('yx',_with_channel('YX')),
40
+ ('ytx',_with_channel('YTX')),
41
+ ('zyx',_with_channel('ZYX')),
42
+ ('YX',_with_channel('YX')),
43
+ ('XYZ',_with_channel('XYZ')),
44
+ ('XYT',_with_channel('XYT')),
45
+ ('SYX',_with_channel('YX')),
46
+ ('SXYZ',_with_channel('XYZ')),
47
+ ('SXTY',_with_channel('XTY')),
48
+ (_with_channel('YX'),_with_channel('YX')),
49
+ (_with_channel('XYZ'),_with_channel('XYZ')),
50
+ (_with_channel('XTY'),_with_channel('XTY')),
51
+ (_with_channel('SYX'),_with_channel('YX')),
52
+ (_with_channel('STYX'),_with_channel('TYX')),
53
+ (_with_channel('SXYZ'),_with_channel('XYZ')),
54
+ ]
55
+
56
+ for (axes,axes_ref) in axes_list:
57
+ assert Config(axes).axes == axes_ref
58
+
59
+ with pytest.raises(ValueError):
60
+ Config('XYC')
61
+ Config('CXY')
62
+ with pytest.raises(ValueError):
63
+ Config('XYZC')
64
+ Config('CXYZ')
65
+ with pytest.raises(ValueError):
66
+ Config('XTYC')
67
+ Config('CXTY')
68
+ with pytest.raises(ValueError): Config('XYZT')
69
+ with pytest.raises(ValueError): Config('tXYZ')
70
+ with pytest.raises(ValueError): Config('XYS')
71
+ with pytest.raises(ValueError): Config('XSYZ')
72
+
73
+
74
+
75
+ @pytest.mark.parametrize('config', config_generator(
76
+ axes = ['YX','ZYX'],
77
+ n_channel_in = [1,2],
78
+ n_channel_out = [1,2],
79
+ probabilistic = [False,True],
80
+ unet_residual = [False,True],
81
+ unet_n_depth = [1,2],
82
+ # unet_kern_size = [3],
83
+ # unet_n_first = [32],
84
+ # unet_last_activation = ['linear'],
85
+ # unet_input_shape = [(None, None, 1)],
86
+ #
87
+ # train_batch_size = [16],
88
+ # train_checkpoint = ['weights_best.h5'],
89
+ # train_epochs = [100],
90
+ # train_learning_rate = [0.0004],
91
+ # train_loss = ['mae'],
92
+ # train_reduce_lr = [{'factor': 0.5, 'patience': 10}],
93
+ # train_steps_per_epoch = [400],
94
+ # train_tensorboard = [True],
95
+ ))
96
+ def test_model_build_and_export(tmpdir,config):
97
+ K.clear_session()
98
+ def _build():
99
+ with pytest.raises(FileNotFoundError):
100
+ CARE(None,basedir=str(tmpdir))
101
+
102
+ CARE(config,name='model',basedir=None)
103
+ with pytest.raises(ValueError):
104
+ CARE(None,basedir=None)
105
+
106
+ if IS_KERAS_3_PLUS:
107
+ with pytest.raises(NotImplementedError):
108
+ CARE(config,basedir=str(tmpdir)).export_TF()
109
+ else:
110
+ CARE(config,basedir=str(tmpdir)).export_TF()
111
+
112
+ with pytest.warns(UserWarning):
113
+ CARE(config,name='model',basedir=str(tmpdir))
114
+ CARE(config,name='model',basedir=str(tmpdir))
115
+ CARE(None,name='model',basedir=str(tmpdir))
116
+ if config.is_valid():
117
+ _build()
118
+ else:
119
+ with pytest.raises(ValueError):
120
+ _build()
121
+
122
+
123
+
124
+ @pytest.mark.parametrize('config', filter(lambda c: c.is_valid(), config_generator(
125
+ axes = ['YX','ZYX'],
126
+ n_channel_in = [1,2],
127
+ n_channel_out = [1,2],
128
+ probabilistic = [False,True],
129
+ # unet_residual = [False,True],
130
+ unet_n_depth = [1],
131
+
132
+ unet_kern_size = [3],
133
+ unet_n_first = [4],
134
+ unet_last_activation = ['linear'],
135
+ # unet_input_shape = [(None, None, 1)],
136
+
137
+ train_loss = ['mae','laplace'],
138
+ train_epochs = [2],
139
+ train_steps_per_epoch = [2],
140
+ # train_learning_rate = [0.0004],
141
+ train_batch_size = [2],
142
+ # train_tensorboard = [True],
143
+ # train_checkpoint = ['weights_best.h5'],
144
+ # train_reduce_lr = [{'factor': 0.5, 'patience': 10}],
145
+ )))
146
+ def test_model_train(tmpdir,config):
147
+ rng = np.random.RandomState(42)
148
+ K.clear_session()
149
+ X = rng.uniform(size=(4,)+(32,)*config.n_dim+(config.n_channel_in,))
150
+ Y = rng.uniform(size=(4,)+(32,)*config.n_dim+(config.n_channel_out,))
151
+ model = CARE(config,basedir=str(tmpdir))
152
+ model.train(X,Y,(X,Y))
153
+
154
+
155
+
156
+ @pytest.mark.parametrize('config', filter(lambda c: c.is_valid(), config_generator(
157
+ axes = ['YX','ZYX'],
158
+ n_channel_in = [1,2],
159
+ n_channel_out = [1,2],
160
+ probabilistic = [False,True],
161
+ # unet_residual = [False,True],
162
+ unet_n_depth = [2],
163
+
164
+ unet_kern_size = [3],
165
+ unet_n_first = [4],
166
+ unet_last_activation = ['linear'],
167
+ # unet_input_shape = [(None, None, 1)],
168
+ )))
169
+ def test_model_predict(tmpdir,config):
170
+ rng = np.random.RandomState(42)
171
+ normalizer, resizer = NoNormalizer(), NoResizer()
172
+
173
+ K.clear_session()
174
+ model = CARE(config,basedir=str(tmpdir))
175
+ axes = config.axes
176
+
177
+ def _predict(imdims,axes):
178
+ img = rng.uniform(size=imdims)
179
+ # print(img.shape, axes, config.n_channel_out)
180
+ if config.probabilistic:
181
+ prob = model.predict_probabilistic(img, axes, normalizer, resizer)
182
+ mean, scale = prob.mean(), prob.scale()
183
+ assert mean.shape == scale.shape
184
+ else:
185
+ mean = model.predict(img, axes, normalizer, resizer)
186
+
187
+ if 'C' not in axes:
188
+ if config.n_channel_out == 1:
189
+ assert mean.shape == img.shape
190
+ else:
191
+ assert mean.shape == img.shape + (config.n_channel_out,)
192
+ else:
193
+ channel = axes_dict(axes)['C']
194
+ imdims[channel] = config.n_channel_out
195
+ assert mean.shape == tuple(imdims)
196
+
197
+
198
+ imdims = list(rng.randint(20,40,size=config.n_dim))
199
+ div_n = 2**config.unet_n_depth
200
+ imdims = [(d//div_n)*div_n for d in imdims]
201
+
202
+ if config.n_channel_in == 1:
203
+ _predict(imdims,axes=axes.replace('C',''))
204
+
205
+ channel = rng.randint(0,config.n_dim)
206
+ imdims.insert(channel,config.n_channel_in)
207
+ _axes = axes.replace('C','')
208
+ _axes = _axes[:channel]+'C'+_axes[channel:]
209
+ _predict(imdims,axes=_axes)
210
+
211
+
212
+
213
+ @pytest.mark.parametrize('config', filter(lambda c: c.is_valid(), config_generator(
214
+ axes = ['YX','ZYX'],
215
+ n_channel_in = [1,2],
216
+ n_channel_out = [1,2],
217
+ probabilistic = [False],
218
+ # unet_residual = [False,True],
219
+ unet_n_depth = [2,3],
220
+ unet_kern_size = [3,5],
221
+
222
+ unet_n_first = [4],
223
+ unet_last_activation = ['linear'],
224
+ # unet_input_shape = [(None, None, 1)],
225
+ )))
226
+ def test_model_predict_tiled(tmpdir,config):
227
+ """
228
+ Test that tiled prediction yields the same
229
+ or similar result as compared to predicting
230
+ the whole image at once.
231
+ """
232
+ rng = np.random.RandomState(42)
233
+ normalizer, resizer = NoNormalizer(), NoResizer()
234
+
235
+ K.clear_session()
236
+ model = CARE(config,basedir=str(tmpdir))
237
+
238
+ def _predict(imdims,axes,n_tiles):
239
+ img = rng.uniform(size=imdims)
240
+ # print(img.shape, axes)
241
+ mean, scale = model._predict_mean_and_scale(img, axes, normalizer, resizer, n_tiles=None)
242
+ mean_tiled, scale_tiled = model._predict_mean_and_scale(img, axes, normalizer, resizer, n_tiles=n_tiles)
243
+ assert mean.shape == mean_tiled.shape
244
+ if config.probabilistic:
245
+ assert scale.shape == scale_tiled.shape
246
+ error_max = np.max(np.abs(mean-mean_tiled))
247
+ # print('n, k, err = {0}, {1}x{1}, {2}'.format(model.config.unet_n_depth, model.config.unet_kern_size, error_max))
248
+ assert error_max < 1e-3
249
+ return mean, mean_tiled
250
+
251
+ imdims = list(rng.randint(50,70,size=config.n_dim))
252
+ if config.n_dim == 3:
253
+ imdims[0] = 16 # make one dim small, otherwise test takes too long
254
+ div_n = 2**config.unet_n_depth
255
+ imdims = [(d//div_n)*div_n for d in imdims]
256
+
257
+ imdims.insert(0,config.n_channel_in)
258
+ axes = 'C'+config.axes.replace('C','')
259
+
260
+ for n_tiles in (
261
+ -1, 1.2,
262
+ [1]+[1.2]*config.n_dim,
263
+ [1]*config.n_dim, # missing value for channel axis
264
+ [2]+[1]*config.n_dim, # >1 tiles for channel axis
265
+ ):
266
+ with pytest.raises(ValueError):
267
+ _predict(imdims,axes,n_tiles)
268
+
269
+ for n_tiles in [list(rng.randint(1,5,size=config.n_dim)) for _ in range(3)]:
270
+ # print(imdims,axes,[1]+n_tiles)
271
+ if config.n_channel_in == 1:
272
+ _predict(imdims[1:],axes[1:],n_tiles)
273
+ _predict(imdims,axes,[1]+n_tiles)
274
+
275
+ # legacy api: tile only largest dimension
276
+ n_blocks = np.max(imdims) // div_n
277
+ for n_tiles in (2,5,n_blocks+1):
278
+ with pytest.warns(UserWarning):
279
+ if config.n_channel_in == 1:
280
+ _predict(imdims[1:],axes[1:],n_tiles)
281
+ _predict(imdims,axes,n_tiles)
282
+
283
+
284
+
285
+ @pytest.mark.parametrize('n_depth', (1,2,3,4,5))
286
+ @pytest.mark.parametrize('kern_size', (3,5))
287
+ @pytest.mark.parametrize('pool_size', (1,2))
288
+ # TODO: (pool_size=2, kern_size=7, n_depth>=2): works on CPU, but fails on GPU! (at least in TF 2.3.1, 2.5.0, 2.6.0)
289
+ def test_tile_overlap(n_depth, kern_size, pool_size):
290
+ K.clear_session()
291
+ img_size = 1280 if pool_size > 1 else 160
292
+ rf_x, rf_y = receptive_field_unet(n_depth,kern_size,pool_size,n_dim=2,img_size=img_size)
293
+ assert rf_x == rf_y
294
+ rf = rf_x
295
+ assert np.abs(rf[0]-rf[1]) < 10
296
+ assert sum(rf)+1 < img_size
297
+ assert max(rf) == tile_overlap(n_depth,kern_size,pool_size)
298
+ # print("receptive field of n_depth %d and kernel size %d: %s"%(n_depth,kern_size,rf));
299
+
300
+
301
+
302
+ @pytest.mark.parametrize('config', filter(lambda c: c.is_valid(), config_generator(
303
+ axes = ['ZYX'],
304
+ n_channel_in = [1,2],
305
+ n_channel_out = [1,2],
306
+ probabilistic = [False,True],
307
+ # unet_residual = [False,True],
308
+ unet_n_depth = [1],
309
+
310
+ unet_kern_size = [3],
311
+ unet_n_first = [4],
312
+ unet_last_activation = ['linear'],
313
+ # unet_input_shape = [(None, None, 1)],
314
+ )))
315
+ @pytest.mark.parametrize('factor', (2.5,3))
316
+ def test_model_upsampling_predict(tmpdir,config,factor):
317
+ rng = np.random.RandomState(42)
318
+
319
+ K.clear_session()
320
+ model = UpsamplingCARE(config,basedir=None)
321
+ axes = config.axes
322
+
323
+ def _predict(imdims,axes):
324
+ img = rng.uniform(size=imdims)
325
+ if config.probabilistic:
326
+ prob = model.predict_probabilistic(img, axes, factor, None, None)
327
+ mean, scale = prob.mean(), prob.scale()
328
+ assert mean.shape == scale.shape
329
+ else:
330
+ mean = model.predict(img, axes, factor, None, None)
331
+ a = axes_dict(axes)['Z']
332
+ assert imdims[a]*factor == mean.shape[a]
333
+
334
+ imdims = list(rng.randint(20,40,size=config.n_dim))
335
+ div_n = 2**(config.unet_n_depth+1)
336
+ imdims = [(d//div_n)*div_n for d in imdims]
337
+
338
+ if config.n_channel_in == 1:
339
+ _predict(imdims,axes=axes.replace('C',''))
340
+
341
+ channel = rng.randint(0,config.n_dim)
342
+ imdims.insert(channel,config.n_channel_in)
343
+ _axes = axes.replace('C','')
344
+ _axes = _axes[:channel]+'C'+_axes[channel:]
345
+ _predict(imdims,axes=_axes)
346
+
347
+
348
+
349
+ @pytest.mark.parametrize('config', filter(lambda c: c.is_valid(), config_generator(
350
+ axes = ['YX'],
351
+ n_channel_in = [1,2],
352
+ n_channel_out = [1,2],
353
+ probabilistic = [False,True],
354
+ # unet_residual = [False,True],
355
+ unet_n_depth = [1],
356
+
357
+ unet_kern_size = [3],
358
+ unet_n_first = [4],
359
+ unet_last_activation = ['linear'],
360
+ # unet_input_shape = [(None, None, 1)],
361
+ )))
362
+ @pytest.mark.parametrize('factor', (2.5,3))
363
+ def test_model_isotropic_predict(tmpdir,config,factor):
364
+ rng = np.random.RandomState(42)
365
+
366
+ K.clear_session()
367
+ model = IsotropicCARE(config,basedir=None)
368
+ axes = config.axes+'Z'
369
+
370
+ def _predict(imdims,axes):
371
+ img = rng.uniform(size=imdims)
372
+ if config.probabilistic:
373
+ prob = model.predict_probabilistic(img, axes, factor, None, None)
374
+ mean, scale = prob.mean(), prob.scale()
375
+ assert mean.shape == scale.shape
376
+ else:
377
+ mean = model.predict(img, axes, factor, None, None)
378
+ a = axes_dict(axes)['Z']
379
+ assert imdims[a]*factor == mean.shape[a]
380
+
381
+ imdims = list(rng.randint(20,40,size=config.n_dim+1))
382
+ div_n = 2**(config.unet_n_depth+1)
383
+ imdims = [(d//div_n)*div_n for d in imdims]
384
+
385
+ if config.n_channel_in == 1:
386
+ _predict(imdims,axes=axes.replace('C',''))
387
+
388
+ channel = rng.randint(0,config.n_dim+1)
389
+ imdims.insert(channel,config.n_channel_in)
390
+ _axes = axes.replace('C','')
391
+ _axes = _axes[:channel]+'C'+_axes[channel:]
392
+ _predict(imdims,axes=_axes)
393
+
394
+
395
+
396
+ @pytest.mark.parametrize('config', filter(lambda c: c.is_valid(), config_generator(
397
+ ProjectionConfig,
398
+ axes = ['ZYX'],
399
+ n_channel_in = [1,2],
400
+ n_channel_out = [1,2],
401
+ probabilistic = [False,True],
402
+ # unet_residual = [False,True],
403
+ unet_n_depth = [1],
404
+
405
+ unet_kern_size = [3],
406
+ unet_n_first = [4],
407
+ unet_last_activation = ['linear'],
408
+ # unet_input_shape = [(None, None, 1)],
409
+ proj_n_depth = [2,4],
410
+ )))
411
+ def test_model_projection_predict(tmpdir,config):
412
+ rng = np.random.RandomState(42)
413
+
414
+ K.clear_session()
415
+ model = ProjectionCARE(config,basedir=None)
416
+ axes = config.axes
417
+ proj_axis = model.proj_params.axis
418
+
419
+ def _predict(imdims,axes):
420
+ img = rng.uniform(size=imdims)
421
+ n_tiles = [1]*len(axes)
422
+ ax = axes_dict(axes)
423
+
424
+ if config.probabilistic:
425
+ prob = model.predict_probabilistic(img, axes, None, None)
426
+ mean, scale = prob.mean(), prob.scale()
427
+ assert mean.shape == scale.shape
428
+ else:
429
+ mean = model.predict(img, axes, None, None)
430
+
431
+ n_tiles[ax['X']] = 3
432
+ n_tiles[ax['Y']] = 2
433
+ mean_tiled = model.predict(img, axes, None, None, n_tiles=n_tiles)
434
+ error_max = np.max(np.abs(mean-mean_tiled))
435
+ # print(n_tiles, error_max)
436
+ assert error_max < 1e-3
437
+
438
+ with pytest.raises(ValueError):
439
+ n_tiles[ax[proj_axis]] = 2
440
+ model.predict(img, axes, None, None, n_tiles=n_tiles)
441
+
442
+ shape_out = list(imdims)
443
+ if 'C' in axes:
444
+ shape_out[ax['C']] = config.n_channel_out
445
+ elif config.n_channel_out > 1:
446
+ shape_out.append(config.n_channel_out)
447
+
448
+ del shape_out[ax[proj_axis]]
449
+ assert tuple(shape_out) == mean.shape
450
+
451
+ imdims = list(rng.randint(30,50,size=config.n_dim))
452
+ # imdims = [10,1024,1024]
453
+ imdims = [(d//div_n)*div_n for d,div_n in zip(imdims,model._axes_div_by(axes))]
454
+
455
+ if config.n_channel_in == 1:
456
+ _predict(imdims,axes=axes.replace('C',''))
457
+
458
+ channel = rng.randint(0,config.n_dim)
459
+ imdims.insert(channel,config.n_channel_in)
460
+ _axes = axes.replace('C','')
461
+ _axes = _axes[:channel]+'C'+_axes[channel:]
462
+ _predict(imdims,axes=_axes)