senoquant 1.0.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. senoquant/__init__.py +6 -0
  2. senoquant/_reader.py +7 -0
  3. senoquant/_widget.py +33 -0
  4. senoquant/napari.yaml +83 -0
  5. senoquant/reader/__init__.py +5 -0
  6. senoquant/reader/core.py +369 -0
  7. senoquant/tabs/__init__.py +15 -0
  8. senoquant/tabs/batch/__init__.py +10 -0
  9. senoquant/tabs/batch/backend.py +641 -0
  10. senoquant/tabs/batch/config.py +270 -0
  11. senoquant/tabs/batch/frontend.py +1283 -0
  12. senoquant/tabs/batch/io.py +326 -0
  13. senoquant/tabs/batch/layers.py +86 -0
  14. senoquant/tabs/quantification/__init__.py +1 -0
  15. senoquant/tabs/quantification/backend.py +228 -0
  16. senoquant/tabs/quantification/features/__init__.py +80 -0
  17. senoquant/tabs/quantification/features/base.py +142 -0
  18. senoquant/tabs/quantification/features/marker/__init__.py +5 -0
  19. senoquant/tabs/quantification/features/marker/config.py +69 -0
  20. senoquant/tabs/quantification/features/marker/dialog.py +437 -0
  21. senoquant/tabs/quantification/features/marker/export.py +879 -0
  22. senoquant/tabs/quantification/features/marker/feature.py +119 -0
  23. senoquant/tabs/quantification/features/marker/morphology.py +285 -0
  24. senoquant/tabs/quantification/features/marker/rows.py +654 -0
  25. senoquant/tabs/quantification/features/marker/thresholding.py +46 -0
  26. senoquant/tabs/quantification/features/roi.py +346 -0
  27. senoquant/tabs/quantification/features/spots/__init__.py +5 -0
  28. senoquant/tabs/quantification/features/spots/config.py +62 -0
  29. senoquant/tabs/quantification/features/spots/dialog.py +477 -0
  30. senoquant/tabs/quantification/features/spots/export.py +1292 -0
  31. senoquant/tabs/quantification/features/spots/feature.py +112 -0
  32. senoquant/tabs/quantification/features/spots/morphology.py +279 -0
  33. senoquant/tabs/quantification/features/spots/rows.py +241 -0
  34. senoquant/tabs/quantification/frontend.py +815 -0
  35. senoquant/tabs/segmentation/__init__.py +1 -0
  36. senoquant/tabs/segmentation/backend.py +131 -0
  37. senoquant/tabs/segmentation/frontend.py +1009 -0
  38. senoquant/tabs/segmentation/models/__init__.py +5 -0
  39. senoquant/tabs/segmentation/models/base.py +146 -0
  40. senoquant/tabs/segmentation/models/cpsam/details.json +65 -0
  41. senoquant/tabs/segmentation/models/cpsam/model.py +150 -0
  42. senoquant/tabs/segmentation/models/default_2d/details.json +69 -0
  43. senoquant/tabs/segmentation/models/default_2d/model.py +664 -0
  44. senoquant/tabs/segmentation/models/default_3d/details.json +69 -0
  45. senoquant/tabs/segmentation/models/default_3d/model.py +682 -0
  46. senoquant/tabs/segmentation/models/hf.py +71 -0
  47. senoquant/tabs/segmentation/models/nuclear_dilation/__init__.py +1 -0
  48. senoquant/tabs/segmentation/models/nuclear_dilation/details.json +26 -0
  49. senoquant/tabs/segmentation/models/nuclear_dilation/model.py +96 -0
  50. senoquant/tabs/segmentation/models/perinuclear_rings/__init__.py +1 -0
  51. senoquant/tabs/segmentation/models/perinuclear_rings/details.json +34 -0
  52. senoquant/tabs/segmentation/models/perinuclear_rings/model.py +132 -0
  53. senoquant/tabs/segmentation/stardist_onnx_utils/__init__.py +2 -0
  54. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/__init__.py +3 -0
  55. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/__init__.py +6 -0
  56. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/generate.py +470 -0
  57. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/prepare.py +273 -0
  58. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/rawdata.py +112 -0
  59. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/transform.py +384 -0
  60. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/__init__.py +0 -0
  61. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/blocks.py +184 -0
  62. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/losses.py +79 -0
  63. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/nets.py +165 -0
  64. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/predict.py +467 -0
  65. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/probability.py +67 -0
  66. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/train.py +148 -0
  67. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/io/__init__.py +163 -0
  68. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/__init__.py +52 -0
  69. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/base_model.py +329 -0
  70. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_isotropic.py +160 -0
  71. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_projection.py +178 -0
  72. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_standard.py +446 -0
  73. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_upsampling.py +54 -0
  74. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/config.py +254 -0
  75. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/pretrained.py +119 -0
  76. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/__init__.py +0 -0
  77. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/care_predict.py +180 -0
  78. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/__init__.py +5 -0
  79. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/plot_utils.py +159 -0
  80. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/six.py +18 -0
  81. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/tf.py +644 -0
  82. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/utils.py +272 -0
  83. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/version.py +1 -0
  84. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/docs/source/conf.py +368 -0
  85. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/setup.py +68 -0
  86. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_datagen.py +169 -0
  87. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_models.py +462 -0
  88. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_utils.py +166 -0
  89. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tools/create_zip_contents.py +34 -0
  90. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/__init__.py +30 -0
  91. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/big.py +624 -0
  92. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/bioimageio_utils.py +494 -0
  93. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/data/__init__.py +39 -0
  94. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/__init__.py +10 -0
  95. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom2d.py +215 -0
  96. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom3d.py +349 -0
  97. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/matching.py +483 -0
  98. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/__init__.py +28 -0
  99. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/base.py +1217 -0
  100. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model2d.py +594 -0
  101. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model3d.py +696 -0
  102. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/nms.py +384 -0
  103. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/__init__.py +2 -0
  104. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/plot.py +74 -0
  105. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/render.py +298 -0
  106. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/rays3d.py +373 -0
  107. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/sample_patches.py +65 -0
  108. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/__init__.py +0 -0
  109. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict2d.py +90 -0
  110. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict3d.py +93 -0
  111. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/utils.py +408 -0
  112. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/version.py +1 -0
  113. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/__init__.py +45 -0
  114. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/__init__.py +17 -0
  115. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/cli.py +55 -0
  116. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/core.py +285 -0
  117. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/__init__.py +15 -0
  118. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/cli.py +36 -0
  119. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/divisibility.py +193 -0
  120. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/probe.py +100 -0
  121. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/receptive_field.py +182 -0
  122. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/rf_cli.py +48 -0
  123. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/valid_sizes.py +278 -0
  124. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/__init__.py +8 -0
  125. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/core.py +157 -0
  126. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/__init__.py +17 -0
  127. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/core.py +226 -0
  128. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/__init__.py +5 -0
  129. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/core.py +401 -0
  130. senoquant/tabs/settings/__init__.py +1 -0
  131. senoquant/tabs/settings/backend.py +29 -0
  132. senoquant/tabs/settings/frontend.py +19 -0
  133. senoquant/tabs/spots/__init__.py +1 -0
  134. senoquant/tabs/spots/backend.py +139 -0
  135. senoquant/tabs/spots/frontend.py +800 -0
  136. senoquant/tabs/spots/models/__init__.py +5 -0
  137. senoquant/tabs/spots/models/base.py +94 -0
  138. senoquant/tabs/spots/models/rmp/details.json +61 -0
  139. senoquant/tabs/spots/models/rmp/model.py +499 -0
  140. senoquant/tabs/spots/models/udwt/details.json +103 -0
  141. senoquant/tabs/spots/models/udwt/model.py +482 -0
  142. senoquant/utils.py +25 -0
  143. senoquant-1.0.0b1.dist-info/METADATA +193 -0
  144. senoquant-1.0.0b1.dist-info/RECORD +148 -0
  145. senoquant-1.0.0b1.dist-info/WHEEL +5 -0
  146. senoquant-1.0.0b1.dist-info/entry_points.txt +2 -0
  147. senoquant-1.0.0b1.dist-info/licenses/LICENSE +28 -0
  148. senoquant-1.0.0b1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,470 @@
