senoquant 1.0.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. senoquant/__init__.py +6 -0
  2. senoquant/_reader.py +7 -0
  3. senoquant/_widget.py +33 -0
  4. senoquant/napari.yaml +83 -0
  5. senoquant/reader/__init__.py +5 -0
  6. senoquant/reader/core.py +369 -0
  7. senoquant/tabs/__init__.py +15 -0
  8. senoquant/tabs/batch/__init__.py +10 -0
  9. senoquant/tabs/batch/backend.py +641 -0
  10. senoquant/tabs/batch/config.py +270 -0
  11. senoquant/tabs/batch/frontend.py +1283 -0
  12. senoquant/tabs/batch/io.py +326 -0
  13. senoquant/tabs/batch/layers.py +86 -0
  14. senoquant/tabs/quantification/__init__.py +1 -0
  15. senoquant/tabs/quantification/backend.py +228 -0
  16. senoquant/tabs/quantification/features/__init__.py +80 -0
  17. senoquant/tabs/quantification/features/base.py +142 -0
  18. senoquant/tabs/quantification/features/marker/__init__.py +5 -0
  19. senoquant/tabs/quantification/features/marker/config.py +69 -0
  20. senoquant/tabs/quantification/features/marker/dialog.py +437 -0
  21. senoquant/tabs/quantification/features/marker/export.py +879 -0
  22. senoquant/tabs/quantification/features/marker/feature.py +119 -0
  23. senoquant/tabs/quantification/features/marker/morphology.py +285 -0
  24. senoquant/tabs/quantification/features/marker/rows.py +654 -0
  25. senoquant/tabs/quantification/features/marker/thresholding.py +46 -0
  26. senoquant/tabs/quantification/features/roi.py +346 -0
  27. senoquant/tabs/quantification/features/spots/__init__.py +5 -0
  28. senoquant/tabs/quantification/features/spots/config.py +62 -0
  29. senoquant/tabs/quantification/features/spots/dialog.py +477 -0
  30. senoquant/tabs/quantification/features/spots/export.py +1292 -0
  31. senoquant/tabs/quantification/features/spots/feature.py +112 -0
  32. senoquant/tabs/quantification/features/spots/morphology.py +279 -0
  33. senoquant/tabs/quantification/features/spots/rows.py +241 -0
  34. senoquant/tabs/quantification/frontend.py +815 -0
  35. senoquant/tabs/segmentation/__init__.py +1 -0
  36. senoquant/tabs/segmentation/backend.py +131 -0
  37. senoquant/tabs/segmentation/frontend.py +1009 -0
  38. senoquant/tabs/segmentation/models/__init__.py +5 -0
  39. senoquant/tabs/segmentation/models/base.py +146 -0
  40. senoquant/tabs/segmentation/models/cpsam/details.json +65 -0
  41. senoquant/tabs/segmentation/models/cpsam/model.py +150 -0
  42. senoquant/tabs/segmentation/models/default_2d/details.json +69 -0
  43. senoquant/tabs/segmentation/models/default_2d/model.py +664 -0
  44. senoquant/tabs/segmentation/models/default_3d/details.json +69 -0
  45. senoquant/tabs/segmentation/models/default_3d/model.py +682 -0
  46. senoquant/tabs/segmentation/models/hf.py +71 -0
  47. senoquant/tabs/segmentation/models/nuclear_dilation/__init__.py +1 -0
  48. senoquant/tabs/segmentation/models/nuclear_dilation/details.json +26 -0
  49. senoquant/tabs/segmentation/models/nuclear_dilation/model.py +96 -0
  50. senoquant/tabs/segmentation/models/perinuclear_rings/__init__.py +1 -0
  51. senoquant/tabs/segmentation/models/perinuclear_rings/details.json +34 -0
  52. senoquant/tabs/segmentation/models/perinuclear_rings/model.py +132 -0
  53. senoquant/tabs/segmentation/stardist_onnx_utils/__init__.py +2 -0
  54. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/__init__.py +3 -0
  55. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/__init__.py +6 -0
  56. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/generate.py +470 -0
  57. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/prepare.py +273 -0
  58. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/rawdata.py +112 -0
  59. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/transform.py +384 -0
  60. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/__init__.py +0 -0
  61. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/blocks.py +184 -0
  62. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/losses.py +79 -0
  63. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/nets.py +165 -0
  64. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/predict.py +467 -0
  65. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/probability.py +67 -0
  66. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/train.py +148 -0
  67. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/io/__init__.py +163 -0
  68. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/__init__.py +52 -0
  69. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/base_model.py +329 -0
  70. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_isotropic.py +160 -0
  71. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_projection.py +178 -0
  72. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_standard.py +446 -0
  73. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_upsampling.py +54 -0
  74. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/config.py +254 -0
  75. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/pretrained.py +119 -0
  76. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/__init__.py +0 -0
  77. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/care_predict.py +180 -0
  78. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/__init__.py +5 -0
  79. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/plot_utils.py +159 -0
  80. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/six.py +18 -0
  81. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/tf.py +644 -0
  82. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/utils.py +272 -0
  83. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/version.py +1 -0
  84. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/docs/source/conf.py +368 -0
  85. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/setup.py +68 -0
  86. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_datagen.py +169 -0
  87. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_models.py +462 -0
  88. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_utils.py +166 -0
  89. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tools/create_zip_contents.py +34 -0
  90. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/__init__.py +30 -0
  91. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/big.py +624 -0
  92. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/bioimageio_utils.py +494 -0
  93. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/data/__init__.py +39 -0
  94. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/__init__.py +10 -0
  95. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom2d.py +215 -0
  96. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom3d.py +349 -0
  97. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/matching.py +483 -0
  98. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/__init__.py +28 -0
  99. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/base.py +1217 -0
  100. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model2d.py +594 -0
  101. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model3d.py +696 -0
  102. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/nms.py +384 -0
  103. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/__init__.py +2 -0
  104. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/plot.py +74 -0
  105. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/render.py +298 -0
  106. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/rays3d.py +373 -0
  107. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/sample_patches.py +65 -0
  108. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/__init__.py +0 -0
  109. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict2d.py +90 -0
  110. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict3d.py +93 -0
  111. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/utils.py +408 -0
  112. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/version.py +1 -0
  113. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/__init__.py +45 -0
  114. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/__init__.py +17 -0
  115. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/cli.py +55 -0
  116. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/core.py +285 -0
  117. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/__init__.py +15 -0
  118. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/cli.py +36 -0
  119. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/divisibility.py +193 -0
  120. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/probe.py +100 -0
  121. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/receptive_field.py +182 -0
  122. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/rf_cli.py +48 -0
  123. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/valid_sizes.py +278 -0
  124. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/__init__.py +8 -0
  125. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/core.py +157 -0
  126. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/__init__.py +17 -0
  127. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/core.py +226 -0
  128. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/__init__.py +5 -0
  129. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/core.py +401 -0
  130. senoquant/tabs/settings/__init__.py +1 -0
  131. senoquant/tabs/settings/backend.py +29 -0
  132. senoquant/tabs/settings/frontend.py +19 -0
  133. senoquant/tabs/spots/__init__.py +1 -0
  134. senoquant/tabs/spots/backend.py +139 -0
  135. senoquant/tabs/spots/frontend.py +800 -0
  136. senoquant/tabs/spots/models/__init__.py +5 -0
  137. senoquant/tabs/spots/models/base.py +94 -0
  138. senoquant/tabs/spots/models/rmp/details.json +61 -0
  139. senoquant/tabs/spots/models/rmp/model.py +499 -0
  140. senoquant/tabs/spots/models/udwt/details.json +103 -0
  141. senoquant/tabs/spots/models/udwt/model.py +482 -0
  142. senoquant/utils.py +25 -0
  143. senoquant-1.0.0b1.dist-info/METADATA +193 -0
  144. senoquant-1.0.0b1.dist-info/RECORD +148 -0
  145. senoquant-1.0.0b1.dist-info/WHEEL +5 -0
  146. senoquant-1.0.0b1.dist-info/entry_points.txt +2 -0
  147. senoquant-1.0.0b1.dist-info/licenses/LICENSE +28 -0
  148. senoquant-1.0.0b1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,160 @@
