senoquant 1.0.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- senoquant/__init__.py +6 -0
- senoquant/_reader.py +7 -0
- senoquant/_widget.py +33 -0
- senoquant/napari.yaml +83 -0
- senoquant/reader/__init__.py +5 -0
- senoquant/reader/core.py +369 -0
- senoquant/tabs/__init__.py +15 -0
- senoquant/tabs/batch/__init__.py +10 -0
- senoquant/tabs/batch/backend.py +641 -0
- senoquant/tabs/batch/config.py +270 -0
- senoquant/tabs/batch/frontend.py +1283 -0
- senoquant/tabs/batch/io.py +326 -0
- senoquant/tabs/batch/layers.py +86 -0
- senoquant/tabs/quantification/__init__.py +1 -0
- senoquant/tabs/quantification/backend.py +228 -0
- senoquant/tabs/quantification/features/__init__.py +80 -0
- senoquant/tabs/quantification/features/base.py +142 -0
- senoquant/tabs/quantification/features/marker/__init__.py +5 -0
- senoquant/tabs/quantification/features/marker/config.py +69 -0
- senoquant/tabs/quantification/features/marker/dialog.py +437 -0
- senoquant/tabs/quantification/features/marker/export.py +879 -0
- senoquant/tabs/quantification/features/marker/feature.py +119 -0
- senoquant/tabs/quantification/features/marker/morphology.py +285 -0
- senoquant/tabs/quantification/features/marker/rows.py +654 -0
- senoquant/tabs/quantification/features/marker/thresholding.py +46 -0
- senoquant/tabs/quantification/features/roi.py +346 -0
- senoquant/tabs/quantification/features/spots/__init__.py +5 -0
- senoquant/tabs/quantification/features/spots/config.py +62 -0
- senoquant/tabs/quantification/features/spots/dialog.py +477 -0
- senoquant/tabs/quantification/features/spots/export.py +1292 -0
- senoquant/tabs/quantification/features/spots/feature.py +112 -0
- senoquant/tabs/quantification/features/spots/morphology.py +279 -0
- senoquant/tabs/quantification/features/spots/rows.py +241 -0
- senoquant/tabs/quantification/frontend.py +815 -0
- senoquant/tabs/segmentation/__init__.py +1 -0
- senoquant/tabs/segmentation/backend.py +131 -0
- senoquant/tabs/segmentation/frontend.py +1009 -0
- senoquant/tabs/segmentation/models/__init__.py +5 -0
- senoquant/tabs/segmentation/models/base.py +146 -0
- senoquant/tabs/segmentation/models/cpsam/details.json +65 -0
- senoquant/tabs/segmentation/models/cpsam/model.py +150 -0
- senoquant/tabs/segmentation/models/default_2d/details.json +69 -0
- senoquant/tabs/segmentation/models/default_2d/model.py +664 -0
- senoquant/tabs/segmentation/models/default_3d/details.json +69 -0
- senoquant/tabs/segmentation/models/default_3d/model.py +682 -0
- senoquant/tabs/segmentation/models/hf.py +71 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/__init__.py +1 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/details.json +26 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/model.py +96 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/__init__.py +1 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/details.json +34 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/model.py +132 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/__init__.py +2 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/__init__.py +3 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/__init__.py +6 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/generate.py +470 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/prepare.py +273 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/rawdata.py +112 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/transform.py +384 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/blocks.py +184 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/losses.py +79 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/nets.py +165 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/predict.py +467 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/probability.py +67 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/train.py +148 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/io/__init__.py +163 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/__init__.py +52 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/base_model.py +329 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_isotropic.py +160 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_projection.py +178 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_standard.py +446 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_upsampling.py +54 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/config.py +254 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/pretrained.py +119 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/care_predict.py +180 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/__init__.py +5 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/plot_utils.py +159 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/six.py +18 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/tf.py +644 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/utils.py +272 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/version.py +1 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/docs/source/conf.py +368 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/setup.py +68 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_datagen.py +169 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_models.py +462 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_utils.py +166 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tools/create_zip_contents.py +34 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/__init__.py +30 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/big.py +624 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/bioimageio_utils.py +494 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/data/__init__.py +39 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/__init__.py +10 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom2d.py +215 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom3d.py +349 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/matching.py +483 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/__init__.py +28 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/base.py +1217 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model2d.py +594 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model3d.py +696 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/nms.py +384 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/__init__.py +2 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/plot.py +74 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/render.py +298 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/rays3d.py +373 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/sample_patches.py +65 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict2d.py +90 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict3d.py +93 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/utils.py +408 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/version.py +1 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/__init__.py +45 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/__init__.py +17 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/cli.py +55 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/core.py +285 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/__init__.py +15 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/cli.py +36 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/divisibility.py +193 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/probe.py +100 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/receptive_field.py +182 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/rf_cli.py +48 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/valid_sizes.py +278 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/__init__.py +8 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/core.py +157 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/__init__.py +17 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/core.py +226 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/__init__.py +5 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/core.py +401 -0
- senoquant/tabs/settings/__init__.py +1 -0
- senoquant/tabs/settings/backend.py +29 -0
- senoquant/tabs/settings/frontend.py +19 -0
- senoquant/tabs/spots/__init__.py +1 -0
- senoquant/tabs/spots/backend.py +139 -0
- senoquant/tabs/spots/frontend.py +800 -0
- senoquant/tabs/spots/models/__init__.py +5 -0
- senoquant/tabs/spots/models/base.py +94 -0
- senoquant/tabs/spots/models/rmp/details.json +61 -0
- senoquant/tabs/spots/models/rmp/model.py +499 -0
- senoquant/tabs/spots/models/udwt/details.json +103 -0
- senoquant/tabs/spots/models/udwt/model.py +482 -0
- senoquant/utils.py +25 -0
- senoquant-1.0.0b1.dist-info/METADATA +193 -0
- senoquant-1.0.0b1.dist-info/RECORD +148 -0
- senoquant-1.0.0b1.dist-info/WHEEL +5 -0
- senoquant-1.0.0b1.dist-info/entry_points.txt +2 -0
- senoquant-1.0.0b1.dist-info/licenses/LICENSE +28 -0
- senoquant-1.0.0b1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
from __future__ import print_function, unicode_literals, absolute_import, division
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from scipy.ndimage import zoom
|
|
5
|
+
|
|
6
|
+
from csbdeep.internals.probability import ProbabilisticPrediction
|
|
7
|
+
from .care_standard import CARE
|
|
8
|
+
from ..internals.predict import predict_direct
|
|
9
|
+
from ..data import PercentileNormalizer, PadAndCropResizer
|
|
10
|
+
from ..utils import _raise, axes_check_and_normalize
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class IsotropicCARE(CARE):
|
|
14
|
+
"""CARE network for isotropic image reconstruction.
|
|
15
|
+
|
|
16
|
+
Extends :class:`csbdeep.models.CARE` by replacing prediction
|
|
17
|
+
(:func:`predict`, :func:`predict_probabilistic`) to do isotropic reconstruction.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def predict(self, img, axes, factor, normalizer=PercentileNormalizer(), resizer=PadAndCropResizer(), batch_size=8):
|
|
21
|
+
"""Apply neural network to raw image for isotropic reconstruction.
|
|
22
|
+
|
|
23
|
+
See :func:`CARE.predict` for documentation.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
factor : float
|
|
28
|
+
Upsampling factor for Z axis. It is important that this is chosen in correspondence
|
|
29
|
+
to the subsampling factor used during training data generation.
|
|
30
|
+
batch_size : int
|
|
31
|
+
Number of image slices that are processed together by the neural network.
|
|
32
|
+
Reduce this value if out of memory errors occur.
|
|
33
|
+
|
|
34
|
+
"""
|
|
35
|
+
return self._predict_mean_and_scale(img, axes, factor, normalizer, resizer, batch_size)[0]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def predict_probabilistic(self, img, axes, factor, normalizer=PercentileNormalizer(), resizer=PadAndCropResizer(), batch_size=8):
|
|
39
|
+
"""Apply neural network to raw image to predict probability distribution for isotropic restored image.
|
|
40
|
+
|
|
41
|
+
See :func:`CARE.predict_probabilistic` for documentation.
|
|
42
|
+
|
|
43
|
+
Parameters
|
|
44
|
+
----------
|
|
45
|
+
factor : float
|
|
46
|
+
Upsampling factor for Z axis. It is important that this is chosen in correspondence
|
|
47
|
+
to the subsampling factor used during training data generation.
|
|
48
|
+
batch_size : int
|
|
49
|
+
Number of image slices that are processed together by the neural network.
|
|
50
|
+
Reduce this value if out of memory errors occur.
|
|
51
|
+
|
|
52
|
+
"""
|
|
53
|
+
self.config.probabilistic or _raise(ValueError('This is not a probabilistic model.'))
|
|
54
|
+
mean, scale = self._predict_mean_and_scale(img, axes, factor, normalizer, resizer, batch_size)
|
|
55
|
+
return ProbabilisticPrediction(mean, scale)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _predict_mean_and_scale(self, img, axes, factor, normalizer, resizer, batch_size):
|
|
59
|
+
"""Apply neural network to raw image to restore isotropic resolution.
|
|
60
|
+
|
|
61
|
+
See :func:`predict` for parameter explanations.
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray` or None)
|
|
66
|
+
If model is probabilistic, returns a tuple `(mean, scale)` that defines the parameters
|
|
67
|
+
of per-pixel Laplace distributions. Otherwise, returns the restored image via a tuple `(restored,None)`
|
|
68
|
+
|
|
69
|
+
"""
|
|
70
|
+
normalizer, resizer = self._check_normalizer_resizer(normalizer, resizer)
|
|
71
|
+
axes = axes_check_and_normalize(axes,img.ndim)
|
|
72
|
+
'Z' in axes or _raise(ValueError())
|
|
73
|
+
axes_tmp = 'CZ' + axes.replace('Z','').replace('C','')
|
|
74
|
+
_permute_axes = self._make_permute_axes(axes, axes_tmp)
|
|
75
|
+
channel = 0
|
|
76
|
+
|
|
77
|
+
x = _permute_axes(img)
|
|
78
|
+
|
|
79
|
+
self.config.n_channel_in == x.shape[channel] or _raise(ValueError())
|
|
80
|
+
np.isscalar(factor) and factor > 0 or _raise(ValueError())
|
|
81
|
+
|
|
82
|
+
def scale_z(arr,factor):
|
|
83
|
+
return zoom(arr,(1,factor,1,1),order=1)
|
|
84
|
+
|
|
85
|
+
# normalize
|
|
86
|
+
x = normalizer.before(x,axes_tmp)
|
|
87
|
+
|
|
88
|
+
# scale z up (second axis)
|
|
89
|
+
x_scaled = scale_z(x,factor)
|
|
90
|
+
|
|
91
|
+
# resize: make (x,y,z) image dimensions divisible by power of 2 to allow downsampling steps in unet
|
|
92
|
+
x_scaled = resizer.before(x_scaled, axes_tmp, self._axes_div_by(axes_tmp))
|
|
93
|
+
|
|
94
|
+
# move channel to the end (axes_predict semantics)
|
|
95
|
+
x_scaled = np.moveaxis(x_scaled, channel, -1)
|
|
96
|
+
axes_predict = 'S' + axes_tmp[2:] + 'C'
|
|
97
|
+
channel = -1
|
|
98
|
+
|
|
99
|
+
# u1: first rotation and prediction
|
|
100
|
+
x_rot1 = self._rotate(x_scaled, axis=1, copy=False)
|
|
101
|
+
u_rot1 = predict_direct(self.keras_model, x_rot1, axes_predict, batch_size=batch_size, verbose=0)
|
|
102
|
+
u1 = self._rotate(u_rot1, -1, axis=1, copy=False)
|
|
103
|
+
|
|
104
|
+
# u2: second rotation and prediction
|
|
105
|
+
x_rot2 = self._rotate(self._rotate(x_scaled, axis=2, copy=False), axis=0, copy=False)
|
|
106
|
+
u_rot2 = predict_direct(self.keras_model, x_rot2, axes_predict, batch_size=batch_size, verbose=0)
|
|
107
|
+
u2 = self._rotate(self._rotate(u_rot2, -1, axis=0, copy=False), -1, axis=2, copy=False)
|
|
108
|
+
|
|
109
|
+
n_channel_predicted = self.config.n_channel_out * (2 if self.config.probabilistic else 1)
|
|
110
|
+
u_rot1.shape[channel] == n_channel_predicted or _raise(ValueError())
|
|
111
|
+
u_rot2.shape[channel] == n_channel_predicted or _raise(ValueError())
|
|
112
|
+
|
|
113
|
+
# move channel back to the front (axes_tmp semantics)
|
|
114
|
+
u1 = np.moveaxis(u1, channel, 0)
|
|
115
|
+
u2 = np.moveaxis(u2, channel, 0)
|
|
116
|
+
channel = 0
|
|
117
|
+
|
|
118
|
+
# resize after prediction
|
|
119
|
+
u1 = resizer.after(u1, axes_tmp)
|
|
120
|
+
u2 = resizer.after(u2, axes_tmp)
|
|
121
|
+
|
|
122
|
+
# combine u1 & u2
|
|
123
|
+
mean1, scale1 = self._mean_and_scale_from_prediction(u1,axis=channel)
|
|
124
|
+
mean2, scale2 = self._mean_and_scale_from_prediction(u2,axis=channel)
|
|
125
|
+
# avg = lambda u1,u2: (u1+u2)/2 # arithmetic mean
|
|
126
|
+
avg = lambda u1,u2: np.sqrt(np.maximum(u1,0)*np.maximum(u2,0)) # geometric mean
|
|
127
|
+
mean, scale = avg(mean1,mean2), None
|
|
128
|
+
if self.config.probabilistic:
|
|
129
|
+
scale = np.maximum(scale1,scale2)
|
|
130
|
+
|
|
131
|
+
if normalizer.do_after and self.config.n_channel_in==self.config.n_channel_out:
|
|
132
|
+
mean, scale = normalizer.after(mean, scale, axes_tmp)
|
|
133
|
+
|
|
134
|
+
mean, scale = _permute_axes(mean,undo=True), _permute_axes(scale,undo=True)
|
|
135
|
+
|
|
136
|
+
return mean, scale
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
@staticmethod
|
|
140
|
+
def _rotate(arr, k=1, axis=1, copy=True):
|
|
141
|
+
"""Rotate by 90 degrees around the first 2 axes."""
|
|
142
|
+
if copy:
|
|
143
|
+
arr = arr.copy()
|
|
144
|
+
|
|
145
|
+
k = k % 4
|
|
146
|
+
|
|
147
|
+
arr = np.rollaxis(arr, axis, arr.ndim)
|
|
148
|
+
|
|
149
|
+
if k == 0:
|
|
150
|
+
res = arr
|
|
151
|
+
elif k == 1:
|
|
152
|
+
res = arr[::-1].swapaxes(0, 1)
|
|
153
|
+
elif k == 2:
|
|
154
|
+
res = arr[::-1, ::-1]
|
|
155
|
+
else:
|
|
156
|
+
res = arr.swapaxes(0, 1)[::-1]
|
|
157
|
+
|
|
158
|
+
res = np.rollaxis(res, -1, axis)
|
|
159
|
+
return res
|
|
160
|
+
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
from __future__ import print_function, unicode_literals, absolute_import, division
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from collections import namedtuple
|
|
6
|
+
|
|
7
|
+
from ..utils.tf import keras_import, BACKEND as K
|
|
8
|
+
Model = keras_import('models', 'Model')
|
|
9
|
+
Input, Conv3D, MaxPooling3D, UpSampling3D, Lambda, Multiply = keras_import('layers', 'Input', 'Conv3D', 'MaxPooling3D', 'UpSampling3D', 'Lambda', 'Multiply')
|
|
10
|
+
softmax = keras_import('activations', 'softmax')
|
|
11
|
+
|
|
12
|
+
from .care_standard import CARE
|
|
13
|
+
from .config import Config
|
|
14
|
+
from ..utils import _raise, axes_dict, axes_check_and_normalize
|
|
15
|
+
from ..internals import nets
|
|
16
|
+
from ..internals.predict import tile_overlap
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ProjectionConfig(Config):
|
|
20
|
+
|
|
21
|
+
def __init__(self, axes='ZYX', n_channel_in=1, n_channel_out=1, probabilistic=False, allow_new_parameters=False, **kwargs):
|
|
22
|
+
super(ProjectionConfig, self).__init__(axes, n_channel_in, n_channel_out, probabilistic)
|
|
23
|
+
ax = axes_dict(self.axes)
|
|
24
|
+
self.proj_axis = kwargs.get('proj_axis', 'Z')
|
|
25
|
+
self.proj_n_depth = 4
|
|
26
|
+
self.proj_n_filt = 8
|
|
27
|
+
self.proj_n_conv_per_depth = 1
|
|
28
|
+
self.proj_kern = tuple(3 if d==ax[self.proj_axis] else 3 for d in range(3))
|
|
29
|
+
self.proj_pool = tuple(1 if d==ax[self.proj_axis] else 2 for d in range(3))
|
|
30
|
+
self.update_parameters(allow_new_parameters, **kwargs)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class ProjectionCARE(CARE):
|
|
35
|
+
"""CARE network for combined image restoration and projection of one dimension."""
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def proj_params(self):
|
|
39
|
+
assert self.config is not None
|
|
40
|
+
try:
|
|
41
|
+
return self._proj_params
|
|
42
|
+
except AttributeError:
|
|
43
|
+
# TODO: no need to be so cautious here, since there's now a dedicated ProjectionConfig class
|
|
44
|
+
p = {}
|
|
45
|
+
p['axis'] = vars(self.config).get('proj_axis', 'Z')
|
|
46
|
+
p['n_depth'] = int(vars(self.config).get('proj_n_depth', 4))
|
|
47
|
+
p['n_filt'] = int(vars(self.config).get('proj_n_filt', 8))
|
|
48
|
+
p['n_conv_per_depth'] = int(vars(self.config).get('proj_n_conv_per_depth', 1))
|
|
49
|
+
p['axis'] = axes_check_and_normalize(p['axis'],length=1)
|
|
50
|
+
|
|
51
|
+
ax = axes_dict(self.config.axes)
|
|
52
|
+
len(self.config.axes) == 4 or _raise(ValueError("model must take 3D input, but axes are {self.config.axes}.".format(self=self)))
|
|
53
|
+
ax[p['axis']] is not None or _raise(ValueError("projection axis {axis} not part of model axes {self.config.axes}".format(self=self,axis=p['axis'])))
|
|
54
|
+
self.config.axes[-1] == 'C' or _raise(ValueError())
|
|
55
|
+
(p['n_depth'] > 0 and p['n_filt'] > 0 and p['n_conv_per_depth'] > 0) or _raise(ValueError())
|
|
56
|
+
|
|
57
|
+
p['kern'] = tuple(vars(self.config).get('proj_kern', (3 if d==ax[p['axis']] else 3 for d in range(3))))
|
|
58
|
+
p['pool'] = tuple(vars(self.config).get('proj_pool', (1 if d==ax[p['axis']] else 2 for d in range(3))))
|
|
59
|
+
3 == len(p['pool']) == len(p['kern']) or _raise(ValueError())
|
|
60
|
+
all(isinstance(v,int) and v > 0 for v in p['kern']) or _raise(ValueError())
|
|
61
|
+
all(isinstance(v,int) and v > 0 for v in p['pool']) or _raise(ValueError())
|
|
62
|
+
|
|
63
|
+
self._proj_params = namedtuple('ProjectionParameters',p.keys())(*p.values())
|
|
64
|
+
return self._proj_params
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _repr_extra(self):
|
|
69
|
+
return "├─ {self.proj_params}\n".format(self=self)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _update_and_check_config(self):
|
|
74
|
+
assert self.config is not None
|
|
75
|
+
for k,v in self.proj_params._asdict().items():
|
|
76
|
+
setattr(self.config, 'proj_'+k, v)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _build(self):
|
|
81
|
+
# get parameters
|
|
82
|
+
proj = self.proj_params
|
|
83
|
+
proj_axis = axes_dict(self.config.axes)[proj.axis]
|
|
84
|
+
|
|
85
|
+
# define surface projection network (3D -> 2D)
|
|
86
|
+
inp = u = Input(self.config.unet_input_shape)
|
|
87
|
+
def conv_layers(u):
|
|
88
|
+
for _ in range(proj.n_conv_per_depth):
|
|
89
|
+
u = Conv3D(proj.n_filt, proj.kern, padding='same', activation='relu')(u)
|
|
90
|
+
return u
|
|
91
|
+
# down
|
|
92
|
+
for _ in range(proj.n_depth):
|
|
93
|
+
u = conv_layers(u)
|
|
94
|
+
u = MaxPooling3D(proj.pool)(u)
|
|
95
|
+
# middle
|
|
96
|
+
u = conv_layers(u)
|
|
97
|
+
# up
|
|
98
|
+
for _ in range(proj.n_depth):
|
|
99
|
+
u = UpSampling3D(proj.pool)(u)
|
|
100
|
+
u = conv_layers(u)
|
|
101
|
+
u = Conv3D(1, proj.kern, padding='same', activation='linear')(u)
|
|
102
|
+
# convert learned features along Z to surface probabilities
|
|
103
|
+
# (add 1 to proj_axis because of batch dimension in tensorflow)
|
|
104
|
+
u = Lambda(lambda x: softmax(x, axis=1+proj_axis))(u)
|
|
105
|
+
# multiply Z probabilities with Z values in input stack
|
|
106
|
+
u = Multiply()([inp, u])
|
|
107
|
+
# perform surface projection by summing over weighted Z values
|
|
108
|
+
u = Lambda(lambda x: K.sum(x, axis=1+proj_axis))(u)
|
|
109
|
+
model_projection = Model(inp, u)
|
|
110
|
+
|
|
111
|
+
# define denoising network (2D -> 2D)
|
|
112
|
+
# (remove projected axis from input_shape)
|
|
113
|
+
input_shape = list(self.config.unet_input_shape)
|
|
114
|
+
del input_shape[proj_axis]
|
|
115
|
+
model_denoising = nets.common_unet(
|
|
116
|
+
n_dim = self.config.n_dim-1,
|
|
117
|
+
n_channel_out = self.config.n_channel_out,
|
|
118
|
+
prob_out = self.config.probabilistic,
|
|
119
|
+
residual = self.config.unet_residual,
|
|
120
|
+
n_depth = self.config.unet_n_depth,
|
|
121
|
+
kern_size = self.config.unet_kern_size,
|
|
122
|
+
n_first = self.config.unet_n_first,
|
|
123
|
+
last_activation = self.config.unet_last_activation,
|
|
124
|
+
)(tuple(input_shape))
|
|
125
|
+
|
|
126
|
+
# chain models together
|
|
127
|
+
return Model(inp, model_denoising(model_projection(inp)))
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def train(self, X,Y, validation_data, **kwargs):
|
|
132
|
+
proj_axis = self.proj_params.axis
|
|
133
|
+
proj_axis = 1+axes_dict(self.config.axes)[proj_axis]
|
|
134
|
+
Y.shape[proj_axis] == 1 or _raise(ValueError())
|
|
135
|
+
Y = np.take(Y,0,axis=proj_axis)
|
|
136
|
+
try:
|
|
137
|
+
X_val, Y_val = validation_data
|
|
138
|
+
# Y_val.shape[proj_axis] == 1 or _raise(ValueError())
|
|
139
|
+
validation_data = X_val, np.take(Y_val,0,axis=proj_axis)
|
|
140
|
+
except:
|
|
141
|
+
pass
|
|
142
|
+
return super(ProjectionCARE, self).train(X,Y, validation_data, **kwargs)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _axes_div_by(self, query_axes):
|
|
147
|
+
query_axes = axes_check_and_normalize(query_axes)
|
|
148
|
+
proj = self.proj_params
|
|
149
|
+
div_by = {
|
|
150
|
+
a : max(a_proj_pool**proj.n_depth, 1 if a==proj.axis else 2**self.config.unet_n_depth)
|
|
151
|
+
for a,a_proj_pool in zip(self.config.axes.replace('C',''),proj.pool)
|
|
152
|
+
}
|
|
153
|
+
return tuple(div_by.get(a,1) for a in query_axes)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def _axes_tile_overlap(self, query_axes):
|
|
158
|
+
query_axes = axes_check_and_normalize(query_axes)
|
|
159
|
+
proj = self.proj_params
|
|
160
|
+
unet_overlap = tile_overlap(self.config.unet_n_depth, self.config.unet_kern_size)
|
|
161
|
+
overlap = {
|
|
162
|
+
a : max(tile_overlap(proj.n_depth, a_proj_kern, a_proj_pool), unet_overlap) # approx
|
|
163
|
+
for a,a_proj_pool,a_proj_kern in zip(self.config.axes.replace('C',''),proj.pool,proj.kern)
|
|
164
|
+
if a != proj.axis
|
|
165
|
+
}
|
|
166
|
+
return tuple(overlap.get(a,0) for a in query_axes)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
@property
|
|
171
|
+
def _axes_out(self):
|
|
172
|
+
return ''.join(a for a in self.config.axes if a != self.proj_params.axis)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
@property
|
|
177
|
+
def _config_class(self):
|
|
178
|
+
return ProjectionConfig
|