senoquant 1.0.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. senoquant/__init__.py +6 -0
  2. senoquant/_reader.py +7 -0
  3. senoquant/_widget.py +33 -0
  4. senoquant/napari.yaml +83 -0
  5. senoquant/reader/__init__.py +5 -0
  6. senoquant/reader/core.py +369 -0
  7. senoquant/tabs/__init__.py +15 -0
  8. senoquant/tabs/batch/__init__.py +10 -0
  9. senoquant/tabs/batch/backend.py +641 -0
  10. senoquant/tabs/batch/config.py +270 -0
  11. senoquant/tabs/batch/frontend.py +1283 -0
  12. senoquant/tabs/batch/io.py +326 -0
  13. senoquant/tabs/batch/layers.py +86 -0
  14. senoquant/tabs/quantification/__init__.py +1 -0
  15. senoquant/tabs/quantification/backend.py +228 -0
  16. senoquant/tabs/quantification/features/__init__.py +80 -0
  17. senoquant/tabs/quantification/features/base.py +142 -0
  18. senoquant/tabs/quantification/features/marker/__init__.py +5 -0
  19. senoquant/tabs/quantification/features/marker/config.py +69 -0
  20. senoquant/tabs/quantification/features/marker/dialog.py +437 -0
  21. senoquant/tabs/quantification/features/marker/export.py +879 -0
  22. senoquant/tabs/quantification/features/marker/feature.py +119 -0
  23. senoquant/tabs/quantification/features/marker/morphology.py +285 -0
  24. senoquant/tabs/quantification/features/marker/rows.py +654 -0
  25. senoquant/tabs/quantification/features/marker/thresholding.py +46 -0
  26. senoquant/tabs/quantification/features/roi.py +346 -0
  27. senoquant/tabs/quantification/features/spots/__init__.py +5 -0
  28. senoquant/tabs/quantification/features/spots/config.py +62 -0
  29. senoquant/tabs/quantification/features/spots/dialog.py +477 -0
  30. senoquant/tabs/quantification/features/spots/export.py +1292 -0
  31. senoquant/tabs/quantification/features/spots/feature.py +112 -0
  32. senoquant/tabs/quantification/features/spots/morphology.py +279 -0
  33. senoquant/tabs/quantification/features/spots/rows.py +241 -0
  34. senoquant/tabs/quantification/frontend.py +815 -0
  35. senoquant/tabs/segmentation/__init__.py +1 -0
  36. senoquant/tabs/segmentation/backend.py +131 -0
  37. senoquant/tabs/segmentation/frontend.py +1009 -0
  38. senoquant/tabs/segmentation/models/__init__.py +5 -0
  39. senoquant/tabs/segmentation/models/base.py +146 -0
  40. senoquant/tabs/segmentation/models/cpsam/details.json +65 -0
  41. senoquant/tabs/segmentation/models/cpsam/model.py +150 -0
  42. senoquant/tabs/segmentation/models/default_2d/details.json +69 -0
  43. senoquant/tabs/segmentation/models/default_2d/model.py +664 -0
  44. senoquant/tabs/segmentation/models/default_3d/details.json +69 -0
  45. senoquant/tabs/segmentation/models/default_3d/model.py +682 -0
  46. senoquant/tabs/segmentation/models/hf.py +71 -0
  47. senoquant/tabs/segmentation/models/nuclear_dilation/__init__.py +1 -0
  48. senoquant/tabs/segmentation/models/nuclear_dilation/details.json +26 -0
  49. senoquant/tabs/segmentation/models/nuclear_dilation/model.py +96 -0
  50. senoquant/tabs/segmentation/models/perinuclear_rings/__init__.py +1 -0
  51. senoquant/tabs/segmentation/models/perinuclear_rings/details.json +34 -0
  52. senoquant/tabs/segmentation/models/perinuclear_rings/model.py +132 -0
  53. senoquant/tabs/segmentation/stardist_onnx_utils/__init__.py +2 -0
  54. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/__init__.py +3 -0
  55. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/__init__.py +6 -0
  56. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/generate.py +470 -0
  57. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/prepare.py +273 -0
  58. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/rawdata.py +112 -0
  59. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/transform.py +384 -0
  60. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/__init__.py +0 -0
  61. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/blocks.py +184 -0
  62. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/losses.py +79 -0
  63. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/nets.py +165 -0
  64. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/predict.py +467 -0
  65. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/probability.py +67 -0
  66. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/train.py +148 -0
  67. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/io/__init__.py +163 -0
  68. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/__init__.py +52 -0
  69. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/base_model.py +329 -0
  70. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_isotropic.py +160 -0
  71. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_projection.py +178 -0
  72. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_standard.py +446 -0
  73. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_upsampling.py +54 -0
  74. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/config.py +254 -0
  75. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/pretrained.py +119 -0
  76. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/__init__.py +0 -0
  77. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/care_predict.py +180 -0
  78. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/__init__.py +5 -0
  79. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/plot_utils.py +159 -0
  80. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/six.py +18 -0
  81. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/tf.py +644 -0
  82. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/utils.py +272 -0
  83. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/version.py +1 -0
  84. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/docs/source/conf.py +368 -0
  85. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/setup.py +68 -0
  86. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_datagen.py +169 -0
  87. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_models.py +462 -0
  88. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_utils.py +166 -0
  89. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tools/create_zip_contents.py +34 -0
  90. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/__init__.py +30 -0
  91. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/big.py +624 -0
  92. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/bioimageio_utils.py +494 -0
  93. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/data/__init__.py +39 -0
  94. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/__init__.py +10 -0
  95. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom2d.py +215 -0
  96. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom3d.py +349 -0
  97. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/matching.py +483 -0
  98. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/__init__.py +28 -0
  99. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/base.py +1217 -0
  100. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model2d.py +594 -0
  101. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model3d.py +696 -0
  102. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/nms.py +384 -0
  103. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/__init__.py +2 -0
  104. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/plot.py +74 -0
  105. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/render.py +298 -0
  106. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/rays3d.py +373 -0
  107. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/sample_patches.py +65 -0
  108. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/__init__.py +0 -0
  109. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict2d.py +90 -0
  110. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict3d.py +93 -0
  111. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/utils.py +408 -0
  112. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/version.py +1 -0
  113. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/__init__.py +45 -0
  114. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/__init__.py +17 -0
  115. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/cli.py +55 -0
  116. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/core.py +285 -0
  117. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/__init__.py +15 -0
  118. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/cli.py +36 -0
  119. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/divisibility.py +193 -0
  120. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/probe.py +100 -0
  121. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/receptive_field.py +182 -0
  122. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/rf_cli.py +48 -0
  123. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/valid_sizes.py +278 -0
  124. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/__init__.py +8 -0
  125. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/core.py +157 -0
  126. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/__init__.py +17 -0
  127. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/core.py +226 -0
  128. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/__init__.py +5 -0
  129. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/core.py +401 -0
  130. senoquant/tabs/settings/__init__.py +1 -0
  131. senoquant/tabs/settings/backend.py +29 -0
  132. senoquant/tabs/settings/frontend.py +19 -0
  133. senoquant/tabs/spots/__init__.py +1 -0
  134. senoquant/tabs/spots/backend.py +139 -0
  135. senoquant/tabs/spots/frontend.py +800 -0
  136. senoquant/tabs/spots/models/__init__.py +5 -0
  137. senoquant/tabs/spots/models/base.py +94 -0
  138. senoquant/tabs/spots/models/rmp/details.json +61 -0
  139. senoquant/tabs/spots/models/rmp/model.py +499 -0
  140. senoquant/tabs/spots/models/udwt/details.json +103 -0
  141. senoquant/tabs/spots/models/udwt/model.py +482 -0
  142. senoquant/utils.py +25 -0
  143. senoquant-1.0.0b1.dist-info/METADATA +193 -0
  144. senoquant-1.0.0b1.dist-info/RECORD +148 -0
  145. senoquant-1.0.0b1.dist-info/WHEEL +5 -0
  146. senoquant-1.0.0b1.dist-info/entry_points.txt +2 -0
  147. senoquant-1.0.0b1.dist-info/licenses/LICENSE +28 -0
  148. senoquant-1.0.0b1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,148 @@
