senoquant 1.0.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- senoquant/__init__.py +6 -0
- senoquant/_reader.py +7 -0
- senoquant/_widget.py +33 -0
- senoquant/napari.yaml +83 -0
- senoquant/reader/__init__.py +5 -0
- senoquant/reader/core.py +369 -0
- senoquant/tabs/__init__.py +15 -0
- senoquant/tabs/batch/__init__.py +10 -0
- senoquant/tabs/batch/backend.py +641 -0
- senoquant/tabs/batch/config.py +270 -0
- senoquant/tabs/batch/frontend.py +1283 -0
- senoquant/tabs/batch/io.py +326 -0
- senoquant/tabs/batch/layers.py +86 -0
- senoquant/tabs/quantification/__init__.py +1 -0
- senoquant/tabs/quantification/backend.py +228 -0
- senoquant/tabs/quantification/features/__init__.py +80 -0
- senoquant/tabs/quantification/features/base.py +142 -0
- senoquant/tabs/quantification/features/marker/__init__.py +5 -0
- senoquant/tabs/quantification/features/marker/config.py +69 -0
- senoquant/tabs/quantification/features/marker/dialog.py +437 -0
- senoquant/tabs/quantification/features/marker/export.py +879 -0
- senoquant/tabs/quantification/features/marker/feature.py +119 -0
- senoquant/tabs/quantification/features/marker/morphology.py +285 -0
- senoquant/tabs/quantification/features/marker/rows.py +654 -0
- senoquant/tabs/quantification/features/marker/thresholding.py +46 -0
- senoquant/tabs/quantification/features/roi.py +346 -0
- senoquant/tabs/quantification/features/spots/__init__.py +5 -0
- senoquant/tabs/quantification/features/spots/config.py +62 -0
- senoquant/tabs/quantification/features/spots/dialog.py +477 -0
- senoquant/tabs/quantification/features/spots/export.py +1292 -0
- senoquant/tabs/quantification/features/spots/feature.py +112 -0
- senoquant/tabs/quantification/features/spots/morphology.py +279 -0
- senoquant/tabs/quantification/features/spots/rows.py +241 -0
- senoquant/tabs/quantification/frontend.py +815 -0
- senoquant/tabs/segmentation/__init__.py +1 -0
- senoquant/tabs/segmentation/backend.py +131 -0
- senoquant/tabs/segmentation/frontend.py +1009 -0
- senoquant/tabs/segmentation/models/__init__.py +5 -0
- senoquant/tabs/segmentation/models/base.py +146 -0
- senoquant/tabs/segmentation/models/cpsam/details.json +65 -0
- senoquant/tabs/segmentation/models/cpsam/model.py +150 -0
- senoquant/tabs/segmentation/models/default_2d/details.json +69 -0
- senoquant/tabs/segmentation/models/default_2d/model.py +664 -0
- senoquant/tabs/segmentation/models/default_3d/details.json +69 -0
- senoquant/tabs/segmentation/models/default_3d/model.py +682 -0
- senoquant/tabs/segmentation/models/hf.py +71 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/__init__.py +1 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/details.json +26 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/model.py +96 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/__init__.py +1 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/details.json +34 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/model.py +132 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/__init__.py +2 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/__init__.py +3 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/__init__.py +6 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/generate.py +470 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/prepare.py +273 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/rawdata.py +112 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/transform.py +384 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/blocks.py +184 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/losses.py +79 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/nets.py +165 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/predict.py +467 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/probability.py +67 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/train.py +148 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/io/__init__.py +163 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/__init__.py +52 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/base_model.py +329 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_isotropic.py +160 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_projection.py +178 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_standard.py +446 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_upsampling.py +54 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/config.py +254 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/pretrained.py +119 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/care_predict.py +180 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/__init__.py +5 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/plot_utils.py +159 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/six.py +18 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/tf.py +644 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/utils.py +272 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/version.py +1 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/docs/source/conf.py +368 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/setup.py +68 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_datagen.py +169 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_models.py +462 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_utils.py +166 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tools/create_zip_contents.py +34 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/__init__.py +30 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/big.py +624 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/bioimageio_utils.py +494 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/data/__init__.py +39 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/__init__.py +10 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom2d.py +215 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom3d.py +349 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/matching.py +483 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/__init__.py +28 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/base.py +1217 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model2d.py +594 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model3d.py +696 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/nms.py +384 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/__init__.py +2 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/plot.py +74 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/render.py +298 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/rays3d.py +373 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/sample_patches.py +65 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict2d.py +90 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict3d.py +93 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/utils.py +408 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/version.py +1 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/__init__.py +45 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/__init__.py +17 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/cli.py +55 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/core.py +285 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/__init__.py +15 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/cli.py +36 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/divisibility.py +193 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/probe.py +100 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/receptive_field.py +182 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/rf_cli.py +48 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/valid_sizes.py +278 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/__init__.py +8 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/core.py +157 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/__init__.py +17 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/core.py +226 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/__init__.py +5 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/core.py +401 -0
- senoquant/tabs/settings/__init__.py +1 -0
- senoquant/tabs/settings/backend.py +29 -0
- senoquant/tabs/settings/frontend.py +19 -0
- senoquant/tabs/spots/__init__.py +1 -0
- senoquant/tabs/spots/backend.py +139 -0
- senoquant/tabs/spots/frontend.py +800 -0
- senoquant/tabs/spots/models/__init__.py +5 -0
- senoquant/tabs/spots/models/base.py +94 -0
- senoquant/tabs/spots/models/rmp/details.json +61 -0
- senoquant/tabs/spots/models/rmp/model.py +499 -0
- senoquant/tabs/spots/models/udwt/details.json +103 -0
- senoquant/tabs/spots/models/udwt/model.py +482 -0
- senoquant/utils.py +25 -0
- senoquant-1.0.0b1.dist-info/METADATA +193 -0
- senoquant-1.0.0b1.dist-info/RECORD +148 -0
- senoquant-1.0.0b1.dist-info/WHEEL +5 -0
- senoquant-1.0.0b1.dist-info/entry_points.txt +2 -0
- senoquant-1.0.0b1.dist-info/licenses/LICENSE +28 -0
- senoquant-1.0.0b1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,470 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
from __future__ import print_function, unicode_literals, absolute_import, division
|
|
3
|
+
from six.moves import range, zip, map, reduce, filter
|
|
4
|
+
from six import string_types
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import sys, os, warnings
|
|
8
|
+
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
from ..utils import _raise, consume, compose, normalize_mi_ma, axes_dict, axes_check_and_normalize, choice
|
|
11
|
+
from ..utils.six import Path
|
|
12
|
+
from ..io import save_training_data
|
|
13
|
+
|
|
14
|
+
from .transform import Transform, permute_axes, broadcast_target
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
## Patch filter
|
|
19
|
+
|
|
20
|
+
def no_background_patches(threshold=0.4, percentile=99.9):
|
|
21
|
+
|
|
22
|
+
"""Returns a patch filter to be used by :func:`create_patches` to determine for each image pair which patches
|
|
23
|
+
are eligible for sampling. The purpose is to only sample patches from "interesting" regions of the raw image that
|
|
24
|
+
actually contain a substantial amount of non-background signal. To that end, a maximum filter is applied to the target image
|
|
25
|
+
to find the largest values in a region.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
threshold : float, optional
|
|
30
|
+
Scalar threshold between 0 and 1 that will be multiplied with the (outlier-robust)
|
|
31
|
+
maximum of the image (see `percentile` below) to denote a lower bound.
|
|
32
|
+
Only patches with a maximum value above this lower bound are eligible to be sampled.
|
|
33
|
+
percentile : float, optional
|
|
34
|
+
Percentile value to denote the (outlier-robust) maximum of an image, i.e. should be close 100.
|
|
35
|
+
|
|
36
|
+
Returns
|
|
37
|
+
-------
|
|
38
|
+
function
|
|
39
|
+
Function that takes an image pair `(y,x)` and the patch size as arguments and
|
|
40
|
+
returns a binary mask of the same size as the image (to denote the locations
|
|
41
|
+
eligible for sampling for :func:`create_patches`). At least one pixel of the
|
|
42
|
+
binary mask must be ``True``, otherwise there are no patches to sample.
|
|
43
|
+
|
|
44
|
+
Raises
|
|
45
|
+
------
|
|
46
|
+
ValueError
|
|
47
|
+
Illegal arguments.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
(np.isscalar(percentile) and 0 <= percentile <= 100) or _raise(ValueError())
|
|
51
|
+
(np.isscalar(threshold) and 0 <= threshold <= 1) or _raise(ValueError())
|
|
52
|
+
|
|
53
|
+
from scipy.ndimage import maximum_filter
|
|
54
|
+
def _filter(datas, patch_size, dtype=np.float32):
|
|
55
|
+
image = datas[0]
|
|
56
|
+
if dtype is not None:
|
|
57
|
+
image = image.astype(dtype)
|
|
58
|
+
# make max filter patch_size smaller to avoid only few non-bg pixel close to image border
|
|
59
|
+
patch_size = [(p//2 if p>1 else p) for p in patch_size]
|
|
60
|
+
filtered = maximum_filter(image, patch_size, mode='constant')
|
|
61
|
+
return filtered > threshold * np.percentile(image,percentile)
|
|
62
|
+
return _filter
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
## Sample patches
|
|
67
|
+
|
|
68
|
+
def sample_patches_from_multiple_stacks(datas, patch_size, n_samples, datas_mask=None, patch_filter=None, verbose=False):
|
|
69
|
+
""" sample matching patches of size `patch_size` from all arrays in `datas` """
|
|
70
|
+
|
|
71
|
+
# TODO: some of these checks are already required in 'create_patches'
|
|
72
|
+
len(patch_size)==datas[0].ndim or _raise(ValueError())
|
|
73
|
+
|
|
74
|
+
if not all(( a.shape == datas[0].shape for a in datas )):
|
|
75
|
+
raise ValueError("all input shapes must be the same: %s" % (" / ".join(str(a.shape) for a in datas)))
|
|
76
|
+
|
|
77
|
+
if not all(( 0 < s <= d for s,d in zip(patch_size,datas[0].shape) )):
|
|
78
|
+
raise ValueError("patch_size %s negative or larger than data shape %s along some dimensions" % (str(patch_size), str(datas[0].shape)))
|
|
79
|
+
|
|
80
|
+
if patch_filter is None:
|
|
81
|
+
patch_mask = np.ones(datas[0].shape,dtype=bool)
|
|
82
|
+
else:
|
|
83
|
+
patch_mask = patch_filter(datas, patch_size)
|
|
84
|
+
|
|
85
|
+
if datas_mask is not None:
|
|
86
|
+
# TODO: Test this
|
|
87
|
+
warnings.warn('Using pixel masks for raw/transformed images not tested.')
|
|
88
|
+
datas_mask.shape == datas[0].shape or _raise(ValueError())
|
|
89
|
+
datas_mask.dtype == bool or _raise(ValueError())
|
|
90
|
+
from scipy.ndimage import minimum_filter
|
|
91
|
+
patch_mask &= minimum_filter(datas_mask, patch_size, mode='constant', cval=False)
|
|
92
|
+
|
|
93
|
+
# get the valid indices
|
|
94
|
+
|
|
95
|
+
border_slices = tuple([slice(s // 2, d - s + s // 2 + 1) for s, d in zip(patch_size, datas[0].shape)])
|
|
96
|
+
valid_inds = np.where(patch_mask[border_slices])
|
|
97
|
+
n_valid = len(valid_inds[0])
|
|
98
|
+
|
|
99
|
+
if n_valid == 0:
|
|
100
|
+
raise ValueError("'patch_filter' didn't return any region to sample from")
|
|
101
|
+
|
|
102
|
+
sample_inds = choice(range(n_valid), n_samples, replace=(n_valid < n_samples))
|
|
103
|
+
|
|
104
|
+
# valid_inds = [v + s.start for s, v in zip(border_slices, valid_inds)] # slow for large n_valid
|
|
105
|
+
# rand_inds = [v[sample_inds] for v in valid_inds]
|
|
106
|
+
rand_inds = [v[sample_inds] + s.start for s, v in zip(border_slices, valid_inds)]
|
|
107
|
+
|
|
108
|
+
# res = [np.stack([data[r[0] - patch_size[0] // 2:r[0] + patch_size[0] - patch_size[0] // 2,
|
|
109
|
+
# r[1] - patch_size[1] // 2:r[1] + patch_size[1] - patch_size[1] // 2,
|
|
110
|
+
# r[2] - patch_size[2] // 2:r[2] + patch_size[2] - patch_size[2] // 2,
|
|
111
|
+
# ] for r in zip(*rand_inds)]) for data in datas]
|
|
112
|
+
|
|
113
|
+
res = [np.stack([data[tuple(slice(_r-(_p//2),_r+_p-(_p//2)) for _r,_p in zip(r,patch_size))] for r in zip(*rand_inds)]) for data in datas]
|
|
114
|
+
|
|
115
|
+
return res
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
## Create training data
|
|
120
|
+
|
|
121
|
+
def _valid_low_high_percentiles(ps):
|
|
122
|
+
return isinstance(ps,(list,tuple,np.ndarray)) and len(ps)==2 and all(map(np.isscalar,ps)) and (0<=ps[0]<ps[1]<=100)
|
|
123
|
+
|
|
124
|
+
def _memory_check(n_required_memory_bytes, thresh_free_frac=0.5, thresh_abs_bytes=1024*1024**2):
|
|
125
|
+
try:
|
|
126
|
+
# raise ImportError
|
|
127
|
+
import psutil
|
|
128
|
+
mem = psutil.virtual_memory()
|
|
129
|
+
mem_frac = n_required_memory_bytes / mem.available
|
|
130
|
+
if mem_frac > 1:
|
|
131
|
+
raise MemoryError('Not enough available memory.')
|
|
132
|
+
elif mem_frac > thresh_free_frac:
|
|
133
|
+
print('Warning: will use at least %.0f MB (%.1f%%) of available memory.\n' % (n_required_memory_bytes/1024**2,100*mem_frac), file=sys.stderr)
|
|
134
|
+
sys.stderr.flush()
|
|
135
|
+
except ImportError:
|
|
136
|
+
if n_required_memory_bytes > thresh_abs_bytes:
|
|
137
|
+
print('Warning: will use at least %.0f MB of memory.\n' % (n_required_memory_bytes/1024**2), file=sys.stderr)
|
|
138
|
+
sys.stderr.flush()
|
|
139
|
+
|
|
140
|
+
def sample_percentiles(pmin=(1,3), pmax=(99.5,99.9)):
|
|
141
|
+
"""Sample percentile values from a uniform distribution.
|
|
142
|
+
|
|
143
|
+
Parameters
|
|
144
|
+
----------
|
|
145
|
+
pmin : tuple
|
|
146
|
+
Tuple of two values that denotes the interval for sampling low percentiles.
|
|
147
|
+
pmax : tuple
|
|
148
|
+
Tuple of two values that denotes the interval for sampling high percentiles.
|
|
149
|
+
|
|
150
|
+
Returns
|
|
151
|
+
-------
|
|
152
|
+
function
|
|
153
|
+
Function without arguments that returns `(pl,ph)`, where `pl` (`ph`) is a sampled low (high) percentile.
|
|
154
|
+
|
|
155
|
+
Raises
|
|
156
|
+
------
|
|
157
|
+
ValueError
|
|
158
|
+
Illegal arguments.
|
|
159
|
+
"""
|
|
160
|
+
_valid_low_high_percentiles(pmin) or _raise(ValueError(pmin))
|
|
161
|
+
_valid_low_high_percentiles(pmax) or _raise(ValueError(pmax))
|
|
162
|
+
pmin[1] < pmax[0] or _raise(ValueError())
|
|
163
|
+
return lambda: (np.random.uniform(*pmin), np.random.uniform(*pmax))
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def norm_percentiles(percentiles=sample_percentiles(), relu_last=False):
|
|
167
|
+
"""Normalize extracted patches based on percentiles from corresponding raw image.
|
|
168
|
+
|
|
169
|
+
Parameters
|
|
170
|
+
----------
|
|
171
|
+
percentiles : tuple, optional
|
|
172
|
+
A tuple (`pmin`, `pmax`) or a function that returns such a tuple, where the extracted patches
|
|
173
|
+
are (affinely) normalized in such that a value of 0 (1) corresponds to the `pmin`-th (`pmax`-th) percentile
|
|
174
|
+
of the raw image (default: :func:`sample_percentiles`).
|
|
175
|
+
relu_last : bool, optional
|
|
176
|
+
Flag to indicate whether the last activation of the CARE network is/will be using
|
|
177
|
+
a ReLU activation function (default: ``False``)
|
|
178
|
+
|
|
179
|
+
Return
|
|
180
|
+
------
|
|
181
|
+
function
|
|
182
|
+
Function that does percentile-based normalization to be used in :func:`create_patches`.
|
|
183
|
+
|
|
184
|
+
Raises
|
|
185
|
+
------
|
|
186
|
+
ValueError
|
|
187
|
+
Illegal arguments.
|
|
188
|
+
|
|
189
|
+
Todo
|
|
190
|
+
----
|
|
191
|
+
``relu_last`` flag problematic/inelegant.
|
|
192
|
+
|
|
193
|
+
"""
|
|
194
|
+
if callable(percentiles):
|
|
195
|
+
_tmp = percentiles()
|
|
196
|
+
_valid_low_high_percentiles(_tmp) or _raise(ValueError(_tmp))
|
|
197
|
+
get_percentiles = percentiles
|
|
198
|
+
else:
|
|
199
|
+
_valid_low_high_percentiles(percentiles) or _raise(ValueError(percentiles))
|
|
200
|
+
get_percentiles = lambda: percentiles
|
|
201
|
+
|
|
202
|
+
def _normalize(patches_x,patches_y, x,y,mask,channel):
|
|
203
|
+
pmins, pmaxs = zip(*(get_percentiles() for _ in patches_x))
|
|
204
|
+
percentile_axes = None if channel is None else tuple((d for d in range(x.ndim) if d != channel))
|
|
205
|
+
_perc = lambda a,p: np.percentile(a,p,axis=percentile_axes,keepdims=True)
|
|
206
|
+
patches_x_norm = normalize_mi_ma(patches_x, _perc(x,pmins), _perc(x,pmaxs))
|
|
207
|
+
if relu_last:
|
|
208
|
+
pmins = np.zeros_like(pmins)
|
|
209
|
+
patches_y_norm = normalize_mi_ma(patches_y, _perc(y,pmins), _perc(y,pmaxs))
|
|
210
|
+
return patches_x_norm, patches_y_norm
|
|
211
|
+
|
|
212
|
+
return _normalize
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def create_patches(
|
|
216
|
+
raw_data,
|
|
217
|
+
patch_size,
|
|
218
|
+
n_patches_per_image,
|
|
219
|
+
patch_axes = None,
|
|
220
|
+
save_file = None,
|
|
221
|
+
transforms = None,
|
|
222
|
+
patch_filter = no_background_patches(),
|
|
223
|
+
normalization = norm_percentiles(),
|
|
224
|
+
shuffle = True,
|
|
225
|
+
verbose = True,
|
|
226
|
+
):
|
|
227
|
+
"""Create normalized training data to be used for neural network training.
|
|
228
|
+
|
|
229
|
+
Parameters
|
|
230
|
+
----------
|
|
231
|
+
raw_data : :class:`RawData`
|
|
232
|
+
Object that yields matching pairs of raw images.
|
|
233
|
+
patch_size : tuple
|
|
234
|
+
Shape of the patches to be extraced from raw images.
|
|
235
|
+
Must be compatible with the number of dimensions and axes of the raw images.
|
|
236
|
+
As a general rule, use a power of two along all XYZT axes, or at least divisible by 8.
|
|
237
|
+
n_patches_per_image : int
|
|
238
|
+
Number of patches to be sampled/extracted from each raw image pair (after transformations, see below).
|
|
239
|
+
patch_axes : str or None
|
|
240
|
+
Axes of the extracted patches. If ``None``, will assume to be equal to that of transformed raw data.
|
|
241
|
+
save_file : str or None
|
|
242
|
+
File name to save training data to disk in ``.npz`` format (see :func:`csbdeep.io.save_training_data`).
|
|
243
|
+
If ``None``, data will not be saved.
|
|
244
|
+
transforms : list or tuple, optional
|
|
245
|
+
List of :class:`Transform` objects that apply additional transformations to the raw images.
|
|
246
|
+
This can be used to augment the set of raw images (e.g., by including rotations).
|
|
247
|
+
Set to ``None`` to disable. Default: ``None``.
|
|
248
|
+
patch_filter : function, optional
|
|
249
|
+
Function to determine for each image pair which patches are eligible to be extracted
|
|
250
|
+
(default: :func:`no_background_patches`). Set to ``None`` to disable.
|
|
251
|
+
normalization : function, optional
|
|
252
|
+
Function that takes arguments `(patches_x, patches_y, x, y, mask, channel)`, whose purpose is to
|
|
253
|
+
normalize the patches (`patches_x`, `patches_y`) extracted from the associated raw images
|
|
254
|
+
(`x`, `y`, with `mask`; see :class:`RawData`). Default: :func:`norm_percentiles`.
|
|
255
|
+
shuffle : bool, optional
|
|
256
|
+
Randomly shuffle all extracted patches.
|
|
257
|
+
verbose : bool, optional
|
|
258
|
+
Display overview of images, transforms, etc.
|
|
259
|
+
|
|
260
|
+
Returns
|
|
261
|
+
-------
|
|
262
|
+
tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`, str)
|
|
263
|
+
Returns a tuple (`X`, `Y`, `axes`) with the normalized extracted patches from all (transformed) raw images
|
|
264
|
+
and their axes.
|
|
265
|
+
`X` is the array of patches extracted from source images with `Y` being the array of corresponding target patches.
|
|
266
|
+
The shape of `X` and `Y` is as follows: `(n_total_patches, n_channels, ...)`.
|
|
267
|
+
For single-channel images, `n_channels` will be 1.
|
|
268
|
+
|
|
269
|
+
Raises
|
|
270
|
+
------
|
|
271
|
+
ValueError
|
|
272
|
+
Various reasons.
|
|
273
|
+
|
|
274
|
+
Example
|
|
275
|
+
-------
|
|
276
|
+
>>> raw_data = RawData.from_folder(basepath='data', source_dirs=['source1','source2'], target_dir='GT', axes='ZYX')
|
|
277
|
+
>>> X, Y, XY_axes = create_patches(raw_data, patch_size=(32,128,128), n_patches_per_image=16)
|
|
278
|
+
|
|
279
|
+
Todo
|
|
280
|
+
----
|
|
281
|
+
- Save created patches directly to disk using :class:`numpy.memmap` or similar?
|
|
282
|
+
Would allow to work with large data that doesn't fit in memory.
|
|
283
|
+
|
|
284
|
+
"""
|
|
285
|
+
## images and transforms
|
|
286
|
+
if transforms is None:
|
|
287
|
+
transforms = []
|
|
288
|
+
transforms = list(transforms)
|
|
289
|
+
if patch_axes is not None:
|
|
290
|
+
transforms.append(permute_axes(patch_axes))
|
|
291
|
+
if len(transforms) == 0:
|
|
292
|
+
transforms.append(Transform.identity())
|
|
293
|
+
|
|
294
|
+
if normalization is None:
|
|
295
|
+
normalization = lambda patches_x, patches_y, x, y, mask, channel: (patches_x, patches_y)
|
|
296
|
+
|
|
297
|
+
image_pairs, n_raw_images = raw_data.generator(), raw_data.size
|
|
298
|
+
tf = Transform(*zip(*transforms)) # convert list of Transforms into Transform of lists
|
|
299
|
+
image_pairs = compose(*tf.generator)(image_pairs) # combine all transformations with raw images as input
|
|
300
|
+
n_transforms = np.prod(tf.size)
|
|
301
|
+
n_images = n_raw_images * n_transforms
|
|
302
|
+
n_patches = n_images * n_patches_per_image
|
|
303
|
+
n_required_memory_bytes = 2 * n_patches*np.prod(patch_size) * 4
|
|
304
|
+
|
|
305
|
+
## memory check
|
|
306
|
+
_memory_check(n_required_memory_bytes)
|
|
307
|
+
|
|
308
|
+
## summary
|
|
309
|
+
if verbose:
|
|
310
|
+
print('='*66)
|
|
311
|
+
print('%5d raw images x %4d transformations = %5d images' % (n_raw_images,n_transforms,n_images))
|
|
312
|
+
print('%5d images x %4d patches per image = %5d patches in total' % (n_images,n_patches_per_image,n_patches))
|
|
313
|
+
print('='*66)
|
|
314
|
+
print('Input data:')
|
|
315
|
+
print(raw_data.description)
|
|
316
|
+
print('='*66)
|
|
317
|
+
print('Transformations:')
|
|
318
|
+
for t in transforms:
|
|
319
|
+
print('{t.size} x {t.name}'.format(t=t))
|
|
320
|
+
print('='*66)
|
|
321
|
+
print('Patch size:')
|
|
322
|
+
print(" x ".join(str(p) for p in patch_size))
|
|
323
|
+
print('=' * 66)
|
|
324
|
+
|
|
325
|
+
sys.stdout.flush()
|
|
326
|
+
|
|
327
|
+
## sample patches from each pair of transformed raw images
|
|
328
|
+
X = np.empty((n_patches,)+tuple(patch_size),dtype=np.float32)
|
|
329
|
+
Y = np.empty_like(X)
|
|
330
|
+
|
|
331
|
+
for i, (x,y,_axes,mask) in tqdm(enumerate(image_pairs),total=n_images,disable=(not verbose)):
|
|
332
|
+
if i >= n_images:
|
|
333
|
+
warnings.warn('more raw images (or transformations thereof) than expected, skipping excess images.')
|
|
334
|
+
break
|
|
335
|
+
if i==0:
|
|
336
|
+
axes = axes_check_and_normalize(_axes,len(patch_size))
|
|
337
|
+
channel = axes_dict(axes)['C']
|
|
338
|
+
# checks
|
|
339
|
+
# len(axes) >= x.ndim or _raise(ValueError())
|
|
340
|
+
axes == axes_check_and_normalize(_axes) or _raise(ValueError('not all images have the same axes.'))
|
|
341
|
+
x.shape == y.shape or _raise(ValueError())
|
|
342
|
+
mask is None or mask.shape == x.shape or _raise(ValueError())
|
|
343
|
+
(channel is None or (isinstance(channel,int) and 0<=channel<x.ndim)) or _raise(ValueError())
|
|
344
|
+
channel is None or patch_size[channel]==x.shape[channel] or _raise(ValueError('extracted patches must contain all channels.'))
|
|
345
|
+
|
|
346
|
+
_Y,_X = sample_patches_from_multiple_stacks((y,x), patch_size, n_patches_per_image, mask, patch_filter)
|
|
347
|
+
|
|
348
|
+
s = slice(i*n_patches_per_image,(i+1)*n_patches_per_image)
|
|
349
|
+
X[s], Y[s] = normalization(_X,_Y, x,y,mask,channel)
|
|
350
|
+
|
|
351
|
+
if shuffle:
|
|
352
|
+
shuffle_inplace(X,Y)
|
|
353
|
+
|
|
354
|
+
axes = 'SC'+axes.replace('C','')
|
|
355
|
+
if channel is None:
|
|
356
|
+
X = np.expand_dims(X,1)
|
|
357
|
+
Y = np.expand_dims(Y,1)
|
|
358
|
+
else:
|
|
359
|
+
X = np.moveaxis(X, 1+channel, 1)
|
|
360
|
+
Y = np.moveaxis(Y, 1+channel, 1)
|
|
361
|
+
|
|
362
|
+
if save_file is not None:
|
|
363
|
+
print('Saving data to %s.' % str(Path(save_file)))
|
|
364
|
+
save_training_data(save_file, X, Y, axes)
|
|
365
|
+
|
|
366
|
+
return X,Y,axes
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
def create_patches_reduced_target(
|
|
370
|
+
raw_data,
|
|
371
|
+
patch_size,
|
|
372
|
+
n_patches_per_image,
|
|
373
|
+
reduction_axes,
|
|
374
|
+
target_axes = None, # TODO: this should rather be part of RawData and also exposed to transforms
|
|
375
|
+
**kwargs
|
|
376
|
+
):
|
|
377
|
+
"""Create normalized training data to be used for neural network training.
|
|
378
|
+
|
|
379
|
+
In contrast to :func:`create_patches`, it is assumed that the target image has reduced
|
|
380
|
+
dimensionality (i.e. size 1) along one or several axes (`reduction_axes`).
|
|
381
|
+
|
|
382
|
+
Parameters
|
|
383
|
+
----------
|
|
384
|
+
raw_data : :class:`RawData`
|
|
385
|
+
See :func:`create_patches`.
|
|
386
|
+
patch_size : tuple
|
|
387
|
+
See :func:`create_patches`.
|
|
388
|
+
n_patches_per_image : int
|
|
389
|
+
See :func:`create_patches`.
|
|
390
|
+
reduction_axes : str
|
|
391
|
+
Axes where the target images have a reduced dimension (i.e. size 1) compared to the source images.
|
|
392
|
+
target_axes : str
|
|
393
|
+
Axes of the raw target images. If ``None``, will be assumed to be equal to that of the raw source images.
|
|
394
|
+
kwargs : dict
|
|
395
|
+
Additional parameters as in :func:`create_patches`.
|
|
396
|
+
|
|
397
|
+
Returns
|
|
398
|
+
-------
|
|
399
|
+
tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`, str)
|
|
400
|
+
See :func:`create_patches`. Note that the shape of the target data will be 1 along all reduction axes.
|
|
401
|
+
|
|
402
|
+
"""
|
|
403
|
+
reduction_axes = axes_check_and_normalize(reduction_axes,disallowed='S')
|
|
404
|
+
|
|
405
|
+
transforms = kwargs.get('transforms')
|
|
406
|
+
if transforms is None:
|
|
407
|
+
transforms = []
|
|
408
|
+
transforms = list(transforms)
|
|
409
|
+
transforms.insert(0,broadcast_target(target_axes))
|
|
410
|
+
kwargs['transforms'] = transforms
|
|
411
|
+
|
|
412
|
+
save_file = kwargs.pop('save_file',None)
|
|
413
|
+
|
|
414
|
+
if any(s is None for s in patch_size):
|
|
415
|
+
patch_axes = kwargs.get('patch_axes')
|
|
416
|
+
if patch_axes is not None:
|
|
417
|
+
_transforms = list(transforms)
|
|
418
|
+
_transforms.append(permute_axes(patch_axes))
|
|
419
|
+
else:
|
|
420
|
+
_transforms = transforms
|
|
421
|
+
tf = Transform(*zip(*_transforms))
|
|
422
|
+
image_pairs = compose(*tf.generator)(raw_data.generator())
|
|
423
|
+
x,y,axes,mask = next(image_pairs) # get the first entry from the generator
|
|
424
|
+
patch_size = list(patch_size)
|
|
425
|
+
for i,(a,s) in enumerate(zip(axes,patch_size)):
|
|
426
|
+
if s is not None: continue
|
|
427
|
+
a in reduction_axes or _raise(ValueError("entry of patch_size is None for non reduction axis %s." % a))
|
|
428
|
+
patch_size[i] = x.shape[i]
|
|
429
|
+
patch_size = tuple(patch_size)
|
|
430
|
+
del x,y,axes,mask
|
|
431
|
+
|
|
432
|
+
X,Y,axes = create_patches (
|
|
433
|
+
raw_data = raw_data,
|
|
434
|
+
patch_size = patch_size,
|
|
435
|
+
n_patches_per_image = n_patches_per_image,
|
|
436
|
+
**kwargs
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
ax = axes_dict(axes)
|
|
440
|
+
for a in reduction_axes:
|
|
441
|
+
a in axes or _raise(ValueError("reduction axis %d not present in extracted patches" % a))
|
|
442
|
+
n_dims = Y.shape[ax[a]]
|
|
443
|
+
if n_dims == 1:
|
|
444
|
+
warnings.warn("extracted target patches already have dimensionality 1 along reduction axis %s." % a)
|
|
445
|
+
else:
|
|
446
|
+
t = np.take(Y,(1,),axis=ax[a])
|
|
447
|
+
Y = np.take(Y,(0,),axis=ax[a])
|
|
448
|
+
i = np.random.choice(Y.size,size=100)
|
|
449
|
+
if not np.all(t.flat[i]==Y.flat[i]):
|
|
450
|
+
warnings.warn("extracted target patches vary along reduction axis %s." % a)
|
|
451
|
+
|
|
452
|
+
if save_file is not None:
|
|
453
|
+
print('Saving data to %s.' % str(Path(save_file)))
|
|
454
|
+
save_training_data(save_file, X, Y, axes)
|
|
455
|
+
|
|
456
|
+
return X,Y,axes
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
# Misc
|
|
460
|
+
|
|
461
|
+
def shuffle_inplace(*arrs,**kwargs):
|
|
462
|
+
seed = kwargs.pop('seed', None)
|
|
463
|
+
if seed is None:
|
|
464
|
+
rng = np.random
|
|
465
|
+
else:
|
|
466
|
+
rng = np.random.RandomState(seed=seed)
|
|
467
|
+
state = rng.get_state()
|
|
468
|
+
for a in arrs:
|
|
469
|
+
rng.set_state(state)
|
|
470
|
+
rng.shuffle(a)
|