1
+ # -*- coding: utf-8 -*-
2
+ from __future__ import print_function, unicode_literals, absolute_import, division
3
+ from six.moves import range, zip, map, reduce, filter
4
+ from six import string_types
5
+
6
+ import numpy as np
7
+ import sys, os, warnings
8
+
9
+ from tqdm import tqdm
10
+ from ..utils import _raise, consume, compose, normalize_mi_ma, axes_dict, axes_check_and_normalize, choice
11
+ from ..utils.six import Path
12
+ from ..io import save_training_data
13
+
14
+ from .transform import Transform, permute_axes, broadcast_target
15
+
16
+
17
+
18
+ ## Patch filter
19
+
20
+ def no_background_patches(threshold=0.4, percentile=99.9):
21
+
22
+ """Returns a patch filter to be used by :func:`create_patches` to determine for each image pair which patches
23
+ are eligible for sampling. The purpose is to only sample patches from "interesting" regions of the raw image that
24
+ actually contain a substantial amount of non-background signal. To that end, a maximum filter is applied to the target image
25
+ to find the largest values in a region.
26
+
27
+ Parameters
28
+ ----------
29
+ threshold : float, optional
30
+ Scalar threshold between 0 and 1 that will be multiplied with the (outlier-robust)
31
+ maximum of the image (see `percentile` below) to denote a lower bound.
32
+ Only patches with a maximum value above this lower bound are eligible to be sampled.
33
+ percentile : float, optional
34
+ Percentile value to denote the (outlier-robust) maximum of an image, i.e. should be close 100.
35
+
36
+ Returns
37
+ -------
38
+ function
39
+ Function that takes an image pair `(y,x)` and the patch size as arguments and
40
+ returns a binary mask of the same size as the image (to denote the locations
41
+ eligible for sampling for :func:`create_patches`). At least one pixel of the
42
+ binary mask must be ``True``, otherwise there are no patches to sample.
43
+
44
+ Raises
45
+ ------
46
+ ValueError
47
+ Illegal arguments.
48
+ """
49
+
50
+ (np.isscalar(percentile) and 0 <= percentile <= 100) or _raise(ValueError())
51
+ (np.isscalar(threshold) and 0 <= threshold <= 1) or _raise(ValueError())
52
+
53
+ from scipy.ndimage import maximum_filter
54
+ def _filter(datas, patch_size, dtype=np.float32):
55
+ image = datas[0]
56
+ if dtype is not None:
57
+ image = image.astype(dtype)
58
+ # make max filter patch_size smaller to avoid only few non-bg pixel close to image border
59
+ patch_size = [(p//2 if p>1 else p) for p in patch_size]
60
+ filtered = maximum_filter(image, patch_size, mode='constant')
61
+ return filtered > threshold * np.percentile(image,percentile)
62
+ return _filter
63
+
64
+
65
+
66
+ ## Sample patches
67
+
68
+ def sample_patches_from_multiple_stacks(datas, patch_size, n_samples, datas_mask=None, patch_filter=None, verbose=False):
69
+ """ sample matching patches of size `patch_size` from all arrays in `datas` """
70
+
71
+ # TODO: some of these checks are already required in 'create_patches'
72
+ len(patch_size)==datas[0].ndim or _raise(ValueError())
73
+
74
+ if not all(( a.shape == datas[0].shape for a in datas )):
75
+ raise ValueError("all input shapes must be the same: %s" % (" / ".join(str(a.shape) for a in datas)))
76
+
77
+ if not all(( 0 < s <= d for s,d in zip(patch_size,datas[0].shape) )):
78
+ raise ValueError("patch_size %s negative or larger than data shape %s along some dimensions" % (str(patch_size), str(datas[0].shape)))
79
+
80
+ if patch_filter is None:
81
+ patch_mask = np.ones(datas[0].shape,dtype=bool)
82
+ else:
83
+ patch_mask = patch_filter(datas, patch_size)
84
+
85
+ if datas_mask is not None:
86
+ # TODO: Test this
87
+ warnings.warn('Using pixel masks for raw/transformed images not tested.')
88
+ datas_mask.shape == datas[0].shape or _raise(ValueError())
89
+ datas_mask.dtype == bool or _raise(ValueError())
90
+ from scipy.ndimage import minimum_filter
91
+ patch_mask &= minimum_filter(datas_mask, patch_size, mode='constant', cval=False)
92
+
93
+ # get the valid indices
94
+
95
+ border_slices = tuple([slice(s // 2, d - s + s // 2 + 1) for s, d in zip(patch_size, datas[0].shape)])
96
+ valid_inds = np.where(patch_mask[border_slices])
97
+ n_valid = len(valid_inds[0])
98
+
99
+ if n_valid == 0:
100
+ raise ValueError("'patch_filter' didn't return any region to sample from")
101
+
102
+ sample_inds = choice(range(n_valid), n_samples, replace=(n_valid < n_samples))
103
+
104
+ # valid_inds = [v + s.start for s, v in zip(border_slices, valid_inds)] # slow for large n_valid
105
+ # rand_inds = [v[sample_inds] for v in valid_inds]
106
+ rand_inds = [v[sample_inds] + s.start for s, v in zip(border_slices, valid_inds)]
107
+
108
+ # res = [np.stack([data[r[0] - patch_size[0] // 2:r[0] + patch_size[0] - patch_size[0] // 2,
109
+ # r[1] - patch_size[1] // 2:r[1] + patch_size[1] - patch_size[1] // 2,
110
+ # r[2] - patch_size[2] // 2:r[2] + patch_size[2] - patch_size[2] // 2,
111
+ # ] for r in zip(*rand_inds)]) for data in datas]
112
+
113
+ res = [np.stack([data[tuple(slice(_r-(_p//2),_r+_p-(_p//2)) for _r,_p in zip(r,patch_size))] for r in zip(*rand_inds)]) for data in datas]
114
+
115
+ return res
116
+
117
+
118
+
119
+ ## Create training data
120
+
121
+ def _valid_low_high_percentiles(ps):
122
+ return isinstance(ps,(list,tuple,np.ndarray)) and len(ps)==2 and all(map(np.isscalar,ps)) and (0<=ps[0]<ps[1]<=100)
123
+
124
+ def _memory_check(n_required_memory_bytes, thresh_free_frac=0.5, thresh_abs_bytes=1024*1024**2):
125
+ try:
126
+ # raise ImportError
127
+ import psutil
128
+ mem = psutil.virtual_memory()
129
+ mem_frac = n_required_memory_bytes / mem.available
130
+ if mem_frac > 1:
131
+ raise MemoryError('Not enough available memory.')
132
+ elif mem_frac > thresh_free_frac:
133
+ print('Warning: will use at least %.0f MB (%.1f%%) of available memory.\n' % (n_required_memory_bytes/1024**2,100*mem_frac), file=sys.stderr)
134
+ sys.stderr.flush()
135
+ except ImportError:
136
+ if n_required_memory_bytes > thresh_abs_bytes:
137
+ print('Warning: will use at least %.0f MB of memory.\n' % (n_required_memory_bytes/1024**2), file=sys.stderr)
138
+ sys.stderr.flush()
139
+
140
+ def sample_percentiles(pmin=(1,3), pmax=(99.5,99.9)):
141
+ """Sample percentile values from a uniform distribution.
142
+
143
+ Parameters
144
+ ----------
145
+ pmin : tuple
146
+ Tuple of two values that denotes the interval for sampling low percentiles.
147
+ pmax : tuple
148
+ Tuple of two values that denotes the interval for sampling high percentiles.
149
+
150
+ Returns
151
+ -------
152
+ function
153
+ Function without arguments that returns `(pl,ph)`, where `pl` (`ph`) is a sampled low (high) percentile.
154
+
155
+ Raises
156
+ ------
157
+ ValueError
158
+ Illegal arguments.
159
+ """
160
+ _valid_low_high_percentiles(pmin) or _raise(ValueError(pmin))
161
+ _valid_low_high_percentiles(pmax) or _raise(ValueError(pmax))
162
+ pmin[1] < pmax[0] or _raise(ValueError())
163
+ return lambda: (np.random.uniform(*pmin), np.random.uniform(*pmax))
164
+
165
+
166
+ def norm_percentiles(percentiles=sample_percentiles(), relu_last=False):
167
+ """Normalize extracted patches based on percentiles from corresponding raw image.
168
+
169
+ Parameters
170
+ ----------
171
+ percentiles : tuple, optional
172
+ A tuple (`pmin`, `pmax`) or a function that returns such a tuple, where the extracted patches
173
+ are (affinely) normalized in such that a value of 0 (1) corresponds to the `pmin`-th (`pmax`-th) percentile
174
+ of the raw image (default: :func:`sample_percentiles`).
175
+ relu_last : bool, optional
176
+ Flag to indicate whether the last activation of the CARE network is/will be using
177
+ a ReLU activation function (default: ``False``)
178
+
179
+ Return
180
+ ------
181
+ function
182
+ Function that does percentile-based normalization to be used in :func:`create_patches`.
183
+
184
+ Raises
185
+ ------
186
+ ValueError
187
+ Illegal arguments.
188
+
189
+ Todo
190
+ ----
191
+ ``relu_last`` flag problematic/inelegant.
192
+
193
+ """
194
+ if callable(percentiles):
195
+ _tmp = percentiles()
196
+ _valid_low_high_percentiles(_tmp) or _raise(ValueError(_tmp))
197
+ get_percentiles = percentiles
198
+ else:
199
+ _valid_low_high_percentiles(percentiles) or _raise(ValueError(percentiles))
200
+ get_percentiles = lambda: percentiles
201
+
202
+ def _normalize(patches_x,patches_y, x,y,mask,channel):
203
+ pmins, pmaxs = zip(*(get_percentiles() for _ in patches_x))
204
+ percentile_axes = None if channel is None else tuple((d for d in range(x.ndim) if d != channel))
205
+ _perc = lambda a,p: np.percentile(a,p,axis=percentile_axes,keepdims=True)
206
+ patches_x_norm = normalize_mi_ma(patches_x, _perc(x,pmins), _perc(x,pmaxs))
207
+ if relu_last:
208
+ pmins = np.zeros_like(pmins)
209
+ patches_y_norm = normalize_mi_ma(patches_y, _perc(y,pmins), _perc(y,pmaxs))
210
+ return patches_x_norm, patches_y_norm
211
+
212
+ return _normalize
213
+
214
+
215
+ def create_patches(
216
+ raw_data,
217
+ patch_size,
218
+ n_patches_per_image,
219
+ patch_axes = None,
220
+ save_file = None,
221
+ transforms = None,
222
+ patch_filter = no_background_patches(),
223
+ normalization = norm_percentiles(),
224
+ shuffle = True,
225
+ verbose = True,
226
+ ):
227
+ """Create normalized training data to be used for neural network training.
228
+
229
+ Parameters
230
+ ----------
231
+ raw_data : :class:`RawData`
232
+ Object that yields matching pairs of raw images.
233
+ patch_size : tuple
234
+ Shape of the patches to be extraced from raw images.
235
+ Must be compatible with the number of dimensions and axes of the raw images.
236
+ As a general rule, use a power of two along all XYZT axes, or at least divisible by 8.
237
+ n_patches_per_image : int
238
+ Number of patches to be sampled/extracted from each raw image pair (after transformations, see below).
239
+ patch_axes : str or None
240
+ Axes of the extracted patches. If ``None``, will assume to be equal to that of transformed raw data.
241
+ save_file : str or None
242
+ File name to save training data to disk in ``.npz`` format (see :func:`csbdeep.io.save_training_data`).
243
+ If ``None``, data will not be saved.
244
+ transforms : list or tuple, optional
245
+ List of :class:`Transform` objects that apply additional transformations to the raw images.
246
+ This can be used to augment the set of raw images (e.g., by including rotations).
247
+ Set to ``None`` to disable. Default: ``None``.
248
+ patch_filter : function, optional
249
+ Function to determine for each image pair which patches are eligible to be extracted
250
+ (default: :func:`no_background_patches`). Set to ``None`` to disable.
251
+ normalization : function, optional
252
+ Function that takes arguments `(patches_x, patches_y, x, y, mask, channel)`, whose purpose is to
253
+ normalize the patches (`patches_x`, `patches_y`) extracted from the associated raw images
254
+ (`x`, `y`, with `mask`; see :class:`RawData`). Default: :func:`norm_percentiles`.
255
+ shuffle : bool, optional
256
+ Randomly shuffle all extracted patches.
257
+ verbose : bool, optional
258
+ Display overview of images, transforms, etc.
259
+
260
+ Returns
261
+ -------
262
+ tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`, str)
263
+ Returns a tuple (`X`, `Y`, `axes`) with the normalized extracted patches from all (transformed) raw images
264
+ and their axes.
265
+ `X` is the array of patches extracted from source images with `Y` being the array of corresponding target patches.
266
+ The shape of `X` and `Y` is as follows: `(n_total_patches, n_channels, ...)`.
267
+ For single-channel images, `n_channels` will be 1.
268
+
269
+ Raises
270
+ ------
271
+ ValueError
272
+ Various reasons.
273
+
274
+ Example
275
+ -------
276
+ >>> raw_data = RawData.from_folder(basepath='data', source_dirs=['source1','source2'], target_dir='GT', axes='ZYX')
277
+ >>> X, Y, XY_axes = create_patches(raw_data, patch_size=(32,128,128), n_patches_per_image=16)
278
+
279
+ Todo
280
+ ----
281
+ - Save created patches directly to disk using :class:`numpy.memmap` or similar?
282
+ Would allow to work with large data that doesn't fit in memory.
283
+
284
+ """
285
+ ## images and transforms
286
+ if transforms is None:
287
+ transforms = []
288
+ transforms = list(transforms)
289
+ if patch_axes is not None:
290
+ transforms.append(permute_axes(patch_axes))
291
+ if len(transforms) == 0:
292
+ transforms.append(Transform.identity())
293
+
294
+ if normalization is None:
295
+ normalization = lambda patches_x, patches_y, x, y, mask, channel: (patches_x, patches_y)
296
+
297
+ image_pairs, n_raw_images = raw_data.generator(), raw_data.size
298
+ tf = Transform(*zip(*transforms)) # convert list of Transforms into Transform of lists
299
+ image_pairs = compose(*tf.generator)(image_pairs) # combine all transformations with raw images as input
300
+ n_transforms = np.prod(tf.size)
301
+ n_images = n_raw_images * n_transforms
302
+ n_patches = n_images * n_patches_per_image
303
+ n_required_memory_bytes = 2 * n_patches*np.prod(patch_size) * 4
304
+
305
+ ## memory check
306
+ _memory_check(n_required_memory_bytes)
307
+
308
+ ## summary
309
+ if verbose:
310
+ print('='*66)
311
+ print('%5d raw images x %4d transformations = %5d images' % (n_raw_images,n_transforms,n_images))
312
+ print('%5d images x %4d patches per image = %5d patches in total' % (n_images,n_patches_per_image,n_patches))
313
+ print('='*66)
314
+ print('Input data:')
315
+ print(raw_data.description)
316
+ print('='*66)
317
+ print('Transformations:')
318
+ for t in transforms:
319
+ print('{t.size} x {t.name}'.format(t=t))
320
+ print('='*66)
321
+ print('Patch size:')
322
+ print(" x ".join(str(p) for p in patch_size))
323
+ print('=' * 66)
324
+
325
+ sys.stdout.flush()
326
+
327
+ ## sample patches from each pair of transformed raw images
328
+ X = np.empty((n_patches,)+tuple(patch_size),dtype=np.float32)
329
+ Y = np.empty_like(X)
330
+
331
+ for i, (x,y,_axes,mask) in tqdm(enumerate(image_pairs),total=n_images,disable=(not verbose)):
332
+ if i >= n_images:
333
+ warnings.warn('more raw images (or transformations thereof) than expected, skipping excess images.')
334
+ break
335
+ if i==0:
336
+ axes = axes_check_and_normalize(_axes,len(patch_size))
337
+ channel = axes_dict(axes)['C']
338
+ # checks
339
+ # len(axes) >= x.ndim or _raise(ValueError())
340
+ axes == axes_check_and_normalize(_axes) or _raise(ValueError('not all images have the same axes.'))
341
+ x.shape == y.shape or _raise(ValueError())
342
+ mask is None or mask.shape == x.shape or _raise(ValueError())
343
+ (channel is None or (isinstance(channel,int) and 0<=channel<x.ndim)) or _raise(ValueError())
344
+ channel is None or patch_size[channel]==x.shape[channel] or _raise(ValueError('extracted patches must contain all channels.'))
345
+
346
+ _Y,_X = sample_patches_from_multiple_stacks((y,x), patch_size, n_patches_per_image, mask, patch_filter)
347
+
348
+ s = slice(i*n_patches_per_image,(i+1)*n_patches_per_image)
349
+ X[s], Y[s] = normalization(_X,_Y, x,y,mask,channel)
350
+
351
+ if shuffle:
352
+ shuffle_inplace(X,Y)
353
+
354
+ axes = 'SC'+axes.replace('C','')
355
+ if channel is None:
356
+ X = np.expand_dims(X,1)
357
+ Y = np.expand_dims(Y,1)
358
+ else:
359
+ X = np.moveaxis(X, 1+channel, 1)
360
+ Y = np.moveaxis(Y, 1+channel, 1)
361
+
362
+ if save_file is not None:
363
+ print('Saving data to %s.' % str(Path(save_file)))
364
+ save_training_data(save_file, X, Y, axes)
365
+
366
+ return X,Y,axes
367
+
368
+
369
+ def create_patches_reduced_target(
370
+ raw_data,
371
+ patch_size,
372
+ n_patches_per_image,
373
+ reduction_axes,
374
+ target_axes = None, # TODO: this should rather be part of RawData and also exposed to transforms
375
+ **kwargs
376
+ ):
377
+ """Create normalized training data to be used for neural network training.
378
+
379
+ In contrast to :func:`create_patches`, it is assumed that the target image has reduced
380
+ dimensionality (i.e. size 1) along one or several axes (`reduction_axes`).
381
+
382
+ Parameters
383
+ ----------
384
+ raw_data : :class:`RawData`
385
+ See :func:`create_patches`.
386
+ patch_size : tuple
387
+ See :func:`create_patches`.
388
+ n_patches_per_image : int
389
+ See :func:`create_patches`.
390
+ reduction_axes : str
391
+ Axes where the target images have a reduced dimension (i.e. size 1) compared to the source images.
392
+ target_axes : str
393
+ Axes of the raw target images. If ``None``, will be assumed to be equal to that of the raw source images.
394
+ kwargs : dict
395
+ Additional parameters as in :func:`create_patches`.
396
+
397
+ Returns
398
+ -------
399
+ tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`, str)
400
+ See :func:`create_patches`. Note that the shape of the target data will be 1 along all reduction axes.
401
+
402
+ """
403
+ reduction_axes = axes_check_and_normalize(reduction_axes,disallowed='S')
404
+
405
+ transforms = kwargs.get('transforms')
406
+ if transforms is None:
407
+ transforms = []
408
+ transforms = list(transforms)
409
+ transforms.insert(0,broadcast_target(target_axes))
410
+ kwargs['transforms'] = transforms
411
+
412
+ save_file = kwargs.pop('save_file',None)
413
+
414
+ if any(s is None for s in patch_size):
415
+ patch_axes = kwargs.get('patch_axes')
416
+ if patch_axes is not None:
417
+ _transforms = list(transforms)
418
+ _transforms.append(permute_axes(patch_axes))
419
+ else:
420
+ _transforms = transforms
421
+ tf = Transform(*zip(*_transforms))
422
+ image_pairs = compose(*tf.generator)(raw_data.generator())
423
+ x,y,axes,mask = next(image_pairs) # get the first entry from the generator
424
+ patch_size = list(patch_size)
425
+ for i,(a,s) in enumerate(zip(axes,patch_size)):
426
+ if s is not None: continue
427
+ a in reduction_axes or _raise(ValueError("entry of patch_size is None for non reduction axis %s." % a))
428
+ patch_size[i] = x.shape[i]
429
+ patch_size = tuple(patch_size)
430
+ del x,y,axes,mask
431
+
432
+ X,Y,axes = create_patches (
433
+ raw_data = raw_data,
434
+ patch_size = patch_size,
435
+ n_patches_per_image = n_patches_per_image,
436
+ **kwargs
437
+ )
438
+
439
+ ax = axes_dict(axes)
440
+ for a in reduction_axes:
441
+ a in axes or _raise(ValueError("reduction axis %d not present in extracted patches" % a))
442
+ n_dims = Y.shape[ax[a]]
443
+ if n_dims == 1:
444
+ warnings.warn("extracted target patches already have dimensionality 1 along reduction axis %s." % a)
445
+ else:
446
+ t = np.take(Y,(1,),axis=ax[a])
447
+ Y = np.take(Y,(0,),axis=ax[a])
448
+ i = np.random.choice(Y.size,size=100)
449
+ if not np.all(t.flat[i]==Y.flat[i]):
450
+ warnings.warn("extracted target patches vary along reduction axis %s." % a)
451
+
452
+ if save_file is not None:
453
+ print('Saving data to %s.' % str(Path(save_file)))
454
+ save_training_data(save_file, X, Y, axes)
455
+
456
+ return X,Y,axes
457
+
458
+
459
+ # Misc
460
+
461
+ def shuffle_inplace(*arrs,**kwargs):
462
+ seed = kwargs.pop('seed', None)
463
+ if seed is None:
464
+ rng = np.random
465
+ else:
466
+ rng = np.random.RandomState(seed=seed)
467
+ state = rng.get_state()
468
+ for a in arrs:
469
+ rng.set_state(state)
470
+ rng.shuffle(a)