1
+ from __future__ import print_function, unicode_literals, absolute_import, division
2
+
3
+ import numpy as np
4
+ from scipy.ndimage import zoom
5
+
6
+ from csbdeep.internals.probability import ProbabilisticPrediction
7
+ from .care_standard import CARE
8
+ from ..internals.predict import predict_direct
9
+ from ..data import PercentileNormalizer, PadAndCropResizer
10
+ from ..utils import _raise, axes_check_and_normalize
11
+
12
+
13
+ class IsotropicCARE(CARE):
14
+ """CARE network for isotropic image reconstruction.
15
+
16
+ Extends :class:`csbdeep.models.CARE` by replacing prediction
17
+ (:func:`predict`, :func:`predict_probabilistic`) to do isotropic reconstruction.
18
+ """
19
+
20
+ def predict(self, img, axes, factor, normalizer=PercentileNormalizer(), resizer=PadAndCropResizer(), batch_size=8):
21
+ """Apply neural network to raw image for isotropic reconstruction.
22
+
23
+ See :func:`CARE.predict` for documentation.
24
+
25
+ Parameters
26
+ ----------
27
+ factor : float
28
+ Upsampling factor for Z axis. It is important that this is chosen in correspondence
29
+ to the subsampling factor used during training data generation.
30
+ batch_size : int
31
+ Number of image slices that are processed together by the neural network.
32
+ Reduce this value if out of memory errors occur.
33
+
34
+ """
35
+ return self._predict_mean_and_scale(img, axes, factor, normalizer, resizer, batch_size)[0]
36
+
37
+
38
+ def predict_probabilistic(self, img, axes, factor, normalizer=PercentileNormalizer(), resizer=PadAndCropResizer(), batch_size=8):
39
+ """Apply neural network to raw image to predict probability distribution for isotropic restored image.
40
+
41
+ See :func:`CARE.predict_probabilistic` for documentation.
42
+
43
+ Parameters
44
+ ----------
45
+ factor : float
46
+ Upsampling factor for Z axis. It is important that this is chosen in correspondence
47
+ to the subsampling factor used during training data generation.
48
+ batch_size : int
49
+ Number of image slices that are processed together by the neural network.
50
+ Reduce this value if out of memory errors occur.
51
+
52
+ """
53
+ self.config.probabilistic or _raise(ValueError('This is not a probabilistic model.'))
54
+ mean, scale = self._predict_mean_and_scale(img, axes, factor, normalizer, resizer, batch_size)
55
+ return ProbabilisticPrediction(mean, scale)
56
+
57
+
58
+ def _predict_mean_and_scale(self, img, axes, factor, normalizer, resizer, batch_size):
59
+ """Apply neural network to raw image to restore isotropic resolution.
60
+
61
+ See :func:`predict` for parameter explanations.
62
+
63
+ Returns
64
+ -------
65
+ tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray` or None)
66
+ If model is probabilistic, returns a tuple `(mean, scale)` that defines the parameters
67
+ of per-pixel Laplace distributions. Otherwise, returns the restored image via a tuple `(restored,None)`
68
+
69
+ """
70
+ normalizer, resizer = self._check_normalizer_resizer(normalizer, resizer)
71
+ axes = axes_check_and_normalize(axes,img.ndim)
72
+ 'Z' in axes or _raise(ValueError())
73
+ axes_tmp = 'CZ' + axes.replace('Z','').replace('C','')
74
+ _permute_axes = self._make_permute_axes(axes, axes_tmp)
75
+ channel = 0
76
+
77
+ x = _permute_axes(img)
78
+
79
+ self.config.n_channel_in == x.shape[channel] or _raise(ValueError())
80
+ np.isscalar(factor) and factor > 0 or _raise(ValueError())
81
+
82
+ def scale_z(arr,factor):
83
+ return zoom(arr,(1,factor,1,1),order=1)
84
+
85
+ # normalize
86
+ x = normalizer.before(x,axes_tmp)
87
+
88
+ # scale z up (second axis)
89
+ x_scaled = scale_z(x,factor)
90
+
91
+ # resize: make (x,y,z) image dimensions divisible by power of 2 to allow downsampling steps in unet
92
+ x_scaled = resizer.before(x_scaled, axes_tmp, self._axes_div_by(axes_tmp))
93
+
94
+ # move channel to the end (axes_predict semantics)
95
+ x_scaled = np.moveaxis(x_scaled, channel, -1)
96
+ axes_predict = 'S' + axes_tmp[2:] + 'C'
97
+ channel = -1
98
+
99
+ # u1: first rotation and prediction
100
+ x_rot1 = self._rotate(x_scaled, axis=1, copy=False)
101
+ u_rot1 = predict_direct(self.keras_model, x_rot1, axes_predict, batch_size=batch_size, verbose=0)
102
+ u1 = self._rotate(u_rot1, -1, axis=1, copy=False)
103
+
104
+ # u2: second rotation and prediction
105
+ x_rot2 = self._rotate(self._rotate(x_scaled, axis=2, copy=False), axis=0, copy=False)
106
+ u_rot2 = predict_direct(self.keras_model, x_rot2, axes_predict, batch_size=batch_size, verbose=0)
107
+ u2 = self._rotate(self._rotate(u_rot2, -1, axis=0, copy=False), -1, axis=2, copy=False)
108
+
109
+ n_channel_predicted = self.config.n_channel_out * (2 if self.config.probabilistic else 1)
110
+ u_rot1.shape[channel] == n_channel_predicted or _raise(ValueError())
111
+ u_rot2.shape[channel] == n_channel_predicted or _raise(ValueError())
112
+
113
+ # move channel back to the front (axes_tmp semantics)
114
+ u1 = np.moveaxis(u1, channel, 0)
115
+ u2 = np.moveaxis(u2, channel, 0)
116
+ channel = 0
117
+
118
+ # resize after prediction
119
+ u1 = resizer.after(u1, axes_tmp)
120
+ u2 = resizer.after(u2, axes_tmp)
121
+
122
+ # combine u1 & u2
123
+ mean1, scale1 = self._mean_and_scale_from_prediction(u1,axis=channel)
124
+ mean2, scale2 = self._mean_and_scale_from_prediction(u2,axis=channel)
125
+ # avg = lambda u1,u2: (u1+u2)/2 # arithmetic mean
126
+ avg = lambda u1,u2: np.sqrt(np.maximum(u1,0)*np.maximum(u2,0)) # geometric mean
127
+ mean, scale = avg(mean1,mean2), None
128
+ if self.config.probabilistic:
129
+ scale = np.maximum(scale1,scale2)
130
+
131
+ if normalizer.do_after and self.config.n_channel_in==self.config.n_channel_out:
132
+ mean, scale = normalizer.after(mean, scale, axes_tmp)
133
+
134
+ mean, scale = _permute_axes(mean,undo=True), _permute_axes(scale,undo=True)
135
+
136
+ return mean, scale
137
+
138
+
139
+ @staticmethod
140
+ def _rotate(arr, k=1, axis=1, copy=True):
141
+ """Rotate by 90 degrees around the first 2 axes."""
142
+ if copy:
143
+ arr = arr.copy()
144
+
145
+ k = k % 4
146
+
147
+ arr = np.rollaxis(arr, axis, arr.ndim)
148
+
149
+ if k == 0:
150
+ res = arr
151
+ elif k == 1:
152
+ res = arr[::-1].swapaxes(0, 1)
153
+ elif k == 2:
154
+ res = arr[::-1, ::-1]
155
+ else:
156
+ res = arr.swapaxes(0, 1)[::-1]
157
+
158
+ res = np.rollaxis(res, -1, axis)
159
+ return res
160
+
@@ -0,0 +1,178 @@
1
+ # -*- coding: utf-8 -*-
2
+ from __future__ import print_function, unicode_literals, absolute_import, division
3
+
4
+ import numpy as np
5
+ from collections import namedtuple
6
+
7
+ from ..utils.tf import keras_import, BACKEND as K
8
+ Model = keras_import('models', 'Model')
9
+ Input, Conv3D, MaxPooling3D, UpSampling3D, Lambda, Multiply = keras_import('layers', 'Input', 'Conv3D', 'MaxPooling3D', 'UpSampling3D', 'Lambda', 'Multiply')
10
+ softmax = keras_import('activations', 'softmax')
11
+
12
+ from .care_standard import CARE
13
+ from .config import Config
14
+ from ..utils import _raise, axes_dict, axes_check_and_normalize
15
+ from ..internals import nets
16
+ from ..internals.predict import tile_overlap
17
+
18
+
19
+ class ProjectionConfig(Config):
20
+
21
+ def __init__(self, axes='ZYX', n_channel_in=1, n_channel_out=1, probabilistic=False, allow_new_parameters=False, **kwargs):
22
+ super(ProjectionConfig, self).__init__(axes, n_channel_in, n_channel_out, probabilistic)
23
+ ax = axes_dict(self.axes)
24
+ self.proj_axis = kwargs.get('proj_axis', 'Z')
25
+ self.proj_n_depth = 4
26
+ self.proj_n_filt = 8
27
+ self.proj_n_conv_per_depth = 1
28
+ self.proj_kern = tuple(3 if d==ax[self.proj_axis] else 3 for d in range(3))
29
+ self.proj_pool = tuple(1 if d==ax[self.proj_axis] else 2 for d in range(3))
30
+ self.update_parameters(allow_new_parameters, **kwargs)
31
+
32
+
33
+
34
+ class ProjectionCARE(CARE):
35
+ """CARE network for combined image restoration and projection of one dimension."""
36
+
37
+ @property
38
+ def proj_params(self):
39
+ assert self.config is not None
40
+ try:
41
+ return self._proj_params
42
+ except AttributeError:
43
+ # TODO: no need to be so cautious here, since there's now a dedicated ProjectionConfig class
44
+ p = {}
45
+ p['axis'] = vars(self.config).get('proj_axis', 'Z')
46
+ p['n_depth'] = int(vars(self.config).get('proj_n_depth', 4))
47
+ p['n_filt'] = int(vars(self.config).get('proj_n_filt', 8))
48
+ p['n_conv_per_depth'] = int(vars(self.config).get('proj_n_conv_per_depth', 1))
49
+ p['axis'] = axes_check_and_normalize(p['axis'],length=1)
50
+
51
+ ax = axes_dict(self.config.axes)
52
+ len(self.config.axes) == 4 or _raise(ValueError("model must take 3D input, but axes are {self.config.axes}.".format(self=self)))
53
+ ax[p['axis']] is not None or _raise(ValueError("projection axis {axis} not part of model axes {self.config.axes}".format(self=self,axis=p['axis'])))
54
+ self.config.axes[-1] == 'C' or _raise(ValueError())
55
+ (p['n_depth'] > 0 and p['n_filt'] > 0 and p['n_conv_per_depth'] > 0) or _raise(ValueError())
56
+
57
+ p['kern'] = tuple(vars(self.config).get('proj_kern', (3 if d==ax[p['axis']] else 3 for d in range(3))))
58
+ p['pool'] = tuple(vars(self.config).get('proj_pool', (1 if d==ax[p['axis']] else 2 for d in range(3))))
59
+ 3 == len(p['pool']) == len(p['kern']) or _raise(ValueError())
60
+ all(isinstance(v,int) and v > 0 for v in p['kern']) or _raise(ValueError())
61
+ all(isinstance(v,int) and v > 0 for v in p['pool']) or _raise(ValueError())
62
+
63
+ self._proj_params = namedtuple('ProjectionParameters',p.keys())(*p.values())
64
+ return self._proj_params
65
+
66
+
67
+
68
+ def _repr_extra(self):
69
+ return "├─ {self.proj_params}\n".format(self=self)
70
+
71
+
72
+
73
+ def _update_and_check_config(self):
74
+ assert self.config is not None
75
+ for k,v in self.proj_params._asdict().items():
76
+ setattr(self.config, 'proj_'+k, v)
77
+
78
+
79
+
80
+ def _build(self):
81
+ # get parameters
82
+ proj = self.proj_params
83
+ proj_axis = axes_dict(self.config.axes)[proj.axis]
84
+
85
+ # define surface projection network (3D -> 2D)
86
+ inp = u = Input(self.config.unet_input_shape)
87
+ def conv_layers(u):
88
+ for _ in range(proj.n_conv_per_depth):
89
+ u = Conv3D(proj.n_filt, proj.kern, padding='same', activation='relu')(u)
90
+ return u
91
+ # down
92
+ for _ in range(proj.n_depth):
93
+ u = conv_layers(u)
94
+ u = MaxPooling3D(proj.pool)(u)
95
+ # middle
96
+ u = conv_layers(u)
97
+ # up
98
+ for _ in range(proj.n_depth):
99
+ u = UpSampling3D(proj.pool)(u)
100
+ u = conv_layers(u)
101
+ u = Conv3D(1, proj.kern, padding='same', activation='linear')(u)
102
+ # convert learned features along Z to surface probabilities
103
+ # (add 1 to proj_axis because of batch dimension in tensorflow)
104
+ u = Lambda(lambda x: softmax(x, axis=1+proj_axis))(u)
105
+ # multiply Z probabilities with Z values in input stack
106
+ u = Multiply()([inp, u])
107
+ # perform surface projection by summing over weighted Z values
108
+ u = Lambda(lambda x: K.sum(x, axis=1+proj_axis))(u)
109
+ model_projection = Model(inp, u)
110
+
111
+ # define denoising network (2D -> 2D)
112
+ # (remove projected axis from input_shape)
113
+ input_shape = list(self.config.unet_input_shape)
114
+ del input_shape[proj_axis]
115
+ model_denoising = nets.common_unet(
116
+ n_dim = self.config.n_dim-1,
117
+ n_channel_out = self.config.n_channel_out,
118
+ prob_out = self.config.probabilistic,
119
+ residual = self.config.unet_residual,
120
+ n_depth = self.config.unet_n_depth,
121
+ kern_size = self.config.unet_kern_size,
122
+ n_first = self.config.unet_n_first,
123
+ last_activation = self.config.unet_last_activation,
124
+ )(tuple(input_shape))
125
+
126
+ # chain models together
127
+ return Model(inp, model_denoising(model_projection(inp)))
128
+
129
+
130
+
131
+ def train(self, X,Y, validation_data, **kwargs):
132
+ proj_axis = self.proj_params.axis
133
+ proj_axis = 1+axes_dict(self.config.axes)[proj_axis]
134
+ Y.shape[proj_axis] == 1 or _raise(ValueError())
135
+ Y = np.take(Y,0,axis=proj_axis)
136
+ try:
137
+ X_val, Y_val = validation_data
138
+ # Y_val.shape[proj_axis] == 1 or _raise(ValueError())
139
+ validation_data = X_val, np.take(Y_val,0,axis=proj_axis)
140
+ except:
141
+ pass
142
+ return super(ProjectionCARE, self).train(X,Y, validation_data, **kwargs)
143
+
144
+
145
+
146
+ def _axes_div_by(self, query_axes):
147
+ query_axes = axes_check_and_normalize(query_axes)
148
+ proj = self.proj_params
149
+ div_by = {
150
+ a : max(a_proj_pool**proj.n_depth, 1 if a==proj.axis else 2**self.config.unet_n_depth)
151
+ for a,a_proj_pool in zip(self.config.axes.replace('C',''),proj.pool)
152
+ }
153
+ return tuple(div_by.get(a,1) for a in query_axes)
154
+
155
+
156
+
157
+ def _axes_tile_overlap(self, query_axes):
158
+ query_axes = axes_check_and_normalize(query_axes)
159
+ proj = self.proj_params
160
+ unet_overlap = tile_overlap(self.config.unet_n_depth, self.config.unet_kern_size)
161
+ overlap = {
162
+ a : max(tile_overlap(proj.n_depth, a_proj_kern, a_proj_pool), unet_overlap) # approx
163
+ for a,a_proj_pool,a_proj_kern in zip(self.config.axes.replace('C',''),proj.pool,proj.kern)
164
+ if a != proj.axis
165
+ }
166
+ return tuple(overlap.get(a,0) for a in query_axes)
167
+
168
+
169
+
170
+ @property
171
+ def _axes_out(self):
172
+ return ''.join(a for a in self.config.axes if a != self.proj_params.axis)
173
+
174
+
175
+
176
+ @property
177
+ def _config_class(self):
178
+ return ProjectionConfig