senoquant 1.0.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. senoquant/__init__.py +6 -0
  2. senoquant/_reader.py +7 -0
  3. senoquant/_widget.py +33 -0
  4. senoquant/napari.yaml +83 -0
  5. senoquant/reader/__init__.py +5 -0
  6. senoquant/reader/core.py +369 -0
  7. senoquant/tabs/__init__.py +15 -0
  8. senoquant/tabs/batch/__init__.py +10 -0
  9. senoquant/tabs/batch/backend.py +641 -0
  10. senoquant/tabs/batch/config.py +270 -0
  11. senoquant/tabs/batch/frontend.py +1283 -0
  12. senoquant/tabs/batch/io.py +326 -0
  13. senoquant/tabs/batch/layers.py +86 -0
  14. senoquant/tabs/quantification/__init__.py +1 -0
  15. senoquant/tabs/quantification/backend.py +228 -0
  16. senoquant/tabs/quantification/features/__init__.py +80 -0
  17. senoquant/tabs/quantification/features/base.py +142 -0
  18. senoquant/tabs/quantification/features/marker/__init__.py +5 -0
  19. senoquant/tabs/quantification/features/marker/config.py +69 -0
  20. senoquant/tabs/quantification/features/marker/dialog.py +437 -0
  21. senoquant/tabs/quantification/features/marker/export.py +879 -0
  22. senoquant/tabs/quantification/features/marker/feature.py +119 -0
  23. senoquant/tabs/quantification/features/marker/morphology.py +285 -0
  24. senoquant/tabs/quantification/features/marker/rows.py +654 -0
  25. senoquant/tabs/quantification/features/marker/thresholding.py +46 -0
  26. senoquant/tabs/quantification/features/roi.py +346 -0
  27. senoquant/tabs/quantification/features/spots/__init__.py +5 -0
  28. senoquant/tabs/quantification/features/spots/config.py +62 -0
  29. senoquant/tabs/quantification/features/spots/dialog.py +477 -0
  30. senoquant/tabs/quantification/features/spots/export.py +1292 -0
  31. senoquant/tabs/quantification/features/spots/feature.py +112 -0
  32. senoquant/tabs/quantification/features/spots/morphology.py +279 -0
  33. senoquant/tabs/quantification/features/spots/rows.py +241 -0
  34. senoquant/tabs/quantification/frontend.py +815 -0
  35. senoquant/tabs/segmentation/__init__.py +1 -0
  36. senoquant/tabs/segmentation/backend.py +131 -0
  37. senoquant/tabs/segmentation/frontend.py +1009 -0
  38. senoquant/tabs/segmentation/models/__init__.py +5 -0
  39. senoquant/tabs/segmentation/models/base.py +146 -0
  40. senoquant/tabs/segmentation/models/cpsam/details.json +65 -0
  41. senoquant/tabs/segmentation/models/cpsam/model.py +150 -0
  42. senoquant/tabs/segmentation/models/default_2d/details.json +69 -0
  43. senoquant/tabs/segmentation/models/default_2d/model.py +664 -0
  44. senoquant/tabs/segmentation/models/default_3d/details.json +69 -0
  45. senoquant/tabs/segmentation/models/default_3d/model.py +682 -0
  46. senoquant/tabs/segmentation/models/hf.py +71 -0
  47. senoquant/tabs/segmentation/models/nuclear_dilation/__init__.py +1 -0
  48. senoquant/tabs/segmentation/models/nuclear_dilation/details.json +26 -0
  49. senoquant/tabs/segmentation/models/nuclear_dilation/model.py +96 -0
  50. senoquant/tabs/segmentation/models/perinuclear_rings/__init__.py +1 -0
  51. senoquant/tabs/segmentation/models/perinuclear_rings/details.json +34 -0
  52. senoquant/tabs/segmentation/models/perinuclear_rings/model.py +132 -0
  53. senoquant/tabs/segmentation/stardist_onnx_utils/__init__.py +2 -0
  54. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/__init__.py +3 -0
  55. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/__init__.py +6 -0
  56. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/generate.py +470 -0
  57. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/prepare.py +273 -0
  58. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/rawdata.py +112 -0
  59. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/transform.py +384 -0
  60. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/__init__.py +0 -0
  61. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/blocks.py +184 -0
  62. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/losses.py +79 -0
  63. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/nets.py +165 -0
  64. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/predict.py +467 -0
  65. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/probability.py +67 -0
  66. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/train.py +148 -0
  67. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/io/__init__.py +163 -0
  68. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/__init__.py +52 -0
  69. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/base_model.py +329 -0
  70. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_isotropic.py +160 -0
  71. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_projection.py +178 -0
  72. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_standard.py +446 -0
  73. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_upsampling.py +54 -0
  74. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/config.py +254 -0
  75. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/pretrained.py +119 -0
  76. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/__init__.py +0 -0
  77. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/care_predict.py +180 -0
  78. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/__init__.py +5 -0
  79. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/plot_utils.py +159 -0
  80. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/six.py +18 -0
  81. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/tf.py +644 -0
  82. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/utils.py +272 -0
  83. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/version.py +1 -0
  84. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/docs/source/conf.py +368 -0
  85. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/setup.py +68 -0
  86. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_datagen.py +169 -0
  87. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_models.py +462 -0
  88. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_utils.py +166 -0
  89. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tools/create_zip_contents.py +34 -0
  90. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/__init__.py +30 -0
  91. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/big.py +624 -0
  92. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/bioimageio_utils.py +494 -0
  93. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/data/__init__.py +39 -0
  94. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/__init__.py +10 -0
  95. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom2d.py +215 -0
  96. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom3d.py +349 -0
  97. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/matching.py +483 -0
  98. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/__init__.py +28 -0
  99. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/base.py +1217 -0
  100. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model2d.py +594 -0
  101. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model3d.py +696 -0
  102. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/nms.py +384 -0
  103. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/__init__.py +2 -0
  104. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/plot.py +74 -0
  105. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/render.py +298 -0
  106. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/rays3d.py +373 -0
  107. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/sample_patches.py +65 -0
  108. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/__init__.py +0 -0
  109. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict2d.py +90 -0
  110. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict3d.py +93 -0
  111. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/utils.py +408 -0
  112. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/version.py +1 -0
  113. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/__init__.py +45 -0
  114. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/__init__.py +17 -0
  115. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/cli.py +55 -0
  116. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/core.py +285 -0
  117. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/__init__.py +15 -0
  118. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/cli.py +36 -0
  119. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/divisibility.py +193 -0
  120. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/probe.py +100 -0
  121. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/receptive_field.py +182 -0
  122. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/rf_cli.py +48 -0
  123. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/valid_sizes.py +278 -0
  124. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/__init__.py +8 -0
  125. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/core.py +157 -0
  126. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/__init__.py +17 -0
  127. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/core.py +226 -0
  128. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/__init__.py +5 -0
  129. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/core.py +401 -0
  130. senoquant/tabs/settings/__init__.py +1 -0
  131. senoquant/tabs/settings/backend.py +29 -0
  132. senoquant/tabs/settings/frontend.py +19 -0
  133. senoquant/tabs/spots/__init__.py +1 -0
  134. senoquant/tabs/spots/backend.py +139 -0
  135. senoquant/tabs/spots/frontend.py +800 -0
  136. senoquant/tabs/spots/models/__init__.py +5 -0
  137. senoquant/tabs/spots/models/base.py +94 -0
  138. senoquant/tabs/spots/models/rmp/details.json +61 -0
  139. senoquant/tabs/spots/models/rmp/model.py +499 -0
  140. senoquant/tabs/spots/models/udwt/details.json +103 -0
  141. senoquant/tabs/spots/models/udwt/model.py +482 -0
  142. senoquant/utils.py +25 -0
  143. senoquant-1.0.0b1.dist-info/METADATA +193 -0
  144. senoquant-1.0.0b1.dist-info/RECORD +148 -0
  145. senoquant-1.0.0b1.dist-info/WHEEL +5 -0
  146. senoquant-1.0.0b1.dist-info/entry_points.txt +2 -0
  147. senoquant-1.0.0b1.dist-info/licenses/LICENSE +28 -0
  148. senoquant-1.0.0b1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,254 @@
