senoquant 1.0.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. senoquant/__init__.py +6 -0
  2. senoquant/_reader.py +7 -0
  3. senoquant/_widget.py +33 -0
  4. senoquant/napari.yaml +83 -0
  5. senoquant/reader/__init__.py +5 -0
  6. senoquant/reader/core.py +369 -0
  7. senoquant/tabs/__init__.py +15 -0
  8. senoquant/tabs/batch/__init__.py +10 -0
  9. senoquant/tabs/batch/backend.py +641 -0
  10. senoquant/tabs/batch/config.py +270 -0
  11. senoquant/tabs/batch/frontend.py +1283 -0
  12. senoquant/tabs/batch/io.py +326 -0
  13. senoquant/tabs/batch/layers.py +86 -0
  14. senoquant/tabs/quantification/__init__.py +1 -0
  15. senoquant/tabs/quantification/backend.py +228 -0
  16. senoquant/tabs/quantification/features/__init__.py +80 -0
  17. senoquant/tabs/quantification/features/base.py +142 -0
  18. senoquant/tabs/quantification/features/marker/__init__.py +5 -0
  19. senoquant/tabs/quantification/features/marker/config.py +69 -0
  20. senoquant/tabs/quantification/features/marker/dialog.py +437 -0
  21. senoquant/tabs/quantification/features/marker/export.py +879 -0
  22. senoquant/tabs/quantification/features/marker/feature.py +119 -0
  23. senoquant/tabs/quantification/features/marker/morphology.py +285 -0
  24. senoquant/tabs/quantification/features/marker/rows.py +654 -0
  25. senoquant/tabs/quantification/features/marker/thresholding.py +46 -0
  26. senoquant/tabs/quantification/features/roi.py +346 -0
  27. senoquant/tabs/quantification/features/spots/__init__.py +5 -0
  28. senoquant/tabs/quantification/features/spots/config.py +62 -0
  29. senoquant/tabs/quantification/features/spots/dialog.py +477 -0
  30. senoquant/tabs/quantification/features/spots/export.py +1292 -0
  31. senoquant/tabs/quantification/features/spots/feature.py +112 -0
  32. senoquant/tabs/quantification/features/spots/morphology.py +279 -0
  33. senoquant/tabs/quantification/features/spots/rows.py +241 -0
  34. senoquant/tabs/quantification/frontend.py +815 -0
  35. senoquant/tabs/segmentation/__init__.py +1 -0
  36. senoquant/tabs/segmentation/backend.py +131 -0
  37. senoquant/tabs/segmentation/frontend.py +1009 -0
  38. senoquant/tabs/segmentation/models/__init__.py +5 -0
  39. senoquant/tabs/segmentation/models/base.py +146 -0
  40. senoquant/tabs/segmentation/models/cpsam/details.json +65 -0
  41. senoquant/tabs/segmentation/models/cpsam/model.py +150 -0
  42. senoquant/tabs/segmentation/models/default_2d/details.json +69 -0
  43. senoquant/tabs/segmentation/models/default_2d/model.py +664 -0
  44. senoquant/tabs/segmentation/models/default_3d/details.json +69 -0
  45. senoquant/tabs/segmentation/models/default_3d/model.py +682 -0
  46. senoquant/tabs/segmentation/models/hf.py +71 -0
  47. senoquant/tabs/segmentation/models/nuclear_dilation/__init__.py +1 -0
  48. senoquant/tabs/segmentation/models/nuclear_dilation/details.json +26 -0
  49. senoquant/tabs/segmentation/models/nuclear_dilation/model.py +96 -0
  50. senoquant/tabs/segmentation/models/perinuclear_rings/__init__.py +1 -0
  51. senoquant/tabs/segmentation/models/perinuclear_rings/details.json +34 -0
  52. senoquant/tabs/segmentation/models/perinuclear_rings/model.py +132 -0
  53. senoquant/tabs/segmentation/stardist_onnx_utils/__init__.py +2 -0
  54. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/__init__.py +3 -0
  55. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/__init__.py +6 -0
  56. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/generate.py +470 -0
  57. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/prepare.py +273 -0
  58. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/rawdata.py +112 -0
  59. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/transform.py +384 -0
  60. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/__init__.py +0 -0
  61. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/blocks.py +184 -0
  62. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/losses.py +79 -0
  63. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/nets.py +165 -0
  64. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/predict.py +467 -0
  65. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/probability.py +67 -0
  66. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/train.py +148 -0
  67. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/io/__init__.py +163 -0
  68. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/__init__.py +52 -0
  69. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/base_model.py +329 -0
  70. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_isotropic.py +160 -0
  71. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_projection.py +178 -0
  72. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_standard.py +446 -0
  73. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_upsampling.py +54 -0
  74. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/config.py +254 -0
  75. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/pretrained.py +119 -0
  76. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/__init__.py +0 -0
  77. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/care_predict.py +180 -0
  78. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/__init__.py +5 -0
  79. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/plot_utils.py +159 -0
  80. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/six.py +18 -0
  81. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/tf.py +644 -0
  82. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/utils.py +272 -0
  83. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/version.py +1 -0
  84. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/docs/source/conf.py +368 -0
  85. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/setup.py +68 -0
  86. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_datagen.py +169 -0
  87. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_models.py +462 -0
  88. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_utils.py +166 -0
  89. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tools/create_zip_contents.py +34 -0
  90. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/__init__.py +30 -0
  91. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/big.py +624 -0
  92. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/bioimageio_utils.py +494 -0
  93. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/data/__init__.py +39 -0
  94. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/__init__.py +10 -0
  95. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom2d.py +215 -0
  96. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom3d.py +349 -0
  97. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/matching.py +483 -0
  98. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/__init__.py +28 -0
  99. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/base.py +1217 -0
  100. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model2d.py +594 -0
  101. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model3d.py +696 -0
  102. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/nms.py +384 -0
  103. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/__init__.py +2 -0
  104. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/plot.py +74 -0
  105. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/render.py +298 -0
  106. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/rays3d.py +373 -0
  107. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/sample_patches.py +65 -0
  108. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/__init__.py +0 -0
  109. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict2d.py +90 -0
  110. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict3d.py +93 -0
  111. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/utils.py +408 -0
  112. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/version.py +1 -0
  113. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/__init__.py +45 -0
  114. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/__init__.py +17 -0
  115. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/cli.py +55 -0
  116. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/core.py +285 -0
  117. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/__init__.py +15 -0
  118. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/cli.py +36 -0
  119. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/divisibility.py +193 -0
  120. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/probe.py +100 -0
  121. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/receptive_field.py +182 -0
  122. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/rf_cli.py +48 -0
  123. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/valid_sizes.py +278 -0
  124. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/__init__.py +8 -0
  125. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/core.py +157 -0
  126. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/__init__.py +17 -0
  127. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/core.py +226 -0
  128. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/__init__.py +5 -0
  129. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/core.py +401 -0
  130. senoquant/tabs/settings/__init__.py +1 -0
  131. senoquant/tabs/settings/backend.py +29 -0
  132. senoquant/tabs/settings/frontend.py +19 -0
  133. senoquant/tabs/spots/__init__.py +1 -0
  134. senoquant/tabs/spots/backend.py +139 -0
  135. senoquant/tabs/spots/frontend.py +800 -0
  136. senoquant/tabs/spots/models/__init__.py +5 -0
  137. senoquant/tabs/spots/models/base.py +94 -0
  138. senoquant/tabs/spots/models/rmp/details.json +61 -0
  139. senoquant/tabs/spots/models/rmp/model.py +499 -0
  140. senoquant/tabs/spots/models/udwt/details.json +103 -0
  141. senoquant/tabs/spots/models/udwt/model.py +482 -0
  142. senoquant/utils.py +25 -0
  143. senoquant-1.0.0b1.dist-info/METADATA +193 -0
  144. senoquant-1.0.0b1.dist-info/RECORD +148 -0
  145. senoquant-1.0.0b1.dist-info/WHEEL +5 -0
  146. senoquant-1.0.0b1.dist-info/entry_points.txt +2 -0
  147. senoquant-1.0.0b1.dist-info/licenses/LICENSE +28 -0
  148. senoquant-1.0.0b1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,384 @@
