senoquant 1.0.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- senoquant/__init__.py +6 -0
- senoquant/_reader.py +7 -0
- senoquant/_widget.py +33 -0
- senoquant/napari.yaml +83 -0
- senoquant/reader/__init__.py +5 -0
- senoquant/reader/core.py +369 -0
- senoquant/tabs/__init__.py +15 -0
- senoquant/tabs/batch/__init__.py +10 -0
- senoquant/tabs/batch/backend.py +641 -0
- senoquant/tabs/batch/config.py +270 -0
- senoquant/tabs/batch/frontend.py +1283 -0
- senoquant/tabs/batch/io.py +326 -0
- senoquant/tabs/batch/layers.py +86 -0
- senoquant/tabs/quantification/__init__.py +1 -0
- senoquant/tabs/quantification/backend.py +228 -0
- senoquant/tabs/quantification/features/__init__.py +80 -0
- senoquant/tabs/quantification/features/base.py +142 -0
- senoquant/tabs/quantification/features/marker/__init__.py +5 -0
- senoquant/tabs/quantification/features/marker/config.py +69 -0
- senoquant/tabs/quantification/features/marker/dialog.py +437 -0
- senoquant/tabs/quantification/features/marker/export.py +879 -0
- senoquant/tabs/quantification/features/marker/feature.py +119 -0
- senoquant/tabs/quantification/features/marker/morphology.py +285 -0
- senoquant/tabs/quantification/features/marker/rows.py +654 -0
- senoquant/tabs/quantification/features/marker/thresholding.py +46 -0
- senoquant/tabs/quantification/features/roi.py +346 -0
- senoquant/tabs/quantification/features/spots/__init__.py +5 -0
- senoquant/tabs/quantification/features/spots/config.py +62 -0
- senoquant/tabs/quantification/features/spots/dialog.py +477 -0
- senoquant/tabs/quantification/features/spots/export.py +1292 -0
- senoquant/tabs/quantification/features/spots/feature.py +112 -0
- senoquant/tabs/quantification/features/spots/morphology.py +279 -0
- senoquant/tabs/quantification/features/spots/rows.py +241 -0
- senoquant/tabs/quantification/frontend.py +815 -0
- senoquant/tabs/segmentation/__init__.py +1 -0
- senoquant/tabs/segmentation/backend.py +131 -0
- senoquant/tabs/segmentation/frontend.py +1009 -0
- senoquant/tabs/segmentation/models/__init__.py +5 -0
- senoquant/tabs/segmentation/models/base.py +146 -0
- senoquant/tabs/segmentation/models/cpsam/details.json +65 -0
- senoquant/tabs/segmentation/models/cpsam/model.py +150 -0
- senoquant/tabs/segmentation/models/default_2d/details.json +69 -0
- senoquant/tabs/segmentation/models/default_2d/model.py +664 -0
- senoquant/tabs/segmentation/models/default_3d/details.json +69 -0
- senoquant/tabs/segmentation/models/default_3d/model.py +682 -0
- senoquant/tabs/segmentation/models/hf.py +71 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/__init__.py +1 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/details.json +26 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/model.py +96 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/__init__.py +1 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/details.json +34 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/model.py +132 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/__init__.py +2 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/__init__.py +3 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/__init__.py +6 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/generate.py +470 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/prepare.py +273 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/rawdata.py +112 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/transform.py +384 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/blocks.py +184 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/losses.py +79 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/nets.py +165 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/predict.py +467 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/probability.py +67 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/train.py +148 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/io/__init__.py +163 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/__init__.py +52 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/base_model.py +329 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_isotropic.py +160 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_projection.py +178 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_standard.py +446 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_upsampling.py +54 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/config.py +254 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/pretrained.py +119 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/care_predict.py +180 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/__init__.py +5 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/plot_utils.py +159 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/six.py +18 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/tf.py +644 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/utils.py +272 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/version.py +1 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/docs/source/conf.py +368 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/setup.py +68 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_datagen.py +169 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_models.py +462 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_utils.py +166 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tools/create_zip_contents.py +34 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/__init__.py +30 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/big.py +624 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/bioimageio_utils.py +494 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/data/__init__.py +39 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/__init__.py +10 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom2d.py +215 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom3d.py +349 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/matching.py +483 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/__init__.py +28 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/base.py +1217 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model2d.py +594 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model3d.py +696 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/nms.py +384 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/__init__.py +2 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/plot.py +74 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/render.py +298 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/rays3d.py +373 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/sample_patches.py +65 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict2d.py +90 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict3d.py +93 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/utils.py +408 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/version.py +1 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/__init__.py +45 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/__init__.py +17 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/cli.py +55 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/core.py +285 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/__init__.py +15 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/cli.py +36 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/divisibility.py +193 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/probe.py +100 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/receptive_field.py +182 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/rf_cli.py +48 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/valid_sizes.py +278 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/__init__.py +8 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/core.py +157 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/__init__.py +17 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/core.py +226 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/__init__.py +5 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/core.py +401 -0
- senoquant/tabs/settings/__init__.py +1 -0
- senoquant/tabs/settings/backend.py +29 -0
- senoquant/tabs/settings/frontend.py +19 -0
- senoquant/tabs/spots/__init__.py +1 -0
- senoquant/tabs/spots/backend.py +139 -0
- senoquant/tabs/spots/frontend.py +800 -0
- senoquant/tabs/spots/models/__init__.py +5 -0
- senoquant/tabs/spots/models/base.py +94 -0
- senoquant/tabs/spots/models/rmp/details.json +61 -0
- senoquant/tabs/spots/models/rmp/model.py +499 -0
- senoquant/tabs/spots/models/udwt/details.json +103 -0
- senoquant/tabs/spots/models/udwt/model.py +482 -0
- senoquant/utils.py +25 -0
- senoquant-1.0.0b1.dist-info/METADATA +193 -0
- senoquant-1.0.0b1.dist-info/RECORD +148 -0
- senoquant-1.0.0b1.dist-info/WHEEL +5 -0
- senoquant-1.0.0b1.dist-info/entry_points.txt +2 -0
- senoquant-1.0.0b1.dist-info/licenses/LICENSE +28 -0
- senoquant-1.0.0b1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
from __future__ import print_function, unicode_literals, absolute_import, division
|
|
2
|
+
from six.moves import range, zip, map, reduce, filter
|
|
3
|
+
|
|
4
|
+
from ..utils.tf import keras_import
|
|
5
|
+
Input, Conv2D, Conv3D, Activation, Lambda, Add, Concatenate = keras_import('layers', 'Input', 'Conv2D', 'Conv3D', 'Activation', 'Lambda', 'Add', 'Concatenate')
|
|
6
|
+
Model = keras_import('models', 'Model')
|
|
7
|
+
from .blocks import unet_block
|
|
8
|
+
import re
|
|
9
|
+
|
|
10
|
+
from ..utils import _raise, backend_channels_last
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def custom_unet(input_shape,
|
|
15
|
+
last_activation,
|
|
16
|
+
n_depth=2,
|
|
17
|
+
n_filter_base=16,
|
|
18
|
+
kernel_size=(3,3,3),
|
|
19
|
+
n_conv_per_depth=2,
|
|
20
|
+
activation="relu",
|
|
21
|
+
batch_norm=False,
|
|
22
|
+
dropout=0.0,
|
|
23
|
+
pool_size=(2,2,2),
|
|
24
|
+
n_channel_out=1,
|
|
25
|
+
residual=False,
|
|
26
|
+
prob_out=False,
|
|
27
|
+
eps_scale=1e-3):
|
|
28
|
+
""" TODO """
|
|
29
|
+
|
|
30
|
+
if last_activation is None:
|
|
31
|
+
raise ValueError("last activation has to be given (e.g. 'sigmoid', 'relu')!")
|
|
32
|
+
|
|
33
|
+
all((s % 2 == 1 for s in kernel_size)) or _raise(ValueError('kernel size should be odd in all dimensions.'))
|
|
34
|
+
|
|
35
|
+
channel_axis = -1 if backend_channels_last() else 1
|
|
36
|
+
|
|
37
|
+
n_dim = len(kernel_size)
|
|
38
|
+
conv = Conv2D if n_dim==2 else Conv3D
|
|
39
|
+
|
|
40
|
+
input = Input(input_shape, name = "input")
|
|
41
|
+
unet = unet_block(n_depth, n_filter_base, kernel_size,
|
|
42
|
+
activation=activation, dropout=dropout, batch_norm=batch_norm,
|
|
43
|
+
n_conv_per_depth=n_conv_per_depth, pool=pool_size)(input)
|
|
44
|
+
|
|
45
|
+
final = conv(n_channel_out, (1,)*n_dim, activation='linear')(unet)
|
|
46
|
+
if residual:
|
|
47
|
+
if not (n_channel_out == input_shape[-1] if backend_channels_last() else n_channel_out == input_shape[0]):
|
|
48
|
+
raise ValueError("number of input and output channels must be the same for a residual net.")
|
|
49
|
+
final = Add()([final, input])
|
|
50
|
+
final = Activation(activation=last_activation)(final)
|
|
51
|
+
|
|
52
|
+
if prob_out:
|
|
53
|
+
scale = conv(n_channel_out, (1,)*n_dim, activation='softplus')(unet)
|
|
54
|
+
scale = Lambda(lambda x: x+np.float32(eps_scale))(scale)
|
|
55
|
+
final = Concatenate(axis=channel_axis)([final,scale])
|
|
56
|
+
|
|
57
|
+
return Model(inputs=input, outputs=final)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def common_unet(n_dim=2, n_depth=1, kern_size=3, n_first=16, n_channel_out=1, residual=True, prob_out=False, last_activation='linear'):
|
|
62
|
+
"""Construct a common CARE neural net based on U-Net [1]_ and residual learning [2]_ to be used for image restoration/enhancement.
|
|
63
|
+
|
|
64
|
+
Parameters
|
|
65
|
+
----------
|
|
66
|
+
n_dim : int
|
|
67
|
+
number of image dimensions (2 or 3)
|
|
68
|
+
n_depth : int
|
|
69
|
+
number of resolution levels of U-Net architecture
|
|
70
|
+
kern_size : int
|
|
71
|
+
size of convolution filter in all image dimensions
|
|
72
|
+
n_first : int
|
|
73
|
+
number of convolution filters for first U-Net resolution level (value is doubled after each downsampling operation)
|
|
74
|
+
n_channel_out : int
|
|
75
|
+
number of channels of the predicted output image
|
|
76
|
+
residual : bool
|
|
77
|
+
if True, model will internally predict the residual w.r.t. the input (typically better)
|
|
78
|
+
requires number of input and output image channels to be equal
|
|
79
|
+
prob_out : bool
|
|
80
|
+
standard regression (False) or probabilistic prediction (True)
|
|
81
|
+
if True, model will predict two values for each input pixel (mean and positive scale value)
|
|
82
|
+
last_activation : str
|
|
83
|
+
name of activation function for the final output layer
|
|
84
|
+
|
|
85
|
+
Returns
|
|
86
|
+
-------
|
|
87
|
+
function
|
|
88
|
+
Function to construct the network, which takes as argument the shape of the input image
|
|
89
|
+
|
|
90
|
+
Example
|
|
91
|
+
-------
|
|
92
|
+
>>> model = common_unet(2, 1,3,16, 1, True, False)(input_shape)
|
|
93
|
+
|
|
94
|
+
References
|
|
95
|
+
----------
|
|
96
|
+
.. [1] Olaf Ronneberger, Philipp Fischer, Thomas Brox, *U-Net: Convolutional Networks for Biomedical Image Segmentation*, MICCAI 2015
|
|
97
|
+
.. [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. *Deep Residual Learning for Image Recognition*, CVPR 2016
|
|
98
|
+
"""
|
|
99
|
+
def _build_this(input_shape):
|
|
100
|
+
return custom_unet(input_shape, last_activation, n_depth, n_first, (kern_size,)*n_dim, pool_size=(2,)*n_dim, n_channel_out=n_channel_out, residual=residual, prob_out=prob_out)
|
|
101
|
+
return _build_this
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
modelname = re.compile(r"^(?P<model>resunet|unet)(?P<n_dim>\d)(?P<prob_out>p)?_(?P<n_depth>\d+)_(?P<kern_size>\d+)_(?P<n_first>\d+)(_(?P<n_channel_out>\d+)out)?(_(?P<last_activation>.+)-last)?$")
|
|
106
|
+
def common_unet_by_name(model):
|
|
107
|
+
r"""Shorthand notation for equivalent use of :func:`common_unet`.
|
|
108
|
+
|
|
109
|
+
Parameters
|
|
110
|
+
----------
|
|
111
|
+
model : str
|
|
112
|
+
define model to be created via string, which is parsed as a regular expression:
|
|
113
|
+
`^(?P<model>resunet|unet)(?P<n_dim>\d)(?P<prob_out>p)?_(?P<n_depth>\d+)_(?P<kern_size>\d+)_(?P<n_first>\d+)(_(?P<n_channel_out>\d+)out)?(_(?P<last_activation>.+)-last)?$`
|
|
114
|
+
|
|
115
|
+
Returns
|
|
116
|
+
-------
|
|
117
|
+
function
|
|
118
|
+
Calls :func:`common_unet` with the respective parameters.
|
|
119
|
+
|
|
120
|
+
Raises
|
|
121
|
+
------
|
|
122
|
+
ValueError
|
|
123
|
+
If argument `model` is not a valid string according to the regular expression.
|
|
124
|
+
|
|
125
|
+
Example
|
|
126
|
+
-------
|
|
127
|
+
>>> model = common_unet_by_name('resunet2_1_3_16_1out')(input_shape)
|
|
128
|
+
>>> # equivalent to: model = common_unet(2, 1,3,16, 1, True, False)(input_shape)
|
|
129
|
+
|
|
130
|
+
Todo
|
|
131
|
+
----
|
|
132
|
+
Backslashes in docstring for regexp not rendered correctly.
|
|
133
|
+
|
|
134
|
+
"""
|
|
135
|
+
m = modelname.fullmatch(model)
|
|
136
|
+
if m is None:
|
|
137
|
+
raise ValueError("model name '%s' unknown, must follow pattern '%s'" % (model, modelname.pattern))
|
|
138
|
+
# from pprint import pprint
|
|
139
|
+
# pprint(m.groupdict())
|
|
140
|
+
options = {k:int(m.group(k)) for k in ['n_depth','n_first','kern_size']}
|
|
141
|
+
options['prob_out'] = m.group('prob_out') is not None
|
|
142
|
+
options['residual'] = {'unet': False, 'resunet': True}[m.group('model')]
|
|
143
|
+
options['n_dim'] = int(m.group('n_dim'))
|
|
144
|
+
options['n_channel_out'] = 1 if m.group('n_channel_out') is None else int(m.group('n_channel_out'))
|
|
145
|
+
if m.group('last_activation') is not None:
|
|
146
|
+
options['last_activation'] = m.group('last_activation')
|
|
147
|
+
|
|
148
|
+
return common_unet(**options)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def receptive_field_unet(n_depth, kern_size, pool_size=2, n_dim=2, img_size=1024):
|
|
153
|
+
"""Receptive field for U-Net model (pre/post for each dimension)."""
|
|
154
|
+
x = np.zeros((1,)+(img_size,)*n_dim+(1,))
|
|
155
|
+
mid = tuple([s//2 for s in x.shape[1:-1]])
|
|
156
|
+
x[(slice(None),) + mid + (slice(None),)] = 1
|
|
157
|
+
model = custom_unet (
|
|
158
|
+
x.shape[1:],
|
|
159
|
+
n_depth=n_depth, kernel_size=[kern_size]*n_dim, pool_size=[pool_size]*n_dim,
|
|
160
|
+
n_filter_base=8, activation='linear', last_activation='linear',
|
|
161
|
+
)
|
|
162
|
+
y = model.predict(x)[0,...,0]
|
|
163
|
+
y0 = model.predict(0*x)[0,...,0]
|
|
164
|
+
ind = np.where(np.abs(y-y0)>0)
|
|
165
|
+
return [(m-np.min(i),np.max(i)-m) for (m,i) in zip(mid,ind)]
|
|
@@ -0,0 +1,467 @@
|
|
|
1
|
+
from __future__ import print_function, unicode_literals, absolute_import, division
|
|
2
|
+
from six.moves import range, zip, map, reduce, filter
|
|
3
|
+
|
|
4
|
+
from tqdm import tqdm
|
|
5
|
+
from ..utils import _raise, consume, move_channel_for_backend, backend_channels_last, axes_check_and_normalize, axes_dict
|
|
6
|
+
import warnings
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def to_tensor(x,channel=None,single_sample=True):
|
|
12
|
+
if single_sample:
|
|
13
|
+
x = x[np.newaxis]
|
|
14
|
+
if channel is not None and channel >= 0:
|
|
15
|
+
channel += 1
|
|
16
|
+
if channel is None:
|
|
17
|
+
x, channel = np.expand_dims(x,-1), -1
|
|
18
|
+
return move_channel_for_backend(x,channel)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def from_tensor(x,channel=-1,single_sample=True):
|
|
23
|
+
return np.moveaxis((x[0] if single_sample else x), (-1 if backend_channels_last() else 1), channel)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def tensor_num_channels(x):
|
|
28
|
+
return x.shape[-1 if backend_channels_last() else 1]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def predict_direct(keras_model,x,axes_in,axes_out=None,**kwargs):
|
|
33
|
+
"""TODO."""
|
|
34
|
+
if axes_out is None:
|
|
35
|
+
axes_out = axes_in
|
|
36
|
+
kwargs.setdefault('verbose', 0)
|
|
37
|
+
ax_in, ax_out = axes_dict(axes_in), axes_dict(axes_out)
|
|
38
|
+
channel_in, channel_out = ax_in['C'], ax_out['C']
|
|
39
|
+
single_sample = ax_in['S'] is None
|
|
40
|
+
len(axes_in) == x.ndim or _raise(ValueError())
|
|
41
|
+
x = to_tensor(x,channel=channel_in,single_sample=single_sample)
|
|
42
|
+
pred = from_tensor(keras_model.predict(x,**kwargs),channel=channel_out,single_sample=single_sample)
|
|
43
|
+
len(axes_out) == pred.ndim or _raise(ValueError())
|
|
44
|
+
return pred
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def predict_tiled(keras_model,x,n_tiles,block_sizes,tile_overlaps,axes_in,axes_out=None,pbar=None,**kwargs):
|
|
49
|
+
"""TODO."""
|
|
50
|
+
|
|
51
|
+
if all(t==1 for t in n_tiles):
|
|
52
|
+
pred = predict_direct(keras_model,x,axes_in,axes_out,**kwargs)
|
|
53
|
+
if pbar is not None:
|
|
54
|
+
pbar.update()
|
|
55
|
+
return pred
|
|
56
|
+
|
|
57
|
+
###
|
|
58
|
+
|
|
59
|
+
if axes_out is None:
|
|
60
|
+
axes_out = axes_in
|
|
61
|
+
axes_in, axes_out = axes_check_and_normalize(axes_in,x.ndim), axes_check_and_normalize(axes_out)
|
|
62
|
+
assert 'S' not in axes_in
|
|
63
|
+
assert 'C' in axes_in and 'C' in axes_out
|
|
64
|
+
ax_in, ax_out = axes_dict(axes_in), axes_dict(axes_out)
|
|
65
|
+
channel_in, channel_out = ax_in['C'], ax_out['C']
|
|
66
|
+
|
|
67
|
+
assert set(axes_out).issubset(set(axes_in))
|
|
68
|
+
axes_lost = set(axes_in).difference(set(axes_out))
|
|
69
|
+
|
|
70
|
+
def _to_axes_out(seq,elem):
|
|
71
|
+
# assumption: prediction size is same as input size along all axes, except for channel (and lost axes)
|
|
72
|
+
assert len(seq) == len(axes_in)
|
|
73
|
+
# 1. re-order 'seq' from axes_in to axes_out semantics
|
|
74
|
+
seq = [seq[ax_in[a]] for a in axes_out]
|
|
75
|
+
# 2. replace value at channel position with 'elem'
|
|
76
|
+
seq[ax_out['C']] = elem
|
|
77
|
+
return tuple(seq)
|
|
78
|
+
|
|
79
|
+
###
|
|
80
|
+
|
|
81
|
+
assert x.ndim == len(n_tiles) == len(block_sizes)
|
|
82
|
+
assert n_tiles[channel_in] == 1
|
|
83
|
+
assert all(n_tiles[ax_in[a]] == 1 for a in axes_lost)
|
|
84
|
+
assert all(np.isscalar(t) and 1<=t and int(t)==t for t in n_tiles)
|
|
85
|
+
|
|
86
|
+
# first axis > 1
|
|
87
|
+
axis = next(i for i,t in enumerate(n_tiles) if t>1)
|
|
88
|
+
|
|
89
|
+
block_size = block_sizes[axis]
|
|
90
|
+
tile_overlap = tile_overlaps[axis]
|
|
91
|
+
n_block_overlap = int(np.ceil(1.* tile_overlap / block_size))
|
|
92
|
+
|
|
93
|
+
# print(f"axis={axis},n_tiles={n_tiles[axis]},block_size={block_size},tile_overlap={tile_overlap},n_block_overlap={n_block_overlap}")
|
|
94
|
+
|
|
95
|
+
n_tiles_remaining = list(n_tiles)
|
|
96
|
+
n_tiles_remaining[axis] = 1
|
|
97
|
+
|
|
98
|
+
dst = None
|
|
99
|
+
for tile, s_src, s_dst in tile_iterator_1d(x,axis=axis,n_tiles=n_tiles[axis],block_size=block_size,n_block_overlap=n_block_overlap):
|
|
100
|
+
|
|
101
|
+
pred = predict_tiled(keras_model,tile,n_tiles_remaining,block_sizes,tile_overlaps,axes_in,axes_out,pbar=pbar,**kwargs)
|
|
102
|
+
|
|
103
|
+
# if any(t>1 for t in n_tiles_remaining):
|
|
104
|
+
# pred = predict_tiled(keras_model,tile,n_tiles_remaining,block_sizes,tile_overlaps,axes_in,axes_out,pbar=pbar,**kwargs)
|
|
105
|
+
# else:
|
|
106
|
+
# # tmp
|
|
107
|
+
# pred = tile
|
|
108
|
+
# if pbar is not None:
|
|
109
|
+
# pbar.update()
|
|
110
|
+
|
|
111
|
+
if dst is None:
|
|
112
|
+
dst_shape = _to_axes_out(x.shape, pred.shape[channel_out])
|
|
113
|
+
dst = np.empty(dst_shape, dtype=x.dtype)
|
|
114
|
+
|
|
115
|
+
s_src = _to_axes_out(s_src, slice(None))
|
|
116
|
+
s_dst = _to_axes_out(s_dst, slice(None))
|
|
117
|
+
|
|
118
|
+
dst[s_dst] = pred[s_src]
|
|
119
|
+
|
|
120
|
+
return dst
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class Tile(object):
|
|
125
|
+
def __init__(self, n, size, overlap, prev):
|
|
126
|
+
self.n = int(n)
|
|
127
|
+
self.size = int(size)
|
|
128
|
+
self.overlap = int(overlap)
|
|
129
|
+
if self.n < self.size:
|
|
130
|
+
assert prev is None
|
|
131
|
+
# print("Truncating tile size from %d to %d." % (self.size, self.n))
|
|
132
|
+
self.size = self.n
|
|
133
|
+
self.overlap = 0
|
|
134
|
+
assert self.size > 2*self.overlap
|
|
135
|
+
# assert self.n >= self.size
|
|
136
|
+
if prev is not None:
|
|
137
|
+
assert not prev.at_end, "Previous tile already at end"
|
|
138
|
+
self.prev = prev
|
|
139
|
+
self.read_slice = self._read_slice
|
|
140
|
+
self.write_slice = self._write_slice
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
def at_begin(self):
|
|
144
|
+
return self.prev is None
|
|
145
|
+
|
|
146
|
+
@property
|
|
147
|
+
def at_end(self):
|
|
148
|
+
return self.read_slice.stop == self.n
|
|
149
|
+
|
|
150
|
+
@property
|
|
151
|
+
def _read_slice(self):
|
|
152
|
+
if self.at_begin:
|
|
153
|
+
start, stop = 0, self.size
|
|
154
|
+
else:
|
|
155
|
+
prev_read_slice = self.prev.read_slice
|
|
156
|
+
start = prev_read_slice.stop - 2*self.overlap
|
|
157
|
+
stop = start + self.size
|
|
158
|
+
shift = min(0, self.n - stop)
|
|
159
|
+
start, stop = start + shift, stop + shift
|
|
160
|
+
assert start > prev_read_slice.start
|
|
161
|
+
assert start >= 0 and stop <= self.n
|
|
162
|
+
return slice(start, stop)
|
|
163
|
+
|
|
164
|
+
@property
|
|
165
|
+
def _write_slice(self):
|
|
166
|
+
if self.at_begin:
|
|
167
|
+
if self.at_end:
|
|
168
|
+
return slice(0, self.n)
|
|
169
|
+
else:
|
|
170
|
+
return slice(0, self.size - 1*self.overlap)
|
|
171
|
+
elif self.at_end:
|
|
172
|
+
s = self.prev.write_slice.stop
|
|
173
|
+
return slice(s, self.n)
|
|
174
|
+
else:
|
|
175
|
+
s = self.prev.write_slice.stop
|
|
176
|
+
return slice(s, s + self.size - 2*self.overlap)
|
|
177
|
+
|
|
178
|
+
def __repr__(self):
|
|
179
|
+
s = np.array(list(' '*self.n))
|
|
180
|
+
s[self.read_slice] = '-'
|
|
181
|
+
s[self.write_slice] = 'x' if (self.at_begin or self.at_end) else 'o'
|
|
182
|
+
return ''.join(s)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
class Tiling(object):
|
|
187
|
+
def __init__(self, n, size, overlap):
|
|
188
|
+
self.n = n
|
|
189
|
+
self.size = size
|
|
190
|
+
self.overlap = overlap
|
|
191
|
+
tiles = [Tile(prev=None, **self.__dict__)]
|
|
192
|
+
while not tiles[-1].at_end:
|
|
193
|
+
tiles.append(Tile(prev=tiles[-1], **self.__dict__))
|
|
194
|
+
self.tiles = tiles
|
|
195
|
+
|
|
196
|
+
def __len__(self):
|
|
197
|
+
return len(self.tiles)
|
|
198
|
+
|
|
199
|
+
def __repr__(self):
|
|
200
|
+
return '\n'.join('{i:3}. {t}'.format(i=i,t=t) for i,t in enumerate(self.tiles,1))
|
|
201
|
+
|
|
202
|
+
def slice_generator(self, block_size=1):
|
|
203
|
+
def scale(sl):
|
|
204
|
+
return slice(block_size * sl.start, block_size * sl.stop)
|
|
205
|
+
def crop_slice(read, write):
|
|
206
|
+
stop = write.stop - read.stop
|
|
207
|
+
return slice(write.start - read.start, stop if stop < 0 else None)
|
|
208
|
+
for t in self.tiles:
|
|
209
|
+
read, write = scale(t.read_slice), scale(t.write_slice)
|
|
210
|
+
yield read, write, crop_slice(read, write)
|
|
211
|
+
|
|
212
|
+
@staticmethod
|
|
213
|
+
def for_n_tiles(n, n_tiles, overlap):
|
|
214
|
+
smallest_size = 2*overlap + 1
|
|
215
|
+
tile_size = smallest_size # start with smallest posible tile_size
|
|
216
|
+
while len(Tiling(n, tile_size, overlap)) > n_tiles:
|
|
217
|
+
tile_size += 1
|
|
218
|
+
if tile_size == smallest_size:
|
|
219
|
+
return Tiling(n, tile_size, overlap)
|
|
220
|
+
candidates = (
|
|
221
|
+
Tiling(n, tile_size-1, overlap),
|
|
222
|
+
Tiling(n, tile_size, overlap),
|
|
223
|
+
)
|
|
224
|
+
diffs = [np.abs(len(c) - n_tiles) for c in candidates]
|
|
225
|
+
return candidates[np.argmin(diffs)]
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def total_n_tiles(x,n_tiles,block_sizes,n_block_overlaps,guarantee='size'):
|
|
230
|
+
assert x.ndim == len(n_tiles) == len(block_sizes) == len(n_block_overlaps)
|
|
231
|
+
assert guarantee in ('size', 'n_tiles')
|
|
232
|
+
n_tiles_used = 1
|
|
233
|
+
for n, n_tile, block_size, n_block_overlap in zip(x.shape, n_tiles, block_sizes, n_block_overlaps):
|
|
234
|
+
assert n % block_size == 0
|
|
235
|
+
n_blocks = n // block_size
|
|
236
|
+
if guarantee == 'size':
|
|
237
|
+
n_tiles_used *= len(Tiling.for_n_tiles(n_blocks, n_tile, n_block_overlap))
|
|
238
|
+
elif guarantee == 'n_tiles':
|
|
239
|
+
n_tiles_used *= n_tile
|
|
240
|
+
return n_tiles_used
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def tile_iterator_1d(x,axis,n_tiles,block_size,n_block_overlap,guarantee='size'):
|
|
245
|
+
"""Tile iterator for one dimension of array x.
|
|
246
|
+
|
|
247
|
+
Parameters
|
|
248
|
+
----------
|
|
249
|
+
x : numpy.ndarray
|
|
250
|
+
Input array
|
|
251
|
+
axis : int
|
|
252
|
+
Axis which sould be tiled, all other axis not tiled
|
|
253
|
+
n_tiles : int
|
|
254
|
+
Targeted number of tiles for axis of x (see guarantee below)
|
|
255
|
+
block_size : int
|
|
256
|
+
Axis of x is assumed to be evenly divisible by block_size
|
|
257
|
+
All tiles are aligned with the block_size
|
|
258
|
+
n_block_overlap : int
|
|
259
|
+
Tiles will overlap at least this many blocks (see guarantee below)
|
|
260
|
+
guarantee : str
|
|
261
|
+
Can be either 'size' or 'n_tiles':
|
|
262
|
+
'size': The size of all tiles is guaranteed to be the same,
|
|
263
|
+
but the number of tiles can be different and the
|
|
264
|
+
amount of overlap can be larger than requested.
|
|
265
|
+
'n_tiles': The size of tiles can be different at the beginning and end,
|
|
266
|
+
but the number of tiles is guarantee to be the one requested.
|
|
267
|
+
The mount of overlap is also exactly as requested.
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
"""
|
|
271
|
+
n = x.shape[axis]
|
|
272
|
+
|
|
273
|
+
n % block_size == 0 or _raise(ValueError("'x' must be evenly divisible by 'block_size' along 'axis'"))
|
|
274
|
+
n_blocks = n // block_size
|
|
275
|
+
|
|
276
|
+
guarantee in ('size', 'n_tiles') or _raise(ValueError("guarantee must be either 'size' or 'n_tiles'"))
|
|
277
|
+
|
|
278
|
+
if guarantee == 'size':
|
|
279
|
+
tiling = Tiling.for_n_tiles(n_blocks, n_tiles, n_block_overlap)
|
|
280
|
+
|
|
281
|
+
def ndim_slices(t):
|
|
282
|
+
sl = [slice(None)] * x.ndim
|
|
283
|
+
sl[axis] = t
|
|
284
|
+
return tuple(sl)
|
|
285
|
+
|
|
286
|
+
for read, write, crop in tiling.slice_generator(block_size):
|
|
287
|
+
tile_in = read # src in input image / tile
|
|
288
|
+
tile_out = write # dst in output image / s_dst
|
|
289
|
+
tile_crop = crop # crop of src for output / s_src
|
|
290
|
+
yield x[ndim_slices(tile_in)], ndim_slices(tile_crop), ndim_slices(tile_out)
|
|
291
|
+
|
|
292
|
+
elif guarantee == 'n_tiles':
|
|
293
|
+
n_tiles_valid = int(np.clip(n_tiles,1,n_blocks))
|
|
294
|
+
if n_tiles != n_tiles_valid:
|
|
295
|
+
warnings.warn("invalid value (%d) for 'n_tiles', changing to %d" % (n_tiles,n_tiles_valid))
|
|
296
|
+
n_tiles = n_tiles_valid
|
|
297
|
+
|
|
298
|
+
s = n_blocks // n_tiles # tile size
|
|
299
|
+
r = n_blocks % n_tiles # blocks remainder
|
|
300
|
+
assert n_tiles * s + r == n_blocks
|
|
301
|
+
|
|
302
|
+
# list of sizes for each tile
|
|
303
|
+
tile_sizes = s*np.ones(n_tiles,int)
|
|
304
|
+
# distribute remaining blocks to tiles at beginning and end
|
|
305
|
+
if r > 0:
|
|
306
|
+
tile_sizes[:r//2] += 1
|
|
307
|
+
tile_sizes[-(r-r//2):] += 1
|
|
308
|
+
|
|
309
|
+
# n_block_overlap = int(np.ceil(92 / block_size))
|
|
310
|
+
# n_block_overlap -= 1
|
|
311
|
+
# print(n_block_overlap)
|
|
312
|
+
|
|
313
|
+
# (pre,post) offsets for each tile
|
|
314
|
+
off = [(n_block_overlap if i > 0 else 0, n_block_overlap if i < n_tiles-1 else 0) for i in range(n_tiles)]
|
|
315
|
+
|
|
316
|
+
# tile_starts = np.concatenate(([0],np.cumsum(tile_sizes[:-1])))
|
|
317
|
+
# print([(_st-_pre,_st+_sz+_post) for (_st,_sz,(_pre,_post)) in zip(tile_starts,tile_sizes,off)])
|
|
318
|
+
|
|
319
|
+
def to_slice(t):
|
|
320
|
+
sl = [slice(None)] * x.ndim
|
|
321
|
+
sl[axis] = slice(
|
|
322
|
+
t[0]*block_size,
|
|
323
|
+
t[1]*block_size if t[1]!=0 else None)
|
|
324
|
+
return tuple(sl)
|
|
325
|
+
|
|
326
|
+
start = 0
|
|
327
|
+
for i in range(n_tiles):
|
|
328
|
+
off_pre, off_post = off[i]
|
|
329
|
+
|
|
330
|
+
# tile starts before block 0 -> adjust off_pre
|
|
331
|
+
if start-off_pre < 0:
|
|
332
|
+
off_pre = start
|
|
333
|
+
# tile end after last block -> adjust off_post
|
|
334
|
+
if start+tile_sizes[i]+off_post > n_blocks:
|
|
335
|
+
off_post = n_blocks-start-tile_sizes[i]
|
|
336
|
+
|
|
337
|
+
tile_in = (start-off_pre,start+tile_sizes[i]+off_post) # src in input image / tile
|
|
338
|
+
tile_out = (start,start+tile_sizes[i]) # dst in output image / s_dst
|
|
339
|
+
tile_crop = (off_pre,-off_post) # crop of src for output / s_src
|
|
340
|
+
|
|
341
|
+
yield x[to_slice(tile_in)], to_slice(tile_crop), to_slice(tile_out)
|
|
342
|
+
start += tile_sizes[i]
|
|
343
|
+
|
|
344
|
+
else:
|
|
345
|
+
assert False
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def tile_iterator(x,n_tiles,block_sizes,n_block_overlaps,guarantee='size'):
|
|
350
|
+
"""Tile iterator for n-d arrays.
|
|
351
|
+
|
|
352
|
+
Yields block-aligned tiles (`block_sizes`) that have at least
|
|
353
|
+
a certain amount of overlapping blocks (`n_block_overlaps`)
|
|
354
|
+
with their neighbors. Also yields slices that allow to map each
|
|
355
|
+
tile back to the original array x.
|
|
356
|
+
|
|
357
|
+
Notes
|
|
358
|
+
-----
|
|
359
|
+
- Tiles will not go beyond the array boundary (i.e. no padding).
|
|
360
|
+
This means the shape of x must be evenly divisible by the respective block_size.
|
|
361
|
+
- It is not guaranteed that all tiles have the same size if guarantee is not 'size'.
|
|
362
|
+
|
|
363
|
+
Parameters
|
|
364
|
+
----------
|
|
365
|
+
x : numpy.ndarray
|
|
366
|
+
Input array.
|
|
367
|
+
n_tiles : int or sequence of ints
|
|
368
|
+
Number of tiles for each dimension of x.
|
|
369
|
+
block_sizes : int or sequence of ints
|
|
370
|
+
Block sizes for each dimension of x.
|
|
371
|
+
The shape of x is assumed to be evenly divisible by block_sizes.
|
|
372
|
+
All tiles are aligned with block_sizes.
|
|
373
|
+
n_block_overlaps : int or sequence of ints
|
|
374
|
+
Tiles will at least overlap this many blocks in each dimension.
|
|
375
|
+
guarantee : str
|
|
376
|
+
Can be either 'size' or 'n_tiles':
|
|
377
|
+
'size': The size of all tiles is guaranteed to be the same,
|
|
378
|
+
but the number of tiles can be different and the
|
|
379
|
+
amount of overlap can be larger than requested.
|
|
380
|
+
'n_tiles': The size of tiles can be different at the beginning and end,
|
|
381
|
+
but the number of tiles is guarantee to be the one requested.
|
|
382
|
+
The mount of overlap is also exactly as requested.
|
|
383
|
+
|
|
384
|
+
Example
|
|
385
|
+
-------
|
|
386
|
+
|
|
387
|
+
Duplicate an array tile-by-tile:
|
|
388
|
+
|
|
389
|
+
>>> x = np.array(...)
|
|
390
|
+
>>> y = np.empty_like(x)
|
|
391
|
+
>>>
|
|
392
|
+
>>> for tile,s_src,s_dst in tile_iterator(x, n_tiles, block_sizes, n_block_overlaps):
|
|
393
|
+
>>> y[s_dst] = tile[s_src]
|
|
394
|
+
>>>
|
|
395
|
+
>>> np.allclose(x,y)
|
|
396
|
+
True
|
|
397
|
+
|
|
398
|
+
"""
|
|
399
|
+
if np.isscalar(n_tiles): n_tiles = (n_tiles,)*x.ndim
|
|
400
|
+
if np.isscalar(block_sizes): block_sizes = (block_sizes,)*x.ndim
|
|
401
|
+
if np.isscalar(n_block_overlaps): n_block_overlaps = (n_block_overlaps,)*x.ndim
|
|
402
|
+
|
|
403
|
+
assert x.ndim == len(n_tiles) == len(block_sizes) == len(n_block_overlaps)
|
|
404
|
+
|
|
405
|
+
def _accumulate(tile_in,axis,src,dst):
|
|
406
|
+
for tile, s_src, s_dst in tile_iterator_1d(tile_in, axis, n_tiles[axis], block_sizes[axis], n_block_overlaps[axis], guarantee):
|
|
407
|
+
src[axis] = s_src[axis]
|
|
408
|
+
dst[axis] = s_dst[axis]
|
|
409
|
+
if axis+1 == tile_in.ndim:
|
|
410
|
+
# remove None and negative slicing
|
|
411
|
+
src = [slice(s.start, size if s.stop is None else (s.stop if s.stop >= 0 else size + s.stop)) for s,size in zip(src,tile.shape)]
|
|
412
|
+
yield tile, tuple(src), tuple(dst)
|
|
413
|
+
else:
|
|
414
|
+
# yield from _accumulate(tile, axis+1, src, dst)
|
|
415
|
+
for entry in _accumulate(tile, axis+1, src, dst):
|
|
416
|
+
yield entry
|
|
417
|
+
|
|
418
|
+
return _accumulate(x, 0, [None]*x.ndim, [None]*x.ndim)
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
def tile_overlap(n_depth, kern_size, pool_size=2):
|
|
423
|
+
rf = {(1, 3, 1): 6, (1, 5, 1): 12, (1, 7, 1): 18,
|
|
424
|
+
(2, 3, 1): 10, (2, 5, 1): 20, (2, 7, 1): 30,
|
|
425
|
+
(3, 3, 1): 14, (3, 5, 1): 28, (3, 7, 1): 42,
|
|
426
|
+
(4, 3, 1): 18, (4, 5, 1): 36, (4, 7, 1): 54,
|
|
427
|
+
(5, 3, 1): 22, (5, 5, 1): 44, (5, 7, 1): 66,
|
|
428
|
+
#
|
|
429
|
+
(1, 3, 2): 9, (1, 5, 2): 17, (1, 7, 2): 25,
|
|
430
|
+
(2, 3, 2): 22, (2, 5, 2): 43, (2, 7, 2): 62,
|
|
431
|
+
(3, 3, 2): 46, (3, 5, 2): 92, (3, 7, 2): 138,
|
|
432
|
+
(4, 3, 2): 94, (4, 5, 2): 188, (4, 7, 2): 282,
|
|
433
|
+
(5, 3, 2): 190, (5, 5, 2): 380, (5, 7, 2): 570,
|
|
434
|
+
#
|
|
435
|
+
(1, 3, 4): 14, (1, 5, 4): 27, (1, 7, 4): 38,
|
|
436
|
+
(2, 3, 4): 58, (2, 5, 4): 116, (2, 7, 4): 158,
|
|
437
|
+
(3, 3, 4): 234, (3, 5, 4): 468, (3, 7, 4): 638,
|
|
438
|
+
(4, 3, 4): 938, (4, 5, 4): 1876, (4, 7, 4): 2558}
|
|
439
|
+
try:
|
|
440
|
+
return rf[n_depth, kern_size, pool_size]
|
|
441
|
+
except KeyError:
|
|
442
|
+
raise ValueError('tile_overlap value for n_depth=%d, kern_size=%d, pool_size=%d not available.' % (n_depth, kern_size, pool_size))
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
class Progress(object):
|
|
447
|
+
def __init__(self, total, thr=1):
|
|
448
|
+
self.pbar = None
|
|
449
|
+
self.total = total
|
|
450
|
+
self.thr = thr
|
|
451
|
+
@property
|
|
452
|
+
def total(self):
|
|
453
|
+
return self._total
|
|
454
|
+
@total.setter
|
|
455
|
+
def total(self, total):
|
|
456
|
+
self.close()
|
|
457
|
+
self._total = total
|
|
458
|
+
def update(self):
|
|
459
|
+
if self.total > self.thr:
|
|
460
|
+
if self.pbar is None:
|
|
461
|
+
self.pbar = tqdm(total=self.total)
|
|
462
|
+
self.pbar.update()
|
|
463
|
+
self.pbar.refresh()
|
|
464
|
+
def close(self):
|
|
465
|
+
if self.pbar is not None:
|
|
466
|
+
self.pbar.close()
|
|
467
|
+
self.pbar = None
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from __future__ import print_function, unicode_literals, absolute_import, division
|
|
2
|
+
from six.moves import range, zip, map, reduce, filter
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
from ..utils import _raise, consume
|
|
6
|
+
import warnings
|
|
7
|
+
import numpy as np
|
|
8
|
+
from scipy.stats import laplace
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ProbabilisticPrediction(object):
|
|
12
|
+
"""Laplace distribution (independently per pixel)."""
|
|
13
|
+
|
|
14
|
+
def __init__(self, loc, scale):
|
|
15
|
+
loc.shape == scale.shape or _raise(ValueError())
|
|
16
|
+
#
|
|
17
|
+
self._loc = loc
|
|
18
|
+
self._scale = scale
|
|
19
|
+
# expose methods from laplace object
|
|
20
|
+
_laplace = laplace(loc=self._loc,scale=self._scale)
|
|
21
|
+
self.rvs = _laplace.rvs
|
|
22
|
+
self.pdf = _laplace.pdf
|
|
23
|
+
self.logpdf = _laplace.logpdf
|
|
24
|
+
self.cdf = _laplace.cdf
|
|
25
|
+
self.logcdf = _laplace.logcdf
|
|
26
|
+
self.sf = _laplace.sf
|
|
27
|
+
self.logsf = _laplace.logsf
|
|
28
|
+
self.ppf = _laplace.ppf
|
|
29
|
+
self.isf = _laplace.isf
|
|
30
|
+
self.moment = _laplace.moment
|
|
31
|
+
self.stats = _laplace.stats
|
|
32
|
+
self.entropy = _laplace.entropy
|
|
33
|
+
self.expect = _laplace.expect
|
|
34
|
+
self.median = _laplace.median
|
|
35
|
+
self.mean = _laplace.mean
|
|
36
|
+
self.var = _laplace.var
|
|
37
|
+
self.std = _laplace.std
|
|
38
|
+
self.interval = _laplace.interval
|
|
39
|
+
|
|
40
|
+
def __getitem__(self, indices):
|
|
41
|
+
return ProbabilisticPrediction(loc=self._loc[indices],scale=self._scale[indices])
|
|
42
|
+
|
|
43
|
+
def __len__(self):
|
|
44
|
+
return len(self._loc)
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def shape(self):
|
|
48
|
+
return self._loc.shape
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def ndim(self):
|
|
52
|
+
return self._loc.ndim
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def size(self):
|
|
56
|
+
return self._loc.size
|
|
57
|
+
|
|
58
|
+
def scale(self):
|
|
59
|
+
return self._scale
|
|
60
|
+
|
|
61
|
+
def sampling_generator(self,n=None):
|
|
62
|
+
if n is None:
|
|
63
|
+
while True:
|
|
64
|
+
yield self.rvs()
|
|
65
|
+
else:
|
|
66
|
+
for i in range(n):
|
|
67
|
+
yield self.rvs()
|