senoquant 1.0.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. senoquant/__init__.py +6 -0
  2. senoquant/_reader.py +7 -0
  3. senoquant/_widget.py +33 -0
  4. senoquant/napari.yaml +83 -0
  5. senoquant/reader/__init__.py +5 -0
  6. senoquant/reader/core.py +369 -0
  7. senoquant/tabs/__init__.py +15 -0
  8. senoquant/tabs/batch/__init__.py +10 -0
  9. senoquant/tabs/batch/backend.py +641 -0
  10. senoquant/tabs/batch/config.py +270 -0
  11. senoquant/tabs/batch/frontend.py +1283 -0
  12. senoquant/tabs/batch/io.py +326 -0
  13. senoquant/tabs/batch/layers.py +86 -0
  14. senoquant/tabs/quantification/__init__.py +1 -0
  15. senoquant/tabs/quantification/backend.py +228 -0
  16. senoquant/tabs/quantification/features/__init__.py +80 -0
  17. senoquant/tabs/quantification/features/base.py +142 -0
  18. senoquant/tabs/quantification/features/marker/__init__.py +5 -0
  19. senoquant/tabs/quantification/features/marker/config.py +69 -0
  20. senoquant/tabs/quantification/features/marker/dialog.py +437 -0
  21. senoquant/tabs/quantification/features/marker/export.py +879 -0
  22. senoquant/tabs/quantification/features/marker/feature.py +119 -0
  23. senoquant/tabs/quantification/features/marker/morphology.py +285 -0
  24. senoquant/tabs/quantification/features/marker/rows.py +654 -0
  25. senoquant/tabs/quantification/features/marker/thresholding.py +46 -0
  26. senoquant/tabs/quantification/features/roi.py +346 -0
  27. senoquant/tabs/quantification/features/spots/__init__.py +5 -0
  28. senoquant/tabs/quantification/features/spots/config.py +62 -0
  29. senoquant/tabs/quantification/features/spots/dialog.py +477 -0
  30. senoquant/tabs/quantification/features/spots/export.py +1292 -0
  31. senoquant/tabs/quantification/features/spots/feature.py +112 -0
  32. senoquant/tabs/quantification/features/spots/morphology.py +279 -0
  33. senoquant/tabs/quantification/features/spots/rows.py +241 -0
  34. senoquant/tabs/quantification/frontend.py +815 -0
  35. senoquant/tabs/segmentation/__init__.py +1 -0
  36. senoquant/tabs/segmentation/backend.py +131 -0
  37. senoquant/tabs/segmentation/frontend.py +1009 -0
  38. senoquant/tabs/segmentation/models/__init__.py +5 -0
  39. senoquant/tabs/segmentation/models/base.py +146 -0
  40. senoquant/tabs/segmentation/models/cpsam/details.json +65 -0
  41. senoquant/tabs/segmentation/models/cpsam/model.py +150 -0
  42. senoquant/tabs/segmentation/models/default_2d/details.json +69 -0
  43. senoquant/tabs/segmentation/models/default_2d/model.py +664 -0
  44. senoquant/tabs/segmentation/models/default_3d/details.json +69 -0
  45. senoquant/tabs/segmentation/models/default_3d/model.py +682 -0
  46. senoquant/tabs/segmentation/models/hf.py +71 -0
  47. senoquant/tabs/segmentation/models/nuclear_dilation/__init__.py +1 -0
  48. senoquant/tabs/segmentation/models/nuclear_dilation/details.json +26 -0
  49. senoquant/tabs/segmentation/models/nuclear_dilation/model.py +96 -0
  50. senoquant/tabs/segmentation/models/perinuclear_rings/__init__.py +1 -0
  51. senoquant/tabs/segmentation/models/perinuclear_rings/details.json +34 -0
  52. senoquant/tabs/segmentation/models/perinuclear_rings/model.py +132 -0
  53. senoquant/tabs/segmentation/stardist_onnx_utils/__init__.py +2 -0
  54. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/__init__.py +3 -0
  55. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/__init__.py +6 -0
  56. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/generate.py +470 -0
  57. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/prepare.py +273 -0
  58. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/rawdata.py +112 -0
  59. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/transform.py +384 -0
  60. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/__init__.py +0 -0
  61. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/blocks.py +184 -0
  62. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/losses.py +79 -0
  63. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/nets.py +165 -0
  64. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/predict.py +467 -0
  65. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/probability.py +67 -0
  66. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/train.py +148 -0
  67. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/io/__init__.py +163 -0
  68. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/__init__.py +52 -0
  69. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/base_model.py +329 -0
  70. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_isotropic.py +160 -0
  71. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_projection.py +178 -0
  72. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_standard.py +446 -0
  73. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_upsampling.py +54 -0
  74. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/config.py +254 -0
  75. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/pretrained.py +119 -0
  76. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/__init__.py +0 -0
  77. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/care_predict.py +180 -0
  78. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/__init__.py +5 -0
  79. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/plot_utils.py +159 -0
  80. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/six.py +18 -0
  81. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/tf.py +644 -0
  82. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/utils.py +272 -0
  83. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/version.py +1 -0
  84. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/docs/source/conf.py +368 -0
  85. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/setup.py +68 -0
  86. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_datagen.py +169 -0
  87. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_models.py +462 -0
  88. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_utils.py +166 -0
  89. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tools/create_zip_contents.py +34 -0
  90. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/__init__.py +30 -0
  91. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/big.py +624 -0
  92. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/bioimageio_utils.py +494 -0
  93. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/data/__init__.py +39 -0
  94. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/__init__.py +10 -0
  95. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom2d.py +215 -0
  96. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom3d.py +349 -0
  97. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/matching.py +483 -0
  98. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/__init__.py +28 -0
  99. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/base.py +1217 -0
  100. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model2d.py +594 -0
  101. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model3d.py +696 -0
  102. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/nms.py +384 -0
  103. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/__init__.py +2 -0
  104. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/plot.py +74 -0
  105. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/render.py +298 -0
  106. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/rays3d.py +373 -0
  107. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/sample_patches.py +65 -0
  108. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/__init__.py +0 -0
  109. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict2d.py +90 -0
  110. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict3d.py +93 -0
  111. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/utils.py +408 -0
  112. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/version.py +1 -0
  113. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/__init__.py +45 -0
  114. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/__init__.py +17 -0
  115. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/cli.py +55 -0
  116. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/core.py +285 -0
  117. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/__init__.py +15 -0
  118. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/cli.py +36 -0
  119. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/divisibility.py +193 -0
  120. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/probe.py +100 -0
  121. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/receptive_field.py +182 -0
  122. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/rf_cli.py +48 -0
  123. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/valid_sizes.py +278 -0
  124. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/__init__.py +8 -0
  125. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/core.py +157 -0
  126. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/__init__.py +17 -0
  127. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/core.py +226 -0
  128. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/__init__.py +5 -0
  129. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/core.py +401 -0
  130. senoquant/tabs/settings/__init__.py +1 -0
  131. senoquant/tabs/settings/backend.py +29 -0
  132. senoquant/tabs/settings/frontend.py +19 -0
  133. senoquant/tabs/spots/__init__.py +1 -0
  134. senoquant/tabs/spots/backend.py +139 -0
  135. senoquant/tabs/spots/frontend.py +800 -0
  136. senoquant/tabs/spots/models/__init__.py +5 -0
  137. senoquant/tabs/spots/models/base.py +94 -0
  138. senoquant/tabs/spots/models/rmp/details.json +61 -0
  139. senoquant/tabs/spots/models/rmp/model.py +499 -0
  140. senoquant/tabs/spots/models/udwt/details.json +103 -0
  141. senoquant/tabs/spots/models/udwt/model.py +482 -0
  142. senoquant/utils.py +25 -0
  143. senoquant-1.0.0b1.dist-info/METADATA +193 -0
  144. senoquant-1.0.0b1.dist-info/RECORD +148 -0
  145. senoquant-1.0.0b1.dist-info/WHEEL +5 -0
  146. senoquant-1.0.0b1.dist-info/entry_points.txt +2 -0
  147. senoquant-1.0.0b1.dist-info/licenses/LICENSE +28 -0
  148. senoquant-1.0.0b1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,165 @@