1
+ # -*- coding: utf-8 -*-
2
+ from __future__ import print_function, unicode_literals, absolute_import, division
3
+ from six.moves import range, zip, map, reduce, filter
4
+ from six import string_types
5
+
6
+ import numpy as np
7
+ from collections import namedtuple
8
+ import sys, os, warnings
9
+
10
+ from ..utils import _raise, consume, axes_check_and_normalize, axes_dict, move_image_axes
11
+
12
+
13
+
14
+ class Transform(namedtuple('Transform',('name','generator','size'))):
15
+ """Extension of :func:`collections.namedtuple` with three fields: `name`, `generator`, and `size`.
16
+
17
+ Parameters
18
+ ----------
19
+ name : str
20
+ Name of the applied transformation.
21
+ generator : function
22
+ Function that takes a generator as input and itself returns a generator; input and returned
23
+ generator have the same structure as that of :class:`RawData`.
24
+ The purpose of the returned generator is to augment the images provided by the input generator
25
+ through additional transformations.
26
+ It is important that the returned generator also includes every input tuple unchanged.
27
+ size : int
28
+ Number of transformations applied to every image (obtained from the input generator).
29
+ """
30
+
31
+ @staticmethod
32
+ def identity():
33
+ """
34
+ Returns
35
+ -------
36
+ Transform
37
+ Identity transformation that passes every input through unchanged.
38
+ """
39
+ def _gen(inputs):
40
+ for d in inputs:
41
+ yield d
42
+ return Transform('Identity', _gen, 1)
43
+
44
+ # def flip(axis):
45
+ # """TODO"""
46
+ # def _gen(inputs):
47
+ # for x,y,m_in in inputs:
48
+ # axis < x.ndim or _raise(ValueError())
49
+ # yield x, y, m_in
50
+ # yield np.flip(x,axis), np.flip(y,axis), None if m_in is None else np.flip(m_in,axis)
51
+ # return Transform('Flip (axis=%d)'%axis, _gen, 2)
52
+
53
+
54
+
55
+ def anisotropic_distortions(
56
+ subsample,
57
+ psf,
58
+ psf_axes = None,
59
+ poisson_noise = False,
60
+ gauss_sigma = 0,
61
+ subsample_axis = 'X',
62
+ yield_target = 'source',
63
+ crop_threshold = 0.2,
64
+ ):
65
+ """Simulate anisotropic distortions.
66
+
67
+ Modify the first image (obtained from input generator) along one axis to mimic the
68
+ distortions that typically occur due to low resolution along the Z axis.
69
+ Note that the modified image is finally upscaled to obtain the same resolution
70
+ as the unmodified input image and is yielded as the 'source' image (see :class:`RawData`).
71
+ The mask from the input generator is simply passed through.
72
+
73
+ The following operations are applied to the image (in order):
74
+
75
+ 1. Convolution with PSF
76
+ 2. Poisson noise
77
+ 3. Gaussian noise
78
+ 4. Subsampling along ``subsample_axis``
79
+ 5. Upsampling along ``subsample_axis`` (to former size).
80
+
81
+
82
+ Parameters
83
+ ----------
84
+ subsample : float
85
+ Subsampling factor to mimic distortions along Z.
86
+ psf : :class:`numpy.ndarray` or None
87
+ Point spread function (PSF) that is supposed to mimic blurring
88
+ of the microscope due to reduced axial resolution. Set to ``None`` to disable.
89
+ psf_axes : str or None
90
+ Axes of the PSF. If ``None``, psf axes are assumed to be the same as of the image
91
+ that it is applied to.
92
+ poisson_noise : bool
93
+ Flag to indicate whether Poisson noise should be applied to the image.
94
+ gauss_sigma : float
95
+ Standard deviation of white Gaussian noise to be added to the image.
96
+ subsample_axis : str
97
+ Subsampling image axis (default X).
98
+ yield_target : str
99
+ Which image from the input generator should be yielded by the generator ('source' or 'target').
100
+ If 'source', the unmodified input/source image (from which the distorted image is computed)
101
+ is yielded as the target image. If 'target', the target image from the input generator is simply
102
+ passed through.
103
+ crop_threshold : float
104
+ The subsample factor must evenly divide the image size along the subsampling axis to prevent
105
+ potential image misalignment. If this is not the case the subsample factors are
106
+ modified and the raw image may be cropped along the subsampling axis
107
+ up to a fraction indicated by `crop_threshold`.
108
+
109
+ Returns
110
+ -------
111
+ Transform
112
+ Returns a :class:`Transform` object intended to be used with :func:`create_patches`.
113
+
114
+ Raises
115
+ ------
116
+ ValueError
117
+ Various reasons.
118
+
119
+ """
120
+ zoom_order = 1
121
+
122
+ (np.isscalar(subsample) and subsample >= 1) or _raise(ValueError('subsample must be >= 1'))
123
+ _subsample = subsample
124
+
125
+ subsample_axis = axes_check_and_normalize(subsample_axis)
126
+ len(subsample_axis)==1 or _raise(ValueError())
127
+
128
+ psf is None or isinstance(psf,np.ndarray) or _raise(ValueError())
129
+ if psf_axes is not None:
130
+ psf_axes = axes_check_and_normalize(psf_axes)
131
+
132
+ 0 < crop_threshold < 1 or _raise(ValueError())
133
+
134
+ yield_target in ('source','target') or _raise(ValueError())
135
+
136
+ if psf is None and yield_target == 'source':
137
+ warnings.warn(
138
+ "It is strongly recommended to use an appropriate PSF to "
139
+ "mimic the optical effects of the microscope. "
140
+ "We found that training with synthesized anisotropic images "
141
+ "that were created without a PSF "
142
+ "can sometimes lead to unwanted artifacts in the reconstructed images."
143
+ )
144
+
145
+
146
+ def _make_normalize_data(axes_in):
147
+ """Move X to front of image."""
148
+ axes_in = axes_check_and_normalize(axes_in)
149
+ axes_out = subsample_axis
150
+ # (a in axes_in for a in 'XY') or _raise(ValueError('X and/or Y axis missing.'))
151
+ # add axis in axes_in to axes_out (if it doesn't exist there)
152
+ axes_out += ''.join(a for a in axes_in if a not in axes_out)
153
+
154
+ def _normalize_data(data,undo=False):
155
+ if undo:
156
+ return move_image_axes(data, axes_out, axes_in)
157
+ else:
158
+ return move_image_axes(data, axes_in, axes_out)
159
+ return _normalize_data
160
+
161
+
162
+ def _scale_down_up(data,subsample):
163
+ from scipy.ndimage import zoom
164
+ with warnings.catch_warnings():
165
+ warnings.simplefilter("ignore", UserWarning)
166
+ factor = np.ones(data.ndim)
167
+ factor[0] = subsample
168
+ return zoom(zoom(data, 1/factor, order=0),
169
+ factor, order=zoom_order)
170
+
171
+
172
+ def _adjust_subsample(d,s,c):
173
+ """length d, subsample s, tolerated crop loss fraction c"""
174
+ from fractions import Fraction
175
+
176
+ def crop_size(n_digits,frac):
177
+ _s = round(s,n_digits)
178
+ _div = frac.denominator
179
+ s_multiple_max = np.floor(d/_s)
180
+ s_multiple = (s_multiple_max//_div)*_div
181
+ # print(n_digits, _s,_div,s_multiple)
182
+ size = s_multiple * _s
183
+ assert np.allclose(size,round(size))
184
+ return size
185
+
186
+ def decimals(v,n_digits=None):
187
+ if n_digits is not None:
188
+ v = round(v,n_digits)
189
+ s = str(v)
190
+ assert '.' in s
191
+ decimals = s[1+s.find('.'):]
192
+ return int(decimals), len(decimals)
193
+
194
+ s = float(s)
195
+ dec, n_digits = decimals(s)
196
+ frac = Fraction(dec,10**n_digits)
197
+ # a multiple of s that is also an integer number must be
198
+ # divisible by the denominator of the fraction that represents the decimal points
199
+
200
+ # round off decimals points if needed
201
+ while n_digits > 0 and (d-crop_size(n_digits,frac))/d > c:
202
+ n_digits -= 1
203
+ frac = Fraction(decimals(s,n_digits)[0], 10**n_digits)
204
+
205
+ size = crop_size(n_digits,frac)
206
+ if size == 0 or (d-size)/d > c:
207
+ raise ValueError("subsample factor %g too large (crop_threshold=%g)" % (s,c))
208
+
209
+ return round(s,n_digits), int(round(crop_size(n_digits,frac)))
210
+
211
+
212
+ def _make_divisible_by_subsample(x,size):
213
+ def _split_slice(v):
214
+ return slice(None) if v==0 else slice(v//2,-(v-v//2))
215
+ slices = [slice(None) for _ in x.shape]
216
+ slices[0] = _split_slice(x.shape[0]-size)
217
+ return x[tuple(slices)]
218
+
219
+
220
+ def _generator(inputs):
221
+ for img,y,axes,mask in inputs:
222
+
223
+ if yield_target == 'source':
224
+ y is None or np.allclose(img,y) or warnings.warn("ignoring 'target' image from input generator")
225
+ target = img
226
+ else:
227
+ target = y
228
+
229
+ img.shape == target.shape or _raise(ValueError())
230
+
231
+ axes = axes_check_and_normalize(axes)
232
+ _normalize_data = _make_normalize_data(axes)
233
+ # print(axes, img.shape)
234
+
235
+ x = img.astype(np.float32, copy=False)
236
+
237
+ if psf is not None:
238
+ from scipy.signal import fftconvolve
239
+ # print("blurring with psf")
240
+ _psf = psf.astype(np.float32,copy=False)
241
+ np.min(_psf) >= 0 or _raise(ValueError('psf has negative values.'))
242
+ _psf /= np.sum(_psf)
243
+ if psf_axes is not None:
244
+ _psf = move_image_axes(_psf, psf_axes, axes, True)
245
+ x.ndim == _psf.ndim or _raise(ValueError('image and psf must have the same number of dimensions.'))
246
+
247
+ if 'C' in axes:
248
+ ch = axes_dict(axes)['C']
249
+ n_channels = x.shape[ch]
250
+ # convolve with psf separately for every channel
251
+ if _psf.shape[ch] == 1 and n_channels > 1:
252
+ warnings.warn('applying same psf to every channel of the image.')
253
+ if _psf.shape[ch] in (1,n_channels):
254
+ x = np.stack([
255
+ fftconvolve(
256
+ np.take(x, i,axis=ch),
257
+ np.take(_psf,i,axis=ch,mode='clip'),
258
+ mode='same'
259
+ )
260
+ for i in range(n_channels)
261
+ ],axis=ch)
262
+ else:
263
+ raise ValueError('number of psf channels (%d) incompatible with number of image channels (%d).' % (_psf.shape[ch],n_channels))
264
+ else:
265
+ x = fftconvolve(x, _psf, mode='same')
266
+
267
+ if bool(poisson_noise):
268
+ # print("apply poisson noise")
269
+ x = np.random.poisson(np.maximum(0,x).astype(int)).astype(np.float32)
270
+
271
+ if gauss_sigma > 0:
272
+ # print("adding gaussian noise with sigma = ", gauss_sigma)
273
+ noise = np.random.normal(0,gauss_sigma,size=x.shape).astype(np.float32)
274
+ x = np.maximum(0,x+noise)
275
+
276
+ if _subsample != 1:
277
+ # print("down and upsampling X by factor %s" % str(_subsample))
278
+ target = _normalize_data(target)
279
+ x = _normalize_data(x)
280
+
281
+ subsample, subsample_size = _adjust_subsample(x.shape[0],_subsample,crop_threshold)
282
+ # print(subsample, subsample_size)
283
+ if _subsample != subsample:
284
+ warnings.warn('changing subsample from %s to %s' % (str(_subsample),str(subsample)))
285
+
286
+ target = _make_divisible_by_subsample(target,subsample_size)
287
+ x = _make_divisible_by_subsample(x, subsample_size)
288
+ x = _scale_down_up(x,subsample)
289
+
290
+ assert x.shape == target.shape, (x.shape, target.shape)
291
+
292
+ target = _normalize_data(target,undo=True)
293
+ x = _normalize_data(x, undo=True)
294
+
295
+ yield x, target, axes, mask
296
+
297
+
298
+ return Transform('Anisotropic distortion (along %s axis)' % subsample_axis, _generator, 1)
299
+
300
+
301
+
302
+ def permute_axes(axes):
303
+ """Transformation to permute images axes.
304
+
305
+ Parameters
306
+ ----------
307
+ axes : str
308
+ Target axes, to which the input images will be permuted.
309
+
310
+ Returns
311
+ -------
312
+ Transform
313
+ Returns a :class:`Transform` object whose `generator` will
314
+ perform the axes permutation of `x`, `y`, and `mask`.
315
+
316
+ """
317
+ axes = axes_check_and_normalize(axes)
318
+ def _generator(inputs):
319
+ for x, y, axes_in, mask in inputs:
320
+ axes_in = axes_check_and_normalize(axes_in)
321
+ if axes_in != axes:
322
+ # print('permuting axes from %s to %s' % (axes_in,axes))
323
+ x = move_image_axes(x, axes_in, axes, True)
324
+ y = move_image_axes(y, axes_in, axes, True)
325
+ if mask is not None:
326
+ mask = move_image_axes(mask, axes_in, axes)
327
+ yield x, y, axes, mask
328
+
329
+ return Transform('Permute axes to %s' % axes, _generator, 1)
330
+
331
+
332
+
333
+ def crop_images(slices):
334
+ """Transformation to crop all images (and mask).
335
+
336
+ Note that slices must be compatible with the image size.
337
+
338
+ Parameters
339
+ ----------
340
+ slices : list or tuple of slice
341
+ List of slices to apply to each dimension of the image.
342
+
343
+ Returns
344
+ -------
345
+ Transform
346
+ Returns a :class:`Transform` object whose `generator` will
347
+ perform image cropping of `x`, `y`, and `mask`.
348
+
349
+ """
350
+ slices = tuple(slices)
351
+ def _generator(inputs):
352
+ for x, y, axes, mask in inputs:
353
+ axes = axes_check_and_normalize(axes)
354
+ len(axes) == len(slices) or _raise(ValueError())
355
+ yield x[slices], y[slices], axes, (mask[slices] if mask is not None else None)
356
+
357
+ return Transform('Crop images (%s)' % str(slices), _generator, 1)
358
+
359
+
360
+
361
+ def broadcast_target(target_axes=None):
362
+ """Transformation to broadcast the target image to the shape of the source image.
363
+
364
+ Parameters
365
+ ----------
366
+ target_axes : str
367
+ Axes of the target image before broadcasting.
368
+ If `None`, assumed to be the same as for the source image.
369
+
370
+ Returns
371
+ -------
372
+ Transform
373
+ Returns a :class:`Transform` object whose `generator` will
374
+ perform broadcasting of `y` to match the shape of `x`.
375
+
376
+ """
377
+ def _generator(inputs):
378
+ for x, y, axes_x, mask in inputs:
379
+ if target_axes is not None:
380
+ axes_y = axes_check_and_normalize(target_axes,length=y.ndim)
381
+ y = move_image_axes(y, axes_y, axes_x, True)
382
+ yield x, np.broadcast_to(y,x.shape), axes_x, mask
383
+
384
+ return Transform('Broadcast target image to the shape of source', _generator, 1)
@@ -0,0 +1,184 @@
1
+ from __future__ import print_function, unicode_literals, absolute_import, division
2
+ from six.moves import range, zip, map, reduce, filter
3
+
4
+ from ..utils import _raise, backend_channels_last
5
+
6
+ from ..utils.tf import keras_import, BACKEND as K
7
+ Conv2D, MaxPooling2D, UpSampling2D, Conv3D, MaxPooling3D, UpSampling3D, Cropping2D, Cropping3D, Concatenate, Add, Dropout, Activation, BatchNormalization = \
8
+ keras_import('layers', 'Conv2D', 'MaxPooling2D', 'UpSampling2D', 'Conv3D', 'MaxPooling3D', 'UpSampling3D', 'Cropping2D', 'Cropping3D', 'Concatenate', 'Add', 'Dropout', 'Activation', 'BatchNormalization')
9
+
10
+
11
+
12
+ def conv_block2(n_filter, n1, n2,
13
+ activation="relu",
14
+ border_mode="same",
15
+ dropout=0.0,
16
+ batch_norm=False,
17
+ init="glorot_uniform",
18
+ **kwargs):
19
+
20
+ def _func(lay):
21
+ if batch_norm:
22
+ s = Conv2D(n_filter, (n1, n2), padding=border_mode, kernel_initializer=init, **kwargs)(lay)
23
+ s = BatchNormalization()(s)
24
+ s = Activation(activation)(s)
25
+ else:
26
+ s = Conv2D(n_filter, (n1, n2), padding=border_mode, kernel_initializer=init, activation=activation, **kwargs)(lay)
27
+ if dropout is not None and dropout > 0:
28
+ s = Dropout(dropout)(s)
29
+ return s
30
+
31
+ return _func
32
+
33
+
34
+
35
+ def conv_block3(n_filter, n1, n2, n3,
36
+ activation="relu",
37
+ border_mode="same",
38
+ dropout=0.0,
39
+ batch_norm=False,
40
+ init="glorot_uniform",
41
+ **kwargs):
42
+
43
+ def _func(lay):
44
+ if batch_norm:
45
+ s = Conv3D(n_filter, (n1, n2, n3), padding=border_mode, kernel_initializer=init, **kwargs)(lay)
46
+ s = BatchNormalization()(s)
47
+ s = Activation(activation)(s)
48
+ else:
49
+ s = Conv3D(n_filter, (n1, n2, n3), padding=border_mode, kernel_initializer=init, activation=activation, **kwargs)(lay)
50
+ if dropout is not None and dropout > 0:
51
+ s = Dropout(dropout)(s)
52
+ return s
53
+
54
+ return _func
55
+
56
+
57
+
58
+ def unet_block(n_depth=2, n_filter_base=16, kernel_size=(3,3), n_conv_per_depth=2,
59
+ activation="relu",
60
+ batch_norm=False,
61
+ dropout=0.0,
62
+ last_activation=None,
63
+ pool=(2,2),
64
+ kernel_init="glorot_uniform",
65
+ expansion=2,
66
+ prefix=''):
67
+
68
+ if len(pool) != len(kernel_size):
69
+ raise ValueError('kernel and pool sizes must match.')
70
+ n_dim = len(kernel_size)
71
+ if n_dim not in (2,3):
72
+ raise ValueError('unet_block only 2d or 3d.')
73
+
74
+ conv_block = conv_block2 if n_dim == 2 else conv_block3
75
+ pooling = MaxPooling2D if n_dim == 2 else MaxPooling3D
76
+ upsampling = UpSampling2D if n_dim == 2 else UpSampling3D
77
+
78
+ if last_activation is None:
79
+ last_activation = activation
80
+
81
+ channel_axis = -1 if backend_channels_last() else 1
82
+
83
+ def _name(s):
84
+ return prefix+s
85
+
86
+ def _func(input):
87
+ skip_layers = []
88
+ layer = input
89
+
90
+ # down ...
91
+ for n in range(n_depth):
92
+ for i in range(n_conv_per_depth):
93
+ layer = conv_block(int(n_filter_base * expansion ** n), *kernel_size,
94
+ dropout=dropout,
95
+ activation=activation,
96
+ init=kernel_init,
97
+ batch_norm=batch_norm, name=_name("down_level_%s_no_%s" % (n, i)))(layer)
98
+ skip_layers.append(layer)
99
+ layer = pooling(pool, name=_name("max_%s" % n))(layer)
100
+
101
+ # middle
102
+ for i in range(n_conv_per_depth - 1):
103
+ layer = conv_block(int(n_filter_base * expansion ** n_depth), *kernel_size,
104
+ dropout=dropout,
105
+ init=kernel_init,
106
+ activation=activation,
107
+ batch_norm=batch_norm, name=_name("middle_%s" % i))(layer)
108
+
109
+ layer = conv_block(int(n_filter_base * expansion ** max(0, n_depth - 1)), *kernel_size,
110
+ dropout=dropout,
111
+ activation=activation,
112
+ init=kernel_init,
113
+ batch_norm=batch_norm, name=_name("middle_%s" % n_conv_per_depth))(layer)
114
+
115
+ # ...and up with skip layers
116
+ for n in reversed(range(n_depth)):
117
+ layer = Concatenate(axis=channel_axis)([upsampling(pool)(layer), skip_layers[n]])
118
+ for i in range(n_conv_per_depth - 1):
119
+ layer = conv_block(int(n_filter_base * expansion ** n), *kernel_size,
120
+ dropout=dropout,
121
+ init=kernel_init,
122
+ activation=activation,
123
+ batch_norm=batch_norm, name=_name("up_level_%s_no_%s" % (n, i)))(layer)
124
+
125
+ layer = conv_block(int(n_filter_base * expansion ** max(0, n - 1)), *kernel_size,
126
+ dropout=dropout,
127
+ init=kernel_init,
128
+ activation=activation if n > 0 else last_activation,
129
+ batch_norm=batch_norm, name=_name("up_level_%s_no_%s" % (n, n_conv_per_depth)))(layer)
130
+
131
+ return layer
132
+
133
+ return _func
134
+
135
+
136
+
137
+ def resnet_block(n_filter, kernel_size=(3,3), pool=(1,1), n_conv_per_block=2,
138
+ batch_norm=False, kernel_initializer='he_normal', activation='relu',
139
+ last_conv_bias_if_batch_norm=False):
140
+ """
141
+ The default value for 'last_conv_bias_if_batch_norm' is 'False' for legacy reasons only.
142
+ """
143
+
144
+ n_conv_per_block >= 2 or _raise(ValueError('required: n_conv_per_block >= 2'))
145
+ len(pool) == len(kernel_size) or _raise(ValueError('kernel and pool sizes must match.'))
146
+ n_dim = len(kernel_size)
147
+ n_dim in (2,3) or _raise(ValueError('resnet_block only 2d or 3d.'))
148
+
149
+ conv_layer = Conv2D if n_dim == 2 else Conv3D
150
+ conv_kwargs = dict (
151
+ padding = 'same',
152
+ kernel_initializer = kernel_initializer,
153
+ )
154
+ channel_axis = -1 if backend_channels_last() else 1
155
+
156
+ def f(inp):
157
+ # first conv to prepare filter sizes and strides...
158
+ x = conv_layer(n_filter, kernel_size, strides=pool, use_bias=not batch_norm, **conv_kwargs)(inp)
159
+ if batch_norm:
160
+ x = BatchNormalization(axis=channel_axis)(x)
161
+ x = Activation(activation)(x)
162
+
163
+ # middle conv
164
+ for _ in range(n_conv_per_block-2):
165
+ x = conv_layer(n_filter, kernel_size, use_bias=not batch_norm, **conv_kwargs)(x)
166
+ if batch_norm:
167
+ x = BatchNormalization(axis=channel_axis)(x)
168
+ x = Activation(activation)(x)
169
+
170
+ # last conv with no activation for residual addition
171
+ x = conv_layer(n_filter, kernel_size, use_bias=not batch_norm, **conv_kwargs)(x)
172
+ if batch_norm:
173
+ x = BatchNormalization(axis=channel_axis)(x)
174
+
175
+ # transform input if not compatible...
176
+ if any(p!=1 for p in pool) or n_filter != K.int_shape(inp)[-1]:
177
+ last_conv_bias = last_conv_bias_if_batch_norm if batch_norm else True
178
+ inp = conv_layer(n_filter, (1,)*n_dim, strides=pool, use_bias=last_conv_bias, **conv_kwargs)(inp)
179
+
180
+ x = Add()([inp, x])
181
+ x = Activation(activation)(x)
182
+ return x
183
+
184
+ return f
@@ -0,0 +1,79 @@
1
+ from __future__ import print_function, unicode_literals, absolute_import, division
2
+ from six.moves import range, zip, map, reduce, filter
3
+
4
+ from ..utils import _raise, backend_channels_last
5
+
6
+ import numpy as np
7
+ from ..utils.tf import keras_import, BACKEND as K
8
+
9
+
10
+
11
+ def _mean_or_not(mean):
12
+ # return (lambda x: K.mean(x,axis=(-1 if backend_channels_last() else 1))) if mean else (lambda x: x)
13
+ # Keras also only averages over axis=-1, see https://github.com/keras-team/keras/blob/master/keras/losses.py
14
+ return (lambda x: K.mean(x,axis=-1)) if mean else (lambda x: x)
15
+
16
+
17
+ def loss_laplace(mean=True):
18
+ R = _mean_or_not(mean)
19
+ C = np.log(2.0)
20
+ if backend_channels_last():
21
+ def nll(y_true, y_pred):
22
+ y_true = K.cast(y_true, K.floatx())
23
+ n = K.shape(y_true)[-1]
24
+ mu = y_pred[...,:n]
25
+ sigma = y_pred[...,n:]
26
+ return R(K.abs((mu-y_true)/sigma) + K.log(sigma) + C)
27
+ return nll
28
+ else:
29
+ def nll(y_true, y_pred):
30
+ y_true = K.cast(y_true, K.floatx())
31
+ n = K.shape(y_true)[1]
32
+ mu = y_pred[:,:n,...]
33
+ sigma = y_pred[:,n:,...]
34
+ return R(K.abs((mu-y_true)/sigma) + K.log(sigma) + C)
35
+ return nll
36
+
37
+
38
+ def loss_mae(mean=True):
39
+ R = _mean_or_not(mean)
40
+ if backend_channels_last():
41
+ def mae(y_true, y_pred):
42
+ y_true = K.cast(y_true, K.floatx())
43
+ n = K.shape(y_true)[-1]
44
+ return R(K.abs(y_pred[...,:n] - y_true))
45
+ return mae
46
+ else:
47
+ def mae(y_true, y_pred):
48
+ y_true = K.cast(y_true, K.floatx())
49
+ n = K.shape(y_true)[1]
50
+ return R(K.abs(y_pred[:,:n,...] - y_true))
51
+ return mae
52
+
53
+
54
+ def loss_mse(mean=True):
55
+ R = _mean_or_not(mean)
56
+ if backend_channels_last():
57
+ def mse(y_true, y_pred):
58
+ y_true = K.cast(y_true, K.floatx())
59
+ n = K.shape(y_true)[-1]
60
+ return R(K.square(y_pred[...,:n] - y_true))
61
+ return mse
62
+ else:
63
+ def mse(y_true, y_pred):
64
+ y_true = K.cast(y_true, K.floatx())
65
+ n = K.shape(y_true)[1]
66
+ return R(K.square(y_pred[:,:n,...] - y_true))
67
+ return mse
68
+
69
+
70
+ def loss_thresh_weighted_decay(loss_per_pixel, thresh, w1, w2, alpha):
71
+ def _loss(y_true, y_pred):
72
+ y_true = K.cast(y_true, K.floatx())
73
+ val = loss_per_pixel(y_true, y_pred)
74
+ k1 = alpha * w1 + (1 - alpha)
75
+ k2 = alpha * w2 + (1 - alpha)
76
+ return K.mean(K.tf.where(K.tf.less_equal(y_true, thresh), k1 * val, k2 * val),
77
+ axis=(-1 if backend_channels_last() else 1))
78
+ return _loss
79
+