senoquant 1.0.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- senoquant/__init__.py +6 -0
- senoquant/_reader.py +7 -0
- senoquant/_widget.py +33 -0
- senoquant/napari.yaml +83 -0
- senoquant/reader/__init__.py +5 -0
- senoquant/reader/core.py +369 -0
- senoquant/tabs/__init__.py +15 -0
- senoquant/tabs/batch/__init__.py +10 -0
- senoquant/tabs/batch/backend.py +641 -0
- senoquant/tabs/batch/config.py +270 -0
- senoquant/tabs/batch/frontend.py +1283 -0
- senoquant/tabs/batch/io.py +326 -0
- senoquant/tabs/batch/layers.py +86 -0
- senoquant/tabs/quantification/__init__.py +1 -0
- senoquant/tabs/quantification/backend.py +228 -0
- senoquant/tabs/quantification/features/__init__.py +80 -0
- senoquant/tabs/quantification/features/base.py +142 -0
- senoquant/tabs/quantification/features/marker/__init__.py +5 -0
- senoquant/tabs/quantification/features/marker/config.py +69 -0
- senoquant/tabs/quantification/features/marker/dialog.py +437 -0
- senoquant/tabs/quantification/features/marker/export.py +879 -0
- senoquant/tabs/quantification/features/marker/feature.py +119 -0
- senoquant/tabs/quantification/features/marker/morphology.py +285 -0
- senoquant/tabs/quantification/features/marker/rows.py +654 -0
- senoquant/tabs/quantification/features/marker/thresholding.py +46 -0
- senoquant/tabs/quantification/features/roi.py +346 -0
- senoquant/tabs/quantification/features/spots/__init__.py +5 -0
- senoquant/tabs/quantification/features/spots/config.py +62 -0
- senoquant/tabs/quantification/features/spots/dialog.py +477 -0
- senoquant/tabs/quantification/features/spots/export.py +1292 -0
- senoquant/tabs/quantification/features/spots/feature.py +112 -0
- senoquant/tabs/quantification/features/spots/morphology.py +279 -0
- senoquant/tabs/quantification/features/spots/rows.py +241 -0
- senoquant/tabs/quantification/frontend.py +815 -0
- senoquant/tabs/segmentation/__init__.py +1 -0
- senoquant/tabs/segmentation/backend.py +131 -0
- senoquant/tabs/segmentation/frontend.py +1009 -0
- senoquant/tabs/segmentation/models/__init__.py +5 -0
- senoquant/tabs/segmentation/models/base.py +146 -0
- senoquant/tabs/segmentation/models/cpsam/details.json +65 -0
- senoquant/tabs/segmentation/models/cpsam/model.py +150 -0
- senoquant/tabs/segmentation/models/default_2d/details.json +69 -0
- senoquant/tabs/segmentation/models/default_2d/model.py +664 -0
- senoquant/tabs/segmentation/models/default_3d/details.json +69 -0
- senoquant/tabs/segmentation/models/default_3d/model.py +682 -0
- senoquant/tabs/segmentation/models/hf.py +71 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/__init__.py +1 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/details.json +26 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/model.py +96 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/__init__.py +1 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/details.json +34 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/model.py +132 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/__init__.py +2 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/__init__.py +3 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/__init__.py +6 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/generate.py +470 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/prepare.py +273 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/rawdata.py +112 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/transform.py +384 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/blocks.py +184 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/losses.py +79 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/nets.py +165 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/predict.py +467 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/probability.py +67 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/train.py +148 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/io/__init__.py +163 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/__init__.py +52 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/base_model.py +329 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_isotropic.py +160 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_projection.py +178 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_standard.py +446 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_upsampling.py +54 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/config.py +254 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/pretrained.py +119 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/care_predict.py +180 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/__init__.py +5 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/plot_utils.py +159 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/six.py +18 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/tf.py +644 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/utils.py +272 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/version.py +1 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/docs/source/conf.py +368 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/setup.py +68 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_datagen.py +169 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_models.py +462 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_utils.py +166 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tools/create_zip_contents.py +34 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/__init__.py +30 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/big.py +624 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/bioimageio_utils.py +494 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/data/__init__.py +39 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/__init__.py +10 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom2d.py +215 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom3d.py +349 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/matching.py +483 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/__init__.py +28 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/base.py +1217 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model2d.py +594 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model3d.py +696 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/nms.py +384 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/__init__.py +2 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/plot.py +74 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/render.py +298 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/rays3d.py +373 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/sample_patches.py +65 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict2d.py +90 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict3d.py +93 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/utils.py +408 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/version.py +1 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/__init__.py +45 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/__init__.py +17 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/cli.py +55 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/core.py +285 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/__init__.py +15 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/cli.py +36 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/divisibility.py +193 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/probe.py +100 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/receptive_field.py +182 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/rf_cli.py +48 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/valid_sizes.py +278 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/__init__.py +8 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/core.py +157 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/__init__.py +17 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/core.py +226 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/__init__.py +5 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/core.py +401 -0
- senoquant/tabs/settings/__init__.py +1 -0
- senoquant/tabs/settings/backend.py +29 -0
- senoquant/tabs/settings/frontend.py +19 -0
- senoquant/tabs/spots/__init__.py +1 -0
- senoquant/tabs/spots/backend.py +139 -0
- senoquant/tabs/spots/frontend.py +800 -0
- senoquant/tabs/spots/models/__init__.py +5 -0
- senoquant/tabs/spots/models/base.py +94 -0
- senoquant/tabs/spots/models/rmp/details.json +61 -0
- senoquant/tabs/spots/models/rmp/model.py +499 -0
- senoquant/tabs/spots/models/udwt/details.json +103 -0
- senoquant/tabs/spots/models/udwt/model.py +482 -0
- senoquant/utils.py +25 -0
- senoquant-1.0.0b1.dist-info/METADATA +193 -0
- senoquant-1.0.0b1.dist-info/RECORD +148 -0
- senoquant-1.0.0b1.dist-info/WHEEL +5 -0
- senoquant-1.0.0b1.dist-info/entry_points.txt +2 -0
- senoquant-1.0.0b1.dist-info/licenses/LICENSE +28 -0
- senoquant-1.0.0b1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
from __future__ import print_function, unicode_literals, absolute_import, division
|
|
2
|
+
from six.moves import range, zip, map, reduce, filter
|
|
3
|
+
|
|
4
|
+
# import warnings
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pytest
|
|
7
|
+
from tifffile import imread
|
|
8
|
+
try:
|
|
9
|
+
from tifffile import imwrite as imsave
|
|
10
|
+
except ImportError:
|
|
11
|
+
from tifffile import imsave
|
|
12
|
+
from csbdeep.data import RawData, create_patches, create_patches_reduced_target
|
|
13
|
+
from csbdeep.io import load_training_data
|
|
14
|
+
from csbdeep.utils import Path, axes_dict, move_image_axes, backend_channels_last
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def test_create_patches():
|
|
19
|
+
rng = np.random.RandomState(42)
|
|
20
|
+
def get_data(n_images, axes, shape):
|
|
21
|
+
def _gen():
|
|
22
|
+
for i in range(n_images):
|
|
23
|
+
x = rng.uniform(size=shape)
|
|
24
|
+
y = 5 + 3*x
|
|
25
|
+
yield x, y, axes, None
|
|
26
|
+
return RawData(_gen, n_images, '')
|
|
27
|
+
|
|
28
|
+
n_images, n_patches_per_image = 2, 4
|
|
29
|
+
def _create(img_size,img_axes,patch_size,patch_axes):
|
|
30
|
+
X,Y,XYaxes = create_patches (
|
|
31
|
+
raw_data = get_data(n_images, img_axes, img_size),
|
|
32
|
+
patch_size = patch_size,
|
|
33
|
+
patch_axes = patch_axes,
|
|
34
|
+
n_patches_per_image = n_patches_per_image,
|
|
35
|
+
)
|
|
36
|
+
assert len(X) == n_images*n_patches_per_image
|
|
37
|
+
assert np.allclose(X,Y,atol=1e-6)
|
|
38
|
+
if patch_axes is not None:
|
|
39
|
+
assert XYaxes == 'SC'+patch_axes.replace('C','')
|
|
40
|
+
|
|
41
|
+
_create((128,128),'YX',(32,32),'YX')
|
|
42
|
+
_create((128,128),'YX',(32,32),None)
|
|
43
|
+
_create((128,128),'YX',(32,32),'XY')
|
|
44
|
+
_create((128,128),'YX',(32,32,1),'XYC')
|
|
45
|
+
|
|
46
|
+
_create((32,48,32),'ZYX',(16,32,8),None)
|
|
47
|
+
_create((32,48,32),'ZYX',(16,32,8),'ZYX')
|
|
48
|
+
_create((32,48,32),'ZYX',(16,32,8),'YXZ')
|
|
49
|
+
_create((32,48,32),'ZYX',(16,32,1,8),'YXCZ')
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def test_create_patches_reduced_target():
|
|
54
|
+
rng = np.random.RandomState(42)
|
|
55
|
+
def get_data(n_images, axes, shape):
|
|
56
|
+
red_n = rng.choice(len(axes)-1)+1
|
|
57
|
+
red_axes = ''.join(rng.choice(tuple(axes),red_n,replace=False))
|
|
58
|
+
keepdims = bool(rng.choice((True,False)))
|
|
59
|
+
|
|
60
|
+
def _gen():
|
|
61
|
+
for i in range(n_images):
|
|
62
|
+
x = rng.uniform(size=shape)
|
|
63
|
+
y = np.mean(x,axis=tuple(axes_dict(axes)[a] for a in red_axes),keepdims=keepdims)
|
|
64
|
+
yield x, y, axes, None
|
|
65
|
+
return RawData(_gen, n_images, ''), red_axes, keepdims
|
|
66
|
+
|
|
67
|
+
n_images, n_patches_per_image = 2, 4
|
|
68
|
+
def _create(red_none,img_size,img_axes,patch_size,patch_axes):
|
|
69
|
+
raw_data, red_axes, keepdims = get_data(n_images, img_axes, img_size)
|
|
70
|
+
# change patch_size to (img_size or None) for red_axes
|
|
71
|
+
patch_size = list(patch_size)
|
|
72
|
+
for a in red_axes:
|
|
73
|
+
patch_size[axes_dict(img_axes if patch_axes is None else patch_axes)[a]] = (
|
|
74
|
+
None if red_none else img_size[axes_dict(img_axes)[a]]
|
|
75
|
+
)
|
|
76
|
+
X,Y,XYaxes = create_patches_reduced_target (
|
|
77
|
+
raw_data = raw_data,
|
|
78
|
+
patch_size = patch_size,
|
|
79
|
+
patch_axes = patch_axes,
|
|
80
|
+
n_patches_per_image = n_patches_per_image,
|
|
81
|
+
reduction_axes = red_axes,
|
|
82
|
+
target_axes = rng.choice((None,img_axes)) if keepdims else ''.join(a for a in img_axes if a not in red_axes),
|
|
83
|
+
#
|
|
84
|
+
normalization = lambda patches_x, patches_y, *args: (patches_x, patches_y),
|
|
85
|
+
verbose = False,
|
|
86
|
+
)
|
|
87
|
+
assert len(X) == n_images*n_patches_per_image
|
|
88
|
+
_X = np.mean(X,axis=tuple(axes_dict(XYaxes)[a] for a in red_axes),keepdims=True)
|
|
89
|
+
err = np.max(np.abs(_X-Y))
|
|
90
|
+
assert err < 1e-5
|
|
91
|
+
|
|
92
|
+
for b in (True,False):
|
|
93
|
+
_create(b,(128,128),'YX',(32,32),'YX')
|
|
94
|
+
_create(b,(128,128),'YX',(32,32),None)
|
|
95
|
+
_create(b,(128,128),'YX',(32,32),'XY')
|
|
96
|
+
_create(b,(128,128),'YX',(32,32,1),'XYC')
|
|
97
|
+
|
|
98
|
+
_create(b,(32,48,32),'ZYX',(16,32,8),None)
|
|
99
|
+
_create(b,(32,48,32),'ZYX',(16,32,8),'ZYX')
|
|
100
|
+
_create(b,(32,48,32),'ZYX',(16,32,8),'YXZ')
|
|
101
|
+
_create(b,(32,48,32),'ZYX',(16,32,1,8),'YXCZ')
|
|
102
|
+
|
|
103
|
+
_create(b,(128,2,128),'YCX',(32,2,32),'YCX')
|
|
104
|
+
_create(b,(3,128,128),'CYX',(3,32,32),None)
|
|
105
|
+
_create(b,(128,128,4),'YXC',(4,32,32),'CXY')
|
|
106
|
+
_create(b,(128,128,5),'YXC',(32,32,5),'XYC')
|
|
107
|
+
|
|
108
|
+
_create(b,(32,48,2,32),'ZYCX',(16,32,2,8),None)
|
|
109
|
+
_create(b,(32,3,48,32),'ZCYX',(3,16,32,8),'CZYX')
|
|
110
|
+
_create(b,(4,32,48,32),'CZYX',(16,32,8,4),'YXZC')
|
|
111
|
+
_create(b,(32,48,32,2),'ZYXC',(16,32,2,8),'YXCZ')
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def test_create_save_and_load(tmpdir):
|
|
116
|
+
rng = np.random.RandomState(42)
|
|
117
|
+
tmpdir = Path(str(tmpdir))
|
|
118
|
+
save_file = str(tmpdir / 'data.npz')
|
|
119
|
+
|
|
120
|
+
n_images, n_patches_per_image = 2, 4
|
|
121
|
+
def _create(img_size,img_axes,patch_size,patch_axes):
|
|
122
|
+
U,V = (rng.uniform(size=(n_images,)+img_size) for _ in range(2))
|
|
123
|
+
X,Y,XYaxes = create_patches (
|
|
124
|
+
raw_data = RawData.from_arrays(U,V,img_axes),
|
|
125
|
+
patch_size = patch_size,
|
|
126
|
+
patch_axes = patch_axes,
|
|
127
|
+
n_patches_per_image = n_patches_per_image,
|
|
128
|
+
save_file = save_file
|
|
129
|
+
)
|
|
130
|
+
(_X,_Y), val_data, _XYaxes = load_training_data(save_file,verbose=True)
|
|
131
|
+
assert val_data is None
|
|
132
|
+
assert _XYaxes[-1 if backend_channels_last else 1] == 'C'
|
|
133
|
+
_X,_Y = (move_image_axes(u,fr=_XYaxes,to=XYaxes) for u in (_X,_Y))
|
|
134
|
+
assert np.allclose(X,_X,atol=1e-6)
|
|
135
|
+
assert np.allclose(Y,_Y,atol=1e-6)
|
|
136
|
+
assert set(XYaxes) == set(_XYaxes)
|
|
137
|
+
assert load_training_data(save_file,validation_split=0.5)[2] is not None
|
|
138
|
+
assert all(len(x)==3 for x in load_training_data(save_file,n_images=3)[0])
|
|
139
|
+
|
|
140
|
+
_create(( 64,64), 'YX',(16,16 ),None)
|
|
141
|
+
_create(( 64,64), 'YX',(16,16 ),'YX')
|
|
142
|
+
_create(( 64,64), 'YX',(16,16,1),'YXC')
|
|
143
|
+
_create((1,64,64),'CYX',( 16,16),'YX')
|
|
144
|
+
_create((1,64,64),'CYX',(1,16,16),None)
|
|
145
|
+
_create((64,3,64),'YCX',(3,16,16),'CYX')
|
|
146
|
+
_create((64,3,64),'YCX',(16,16,3),'YXC')
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def test_rawdata_from_folder(tmpdir):
|
|
151
|
+
rng = np.random.RandomState(42)
|
|
152
|
+
tmpdir = Path(str(tmpdir))
|
|
153
|
+
|
|
154
|
+
n_images, img_size, img_axes = 3, (64,64), 'YX'
|
|
155
|
+
data = {'X' : rng.uniform(size=(n_images,)+img_size).astype(np.float32),
|
|
156
|
+
'Y' : rng.uniform(size=(n_images,)+img_size).astype(np.float32)}
|
|
157
|
+
|
|
158
|
+
for name,images in data.items():
|
|
159
|
+
(tmpdir/name).mkdir(exist_ok=True)
|
|
160
|
+
for i,img in enumerate(images):
|
|
161
|
+
imsave(str(tmpdir/name/('img_%02d.tif'%i)),img)
|
|
162
|
+
|
|
163
|
+
raw_data = RawData.from_folder(str(tmpdir),['X'],'Y',img_axes)
|
|
164
|
+
assert raw_data.size == n_images
|
|
165
|
+
for i,(x,y,axes,mask) in enumerate(raw_data.generator()):
|
|
166
|
+
assert mask is None
|
|
167
|
+
assert axes == img_axes
|
|
168
|
+
assert any(np.allclose(x,u) for u in data['X'])
|
|
169
|
+
assert any(np.allclose(y,u) for u in data['Y'])
|
|
@@ -0,0 +1,462 @@
|
|
|
1
|
+
from __future__ import print_function, unicode_literals, absolute_import, division
|
|
2
|
+
from six.moves import range, zip, map, reduce, filter
|
|
3
|
+
|
|
4
|
+
from itertools import product
|
|
5
|
+
|
|
6
|
+
# import warnings
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pytest
|
|
9
|
+
from csbdeep.data import NoNormalizer, NoResizer
|
|
10
|
+
from csbdeep.internals.predict import tile_overlap
|
|
11
|
+
from csbdeep.utils.tf import IS_KERAS_3_PLUS, BACKEND as K
|
|
12
|
+
|
|
13
|
+
from csbdeep.internals.nets import receptive_field_unet
|
|
14
|
+
from csbdeep.models import Config, CARE, UpsamplingCARE, IsotropicCARE
|
|
15
|
+
from csbdeep.models import ProjectionConfig, ProjectionCARE
|
|
16
|
+
from csbdeep.utils import axes_dict
|
|
17
|
+
from csbdeep.utils.six import FileNotFoundError
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def config_generator(cls=Config, **kwargs):
|
|
22
|
+
assert 'axes' in kwargs
|
|
23
|
+
keys, values = kwargs.keys(), kwargs.values()
|
|
24
|
+
values = [v if isinstance(v,(list,tuple)) else [v] for v in values]
|
|
25
|
+
for p in product(*values):
|
|
26
|
+
yield cls(**dict(zip(keys,p)))
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def test_config():
|
|
31
|
+
assert K.image_data_format() in ('channels_first','channels_last')
|
|
32
|
+
def _with_channel(axes):
|
|
33
|
+
axes = axes.upper()
|
|
34
|
+
if 'C' in axes:
|
|
35
|
+
return axes
|
|
36
|
+
return (axes+'C') if K.image_data_format() == 'channels_last' else ('C'+axes)
|
|
37
|
+
|
|
38
|
+
axes_list = [
|
|
39
|
+
('yx',_with_channel('YX')),
|
|
40
|
+
('ytx',_with_channel('YTX')),
|
|
41
|
+
('zyx',_with_channel('ZYX')),
|
|
42
|
+
('YX',_with_channel('YX')),
|
|
43
|
+
('XYZ',_with_channel('XYZ')),
|
|
44
|
+
('XYT',_with_channel('XYT')),
|
|
45
|
+
('SYX',_with_channel('YX')),
|
|
46
|
+
('SXYZ',_with_channel('XYZ')),
|
|
47
|
+
('SXTY',_with_channel('XTY')),
|
|
48
|
+
(_with_channel('YX'),_with_channel('YX')),
|
|
49
|
+
(_with_channel('XYZ'),_with_channel('XYZ')),
|
|
50
|
+
(_with_channel('XTY'),_with_channel('XTY')),
|
|
51
|
+
(_with_channel('SYX'),_with_channel('YX')),
|
|
52
|
+
(_with_channel('STYX'),_with_channel('TYX')),
|
|
53
|
+
(_with_channel('SXYZ'),_with_channel('XYZ')),
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
for (axes,axes_ref) in axes_list:
|
|
57
|
+
assert Config(axes).axes == axes_ref
|
|
58
|
+
|
|
59
|
+
with pytest.raises(ValueError):
|
|
60
|
+
Config('XYC')
|
|
61
|
+
Config('CXY')
|
|
62
|
+
with pytest.raises(ValueError):
|
|
63
|
+
Config('XYZC')
|
|
64
|
+
Config('CXYZ')
|
|
65
|
+
with pytest.raises(ValueError):
|
|
66
|
+
Config('XTYC')
|
|
67
|
+
Config('CXTY')
|
|
68
|
+
with pytest.raises(ValueError): Config('XYZT')
|
|
69
|
+
with pytest.raises(ValueError): Config('tXYZ')
|
|
70
|
+
with pytest.raises(ValueError): Config('XYS')
|
|
71
|
+
with pytest.raises(ValueError): Config('XSYZ')
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@pytest.mark.parametrize('config', config_generator(
|
|
76
|
+
axes = ['YX','ZYX'],
|
|
77
|
+
n_channel_in = [1,2],
|
|
78
|
+
n_channel_out = [1,2],
|
|
79
|
+
probabilistic = [False,True],
|
|
80
|
+
unet_residual = [False,True],
|
|
81
|
+
unet_n_depth = [1,2],
|
|
82
|
+
# unet_kern_size = [3],
|
|
83
|
+
# unet_n_first = [32],
|
|
84
|
+
# unet_last_activation = ['linear'],
|
|
85
|
+
# unet_input_shape = [(None, None, 1)],
|
|
86
|
+
#
|
|
87
|
+
# train_batch_size = [16],
|
|
88
|
+
# train_checkpoint = ['weights_best.h5'],
|
|
89
|
+
# train_epochs = [100],
|
|
90
|
+
# train_learning_rate = [0.0004],
|
|
91
|
+
# train_loss = ['mae'],
|
|
92
|
+
# train_reduce_lr = [{'factor': 0.5, 'patience': 10}],
|
|
93
|
+
# train_steps_per_epoch = [400],
|
|
94
|
+
# train_tensorboard = [True],
|
|
95
|
+
))
|
|
96
|
+
def test_model_build_and_export(tmpdir,config):
|
|
97
|
+
K.clear_session()
|
|
98
|
+
def _build():
|
|
99
|
+
with pytest.raises(FileNotFoundError):
|
|
100
|
+
CARE(None,basedir=str(tmpdir))
|
|
101
|
+
|
|
102
|
+
CARE(config,name='model',basedir=None)
|
|
103
|
+
with pytest.raises(ValueError):
|
|
104
|
+
CARE(None,basedir=None)
|
|
105
|
+
|
|
106
|
+
if IS_KERAS_3_PLUS:
|
|
107
|
+
with pytest.raises(NotImplementedError):
|
|
108
|
+
CARE(config,basedir=str(tmpdir)).export_TF()
|
|
109
|
+
else:
|
|
110
|
+
CARE(config,basedir=str(tmpdir)).export_TF()
|
|
111
|
+
|
|
112
|
+
with pytest.warns(UserWarning):
|
|
113
|
+
CARE(config,name='model',basedir=str(tmpdir))
|
|
114
|
+
CARE(config,name='model',basedir=str(tmpdir))
|
|
115
|
+
CARE(None,name='model',basedir=str(tmpdir))
|
|
116
|
+
if config.is_valid():
|
|
117
|
+
_build()
|
|
118
|
+
else:
|
|
119
|
+
with pytest.raises(ValueError):
|
|
120
|
+
_build()
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@pytest.mark.parametrize('config', filter(lambda c: c.is_valid(), config_generator(
|
|
125
|
+
axes = ['YX','ZYX'],
|
|
126
|
+
n_channel_in = [1,2],
|
|
127
|
+
n_channel_out = [1,2],
|
|
128
|
+
probabilistic = [False,True],
|
|
129
|
+
# unet_residual = [False,True],
|
|
130
|
+
unet_n_depth = [1],
|
|
131
|
+
|
|
132
|
+
unet_kern_size = [3],
|
|
133
|
+
unet_n_first = [4],
|
|
134
|
+
unet_last_activation = ['linear'],
|
|
135
|
+
# unet_input_shape = [(None, None, 1)],
|
|
136
|
+
|
|
137
|
+
train_loss = ['mae','laplace'],
|
|
138
|
+
train_epochs = [2],
|
|
139
|
+
train_steps_per_epoch = [2],
|
|
140
|
+
# train_learning_rate = [0.0004],
|
|
141
|
+
train_batch_size = [2],
|
|
142
|
+
# train_tensorboard = [True],
|
|
143
|
+
# train_checkpoint = ['weights_best.h5'],
|
|
144
|
+
# train_reduce_lr = [{'factor': 0.5, 'patience': 10}],
|
|
145
|
+
)))
|
|
146
|
+
def test_model_train(tmpdir,config):
|
|
147
|
+
rng = np.random.RandomState(42)
|
|
148
|
+
K.clear_session()
|
|
149
|
+
X = rng.uniform(size=(4,)+(32,)*config.n_dim+(config.n_channel_in,))
|
|
150
|
+
Y = rng.uniform(size=(4,)+(32,)*config.n_dim+(config.n_channel_out,))
|
|
151
|
+
model = CARE(config,basedir=str(tmpdir))
|
|
152
|
+
model.train(X,Y,(X,Y))
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
@pytest.mark.parametrize('config', filter(lambda c: c.is_valid(), config_generator(
|
|
157
|
+
axes = ['YX','ZYX'],
|
|
158
|
+
n_channel_in = [1,2],
|
|
159
|
+
n_channel_out = [1,2],
|
|
160
|
+
probabilistic = [False,True],
|
|
161
|
+
# unet_residual = [False,True],
|
|
162
|
+
unet_n_depth = [2],
|
|
163
|
+
|
|
164
|
+
unet_kern_size = [3],
|
|
165
|
+
unet_n_first = [4],
|
|
166
|
+
unet_last_activation = ['linear'],
|
|
167
|
+
# unet_input_shape = [(None, None, 1)],
|
|
168
|
+
)))
|
|
169
|
+
def test_model_predict(tmpdir,config):
|
|
170
|
+
rng = np.random.RandomState(42)
|
|
171
|
+
normalizer, resizer = NoNormalizer(), NoResizer()
|
|
172
|
+
|
|
173
|
+
K.clear_session()
|
|
174
|
+
model = CARE(config,basedir=str(tmpdir))
|
|
175
|
+
axes = config.axes
|
|
176
|
+
|
|
177
|
+
def _predict(imdims,axes):
|
|
178
|
+
img = rng.uniform(size=imdims)
|
|
179
|
+
# print(img.shape, axes, config.n_channel_out)
|
|
180
|
+
if config.probabilistic:
|
|
181
|
+
prob = model.predict_probabilistic(img, axes, normalizer, resizer)
|
|
182
|
+
mean, scale = prob.mean(), prob.scale()
|
|
183
|
+
assert mean.shape == scale.shape
|
|
184
|
+
else:
|
|
185
|
+
mean = model.predict(img, axes, normalizer, resizer)
|
|
186
|
+
|
|
187
|
+
if 'C' not in axes:
|
|
188
|
+
if config.n_channel_out == 1:
|
|
189
|
+
assert mean.shape == img.shape
|
|
190
|
+
else:
|
|
191
|
+
assert mean.shape == img.shape + (config.n_channel_out,)
|
|
192
|
+
else:
|
|
193
|
+
channel = axes_dict(axes)['C']
|
|
194
|
+
imdims[channel] = config.n_channel_out
|
|
195
|
+
assert mean.shape == tuple(imdims)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
imdims = list(rng.randint(20,40,size=config.n_dim))
|
|
199
|
+
div_n = 2**config.unet_n_depth
|
|
200
|
+
imdims = [(d//div_n)*div_n for d in imdims]
|
|
201
|
+
|
|
202
|
+
if config.n_channel_in == 1:
|
|
203
|
+
_predict(imdims,axes=axes.replace('C',''))
|
|
204
|
+
|
|
205
|
+
channel = rng.randint(0,config.n_dim)
|
|
206
|
+
imdims.insert(channel,config.n_channel_in)
|
|
207
|
+
_axes = axes.replace('C','')
|
|
208
|
+
_axes = _axes[:channel]+'C'+_axes[channel:]
|
|
209
|
+
_predict(imdims,axes=_axes)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
@pytest.mark.parametrize('config', filter(lambda c: c.is_valid(), config_generator(
|
|
214
|
+
axes = ['YX','ZYX'],
|
|
215
|
+
n_channel_in = [1,2],
|
|
216
|
+
n_channel_out = [1,2],
|
|
217
|
+
probabilistic = [False],
|
|
218
|
+
# unet_residual = [False,True],
|
|
219
|
+
unet_n_depth = [2,3],
|
|
220
|
+
unet_kern_size = [3,5],
|
|
221
|
+
|
|
222
|
+
unet_n_first = [4],
|
|
223
|
+
unet_last_activation = ['linear'],
|
|
224
|
+
# unet_input_shape = [(None, None, 1)],
|
|
225
|
+
)))
|
|
226
|
+
def test_model_predict_tiled(tmpdir,config):
|
|
227
|
+
"""
|
|
228
|
+
Test that tiled prediction yields the same
|
|
229
|
+
or similar result as compared to predicting
|
|
230
|
+
the whole image at once.
|
|
231
|
+
"""
|
|
232
|
+
rng = np.random.RandomState(42)
|
|
233
|
+
normalizer, resizer = NoNormalizer(), NoResizer()
|
|
234
|
+
|
|
235
|
+
K.clear_session()
|
|
236
|
+
model = CARE(config,basedir=str(tmpdir))
|
|
237
|
+
|
|
238
|
+
def _predict(imdims,axes,n_tiles):
|
|
239
|
+
img = rng.uniform(size=imdims)
|
|
240
|
+
# print(img.shape, axes)
|
|
241
|
+
mean, scale = model._predict_mean_and_scale(img, axes, normalizer, resizer, n_tiles=None)
|
|
242
|
+
mean_tiled, scale_tiled = model._predict_mean_and_scale(img, axes, normalizer, resizer, n_tiles=n_tiles)
|
|
243
|
+
assert mean.shape == mean_tiled.shape
|
|
244
|
+
if config.probabilistic:
|
|
245
|
+
assert scale.shape == scale_tiled.shape
|
|
246
|
+
error_max = np.max(np.abs(mean-mean_tiled))
|
|
247
|
+
# print('n, k, err = {0}, {1}x{1}, {2}'.format(model.config.unet_n_depth, model.config.unet_kern_size, error_max))
|
|
248
|
+
assert error_max < 1e-3
|
|
249
|
+
return mean, mean_tiled
|
|
250
|
+
|
|
251
|
+
imdims = list(rng.randint(50,70,size=config.n_dim))
|
|
252
|
+
if config.n_dim == 3:
|
|
253
|
+
imdims[0] = 16 # make one dim small, otherwise test takes too long
|
|
254
|
+
div_n = 2**config.unet_n_depth
|
|
255
|
+
imdims = [(d//div_n)*div_n for d in imdims]
|
|
256
|
+
|
|
257
|
+
imdims.insert(0,config.n_channel_in)
|
|
258
|
+
axes = 'C'+config.axes.replace('C','')
|
|
259
|
+
|
|
260
|
+
for n_tiles in (
|
|
261
|
+
-1, 1.2,
|
|
262
|
+
[1]+[1.2]*config.n_dim,
|
|
263
|
+
[1]*config.n_dim, # missing value for channel axis
|
|
264
|
+
[2]+[1]*config.n_dim, # >1 tiles for channel axis
|
|
265
|
+
):
|
|
266
|
+
with pytest.raises(ValueError):
|
|
267
|
+
_predict(imdims,axes,n_tiles)
|
|
268
|
+
|
|
269
|
+
for n_tiles in [list(rng.randint(1,5,size=config.n_dim)) for _ in range(3)]:
|
|
270
|
+
# print(imdims,axes,[1]+n_tiles)
|
|
271
|
+
if config.n_channel_in == 1:
|
|
272
|
+
_predict(imdims[1:],axes[1:],n_tiles)
|
|
273
|
+
_predict(imdims,axes,[1]+n_tiles)
|
|
274
|
+
|
|
275
|
+
# legacy api: tile only largest dimension
|
|
276
|
+
n_blocks = np.max(imdims) // div_n
|
|
277
|
+
for n_tiles in (2,5,n_blocks+1):
|
|
278
|
+
with pytest.warns(UserWarning):
|
|
279
|
+
if config.n_channel_in == 1:
|
|
280
|
+
_predict(imdims[1:],axes[1:],n_tiles)
|
|
281
|
+
_predict(imdims,axes,n_tiles)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
@pytest.mark.parametrize('n_depth', (1,2,3,4,5))
|
|
286
|
+
@pytest.mark.parametrize('kern_size', (3,5))
|
|
287
|
+
@pytest.mark.parametrize('pool_size', (1,2))
|
|
288
|
+
# TODO: (pool_size=2, kern_size=7, n_depth>=2): works on CPU, but fails on GPU! (at least in TF 2.3.1, 2.5.0, 2.6.0)
|
|
289
|
+
def test_tile_overlap(n_depth, kern_size, pool_size):
|
|
290
|
+
K.clear_session()
|
|
291
|
+
img_size = 1280 if pool_size > 1 else 160
|
|
292
|
+
rf_x, rf_y = receptive_field_unet(n_depth,kern_size,pool_size,n_dim=2,img_size=img_size)
|
|
293
|
+
assert rf_x == rf_y
|
|
294
|
+
rf = rf_x
|
|
295
|
+
assert np.abs(rf[0]-rf[1]) < 10
|
|
296
|
+
assert sum(rf)+1 < img_size
|
|
297
|
+
assert max(rf) == tile_overlap(n_depth,kern_size,pool_size)
|
|
298
|
+
# print("receptive field of n_depth %d and kernel size %d: %s"%(n_depth,kern_size,rf));
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
@pytest.mark.parametrize('config', filter(lambda c: c.is_valid(), config_generator(
|
|
303
|
+
axes = ['ZYX'],
|
|
304
|
+
n_channel_in = [1,2],
|
|
305
|
+
n_channel_out = [1,2],
|
|
306
|
+
probabilistic = [False,True],
|
|
307
|
+
# unet_residual = [False,True],
|
|
308
|
+
unet_n_depth = [1],
|
|
309
|
+
|
|
310
|
+
unet_kern_size = [3],
|
|
311
|
+
unet_n_first = [4],
|
|
312
|
+
unet_last_activation = ['linear'],
|
|
313
|
+
# unet_input_shape = [(None, None, 1)],
|
|
314
|
+
)))
|
|
315
|
+
@pytest.mark.parametrize('factor', (2.5,3))
|
|
316
|
+
def test_model_upsampling_predict(tmpdir,config,factor):
|
|
317
|
+
rng = np.random.RandomState(42)
|
|
318
|
+
|
|
319
|
+
K.clear_session()
|
|
320
|
+
model = UpsamplingCARE(config,basedir=None)
|
|
321
|
+
axes = config.axes
|
|
322
|
+
|
|
323
|
+
def _predict(imdims,axes):
|
|
324
|
+
img = rng.uniform(size=imdims)
|
|
325
|
+
if config.probabilistic:
|
|
326
|
+
prob = model.predict_probabilistic(img, axes, factor, None, None)
|
|
327
|
+
mean, scale = prob.mean(), prob.scale()
|
|
328
|
+
assert mean.shape == scale.shape
|
|
329
|
+
else:
|
|
330
|
+
mean = model.predict(img, axes, factor, None, None)
|
|
331
|
+
a = axes_dict(axes)['Z']
|
|
332
|
+
assert imdims[a]*factor == mean.shape[a]
|
|
333
|
+
|
|
334
|
+
imdims = list(rng.randint(20,40,size=config.n_dim))
|
|
335
|
+
div_n = 2**(config.unet_n_depth+1)
|
|
336
|
+
imdims = [(d//div_n)*div_n for d in imdims]
|
|
337
|
+
|
|
338
|
+
if config.n_channel_in == 1:
|
|
339
|
+
_predict(imdims,axes=axes.replace('C',''))
|
|
340
|
+
|
|
341
|
+
channel = rng.randint(0,config.n_dim)
|
|
342
|
+
imdims.insert(channel,config.n_channel_in)
|
|
343
|
+
_axes = axes.replace('C','')
|
|
344
|
+
_axes = _axes[:channel]+'C'+_axes[channel:]
|
|
345
|
+
_predict(imdims,axes=_axes)
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
@pytest.mark.parametrize('config', filter(lambda c: c.is_valid(), config_generator(
|
|
350
|
+
axes = ['YX'],
|
|
351
|
+
n_channel_in = [1,2],
|
|
352
|
+
n_channel_out = [1,2],
|
|
353
|
+
probabilistic = [False,True],
|
|
354
|
+
# unet_residual = [False,True],
|
|
355
|
+
unet_n_depth = [1],
|
|
356
|
+
|
|
357
|
+
unet_kern_size = [3],
|
|
358
|
+
unet_n_first = [4],
|
|
359
|
+
unet_last_activation = ['linear'],
|
|
360
|
+
# unet_input_shape = [(None, None, 1)],
|
|
361
|
+
)))
|
|
362
|
+
@pytest.mark.parametrize('factor', (2.5,3))
|
|
363
|
+
def test_model_isotropic_predict(tmpdir,config,factor):
|
|
364
|
+
rng = np.random.RandomState(42)
|
|
365
|
+
|
|
366
|
+
K.clear_session()
|
|
367
|
+
model = IsotropicCARE(config,basedir=None)
|
|
368
|
+
axes = config.axes+'Z'
|
|
369
|
+
|
|
370
|
+
def _predict(imdims,axes):
|
|
371
|
+
img = rng.uniform(size=imdims)
|
|
372
|
+
if config.probabilistic:
|
|
373
|
+
prob = model.predict_probabilistic(img, axes, factor, None, None)
|
|
374
|
+
mean, scale = prob.mean(), prob.scale()
|
|
375
|
+
assert mean.shape == scale.shape
|
|
376
|
+
else:
|
|
377
|
+
mean = model.predict(img, axes, factor, None, None)
|
|
378
|
+
a = axes_dict(axes)['Z']
|
|
379
|
+
assert imdims[a]*factor == mean.shape[a]
|
|
380
|
+
|
|
381
|
+
imdims = list(rng.randint(20,40,size=config.n_dim+1))
|
|
382
|
+
div_n = 2**(config.unet_n_depth+1)
|
|
383
|
+
imdims = [(d//div_n)*div_n for d in imdims]
|
|
384
|
+
|
|
385
|
+
if config.n_channel_in == 1:
|
|
386
|
+
_predict(imdims,axes=axes.replace('C',''))
|
|
387
|
+
|
|
388
|
+
channel = rng.randint(0,config.n_dim+1)
|
|
389
|
+
imdims.insert(channel,config.n_channel_in)
|
|
390
|
+
_axes = axes.replace('C','')
|
|
391
|
+
_axes = _axes[:channel]+'C'+_axes[channel:]
|
|
392
|
+
_predict(imdims,axes=_axes)
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
@pytest.mark.parametrize('config', filter(lambda c: c.is_valid(), config_generator(
|
|
397
|
+
ProjectionConfig,
|
|
398
|
+
axes = ['ZYX'],
|
|
399
|
+
n_channel_in = [1,2],
|
|
400
|
+
n_channel_out = [1,2],
|
|
401
|
+
probabilistic = [False,True],
|
|
402
|
+
# unet_residual = [False,True],
|
|
403
|
+
unet_n_depth = [1],
|
|
404
|
+
|
|
405
|
+
unet_kern_size = [3],
|
|
406
|
+
unet_n_first = [4],
|
|
407
|
+
unet_last_activation = ['linear'],
|
|
408
|
+
# unet_input_shape = [(None, None, 1)],
|
|
409
|
+
proj_n_depth = [2,4],
|
|
410
|
+
)))
|
|
411
|
+
def test_model_projection_predict(tmpdir,config):
|
|
412
|
+
rng = np.random.RandomState(42)
|
|
413
|
+
|
|
414
|
+
K.clear_session()
|
|
415
|
+
model = ProjectionCARE(config,basedir=None)
|
|
416
|
+
axes = config.axes
|
|
417
|
+
proj_axis = model.proj_params.axis
|
|
418
|
+
|
|
419
|
+
def _predict(imdims,axes):
|
|
420
|
+
img = rng.uniform(size=imdims)
|
|
421
|
+
n_tiles = [1]*len(axes)
|
|
422
|
+
ax = axes_dict(axes)
|
|
423
|
+
|
|
424
|
+
if config.probabilistic:
|
|
425
|
+
prob = model.predict_probabilistic(img, axes, None, None)
|
|
426
|
+
mean, scale = prob.mean(), prob.scale()
|
|
427
|
+
assert mean.shape == scale.shape
|
|
428
|
+
else:
|
|
429
|
+
mean = model.predict(img, axes, None, None)
|
|
430
|
+
|
|
431
|
+
n_tiles[ax['X']] = 3
|
|
432
|
+
n_tiles[ax['Y']] = 2
|
|
433
|
+
mean_tiled = model.predict(img, axes, None, None, n_tiles=n_tiles)
|
|
434
|
+
error_max = np.max(np.abs(mean-mean_tiled))
|
|
435
|
+
# print(n_tiles, error_max)
|
|
436
|
+
assert error_max < 1e-3
|
|
437
|
+
|
|
438
|
+
with pytest.raises(ValueError):
|
|
439
|
+
n_tiles[ax[proj_axis]] = 2
|
|
440
|
+
model.predict(img, axes, None, None, n_tiles=n_tiles)
|
|
441
|
+
|
|
442
|
+
shape_out = list(imdims)
|
|
443
|
+
if 'C' in axes:
|
|
444
|
+
shape_out[ax['C']] = config.n_channel_out
|
|
445
|
+
elif config.n_channel_out > 1:
|
|
446
|
+
shape_out.append(config.n_channel_out)
|
|
447
|
+
|
|
448
|
+
del shape_out[ax[proj_axis]]
|
|
449
|
+
assert tuple(shape_out) == mean.shape
|
|
450
|
+
|
|
451
|
+
imdims = list(rng.randint(30,50,size=config.n_dim))
|
|
452
|
+
# imdims = [10,1024,1024]
|
|
453
|
+
imdims = [(d//div_n)*div_n for d,div_n in zip(imdims,model._axes_div_by(axes))]
|
|
454
|
+
|
|
455
|
+
if config.n_channel_in == 1:
|
|
456
|
+
_predict(imdims,axes=axes.replace('C',''))
|
|
457
|
+
|
|
458
|
+
channel = rng.randint(0,config.n_dim)
|
|
459
|
+
imdims.insert(channel,config.n_channel_in)
|
|
460
|
+
_axes = axes.replace('C','')
|
|
461
|
+
_axes = _axes[:channel]+'C'+_axes[channel:]
|
|
462
|
+
_predict(imdims,axes=_axes)
|