senoquant 1.0.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- senoquant/__init__.py +6 -0
- senoquant/_reader.py +7 -0
- senoquant/_widget.py +33 -0
- senoquant/napari.yaml +83 -0
- senoquant/reader/__init__.py +5 -0
- senoquant/reader/core.py +369 -0
- senoquant/tabs/__init__.py +15 -0
- senoquant/tabs/batch/__init__.py +10 -0
- senoquant/tabs/batch/backend.py +641 -0
- senoquant/tabs/batch/config.py +270 -0
- senoquant/tabs/batch/frontend.py +1283 -0
- senoquant/tabs/batch/io.py +326 -0
- senoquant/tabs/batch/layers.py +86 -0
- senoquant/tabs/quantification/__init__.py +1 -0
- senoquant/tabs/quantification/backend.py +228 -0
- senoquant/tabs/quantification/features/__init__.py +80 -0
- senoquant/tabs/quantification/features/base.py +142 -0
- senoquant/tabs/quantification/features/marker/__init__.py +5 -0
- senoquant/tabs/quantification/features/marker/config.py +69 -0
- senoquant/tabs/quantification/features/marker/dialog.py +437 -0
- senoquant/tabs/quantification/features/marker/export.py +879 -0
- senoquant/tabs/quantification/features/marker/feature.py +119 -0
- senoquant/tabs/quantification/features/marker/morphology.py +285 -0
- senoquant/tabs/quantification/features/marker/rows.py +654 -0
- senoquant/tabs/quantification/features/marker/thresholding.py +46 -0
- senoquant/tabs/quantification/features/roi.py +346 -0
- senoquant/tabs/quantification/features/spots/__init__.py +5 -0
- senoquant/tabs/quantification/features/spots/config.py +62 -0
- senoquant/tabs/quantification/features/spots/dialog.py +477 -0
- senoquant/tabs/quantification/features/spots/export.py +1292 -0
- senoquant/tabs/quantification/features/spots/feature.py +112 -0
- senoquant/tabs/quantification/features/spots/morphology.py +279 -0
- senoquant/tabs/quantification/features/spots/rows.py +241 -0
- senoquant/tabs/quantification/frontend.py +815 -0
- senoquant/tabs/segmentation/__init__.py +1 -0
- senoquant/tabs/segmentation/backend.py +131 -0
- senoquant/tabs/segmentation/frontend.py +1009 -0
- senoquant/tabs/segmentation/models/__init__.py +5 -0
- senoquant/tabs/segmentation/models/base.py +146 -0
- senoquant/tabs/segmentation/models/cpsam/details.json +65 -0
- senoquant/tabs/segmentation/models/cpsam/model.py +150 -0
- senoquant/tabs/segmentation/models/default_2d/details.json +69 -0
- senoquant/tabs/segmentation/models/default_2d/model.py +664 -0
- senoquant/tabs/segmentation/models/default_3d/details.json +69 -0
- senoquant/tabs/segmentation/models/default_3d/model.py +682 -0
- senoquant/tabs/segmentation/models/hf.py +71 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/__init__.py +1 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/details.json +26 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/model.py +96 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/__init__.py +1 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/details.json +34 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/model.py +132 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/__init__.py +2 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/__init__.py +3 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/__init__.py +6 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/generate.py +470 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/prepare.py +273 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/rawdata.py +112 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/transform.py +384 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/blocks.py +184 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/losses.py +79 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/nets.py +165 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/predict.py +467 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/probability.py +67 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/train.py +148 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/io/__init__.py +163 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/__init__.py +52 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/base_model.py +329 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_isotropic.py +160 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_projection.py +178 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_standard.py +446 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_upsampling.py +54 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/config.py +254 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/pretrained.py +119 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/care_predict.py +180 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/__init__.py +5 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/plot_utils.py +159 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/six.py +18 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/tf.py +644 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/utils.py +272 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/version.py +1 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/docs/source/conf.py +368 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/setup.py +68 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_datagen.py +169 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_models.py +462 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_utils.py +166 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tools/create_zip_contents.py +34 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/__init__.py +30 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/big.py +624 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/bioimageio_utils.py +494 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/data/__init__.py +39 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/__init__.py +10 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom2d.py +215 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom3d.py +349 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/matching.py +483 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/__init__.py +28 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/base.py +1217 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model2d.py +594 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model3d.py +696 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/nms.py +384 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/__init__.py +2 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/plot.py +74 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/render.py +298 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/rays3d.py +373 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/sample_patches.py +65 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict2d.py +90 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict3d.py +93 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/utils.py +408 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/version.py +1 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/__init__.py +45 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/__init__.py +17 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/cli.py +55 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/core.py +285 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/__init__.py +15 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/cli.py +36 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/divisibility.py +193 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/probe.py +100 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/receptive_field.py +182 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/rf_cli.py +48 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/valid_sizes.py +278 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/__init__.py +8 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/core.py +157 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/__init__.py +17 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/core.py +226 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/__init__.py +5 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/core.py +401 -0
- senoquant/tabs/settings/__init__.py +1 -0
- senoquant/tabs/settings/backend.py +29 -0
- senoquant/tabs/settings/frontend.py +19 -0
- senoquant/tabs/spots/__init__.py +1 -0
- senoquant/tabs/spots/backend.py +139 -0
- senoquant/tabs/spots/frontend.py +800 -0
- senoquant/tabs/spots/models/__init__.py +5 -0
- senoquant/tabs/spots/models/base.py +94 -0
- senoquant/tabs/spots/models/rmp/details.json +61 -0
- senoquant/tabs/spots/models/rmp/model.py +499 -0
- senoquant/tabs/spots/models/udwt/details.json +103 -0
- senoquant/tabs/spots/models/udwt/model.py +482 -0
- senoquant/utils.py +25 -0
- senoquant-1.0.0b1.dist-info/METADATA +193 -0
- senoquant-1.0.0b1.dist-info/RECORD +148 -0
- senoquant-1.0.0b1.dist-info/WHEEL +5 -0
- senoquant-1.0.0b1.dist-info/entry_points.txt +2 -0
- senoquant-1.0.0b1.dist-info/licenses/LICENSE +28 -0
- senoquant-1.0.0b1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
from __future__ import print_function, unicode_literals, absolute_import, division
|
|
2
|
+
from six.moves import range, zip, map, reduce, filter
|
|
3
|
+
|
|
4
|
+
from ..utils import _raise, move_channel_for_backend, axes_dict, axes_check_and_normalize, backend_channels_last
|
|
5
|
+
from ..internals.losses import loss_laplace, loss_mse, loss_mae, loss_thresh_weighted_decay
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from ..utils.tf import keras_import, BACKEND as K
|
|
10
|
+
Callback, TerminateOnNaN = keras_import('callbacks', 'Callback', 'TerminateOnNaN')
|
|
11
|
+
Sequence = keras_import('utils', 'Sequence')
|
|
12
|
+
Optimizer = keras_import('optimizers', 'Optimizer')
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ParameterDecayCallback(Callback):
|
|
16
|
+
""" TODO """
|
|
17
|
+
def __init__(self, parameter, decay, name=None, verbose=0):
|
|
18
|
+
self.parameter = parameter
|
|
19
|
+
self.decay = decay
|
|
20
|
+
self.name = name
|
|
21
|
+
self.verbose = verbose
|
|
22
|
+
|
|
23
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
24
|
+
old_val = K.get_value(self.parameter)
|
|
25
|
+
if self.name:
|
|
26
|
+
logs = logs or {}
|
|
27
|
+
logs[self.name] = old_val
|
|
28
|
+
new_val = old_val * (1. / (1. + self.decay * (epoch + 1)))
|
|
29
|
+
K.set_value(self.parameter, new_val)
|
|
30
|
+
if self.verbose:
|
|
31
|
+
print("\n[ParameterDecayCallback] new %s: %s\n" % (self.name if self.name else 'parameter', new_val))
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def prepare_model(model, optimizer, loss, metrics=('mse','mae'),
|
|
35
|
+
loss_bg_thresh=0, loss_bg_decay=0.06, Y=None):
|
|
36
|
+
""" TODO """
|
|
37
|
+
|
|
38
|
+
isinstance(optimizer,Optimizer) or _raise(ValueError())
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
loss_standard = eval('loss_%s()'%loss)
|
|
42
|
+
_metrics = [eval('loss_%s()'%m) for m in metrics]
|
|
43
|
+
callbacks = [TerminateOnNaN()]
|
|
44
|
+
|
|
45
|
+
# checks
|
|
46
|
+
assert 0 <= loss_bg_thresh <= 1
|
|
47
|
+
assert loss_bg_thresh == 0 or Y is not None
|
|
48
|
+
if loss == 'laplace':
|
|
49
|
+
assert K.image_data_format() == "channels_last", "TODO"
|
|
50
|
+
assert list(model.output.shape)[-1] >= 2 and list(model.output.shape)[-1] % 2 == 0
|
|
51
|
+
|
|
52
|
+
# loss
|
|
53
|
+
if loss_bg_thresh == 0:
|
|
54
|
+
_loss = loss_standard
|
|
55
|
+
else:
|
|
56
|
+
freq = np.mean(Y > loss_bg_thresh)
|
|
57
|
+
# print("class frequency:", freq)
|
|
58
|
+
alpha = K.variable(1.0)
|
|
59
|
+
loss_per_pixel = eval('loss_{loss}(mean=False)'.format(loss=loss))
|
|
60
|
+
_loss = loss_thresh_weighted_decay(loss_per_pixel, loss_bg_thresh,
|
|
61
|
+
0.5 / (0.1 + (1 - freq)),
|
|
62
|
+
0.5 / (0.1 + freq),
|
|
63
|
+
alpha)
|
|
64
|
+
callbacks.append(ParameterDecayCallback(alpha, loss_bg_decay, name='alpha'))
|
|
65
|
+
if not loss in metrics:
|
|
66
|
+
_metrics.append(loss_standard)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
# compile model
|
|
70
|
+
model.compile(optimizer=optimizer, loss=_loss, metrics=_metrics)
|
|
71
|
+
|
|
72
|
+
return callbacks
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class RollingSequence(Sequence):
|
|
76
|
+
"""Helper class for creating batches for rolling sequence.
|
|
77
|
+
|
|
78
|
+
Create batches of size `batch_size` that contain indices in `range(data_size)`.
|
|
79
|
+
To that end, the data indices are repeated (rolling), either in ascending order or
|
|
80
|
+
shuffled if `shuffle=True`. If taking batches sequentially, all data indices will
|
|
81
|
+
appear equally often. All calls to `batch(i)` will return the same batch for same i.
|
|
82
|
+
Parameter `length` will only determine the result of `len`, it has no effect otherwise.
|
|
83
|
+
Note that batch_size is allowed to be larger than data_size.
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
def __init__(self, data_size, batch_size, length=None, shuffle=True, rng=None, keras_kwargs=None):
|
|
87
|
+
super(RollingSequence, self).__init__(**({} if keras_kwargs is None else keras_kwargs))
|
|
88
|
+
# print(f"### __init__", flush=True)
|
|
89
|
+
if rng is None: rng = np.random
|
|
90
|
+
self.data_size = int(data_size)
|
|
91
|
+
self.batch_size = int(batch_size)
|
|
92
|
+
self.length = 2**63-1 if length is None else int(length) # 2**63-1 is max possible value
|
|
93
|
+
self.shuffle = bool(shuffle)
|
|
94
|
+
self.index_gen = rng.permutation if self.shuffle else np.arange
|
|
95
|
+
self.index_map = {}
|
|
96
|
+
|
|
97
|
+
def __len__(self):
|
|
98
|
+
# print(f"### __len__ = {self.length}", flush=True)
|
|
99
|
+
return self.length
|
|
100
|
+
|
|
101
|
+
def _index(self, loop):
|
|
102
|
+
if loop in self.index_map:
|
|
103
|
+
return self.index_map[loop]
|
|
104
|
+
else:
|
|
105
|
+
return self.index_map.setdefault(loop, self.index_gen(self.data_size))
|
|
106
|
+
|
|
107
|
+
def on_epoch_end(self):
|
|
108
|
+
# print(f"### on_epoch_end", flush=True)
|
|
109
|
+
pass
|
|
110
|
+
|
|
111
|
+
def __iter__(self):
|
|
112
|
+
# print(f"### __iter__", flush=True)
|
|
113
|
+
for i in range(len(self)):
|
|
114
|
+
yield self[i]
|
|
115
|
+
|
|
116
|
+
def batch(self, i):
|
|
117
|
+
pos = i * self.batch_size
|
|
118
|
+
loop = pos // self.data_size
|
|
119
|
+
pos_loop = pos % self.data_size
|
|
120
|
+
sl = slice(pos_loop, pos_loop + self.batch_size)
|
|
121
|
+
index = self._index(loop)
|
|
122
|
+
_loop = loop
|
|
123
|
+
while sl.stop > len(index):
|
|
124
|
+
_loop += 1
|
|
125
|
+
index = np.concatenate((index, self._index(_loop)))
|
|
126
|
+
# print(f"### - batch({i:02}) -> {tuple(index[sl])}", flush=True)
|
|
127
|
+
return index[sl]
|
|
128
|
+
|
|
129
|
+
def __getitem__(self, i):
|
|
130
|
+
return self.batch(i)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class DataWrapper(RollingSequence):
|
|
134
|
+
|
|
135
|
+
def __init__(self, X, Y, batch_size, length, augmenter=None, keras_kwargs=None):
|
|
136
|
+
super(DataWrapper, self).__init__(data_size=len(X), batch_size=batch_size, length=length, shuffle=True, keras_kwargs=keras_kwargs)
|
|
137
|
+
len(X) == len(Y) or _raise(ValueError("X and Y must have same length"))
|
|
138
|
+
self.X, self.Y = X, Y
|
|
139
|
+
self.augmenter = augmenter
|
|
140
|
+
|
|
141
|
+
def __getitem__(self, i):
|
|
142
|
+
idx = self.batch(i)
|
|
143
|
+
X, Y = self.X[idx], self.Y[idx]
|
|
144
|
+
if self.augmenter is not None:
|
|
145
|
+
X,Y = tuple(zip(*tuple(self.augmenter(x,y) for x,y in zip(X,Y))))
|
|
146
|
+
X,Y = np.stack(X), np.stack(Y)
|
|
147
|
+
|
|
148
|
+
return X,Y
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
from __future__ import print_function, unicode_literals, absolute_import, division
|
|
3
|
+
from six.moves import range, zip, map, reduce, filter
|
|
4
|
+
from six import string_types
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
try:
|
|
8
|
+
from tifffile import imwrite as imsave
|
|
9
|
+
except ImportError:
|
|
10
|
+
from tifffile import imsave
|
|
11
|
+
import warnings
|
|
12
|
+
|
|
13
|
+
from ..utils import _raise, axes_check_and_normalize, axes_dict, move_image_axes, move_channel_for_backend, backend_channels_last
|
|
14
|
+
from ..utils.six import Path
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def save_tiff_imagej_compatible(file, img, axes, **imsave_kwargs):
|
|
19
|
+
"""Save image in ImageJ-compatible TIFF format.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
file : str
|
|
24
|
+
File name
|
|
25
|
+
img : numpy.ndarray
|
|
26
|
+
Image
|
|
27
|
+
axes: str
|
|
28
|
+
Axes of ``img``
|
|
29
|
+
imsave_kwargs : dict, optional
|
|
30
|
+
Keyword arguments for :func:`tifffile.imsave`
|
|
31
|
+
|
|
32
|
+
"""
|
|
33
|
+
axes = axes_check_and_normalize(axes,img.ndim,disallowed='S')
|
|
34
|
+
|
|
35
|
+
# convert to imagej-compatible data type
|
|
36
|
+
t = img.dtype
|
|
37
|
+
if 'float' in t.name: t_new = np.float32
|
|
38
|
+
elif 'uint' in t.name: t_new = np.uint16 if t.itemsize >= 2 else np.uint8
|
|
39
|
+
elif 'int' in t.name: t_new = np.int16
|
|
40
|
+
else: t_new = t
|
|
41
|
+
img = img.astype(t_new, copy=False)
|
|
42
|
+
if t != t_new:
|
|
43
|
+
warnings.warn("Converting data type from '%s' to ImageJ-compatible '%s'." % (t, np.dtype(t_new)))
|
|
44
|
+
|
|
45
|
+
# move axes to correct positions for imagej
|
|
46
|
+
img = move_image_axes(img, axes, 'TZCYX', True)
|
|
47
|
+
|
|
48
|
+
imsave_kwargs['imagej'] = True
|
|
49
|
+
imsave(file, img, **imsave_kwargs)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def load_training_data(file, validation_split=0, axes=None, n_images=None, verbose=False):
|
|
54
|
+
"""Load training data from file in ``.npz`` format.
|
|
55
|
+
|
|
56
|
+
The data file is expected to have the keys:
|
|
57
|
+
|
|
58
|
+
- ``X`` : Array of training input images.
|
|
59
|
+
- ``Y`` : Array of corresponding target images.
|
|
60
|
+
- ``axes`` : Axes of the training images.
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
Parameters
|
|
64
|
+
----------
|
|
65
|
+
file : str
|
|
66
|
+
File name
|
|
67
|
+
validation_split : float
|
|
68
|
+
Fraction of images to use as validation set during training.
|
|
69
|
+
axes: str, optional
|
|
70
|
+
Must be provided in case the loaded data does not contain ``axes`` information.
|
|
71
|
+
n_images : int, optional
|
|
72
|
+
Can be used to limit the number of images loaded from data.
|
|
73
|
+
verbose : bool, optional
|
|
74
|
+
Can be used to display information about the loaded images.
|
|
75
|
+
|
|
76
|
+
Returns
|
|
77
|
+
-------
|
|
78
|
+
tuple( tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`), tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`), str )
|
|
79
|
+
Returns two tuples (`X_train`, `Y_train`), (`X_val`, `Y_val`) of training and validation sets
|
|
80
|
+
and the axes of the input images.
|
|
81
|
+
The tuple of validation data will be ``None`` if ``validation_split = 0``.
|
|
82
|
+
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
f = np.load(file)
|
|
86
|
+
X, Y = f['X'], f['Y']
|
|
87
|
+
if axes is None:
|
|
88
|
+
axes = f['axes']
|
|
89
|
+
axes = axes_check_and_normalize(axes)
|
|
90
|
+
|
|
91
|
+
# assert X.shape == Y.shape
|
|
92
|
+
assert X.ndim == Y.ndim
|
|
93
|
+
assert len(axes) == X.ndim
|
|
94
|
+
assert 'C' in axes
|
|
95
|
+
if n_images is None:
|
|
96
|
+
n_images = X.shape[0]
|
|
97
|
+
assert X.shape[0] == Y.shape[0]
|
|
98
|
+
assert 0 < n_images <= X.shape[0]
|
|
99
|
+
assert 0 <= validation_split < 1
|
|
100
|
+
|
|
101
|
+
X, Y = X[:n_images], Y[:n_images]
|
|
102
|
+
channel = axes_dict(axes)['C']
|
|
103
|
+
|
|
104
|
+
if validation_split > 0:
|
|
105
|
+
n_val = int(round(n_images * validation_split))
|
|
106
|
+
n_train = n_images - n_val
|
|
107
|
+
assert 0 < n_val and 0 < n_train
|
|
108
|
+
X_t, Y_t = X[-n_val:], Y[-n_val:]
|
|
109
|
+
X, Y = X[:n_train], Y[:n_train]
|
|
110
|
+
assert X.shape[0] == n_train and X_t.shape[0] == n_val
|
|
111
|
+
X_t = move_channel_for_backend(X_t,channel=channel)
|
|
112
|
+
Y_t = move_channel_for_backend(Y_t,channel=channel)
|
|
113
|
+
|
|
114
|
+
X = move_channel_for_backend(X,channel=channel)
|
|
115
|
+
Y = move_channel_for_backend(Y,channel=channel)
|
|
116
|
+
|
|
117
|
+
axes = axes.replace('C','') # remove channel
|
|
118
|
+
if backend_channels_last():
|
|
119
|
+
axes = axes+'C'
|
|
120
|
+
else:
|
|
121
|
+
axes = axes[:1]+'C'+axes[1:]
|
|
122
|
+
|
|
123
|
+
data_val = (X_t,Y_t) if validation_split > 0 else None
|
|
124
|
+
|
|
125
|
+
if verbose:
|
|
126
|
+
ax = axes_dict(axes)
|
|
127
|
+
n_train, n_val = len(X), len(X_t) if validation_split>0 else 0
|
|
128
|
+
image_size = tuple( X.shape[ax[a]] for a in axes if a in 'TZYX' )
|
|
129
|
+
n_dim = len(image_size)
|
|
130
|
+
n_channel_in, n_channel_out = X.shape[ax['C']], Y.shape[ax['C']]
|
|
131
|
+
|
|
132
|
+
print('number of training images:\t', n_train)
|
|
133
|
+
print('number of validation images:\t', n_val)
|
|
134
|
+
print('image size (%dD):\t\t'%n_dim, image_size)
|
|
135
|
+
print('axes:\t\t\t\t', axes)
|
|
136
|
+
print('channels in / out:\t\t', n_channel_in, '/', n_channel_out)
|
|
137
|
+
|
|
138
|
+
return (X,Y), data_val, axes
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def save_training_data(file, X, Y, axes):
|
|
143
|
+
"""Save training data in ``.npz`` format.
|
|
144
|
+
|
|
145
|
+
Parameters
|
|
146
|
+
----------
|
|
147
|
+
file : str
|
|
148
|
+
File name
|
|
149
|
+
X : :class:`numpy.ndarray`
|
|
150
|
+
Array of patches extracted from source images.
|
|
151
|
+
Y : :class:`numpy.ndarray`
|
|
152
|
+
Array of corresponding target patches.
|
|
153
|
+
axes : str
|
|
154
|
+
Axes of the extracted patches.
|
|
155
|
+
|
|
156
|
+
"""
|
|
157
|
+
isinstance(file,(Path,string_types)) or _raise(ValueError())
|
|
158
|
+
file = Path(file).with_suffix('.npz')
|
|
159
|
+
file.parent.mkdir(parents=True,exist_ok=True)
|
|
160
|
+
|
|
161
|
+
axes = axes_check_and_normalize(axes)
|
|
162
|
+
len(axes) == X.ndim or _raise(ValueError())
|
|
163
|
+
np.savez(str(file), X=X, Y=Y, axes=axes)
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from __future__ import absolute_import, print_function
|
|
2
|
+
|
|
3
|
+
# checks
|
|
4
|
+
try:
|
|
5
|
+
import tensorflow
|
|
6
|
+
del tensorflow
|
|
7
|
+
except ModuleNotFoundError as e:
|
|
8
|
+
from six import raise_from
|
|
9
|
+
raise_from(RuntimeError('Please install TensorFlow: https://www.tensorflow.org/install/'), e)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
import tensorflow, sys
|
|
13
|
+
from ..utils.tf import keras_import, IS_TF_1, BACKEND as K
|
|
14
|
+
|
|
15
|
+
if IS_TF_1:
|
|
16
|
+
try:
|
|
17
|
+
import keras
|
|
18
|
+
del keras
|
|
19
|
+
except ModuleNotFoundError as e:
|
|
20
|
+
if e.name in {'theano','cntk'}:
|
|
21
|
+
from six import raise_from
|
|
22
|
+
raise_from(RuntimeError(
|
|
23
|
+
"Keras is configured to use the '%s' backend, which is not installed. "
|
|
24
|
+
"Please change it to use 'tensorflow' instead: "
|
|
25
|
+
"https://keras.io/getting-started/faq/#where-is-the-keras-configuration-file-stored" % e.name
|
|
26
|
+
), e)
|
|
27
|
+
else:
|
|
28
|
+
raise e
|
|
29
|
+
|
|
30
|
+
if K.backend() != 'tensorflow':
|
|
31
|
+
raise NotImplementedError(
|
|
32
|
+
"Keras is configured to use the '%s' backend, which is currently not supported. "
|
|
33
|
+
"Please configure Keras to use 'tensorflow' instead." % K.backend()
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
if K.image_data_format() != 'channels_last':
|
|
37
|
+
raise NotImplementedError(
|
|
38
|
+
"Keras is configured to use the '%s' image data format, which is currently not supported. "
|
|
39
|
+
"Please change it to use 'channels_last' instead: "
|
|
40
|
+
"https://keras.io/getting-started/faq/#where-is-the-keras-configuration-file-stored" % K.image_data_format()
|
|
41
|
+
)
|
|
42
|
+
del tensorflow, sys, keras_import, IS_TF_1, K
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# imports
|
|
46
|
+
from .config import BaseConfig, Config
|
|
47
|
+
from .base_model import BaseModel
|
|
48
|
+
from .care_standard import CARE
|
|
49
|
+
from .care_upsampling import UpsamplingCARE
|
|
50
|
+
from .care_isotropic import IsotropicCARE
|
|
51
|
+
from .care_projection import ProjectionConfig, ProjectionCARE
|
|
52
|
+
from .pretrained import register_model, register_aliases, clear_models_and_aliases
|
|
@@ -0,0 +1,329 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
from __future__ import print_function, unicode_literals, absolute_import, division
|
|
3
|
+
|
|
4
|
+
import datetime
|
|
5
|
+
import warnings
|
|
6
|
+
import sys
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from six import string_types, PY2
|
|
10
|
+
from functools import wraps
|
|
11
|
+
|
|
12
|
+
from .config import BaseConfig
|
|
13
|
+
from ..utils import _raise, load_json, save_json, axes_check_and_normalize, axes_dict, move_image_axes
|
|
14
|
+
from ..utils.six import Path, FileNotFoundError
|
|
15
|
+
from ..utils.tf import keras_import, IS_KERAS_3_PLUS
|
|
16
|
+
from ..data import Normalizer, NoNormalizer
|
|
17
|
+
from ..data import Resizer, NoResizer
|
|
18
|
+
from .pretrained import get_model_details, get_model_instance, get_registered_models
|
|
19
|
+
|
|
20
|
+
from six import add_metaclass
|
|
21
|
+
from abc import ABCMeta, abstractmethod, abstractproperty
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def suppress_without_basedir(warn):
|
|
26
|
+
def _suppress_without_basedir(f):
|
|
27
|
+
@wraps(f)
|
|
28
|
+
def wrapper(*args, **kwargs):
|
|
29
|
+
self = args[0]
|
|
30
|
+
if self.basedir is None:
|
|
31
|
+
warn is False or warnings.warn("Suppressing call of '%s' (due to basedir=None)." % f.__name__)
|
|
32
|
+
else:
|
|
33
|
+
return f(*args, **kwargs)
|
|
34
|
+
return wrapper
|
|
35
|
+
return _suppress_without_basedir
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
if IS_KERAS_3_PLUS:
|
|
40
|
+
import h5py
|
|
41
|
+
from packaging.version import Version
|
|
42
|
+
from keras.src.legacy.saving import legacy_h5_format
|
|
43
|
+
from keras.src import backend, __version__ as keras_version
|
|
44
|
+
|
|
45
|
+
if Version(keras_version) >= Version("3.3.0"):
|
|
46
|
+
save_weights_to_hdf5_group = legacy_h5_format.save_weights_to_hdf5_group
|
|
47
|
+
else:
|
|
48
|
+
def save_weights_to_hdf5_group(f, model):
|
|
49
|
+
legacy_h5_format.save_attributes_to_hdf5_group(f, "layer_names", [layer.name.encode("utf8") for layer in model.layers])
|
|
50
|
+
f.attrs["backend"] = backend.backend().encode("utf8")
|
|
51
|
+
f.attrs["keras_version"] = str(keras_version).encode("utf8")
|
|
52
|
+
for layer in sorted(model.layers, key=lambda x: x.name):
|
|
53
|
+
g = f.create_group(layer.name)
|
|
54
|
+
weights = legacy_h5_format._legacy_weights(layer)
|
|
55
|
+
save_subset_weights_to_hdf5_group(g, weights)
|
|
56
|
+
g = f.create_group("top_level_model_weights")
|
|
57
|
+
weights = [v for v in model._trainable_variables + model._non_trainable_variables if v in model.weights]
|
|
58
|
+
save_subset_weights_to_hdf5_group(g, weights)
|
|
59
|
+
|
|
60
|
+
def save_subset_weights_to_hdf5_group(f, weights):
|
|
61
|
+
# FIX: use w.path instead of w.name to avoid name collisions (for "functional" layers)
|
|
62
|
+
# -> has been fixed since keras 3.3.0: https://github.com/keras-team/keras/blob/v3.3.0/keras/src/legacy/saving/legacy_h5_format.py#L234
|
|
63
|
+
weight_names = [w.path.encode("utf8") for w in weights]
|
|
64
|
+
weight_values = [backend.convert_to_numpy(w) for w in weights]
|
|
65
|
+
legacy_h5_format.save_attributes_to_hdf5_group(f, "weight_names", weight_names)
|
|
66
|
+
for name, val in zip(weight_names, weight_values):
|
|
67
|
+
param_dset = f.create_dataset(name, val.shape, dtype=val.dtype)
|
|
68
|
+
param_dset[() if not val.shape else slice(None)] = val
|
|
69
|
+
|
|
70
|
+
def _keras3_monkey_patch_legacy_weights(model):
|
|
71
|
+
ref_save_weights = model.save_weights
|
|
72
|
+
|
|
73
|
+
def save_weights(self, filepath, overwrite=True):
|
|
74
|
+
p = Path(filepath)
|
|
75
|
+
if not overwrite and p.exists():
|
|
76
|
+
raise FileExistsError(f"Weights file already exists: {str(p.resolve())}")
|
|
77
|
+
if p.name.endswith(".weights.h5"):
|
|
78
|
+
warnings.warn("Detected filename suffix '.weights.h5', thus saving in newer Keras 3.x file format (cannot be loaded with Keras 2.x)")
|
|
79
|
+
if not p.name.endswith(".weights.h5"):
|
|
80
|
+
with h5py.File(str(p), "w") as f:
|
|
81
|
+
save_weights_to_hdf5_group(f, self)
|
|
82
|
+
else:
|
|
83
|
+
return ref_save_weights(filepath, overwrite=overwrite)
|
|
84
|
+
|
|
85
|
+
model.save_weights = save_weights.__get__(model)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@add_metaclass(ABCMeta)
|
|
90
|
+
class BaseModel(object):
|
|
91
|
+
"""Base model.
|
|
92
|
+
|
|
93
|
+
Subclasses must implement :func:`_build` and :func:`_config_class`.
|
|
94
|
+
|
|
95
|
+
Parameters
|
|
96
|
+
----------
|
|
97
|
+
config : Subclass of :class:`csbdeep.models.BaseConfig` or None
|
|
98
|
+
Valid configuration of a model (see :func:`BaseConfig.is_valid`).
|
|
99
|
+
Will be saved to disk as JSON (``config.json``).
|
|
100
|
+
If set to ``None``, will be loaded from disk (must exist).
|
|
101
|
+
name : str or None
|
|
102
|
+
Model name. Uses a timestamp if set to ``None`` (default).
|
|
103
|
+
basedir : str
|
|
104
|
+
Directory that contains (or will contain) a folder with the given model name.
|
|
105
|
+
Use ``None`` to disable saving (or loading) any data to (or from) disk (regardless of other parameters).
|
|
106
|
+
|
|
107
|
+
Raises
|
|
108
|
+
------
|
|
109
|
+
FileNotFoundError
|
|
110
|
+
If ``config=None`` and config cannot be loaded from disk.
|
|
111
|
+
ValueError
|
|
112
|
+
Illegal arguments, including invalid configuration.
|
|
113
|
+
|
|
114
|
+
Attributes
|
|
115
|
+
----------
|
|
116
|
+
config : :class:`csbdeep.models.BaseConfig`
|
|
117
|
+
Configuration of the model, as provided during instantiation.
|
|
118
|
+
keras_model : `Keras model <https://keras.io/getting-started/functional-api-guide/>`_
|
|
119
|
+
Keras neural network model.
|
|
120
|
+
name : str
|
|
121
|
+
Model name.
|
|
122
|
+
logdir : :class:`pathlib.Path`
|
|
123
|
+
Path to model folder (which stores configuration, weights, etc.)
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
@classmethod
|
|
127
|
+
def from_pretrained(cls, name_or_alias=None):
|
|
128
|
+
try:
|
|
129
|
+
get_model_details(cls, name_or_alias, verbose=True)
|
|
130
|
+
return get_model_instance(cls, name_or_alias)
|
|
131
|
+
except ValueError as e:
|
|
132
|
+
if name_or_alias is not None:
|
|
133
|
+
print("Could not find model with name or alias '%s'" % (name_or_alias), file=sys.stderr)
|
|
134
|
+
sys.stderr.flush()
|
|
135
|
+
get_registered_models(cls, verbose=True)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def __init__(self, config, name=None, basedir='.'):
|
|
139
|
+
"""See class docstring."""
|
|
140
|
+
|
|
141
|
+
config is None or isinstance(config,self._config_class) or _raise (
|
|
142
|
+
ValueError("Invalid configuration of type '%s', was expecting type '%s'." % (type(config).__name__, self._config_class.__name__))
|
|
143
|
+
)
|
|
144
|
+
if config is not None and not config.is_valid():
|
|
145
|
+
invalid_attr = config.is_valid(True)[1]
|
|
146
|
+
raise ValueError('Invalid configuration attributes: ' + ', '.join(invalid_attr))
|
|
147
|
+
(not (config is None and basedir is None)) or _raise(ValueError("No config provided and cannot be loaded from disk since basedir=None."))
|
|
148
|
+
|
|
149
|
+
name is None or (isinstance(name,string_types) and len(name)>0) or _raise(ValueError("No valid name: '%s'" % str(name)))
|
|
150
|
+
basedir is None or isinstance(basedir,(string_types,Path)) or _raise(ValueError("No valid basedir: '%s'" % str(basedir)))
|
|
151
|
+
self.config = config
|
|
152
|
+
self.name = name if name is not None else datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S.%f")
|
|
153
|
+
self.basedir = Path(basedir) if basedir is not None else None
|
|
154
|
+
if config is not None:
|
|
155
|
+
# config was provided -> update before it is saved to disk
|
|
156
|
+
self._update_and_check_config()
|
|
157
|
+
self._set_logdir()
|
|
158
|
+
if config is None:
|
|
159
|
+
# config was loaded from disk -> update it after loading
|
|
160
|
+
self._update_and_check_config()
|
|
161
|
+
self._model_prepared = False
|
|
162
|
+
self.keras_model = self._build()
|
|
163
|
+
if IS_KERAS_3_PLUS and isinstance(self.keras_model, keras_import('models', 'Model')):
|
|
164
|
+
# monkey-patch keras model to save weights in legacy format if suffix is not '.weights.h5'
|
|
165
|
+
_keras3_monkey_patch_legacy_weights(self.keras_model)
|
|
166
|
+
if config is None:
|
|
167
|
+
self._find_and_load_weights()
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def __repr__(self):
|
|
171
|
+
s = ("{self.__class__.__name__}({self.name}): {self.config.axes} → {self._axes_out}\n".format(self=self) +
|
|
172
|
+
"├─ Directory: {}\n".format(self.logdir.resolve() if self.basedir is not None else None) +
|
|
173
|
+
self._repr_extra() +
|
|
174
|
+
"└─ {self.config}".format(self=self))
|
|
175
|
+
return s.encode('utf-8') if PY2 else s
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _repr_extra(self):
|
|
179
|
+
return ""
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def _update_and_check_config(self):
|
|
183
|
+
pass
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
@suppress_without_basedir(warn=False)
|
|
187
|
+
def _set_logdir(self):
|
|
188
|
+
self.logdir = self.basedir / self.name
|
|
189
|
+
|
|
190
|
+
config_file = self.logdir / 'config.json'
|
|
191
|
+
if self.config is None:
|
|
192
|
+
if config_file.exists():
|
|
193
|
+
config_dict = load_json(str(config_file))
|
|
194
|
+
config_dict = self._config_class.update_loaded_config(config_dict)
|
|
195
|
+
self.config = self._config_class(**config_dict)
|
|
196
|
+
if not self.config.is_valid():
|
|
197
|
+
invalid_attr = self.config.is_valid(True)[1]
|
|
198
|
+
raise ValueError('Invalid attributes in loaded config: ' + ', '.join(invalid_attr))
|
|
199
|
+
else:
|
|
200
|
+
raise FileNotFoundError("config file doesn't exist: %s" % str(config_file.resolve()))
|
|
201
|
+
else:
|
|
202
|
+
if self.logdir.exists():
|
|
203
|
+
warnings.warn('output path for model already exists, files may be overwritten: %s' % str(self.logdir.resolve()))
|
|
204
|
+
self.logdir.mkdir(parents=True, exist_ok=True)
|
|
205
|
+
save_json(vars(self.config), str(config_file))
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
@suppress_without_basedir(warn=False)
|
|
209
|
+
def _find_and_load_weights(self,prefer='best'):
|
|
210
|
+
from itertools import chain
|
|
211
|
+
# get all weight files and sort by modification time descending (newest first)
|
|
212
|
+
weights_ext = ('*.h5','*.hdf5')
|
|
213
|
+
weights_files = chain(*(self.logdir.glob(ext) for ext in weights_ext))
|
|
214
|
+
weights_files = reversed(sorted(weights_files, key=lambda f: f.stat().st_mtime))
|
|
215
|
+
weights_files = list(weights_files)
|
|
216
|
+
if len(weights_files) == 0:
|
|
217
|
+
warnings.warn("Couldn't find any network weights (%s) to load." % ', '.join(weights_ext))
|
|
218
|
+
return
|
|
219
|
+
weights_preferred = list(filter(lambda f: prefer in f.name, weights_files))
|
|
220
|
+
weights_chosen = weights_preferred[0] if len(weights_preferred)>0 else weights_files[0]
|
|
221
|
+
print("Loading network weights from '%s'." % weights_chosen.name)
|
|
222
|
+
self.load_weights(weights_chosen.name)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
@abstractmethod
|
|
226
|
+
def _build(self):
|
|
227
|
+
""" Create and return a Keras model. """
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
@suppress_without_basedir(warn=True)
|
|
231
|
+
def load_weights(self, name='weights_best.h5'):
|
|
232
|
+
"""Load neural network weights from model folder.
|
|
233
|
+
|
|
234
|
+
Parameters
|
|
235
|
+
----------
|
|
236
|
+
name : str
|
|
237
|
+
Name of HDF5 weight file (as saved during or after training).
|
|
238
|
+
"""
|
|
239
|
+
self.keras_model.load_weights(str(self.logdir/name))
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def _checkpoint_callbacks(self):
|
|
243
|
+
callbacks = []
|
|
244
|
+
if self.basedir is not None:
|
|
245
|
+
from ..utils.tf import keras_import
|
|
246
|
+
ModelCheckpoint = keras_import('callbacks', 'ModelCheckpoint')
|
|
247
|
+
# keras 3: need to add suffix to filename because ModelCheckpoint constructor throws error if it's missing
|
|
248
|
+
suffix = ".weights.h5" if IS_KERAS_3_PLUS else ""
|
|
249
|
+
if self.config.train_checkpoint is not None:
|
|
250
|
+
callbacks.append(ModelCheckpoint(str(self.logdir / self.config.train_checkpoint) + suffix, save_best_only=True, save_weights_only=True))
|
|
251
|
+
# keras3: remove suffix because patched model.save_weights can save in legacy format
|
|
252
|
+
if IS_KERAS_3_PLUS: callbacks[-1].filepath = callbacks[-1].filepath[:-len(suffix)]
|
|
253
|
+
if self.config.train_checkpoint_epoch is not None:
|
|
254
|
+
callbacks.append(ModelCheckpoint(str(self.logdir / self.config.train_checkpoint_epoch) + suffix, save_best_only=False, save_weights_only=True))
|
|
255
|
+
# keras3: remove suffix because patched model.save_weights can save in legacy format
|
|
256
|
+
if IS_KERAS_3_PLUS: callbacks[-1].filepath = callbacks[-1].filepath[:-len(suffix)]
|
|
257
|
+
return callbacks
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def _training_finished(self):
|
|
261
|
+
if self.basedir is not None:
|
|
262
|
+
if self.config.train_checkpoint_last is not None:
|
|
263
|
+
self.keras_model.save_weights(str(self.logdir / self.config.train_checkpoint_last))
|
|
264
|
+
if self.config.train_checkpoint is not None:
|
|
265
|
+
print()
|
|
266
|
+
self._find_and_load_weights(self.config.train_checkpoint)
|
|
267
|
+
if self.config.train_checkpoint_epoch is not None:
|
|
268
|
+
try:
|
|
269
|
+
# remove temporary weights
|
|
270
|
+
(self.logdir / self.config.train_checkpoint_epoch).unlink()
|
|
271
|
+
except FileNotFoundError:
|
|
272
|
+
pass
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
@suppress_without_basedir(warn=True)
|
|
276
|
+
def export_TF(self, fname=None):
|
|
277
|
+
raise NotImplementedError()
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def _make_permute_axes(self, img_axes_in, net_axes_in, net_axes_out=None, img_axes_out=None):
|
|
281
|
+
# img_axes_in -> net_axes_in ---NN--> net_axes_out -> img_axes_out
|
|
282
|
+
if net_axes_out is None:
|
|
283
|
+
net_axes_out = net_axes_in
|
|
284
|
+
if img_axes_out is None:
|
|
285
|
+
img_axes_out = img_axes_in
|
|
286
|
+
assert 'C' in net_axes_in and 'C' in net_axes_out
|
|
287
|
+
assert not 'C' in img_axes_in or 'C' in img_axes_out
|
|
288
|
+
|
|
289
|
+
def _permute_axes(data,undo=False):
|
|
290
|
+
if data is None:
|
|
291
|
+
return None
|
|
292
|
+
if undo:
|
|
293
|
+
if 'C' in img_axes_in:
|
|
294
|
+
return move_image_axes(data, net_axes_out, img_axes_out, True)
|
|
295
|
+
else:
|
|
296
|
+
# input is single-channel and has no channel axis
|
|
297
|
+
data = move_image_axes(data, net_axes_out, img_axes_out+'C', True)
|
|
298
|
+
if data.shape[-1] == 1:
|
|
299
|
+
# output is single-channel -> remove channel axis
|
|
300
|
+
data = data[...,0]
|
|
301
|
+
return data
|
|
302
|
+
else:
|
|
303
|
+
return move_image_axes(data, img_axes_in, net_axes_in, True)
|
|
304
|
+
return _permute_axes
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def _check_normalizer_resizer(self, normalizer, resizer):
|
|
308
|
+
if normalizer is None:
|
|
309
|
+
normalizer = NoNormalizer()
|
|
310
|
+
if resizer is None:
|
|
311
|
+
resizer = NoResizer()
|
|
312
|
+
isinstance(resizer,Resizer) or _raise(ValueError())
|
|
313
|
+
isinstance(normalizer,Normalizer) or _raise(ValueError())
|
|
314
|
+
if normalizer.do_after:
|
|
315
|
+
if self.config.n_channel_in != self.config.n_channel_out:
|
|
316
|
+
warnings.warn('skipping normalization step after prediction because ' +
|
|
317
|
+
'number of input and output channels differ.')
|
|
318
|
+
|
|
319
|
+
return normalizer, resizer
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
@property
|
|
323
|
+
def _axes_out(self):
|
|
324
|
+
return self.config.axes
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
@abstractproperty
|
|
328
|
+
def _config_class(self):
|
|
329
|
+
""" Class of config to be used for this model. """
|