1
+ from __future__ import print_function, unicode_literals, absolute_import, division
2
+ from six.moves import range, zip, map, reduce, filter
3
+ from six import string_types
4
+
5
+ import numpy as np
6
+ import argparse
7
+ import warnings
8
+
9
+ from packaging.version import Version
10
+
11
+ from ..utils.tf import keras_import, BACKEND as K
12
+ keras = keras_import()
13
+
14
+ from ..utils import _raise, axes_check_and_normalize, axes_dict, backend_channels_last
15
+
16
+
17
+ class BaseConfig(argparse.Namespace):
18
+
19
+ def __init__(self, axes='YX', n_channel_in=1, n_channel_out=1, allow_new_parameters=False, **kwargs):
20
+
21
+ # parse and check axes
22
+ axes = axes_check_and_normalize(axes)
23
+ ax = axes_dict(axes)
24
+ ax = {a: (ax[a] is not None) for a in ax}
25
+
26
+ (ax['X'] and ax['Y']) or _raise(ValueError('lateral axes X and Y must be present.'))
27
+ # not (ax['Z'] and ax['T']) or _raise(ValueError('using Z and T axes together not supported.'))
28
+
29
+ axes.startswith('S') or (not ax['S']) or _raise(ValueError('sample axis S must be first.'))
30
+ axes = axes.replace('S','') # remove sample axis if it exists
31
+
32
+ n_dim = len(axes.replace('C',''))
33
+
34
+ # TODO: Config not independent of backend. Problem?
35
+ # could move things around during train/predict as an alternative... good idea?
36
+ # otherwise, users can choose axes of input image anyhow, so doesn't matter if model is fixed to something else
37
+ if backend_channels_last():
38
+ if ax['C']:
39
+ axes[-1] == 'C' or _raise(ValueError('channel axis must be last for backend (%s).' % K.backend()))
40
+ else:
41
+ axes += 'C'
42
+ else:
43
+ if ax['C']:
44
+ axes[0] == 'C' or _raise(ValueError('channel axis must be first for backend (%s).' % K.backend()))
45
+ else:
46
+ axes = 'C'+axes
47
+
48
+ self.n_dim = n_dim
49
+ self.axes = axes
50
+ self.n_channel_in = int(max(1,n_channel_in))
51
+ self.n_channel_out = int(max(1,n_channel_out))
52
+
53
+ self.train_checkpoint = 'weights_best.h5'
54
+ self.train_checkpoint_last = 'weights_last.h5'
55
+ self.train_checkpoint_epoch = 'weights_now.h5'
56
+
57
+ self.update_parameters(allow_new_parameters, **kwargs)
58
+
59
+
60
+ def is_valid(self, return_invalid=False):
61
+ return (True, tuple()) if return_invalid else True
62
+
63
+
64
+ def update_parameters(self, allow_new=False, **kwargs):
65
+ if not allow_new:
66
+ attr_new = []
67
+ for k in kwargs:
68
+ try:
69
+ getattr(self, k)
70
+ except AttributeError:
71
+ attr_new.append(k)
72
+ if len(attr_new) > 0:
73
+ raise AttributeError("Not allowed to add new parameters (%s)" % ', '.join(attr_new))
74
+ for k in kwargs:
75
+ setattr(self, k, kwargs[k])
76
+
77
+ @classmethod
78
+ def update_loaded_config(cls, config):
79
+ """Called by model to update loaded config dictionary before config object is created
80
+
81
+ Can be used to modify or introduce/delete parameters, e.g. to ensure
82
+ backwards compatibility after new parameters have been introduced.
83
+
84
+ Parameters
85
+ ----------
86
+ config : dict
87
+ dictionary of config parameters loaded from file
88
+
89
+ Returns
90
+ -------
91
+ updated_config: dict
92
+ an updated version of the config parameter dictionary
93
+ """
94
+ return config
95
+
96
+
97
+
98
+ class Config(BaseConfig):
99
+ """Default configuration for a CARE model.
100
+
101
+ This configuration is meant to be used with :class:`CARE`
102
+ and related models (e.g., :class:`IsotropicCARE`).
103
+
104
+ Parameters
105
+ ----------
106
+ axes : str
107
+ Axes of the neural network (channel axis optional).
108
+ n_channel_in : int
109
+ Number of channels of given input image.
110
+ n_channel_out : int
111
+ Number of channels of predicted output image.
112
+ probabilistic : bool
113
+ Probabilistic prediction of per-pixel Laplace distributions or
114
+ typical regression of per-pixel scalar values.
115
+ allow_new_parameters : bool
116
+ Allow adding new configuration attributes (i.e. not listed below).
117
+ kwargs : dict
118
+ Overwrite (or add) configuration attributes (see below).
119
+
120
+ Example
121
+ -------
122
+ >>> config = Config('YX', probabilistic=True, unet_n_depth=3)
123
+
124
+ Attributes
125
+ ----------
126
+ n_dim : int
127
+ Dimensionality of input images (2 or 3).
128
+ unet_residual : bool
129
+ Parameter `residual` of :func:`csbdeep.nets.common_unet`. Default: ``n_channel_in == n_channel_out``
130
+ unet_n_depth : int
131
+ Parameter `n_depth` of :func:`csbdeep.nets.common_unet`. Default: ``2``
132
+ unet_kern_size : int
133
+ Parameter `kern_size` of :func:`csbdeep.nets.common_unet`. Default: ``5 if n_dim==2 else 3``
134
+ unet_n_first : int
135
+ Parameter `n_first` of :func:`csbdeep.nets.common_unet`. Default: ``32``
136
+ unet_last_activation : str
137
+ Parameter `last_activation` of :func:`csbdeep.nets.common_unet`. Default: ``linear``
138
+ train_loss : str
139
+ Name of training loss. Default: ``'laplace' if probabilistic else 'mae'``
140
+ train_epochs : int
141
+ Number of training epochs. Default: ``100``
142
+ train_steps_per_epoch : int
143
+ Number of parameter update steps per epoch. Default: ``400``
144
+ train_learning_rate : float
145
+ Learning rate for training. Default: ``0.0004``
146
+ train_batch_size : int
147
+ Batch size for training. Default: ``16``
148
+ train_tensorboard : bool
149
+ Enable TensorBoard for monitoring training progress. Default: ``True``
150
+ train_checkpoint : str
151
+ Name of checkpoint file for model weights (only best are saved); set to ``None`` to disable. Default: ``weights_best.h5``
152
+ train_reduce_lr : dict
153
+ Parameter :class:`dict` of ReduceLROnPlateau_ callback; set to ``None`` to disable. Default: ``{'factor': 0.5, 'patience': 10, 'min_delta': 0}``
154
+
155
+ .. _ReduceLROnPlateau: https://keras.io/callbacks/#reducelronplateau
156
+ """
157
+
158
+ def __init__(self, axes='YX', n_channel_in=1, n_channel_out=1, probabilistic=False, allow_new_parameters=False, **kwargs):
159
+ """See class docstring."""
160
+
161
+ super(Config, self).__init__(axes, n_channel_in, n_channel_out)
162
+ not ('Z' in self.axes and 'T' in self.axes) or _raise(ValueError('using Z and T axes together not supported.'))
163
+
164
+ self.probabilistic = bool(probabilistic)
165
+
166
+ # default config (can be overwritten by kwargs below)
167
+ self.unet_residual = self.n_channel_in == self.n_channel_out
168
+ self.unet_n_depth = 2
169
+ self.unet_kern_size = 5 if self.n_dim==2 else 3
170
+ self.unet_n_first = 32
171
+ self.unet_last_activation = 'linear'
172
+ if backend_channels_last():
173
+ self.unet_input_shape = self.n_dim*(None,) + (self.n_channel_in,)
174
+ else:
175
+ self.unet_input_shape = (self.n_channel_in,) + self.n_dim*(None,)
176
+
177
+ self.train_loss = 'laplace' if self.probabilistic else 'mae'
178
+ self.train_epochs = 100
179
+ self.train_steps_per_epoch = 400
180
+ self.train_learning_rate = 0.0004
181
+ self.train_batch_size = 16
182
+ self.train_tensorboard = True
183
+
184
+ # the parameter 'min_delta' was called 'epsilon' for keras<=2.1.5
185
+ # keras.__version__ was removed in tensorflow 2.13.0
186
+ min_delta_key = 'epsilon' if Version(getattr(keras, '__version__', '9.9.9'))<=Version('2.1.5') else 'min_delta'
187
+ self.train_reduce_lr = {'factor': 0.5, 'patience': 10, min_delta_key: 0}
188
+
189
+ # disallow setting 'n_dim' manually
190
+ try:
191
+ del kwargs['n_dim']
192
+ # warnings.warn("ignoring parameter 'n_dim'")
193
+ except:
194
+ pass
195
+
196
+ self.update_parameters(allow_new_parameters, **kwargs)
197
+
198
+
199
+ def is_valid(self, return_invalid=False):
200
+ """Check if configuration is valid.
201
+
202
+ Returns
203
+ -------
204
+ bool
205
+ Flag that indicates whether the current configuration values are valid.
206
+ """
207
+ def _is_int(v,low=None,high=None):
208
+ return (
209
+ isinstance(v,int) and
210
+ (True if low is None else low <= v) and
211
+ (True if high is None else v <= high)
212
+ )
213
+
214
+ ok = {}
215
+ ok['n_dim'] = self.n_dim in (2,3)
216
+ try:
217
+ axes_check_and_normalize(self.axes,self.n_dim+1,disallowed='S')
218
+ ok['axes'] = True
219
+ except:
220
+ ok['axes'] = False
221
+ ok['n_channel_in'] = _is_int(self.n_channel_in,1)
222
+ ok['n_channel_out'] = _is_int(self.n_channel_out,1)
223
+ ok['probabilistic'] = isinstance(self.probabilistic,bool)
224
+
225
+ ok['unet_residual'] = (
226
+ isinstance(self.unet_residual,bool) and
227
+ (not self.unet_residual or (self.n_channel_in==self.n_channel_out))
228
+ )
229
+ ok['unet_n_depth'] = _is_int(self.unet_n_depth,1)
230
+ ok['unet_kern_size'] = _is_int(self.unet_kern_size,1)
231
+ ok['unet_n_first'] = _is_int(self.unet_n_first,1)
232
+ ok['unet_last_activation'] = self.unet_last_activation in ('linear','relu')
233
+ ok['unet_input_shape'] = (
234
+ isinstance(self.unet_input_shape,(list,tuple))
235
+ and len(self.unet_input_shape) == self.n_dim+1
236
+ and self.unet_input_shape[-1] == self.n_channel_in
237
+ # and all((d is None or (_is_int(d) and d%(2**self.unet_n_depth)==0) for d in self.unet_input_shape[:-1]))
238
+ )
239
+ ok['train_loss'] = (
240
+ ( self.probabilistic and self.train_loss == 'laplace' ) or
241
+ (not self.probabilistic and self.train_loss in ('mse','mae'))
242
+ )
243
+ ok['train_epochs'] = _is_int(self.train_epochs,1)
244
+ ok['train_steps_per_epoch'] = _is_int(self.train_steps_per_epoch,1)
245
+ ok['train_learning_rate'] = np.isscalar(self.train_learning_rate) and self.train_learning_rate > 0
246
+ ok['train_batch_size'] = _is_int(self.train_batch_size,1)
247
+ ok['train_tensorboard'] = isinstance(self.train_tensorboard,bool)
248
+ ok['train_checkpoint'] = self.train_checkpoint is None or isinstance(self.train_checkpoint,string_types)
249
+ ok['train_reduce_lr'] = self.train_reduce_lr is None or isinstance(self.train_reduce_lr,dict)
250
+
251
+ if return_invalid:
252
+ return all(ok.values()), tuple(k for (k,v) in ok.items() if not v)
253
+ else:
254
+ return all(ok.values())
@@ -0,0 +1,119 @@
1
+ # -*- coding: utf-8 -*-
2
+ from __future__ import print_function, unicode_literals, absolute_import, division
3
+
4
+ from collections import OrderedDict
5
+ from warnings import warn
6
+ from ..utils import _raise
7
+ from ..utils.six import Path
8
+
9
+ from packaging.version import Version
10
+ from ..utils.tf import keras_import, v_keras
11
+ get_file = keras_import('utils', 'get_file')
12
+
13
+
14
+ _MODELS = {}
15
+ _ALIASES = {}
16
+
17
+
18
+ def clear_models_and_aliases(*cls):
19
+ if len(cls) == 0:
20
+ _MODELS.clear()
21
+ _ALIASES.clear()
22
+ else:
23
+ for c in cls:
24
+ if c in _MODELS: del _MODELS[c]
25
+ if c in _ALIASES: del _ALIASES[c]
26
+
27
+
28
+ def register_model(cls, key, url, hash):
29
+ """ Example:
30
+
31
+ register_model(StarDist2D, 'my_great_model', 'https://github.com/stardist/stardist-models/releases/download/v0.1/python_2D_versatile_fluo.zip', md5sum_as_astring)
32
+
33
+ """
34
+ # key must be a valid file/folder name in the file system
35
+ models = _MODELS.setdefault(cls,OrderedDict())
36
+ key not in models or warn("re-registering model '%s' (was already registered for '%s')" % (key, cls.__name__))
37
+ models[key] = dict(url=url, hash=hash)
38
+
39
+
40
+ def register_aliases(cls, key, *names):
41
+ # aliases can be arbitrary strings
42
+ if len(names) == 0: return
43
+ models = _MODELS.get(cls,{})
44
+ key in models or _raise(ValueError("model '%s' is not registered for '%s'" % (key, cls.__name__)))
45
+ aliases = _ALIASES.setdefault(cls,OrderedDict())
46
+ for name in names:
47
+ aliases.get(name,key) == key or warn("alias '%s' was previously registered with model '%s' for '%s'" % (name, aliases[name], cls.__name__))
48
+ aliases[name] = key
49
+
50
+
51
+ def get_registered_models(cls, return_aliases=True, verbose=False):
52
+ models = _MODELS.get(cls,{})
53
+ aliases = _ALIASES.get(cls,{})
54
+ model_keys = tuple(models.keys())
55
+ model_aliases = {key: tuple(name for name in aliases if aliases[name] == key) for key in models}
56
+ if verbose:
57
+ # this code is very messy and should be refactored...
58
+ _n = len(models)
59
+ _str_model = 'model' if _n == 1 else 'models'
60
+ _str_is_are = 'is' if _n == 1 else 'are'
61
+ _str_colon = ':' if _n > 0 else ''
62
+ print("There {is_are} {n} registered {model_s} for '{clazz}'{c}".format(
63
+ n=_n, clazz=cls.__name__, is_are=_str_is_are, model_s=_str_model, c=_str_colon))
64
+ if _n > 0:
65
+ print()
66
+ _maxkeylen = 2 + max(len(key) for key in models)
67
+ print("Name{s}Alias(es)".format(s=' '*(_maxkeylen-4+3)))
68
+ print("────{s}─────────".format(s=' '*(_maxkeylen-4+3)))
69
+ for key in models:
70
+ _aliases = ' '
71
+ _m = len(model_aliases[key])
72
+ if _m > 0:
73
+ _aliases += "'%s'" % "', '".join(model_aliases[key])
74
+ else:
75
+ _aliases += "None"
76
+ _key = ("{s:%d}"%_maxkeylen).format(s="'%s'"%key)
77
+ print("{key}{aliases}".format(key=_key, aliases=_aliases))
78
+ return ((model_keys, model_aliases) if return_aliases else model_keys)
79
+
80
+
81
+ def get_model_details(cls, key_or_alias, verbose=False):
82
+ models = _MODELS.get(cls,{})
83
+ if key_or_alias in models:
84
+ key = key_or_alias
85
+ alias = None
86
+ else:
87
+ aliases = _ALIASES.get(cls,{})
88
+ alias = key_or_alias
89
+ alias in aliases or _raise(ValueError("'%s' is neither a key or alias for '%s'" % (alias, cls.__name__)))
90
+ key = aliases[alias]
91
+ if verbose:
92
+ print("Found model '{model}'{alias_str} for '{clazz}'.".format(
93
+ model=key, clazz=cls.__name__, alias_str=('' if alias is None else " with alias '%s'" % alias)))
94
+ return key, alias, models[key]
95
+
96
+
97
+ def get_model_folder(cls, key_or_alias):
98
+ key, alias, m = get_model_details(cls, key_or_alias)
99
+ target = str(Path('models') / cls.__name__ / key)
100
+ path = Path(get_file(fname=key+'.zip', origin=m['url'], file_hash=m['hash'],
101
+ cache_subdir=target, extract=True))
102
+ if v_keras >= Version("3.6.0"):
103
+ path_folder = path
104
+ suffix = "_extracted"
105
+ if path_folder.is_dir() and path_folder.name.endswith(suffix) and len(path_folder.name) > len(suffix):
106
+ path_folder = path_folder.with_name(path_folder.name[:-len(suffix)])
107
+ if not path_folder.exists():
108
+ path_folder.symlink_to(path.relative_to(path.parent))
109
+ else:
110
+ path_folder = path.parent
111
+ assert path_folder.exists()
112
+ return path_folder
113
+
114
+
115
+ def get_model_instance(cls, key_or_alias):
116
+ path = get_model_folder(cls, key_or_alias)
117
+ model = cls(config=None, name=path.stem, basedir=path.parent)
118
+ model.basedir = None # make read-only
119
+ return model
@@ -0,0 +1,180 @@
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import print_function, unicode_literals, absolute_import, division
4
+
5
+ import argparse
6
+ import sys
7
+ from pprint import pprint
8
+
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+
12
+ from csbdeep.io import save_tiff_imagej_compatible
13
+ from csbdeep.utils import _raise, axes_check_and_normalize
14
+ from csbdeep.utils.six import Path
15
+
16
+
17
+ def str2bool(v):
18
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
19
+ return True
20
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
21
+ return False
22
+ else:
23
+ raise argparse.ArgumentTypeError('Boolean value expected.')
24
+
25
+
26
+ def parse_args():
27
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
28
+
29
+ parser.add_argument('--quiet', metavar='', type=str2bool, required=False, const=True, nargs='?', default=False, help="don't show status messages")
30
+ parser.add_argument('--gpu-memory-limit', metavar='', type=float, required=False, default=None, help="limit GPU memory to this fraction (0...1)")
31
+
32
+ data = parser.add_argument_group("input")
33
+ data.add_argument('--input-dir', metavar='', type=str, required=False, default=None, help="path to folder with input images")
34
+ data.add_argument('--input-pattern', metavar='', type=str, required=False, default='*.tif*', help="glob-style file name pattern of input images")
35
+ data.add_argument('--input-axes', metavar='', type=str, required=False, default=None, help="axes string of input images")
36
+ data.add_argument('--norm-pmin', metavar='', type=float, required=False, default=2, help="'pmin' for PercentileNormalizer")
37
+ data.add_argument('--norm-pmax', metavar='', type=float, required=False, default=99.8, help="'pmax' for PercentileNormalizer")
38
+ data.add_argument('--norm-undo', metavar='', type=str2bool, required=False, const=True, nargs='?', default=True, help="'do_after' for PercentileNormalizer")
39
+ data.add_argument('--n-tiles', metavar='', type=int, required=False, nargs='+', default=None, help="number of tiles for prediction")
40
+
41
+ model = parser.add_argument_group("model")
42
+ model.add_argument('--model-basedir', metavar='', type=str, required=False, default=None, help="path to folder that contains CARE model")
43
+ model.add_argument('--model-name', metavar='', type=str, required=False, default=None, help="name of CARE model")
44
+ model.add_argument('--model-weights', metavar='', type=str, required=False, default=None, help="specific name of weights file to load (located in model folder)")
45
+
46
+ output = parser.add_argument_group("output")
47
+ output.add_argument('--output-dir', metavar='', type=str, required=False, default=None, help="path to folder where restored images will be saved")
48
+ output.add_argument('--output-name', metavar='', type=str, required=False, default='{model_name}/{file_path}/{file_name}{file_ext}', help="name pattern of restored image (special tokens: {file_path}, {file_name}, {file_ext}, {model_name}, {model_weights})")
49
+ output.add_argument('--output-dtype', metavar='', type=str, required=False, default='float32', help="data type of the saved tiff file")
50
+ output.add_argument('--imagej-tiff', metavar='', type=str2bool, required=False, const=True, nargs='?', default=True, help="save restored image as ImageJ-compatible TIFF file")
51
+ output.add_argument('--dry-run', metavar='', type=str2bool, required=False, const=True, nargs='?', default=False, help="don't save restored images")
52
+
53
+ return parser, parser.parse_args()
54
+
55
+
56
+ def main():
57
+ if not ('__file__' in locals() or '__file__' in globals()):
58
+ print('running interactively, exiting.')
59
+ sys.exit(0)
60
+
61
+ # parse arguments
62
+ parser, args = parse_args()
63
+ args_dict = vars(args)
64
+
65
+ # exit and show help if no arguments provided at all
66
+ if len(sys.argv) == 1:
67
+ parser.print_help()
68
+ sys.exit(0)
69
+
70
+ # check for required arguments manually (because of argparse issue)
71
+ required = ('--input-dir','--input-axes', '--norm-pmin', '--norm-pmax', '--model-basedir', '--model-name', '--output-dir')
72
+ for r in required:
73
+ dest = r[2:].replace('-','_')
74
+ if args_dict[dest] is None:
75
+ parser.print_usage(file=sys.stderr)
76
+ print("%s: error: the following arguments are required: %s" % (parser.prog,r), file=sys.stderr)
77
+ sys.exit(1)
78
+
79
+ # show effective arguments (including defaults)
80
+ if not args.quiet:
81
+ print('Arguments')
82
+ print('---------')
83
+ pprint(args_dict)
84
+ print()
85
+ sys.stdout.flush()
86
+
87
+ # logging function
88
+ log = (lambda *a,**k: None) if args.quiet else tqdm.write
89
+
90
+ # get list of input files and exit if there are none
91
+ file_list = list(Path(args.input_dir).glob(args.input_pattern))
92
+ if len(file_list) == 0:
93
+ log("No files to process in '%s' with pattern '%s'." % (args.input_dir,args.input_pattern))
94
+ sys.exit(0)
95
+
96
+ # delay imports after checking to all required arguments are provided
97
+ from tifffile import imread
98
+ try:
99
+ from tifffile import imwrite as imsave
100
+ except ImportError:
101
+ from tifffile import imsave
102
+ from csbdeep.utils.tf import BACKEND as K
103
+ from csbdeep.models import CARE
104
+ from csbdeep.data import PercentileNormalizer
105
+ sys.stdout.flush()
106
+ sys.stderr.flush()
107
+
108
+ # limit gpu memory
109
+ if args.gpu_memory_limit is not None:
110
+ from csbdeep.utils.tf import limit_gpu_memory
111
+ limit_gpu_memory(args.gpu_memory_limit)
112
+
113
+ # create CARE model and load weights, create normalizer
114
+ K.clear_session()
115
+ model = CARE(config=None, name=args.model_name, basedir=args.model_basedir)
116
+ if args.model_weights is not None:
117
+ print("Loading network weights from '%s'." % args.model_weights)
118
+ model.load_weights(args.model_weights)
119
+ normalizer = PercentileNormalizer(pmin=args.norm_pmin, pmax=args.norm_pmax, do_after=args.norm_undo)
120
+
121
+ n_tiles = args.n_tiles
122
+ if n_tiles is not None and len(n_tiles)==1:
123
+ n_tiles = n_tiles[0]
124
+
125
+ processed = []
126
+
127
+ # process all files
128
+ for file_in in tqdm(file_list, disable=args.quiet or (n_tiles is not None and np.prod(n_tiles)>1)):
129
+ # construct output file name
130
+ file_out = Path(args.output_dir) / args.output_name.format (
131
+ file_path = str(file_in.relative_to(args.input_dir).parent),
132
+ file_name = file_in.stem, file_ext = file_in.suffix,
133
+ model_name = args.model_name, model_weights = Path(args.model_weights).stem if args.model_weights is not None else None
134
+ )
135
+
136
+ # checks
137
+ (file_in.suffix.lower() in ('.tif','.tiff') and
138
+ file_out.suffix.lower() in ('.tif','.tiff')) or _raise(ValueError('only tiff files supported.'))
139
+
140
+ # load and predict restored image
141
+ img = imread(str(file_in))
142
+ restored = model.predict(img, axes=args.input_axes, normalizer=normalizer, n_tiles=n_tiles)
143
+
144
+ # restored image could be multi-channel even if input image is not
145
+ axes_out = axes_check_and_normalize(args.input_axes)
146
+ if restored.ndim > img.ndim:
147
+ assert restored.ndim == img.ndim + 1
148
+ assert 'C' not in axes_out
149
+ axes_out += 'C'
150
+
151
+ # convert data type (if necessary)
152
+ restored = restored.astype(np.dtype(args.output_dtype), copy=False)
153
+
154
+ # save to disk
155
+ if not args.dry_run:
156
+ file_out.parent.mkdir(parents=True, exist_ok=True)
157
+ if args.imagej_tiff:
158
+ save_tiff_imagej_compatible(str(file_out), restored, axes_out)
159
+ else:
160
+ imsave(str(file_out), restored)
161
+
162
+ processed.append((file_in,file_out))
163
+
164
+
165
+ # print summary of processed files
166
+ if not args.quiet:
167
+ sys.stdout.flush()
168
+ sys.stderr.flush()
169
+ n_processed = len(processed)
170
+ len_processed = len(str(n_processed))
171
+ log('Finished processing %d %s' % (n_processed, 'files' if n_processed > 1 else 'file'))
172
+ log('-' * (26+len_processed if n_processed > 1 else 26))
173
+ for i,(file_in,file_out) in enumerate(processed):
174
+ len_file = max(len(str(file_in)),len(str(file_out)))
175
+ log(('{:>%d}. in : {:>%d}'%(len_processed,len_file)).format(1+i,str(file_in)))
176
+ log(('{:>%d} out: {:>%d}'%(len_processed,len_file)).format('',str(file_out)))
177
+
178
+
179
+ if __name__ == '__main__':
180
+ main()
@@ -0,0 +1,5 @@
1
+ from __future__ import print_function, unicode_literals, absolute_import, division
2
+
3
+ from .plot_utils import *
4
+ from .utils import *
5
+ from .utils import _raise