senoquant 1.0.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- senoquant/__init__.py +6 -0
- senoquant/_reader.py +7 -0
- senoquant/_widget.py +33 -0
- senoquant/napari.yaml +83 -0
- senoquant/reader/__init__.py +5 -0
- senoquant/reader/core.py +369 -0
- senoquant/tabs/__init__.py +15 -0
- senoquant/tabs/batch/__init__.py +10 -0
- senoquant/tabs/batch/backend.py +641 -0
- senoquant/tabs/batch/config.py +270 -0
- senoquant/tabs/batch/frontend.py +1283 -0
- senoquant/tabs/batch/io.py +326 -0
- senoquant/tabs/batch/layers.py +86 -0
- senoquant/tabs/quantification/__init__.py +1 -0
- senoquant/tabs/quantification/backend.py +228 -0
- senoquant/tabs/quantification/features/__init__.py +80 -0
- senoquant/tabs/quantification/features/base.py +142 -0
- senoquant/tabs/quantification/features/marker/__init__.py +5 -0
- senoquant/tabs/quantification/features/marker/config.py +69 -0
- senoquant/tabs/quantification/features/marker/dialog.py +437 -0
- senoquant/tabs/quantification/features/marker/export.py +879 -0
- senoquant/tabs/quantification/features/marker/feature.py +119 -0
- senoquant/tabs/quantification/features/marker/morphology.py +285 -0
- senoquant/tabs/quantification/features/marker/rows.py +654 -0
- senoquant/tabs/quantification/features/marker/thresholding.py +46 -0
- senoquant/tabs/quantification/features/roi.py +346 -0
- senoquant/tabs/quantification/features/spots/__init__.py +5 -0
- senoquant/tabs/quantification/features/spots/config.py +62 -0
- senoquant/tabs/quantification/features/spots/dialog.py +477 -0
- senoquant/tabs/quantification/features/spots/export.py +1292 -0
- senoquant/tabs/quantification/features/spots/feature.py +112 -0
- senoquant/tabs/quantification/features/spots/morphology.py +279 -0
- senoquant/tabs/quantification/features/spots/rows.py +241 -0
- senoquant/tabs/quantification/frontend.py +815 -0
- senoquant/tabs/segmentation/__init__.py +1 -0
- senoquant/tabs/segmentation/backend.py +131 -0
- senoquant/tabs/segmentation/frontend.py +1009 -0
- senoquant/tabs/segmentation/models/__init__.py +5 -0
- senoquant/tabs/segmentation/models/base.py +146 -0
- senoquant/tabs/segmentation/models/cpsam/details.json +65 -0
- senoquant/tabs/segmentation/models/cpsam/model.py +150 -0
- senoquant/tabs/segmentation/models/default_2d/details.json +69 -0
- senoquant/tabs/segmentation/models/default_2d/model.py +664 -0
- senoquant/tabs/segmentation/models/default_3d/details.json +69 -0
- senoquant/tabs/segmentation/models/default_3d/model.py +682 -0
- senoquant/tabs/segmentation/models/hf.py +71 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/__init__.py +1 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/details.json +26 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/model.py +96 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/__init__.py +1 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/details.json +34 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/model.py +132 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/__init__.py +2 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/__init__.py +3 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/__init__.py +6 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/generate.py +470 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/prepare.py +273 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/rawdata.py +112 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/transform.py +384 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/blocks.py +184 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/losses.py +79 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/nets.py +165 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/predict.py +467 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/probability.py +67 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/train.py +148 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/io/__init__.py +163 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/__init__.py +52 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/base_model.py +329 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_isotropic.py +160 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_projection.py +178 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_standard.py +446 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_upsampling.py +54 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/config.py +254 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/pretrained.py +119 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/care_predict.py +180 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/__init__.py +5 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/plot_utils.py +159 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/six.py +18 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/tf.py +644 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/utils.py +272 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/version.py +1 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/docs/source/conf.py +368 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/setup.py +68 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_datagen.py +169 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_models.py +462 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_utils.py +166 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tools/create_zip_contents.py +34 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/__init__.py +30 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/big.py +624 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/bioimageio_utils.py +494 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/data/__init__.py +39 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/__init__.py +10 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom2d.py +215 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom3d.py +349 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/matching.py +483 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/__init__.py +28 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/base.py +1217 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model2d.py +594 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model3d.py +696 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/nms.py +384 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/__init__.py +2 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/plot.py +74 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/render.py +298 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/rays3d.py +373 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/sample_patches.py +65 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict2d.py +90 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict3d.py +93 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/utils.py +408 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/version.py +1 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/__init__.py +45 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/__init__.py +17 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/cli.py +55 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/core.py +285 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/__init__.py +15 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/cli.py +36 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/divisibility.py +193 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/probe.py +100 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/receptive_field.py +182 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/rf_cli.py +48 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/valid_sizes.py +278 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/__init__.py +8 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/core.py +157 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/__init__.py +17 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/core.py +226 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/__init__.py +5 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/core.py +401 -0
- senoquant/tabs/settings/__init__.py +1 -0
- senoquant/tabs/settings/backend.py +29 -0
- senoquant/tabs/settings/frontend.py +19 -0
- senoquant/tabs/spots/__init__.py +1 -0
- senoquant/tabs/spots/backend.py +139 -0
- senoquant/tabs/spots/frontend.py +800 -0
- senoquant/tabs/spots/models/__init__.py +5 -0
- senoquant/tabs/spots/models/base.py +94 -0
- senoquant/tabs/spots/models/rmp/details.json +61 -0
- senoquant/tabs/spots/models/rmp/model.py +499 -0
- senoquant/tabs/spots/models/udwt/details.json +103 -0
- senoquant/tabs/spots/models/udwt/model.py +482 -0
- senoquant/utils.py +25 -0
- senoquant-1.0.0b1.dist-info/METADATA +193 -0
- senoquant-1.0.0b1.dist-info/RECORD +148 -0
- senoquant-1.0.0b1.dist-info/WHEEL +5 -0
- senoquant-1.0.0b1.dist-info/entry_points.txt +2 -0
- senoquant-1.0.0b1.dist-info/licenses/LICENSE +28 -0
- senoquant-1.0.0b1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
from __future__ import print_function, unicode_literals, absolute_import, division
|
|
2
|
+
from six.moves import range, zip, map, reduce, filter
|
|
3
|
+
from six import string_types
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import argparse
|
|
7
|
+
import warnings
|
|
8
|
+
|
|
9
|
+
from packaging.version import Version
|
|
10
|
+
|
|
11
|
+
from ..utils.tf import keras_import, BACKEND as K
|
|
12
|
+
keras = keras_import()
|
|
13
|
+
|
|
14
|
+
from ..utils import _raise, axes_check_and_normalize, axes_dict, backend_channels_last
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class BaseConfig(argparse.Namespace):
|
|
18
|
+
|
|
19
|
+
def __init__(self, axes='YX', n_channel_in=1, n_channel_out=1, allow_new_parameters=False, **kwargs):
|
|
20
|
+
|
|
21
|
+
# parse and check axes
|
|
22
|
+
axes = axes_check_and_normalize(axes)
|
|
23
|
+
ax = axes_dict(axes)
|
|
24
|
+
ax = {a: (ax[a] is not None) for a in ax}
|
|
25
|
+
|
|
26
|
+
(ax['X'] and ax['Y']) or _raise(ValueError('lateral axes X and Y must be present.'))
|
|
27
|
+
# not (ax['Z'] and ax['T']) or _raise(ValueError('using Z and T axes together not supported.'))
|
|
28
|
+
|
|
29
|
+
axes.startswith('S') or (not ax['S']) or _raise(ValueError('sample axis S must be first.'))
|
|
30
|
+
axes = axes.replace('S','') # remove sample axis if it exists
|
|
31
|
+
|
|
32
|
+
n_dim = len(axes.replace('C',''))
|
|
33
|
+
|
|
34
|
+
# TODO: Config not independent of backend. Problem?
|
|
35
|
+
# could move things around during train/predict as an alternative... good idea?
|
|
36
|
+
# otherwise, users can choose axes of input image anyhow, so doesn't matter if model is fixed to something else
|
|
37
|
+
if backend_channels_last():
|
|
38
|
+
if ax['C']:
|
|
39
|
+
axes[-1] == 'C' or _raise(ValueError('channel axis must be last for backend (%s).' % K.backend()))
|
|
40
|
+
else:
|
|
41
|
+
axes += 'C'
|
|
42
|
+
else:
|
|
43
|
+
if ax['C']:
|
|
44
|
+
axes[0] == 'C' or _raise(ValueError('channel axis must be first for backend (%s).' % K.backend()))
|
|
45
|
+
else:
|
|
46
|
+
axes = 'C'+axes
|
|
47
|
+
|
|
48
|
+
self.n_dim = n_dim
|
|
49
|
+
self.axes = axes
|
|
50
|
+
self.n_channel_in = int(max(1,n_channel_in))
|
|
51
|
+
self.n_channel_out = int(max(1,n_channel_out))
|
|
52
|
+
|
|
53
|
+
self.train_checkpoint = 'weights_best.h5'
|
|
54
|
+
self.train_checkpoint_last = 'weights_last.h5'
|
|
55
|
+
self.train_checkpoint_epoch = 'weights_now.h5'
|
|
56
|
+
|
|
57
|
+
self.update_parameters(allow_new_parameters, **kwargs)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def is_valid(self, return_invalid=False):
|
|
61
|
+
return (True, tuple()) if return_invalid else True
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def update_parameters(self, allow_new=False, **kwargs):
|
|
65
|
+
if not allow_new:
|
|
66
|
+
attr_new = []
|
|
67
|
+
for k in kwargs:
|
|
68
|
+
try:
|
|
69
|
+
getattr(self, k)
|
|
70
|
+
except AttributeError:
|
|
71
|
+
attr_new.append(k)
|
|
72
|
+
if len(attr_new) > 0:
|
|
73
|
+
raise AttributeError("Not allowed to add new parameters (%s)" % ', '.join(attr_new))
|
|
74
|
+
for k in kwargs:
|
|
75
|
+
setattr(self, k, kwargs[k])
|
|
76
|
+
|
|
77
|
+
@classmethod
|
|
78
|
+
def update_loaded_config(cls, config):
|
|
79
|
+
"""Called by model to update loaded config dictionary before config object is created
|
|
80
|
+
|
|
81
|
+
Can be used to modify or introduce/delete parameters, e.g. to ensure
|
|
82
|
+
backwards compatibility after new parameters have been introduced.
|
|
83
|
+
|
|
84
|
+
Parameters
|
|
85
|
+
----------
|
|
86
|
+
config : dict
|
|
87
|
+
dictionary of config parameters loaded from file
|
|
88
|
+
|
|
89
|
+
Returns
|
|
90
|
+
-------
|
|
91
|
+
updated_config: dict
|
|
92
|
+
an updated version of the config parameter dictionary
|
|
93
|
+
"""
|
|
94
|
+
return config
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class Config(BaseConfig):
|
|
99
|
+
"""Default configuration for a CARE model.
|
|
100
|
+
|
|
101
|
+
This configuration is meant to be used with :class:`CARE`
|
|
102
|
+
and related models (e.g., :class:`IsotropicCARE`).
|
|
103
|
+
|
|
104
|
+
Parameters
|
|
105
|
+
----------
|
|
106
|
+
axes : str
|
|
107
|
+
Axes of the neural network (channel axis optional).
|
|
108
|
+
n_channel_in : int
|
|
109
|
+
Number of channels of given input image.
|
|
110
|
+
n_channel_out : int
|
|
111
|
+
Number of channels of predicted output image.
|
|
112
|
+
probabilistic : bool
|
|
113
|
+
Probabilistic prediction of per-pixel Laplace distributions or
|
|
114
|
+
typical regression of per-pixel scalar values.
|
|
115
|
+
allow_new_parameters : bool
|
|
116
|
+
Allow adding new configuration attributes (i.e. not listed below).
|
|
117
|
+
kwargs : dict
|
|
118
|
+
Overwrite (or add) configuration attributes (see below).
|
|
119
|
+
|
|
120
|
+
Example
|
|
121
|
+
-------
|
|
122
|
+
>>> config = Config('YX', probabilistic=True, unet_n_depth=3)
|
|
123
|
+
|
|
124
|
+
Attributes
|
|
125
|
+
----------
|
|
126
|
+
n_dim : int
|
|
127
|
+
Dimensionality of input images (2 or 3).
|
|
128
|
+
unet_residual : bool
|
|
129
|
+
Parameter `residual` of :func:`csbdeep.nets.common_unet`. Default: ``n_channel_in == n_channel_out``
|
|
130
|
+
unet_n_depth : int
|
|
131
|
+
Parameter `n_depth` of :func:`csbdeep.nets.common_unet`. Default: ``2``
|
|
132
|
+
unet_kern_size : int
|
|
133
|
+
Parameter `kern_size` of :func:`csbdeep.nets.common_unet`. Default: ``5 if n_dim==2 else 3``
|
|
134
|
+
unet_n_first : int
|
|
135
|
+
Parameter `n_first` of :func:`csbdeep.nets.common_unet`. Default: ``32``
|
|
136
|
+
unet_last_activation : str
|
|
137
|
+
Parameter `last_activation` of :func:`csbdeep.nets.common_unet`. Default: ``linear``
|
|
138
|
+
train_loss : str
|
|
139
|
+
Name of training loss. Default: ``'laplace' if probabilistic else 'mae'``
|
|
140
|
+
train_epochs : int
|
|
141
|
+
Number of training epochs. Default: ``100``
|
|
142
|
+
train_steps_per_epoch : int
|
|
143
|
+
Number of parameter update steps per epoch. Default: ``400``
|
|
144
|
+
train_learning_rate : float
|
|
145
|
+
Learning rate for training. Default: ``0.0004``
|
|
146
|
+
train_batch_size : int
|
|
147
|
+
Batch size for training. Default: ``16``
|
|
148
|
+
train_tensorboard : bool
|
|
149
|
+
Enable TensorBoard for monitoring training progress. Default: ``True``
|
|
150
|
+
train_checkpoint : str
|
|
151
|
+
Name of checkpoint file for model weights (only best are saved); set to ``None`` to disable. Default: ``weights_best.h5``
|
|
152
|
+
train_reduce_lr : dict
|
|
153
|
+
Parameter :class:`dict` of ReduceLROnPlateau_ callback; set to ``None`` to disable. Default: ``{'factor': 0.5, 'patience': 10, 'min_delta': 0}``
|
|
154
|
+
|
|
155
|
+
.. _ReduceLROnPlateau: https://keras.io/callbacks/#reducelronplateau
|
|
156
|
+
"""
|
|
157
|
+
|
|
158
|
+
def __init__(self, axes='YX', n_channel_in=1, n_channel_out=1, probabilistic=False, allow_new_parameters=False, **kwargs):
|
|
159
|
+
"""See class docstring."""
|
|
160
|
+
|
|
161
|
+
super(Config, self).__init__(axes, n_channel_in, n_channel_out)
|
|
162
|
+
not ('Z' in self.axes and 'T' in self.axes) or _raise(ValueError('using Z and T axes together not supported.'))
|
|
163
|
+
|
|
164
|
+
self.probabilistic = bool(probabilistic)
|
|
165
|
+
|
|
166
|
+
# default config (can be overwritten by kwargs below)
|
|
167
|
+
self.unet_residual = self.n_channel_in == self.n_channel_out
|
|
168
|
+
self.unet_n_depth = 2
|
|
169
|
+
self.unet_kern_size = 5 if self.n_dim==2 else 3
|
|
170
|
+
self.unet_n_first = 32
|
|
171
|
+
self.unet_last_activation = 'linear'
|
|
172
|
+
if backend_channels_last():
|
|
173
|
+
self.unet_input_shape = self.n_dim*(None,) + (self.n_channel_in,)
|
|
174
|
+
else:
|
|
175
|
+
self.unet_input_shape = (self.n_channel_in,) + self.n_dim*(None,)
|
|
176
|
+
|
|
177
|
+
self.train_loss = 'laplace' if self.probabilistic else 'mae'
|
|
178
|
+
self.train_epochs = 100
|
|
179
|
+
self.train_steps_per_epoch = 400
|
|
180
|
+
self.train_learning_rate = 0.0004
|
|
181
|
+
self.train_batch_size = 16
|
|
182
|
+
self.train_tensorboard = True
|
|
183
|
+
|
|
184
|
+
# the parameter 'min_delta' was called 'epsilon' for keras<=2.1.5
|
|
185
|
+
# keras.__version__ was removed in tensorflow 2.13.0
|
|
186
|
+
min_delta_key = 'epsilon' if Version(getattr(keras, '__version__', '9.9.9'))<=Version('2.1.5') else 'min_delta'
|
|
187
|
+
self.train_reduce_lr = {'factor': 0.5, 'patience': 10, min_delta_key: 0}
|
|
188
|
+
|
|
189
|
+
# disallow setting 'n_dim' manually
|
|
190
|
+
try:
|
|
191
|
+
del kwargs['n_dim']
|
|
192
|
+
# warnings.warn("ignoring parameter 'n_dim'")
|
|
193
|
+
except:
|
|
194
|
+
pass
|
|
195
|
+
|
|
196
|
+
self.update_parameters(allow_new_parameters, **kwargs)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def is_valid(self, return_invalid=False):
|
|
200
|
+
"""Check if configuration is valid.
|
|
201
|
+
|
|
202
|
+
Returns
|
|
203
|
+
-------
|
|
204
|
+
bool
|
|
205
|
+
Flag that indicates whether the current configuration values are valid.
|
|
206
|
+
"""
|
|
207
|
+
def _is_int(v,low=None,high=None):
|
|
208
|
+
return (
|
|
209
|
+
isinstance(v,int) and
|
|
210
|
+
(True if low is None else low <= v) and
|
|
211
|
+
(True if high is None else v <= high)
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
ok = {}
|
|
215
|
+
ok['n_dim'] = self.n_dim in (2,3)
|
|
216
|
+
try:
|
|
217
|
+
axes_check_and_normalize(self.axes,self.n_dim+1,disallowed='S')
|
|
218
|
+
ok['axes'] = True
|
|
219
|
+
except:
|
|
220
|
+
ok['axes'] = False
|
|
221
|
+
ok['n_channel_in'] = _is_int(self.n_channel_in,1)
|
|
222
|
+
ok['n_channel_out'] = _is_int(self.n_channel_out,1)
|
|
223
|
+
ok['probabilistic'] = isinstance(self.probabilistic,bool)
|
|
224
|
+
|
|
225
|
+
ok['unet_residual'] = (
|
|
226
|
+
isinstance(self.unet_residual,bool) and
|
|
227
|
+
(not self.unet_residual or (self.n_channel_in==self.n_channel_out))
|
|
228
|
+
)
|
|
229
|
+
ok['unet_n_depth'] = _is_int(self.unet_n_depth,1)
|
|
230
|
+
ok['unet_kern_size'] = _is_int(self.unet_kern_size,1)
|
|
231
|
+
ok['unet_n_first'] = _is_int(self.unet_n_first,1)
|
|
232
|
+
ok['unet_last_activation'] = self.unet_last_activation in ('linear','relu')
|
|
233
|
+
ok['unet_input_shape'] = (
|
|
234
|
+
isinstance(self.unet_input_shape,(list,tuple))
|
|
235
|
+
and len(self.unet_input_shape) == self.n_dim+1
|
|
236
|
+
and self.unet_input_shape[-1] == self.n_channel_in
|
|
237
|
+
# and all((d is None or (_is_int(d) and d%(2**self.unet_n_depth)==0) for d in self.unet_input_shape[:-1]))
|
|
238
|
+
)
|
|
239
|
+
ok['train_loss'] = (
|
|
240
|
+
( self.probabilistic and self.train_loss == 'laplace' ) or
|
|
241
|
+
(not self.probabilistic and self.train_loss in ('mse','mae'))
|
|
242
|
+
)
|
|
243
|
+
ok['train_epochs'] = _is_int(self.train_epochs,1)
|
|
244
|
+
ok['train_steps_per_epoch'] = _is_int(self.train_steps_per_epoch,1)
|
|
245
|
+
ok['train_learning_rate'] = np.isscalar(self.train_learning_rate) and self.train_learning_rate > 0
|
|
246
|
+
ok['train_batch_size'] = _is_int(self.train_batch_size,1)
|
|
247
|
+
ok['train_tensorboard'] = isinstance(self.train_tensorboard,bool)
|
|
248
|
+
ok['train_checkpoint'] = self.train_checkpoint is None or isinstance(self.train_checkpoint,string_types)
|
|
249
|
+
ok['train_reduce_lr'] = self.train_reduce_lr is None or isinstance(self.train_reduce_lr,dict)
|
|
250
|
+
|
|
251
|
+
if return_invalid:
|
|
252
|
+
return all(ok.values()), tuple(k for (k,v) in ok.items() if not v)
|
|
253
|
+
else:
|
|
254
|
+
return all(ok.values())
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
from __future__ import print_function, unicode_literals, absolute_import, division
|
|
3
|
+
|
|
4
|
+
from collections import OrderedDict
|
|
5
|
+
from warnings import warn
|
|
6
|
+
from ..utils import _raise
|
|
7
|
+
from ..utils.six import Path
|
|
8
|
+
|
|
9
|
+
from packaging.version import Version
|
|
10
|
+
from ..utils.tf import keras_import, v_keras
|
|
11
|
+
get_file = keras_import('utils', 'get_file')
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
_MODELS = {}
|
|
15
|
+
_ALIASES = {}
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def clear_models_and_aliases(*cls):
|
|
19
|
+
if len(cls) == 0:
|
|
20
|
+
_MODELS.clear()
|
|
21
|
+
_ALIASES.clear()
|
|
22
|
+
else:
|
|
23
|
+
for c in cls:
|
|
24
|
+
if c in _MODELS: del _MODELS[c]
|
|
25
|
+
if c in _ALIASES: del _ALIASES[c]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def register_model(cls, key, url, hash):
|
|
29
|
+
""" Example:
|
|
30
|
+
|
|
31
|
+
register_model(StarDist2D, 'my_great_model', 'https://github.com/stardist/stardist-models/releases/download/v0.1/python_2D_versatile_fluo.zip', md5sum_as_astring)
|
|
32
|
+
|
|
33
|
+
"""
|
|
34
|
+
# key must be a valid file/folder name in the file system
|
|
35
|
+
models = _MODELS.setdefault(cls,OrderedDict())
|
|
36
|
+
key not in models or warn("re-registering model '%s' (was already registered for '%s')" % (key, cls.__name__))
|
|
37
|
+
models[key] = dict(url=url, hash=hash)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def register_aliases(cls, key, *names):
|
|
41
|
+
# aliases can be arbitrary strings
|
|
42
|
+
if len(names) == 0: return
|
|
43
|
+
models = _MODELS.get(cls,{})
|
|
44
|
+
key in models or _raise(ValueError("model '%s' is not registered for '%s'" % (key, cls.__name__)))
|
|
45
|
+
aliases = _ALIASES.setdefault(cls,OrderedDict())
|
|
46
|
+
for name in names:
|
|
47
|
+
aliases.get(name,key) == key or warn("alias '%s' was previously registered with model '%s' for '%s'" % (name, aliases[name], cls.__name__))
|
|
48
|
+
aliases[name] = key
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def get_registered_models(cls, return_aliases=True, verbose=False):
|
|
52
|
+
models = _MODELS.get(cls,{})
|
|
53
|
+
aliases = _ALIASES.get(cls,{})
|
|
54
|
+
model_keys = tuple(models.keys())
|
|
55
|
+
model_aliases = {key: tuple(name for name in aliases if aliases[name] == key) for key in models}
|
|
56
|
+
if verbose:
|
|
57
|
+
# this code is very messy and should be refactored...
|
|
58
|
+
_n = len(models)
|
|
59
|
+
_str_model = 'model' if _n == 1 else 'models'
|
|
60
|
+
_str_is_are = 'is' if _n == 1 else 'are'
|
|
61
|
+
_str_colon = ':' if _n > 0 else ''
|
|
62
|
+
print("There {is_are} {n} registered {model_s} for '{clazz}'{c}".format(
|
|
63
|
+
n=_n, clazz=cls.__name__, is_are=_str_is_are, model_s=_str_model, c=_str_colon))
|
|
64
|
+
if _n > 0:
|
|
65
|
+
print()
|
|
66
|
+
_maxkeylen = 2 + max(len(key) for key in models)
|
|
67
|
+
print("Name{s}Alias(es)".format(s=' '*(_maxkeylen-4+3)))
|
|
68
|
+
print("────{s}─────────".format(s=' '*(_maxkeylen-4+3)))
|
|
69
|
+
for key in models:
|
|
70
|
+
_aliases = ' '
|
|
71
|
+
_m = len(model_aliases[key])
|
|
72
|
+
if _m > 0:
|
|
73
|
+
_aliases += "'%s'" % "', '".join(model_aliases[key])
|
|
74
|
+
else:
|
|
75
|
+
_aliases += "None"
|
|
76
|
+
_key = ("{s:%d}"%_maxkeylen).format(s="'%s'"%key)
|
|
77
|
+
print("{key}{aliases}".format(key=_key, aliases=_aliases))
|
|
78
|
+
return ((model_keys, model_aliases) if return_aliases else model_keys)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def get_model_details(cls, key_or_alias, verbose=False):
|
|
82
|
+
models = _MODELS.get(cls,{})
|
|
83
|
+
if key_or_alias in models:
|
|
84
|
+
key = key_or_alias
|
|
85
|
+
alias = None
|
|
86
|
+
else:
|
|
87
|
+
aliases = _ALIASES.get(cls,{})
|
|
88
|
+
alias = key_or_alias
|
|
89
|
+
alias in aliases or _raise(ValueError("'%s' is neither a key or alias for '%s'" % (alias, cls.__name__)))
|
|
90
|
+
key = aliases[alias]
|
|
91
|
+
if verbose:
|
|
92
|
+
print("Found model '{model}'{alias_str} for '{clazz}'.".format(
|
|
93
|
+
model=key, clazz=cls.__name__, alias_str=('' if alias is None else " with alias '%s'" % alias)))
|
|
94
|
+
return key, alias, models[key]
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def get_model_folder(cls, key_or_alias):
|
|
98
|
+
key, alias, m = get_model_details(cls, key_or_alias)
|
|
99
|
+
target = str(Path('models') / cls.__name__ / key)
|
|
100
|
+
path = Path(get_file(fname=key+'.zip', origin=m['url'], file_hash=m['hash'],
|
|
101
|
+
cache_subdir=target, extract=True))
|
|
102
|
+
if v_keras >= Version("3.6.0"):
|
|
103
|
+
path_folder = path
|
|
104
|
+
suffix = "_extracted"
|
|
105
|
+
if path_folder.is_dir() and path_folder.name.endswith(suffix) and len(path_folder.name) > len(suffix):
|
|
106
|
+
path_folder = path_folder.with_name(path_folder.name[:-len(suffix)])
|
|
107
|
+
if not path_folder.exists():
|
|
108
|
+
path_folder.symlink_to(path.relative_to(path.parent))
|
|
109
|
+
else:
|
|
110
|
+
path_folder = path.parent
|
|
111
|
+
assert path_folder.exists()
|
|
112
|
+
return path_folder
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def get_model_instance(cls, key_or_alias):
|
|
116
|
+
path = get_model_folder(cls, key_or_alias)
|
|
117
|
+
model = cls(config=None, name=path.stem, basedir=path.parent)
|
|
118
|
+
model.basedir = None # make read-only
|
|
119
|
+
return model
|
|
File without changes
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
from __future__ import print_function, unicode_literals, absolute_import, division
|
|
4
|
+
|
|
5
|
+
import argparse
|
|
6
|
+
import sys
|
|
7
|
+
from pprint import pprint
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
|
|
12
|
+
from csbdeep.io import save_tiff_imagej_compatible
|
|
13
|
+
from csbdeep.utils import _raise, axes_check_and_normalize
|
|
14
|
+
from csbdeep.utils.six import Path
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def str2bool(v):
|
|
18
|
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
|
19
|
+
return True
|
|
20
|
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
|
21
|
+
return False
|
|
22
|
+
else:
|
|
23
|
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def parse_args():
|
|
27
|
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
28
|
+
|
|
29
|
+
parser.add_argument('--quiet', metavar='', type=str2bool, required=False, const=True, nargs='?', default=False, help="don't show status messages")
|
|
30
|
+
parser.add_argument('--gpu-memory-limit', metavar='', type=float, required=False, default=None, help="limit GPU memory to this fraction (0...1)")
|
|
31
|
+
|
|
32
|
+
data = parser.add_argument_group("input")
|
|
33
|
+
data.add_argument('--input-dir', metavar='', type=str, required=False, default=None, help="path to folder with input images")
|
|
34
|
+
data.add_argument('--input-pattern', metavar='', type=str, required=False, default='*.tif*', help="glob-style file name pattern of input images")
|
|
35
|
+
data.add_argument('--input-axes', metavar='', type=str, required=False, default=None, help="axes string of input images")
|
|
36
|
+
data.add_argument('--norm-pmin', metavar='', type=float, required=False, default=2, help="'pmin' for PercentileNormalizer")
|
|
37
|
+
data.add_argument('--norm-pmax', metavar='', type=float, required=False, default=99.8, help="'pmax' for PercentileNormalizer")
|
|
38
|
+
data.add_argument('--norm-undo', metavar='', type=str2bool, required=False, const=True, nargs='?', default=True, help="'do_after' for PercentileNormalizer")
|
|
39
|
+
data.add_argument('--n-tiles', metavar='', type=int, required=False, nargs='+', default=None, help="number of tiles for prediction")
|
|
40
|
+
|
|
41
|
+
model = parser.add_argument_group("model")
|
|
42
|
+
model.add_argument('--model-basedir', metavar='', type=str, required=False, default=None, help="path to folder that contains CARE model")
|
|
43
|
+
model.add_argument('--model-name', metavar='', type=str, required=False, default=None, help="name of CARE model")
|
|
44
|
+
model.add_argument('--model-weights', metavar='', type=str, required=False, default=None, help="specific name of weights file to load (located in model folder)")
|
|
45
|
+
|
|
46
|
+
output = parser.add_argument_group("output")
|
|
47
|
+
output.add_argument('--output-dir', metavar='', type=str, required=False, default=None, help="path to folder where restored images will be saved")
|
|
48
|
+
output.add_argument('--output-name', metavar='', type=str, required=False, default='{model_name}/{file_path}/{file_name}{file_ext}', help="name pattern of restored image (special tokens: {file_path}, {file_name}, {file_ext}, {model_name}, {model_weights})")
|
|
49
|
+
output.add_argument('--output-dtype', metavar='', type=str, required=False, default='float32', help="data type of the saved tiff file")
|
|
50
|
+
output.add_argument('--imagej-tiff', metavar='', type=str2bool, required=False, const=True, nargs='?', default=True, help="save restored image as ImageJ-compatible TIFF file")
|
|
51
|
+
output.add_argument('--dry-run', metavar='', type=str2bool, required=False, const=True, nargs='?', default=False, help="don't save restored images")
|
|
52
|
+
|
|
53
|
+
return parser, parser.parse_args()
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def main():
|
|
57
|
+
if not ('__file__' in locals() or '__file__' in globals()):
|
|
58
|
+
print('running interactively, exiting.')
|
|
59
|
+
sys.exit(0)
|
|
60
|
+
|
|
61
|
+
# parse arguments
|
|
62
|
+
parser, args = parse_args()
|
|
63
|
+
args_dict = vars(args)
|
|
64
|
+
|
|
65
|
+
# exit and show help if no arguments provided at all
|
|
66
|
+
if len(sys.argv) == 1:
|
|
67
|
+
parser.print_help()
|
|
68
|
+
sys.exit(0)
|
|
69
|
+
|
|
70
|
+
# check for required arguments manually (because of argparse issue)
|
|
71
|
+
required = ('--input-dir','--input-axes', '--norm-pmin', '--norm-pmax', '--model-basedir', '--model-name', '--output-dir')
|
|
72
|
+
for r in required:
|
|
73
|
+
dest = r[2:].replace('-','_')
|
|
74
|
+
if args_dict[dest] is None:
|
|
75
|
+
parser.print_usage(file=sys.stderr)
|
|
76
|
+
print("%s: error: the following arguments are required: %s" % (parser.prog,r), file=sys.stderr)
|
|
77
|
+
sys.exit(1)
|
|
78
|
+
|
|
79
|
+
# show effective arguments (including defaults)
|
|
80
|
+
if not args.quiet:
|
|
81
|
+
print('Arguments')
|
|
82
|
+
print('---------')
|
|
83
|
+
pprint(args_dict)
|
|
84
|
+
print()
|
|
85
|
+
sys.stdout.flush()
|
|
86
|
+
|
|
87
|
+
# logging function
|
|
88
|
+
log = (lambda *a,**k: None) if args.quiet else tqdm.write
|
|
89
|
+
|
|
90
|
+
# get list of input files and exit if there are none
|
|
91
|
+
file_list = list(Path(args.input_dir).glob(args.input_pattern))
|
|
92
|
+
if len(file_list) == 0:
|
|
93
|
+
log("No files to process in '%s' with pattern '%s'." % (args.input_dir,args.input_pattern))
|
|
94
|
+
sys.exit(0)
|
|
95
|
+
|
|
96
|
+
# delay imports after checking to all required arguments are provided
|
|
97
|
+
from tifffile import imread
|
|
98
|
+
try:
|
|
99
|
+
from tifffile import imwrite as imsave
|
|
100
|
+
except ImportError:
|
|
101
|
+
from tifffile import imsave
|
|
102
|
+
from csbdeep.utils.tf import BACKEND as K
|
|
103
|
+
from csbdeep.models import CARE
|
|
104
|
+
from csbdeep.data import PercentileNormalizer
|
|
105
|
+
sys.stdout.flush()
|
|
106
|
+
sys.stderr.flush()
|
|
107
|
+
|
|
108
|
+
# limit gpu memory
|
|
109
|
+
if args.gpu_memory_limit is not None:
|
|
110
|
+
from csbdeep.utils.tf import limit_gpu_memory
|
|
111
|
+
limit_gpu_memory(args.gpu_memory_limit)
|
|
112
|
+
|
|
113
|
+
# create CARE model and load weights, create normalizer
|
|
114
|
+
K.clear_session()
|
|
115
|
+
model = CARE(config=None, name=args.model_name, basedir=args.model_basedir)
|
|
116
|
+
if args.model_weights is not None:
|
|
117
|
+
print("Loading network weights from '%s'." % args.model_weights)
|
|
118
|
+
model.load_weights(args.model_weights)
|
|
119
|
+
normalizer = PercentileNormalizer(pmin=args.norm_pmin, pmax=args.norm_pmax, do_after=args.norm_undo)
|
|
120
|
+
|
|
121
|
+
n_tiles = args.n_tiles
|
|
122
|
+
if n_tiles is not None and len(n_tiles)==1:
|
|
123
|
+
n_tiles = n_tiles[0]
|
|
124
|
+
|
|
125
|
+
processed = []
|
|
126
|
+
|
|
127
|
+
# process all files
|
|
128
|
+
for file_in in tqdm(file_list, disable=args.quiet or (n_tiles is not None and np.prod(n_tiles)>1)):
|
|
129
|
+
# construct output file name
|
|
130
|
+
file_out = Path(args.output_dir) / args.output_name.format (
|
|
131
|
+
file_path = str(file_in.relative_to(args.input_dir).parent),
|
|
132
|
+
file_name = file_in.stem, file_ext = file_in.suffix,
|
|
133
|
+
model_name = args.model_name, model_weights = Path(args.model_weights).stem if args.model_weights is not None else None
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# checks
|
|
137
|
+
(file_in.suffix.lower() in ('.tif','.tiff') and
|
|
138
|
+
file_out.suffix.lower() in ('.tif','.tiff')) or _raise(ValueError('only tiff files supported.'))
|
|
139
|
+
|
|
140
|
+
# load and predict restored image
|
|
141
|
+
img = imread(str(file_in))
|
|
142
|
+
restored = model.predict(img, axes=args.input_axes, normalizer=normalizer, n_tiles=n_tiles)
|
|
143
|
+
|
|
144
|
+
# restored image could be multi-channel even if input image is not
|
|
145
|
+
axes_out = axes_check_and_normalize(args.input_axes)
|
|
146
|
+
if restored.ndim > img.ndim:
|
|
147
|
+
assert restored.ndim == img.ndim + 1
|
|
148
|
+
assert 'C' not in axes_out
|
|
149
|
+
axes_out += 'C'
|
|
150
|
+
|
|
151
|
+
# convert data type (if necessary)
|
|
152
|
+
restored = restored.astype(np.dtype(args.output_dtype), copy=False)
|
|
153
|
+
|
|
154
|
+
# save to disk
|
|
155
|
+
if not args.dry_run:
|
|
156
|
+
file_out.parent.mkdir(parents=True, exist_ok=True)
|
|
157
|
+
if args.imagej_tiff:
|
|
158
|
+
save_tiff_imagej_compatible(str(file_out), restored, axes_out)
|
|
159
|
+
else:
|
|
160
|
+
imsave(str(file_out), restored)
|
|
161
|
+
|
|
162
|
+
processed.append((file_in,file_out))
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
# print summary of processed files
|
|
166
|
+
if not args.quiet:
|
|
167
|
+
sys.stdout.flush()
|
|
168
|
+
sys.stderr.flush()
|
|
169
|
+
n_processed = len(processed)
|
|
170
|
+
len_processed = len(str(n_processed))
|
|
171
|
+
log('Finished processing %d %s' % (n_processed, 'files' if n_processed > 1 else 'file'))
|
|
172
|
+
log('-' * (26+len_processed if n_processed > 1 else 26))
|
|
173
|
+
for i,(file_in,file_out) in enumerate(processed):
|
|
174
|
+
len_file = max(len(str(file_in)),len(str(file_out)))
|
|
175
|
+
log(('{:>%d}. in : {:>%d}'%(len_processed,len_file)).format(1+i,str(file_in)))
|
|
176
|
+
log(('{:>%d} out: {:>%d}'%(len_processed,len_file)).format('',str(file_out)))
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
if __name__ == '__main__':
|
|
180
|
+
main()
|