1
+ from __future__ import print_function, unicode_literals, absolute_import, division
2
+ from six.moves import range, zip, map, reduce, filter
3
+
4
+ from ..utils.tf import keras_import
5
+ Input, Conv2D, Conv3D, Activation, Lambda, Add, Concatenate = keras_import('layers', 'Input', 'Conv2D', 'Conv3D', 'Activation', 'Lambda', 'Add', 'Concatenate')
6
+ Model = keras_import('models', 'Model')
7
+ from .blocks import unet_block
8
+ import re
9
+
10
+ from ..utils import _raise, backend_channels_last
11
+ import numpy as np
12
+
13
+
14
+ def custom_unet(input_shape,
15
+ last_activation,
16
+ n_depth=2,
17
+ n_filter_base=16,
18
+ kernel_size=(3,3,3),
19
+ n_conv_per_depth=2,
20
+ activation="relu",
21
+ batch_norm=False,
22
+ dropout=0.0,
23
+ pool_size=(2,2,2),
24
+ n_channel_out=1,
25
+ residual=False,
26
+ prob_out=False,
27
+ eps_scale=1e-3):
28
+ """ TODO """
29
+
30
+ if last_activation is None:
31
+ raise ValueError("last activation has to be given (e.g. 'sigmoid', 'relu')!")
32
+
33
+ all((s % 2 == 1 for s in kernel_size)) or _raise(ValueError('kernel size should be odd in all dimensions.'))
34
+
35
+ channel_axis = -1 if backend_channels_last() else 1
36
+
37
+ n_dim = len(kernel_size)
38
+ conv = Conv2D if n_dim==2 else Conv3D
39
+
40
+ input = Input(input_shape, name = "input")
41
+ unet = unet_block(n_depth, n_filter_base, kernel_size,
42
+ activation=activation, dropout=dropout, batch_norm=batch_norm,
43
+ n_conv_per_depth=n_conv_per_depth, pool=pool_size)(input)
44
+
45
+ final = conv(n_channel_out, (1,)*n_dim, activation='linear')(unet)
46
+ if residual:
47
+ if not (n_channel_out == input_shape[-1] if backend_channels_last() else n_channel_out == input_shape[0]):
48
+ raise ValueError("number of input and output channels must be the same for a residual net.")
49
+ final = Add()([final, input])
50
+ final = Activation(activation=last_activation)(final)
51
+
52
+ if prob_out:
53
+ scale = conv(n_channel_out, (1,)*n_dim, activation='softplus')(unet)
54
+ scale = Lambda(lambda x: x+np.float32(eps_scale))(scale)
55
+ final = Concatenate(axis=channel_axis)([final,scale])
56
+
57
+ return Model(inputs=input, outputs=final)
58
+
59
+
60
+
61
+ def common_unet(n_dim=2, n_depth=1, kern_size=3, n_first=16, n_channel_out=1, residual=True, prob_out=False, last_activation='linear'):
62
+ """Construct a common CARE neural net based on U-Net [1]_ and residual learning [2]_ to be used for image restoration/enhancement.
63
+
64
+ Parameters
65
+ ----------
66
+ n_dim : int
67
+ number of image dimensions (2 or 3)
68
+ n_depth : int
69
+ number of resolution levels of U-Net architecture
70
+ kern_size : int
71
+ size of convolution filter in all image dimensions
72
+ n_first : int
73
+ number of convolution filters for first U-Net resolution level (value is doubled after each downsampling operation)
74
+ n_channel_out : int
75
+ number of channels of the predicted output image
76
+ residual : bool
77
+ if True, model will internally predict the residual w.r.t. the input (typically better)
78
+ requires number of input and output image channels to be equal
79
+ prob_out : bool
80
+ standard regression (False) or probabilistic prediction (True)
81
+ if True, model will predict two values for each input pixel (mean and positive scale value)
82
+ last_activation : str
83
+ name of activation function for the final output layer
84
+
85
+ Returns
86
+ -------
87
+ function
88
+ Function to construct the network, which takes as argument the shape of the input image
89
+
90
+ Example
91
+ -------
92
+ >>> model = common_unet(2, 1,3,16, 1, True, False)(input_shape)
93
+
94
+ References
95
+ ----------
96
+ .. [1] Olaf Ronneberger, Philipp Fischer, Thomas Brox, *U-Net: Convolutional Networks for Biomedical Image Segmentation*, MICCAI 2015
97
+ .. [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. *Deep Residual Learning for Image Recognition*, CVPR 2016
98
+ """
99
+ def _build_this(input_shape):
100
+ return custom_unet(input_shape, last_activation, n_depth, n_first, (kern_size,)*n_dim, pool_size=(2,)*n_dim, n_channel_out=n_channel_out, residual=residual, prob_out=prob_out)
101
+ return _build_this
102
+
103
+
104
+
105
+ modelname = re.compile(r"^(?P<model>resunet|unet)(?P<n_dim>\d)(?P<prob_out>p)?_(?P<n_depth>\d+)_(?P<kern_size>\d+)_(?P<n_first>\d+)(_(?P<n_channel_out>\d+)out)?(_(?P<last_activation>.+)-last)?$")
106
+ def common_unet_by_name(model):
107
+ r"""Shorthand notation for equivalent use of :func:`common_unet`.
108
+
109
+ Parameters
110
+ ----------
111
+ model : str
112
+ define model to be created via string, which is parsed as a regular expression:
113
+ `^(?P<model>resunet|unet)(?P<n_dim>\d)(?P<prob_out>p)?_(?P<n_depth>\d+)_(?P<kern_size>\d+)_(?P<n_first>\d+)(_(?P<n_channel_out>\d+)out)?(_(?P<last_activation>.+)-last)?$`
114
+
115
+ Returns
116
+ -------
117
+ function
118
+ Calls :func:`common_unet` with the respective parameters.
119
+
120
+ Raises
121
+ ------
122
+ ValueError
123
+ If argument `model` is not a valid string according to the regular expression.
124
+
125
+ Example
126
+ -------
127
+ >>> model = common_unet_by_name('resunet2_1_3_16_1out')(input_shape)
128
+ >>> # equivalent to: model = common_unet(2, 1,3,16, 1, True, False)(input_shape)
129
+
130
+ Todo
131
+ ----
132
+ Backslashes in docstring for regexp not rendered correctly.
133
+
134
+ """
135
+ m = modelname.fullmatch(model)
136
+ if m is None:
137
+ raise ValueError("model name '%s' unknown, must follow pattern '%s'" % (model, modelname.pattern))
138
+ # from pprint import pprint
139
+ # pprint(m.groupdict())
140
+ options = {k:int(m.group(k)) for k in ['n_depth','n_first','kern_size']}
141
+ options['prob_out'] = m.group('prob_out') is not None
142
+ options['residual'] = {'unet': False, 'resunet': True}[m.group('model')]
143
+ options['n_dim'] = int(m.group('n_dim'))
144
+ options['n_channel_out'] = 1 if m.group('n_channel_out') is None else int(m.group('n_channel_out'))
145
+ if m.group('last_activation') is not None:
146
+ options['last_activation'] = m.group('last_activation')
147
+
148
+ return common_unet(**options)
149
+
150
+
151
+
152
+ def receptive_field_unet(n_depth, kern_size, pool_size=2, n_dim=2, img_size=1024):
153
+ """Receptive field for U-Net model (pre/post for each dimension)."""
154
+ x = np.zeros((1,)+(img_size,)*n_dim+(1,))
155
+ mid = tuple([s//2 for s in x.shape[1:-1]])
156
+ x[(slice(None),) + mid + (slice(None),)] = 1
157
+ model = custom_unet (
158
+ x.shape[1:],
159
+ n_depth=n_depth, kernel_size=[kern_size]*n_dim, pool_size=[pool_size]*n_dim,
160
+ n_filter_base=8, activation='linear', last_activation='linear',
161
+ )
162
+ y = model.predict(x)[0,...,0]
163
+ y0 = model.predict(0*x)[0,...,0]
164
+ ind = np.where(np.abs(y-y0)>0)
165
+ return [(m-np.min(i),np.max(i)-m) for (m,i) in zip(mid,ind)]
@@ -0,0 +1,467 @@
1
+ from __future__ import print_function, unicode_literals, absolute_import, division
2
+ from six.moves import range, zip, map, reduce, filter
3
+
4
+ from tqdm import tqdm
5
+ from ..utils import _raise, consume, move_channel_for_backend, backend_channels_last, axes_check_and_normalize, axes_dict
6
+ import warnings
7
+ import numpy as np
8
+
9
+
10
+
11
+ def to_tensor(x,channel=None,single_sample=True):
12
+ if single_sample:
13
+ x = x[np.newaxis]
14
+ if channel is not None and channel >= 0:
15
+ channel += 1
16
+ if channel is None:
17
+ x, channel = np.expand_dims(x,-1), -1
18
+ return move_channel_for_backend(x,channel)
19
+
20
+
21
+
22
+ def from_tensor(x,channel=-1,single_sample=True):
23
+ return np.moveaxis((x[0] if single_sample else x), (-1 if backend_channels_last() else 1), channel)
24
+
25
+
26
+
27
+ def tensor_num_channels(x):
28
+ return x.shape[-1 if backend_channels_last() else 1]
29
+
30
+
31
+
32
+ def predict_direct(keras_model,x,axes_in,axes_out=None,**kwargs):
33
+ """TODO."""
34
+ if axes_out is None:
35
+ axes_out = axes_in
36
+ kwargs.setdefault('verbose', 0)
37
+ ax_in, ax_out = axes_dict(axes_in), axes_dict(axes_out)
38
+ channel_in, channel_out = ax_in['C'], ax_out['C']
39
+ single_sample = ax_in['S'] is None
40
+ len(axes_in) == x.ndim or _raise(ValueError())
41
+ x = to_tensor(x,channel=channel_in,single_sample=single_sample)
42
+ pred = from_tensor(keras_model.predict(x,**kwargs),channel=channel_out,single_sample=single_sample)
43
+ len(axes_out) == pred.ndim or _raise(ValueError())
44
+ return pred
45
+
46
+
47
+
48
+ def predict_tiled(keras_model,x,n_tiles,block_sizes,tile_overlaps,axes_in,axes_out=None,pbar=None,**kwargs):
49
+ """TODO."""
50
+
51
+ if all(t==1 for t in n_tiles):
52
+ pred = predict_direct(keras_model,x,axes_in,axes_out,**kwargs)
53
+ if pbar is not None:
54
+ pbar.update()
55
+ return pred
56
+
57
+ ###
58
+
59
+ if axes_out is None:
60
+ axes_out = axes_in
61
+ axes_in, axes_out = axes_check_and_normalize(axes_in,x.ndim), axes_check_and_normalize(axes_out)
62
+ assert 'S' not in axes_in
63
+ assert 'C' in axes_in and 'C' in axes_out
64
+ ax_in, ax_out = axes_dict(axes_in), axes_dict(axes_out)
65
+ channel_in, channel_out = ax_in['C'], ax_out['C']
66
+
67
+ assert set(axes_out).issubset(set(axes_in))
68
+ axes_lost = set(axes_in).difference(set(axes_out))
69
+
70
+ def _to_axes_out(seq,elem):
71
+ # assumption: prediction size is same as input size along all axes, except for channel (and lost axes)
72
+ assert len(seq) == len(axes_in)
73
+ # 1. re-order 'seq' from axes_in to axes_out semantics
74
+ seq = [seq[ax_in[a]] for a in axes_out]
75
+ # 2. replace value at channel position with 'elem'
76
+ seq[ax_out['C']] = elem
77
+ return tuple(seq)
78
+
79
+ ###
80
+
81
+ assert x.ndim == len(n_tiles) == len(block_sizes)
82
+ assert n_tiles[channel_in] == 1
83
+ assert all(n_tiles[ax_in[a]] == 1 for a in axes_lost)
84
+ assert all(np.isscalar(t) and 1<=t and int(t)==t for t in n_tiles)
85
+
86
+ # first axis > 1
87
+ axis = next(i for i,t in enumerate(n_tiles) if t>1)
88
+
89
+ block_size = block_sizes[axis]
90
+ tile_overlap = tile_overlaps[axis]
91
+ n_block_overlap = int(np.ceil(1.* tile_overlap / block_size))
92
+
93
+ # print(f"axis={axis},n_tiles={n_tiles[axis]},block_size={block_size},tile_overlap={tile_overlap},n_block_overlap={n_block_overlap}")
94
+
95
+ n_tiles_remaining = list(n_tiles)
96
+ n_tiles_remaining[axis] = 1
97
+
98
+ dst = None
99
+ for tile, s_src, s_dst in tile_iterator_1d(x,axis=axis,n_tiles=n_tiles[axis],block_size=block_size,n_block_overlap=n_block_overlap):
100
+
101
+ pred = predict_tiled(keras_model,tile,n_tiles_remaining,block_sizes,tile_overlaps,axes_in,axes_out,pbar=pbar,**kwargs)
102
+
103
+ # if any(t>1 for t in n_tiles_remaining):
104
+ # pred = predict_tiled(keras_model,tile,n_tiles_remaining,block_sizes,tile_overlaps,axes_in,axes_out,pbar=pbar,**kwargs)
105
+ # else:
106
+ # # tmp
107
+ # pred = tile
108
+ # if pbar is not None:
109
+ # pbar.update()
110
+
111
+ if dst is None:
112
+ dst_shape = _to_axes_out(x.shape, pred.shape[channel_out])
113
+ dst = np.empty(dst_shape, dtype=x.dtype)
114
+
115
+ s_src = _to_axes_out(s_src, slice(None))
116
+ s_dst = _to_axes_out(s_dst, slice(None))
117
+
118
+ dst[s_dst] = pred[s_src]
119
+
120
+ return dst
121
+
122
+
123
+
124
+ class Tile(object):
125
+ def __init__(self, n, size, overlap, prev):
126
+ self.n = int(n)
127
+ self.size = int(size)
128
+ self.overlap = int(overlap)
129
+ if self.n < self.size:
130
+ assert prev is None
131
+ # print("Truncating tile size from %d to %d." % (self.size, self.n))
132
+ self.size = self.n
133
+ self.overlap = 0
134
+ assert self.size > 2*self.overlap
135
+ # assert self.n >= self.size
136
+ if prev is not None:
137
+ assert not prev.at_end, "Previous tile already at end"
138
+ self.prev = prev
139
+ self.read_slice = self._read_slice
140
+ self.write_slice = self._write_slice
141
+
142
+ @property
143
+ def at_begin(self):
144
+ return self.prev is None
145
+
146
+ @property
147
+ def at_end(self):
148
+ return self.read_slice.stop == self.n
149
+
150
+ @property
151
+ def _read_slice(self):
152
+ if self.at_begin:
153
+ start, stop = 0, self.size
154
+ else:
155
+ prev_read_slice = self.prev.read_slice
156
+ start = prev_read_slice.stop - 2*self.overlap
157
+ stop = start + self.size
158
+ shift = min(0, self.n - stop)
159
+ start, stop = start + shift, stop + shift
160
+ assert start > prev_read_slice.start
161
+ assert start >= 0 and stop <= self.n
162
+ return slice(start, stop)
163
+
164
+ @property
165
+ def _write_slice(self):
166
+ if self.at_begin:
167
+ if self.at_end:
168
+ return slice(0, self.n)
169
+ else:
170
+ return slice(0, self.size - 1*self.overlap)
171
+ elif self.at_end:
172
+ s = self.prev.write_slice.stop
173
+ return slice(s, self.n)
174
+ else:
175
+ s = self.prev.write_slice.stop
176
+ return slice(s, s + self.size - 2*self.overlap)
177
+
178
+ def __repr__(self):
179
+ s = np.array(list(' '*self.n))
180
+ s[self.read_slice] = '-'
181
+ s[self.write_slice] = 'x' if (self.at_begin or self.at_end) else 'o'
182
+ return ''.join(s)
183
+
184
+
185
+
186
+ class Tiling(object):
187
+ def __init__(self, n, size, overlap):
188
+ self.n = n
189
+ self.size = size
190
+ self.overlap = overlap
191
+ tiles = [Tile(prev=None, **self.__dict__)]
192
+ while not tiles[-1].at_end:
193
+ tiles.append(Tile(prev=tiles[-1], **self.__dict__))
194
+ self.tiles = tiles
195
+
196
+ def __len__(self):
197
+ return len(self.tiles)
198
+
199
+ def __repr__(self):
200
+ return '\n'.join('{i:3}. {t}'.format(i=i,t=t) for i,t in enumerate(self.tiles,1))
201
+
202
+ def slice_generator(self, block_size=1):
203
+ def scale(sl):
204
+ return slice(block_size * sl.start, block_size * sl.stop)
205
+ def crop_slice(read, write):
206
+ stop = write.stop - read.stop
207
+ return slice(write.start - read.start, stop if stop < 0 else None)
208
+ for t in self.tiles:
209
+ read, write = scale(t.read_slice), scale(t.write_slice)
210
+ yield read, write, crop_slice(read, write)
211
+
212
+ @staticmethod
213
+ def for_n_tiles(n, n_tiles, overlap):
214
+ smallest_size = 2*overlap + 1
215
+ tile_size = smallest_size # start with smallest posible tile_size
216
+ while len(Tiling(n, tile_size, overlap)) > n_tiles:
217
+ tile_size += 1
218
+ if tile_size == smallest_size:
219
+ return Tiling(n, tile_size, overlap)
220
+ candidates = (
221
+ Tiling(n, tile_size-1, overlap),
222
+ Tiling(n, tile_size, overlap),
223
+ )
224
+ diffs = [np.abs(len(c) - n_tiles) for c in candidates]
225
+ return candidates[np.argmin(diffs)]
226
+
227
+
228
+
229
+ def total_n_tiles(x,n_tiles,block_sizes,n_block_overlaps,guarantee='size'):
230
+ assert x.ndim == len(n_tiles) == len(block_sizes) == len(n_block_overlaps)
231
+ assert guarantee in ('size', 'n_tiles')
232
+ n_tiles_used = 1
233
+ for n, n_tile, block_size, n_block_overlap in zip(x.shape, n_tiles, block_sizes, n_block_overlaps):
234
+ assert n % block_size == 0
235
+ n_blocks = n // block_size
236
+ if guarantee == 'size':
237
+ n_tiles_used *= len(Tiling.for_n_tiles(n_blocks, n_tile, n_block_overlap))
238
+ elif guarantee == 'n_tiles':
239
+ n_tiles_used *= n_tile
240
+ return n_tiles_used
241
+
242
+
243
+
244
+ def tile_iterator_1d(x,axis,n_tiles,block_size,n_block_overlap,guarantee='size'):
245
+ """Tile iterator for one dimension of array x.
246
+
247
+ Parameters
248
+ ----------
249
+ x : numpy.ndarray
250
+ Input array
251
+ axis : int
252
+ Axis which sould be tiled, all other axis not tiled
253
+ n_tiles : int
254
+ Targeted number of tiles for axis of x (see guarantee below)
255
+ block_size : int
256
+ Axis of x is assumed to be evenly divisible by block_size
257
+ All tiles are aligned with the block_size
258
+ n_block_overlap : int
259
+ Tiles will overlap at least this many blocks (see guarantee below)
260
+ guarantee : str
261
+ Can be either 'size' or 'n_tiles':
262
+ 'size': The size of all tiles is guaranteed to be the same,
263
+ but the number of tiles can be different and the
264
+ amount of overlap can be larger than requested.
265
+ 'n_tiles': The size of tiles can be different at the beginning and end,
266
+ but the number of tiles is guarantee to be the one requested.
267
+ The mount of overlap is also exactly as requested.
268
+
269
+
270
+ """
271
+ n = x.shape[axis]
272
+
273
+ n % block_size == 0 or _raise(ValueError("'x' must be evenly divisible by 'block_size' along 'axis'"))
274
+ n_blocks = n // block_size
275
+
276
+ guarantee in ('size', 'n_tiles') or _raise(ValueError("guarantee must be either 'size' or 'n_tiles'"))
277
+
278
+ if guarantee == 'size':
279
+ tiling = Tiling.for_n_tiles(n_blocks, n_tiles, n_block_overlap)
280
+
281
+ def ndim_slices(t):
282
+ sl = [slice(None)] * x.ndim
283
+ sl[axis] = t
284
+ return tuple(sl)
285
+
286
+ for read, write, crop in tiling.slice_generator(block_size):
287
+ tile_in = read # src in input image / tile
288
+ tile_out = write # dst in output image / s_dst
289
+ tile_crop = crop # crop of src for output / s_src
290
+ yield x[ndim_slices(tile_in)], ndim_slices(tile_crop), ndim_slices(tile_out)
291
+
292
+ elif guarantee == 'n_tiles':
293
+ n_tiles_valid = int(np.clip(n_tiles,1,n_blocks))
294
+ if n_tiles != n_tiles_valid:
295
+ warnings.warn("invalid value (%d) for 'n_tiles', changing to %d" % (n_tiles,n_tiles_valid))
296
+ n_tiles = n_tiles_valid
297
+
298
+ s = n_blocks // n_tiles # tile size
299
+ r = n_blocks % n_tiles # blocks remainder
300
+ assert n_tiles * s + r == n_blocks
301
+
302
+ # list of sizes for each tile
303
+ tile_sizes = s*np.ones(n_tiles,int)
304
+ # distribute remaining blocks to tiles at beginning and end
305
+ if r > 0:
306
+ tile_sizes[:r//2] += 1
307
+ tile_sizes[-(r-r//2):] += 1
308
+
309
+ # n_block_overlap = int(np.ceil(92 / block_size))
310
+ # n_block_overlap -= 1
311
+ # print(n_block_overlap)
312
+
313
+ # (pre,post) offsets for each tile
314
+ off = [(n_block_overlap if i > 0 else 0, n_block_overlap if i < n_tiles-1 else 0) for i in range(n_tiles)]
315
+
316
+ # tile_starts = np.concatenate(([0],np.cumsum(tile_sizes[:-1])))
317
+ # print([(_st-_pre,_st+_sz+_post) for (_st,_sz,(_pre,_post)) in zip(tile_starts,tile_sizes,off)])
318
+
319
+ def to_slice(t):
320
+ sl = [slice(None)] * x.ndim
321
+ sl[axis] = slice(
322
+ t[0]*block_size,
323
+ t[1]*block_size if t[1]!=0 else None)
324
+ return tuple(sl)
325
+
326
+ start = 0
327
+ for i in range(n_tiles):
328
+ off_pre, off_post = off[i]
329
+
330
+ # tile starts before block 0 -> adjust off_pre
331
+ if start-off_pre < 0:
332
+ off_pre = start
333
+ # tile end after last block -> adjust off_post
334
+ if start+tile_sizes[i]+off_post > n_blocks:
335
+ off_post = n_blocks-start-tile_sizes[i]
336
+
337
+ tile_in = (start-off_pre,start+tile_sizes[i]+off_post) # src in input image / tile
338
+ tile_out = (start,start+tile_sizes[i]) # dst in output image / s_dst
339
+ tile_crop = (off_pre,-off_post) # crop of src for output / s_src
340
+
341
+ yield x[to_slice(tile_in)], to_slice(tile_crop), to_slice(tile_out)
342
+ start += tile_sizes[i]
343
+
344
+ else:
345
+ assert False
346
+
347
+
348
+
349
+ def tile_iterator(x,n_tiles,block_sizes,n_block_overlaps,guarantee='size'):
350
+ """Tile iterator for n-d arrays.
351
+
352
+ Yields block-aligned tiles (`block_sizes`) that have at least
353
+ a certain amount of overlapping blocks (`n_block_overlaps`)
354
+ with their neighbors. Also yields slices that allow to map each
355
+ tile back to the original array x.
356
+
357
+ Notes
358
+ -----
359
+ - Tiles will not go beyond the array boundary (i.e. no padding).
360
+ This means the shape of x must be evenly divisible by the respective block_size.
361
+ - It is not guaranteed that all tiles have the same size if guarantee is not 'size'.
362
+
363
+ Parameters
364
+ ----------
365
+ x : numpy.ndarray
366
+ Input array.
367
+ n_tiles : int or sequence of ints
368
+ Number of tiles for each dimension of x.
369
+ block_sizes : int or sequence of ints
370
+ Block sizes for each dimension of x.
371
+ The shape of x is assumed to be evenly divisible by block_sizes.
372
+ All tiles are aligned with block_sizes.
373
+ n_block_overlaps : int or sequence of ints
374
+ Tiles will at least overlap this many blocks in each dimension.
375
+ guarantee : str
376
+ Can be either 'size' or 'n_tiles':
377
+ 'size': The size of all tiles is guaranteed to be the same,
378
+ but the number of tiles can be different and the
379
+ amount of overlap can be larger than requested.
380
+ 'n_tiles': The size of tiles can be different at the beginning and end,
381
+ but the number of tiles is guarantee to be the one requested.
382
+ The mount of overlap is also exactly as requested.
383
+
384
+ Example
385
+ -------
386
+
387
+ Duplicate an array tile-by-tile:
388
+
389
+ >>> x = np.array(...)
390
+ >>> y = np.empty_like(x)
391
+ >>>
392
+ >>> for tile,s_src,s_dst in tile_iterator(x, n_tiles, block_sizes, n_block_overlaps):
393
+ >>> y[s_dst] = tile[s_src]
394
+ >>>
395
+ >>> np.allclose(x,y)
396
+ True
397
+
398
+ """
399
+ if np.isscalar(n_tiles): n_tiles = (n_tiles,)*x.ndim
400
+ if np.isscalar(block_sizes): block_sizes = (block_sizes,)*x.ndim
401
+ if np.isscalar(n_block_overlaps): n_block_overlaps = (n_block_overlaps,)*x.ndim
402
+
403
+ assert x.ndim == len(n_tiles) == len(block_sizes) == len(n_block_overlaps)
404
+
405
+ def _accumulate(tile_in,axis,src,dst):
406
+ for tile, s_src, s_dst in tile_iterator_1d(tile_in, axis, n_tiles[axis], block_sizes[axis], n_block_overlaps[axis], guarantee):
407
+ src[axis] = s_src[axis]
408
+ dst[axis] = s_dst[axis]
409
+ if axis+1 == tile_in.ndim:
410
+ # remove None and negative slicing
411
+ src = [slice(s.start, size if s.stop is None else (s.stop if s.stop >= 0 else size + s.stop)) for s,size in zip(src,tile.shape)]
412
+ yield tile, tuple(src), tuple(dst)
413
+ else:
414
+ # yield from _accumulate(tile, axis+1, src, dst)
415
+ for entry in _accumulate(tile, axis+1, src, dst):
416
+ yield entry
417
+
418
+ return _accumulate(x, 0, [None]*x.ndim, [None]*x.ndim)
419
+
420
+
421
+
422
+ def tile_overlap(n_depth, kern_size, pool_size=2):
423
+ rf = {(1, 3, 1): 6, (1, 5, 1): 12, (1, 7, 1): 18,
424
+ (2, 3, 1): 10, (2, 5, 1): 20, (2, 7, 1): 30,
425
+ (3, 3, 1): 14, (3, 5, 1): 28, (3, 7, 1): 42,
426
+ (4, 3, 1): 18, (4, 5, 1): 36, (4, 7, 1): 54,
427
+ (5, 3, 1): 22, (5, 5, 1): 44, (5, 7, 1): 66,
428
+ #
429
+ (1, 3, 2): 9, (1, 5, 2): 17, (1, 7, 2): 25,
430
+ (2, 3, 2): 22, (2, 5, 2): 43, (2, 7, 2): 62,
431
+ (3, 3, 2): 46, (3, 5, 2): 92, (3, 7, 2): 138,
432
+ (4, 3, 2): 94, (4, 5, 2): 188, (4, 7, 2): 282,
433
+ (5, 3, 2): 190, (5, 5, 2): 380, (5, 7, 2): 570,
434
+ #
435
+ (1, 3, 4): 14, (1, 5, 4): 27, (1, 7, 4): 38,
436
+ (2, 3, 4): 58, (2, 5, 4): 116, (2, 7, 4): 158,
437
+ (3, 3, 4): 234, (3, 5, 4): 468, (3, 7, 4): 638,
438
+ (4, 3, 4): 938, (4, 5, 4): 1876, (4, 7, 4): 2558}
439
+ try:
440
+ return rf[n_depth, kern_size, pool_size]
441
+ except KeyError:
442
+ raise ValueError('tile_overlap value for n_depth=%d, kern_size=%d, pool_size=%d not available.' % (n_depth, kern_size, pool_size))
443
+
444
+
445
+
446
+ class Progress(object):
447
+ def __init__(self, total, thr=1):
448
+ self.pbar = None
449
+ self.total = total
450
+ self.thr = thr
451
+ @property
452
+ def total(self):
453
+ return self._total
454
+ @total.setter
455
+ def total(self, total):
456
+ self.close()
457
+ self._total = total
458
+ def update(self):
459
+ if self.total > self.thr:
460
+ if self.pbar is None:
461
+ self.pbar = tqdm(total=self.total)
462
+ self.pbar.update()
463
+ self.pbar.refresh()
464
+ def close(self):
465
+ if self.pbar is not None:
466
+ self.pbar.close()
467
+ self.pbar = None
@@ -0,0 +1,67 @@
1
+ from __future__ import print_function, unicode_literals, absolute_import, division
2
+ from six.moves import range, zip, map, reduce, filter
3
+
4
+
5
+ from ..utils import _raise, consume
6
+ import warnings
7
+ import numpy as np
8
+ from scipy.stats import laplace
9
+
10
+
11
+ class ProbabilisticPrediction(object):
12
+ """Laplace distribution (independently per pixel)."""
13
+
14
+ def __init__(self, loc, scale):
15
+ loc.shape == scale.shape or _raise(ValueError())
16
+ #
17
+ self._loc = loc
18
+ self._scale = scale
19
+ # expose methods from laplace object
20
+ _laplace = laplace(loc=self._loc,scale=self._scale)
21
+ self.rvs = _laplace.rvs
22
+ self.pdf = _laplace.pdf
23
+ self.logpdf = _laplace.logpdf
24
+ self.cdf = _laplace.cdf
25
+ self.logcdf = _laplace.logcdf
26
+ self.sf = _laplace.sf
27
+ self.logsf = _laplace.logsf
28
+ self.ppf = _laplace.ppf
29
+ self.isf = _laplace.isf
30
+ self.moment = _laplace.moment
31
+ self.stats = _laplace.stats
32
+ self.entropy = _laplace.entropy
33
+ self.expect = _laplace.expect
34
+ self.median = _laplace.median
35
+ self.mean = _laplace.mean
36
+ self.var = _laplace.var
37
+ self.std = _laplace.std
38
+ self.interval = _laplace.interval
39
+
40
+ def __getitem__(self, indices):
41
+ return ProbabilisticPrediction(loc=self._loc[indices],scale=self._scale[indices])
42
+
43
+ def __len__(self):
44
+ return len(self._loc)
45
+
46
+ @property
47
+ def shape(self):
48
+ return self._loc.shape
49
+
50
+ @property
51
+ def ndim(self):
52
+ return self._loc.ndim
53
+
54
+ @property
55
+ def size(self):
56
+ return self._loc.size
57
+
58
+ def scale(self):
59
+ return self._scale
60
+
61
+ def sampling_generator(self,n=None):
62
+ if n is None:
63
+ while True:
64
+ yield self.rvs()
65
+ else:
66
+ for i in range(n):
67
+ yield self.rvs()