1
+ from __future__ import print_function, unicode_literals, absolute_import, division
2
+ from six.moves import range, zip, map, reduce, filter
3
+
4
+ from ..utils import _raise, move_channel_for_backend, axes_dict, axes_check_and_normalize, backend_channels_last
5
+ from ..internals.losses import loss_laplace, loss_mse, loss_mae, loss_thresh_weighted_decay
6
+
7
+ import numpy as np
8
+
9
+ from ..utils.tf import keras_import, BACKEND as K
10
+ Callback, TerminateOnNaN = keras_import('callbacks', 'Callback', 'TerminateOnNaN')
11
+ Sequence = keras_import('utils', 'Sequence')
12
+ Optimizer = keras_import('optimizers', 'Optimizer')
13
+
14
+
15
+ class ParameterDecayCallback(Callback):
16
+ """ TODO """
17
+ def __init__(self, parameter, decay, name=None, verbose=0):
18
+ self.parameter = parameter
19
+ self.decay = decay
20
+ self.name = name
21
+ self.verbose = verbose
22
+
23
+ def on_epoch_end(self, epoch, logs=None):
24
+ old_val = K.get_value(self.parameter)
25
+ if self.name:
26
+ logs = logs or {}
27
+ logs[self.name] = old_val
28
+ new_val = old_val * (1. / (1. + self.decay * (epoch + 1)))
29
+ K.set_value(self.parameter, new_val)
30
+ if self.verbose:
31
+ print("\n[ParameterDecayCallback] new %s: %s\n" % (self.name if self.name else 'parameter', new_val))
32
+
33
+
34
+ def prepare_model(model, optimizer, loss, metrics=('mse','mae'),
35
+ loss_bg_thresh=0, loss_bg_decay=0.06, Y=None):
36
+ """ TODO """
37
+
38
+ isinstance(optimizer,Optimizer) or _raise(ValueError())
39
+
40
+
41
+ loss_standard = eval('loss_%s()'%loss)
42
+ _metrics = [eval('loss_%s()'%m) for m in metrics]
43
+ callbacks = [TerminateOnNaN()]
44
+
45
+ # checks
46
+ assert 0 <= loss_bg_thresh <= 1
47
+ assert loss_bg_thresh == 0 or Y is not None
48
+ if loss == 'laplace':
49
+ assert K.image_data_format() == "channels_last", "TODO"
50
+ assert list(model.output.shape)[-1] >= 2 and list(model.output.shape)[-1] % 2 == 0
51
+
52
+ # loss
53
+ if loss_bg_thresh == 0:
54
+ _loss = loss_standard
55
+ else:
56
+ freq = np.mean(Y > loss_bg_thresh)
57
+ # print("class frequency:", freq)
58
+ alpha = K.variable(1.0)
59
+ loss_per_pixel = eval('loss_{loss}(mean=False)'.format(loss=loss))
60
+ _loss = loss_thresh_weighted_decay(loss_per_pixel, loss_bg_thresh,
61
+ 0.5 / (0.1 + (1 - freq)),
62
+ 0.5 / (0.1 + freq),
63
+ alpha)
64
+ callbacks.append(ParameterDecayCallback(alpha, loss_bg_decay, name='alpha'))
65
+ if not loss in metrics:
66
+ _metrics.append(loss_standard)
67
+
68
+
69
+ # compile model
70
+ model.compile(optimizer=optimizer, loss=_loss, metrics=_metrics)
71
+
72
+ return callbacks
73
+
74
+
75
+ class RollingSequence(Sequence):
76
+ """Helper class for creating batches for rolling sequence.
77
+
78
+ Create batches of size `batch_size` that contain indices in `range(data_size)`.
79
+ To that end, the data indices are repeated (rolling), either in ascending order or
80
+ shuffled if `shuffle=True`. If taking batches sequentially, all data indices will
81
+ appear equally often. All calls to `batch(i)` will return the same batch for same i.
82
+ Parameter `length` will only determine the result of `len`, it has no effect otherwise.
83
+ Note that batch_size is allowed to be larger than data_size.
84
+ """
85
+
86
+ def __init__(self, data_size, batch_size, length=None, shuffle=True, rng=None, keras_kwargs=None):
87
+ super(RollingSequence, self).__init__(**({} if keras_kwargs is None else keras_kwargs))
88
+ # print(f"### __init__", flush=True)
89
+ if rng is None: rng = np.random
90
+ self.data_size = int(data_size)
91
+ self.batch_size = int(batch_size)
92
+ self.length = 2**63-1 if length is None else int(length) # 2**63-1 is max possible value
93
+ self.shuffle = bool(shuffle)
94
+ self.index_gen = rng.permutation if self.shuffle else np.arange
95
+ self.index_map = {}
96
+
97
+ def __len__(self):
98
+ # print(f"### __len__ = {self.length}", flush=True)
99
+ return self.length
100
+
101
+ def _index(self, loop):
102
+ if loop in self.index_map:
103
+ return self.index_map[loop]
104
+ else:
105
+ return self.index_map.setdefault(loop, self.index_gen(self.data_size))
106
+
107
+ def on_epoch_end(self):
108
+ # print(f"### on_epoch_end", flush=True)
109
+ pass
110
+
111
+ def __iter__(self):
112
+ # print(f"### __iter__", flush=True)
113
+ for i in range(len(self)):
114
+ yield self[i]
115
+
116
+ def batch(self, i):
117
+ pos = i * self.batch_size
118
+ loop = pos // self.data_size
119
+ pos_loop = pos % self.data_size
120
+ sl = slice(pos_loop, pos_loop + self.batch_size)
121
+ index = self._index(loop)
122
+ _loop = loop
123
+ while sl.stop > len(index):
124
+ _loop += 1
125
+ index = np.concatenate((index, self._index(_loop)))
126
+ # print(f"### - batch({i:02}) -> {tuple(index[sl])}", flush=True)
127
+ return index[sl]
128
+
129
+ def __getitem__(self, i):
130
+ return self.batch(i)
131
+
132
+
133
+ class DataWrapper(RollingSequence):
134
+
135
+ def __init__(self, X, Y, batch_size, length, augmenter=None, keras_kwargs=None):
136
+ super(DataWrapper, self).__init__(data_size=len(X), batch_size=batch_size, length=length, shuffle=True, keras_kwargs=keras_kwargs)
137
+ len(X) == len(Y) or _raise(ValueError("X and Y must have same length"))
138
+ self.X, self.Y = X, Y
139
+ self.augmenter = augmenter
140
+
141
+ def __getitem__(self, i):
142
+ idx = self.batch(i)
143
+ X, Y = self.X[idx], self.Y[idx]
144
+ if self.augmenter is not None:
145
+ X,Y = tuple(zip(*tuple(self.augmenter(x,y) for x,y in zip(X,Y))))
146
+ X,Y = np.stack(X), np.stack(Y)
147
+
148
+ return X,Y
@@ -0,0 +1,163 @@
1
+ # -*- coding: utf-8 -*-
2
+ from __future__ import print_function, unicode_literals, absolute_import, division
3
+ from six.moves import range, zip, map, reduce, filter
4
+ from six import string_types
5
+
6
+ import numpy as np
7
+ try:
8
+ from tifffile import imwrite as imsave
9
+ except ImportError:
10
+ from tifffile import imsave
11
+ import warnings
12
+
13
+ from ..utils import _raise, axes_check_and_normalize, axes_dict, move_image_axes, move_channel_for_backend, backend_channels_last
14
+ from ..utils.six import Path
15
+
16
+
17
+
18
+ def save_tiff_imagej_compatible(file, img, axes, **imsave_kwargs):
19
+ """Save image in ImageJ-compatible TIFF format.
20
+
21
+ Parameters
22
+ ----------
23
+ file : str
24
+ File name
25
+ img : numpy.ndarray
26
+ Image
27
+ axes: str
28
+ Axes of ``img``
29
+ imsave_kwargs : dict, optional
30
+ Keyword arguments for :func:`tifffile.imsave`
31
+
32
+ """
33
+ axes = axes_check_and_normalize(axes,img.ndim,disallowed='S')
34
+
35
+ # convert to imagej-compatible data type
36
+ t = img.dtype
37
+ if 'float' in t.name: t_new = np.float32
38
+ elif 'uint' in t.name: t_new = np.uint16 if t.itemsize >= 2 else np.uint8
39
+ elif 'int' in t.name: t_new = np.int16
40
+ else: t_new = t
41
+ img = img.astype(t_new, copy=False)
42
+ if t != t_new:
43
+ warnings.warn("Converting data type from '%s' to ImageJ-compatible '%s'." % (t, np.dtype(t_new)))
44
+
45
+ # move axes to correct positions for imagej
46
+ img = move_image_axes(img, axes, 'TZCYX', True)
47
+
48
+ imsave_kwargs['imagej'] = True
49
+ imsave(file, img, **imsave_kwargs)
50
+
51
+
52
+
53
+ def load_training_data(file, validation_split=0, axes=None, n_images=None, verbose=False):
54
+ """Load training data from file in ``.npz`` format.
55
+
56
+ The data file is expected to have the keys:
57
+
58
+ - ``X`` : Array of training input images.
59
+ - ``Y`` : Array of corresponding target images.
60
+ - ``axes`` : Axes of the training images.
61
+
62
+
63
+ Parameters
64
+ ----------
65
+ file : str
66
+ File name
67
+ validation_split : float
68
+ Fraction of images to use as validation set during training.
69
+ axes: str, optional
70
+ Must be provided in case the loaded data does not contain ``axes`` information.
71
+ n_images : int, optional
72
+ Can be used to limit the number of images loaded from data.
73
+ verbose : bool, optional
74
+ Can be used to display information about the loaded images.
75
+
76
+ Returns
77
+ -------
78
+ tuple( tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`), tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`), str )
79
+ Returns two tuples (`X_train`, `Y_train`), (`X_val`, `Y_val`) of training and validation sets
80
+ and the axes of the input images.
81
+ The tuple of validation data will be ``None`` if ``validation_split = 0``.
82
+
83
+ """
84
+
85
+ f = np.load(file)
86
+ X, Y = f['X'], f['Y']
87
+ if axes is None:
88
+ axes = f['axes']
89
+ axes = axes_check_and_normalize(axes)
90
+
91
+ # assert X.shape == Y.shape
92
+ assert X.ndim == Y.ndim
93
+ assert len(axes) == X.ndim
94
+ assert 'C' in axes
95
+ if n_images is None:
96
+ n_images = X.shape[0]
97
+ assert X.shape[0] == Y.shape[0]
98
+ assert 0 < n_images <= X.shape[0]
99
+ assert 0 <= validation_split < 1
100
+
101
+ X, Y = X[:n_images], Y[:n_images]
102
+ channel = axes_dict(axes)['C']
103
+
104
+ if validation_split > 0:
105
+ n_val = int(round(n_images * validation_split))
106
+ n_train = n_images - n_val
107
+ assert 0 < n_val and 0 < n_train
108
+ X_t, Y_t = X[-n_val:], Y[-n_val:]
109
+ X, Y = X[:n_train], Y[:n_train]
110
+ assert X.shape[0] == n_train and X_t.shape[0] == n_val
111
+ X_t = move_channel_for_backend(X_t,channel=channel)
112
+ Y_t = move_channel_for_backend(Y_t,channel=channel)
113
+
114
+ X = move_channel_for_backend(X,channel=channel)
115
+ Y = move_channel_for_backend(Y,channel=channel)
116
+
117
+ axes = axes.replace('C','') # remove channel
118
+ if backend_channels_last():
119
+ axes = axes+'C'
120
+ else:
121
+ axes = axes[:1]+'C'+axes[1:]
122
+
123
+ data_val = (X_t,Y_t) if validation_split > 0 else None
124
+
125
+ if verbose:
126
+ ax = axes_dict(axes)
127
+ n_train, n_val = len(X), len(X_t) if validation_split>0 else 0
128
+ image_size = tuple( X.shape[ax[a]] for a in axes if a in 'TZYX' )
129
+ n_dim = len(image_size)
130
+ n_channel_in, n_channel_out = X.shape[ax['C']], Y.shape[ax['C']]
131
+
132
+ print('number of training images:\t', n_train)
133
+ print('number of validation images:\t', n_val)
134
+ print('image size (%dD):\t\t'%n_dim, image_size)
135
+ print('axes:\t\t\t\t', axes)
136
+ print('channels in / out:\t\t', n_channel_in, '/', n_channel_out)
137
+
138
+ return (X,Y), data_val, axes
139
+
140
+
141
+
142
+ def save_training_data(file, X, Y, axes):
143
+ """Save training data in ``.npz`` format.
144
+
145
+ Parameters
146
+ ----------
147
+ file : str
148
+ File name
149
+ X : :class:`numpy.ndarray`
150
+ Array of patches extracted from source images.
151
+ Y : :class:`numpy.ndarray`
152
+ Array of corresponding target patches.
153
+ axes : str
154
+ Axes of the extracted patches.
155
+
156
+ """
157
+ isinstance(file,(Path,string_types)) or _raise(ValueError())
158
+ file = Path(file).with_suffix('.npz')
159
+ file.parent.mkdir(parents=True,exist_ok=True)
160
+
161
+ axes = axes_check_and_normalize(axes)
162
+ len(axes) == X.ndim or _raise(ValueError())
163
+ np.savez(str(file), X=X, Y=Y, axes=axes)
@@ -0,0 +1,52 @@
1
+ from __future__ import absolute_import, print_function
2
+
3
+ # checks
4
+ try:
5
+ import tensorflow
6
+ del tensorflow
7
+ except ModuleNotFoundError as e:
8
+ from six import raise_from
9
+ raise_from(RuntimeError('Please install TensorFlow: https://www.tensorflow.org/install/'), e)
10
+
11
+
12
+ import tensorflow, sys
13
+ from ..utils.tf import keras_import, IS_TF_1, BACKEND as K
14
+
15
+ if IS_TF_1:
16
+ try:
17
+ import keras
18
+ del keras
19
+ except ModuleNotFoundError as e:
20
+ if e.name in {'theano','cntk'}:
21
+ from six import raise_from
22
+ raise_from(RuntimeError(
23
+ "Keras is configured to use the '%s' backend, which is not installed. "
24
+ "Please change it to use 'tensorflow' instead: "
25
+ "https://keras.io/getting-started/faq/#where-is-the-keras-configuration-file-stored" % e.name
26
+ ), e)
27
+ else:
28
+ raise e
29
+
30
+ if K.backend() != 'tensorflow':
31
+ raise NotImplementedError(
32
+ "Keras is configured to use the '%s' backend, which is currently not supported. "
33
+ "Please configure Keras to use 'tensorflow' instead." % K.backend()
34
+ )
35
+
36
+ if K.image_data_format() != 'channels_last':
37
+ raise NotImplementedError(
38
+ "Keras is configured to use the '%s' image data format, which is currently not supported. "
39
+ "Please change it to use 'channels_last' instead: "
40
+ "https://keras.io/getting-started/faq/#where-is-the-keras-configuration-file-stored" % K.image_data_format()
41
+ )
42
+ del tensorflow, sys, keras_import, IS_TF_1, K
43
+
44
+
45
+ # imports
46
+ from .config import BaseConfig, Config
47
+ from .base_model import BaseModel
48
+ from .care_standard import CARE
49
+ from .care_upsampling import UpsamplingCARE
50
+ from .care_isotropic import IsotropicCARE
51
+ from .care_projection import ProjectionConfig, ProjectionCARE
52
+ from .pretrained import register_model, register_aliases, clear_models_and_aliases
@@ -0,0 +1,329 @@
1
+ # -*- coding: utf-8 -*-
2
+ from __future__ import print_function, unicode_literals, absolute_import, division
3
+
4
+ import datetime
5
+ import warnings
6
+ import sys
7
+
8
+ import numpy as np
9
+ from six import string_types, PY2
10
+ from functools import wraps
11
+
12
+ from .config import BaseConfig
13
+ from ..utils import _raise, load_json, save_json, axes_check_and_normalize, axes_dict, move_image_axes
14
+ from ..utils.six import Path, FileNotFoundError
15
+ from ..utils.tf import keras_import, IS_KERAS_3_PLUS
16
+ from ..data import Normalizer, NoNormalizer
17
+ from ..data import Resizer, NoResizer
18
+ from .pretrained import get_model_details, get_model_instance, get_registered_models
19
+
20
+ from six import add_metaclass
21
+ from abc import ABCMeta, abstractmethod, abstractproperty
22
+
23
+
24
+
25
+ def suppress_without_basedir(warn):
26
+ def _suppress_without_basedir(f):
27
+ @wraps(f)
28
+ def wrapper(*args, **kwargs):
29
+ self = args[0]
30
+ if self.basedir is None:
31
+ warn is False or warnings.warn("Suppressing call of '%s' (due to basedir=None)." % f.__name__)
32
+ else:
33
+ return f(*args, **kwargs)
34
+ return wrapper
35
+ return _suppress_without_basedir
36
+
37
+
38
+
39
+ if IS_KERAS_3_PLUS:
40
+ import h5py
41
+ from packaging.version import Version
42
+ from keras.src.legacy.saving import legacy_h5_format
43
+ from keras.src import backend, __version__ as keras_version
44
+
45
+ if Version(keras_version) >= Version("3.3.0"):
46
+ save_weights_to_hdf5_group = legacy_h5_format.save_weights_to_hdf5_group
47
+ else:
48
+ def save_weights_to_hdf5_group(f, model):
49
+ legacy_h5_format.save_attributes_to_hdf5_group(f, "layer_names", [layer.name.encode("utf8") for layer in model.layers])
50
+ f.attrs["backend"] = backend.backend().encode("utf8")
51
+ f.attrs["keras_version"] = str(keras_version).encode("utf8")
52
+ for layer in sorted(model.layers, key=lambda x: x.name):
53
+ g = f.create_group(layer.name)
54
+ weights = legacy_h5_format._legacy_weights(layer)
55
+ save_subset_weights_to_hdf5_group(g, weights)
56
+ g = f.create_group("top_level_model_weights")
57
+ weights = [v for v in model._trainable_variables + model._non_trainable_variables if v in model.weights]
58
+ save_subset_weights_to_hdf5_group(g, weights)
59
+
60
+ def save_subset_weights_to_hdf5_group(f, weights):
61
+ # FIX: use w.path instead of w.name to avoid name collisions (for "functional" layers)
62
+ # -> has been fixed since keras 3.3.0: https://github.com/keras-team/keras/blob/v3.3.0/keras/src/legacy/saving/legacy_h5_format.py#L234
63
+ weight_names = [w.path.encode("utf8") for w in weights]
64
+ weight_values = [backend.convert_to_numpy(w) for w in weights]
65
+ legacy_h5_format.save_attributes_to_hdf5_group(f, "weight_names", weight_names)
66
+ for name, val in zip(weight_names, weight_values):
67
+ param_dset = f.create_dataset(name, val.shape, dtype=val.dtype)
68
+ param_dset[() if not val.shape else slice(None)] = val
69
+
70
+ def _keras3_monkey_patch_legacy_weights(model):
71
+ ref_save_weights = model.save_weights
72
+
73
+ def save_weights(self, filepath, overwrite=True):
74
+ p = Path(filepath)
75
+ if not overwrite and p.exists():
76
+ raise FileExistsError(f"Weights file already exists: {str(p.resolve())}")
77
+ if p.name.endswith(".weights.h5"):
78
+ warnings.warn("Detected filename suffix '.weights.h5', thus saving in newer Keras 3.x file format (cannot be loaded with Keras 2.x)")
79
+ if not p.name.endswith(".weights.h5"):
80
+ with h5py.File(str(p), "w") as f:
81
+ save_weights_to_hdf5_group(f, self)
82
+ else:
83
+ return ref_save_weights(filepath, overwrite=overwrite)
84
+
85
+ model.save_weights = save_weights.__get__(model)
86
+
87
+
88
+
89
+ @add_metaclass(ABCMeta)
90
+ class BaseModel(object):
91
+ """Base model.
92
+
93
+ Subclasses must implement :func:`_build` and :func:`_config_class`.
94
+
95
+ Parameters
96
+ ----------
97
+ config : Subclass of :class:`csbdeep.models.BaseConfig` or None
98
+ Valid configuration of a model (see :func:`BaseConfig.is_valid`).
99
+ Will be saved to disk as JSON (``config.json``).
100
+ If set to ``None``, will be loaded from disk (must exist).
101
+ name : str or None
102
+ Model name. Uses a timestamp if set to ``None`` (default).
103
+ basedir : str
104
+ Directory that contains (or will contain) a folder with the given model name.
105
+ Use ``None`` to disable saving (or loading) any data to (or from) disk (regardless of other parameters).
106
+
107
+ Raises
108
+ ------
109
+ FileNotFoundError
110
+ If ``config=None`` and config cannot be loaded from disk.
111
+ ValueError
112
+ Illegal arguments, including invalid configuration.
113
+
114
+ Attributes
115
+ ----------
116
+ config : :class:`csbdeep.models.BaseConfig`
117
+ Configuration of the model, as provided during instantiation.
118
+ keras_model : `Keras model <https://keras.io/getting-started/functional-api-guide/>`_
119
+ Keras neural network model.
120
+ name : str
121
+ Model name.
122
+ logdir : :class:`pathlib.Path`
123
+ Path to model folder (which stores configuration, weights, etc.)
124
+ """
125
+
126
+ @classmethod
127
+ def from_pretrained(cls, name_or_alias=None):
128
+ try:
129
+ get_model_details(cls, name_or_alias, verbose=True)
130
+ return get_model_instance(cls, name_or_alias)
131
+ except ValueError as e:
132
+ if name_or_alias is not None:
133
+ print("Could not find model with name or alias '%s'" % (name_or_alias), file=sys.stderr)
134
+ sys.stderr.flush()
135
+ get_registered_models(cls, verbose=True)
136
+
137
+
138
+ def __init__(self, config, name=None, basedir='.'):
139
+ """See class docstring."""
140
+
141
+ config is None or isinstance(config,self._config_class) or _raise (
142
+ ValueError("Invalid configuration of type '%s', was expecting type '%s'." % (type(config).__name__, self._config_class.__name__))
143
+ )
144
+ if config is not None and not config.is_valid():
145
+ invalid_attr = config.is_valid(True)[1]
146
+ raise ValueError('Invalid configuration attributes: ' + ', '.join(invalid_attr))
147
+ (not (config is None and basedir is None)) or _raise(ValueError("No config provided and cannot be loaded from disk since basedir=None."))
148
+
149
+ name is None or (isinstance(name,string_types) and len(name)>0) or _raise(ValueError("No valid name: '%s'" % str(name)))
150
+ basedir is None or isinstance(basedir,(string_types,Path)) or _raise(ValueError("No valid basedir: '%s'" % str(basedir)))
151
+ self.config = config
152
+ self.name = name if name is not None else datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S.%f")
153
+ self.basedir = Path(basedir) if basedir is not None else None
154
+ if config is not None:
155
+ # config was provided -> update before it is saved to disk
156
+ self._update_and_check_config()
157
+ self._set_logdir()
158
+ if config is None:
159
+ # config was loaded from disk -> update it after loading
160
+ self._update_and_check_config()
161
+ self._model_prepared = False
162
+ self.keras_model = self._build()
163
+ if IS_KERAS_3_PLUS and isinstance(self.keras_model, keras_import('models', 'Model')):
164
+ # monkey-patch keras model to save weights in legacy format if suffix is not '.weights.h5'
165
+ _keras3_monkey_patch_legacy_weights(self.keras_model)
166
+ if config is None:
167
+ self._find_and_load_weights()
168
+
169
+
170
+ def __repr__(self):
171
+ s = ("{self.__class__.__name__}({self.name}): {self.config.axes} → {self._axes_out}\n".format(self=self) +
172
+ "├─ Directory: {}\n".format(self.logdir.resolve() if self.basedir is not None else None) +
173
+ self._repr_extra() +
174
+ "└─ {self.config}".format(self=self))
175
+ return s.encode('utf-8') if PY2 else s
176
+
177
+
178
+ def _repr_extra(self):
179
+ return ""
180
+
181
+
182
+ def _update_and_check_config(self):
183
+ pass
184
+
185
+
186
+ @suppress_without_basedir(warn=False)
187
+ def _set_logdir(self):
188
+ self.logdir = self.basedir / self.name
189
+
190
+ config_file = self.logdir / 'config.json'
191
+ if self.config is None:
192
+ if config_file.exists():
193
+ config_dict = load_json(str(config_file))
194
+ config_dict = self._config_class.update_loaded_config(config_dict)
195
+ self.config = self._config_class(**config_dict)
196
+ if not self.config.is_valid():
197
+ invalid_attr = self.config.is_valid(True)[1]
198
+ raise ValueError('Invalid attributes in loaded config: ' + ', '.join(invalid_attr))
199
+ else:
200
+ raise FileNotFoundError("config file doesn't exist: %s" % str(config_file.resolve()))
201
+ else:
202
+ if self.logdir.exists():
203
+ warnings.warn('output path for model already exists, files may be overwritten: %s' % str(self.logdir.resolve()))
204
+ self.logdir.mkdir(parents=True, exist_ok=True)
205
+ save_json(vars(self.config), str(config_file))
206
+
207
+
208
+ @suppress_without_basedir(warn=False)
209
+ def _find_and_load_weights(self,prefer='best'):
210
+ from itertools import chain
211
+ # get all weight files and sort by modification time descending (newest first)
212
+ weights_ext = ('*.h5','*.hdf5')
213
+ weights_files = chain(*(self.logdir.glob(ext) for ext in weights_ext))
214
+ weights_files = reversed(sorted(weights_files, key=lambda f: f.stat().st_mtime))
215
+ weights_files = list(weights_files)
216
+ if len(weights_files) == 0:
217
+ warnings.warn("Couldn't find any network weights (%s) to load." % ', '.join(weights_ext))
218
+ return
219
+ weights_preferred = list(filter(lambda f: prefer in f.name, weights_files))
220
+ weights_chosen = weights_preferred[0] if len(weights_preferred)>0 else weights_files[0]
221
+ print("Loading network weights from '%s'." % weights_chosen.name)
222
+ self.load_weights(weights_chosen.name)
223
+
224
+
225
+ @abstractmethod
226
+ def _build(self):
227
+ """ Create and return a Keras model. """
228
+
229
+
230
+ @suppress_without_basedir(warn=True)
231
+ def load_weights(self, name='weights_best.h5'):
232
+ """Load neural network weights from model folder.
233
+
234
+ Parameters
235
+ ----------
236
+ name : str
237
+ Name of HDF5 weight file (as saved during or after training).
238
+ """
239
+ self.keras_model.load_weights(str(self.logdir/name))
240
+
241
+
242
+ def _checkpoint_callbacks(self):
243
+ callbacks = []
244
+ if self.basedir is not None:
245
+ from ..utils.tf import keras_import
246
+ ModelCheckpoint = keras_import('callbacks', 'ModelCheckpoint')
247
+ # keras 3: need to add suffix to filename because ModelCheckpoint constructor throws error if it's missing
248
+ suffix = ".weights.h5" if IS_KERAS_3_PLUS else ""
249
+ if self.config.train_checkpoint is not None:
250
+ callbacks.append(ModelCheckpoint(str(self.logdir / self.config.train_checkpoint) + suffix, save_best_only=True, save_weights_only=True))
251
+ # keras3: remove suffix because patched model.save_weights can save in legacy format
252
+ if IS_KERAS_3_PLUS: callbacks[-1].filepath = callbacks[-1].filepath[:-len(suffix)]
253
+ if self.config.train_checkpoint_epoch is not None:
254
+ callbacks.append(ModelCheckpoint(str(self.logdir / self.config.train_checkpoint_epoch) + suffix, save_best_only=False, save_weights_only=True))
255
+ # keras3: remove suffix because patched model.save_weights can save in legacy format
256
+ if IS_KERAS_3_PLUS: callbacks[-1].filepath = callbacks[-1].filepath[:-len(suffix)]
257
+ return callbacks
258
+
259
+
260
+ def _training_finished(self):
261
+ if self.basedir is not None:
262
+ if self.config.train_checkpoint_last is not None:
263
+ self.keras_model.save_weights(str(self.logdir / self.config.train_checkpoint_last))
264
+ if self.config.train_checkpoint is not None:
265
+ print()
266
+ self._find_and_load_weights(self.config.train_checkpoint)
267
+ if self.config.train_checkpoint_epoch is not None:
268
+ try:
269
+ # remove temporary weights
270
+ (self.logdir / self.config.train_checkpoint_epoch).unlink()
271
+ except FileNotFoundError:
272
+ pass
273
+
274
+
275
+ @suppress_without_basedir(warn=True)
276
+ def export_TF(self, fname=None):
277
+ raise NotImplementedError()
278
+
279
+
280
+ def _make_permute_axes(self, img_axes_in, net_axes_in, net_axes_out=None, img_axes_out=None):
281
+ # img_axes_in -> net_axes_in ---NN--> net_axes_out -> img_axes_out
282
+ if net_axes_out is None:
283
+ net_axes_out = net_axes_in
284
+ if img_axes_out is None:
285
+ img_axes_out = img_axes_in
286
+ assert 'C' in net_axes_in and 'C' in net_axes_out
287
+ assert not 'C' in img_axes_in or 'C' in img_axes_out
288
+
289
+ def _permute_axes(data,undo=False):
290
+ if data is None:
291
+ return None
292
+ if undo:
293
+ if 'C' in img_axes_in:
294
+ return move_image_axes(data, net_axes_out, img_axes_out, True)
295
+ else:
296
+ # input is single-channel and has no channel axis
297
+ data = move_image_axes(data, net_axes_out, img_axes_out+'C', True)
298
+ if data.shape[-1] == 1:
299
+ # output is single-channel -> remove channel axis
300
+ data = data[...,0]
301
+ return data
302
+ else:
303
+ return move_image_axes(data, img_axes_in, net_axes_in, True)
304
+ return _permute_axes
305
+
306
+
307
+ def _check_normalizer_resizer(self, normalizer, resizer):
308
+ if normalizer is None:
309
+ normalizer = NoNormalizer()
310
+ if resizer is None:
311
+ resizer = NoResizer()
312
+ isinstance(resizer,Resizer) or _raise(ValueError())
313
+ isinstance(normalizer,Normalizer) or _raise(ValueError())
314
+ if normalizer.do_after:
315
+ if self.config.n_channel_in != self.config.n_channel_out:
316
+ warnings.warn('skipping normalization step after prediction because ' +
317
+ 'number of input and output channels differ.')
318
+
319
+ return normalizer, resizer
320
+
321
+
322
+ @property
323
+ def _axes_out(self):
324
+ return self.config.axes
325
+
326
+
327
+ @abstractproperty
328
+ def _config_class(self):
329
+ """ Class of config to be used for this model. """