senoquant 1.0.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- senoquant/__init__.py +6 -0
- senoquant/_reader.py +7 -0
- senoquant/_widget.py +33 -0
- senoquant/napari.yaml +83 -0
- senoquant/reader/__init__.py +5 -0
- senoquant/reader/core.py +369 -0
- senoquant/tabs/__init__.py +15 -0
- senoquant/tabs/batch/__init__.py +10 -0
- senoquant/tabs/batch/backend.py +641 -0
- senoquant/tabs/batch/config.py +270 -0
- senoquant/tabs/batch/frontend.py +1283 -0
- senoquant/tabs/batch/io.py +326 -0
- senoquant/tabs/batch/layers.py +86 -0
- senoquant/tabs/quantification/__init__.py +1 -0
- senoquant/tabs/quantification/backend.py +228 -0
- senoquant/tabs/quantification/features/__init__.py +80 -0
- senoquant/tabs/quantification/features/base.py +142 -0
- senoquant/tabs/quantification/features/marker/__init__.py +5 -0
- senoquant/tabs/quantification/features/marker/config.py +69 -0
- senoquant/tabs/quantification/features/marker/dialog.py +437 -0
- senoquant/tabs/quantification/features/marker/export.py +879 -0
- senoquant/tabs/quantification/features/marker/feature.py +119 -0
- senoquant/tabs/quantification/features/marker/morphology.py +285 -0
- senoquant/tabs/quantification/features/marker/rows.py +654 -0
- senoquant/tabs/quantification/features/marker/thresholding.py +46 -0
- senoquant/tabs/quantification/features/roi.py +346 -0
- senoquant/tabs/quantification/features/spots/__init__.py +5 -0
- senoquant/tabs/quantification/features/spots/config.py +62 -0
- senoquant/tabs/quantification/features/spots/dialog.py +477 -0
- senoquant/tabs/quantification/features/spots/export.py +1292 -0
- senoquant/tabs/quantification/features/spots/feature.py +112 -0
- senoquant/tabs/quantification/features/spots/morphology.py +279 -0
- senoquant/tabs/quantification/features/spots/rows.py +241 -0
- senoquant/tabs/quantification/frontend.py +815 -0
- senoquant/tabs/segmentation/__init__.py +1 -0
- senoquant/tabs/segmentation/backend.py +131 -0
- senoquant/tabs/segmentation/frontend.py +1009 -0
- senoquant/tabs/segmentation/models/__init__.py +5 -0
- senoquant/tabs/segmentation/models/base.py +146 -0
- senoquant/tabs/segmentation/models/cpsam/details.json +65 -0
- senoquant/tabs/segmentation/models/cpsam/model.py +150 -0
- senoquant/tabs/segmentation/models/default_2d/details.json +69 -0
- senoquant/tabs/segmentation/models/default_2d/model.py +664 -0
- senoquant/tabs/segmentation/models/default_3d/details.json +69 -0
- senoquant/tabs/segmentation/models/default_3d/model.py +682 -0
- senoquant/tabs/segmentation/models/hf.py +71 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/__init__.py +1 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/details.json +26 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/model.py +96 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/__init__.py +1 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/details.json +34 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/model.py +132 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/__init__.py +2 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/__init__.py +3 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/__init__.py +6 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/generate.py +470 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/prepare.py +273 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/rawdata.py +112 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/transform.py +384 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/blocks.py +184 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/losses.py +79 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/nets.py +165 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/predict.py +467 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/probability.py +67 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/train.py +148 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/io/__init__.py +163 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/__init__.py +52 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/base_model.py +329 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_isotropic.py +160 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_projection.py +178 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_standard.py +446 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_upsampling.py +54 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/config.py +254 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/pretrained.py +119 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/care_predict.py +180 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/__init__.py +5 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/plot_utils.py +159 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/six.py +18 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/tf.py +644 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/utils.py +272 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/version.py +1 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/docs/source/conf.py +368 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/setup.py +68 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_datagen.py +169 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_models.py +462 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_utils.py +166 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tools/create_zip_contents.py +34 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/__init__.py +30 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/big.py +624 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/bioimageio_utils.py +494 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/data/__init__.py +39 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/__init__.py +10 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom2d.py +215 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom3d.py +349 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/matching.py +483 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/__init__.py +28 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/base.py +1217 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model2d.py +594 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model3d.py +696 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/nms.py +384 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/__init__.py +2 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/plot.py +74 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/render.py +298 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/rays3d.py +373 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/sample_patches.py +65 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict2d.py +90 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict3d.py +93 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/utils.py +408 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/version.py +1 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/__init__.py +45 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/__init__.py +17 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/cli.py +55 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/core.py +285 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/__init__.py +15 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/cli.py +36 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/divisibility.py +193 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/probe.py +100 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/receptive_field.py +182 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/rf_cli.py +48 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/valid_sizes.py +278 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/__init__.py +8 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/core.py +157 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/__init__.py +17 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/core.py +226 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/__init__.py +5 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/core.py +401 -0
- senoquant/tabs/settings/__init__.py +1 -0
- senoquant/tabs/settings/backend.py +29 -0
- senoquant/tabs/settings/frontend.py +19 -0
- senoquant/tabs/spots/__init__.py +1 -0
- senoquant/tabs/spots/backend.py +139 -0
- senoquant/tabs/spots/frontend.py +800 -0
- senoquant/tabs/spots/models/__init__.py +5 -0
- senoquant/tabs/spots/models/base.py +94 -0
- senoquant/tabs/spots/models/rmp/details.json +61 -0
- senoquant/tabs/spots/models/rmp/model.py +499 -0
- senoquant/tabs/spots/models/udwt/details.json +103 -0
- senoquant/tabs/spots/models/udwt/model.py +482 -0
- senoquant/utils.py +25 -0
- senoquant-1.0.0b1.dist-info/METADATA +193 -0
- senoquant-1.0.0b1.dist-info/RECORD +148 -0
- senoquant-1.0.0b1.dist-info/WHEEL +5 -0
- senoquant-1.0.0b1.dist-info/entry_points.txt +2 -0
- senoquant-1.0.0b1.dist-info/licenses/LICENSE +28 -0
- senoquant-1.0.0b1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,446 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
from __future__ import print_function, unicode_literals, absolute_import, division
|
|
3
|
+
|
|
4
|
+
import warnings
|
|
5
|
+
import numpy as np
|
|
6
|
+
from six import string_types
|
|
7
|
+
|
|
8
|
+
from csbdeep.internals.probability import ProbabilisticPrediction
|
|
9
|
+
from .config import Config
|
|
10
|
+
from .base_model import BaseModel, suppress_without_basedir
|
|
11
|
+
|
|
12
|
+
from ..utils import _raise, axes_check_and_normalize, axes_dict, move_image_axes
|
|
13
|
+
from ..utils.six import Path
|
|
14
|
+
from ..utils.tf import export_SavedModel, IS_TF_1, keras_import, CARETensorBoardImage
|
|
15
|
+
from ..version import __version__ as package_version
|
|
16
|
+
from ..data import Normalizer, NoNormalizer, PercentileNormalizer
|
|
17
|
+
from ..data import Resizer, NoResizer, PadAndCropResizer
|
|
18
|
+
from ..internals.predict import predict_tiled, tile_overlap, Progress, total_n_tiles
|
|
19
|
+
from ..internals import nets, train
|
|
20
|
+
|
|
21
|
+
from packaging.version import Version
|
|
22
|
+
keras = keras_import()
|
|
23
|
+
|
|
24
|
+
import tensorflow as tf
|
|
25
|
+
# if IS_TF_1:
|
|
26
|
+
# import tensorflow as tf
|
|
27
|
+
# else:
|
|
28
|
+
# import tensorflow.compat.v1 as tf
|
|
29
|
+
# # tf.disable_v2_behavior()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class CARE(BaseModel):
|
|
33
|
+
"""Standard CARE network for image restoration and enhancement.
|
|
34
|
+
|
|
35
|
+
Uses a convolutional neural network created by :func:`csbdeep.internals.nets.common_unet`.
|
|
36
|
+
Note that isotropic reconstruction and manifold extraction/projection are not supported here
|
|
37
|
+
(see :class:`csbdeep.models.IsotropicCARE` ).
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
config : :class:`csbdeep.models.Config` or None
|
|
42
|
+
Valid configuration of CARE network (see :func:`Config.is_valid`).
|
|
43
|
+
Will be saved to disk as JSON (``config.json``).
|
|
44
|
+
If set to ``None``, will be loaded from disk (must exist).
|
|
45
|
+
name : str or None
|
|
46
|
+
Model name. Uses a timestamp if set to ``None`` (default).
|
|
47
|
+
basedir : str
|
|
48
|
+
Directory that contains (or will contain) a folder with the given model name.
|
|
49
|
+
Use ``None`` to disable saving (or loading) any data to (or from) disk (regardless of other parameters).
|
|
50
|
+
|
|
51
|
+
Raises
|
|
52
|
+
------
|
|
53
|
+
FileNotFoundError
|
|
54
|
+
If ``config=None`` and config cannot be loaded from disk.
|
|
55
|
+
ValueError
|
|
56
|
+
Illegal arguments, including invalid configuration.
|
|
57
|
+
|
|
58
|
+
Example
|
|
59
|
+
-------
|
|
60
|
+
>>> model = CARE(config, 'my_model')
|
|
61
|
+
|
|
62
|
+
Attributes
|
|
63
|
+
----------
|
|
64
|
+
config : :class:`csbdeep.models.Config`
|
|
65
|
+
Configuration of CARE network, as provided during instantiation.
|
|
66
|
+
keras_model : `Keras model <https://keras.io/getting-started/functional-api-guide/>`_
|
|
67
|
+
Keras neural network model.
|
|
68
|
+
name : str
|
|
69
|
+
Model name.
|
|
70
|
+
logdir : :class:`pathlib.Path`
|
|
71
|
+
Path to model folder (which stores configuration, weights, etc.)
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def __init__(self, config, name=None, basedir='.'):
|
|
75
|
+
"""See class docstring."""
|
|
76
|
+
super(CARE, self).__init__(config=config, name=name, basedir=basedir)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _build(self):
|
|
80
|
+
return nets.common_unet(
|
|
81
|
+
n_dim = self.config.n_dim,
|
|
82
|
+
n_channel_out = self.config.n_channel_out,
|
|
83
|
+
prob_out = self.config.probabilistic,
|
|
84
|
+
residual = self.config.unet_residual,
|
|
85
|
+
n_depth = self.config.unet_n_depth,
|
|
86
|
+
kern_size = self.config.unet_kern_size,
|
|
87
|
+
n_first = self.config.unet_n_first,
|
|
88
|
+
last_activation = self.config.unet_last_activation,
|
|
89
|
+
)(self.config.unet_input_shape)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def prepare_for_training(self, optimizer=None, **kwargs):
|
|
93
|
+
"""Prepare for neural network training.
|
|
94
|
+
|
|
95
|
+
Calls :func:`csbdeep.internals.train.prepare_model` and creates
|
|
96
|
+
`Keras Callbacks <https://keras.io/callbacks/>`_ to be used for training.
|
|
97
|
+
|
|
98
|
+
Note that this method will be implicitly called once by :func:`train`
|
|
99
|
+
(with default arguments) if not done so explicitly beforehand.
|
|
100
|
+
|
|
101
|
+
Parameters
|
|
102
|
+
----------
|
|
103
|
+
optimizer : obj or None
|
|
104
|
+
Instance of a `Keras Optimizer <https://keras.io/optimizers/>`_ to be used for training.
|
|
105
|
+
If ``None`` (default), uses ``Adam`` with the learning rate specified in ``config``.
|
|
106
|
+
kwargs : dict
|
|
107
|
+
Additional arguments for :func:`csbdeep.internals.train.prepare_model`.
|
|
108
|
+
|
|
109
|
+
"""
|
|
110
|
+
if optimizer is None:
|
|
111
|
+
Adam = keras_import('optimizers', 'Adam')
|
|
112
|
+
learning_rate = 'lr' if Version(getattr(keras, '__version__', '9.9.9')) < Version('2.3.0') else 'learning_rate'
|
|
113
|
+
optimizer = Adam(**{learning_rate: self.config.train_learning_rate})
|
|
114
|
+
self.callbacks = train.prepare_model(self.keras_model, optimizer, self.config.train_loss, **kwargs)
|
|
115
|
+
|
|
116
|
+
if self.basedir is not None:
|
|
117
|
+
self.callbacks += self._checkpoint_callbacks()
|
|
118
|
+
|
|
119
|
+
if self.config.train_tensorboard:
|
|
120
|
+
if IS_TF_1:
|
|
121
|
+
from ..utils.tf import CARETensorBoard
|
|
122
|
+
self.callbacks.append(CARETensorBoard(log_dir=str(self.logdir), prefix_with_timestamp=False, n_images=3, write_images=True, prob_out=self.config.probabilistic))
|
|
123
|
+
else:
|
|
124
|
+
from tensorflow.keras.callbacks import TensorBoard
|
|
125
|
+
self.callbacks.append(TensorBoard(log_dir=str(self.logdir/'logs'), write_graph=False, profile_batch=0))
|
|
126
|
+
|
|
127
|
+
if self.config.train_reduce_lr is not None:
|
|
128
|
+
ReduceLROnPlateau = keras_import('callbacks', 'ReduceLROnPlateau')
|
|
129
|
+
rlrop_params = self.config.train_reduce_lr
|
|
130
|
+
if 'verbose' not in rlrop_params:
|
|
131
|
+
rlrop_params['verbose'] = True
|
|
132
|
+
# TF2: add as first callback to put 'lr' in the logs for TensorBoard
|
|
133
|
+
self.callbacks.insert(0,ReduceLROnPlateau(**rlrop_params))
|
|
134
|
+
|
|
135
|
+
self._model_prepared = True
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def train(self, X,Y, validation_data, epochs=None, steps_per_epoch=None, augmenter=None):
|
|
139
|
+
"""Train the neural network with the given data.
|
|
140
|
+
|
|
141
|
+
Parameters
|
|
142
|
+
----------
|
|
143
|
+
X : :class:`numpy.ndarray`
|
|
144
|
+
Array of source images.
|
|
145
|
+
Y : :class:`numpy.ndarray`
|
|
146
|
+
Array of target images.
|
|
147
|
+
validation_data : tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`)
|
|
148
|
+
Tuple of arrays for source and target validation images.
|
|
149
|
+
epochs : int
|
|
150
|
+
Optional argument to use instead of the value from ``config``.
|
|
151
|
+
steps_per_epoch : int
|
|
152
|
+
Optional argument to use instead of the value from ``config``.
|
|
153
|
+
|
|
154
|
+
Returns
|
|
155
|
+
-------
|
|
156
|
+
``History`` object
|
|
157
|
+
See `Keras training history <https://keras.io/models/model/#fit>`_.
|
|
158
|
+
|
|
159
|
+
"""
|
|
160
|
+
((isinstance(validation_data,(list,tuple)) and len(validation_data)==2)
|
|
161
|
+
or _raise(ValueError('validation_data must be a pair of numpy arrays')))
|
|
162
|
+
|
|
163
|
+
n_train, n_val = len(X), len(validation_data[0])
|
|
164
|
+
frac_val = (1.0 * n_val) / (n_train + n_val)
|
|
165
|
+
frac_warn = 0.05
|
|
166
|
+
if frac_val < frac_warn:
|
|
167
|
+
warnings.warn("small number of validation images (only %.1f%% of all images)" % (100*frac_val))
|
|
168
|
+
axes = axes_check_and_normalize('S'+self.config.axes,X.ndim)
|
|
169
|
+
ax = axes_dict(axes)
|
|
170
|
+
|
|
171
|
+
for a,div_by in zip(axes,self._axes_div_by(axes)):
|
|
172
|
+
n = X.shape[ax[a]]
|
|
173
|
+
if n % div_by != 0:
|
|
174
|
+
raise ValueError(
|
|
175
|
+
"training images must be evenly divisible by %d along axis %s"
|
|
176
|
+
" (which has incompatible size %d)" % (div_by,a,n)
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
if epochs is None:
|
|
180
|
+
epochs = self.config.train_epochs
|
|
181
|
+
if steps_per_epoch is None:
|
|
182
|
+
steps_per_epoch = self.config.train_steps_per_epoch
|
|
183
|
+
|
|
184
|
+
if not self._model_prepared:
|
|
185
|
+
self.prepare_for_training()
|
|
186
|
+
|
|
187
|
+
if (self.config.train_tensorboard and self.basedir is not None and
|
|
188
|
+
not IS_TF_1 and not any(isinstance(cb,CARETensorBoardImage) for cb in self.callbacks)):
|
|
189
|
+
self.callbacks.append(CARETensorBoardImage(model=self.keras_model, data=validation_data,
|
|
190
|
+
log_dir=str(self.logdir/'logs'/'images'),
|
|
191
|
+
n_images=3, prob_out=self.config.probabilistic))
|
|
192
|
+
|
|
193
|
+
training_data = train.DataWrapper(X, Y, self.config.train_batch_size, length=epochs*steps_per_epoch, augmenter=augmenter)
|
|
194
|
+
|
|
195
|
+
fit = self.keras_model.fit_generator if IS_TF_1 else self.keras_model.fit
|
|
196
|
+
history = fit(iter(training_data), validation_data=validation_data,
|
|
197
|
+
epochs=epochs, steps_per_epoch=steps_per_epoch,
|
|
198
|
+
callbacks=self.callbacks, verbose=1)
|
|
199
|
+
self._training_finished()
|
|
200
|
+
|
|
201
|
+
return history
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
@suppress_without_basedir(warn=True)
|
|
205
|
+
def export_TF(self, fname=None):
|
|
206
|
+
"""Export neural network via :func:`csbdeep.utils.tf.export_SavedModel`.
|
|
207
|
+
|
|
208
|
+
Parameters
|
|
209
|
+
----------
|
|
210
|
+
fname : str or None
|
|
211
|
+
Path of the created SavedModel archive (will end with ".zip").
|
|
212
|
+
If ``None``, "<model-directory>/TF_SavedModel.zip" will be used.
|
|
213
|
+
|
|
214
|
+
"""
|
|
215
|
+
if fname is None:
|
|
216
|
+
fname = self.logdir / 'TF_SavedModel.zip'
|
|
217
|
+
else:
|
|
218
|
+
fname = Path(fname)
|
|
219
|
+
|
|
220
|
+
meta = {
|
|
221
|
+
'type': self.__class__.__name__,
|
|
222
|
+
'version': package_version,
|
|
223
|
+
'probabilistic': self.config.probabilistic,
|
|
224
|
+
'axes': self.config.axes,
|
|
225
|
+
'axes_div_by': self._axes_div_by(self.config.axes),
|
|
226
|
+
'tile_overlap': self._axes_tile_overlap(self.config.axes),
|
|
227
|
+
}
|
|
228
|
+
export_SavedModel(self.keras_model, str(fname), meta=meta)
|
|
229
|
+
print("\nModel exported in TensorFlow's SavedModel format:\n%s" % str(fname.resolve()))
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def predict(self, img, axes, normalizer=PercentileNormalizer(), resizer=PadAndCropResizer(), n_tiles=None):
|
|
233
|
+
"""Apply neural network to raw image to predict restored image.
|
|
234
|
+
|
|
235
|
+
Parameters
|
|
236
|
+
----------
|
|
237
|
+
img : :class:`numpy.ndarray`
|
|
238
|
+
Raw input image
|
|
239
|
+
axes : str
|
|
240
|
+
Axes of the input ``img``.
|
|
241
|
+
normalizer : :class:`csbdeep.data.Normalizer` or None
|
|
242
|
+
Normalization of input image before prediction and (potentially) transformation back after prediction.
|
|
243
|
+
resizer : :class:`csbdeep.data.Resizer` or None
|
|
244
|
+
If necessary, input image is resized to enable neural network prediction and result is (possibly)
|
|
245
|
+
resized to yield original image size.
|
|
246
|
+
n_tiles : iterable or None
|
|
247
|
+
Out of memory (OOM) errors can occur if the input image is too large.
|
|
248
|
+
To avoid this problem, the input image is broken up into (overlapping) tiles
|
|
249
|
+
that can then be processed independently and re-assembled to yield the restored image.
|
|
250
|
+
This parameter denotes a tuple of the number of tiles for every image axis.
|
|
251
|
+
Note that if the number of tiles is too low, it is adaptively increased until
|
|
252
|
+
OOM errors are avoided, albeit at the expense of runtime.
|
|
253
|
+
A value of ``None`` denotes that no tiling should initially be used.
|
|
254
|
+
|
|
255
|
+
Returns
|
|
256
|
+
-------
|
|
257
|
+
:class:`numpy.ndarray`
|
|
258
|
+
Returns the restored image. If the model is probabilistic, this denotes the `mean` parameter of
|
|
259
|
+
the predicted per-pixel Laplace distributions (i.e., the expected restored image).
|
|
260
|
+
Axes semantics are the same as in the input image. Only if the output is multi-channel and
|
|
261
|
+
the input image didn't have a channel axis, then output channels are appended at the end.
|
|
262
|
+
|
|
263
|
+
"""
|
|
264
|
+
return self._predict_mean_and_scale(img, axes, normalizer, resizer, n_tiles)[0]
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def predict_probabilistic(self, img, axes, normalizer=PercentileNormalizer(), resizer=PadAndCropResizer(), n_tiles=None):
|
|
268
|
+
"""Apply neural network to raw image to predict probability distribution for restored image.
|
|
269
|
+
|
|
270
|
+
See :func:`predict` for parameter explanations.
|
|
271
|
+
|
|
272
|
+
Returns
|
|
273
|
+
-------
|
|
274
|
+
:class:`csbdeep.internals.probability.ProbabilisticPrediction`
|
|
275
|
+
Returns the probability distribution of the restored image.
|
|
276
|
+
|
|
277
|
+
Raises
|
|
278
|
+
------
|
|
279
|
+
ValueError
|
|
280
|
+
If this is not a probabilistic model.
|
|
281
|
+
|
|
282
|
+
"""
|
|
283
|
+
self.config.probabilistic or _raise(ValueError('This is not a probabilistic model.'))
|
|
284
|
+
mean, scale = self._predict_mean_and_scale(img, axes, normalizer, resizer, n_tiles)
|
|
285
|
+
return ProbabilisticPrediction(mean, scale)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def _predict_mean_and_scale(self, img, axes, normalizer, resizer, n_tiles=None):
|
|
289
|
+
"""Apply neural network to raw image to predict restored image.
|
|
290
|
+
|
|
291
|
+
See :func:`predict` for parameter explanations.
|
|
292
|
+
|
|
293
|
+
Returns
|
|
294
|
+
-------
|
|
295
|
+
tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray` or None)
|
|
296
|
+
If model is probabilistic, returns a tuple `(mean, scale)` that defines the parameters
|
|
297
|
+
of per-pixel Laplace distributions. Otherwise, returns the restored image via a tuple `(restored,None)`
|
|
298
|
+
|
|
299
|
+
"""
|
|
300
|
+
normalizer, resizer = self._check_normalizer_resizer(normalizer, resizer)
|
|
301
|
+
# axes = axes_check_and_normalize(axes,img.ndim)
|
|
302
|
+
|
|
303
|
+
# different kinds of axes
|
|
304
|
+
# -> typical case: net_axes_in = net_axes_out, img_axes_in = img_axes_out
|
|
305
|
+
img_axes_in = axes_check_and_normalize(axes,img.ndim)
|
|
306
|
+
net_axes_in = self.config.axes
|
|
307
|
+
net_axes_out = axes_check_and_normalize(self._axes_out)
|
|
308
|
+
set(net_axes_out).issubset(set(net_axes_in)) or _raise(ValueError("different kinds of output than input axes"))
|
|
309
|
+
net_axes_lost = set(net_axes_in).difference(set(net_axes_out))
|
|
310
|
+
img_axes_out = ''.join(a for a in img_axes_in if a not in net_axes_lost)
|
|
311
|
+
# print(' -> '.join((img_axes_in, net_axes_in, net_axes_out, img_axes_out)))
|
|
312
|
+
tiling_axes = net_axes_out.replace('C','') # axes eligible for tiling
|
|
313
|
+
|
|
314
|
+
_permute_axes = self._make_permute_axes(img_axes_in, net_axes_in, net_axes_out, img_axes_out)
|
|
315
|
+
# _permute_axes: (img_axes_in -> net_axes_in), undo: (net_axes_out -> img_axes_out)
|
|
316
|
+
x = _permute_axes(img)
|
|
317
|
+
# x has net_axes_in semantics
|
|
318
|
+
x_tiling_axis = tuple(axes_dict(net_axes_in)[a] for a in tiling_axes) # numerical axis ids for x
|
|
319
|
+
|
|
320
|
+
channel_in = axes_dict(net_axes_in)['C']
|
|
321
|
+
channel_out = axes_dict(net_axes_out)['C']
|
|
322
|
+
net_axes_in_div_by = self._axes_div_by(net_axes_in)
|
|
323
|
+
net_axes_in_overlaps = self._axes_tile_overlap(net_axes_in)
|
|
324
|
+
self.config.n_channel_in == x.shape[channel_in] or _raise(ValueError())
|
|
325
|
+
|
|
326
|
+
# TODO: refactor tiling stuff to make code more readable
|
|
327
|
+
|
|
328
|
+
def _total_n_tiles(n_tiles):
|
|
329
|
+
n_block_overlaps = [int(np.ceil(1.* tile_overlap / block_size)) for tile_overlap, block_size in zip(net_axes_in_overlaps, net_axes_in_div_by)]
|
|
330
|
+
return total_n_tiles(x,n_tiles=n_tiles,block_sizes=net_axes_in_div_by,n_block_overlaps=n_block_overlaps,guarantee='size')
|
|
331
|
+
|
|
332
|
+
_permute_axes_n_tiles = self._make_permute_axes(img_axes_in, net_axes_in)
|
|
333
|
+
# _permute_axes_n_tiles: (img_axes_in <-> net_axes_in) to convert n_tiles between img and net axes
|
|
334
|
+
def _permute_n_tiles(n,undo=False):
|
|
335
|
+
# hack: move tiling axis around in the same way as the image was permuted by creating an array
|
|
336
|
+
return _permute_axes_n_tiles(np.empty(n,bool),undo=undo).shape
|
|
337
|
+
|
|
338
|
+
# to support old api: set scalar n_tiles value for the largest tiling axis
|
|
339
|
+
if np.isscalar(n_tiles) and int(n_tiles)==n_tiles and 1<=n_tiles:
|
|
340
|
+
largest_tiling_axis = [i for i in np.argsort(x.shape) if i in x_tiling_axis][-1]
|
|
341
|
+
_n_tiles = [n_tiles if i==largest_tiling_axis else 1 for i in range(x.ndim)]
|
|
342
|
+
n_tiles = _permute_n_tiles(_n_tiles,undo=True)
|
|
343
|
+
warnings.warn("n_tiles should be a tuple with an entry for each image axis")
|
|
344
|
+
print("Changing n_tiles to %s" % str(n_tiles))
|
|
345
|
+
|
|
346
|
+
if n_tiles is None:
|
|
347
|
+
n_tiles = [1]*img.ndim
|
|
348
|
+
try:
|
|
349
|
+
n_tiles = tuple(n_tiles)
|
|
350
|
+
img.ndim == len(n_tiles) or _raise(TypeError())
|
|
351
|
+
except TypeError:
|
|
352
|
+
raise ValueError("n_tiles must be an iterable of length %d" % img.ndim)
|
|
353
|
+
|
|
354
|
+
all(np.isscalar(t) and 1<=t and int(t)==t for t in n_tiles) or _raise(
|
|
355
|
+
ValueError("all values of n_tiles must be integer values >= 1"))
|
|
356
|
+
n_tiles = tuple(map(int,n_tiles))
|
|
357
|
+
n_tiles = _permute_n_tiles(n_tiles)
|
|
358
|
+
(all(n_tiles[i] == 1 for i in range(x.ndim) if i not in x_tiling_axis) or
|
|
359
|
+
_raise(ValueError("entry of n_tiles > 1 only allowed for axes '%s'" % tiling_axes)))
|
|
360
|
+
# n_tiles_limited = self._limit_tiling(x.shape,n_tiles,net_axes_in_div_by)
|
|
361
|
+
# if any(np.array(n_tiles) != np.array(n_tiles_limited)):
|
|
362
|
+
# print("Limiting n_tiles to %s" % str(_permute_n_tiles(n_tiles_limited,undo=True)))
|
|
363
|
+
# n_tiles = n_tiles_limited
|
|
364
|
+
n_tiles = list(n_tiles)
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
# normalize & resize
|
|
368
|
+
x = normalizer.before(x, net_axes_in)
|
|
369
|
+
x = resizer.before(x, net_axes_in, net_axes_in_div_by)
|
|
370
|
+
|
|
371
|
+
done = False
|
|
372
|
+
progress = Progress(_total_n_tiles(n_tiles),1)
|
|
373
|
+
c = 0
|
|
374
|
+
while not done:
|
|
375
|
+
try:
|
|
376
|
+
# raise tf.errors.ResourceExhaustedError(None,None,None) # tmp
|
|
377
|
+
x = predict_tiled(self.keras_model,x,axes_in=net_axes_in,axes_out=net_axes_out,
|
|
378
|
+
n_tiles=n_tiles,block_sizes=net_axes_in_div_by,tile_overlaps=net_axes_in_overlaps,pbar=progress)
|
|
379
|
+
# x has net_axes_out semantics
|
|
380
|
+
done = True
|
|
381
|
+
progress.close()
|
|
382
|
+
except tf.errors.ResourceExhaustedError:
|
|
383
|
+
# TODO: how to test this code?
|
|
384
|
+
# n_tiles_prev = list(n_tiles) # make a copy
|
|
385
|
+
tile_sizes_approx = np.array(x.shape) / np.array(n_tiles)
|
|
386
|
+
t = [i for i in np.argsort(tile_sizes_approx) if i in x_tiling_axis][-1]
|
|
387
|
+
n_tiles[t] *= 2
|
|
388
|
+
# n_tiles = self._limit_tiling(x.shape,n_tiles,net_axes_in_div_by)
|
|
389
|
+
# if all(np.array(n_tiles) == np.array(n_tiles_prev)):
|
|
390
|
+
# raise MemoryError("Tile limit exceeded. Memory occupied by another process (notebook)?")
|
|
391
|
+
if c >= 8:
|
|
392
|
+
raise MemoryError("Giving up increasing number of tiles. Memory occupied by another process (notebook)?")
|
|
393
|
+
print('Out of memory, retrying with n_tiles = %s' % str(_permute_n_tiles(n_tiles,undo=True)))
|
|
394
|
+
progress.total = _total_n_tiles(n_tiles)
|
|
395
|
+
c += 1
|
|
396
|
+
|
|
397
|
+
n_channel_predicted = self.config.n_channel_out * (2 if self.config.probabilistic else 1)
|
|
398
|
+
x.shape[channel_out] == n_channel_predicted or _raise(ValueError())
|
|
399
|
+
|
|
400
|
+
x = resizer.after(x, net_axes_out)
|
|
401
|
+
|
|
402
|
+
mean, scale = self._mean_and_scale_from_prediction(x,axis=channel_out)
|
|
403
|
+
# mean and scale have net_axes_out semantics
|
|
404
|
+
|
|
405
|
+
if normalizer.do_after and self.config.n_channel_in==self.config.n_channel_out:
|
|
406
|
+
mean, scale = normalizer.after(mean, scale, net_axes_out)
|
|
407
|
+
|
|
408
|
+
mean, scale = _permute_axes(mean,undo=True), _permute_axes(scale,undo=True)
|
|
409
|
+
# mean and scale have img_axes_out semantics
|
|
410
|
+
|
|
411
|
+
return mean, scale
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
def _mean_and_scale_from_prediction(self,x,axis=-1):
|
|
415
|
+
# separate mean and scale
|
|
416
|
+
if self.config.probabilistic:
|
|
417
|
+
_n = self.config.n_channel_out
|
|
418
|
+
assert x.shape[axis] == 2*_n
|
|
419
|
+
slices = [slice(None) for _ in x.shape]
|
|
420
|
+
slices[axis] = slice(None,_n)
|
|
421
|
+
mean = x[tuple(slices)]
|
|
422
|
+
slices[axis] = slice(_n,None)
|
|
423
|
+
scale = x[tuple(slices)]
|
|
424
|
+
else:
|
|
425
|
+
mean, scale = x, None
|
|
426
|
+
return mean, scale
|
|
427
|
+
|
|
428
|
+
# def _limit_tiling(self,img_shape,n_tiles,block_sizes):
|
|
429
|
+
# img_shape, n_tiles, block_sizes = np.array(img_shape), np.array(n_tiles), np.array(block_sizes)
|
|
430
|
+
# n_tiles_limit = np.ceil(img_shape / block_sizes) # each tile must be at least one block in size
|
|
431
|
+
# return [int(t) for t in np.minimum(n_tiles,n_tiles_limit)]
|
|
432
|
+
|
|
433
|
+
def _axes_div_by(self, query_axes):
|
|
434
|
+
query_axes = axes_check_and_normalize(query_axes)
|
|
435
|
+
# default: must be divisible by power of 2 to allow down/up-sampling steps in unet
|
|
436
|
+
pool_div_by = 2**self.config.unet_n_depth
|
|
437
|
+
return tuple((pool_div_by if a in 'XYZT' else 1) for a in query_axes)
|
|
438
|
+
|
|
439
|
+
def _axes_tile_overlap(self, query_axes):
|
|
440
|
+
query_axes = axes_check_and_normalize(query_axes)
|
|
441
|
+
overlap = tile_overlap(self.config.unet_n_depth, self.config.unet_kern_size)
|
|
442
|
+
return tuple((overlap if a in 'XYZT' else 0) for a in query_axes)
|
|
443
|
+
|
|
444
|
+
@property
|
|
445
|
+
def _config_class(self):
|
|
446
|
+
return Config
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from __future__ import print_function, unicode_literals, absolute_import, division
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from scipy.ndimage import zoom
|
|
5
|
+
|
|
6
|
+
from .care_standard import CARE
|
|
7
|
+
from ..data import PercentileNormalizer, PadAndCropResizer
|
|
8
|
+
from ..utils import _raise, axes_dict
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class UpsamplingCARE(CARE):
|
|
12
|
+
"""CARE network for combined image restoration and upsampling of one dimension.
|
|
13
|
+
|
|
14
|
+
Extends :class:`csbdeep.models.CARE` by replacing prediction
|
|
15
|
+
(:func:`predict`, :func:`predict_probabilistic`) to first upsample Z before image restoration.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def predict(self, img, axes, factor, normalizer=PercentileNormalizer(), resizer=PadAndCropResizer(), n_tiles=None):
|
|
19
|
+
"""Apply neural network to raw image with low-resolution Z axis.
|
|
20
|
+
|
|
21
|
+
See :func:`CARE.predict` for documentation.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
factor : float
|
|
26
|
+
Upsampling factor for Z axis. It is important that this is chosen in correspondence
|
|
27
|
+
to the subsampling factor used during training data generation.
|
|
28
|
+
|
|
29
|
+
"""
|
|
30
|
+
img = self._upsample(img, axes, factor)
|
|
31
|
+
return super(UpsamplingCARE, self).predict(img, axes, normalizer, resizer, n_tiles)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def predict_probabilistic(self, img, axes, factor, normalizer=PercentileNormalizer(), resizer=PadAndCropResizer(), n_tiles=None):
|
|
35
|
+
"""Apply neural network to raw image with low-resolution Z axis for probabilistic prediction.
|
|
36
|
+
|
|
37
|
+
See :func:`CARE.predict_probabilistic` for documentation.
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
factor : float
|
|
42
|
+
Upsampling factor for Z axis. It is important that this is chosen in correspondence
|
|
43
|
+
to the subsampling factor used during training data generation.
|
|
44
|
+
|
|
45
|
+
"""
|
|
46
|
+
img = self._upsample(img, axes, factor)
|
|
47
|
+
return super(UpsamplingCARE, self).predict_probabilistic(img, axes, normalizer, resizer, n_tiles)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@staticmethod
|
|
51
|
+
def _upsample(img, axes, factor, axis='Z'):
|
|
52
|
+
factors = np.ones(img.ndim)
|
|
53
|
+
factors[axes_dict(axes)[axis]] = factor
|
|
54
|
+
return zoom(img,factors,order=1)
|