senoquant 1.0.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- senoquant/__init__.py +6 -0
- senoquant/_reader.py +7 -0
- senoquant/_widget.py +33 -0
- senoquant/napari.yaml +83 -0
- senoquant/reader/__init__.py +5 -0
- senoquant/reader/core.py +369 -0
- senoquant/tabs/__init__.py +15 -0
- senoquant/tabs/batch/__init__.py +10 -0
- senoquant/tabs/batch/backend.py +641 -0
- senoquant/tabs/batch/config.py +270 -0
- senoquant/tabs/batch/frontend.py +1283 -0
- senoquant/tabs/batch/io.py +326 -0
- senoquant/tabs/batch/layers.py +86 -0
- senoquant/tabs/quantification/__init__.py +1 -0
- senoquant/tabs/quantification/backend.py +228 -0
- senoquant/tabs/quantification/features/__init__.py +80 -0
- senoquant/tabs/quantification/features/base.py +142 -0
- senoquant/tabs/quantification/features/marker/__init__.py +5 -0
- senoquant/tabs/quantification/features/marker/config.py +69 -0
- senoquant/tabs/quantification/features/marker/dialog.py +437 -0
- senoquant/tabs/quantification/features/marker/export.py +879 -0
- senoquant/tabs/quantification/features/marker/feature.py +119 -0
- senoquant/tabs/quantification/features/marker/morphology.py +285 -0
- senoquant/tabs/quantification/features/marker/rows.py +654 -0
- senoquant/tabs/quantification/features/marker/thresholding.py +46 -0
- senoquant/tabs/quantification/features/roi.py +346 -0
- senoquant/tabs/quantification/features/spots/__init__.py +5 -0
- senoquant/tabs/quantification/features/spots/config.py +62 -0
- senoquant/tabs/quantification/features/spots/dialog.py +477 -0
- senoquant/tabs/quantification/features/spots/export.py +1292 -0
- senoquant/tabs/quantification/features/spots/feature.py +112 -0
- senoquant/tabs/quantification/features/spots/morphology.py +279 -0
- senoquant/tabs/quantification/features/spots/rows.py +241 -0
- senoquant/tabs/quantification/frontend.py +815 -0
- senoquant/tabs/segmentation/__init__.py +1 -0
- senoquant/tabs/segmentation/backend.py +131 -0
- senoquant/tabs/segmentation/frontend.py +1009 -0
- senoquant/tabs/segmentation/models/__init__.py +5 -0
- senoquant/tabs/segmentation/models/base.py +146 -0
- senoquant/tabs/segmentation/models/cpsam/details.json +65 -0
- senoquant/tabs/segmentation/models/cpsam/model.py +150 -0
- senoquant/tabs/segmentation/models/default_2d/details.json +69 -0
- senoquant/tabs/segmentation/models/default_2d/model.py +664 -0
- senoquant/tabs/segmentation/models/default_3d/details.json +69 -0
- senoquant/tabs/segmentation/models/default_3d/model.py +682 -0
- senoquant/tabs/segmentation/models/hf.py +71 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/__init__.py +1 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/details.json +26 -0
- senoquant/tabs/segmentation/models/nuclear_dilation/model.py +96 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/__init__.py +1 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/details.json +34 -0
- senoquant/tabs/segmentation/models/perinuclear_rings/model.py +132 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/__init__.py +2 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/__init__.py +3 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/__init__.py +6 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/generate.py +470 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/prepare.py +273 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/rawdata.py +112 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/transform.py +384 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/blocks.py +184 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/losses.py +79 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/nets.py +165 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/predict.py +467 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/probability.py +67 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/train.py +148 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/io/__init__.py +163 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/__init__.py +52 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/base_model.py +329 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_isotropic.py +160 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_projection.py +178 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_standard.py +446 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_upsampling.py +54 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/config.py +254 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/pretrained.py +119 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/care_predict.py +180 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/__init__.py +5 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/plot_utils.py +159 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/six.py +18 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/tf.py +644 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/utils.py +272 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/version.py +1 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/docs/source/conf.py +368 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/setup.py +68 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_datagen.py +169 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_models.py +462 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_utils.py +166 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tools/create_zip_contents.py +34 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/__init__.py +30 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/big.py +624 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/bioimageio_utils.py +494 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/data/__init__.py +39 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/__init__.py +10 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom2d.py +215 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom3d.py +349 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/matching.py +483 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/__init__.py +28 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/base.py +1217 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model2d.py +594 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model3d.py +696 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/nms.py +384 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/__init__.py +2 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/plot.py +74 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/render.py +298 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/rays3d.py +373 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/sample_patches.py +65 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/__init__.py +0 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict2d.py +90 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict3d.py +93 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/utils.py +408 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/version.py +1 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/__init__.py +45 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/__init__.py +17 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/cli.py +55 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/core.py +285 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/__init__.py +15 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/cli.py +36 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/divisibility.py +193 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/probe.py +100 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/receptive_field.py +182 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/rf_cli.py +48 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/valid_sizes.py +278 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/__init__.py +8 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/core.py +157 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/__init__.py +17 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/core.py +226 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/__init__.py +5 -0
- senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/core.py +401 -0
- senoquant/tabs/settings/__init__.py +1 -0
- senoquant/tabs/settings/backend.py +29 -0
- senoquant/tabs/settings/frontend.py +19 -0
- senoquant/tabs/spots/__init__.py +1 -0
- senoquant/tabs/spots/backend.py +139 -0
- senoquant/tabs/spots/frontend.py +800 -0
- senoquant/tabs/spots/models/__init__.py +5 -0
- senoquant/tabs/spots/models/base.py +94 -0
- senoquant/tabs/spots/models/rmp/details.json +61 -0
- senoquant/tabs/spots/models/rmp/model.py +499 -0
- senoquant/tabs/spots/models/udwt/details.json +103 -0
- senoquant/tabs/spots/models/udwt/model.py +482 -0
- senoquant/utils.py +25 -0
- senoquant-1.0.0b1.dist-info/METADATA +193 -0
- senoquant-1.0.0b1.dist-info/RECORD +148 -0
- senoquant-1.0.0b1.dist-info/WHEEL +5 -0
- senoquant-1.0.0b1.dist-info/entry_points.txt +2 -0
- senoquant-1.0.0b1.dist-info/licenses/LICENSE +28 -0
- senoquant-1.0.0b1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,644 @@
|
|
|
1
|
+
from __future__ import print_function, unicode_literals, absolute_import, division
|
|
2
|
+
from six.moves import range, zip, map, reduce, filter
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import os
|
|
6
|
+
import warnings
|
|
7
|
+
import shutil
|
|
8
|
+
import datetime
|
|
9
|
+
from importlib import import_module
|
|
10
|
+
from packaging import version
|
|
11
|
+
|
|
12
|
+
from tensorflow import __version__ as _v_tf
|
|
13
|
+
v_tf = version.parse(_v_tf)
|
|
14
|
+
IS_TF_1 = v_tf.major == 1
|
|
15
|
+
IS_TF_2_6_PLUS = v_tf >= version.parse('2.6')
|
|
16
|
+
IS_TF_2_16_PLUS = v_tf >= version.parse('2.16')
|
|
17
|
+
IS_KERAS_3_PLUS = False
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def err_old_keras(v_keras, package='keras', min_version=None):
|
|
21
|
+
min_version = "{v_tf.major}.{v_tf.minor}".format(v_tf=v_tf) if min_version is None else min_version
|
|
22
|
+
return RuntimeError(\
|
|
23
|
+
"""
|
|
24
|
+
Found version {v_keras} of '{package}', which is too old for the installed version {v_tf} of 'tensorflow'.
|
|
25
|
+
Please update '{package}': pip install "{package}>={min_version}"
|
|
26
|
+
""".format(v_keras=v_keras, v_tf=v_tf, package=package, min_version=min_version))
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
if IS_TF_1:
|
|
30
|
+
try:
|
|
31
|
+
from keras import __version__ as _v_keras
|
|
32
|
+
v_keras = version.parse(_v_keras)
|
|
33
|
+
except ModuleNotFoundError:
|
|
34
|
+
raise RuntimeError(\
|
|
35
|
+
"""
|
|
36
|
+
For 'tensorflow' 1.x (found version {v_tf}), the stand-alone 'keras' package (2.1.2 <= version < 2.4) is required.
|
|
37
|
+
When using the most recent version of 'tensorflow' 1.x, install 'keras' like this: pip install "keras>=2.1.2,<2.4"
|
|
38
|
+
|
|
39
|
+
If you must use an older version of 'tensorflow' 1.x, please see this file for compatible versions of 'keras':
|
|
40
|
+
https://github.com/CSBDeep/CSBDeep/blob/main/.github/workflows/tests_legacy.yml
|
|
41
|
+
""".format(v_tf=v_tf))
|
|
42
|
+
|
|
43
|
+
elif IS_TF_2_16_PLUS:
|
|
44
|
+
# https://keras.io/getting_started/#tensorflow--keras-2-backwards-compatibility
|
|
45
|
+
# https://keras.io/keras_3/ -> "Moving from Keras 2 to Keras 3"
|
|
46
|
+
if os.environ.get('TF_USE_LEGACY_KERAS','0') == '1':
|
|
47
|
+
try:
|
|
48
|
+
from tf_keras import __version__ as _v_keras
|
|
49
|
+
v_keras = version.parse(_v_keras)
|
|
50
|
+
assert v_keras.major < 3
|
|
51
|
+
if (v_tf.major,v_tf.minor) > (v_keras.major,v_keras.minor):
|
|
52
|
+
raise err_old_keras(v_keras, package='tf_keras')
|
|
53
|
+
warnings.warn("Using Keras 2 (via 'tf_keras' package) with Tensorflow 2.16+ might work, but is not regularly tested.")
|
|
54
|
+
except ModuleNotFoundError:
|
|
55
|
+
raise RuntimeError("Found environment variable 'TF_USE_LEGACY_KERAS=1' but 'tf_keras' package not installed.")
|
|
56
|
+
else:
|
|
57
|
+
try:
|
|
58
|
+
from keras import __version__ as _v_keras
|
|
59
|
+
v_keras = version.parse(_v_keras)
|
|
60
|
+
IS_KERAS_3_PLUS = v_keras.major >= 3
|
|
61
|
+
if not IS_KERAS_3_PLUS:
|
|
62
|
+
raise err_old_keras(v_keras, min_version='3')
|
|
63
|
+
except ModuleNotFoundError:
|
|
64
|
+
raise RuntimeError("Package 'keras' is not installed.")
|
|
65
|
+
|
|
66
|
+
elif IS_TF_2_6_PLUS:
|
|
67
|
+
# https://github.com/CSBDeep/CSBDeep/releases/tag/0.6.3
|
|
68
|
+
try:
|
|
69
|
+
from keras import __version__ as _v_keras
|
|
70
|
+
v_keras = version.parse(_v_keras)
|
|
71
|
+
IS_KERAS_3_PLUS = v_keras.major >= 3
|
|
72
|
+
if IS_KERAS_3_PLUS:
|
|
73
|
+
warnings.warn("Using Keras version 3+ with Tensorflow below version 2.16 might work, but has not been tested.")
|
|
74
|
+
else:
|
|
75
|
+
if (v_tf.major,v_tf.minor) > (v_keras.major,v_keras.minor):
|
|
76
|
+
raise err_old_keras(v_keras)
|
|
77
|
+
except ModuleNotFoundError:
|
|
78
|
+
raise RuntimeError(\
|
|
79
|
+
"""
|
|
80
|
+
Starting with 'tensorflow' version 2.6.0, a recent version of the stand-alone 'keras' package is required.
|
|
81
|
+
Please install/update 'keras': pip install "keras>={v_tf.major}.{v_tf.minor}<3"
|
|
82
|
+
""".format(v_tf=v_tf))
|
|
83
|
+
|
|
84
|
+
else:
|
|
85
|
+
from tensorflow.keras import __version__ as _v_keras
|
|
86
|
+
v_keras = version.parse(_v_keras)
|
|
87
|
+
assert v_keras.major < 3
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
_KERAS = 'keras' if (IS_TF_1 or IS_KERAS_3_PLUS) else 'tensorflow.keras'
|
|
91
|
+
def keras_import(sub=None, *names):
|
|
92
|
+
if sub is None:
|
|
93
|
+
return import_module(_KERAS)
|
|
94
|
+
else:
|
|
95
|
+
mod = import_module('{_KERAS}.{sub}'.format(_KERAS=_KERAS,sub=sub))
|
|
96
|
+
if len(names) == 0:
|
|
97
|
+
return mod
|
|
98
|
+
elif len(names) == 1:
|
|
99
|
+
return getattr(mod, names[0])
|
|
100
|
+
return tuple(getattr(mod, name) for name in names)
|
|
101
|
+
|
|
102
|
+
import tensorflow as tf
|
|
103
|
+
# if IS_TF_1:
|
|
104
|
+
# import tensorflow as tf
|
|
105
|
+
# else:
|
|
106
|
+
# import tensorflow.compat.v1 as tf
|
|
107
|
+
# # tf.disable_v2_behavior()
|
|
108
|
+
|
|
109
|
+
keras = keras_import()
|
|
110
|
+
K = keras_import('backend')
|
|
111
|
+
Callback = keras_import('callbacks', 'Callback')
|
|
112
|
+
Lambda = keras_import('layers', 'Lambda')
|
|
113
|
+
|
|
114
|
+
from .utils import _raise, is_tf_backend, save_json, backend_channels_last, normalize
|
|
115
|
+
from .six import tempfile
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
if IS_KERAS_3_PLUS:
|
|
119
|
+
OPS = keras_import('ops')
|
|
120
|
+
_remap = dict(tf=tf, int_shape=OPS.shape)
|
|
121
|
+
class Backend(object):
|
|
122
|
+
def __getattr__(self, name):
|
|
123
|
+
if name in _remap:
|
|
124
|
+
return _remap[name]
|
|
125
|
+
elif hasattr(K, name):
|
|
126
|
+
return getattr(K, name)
|
|
127
|
+
else:
|
|
128
|
+
return getattr(OPS, name)
|
|
129
|
+
else:
|
|
130
|
+
_remap = dict(tf=tf)
|
|
131
|
+
class Backend(object):
|
|
132
|
+
def __getattr__(self, name):
|
|
133
|
+
if name in _remap:
|
|
134
|
+
return _remap[name]
|
|
135
|
+
else:
|
|
136
|
+
return getattr(K, name)
|
|
137
|
+
BACKEND = Backend()
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def limit_gpu_memory(fraction, allow_growth=False, total_memory=None):
|
|
142
|
+
"""Limit GPU memory allocation for TensorFlow (TF) backend.
|
|
143
|
+
|
|
144
|
+
Parameters
|
|
145
|
+
----------
|
|
146
|
+
fraction : float
|
|
147
|
+
Limit TF to use only a fraction (value between 0 and 1) of the available GPU memory.
|
|
148
|
+
Reduced memory allocation can be disabled if fraction is set to ``None``.
|
|
149
|
+
allow_growth : bool, optional
|
|
150
|
+
If ``False`` (default), TF will allocate all designated (see `fraction`) memory all at once.
|
|
151
|
+
If ``True``, TF will allocate memory as needed up to the limit imposed by `fraction`; this may
|
|
152
|
+
incur a performance penalty due to memory fragmentation.
|
|
153
|
+
total_memory : int or iterable of int
|
|
154
|
+
Total amount of available GPU memory (in MB).
|
|
155
|
+
|
|
156
|
+
Raises
|
|
157
|
+
------
|
|
158
|
+
ValueError
|
|
159
|
+
If `fraction` is not ``None`` or a float value between 0 and 1.
|
|
160
|
+
NotImplementedError
|
|
161
|
+
If TensorFlow is not used as the backend.
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
is_tf_backend() or _raise(NotImplementedError('Not using tensorflow backend.'))
|
|
165
|
+
fraction is None or (np.isscalar(fraction) and 0<=fraction<=1) or _raise(ValueError('fraction must be between 0 and 1.'))
|
|
166
|
+
|
|
167
|
+
if IS_TF_1:
|
|
168
|
+
_session = None
|
|
169
|
+
try:
|
|
170
|
+
_session = K.tensorflow_backend._SESSION
|
|
171
|
+
except AttributeError:
|
|
172
|
+
pass
|
|
173
|
+
|
|
174
|
+
if _session is None:
|
|
175
|
+
config = tf.ConfigProto()
|
|
176
|
+
if fraction is not None:
|
|
177
|
+
config.gpu_options.per_process_gpu_memory_fraction = fraction
|
|
178
|
+
config.gpu_options.allow_growth = bool(allow_growth)
|
|
179
|
+
session = tf.Session(config=config)
|
|
180
|
+
K.tensorflow_backend.set_session(session)
|
|
181
|
+
# print("[tf_limit]\t setting config.gpu_options.per_process_gpu_memory_fraction to ",config.gpu_options.per_process_gpu_memory_fraction)
|
|
182
|
+
else:
|
|
183
|
+
warnings.warn('Too late to limit GPU memory, can only be done once and before any computation.')
|
|
184
|
+
else:
|
|
185
|
+
gpus = tf.config.experimental.list_physical_devices('GPU')
|
|
186
|
+
if gpus:
|
|
187
|
+
if fraction is not None:
|
|
188
|
+
np.isscalar(total_memory) or _raise(ValueError("'total_memory' must be provided when using TensorFlow 2."))
|
|
189
|
+
vdc = tf.config.experimental.VirtualDeviceConfiguration(memory_limit=int(np.ceil(total_memory*fraction)))
|
|
190
|
+
try:
|
|
191
|
+
for gpu in gpus:
|
|
192
|
+
if fraction is not None:
|
|
193
|
+
tf.config.experimental.set_virtual_device_configuration(gpu,[vdc])
|
|
194
|
+
if allow_growth:
|
|
195
|
+
tf.config.experimental.set_memory_growth(gpu, True)
|
|
196
|
+
except RuntimeError as e:
|
|
197
|
+
# must be set before GPUs have been initialized
|
|
198
|
+
print(e)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def export_SavedModel(model, outpath, meta={}, format='zip'):
|
|
203
|
+
"""Export Keras model in TensorFlow's SavedModel_ format.
|
|
204
|
+
|
|
205
|
+
See `Your Model in Fiji`_ to learn how to use the exported model with our CSBDeep Fiji plugins.
|
|
206
|
+
|
|
207
|
+
.. _SavedModel: https://www.tensorflow.org/programmers_guide/saved_model#structure_of_a_savedmodel_directory
|
|
208
|
+
.. _`Your Model in Fiji`: https://github.com/CSBDeep/CSBDeep_website/wiki/Your-Model-in-Fiji
|
|
209
|
+
|
|
210
|
+
Parameters
|
|
211
|
+
----------
|
|
212
|
+
model : :class:`keras.models.Model`
|
|
213
|
+
Keras model to be exported.
|
|
214
|
+
outpath : str
|
|
215
|
+
Path of the file/folder that the model will exported to.
|
|
216
|
+
meta : dict, optional
|
|
217
|
+
Metadata to be saved in an additional ``meta.json`` file.
|
|
218
|
+
format : str, optional
|
|
219
|
+
Can be 'dir' to export as a directory or 'zip' (default) to export as a ZIP file.
|
|
220
|
+
|
|
221
|
+
Raises
|
|
222
|
+
------
|
|
223
|
+
ValueError
|
|
224
|
+
Illegal arguments.
|
|
225
|
+
|
|
226
|
+
"""
|
|
227
|
+
if IS_KERAS_3_PLUS:
|
|
228
|
+
raise NotImplementedError(\
|
|
229
|
+
"""Exporting to SavedModel is no longer supported with Keras 3+.
|
|
230
|
+
|
|
231
|
+
It is likely that the exported model *will only work* in associated ImageJ/Fiji
|
|
232
|
+
plugins (e.g. CSBDeep and StarDist) when using 'tensorflow' 1.x to export the model.
|
|
233
|
+
|
|
234
|
+
The current workaround is to load the trained model in a Python environment with
|
|
235
|
+
installed 'tensorflow' 1.x and then export it again. If you need help with this, please read:
|
|
236
|
+
|
|
237
|
+
https://gist.github.com/uschmidt83/4b747862fe307044c722d6d1009f6183
|
|
238
|
+
""")
|
|
239
|
+
|
|
240
|
+
def export_to_dir(dirname):
|
|
241
|
+
if len(model.inputs) > 1 or len(model.outputs) > 1:
|
|
242
|
+
warnings.warn('Found multiple input or output layers.')
|
|
243
|
+
|
|
244
|
+
def _export(model):
|
|
245
|
+
if IS_TF_1:
|
|
246
|
+
from tensorflow import saved_model
|
|
247
|
+
from keras.backend import get_session
|
|
248
|
+
else:
|
|
249
|
+
from tensorflow.compat.v1 import saved_model
|
|
250
|
+
from tensorflow.compat.v1.keras.backend import get_session
|
|
251
|
+
|
|
252
|
+
if not IS_TF_1:
|
|
253
|
+
warnings.warn(\
|
|
254
|
+
"""
|
|
255
|
+
***IMPORTANT NOTE***
|
|
256
|
+
|
|
257
|
+
You are using 'tensorflow' 2.x, hence it is likely that the exported model *will not work*
|
|
258
|
+
in associated ImageJ/Fiji plugins (e.g. CSBDeep and StarDist).
|
|
259
|
+
|
|
260
|
+
If you indeed have problems loading the exported model in Fiji, the current workaround is
|
|
261
|
+
to load the trained model in a Python environment with installed 'tensorflow' 1.x and then
|
|
262
|
+
export it again. If you need help with this, please read:
|
|
263
|
+
|
|
264
|
+
https://gist.github.com/uschmidt83/4b747862fe307044c722d6d1009f6183
|
|
265
|
+
""")
|
|
266
|
+
|
|
267
|
+
builder = saved_model.builder.SavedModelBuilder(dirname)
|
|
268
|
+
# use name 'input'/'output' if there's just a single input/output layer
|
|
269
|
+
inputs = dict(zip(model.input_names,model.inputs)) if len(model.inputs) > 1 else dict(input=model.input)
|
|
270
|
+
outputs = dict(zip(model.output_names,model.outputs)) if len(model.outputs) > 1 else dict(output=model.output)
|
|
271
|
+
signature = saved_model.signature_def_utils.predict_signature_def(inputs=inputs, outputs=outputs)
|
|
272
|
+
signature_def_map = { saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature }
|
|
273
|
+
builder.add_meta_graph_and_variables(get_session(), [saved_model.tag_constants.SERVING], signature_def_map=signature_def_map)
|
|
274
|
+
builder.save()
|
|
275
|
+
|
|
276
|
+
if IS_TF_1:
|
|
277
|
+
_export(model)
|
|
278
|
+
else:
|
|
279
|
+
from tensorflow.keras.models import clone_model
|
|
280
|
+
weights = model.get_weights()
|
|
281
|
+
with tf.Graph().as_default():
|
|
282
|
+
# clone model in new graph and set weights
|
|
283
|
+
_model = clone_model(model)
|
|
284
|
+
_model.set_weights(weights)
|
|
285
|
+
_export(_model)
|
|
286
|
+
|
|
287
|
+
if meta is not None and len(meta) > 0:
|
|
288
|
+
save_json(meta, os.path.join(dirname,'meta.json'))
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
## checks
|
|
292
|
+
isinstance(model,keras.models.Model) or _raise(ValueError("'model' must be a Keras model."))
|
|
293
|
+
# supported_formats = tuple(['dir']+[name for name,description in shutil.get_archive_formats()])
|
|
294
|
+
supported_formats = 'dir','zip'
|
|
295
|
+
format in supported_formats or _raise(ValueError("Unsupported format '%s', must be one of %s." % (format,str(supported_formats))))
|
|
296
|
+
|
|
297
|
+
# remove '.zip' file name extension if necessary
|
|
298
|
+
if format == 'zip' and outpath.endswith('.zip'):
|
|
299
|
+
outpath = os.path.splitext(outpath)[0]
|
|
300
|
+
|
|
301
|
+
if format == 'dir':
|
|
302
|
+
export_to_dir(outpath)
|
|
303
|
+
else:
|
|
304
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
305
|
+
tmpsubdir = os.path.join(tmpdir,'model')
|
|
306
|
+
export_to_dir(tmpsubdir)
|
|
307
|
+
shutil.make_archive(outpath, format, tmpsubdir)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def tf_normalize(x, pmin=1, pmax=99.8, axis=None, clip=False):
|
|
312
|
+
assert pmin < pmax
|
|
313
|
+
mi = tf.contrib.distributions.percentile(x,pmin, axis=axis, keep_dims=True)
|
|
314
|
+
ma = tf.contrib.distributions.percentile(x,pmax, axis=axis, keep_dims=True)
|
|
315
|
+
y = (x-mi)/(ma-mi+K.epsilon())
|
|
316
|
+
if clip:
|
|
317
|
+
y = K.clip(y,0,1.0)
|
|
318
|
+
return y
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def tf_normalize_layer(layer, pmin=1, pmax=99.8, clip=True):
|
|
322
|
+
def norm(x,axis):
|
|
323
|
+
return tf_normalize(x, pmin=pmin, pmax=pmax, axis=axis, clip=clip)
|
|
324
|
+
|
|
325
|
+
shape = K.int_shape(layer)
|
|
326
|
+
n_channels_out = shape[-1]
|
|
327
|
+
n_dim_out = len(shape)
|
|
328
|
+
|
|
329
|
+
if n_dim_out > 4:
|
|
330
|
+
layer = Lambda(lambda x: K.max(x, axis=tuple(range(1,1+n_dim_out-4))))(layer)
|
|
331
|
+
|
|
332
|
+
assert 0 < n_channels_out
|
|
333
|
+
|
|
334
|
+
if n_channels_out == 1:
|
|
335
|
+
out = Lambda(lambda x: norm(x, axis=(1,2)))(layer)
|
|
336
|
+
elif n_channels_out == 2:
|
|
337
|
+
out = Lambda(lambda x: norm(K.concatenate([x,x[...,:1]], axis=-1), axis=(1,2,3)))(layer)
|
|
338
|
+
elif n_channels_out == 3:
|
|
339
|
+
out = Lambda(lambda x: norm(x, axis=(1,2,3)))(layer)
|
|
340
|
+
else:
|
|
341
|
+
out = Lambda(lambda x: norm(K.max(x, axis=-1, keepdims=True), axis=(1,2,3)))(layer)
|
|
342
|
+
return out
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
class CARETensorBoard(Callback):
|
|
346
|
+
""" TODO """
|
|
347
|
+
def __init__(self, log_dir='./logs',
|
|
348
|
+
freq=1,
|
|
349
|
+
compute_histograms=False,
|
|
350
|
+
n_images=3,
|
|
351
|
+
prob_out=False,
|
|
352
|
+
write_graph=False,
|
|
353
|
+
prefix_with_timestamp=True,
|
|
354
|
+
write_images=False,
|
|
355
|
+
image_for_inputs=None, # write images for only these input indices
|
|
356
|
+
image_for_outputs=None, # write images for only these output indices
|
|
357
|
+
input_slices=None, # list (of list) of slices to apply to `image_for_inputs` layers before writing image
|
|
358
|
+
output_slices=None, # list (of list) of slices to apply to `image_for_outputs` layers before writing image
|
|
359
|
+
output_target_shapes=None): # list of shapes of the target/gt images that correspond to all model outputs
|
|
360
|
+
super(CARETensorBoard, self).__init__()
|
|
361
|
+
is_tf_backend() or _raise(RuntimeError('TensorBoard callback only works with the TensorFlow backend.'))
|
|
362
|
+
backend_channels_last() or _raise(NotImplementedError())
|
|
363
|
+
IS_TF_1 or _raise(NotImplementedError("Not supported with TensorFlow 2"))
|
|
364
|
+
|
|
365
|
+
self.freq = freq
|
|
366
|
+
self.image_freq = freq
|
|
367
|
+
self.prob_out = prob_out
|
|
368
|
+
self.merged = None
|
|
369
|
+
self.gt_outputs = None
|
|
370
|
+
self.write_graph = write_graph
|
|
371
|
+
self.write_images = write_images
|
|
372
|
+
self.n_images = n_images
|
|
373
|
+
self.image_for_inputs = image_for_inputs
|
|
374
|
+
self.image_for_outputs = image_for_outputs
|
|
375
|
+
self.input_slices = input_slices
|
|
376
|
+
self.output_slices = output_slices
|
|
377
|
+
self.output_target_shapes = output_target_shapes
|
|
378
|
+
self.compute_histograms = compute_histograms
|
|
379
|
+
|
|
380
|
+
if prefix_with_timestamp:
|
|
381
|
+
log_dir = os.path.join(log_dir, datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S.%f"))
|
|
382
|
+
|
|
383
|
+
self.log_dir = log_dir
|
|
384
|
+
|
|
385
|
+
def set_model(self, model):
|
|
386
|
+
self.model = model
|
|
387
|
+
self.sess = K.get_session()
|
|
388
|
+
tf_sums = []
|
|
389
|
+
|
|
390
|
+
if self.compute_histograms and self.freq and self.merged is None:
|
|
391
|
+
for layer in self.model.layers:
|
|
392
|
+
for weight in layer.weights:
|
|
393
|
+
tf_sums.append(tf.summary.histogram(weight.name, weight))
|
|
394
|
+
|
|
395
|
+
if hasattr(layer, 'output'):
|
|
396
|
+
tf_sums.append(tf.summary.histogram('{}_out'.format(layer.name),
|
|
397
|
+
layer.output))
|
|
398
|
+
|
|
399
|
+
def _gt_shape(output_shape,target_shape):
|
|
400
|
+
if target_shape is not None:
|
|
401
|
+
output_shape = target_shape
|
|
402
|
+
if not self.prob_out: return output_shape
|
|
403
|
+
output_shape[-1] % 2 == 0 or _raise(ValueError())
|
|
404
|
+
return list(output_shape[:-1]) + [output_shape[-1] // 2]
|
|
405
|
+
|
|
406
|
+
n_inputs, n_outputs = len(self.model.inputs), len(self.model.outputs)
|
|
407
|
+
image_for_inputs = np.arange(n_inputs) if self.image_for_inputs is None else self.image_for_inputs
|
|
408
|
+
image_for_outputs = np.arange(n_outputs) if self.image_for_outputs is None else self.image_for_outputs
|
|
409
|
+
|
|
410
|
+
output_target_shapes = [None]*n_outputs if self.output_target_shapes is None else self.output_target_shapes
|
|
411
|
+
self.gt_outputs = [K.placeholder(shape=_gt_shape(K.int_shape(x),sh)) for x,sh in zip(self.model.outputs,output_target_shapes)]
|
|
412
|
+
|
|
413
|
+
input_slices = (slice(None),) if self.input_slices is None else self.input_slices
|
|
414
|
+
output_slices = (slice(None),) if self.output_slices is None else self.output_slices
|
|
415
|
+
if isinstance(input_slices[0],slice): # apply same slices to all inputs
|
|
416
|
+
input_slices = [input_slices]*len(image_for_inputs)
|
|
417
|
+
if isinstance(output_slices[0],slice): # apply same slices to all outputs
|
|
418
|
+
output_slices = [output_slices]*len(image_for_outputs)
|
|
419
|
+
len(input_slices) == len(image_for_inputs) or _raise(ValueError())
|
|
420
|
+
len(output_slices) == len(image_for_outputs) or _raise(ValueError())
|
|
421
|
+
|
|
422
|
+
def _name(prefix, layer, i, n, show_layer_names=False):
|
|
423
|
+
return '{prefix}{i}{name}'.format (
|
|
424
|
+
prefix = prefix,
|
|
425
|
+
i = (i if n > 1 else ''),
|
|
426
|
+
name = '' if (layer is None or not show_layer_names) else '_'+''.join(layer.name.split(':')[:-1]),
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
# inputs
|
|
430
|
+
for i,sl in zip(image_for_inputs,input_slices):
|
|
431
|
+
# print('input', self.model.inputs[i], tuple(sl))
|
|
432
|
+
layer_name = _name('net_input', self.model.inputs[i], i, n_inputs)
|
|
433
|
+
input_layer = tf_normalize_layer(self.model.inputs[i][tuple(sl)])
|
|
434
|
+
tf_sums.append(tf.summary.image(layer_name, input_layer, max_outputs=self.n_images))
|
|
435
|
+
|
|
436
|
+
# outputs
|
|
437
|
+
for i,sl in zip(image_for_outputs,output_slices):
|
|
438
|
+
# print('output', self.model.outputs[i], tuple(sl))
|
|
439
|
+
output_shape = self.model.output_shape if n_outputs==1 else self.model.output_shape[i]
|
|
440
|
+
# target
|
|
441
|
+
output_layer = tf_normalize_layer(self.gt_outputs[i][tuple(sl)])
|
|
442
|
+
layer_name = _name('net_target', self.model.outputs[i], i, n_outputs)
|
|
443
|
+
tf_sums.append(tf.summary.image(layer_name, output_layer, max_outputs=self.n_images))
|
|
444
|
+
# prediction
|
|
445
|
+
n_channels_out = sep = output_shape[-1]
|
|
446
|
+
if self.prob_out: # first half of output channels is mean, second half scale
|
|
447
|
+
n_channels_out % 2 == 0 or _raise(ValueError())
|
|
448
|
+
sep = sep // 2
|
|
449
|
+
output_layer = tf_normalize_layer(self.model.outputs[i][...,:sep][tuple(sl)])
|
|
450
|
+
if self.prob_out:
|
|
451
|
+
scale_layer = tf_normalize_layer(self.model.outputs[i][...,sep:][tuple(sl)], pmin=0, pmax=100)
|
|
452
|
+
mean_name = _name('net_output_mean', self.model.outputs[i], i, n_outputs)
|
|
453
|
+
scale_name = _name('net_output_scale', self.model.outputs[i], i, n_outputs)
|
|
454
|
+
tf_sums.append(tf.summary.image(mean_name, output_layer, max_outputs=self.n_images))
|
|
455
|
+
tf_sums.append(tf.summary.image(scale_name, scale_layer, max_outputs=self.n_images))
|
|
456
|
+
else:
|
|
457
|
+
layer_name = _name('net_output', self.model.outputs[i], i, n_outputs)
|
|
458
|
+
tf_sums.append(tf.summary.image(layer_name, output_layer, max_outputs=self.n_images))
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
with tf.name_scope('merged'):
|
|
462
|
+
self.merged = tf.summary.merge(tf_sums)
|
|
463
|
+
|
|
464
|
+
with tf.name_scope('summary_writer'):
|
|
465
|
+
if self.write_graph:
|
|
466
|
+
self.writer = tf.summary.FileWriter(self.log_dir,
|
|
467
|
+
self.sess.graph)
|
|
468
|
+
else:
|
|
469
|
+
self.writer = tf.summary.FileWriter(self.log_dir)
|
|
470
|
+
|
|
471
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
472
|
+
logs = logs or {}
|
|
473
|
+
|
|
474
|
+
if self.validation_data and self.freq:
|
|
475
|
+
if epoch % self.freq == 0:
|
|
476
|
+
# TODO: implement batched calls to sess.run
|
|
477
|
+
# (current call will likely go OOM on GPU)
|
|
478
|
+
|
|
479
|
+
tensors = self.model.inputs + self.gt_outputs + self.model.sample_weights
|
|
480
|
+
|
|
481
|
+
if self.model.uses_learning_phase:
|
|
482
|
+
tensors += [K.learning_phase()]
|
|
483
|
+
val_data = list(v[:self.n_images] for v in self.validation_data[:-1])
|
|
484
|
+
val_data += self.validation_data[-1:]
|
|
485
|
+
else:
|
|
486
|
+
val_data = list(v[:self.n_images] for v in self.validation_data)
|
|
487
|
+
|
|
488
|
+
feed_dict = dict(zip(tensors, val_data))
|
|
489
|
+
result = self.sess.run([self.merged], feed_dict=feed_dict)
|
|
490
|
+
summary_str = result[0]
|
|
491
|
+
|
|
492
|
+
self.writer.add_summary(summary_str, epoch)
|
|
493
|
+
|
|
494
|
+
for name, value in logs.items():
|
|
495
|
+
if name in ['batch', 'size']:
|
|
496
|
+
continue
|
|
497
|
+
summary = tf.Summary()
|
|
498
|
+
summary_value = summary.value.add()
|
|
499
|
+
summary_value.simple_value = float(value)
|
|
500
|
+
summary_value.tag = name
|
|
501
|
+
self.writer.add_summary(summary, epoch)
|
|
502
|
+
self.writer.flush()
|
|
503
|
+
|
|
504
|
+
def on_train_end(self, _):
|
|
505
|
+
self.writer.close()
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
class CARETensorBoardImage(Callback):
|
|
510
|
+
|
|
511
|
+
def __init__(self, model, data, log_dir,
|
|
512
|
+
n_images=3,
|
|
513
|
+
batch_size=1,
|
|
514
|
+
prob_out=False,
|
|
515
|
+
image_for_inputs=None, # write images for only these input indices
|
|
516
|
+
image_for_outputs=None, # write images for only these output indices
|
|
517
|
+
input_slices=None, # list (of list) of slices to apply to `image_for_inputs` layers before writing image
|
|
518
|
+
output_slices=None): # list (of list) of slices to apply to `image_for_outputs` layers before writing image
|
|
519
|
+
|
|
520
|
+
super(CARETensorBoardImage, self).__init__()
|
|
521
|
+
backend_channels_last() or _raise(NotImplementedError())
|
|
522
|
+
not IS_TF_1 or _raise(NotImplementedError("Not supported with TensorFlow 1"))
|
|
523
|
+
|
|
524
|
+
## model
|
|
525
|
+
from ..models import BaseModel
|
|
526
|
+
if isinstance(model, BaseModel):
|
|
527
|
+
model = model.keras_model
|
|
528
|
+
isinstance(model, keras.Model) or _raise(ValueError())
|
|
529
|
+
# self.model is already used by keras.callbacks.Callback in keras 3+
|
|
530
|
+
self.model_ = model
|
|
531
|
+
self.n_inputs = len(self.model_.inputs)
|
|
532
|
+
self.n_outputs = len(self.model_.outputs)
|
|
533
|
+
|
|
534
|
+
## data
|
|
535
|
+
if isinstance(data,(list,tuple)):
|
|
536
|
+
X, Y = data
|
|
537
|
+
else:
|
|
538
|
+
X, Y = data[0]
|
|
539
|
+
if self.n_inputs == 1 and isinstance(X,np.ndarray): X = (X,)
|
|
540
|
+
if self.n_outputs == 1 and isinstance(Y,np.ndarray): Y = (Y,)
|
|
541
|
+
(self.n_inputs == len(X) and self.n_outputs == len(Y)) or _raise(ValueError())
|
|
542
|
+
# all(len(v)>=n_images for v in X) or _raise(ValueError())
|
|
543
|
+
# all(len(v)>=n_images for v in Y) or _raise(ValueError())
|
|
544
|
+
self.X = [v[:n_images] for v in X]
|
|
545
|
+
self.Y = [v[:n_images] for v in Y]
|
|
546
|
+
|
|
547
|
+
self.log_dir = log_dir
|
|
548
|
+
self.batch_size = batch_size
|
|
549
|
+
self.prob_out = prob_out
|
|
550
|
+
|
|
551
|
+
## I/O choices
|
|
552
|
+
image_for_inputs = np.arange(self.n_inputs) if image_for_inputs is None else image_for_inputs
|
|
553
|
+
image_for_outputs = np.arange(self.n_outputs) if image_for_outputs is None else image_for_outputs
|
|
554
|
+
|
|
555
|
+
input_slices = (slice(None),) if input_slices is None else input_slices
|
|
556
|
+
output_slices = (slice(None),) if output_slices is None else output_slices
|
|
557
|
+
if isinstance(input_slices[0],slice): # apply same slices to all inputs
|
|
558
|
+
input_slices = [input_slices]*len(image_for_inputs)
|
|
559
|
+
if isinstance(output_slices[0],slice): # apply same slices to all outputs
|
|
560
|
+
output_slices = [output_slices]*len(image_for_outputs)
|
|
561
|
+
|
|
562
|
+
len(input_slices) == len(image_for_inputs) or _raise(ValueError())
|
|
563
|
+
len(output_slices) == len(image_for_outputs) or _raise(ValueError())
|
|
564
|
+
|
|
565
|
+
self.image_for_inputs = image_for_inputs
|
|
566
|
+
self.image_for_outputs = image_for_outputs
|
|
567
|
+
self.input_slices = input_slices
|
|
568
|
+
self.output_slices = output_slices
|
|
569
|
+
|
|
570
|
+
self.file_writer = tf.summary.create_file_writer(str(self.log_dir))
|
|
571
|
+
|
|
572
|
+
|
|
573
|
+
def _name(self, prefix, layer, i, n, show_layer_names=False):
|
|
574
|
+
return '{prefix}{i}{name}'.format (
|
|
575
|
+
prefix = prefix,
|
|
576
|
+
i = (i if n > 1 else ''),
|
|
577
|
+
name = '' if (layer is None or not show_layer_names) else '_'+''.join(layer.name.split(':')[:-1]),
|
|
578
|
+
)
|
|
579
|
+
|
|
580
|
+
|
|
581
|
+
def _normalize_image(self, x, pmin=1, pmax=99.8, clip=True):
|
|
582
|
+
def norm(x, axis):
|
|
583
|
+
return normalize(x, pmin=pmin, pmax=pmax, axis=axis, clip=clip)
|
|
584
|
+
|
|
585
|
+
n_channels_out = x.shape[-1]
|
|
586
|
+
n_dim_out = len(x.shape)
|
|
587
|
+
assert 0 < n_channels_out
|
|
588
|
+
|
|
589
|
+
if n_dim_out > 4:
|
|
590
|
+
x = np.max(x, axis=tuple(range(1,1+n_dim_out-4)))
|
|
591
|
+
|
|
592
|
+
if n_channels_out == 1:
|
|
593
|
+
out = norm(x, axis=(1,2))
|
|
594
|
+
elif n_channels_out == 2:
|
|
595
|
+
out = norm(np.concatenate([x,x[...,:1]], axis=-1), axis=(1,2,3))
|
|
596
|
+
elif n_channels_out == 3:
|
|
597
|
+
out = norm(x, axis=(1,2,3))
|
|
598
|
+
else:
|
|
599
|
+
out = norm(np.max(x, axis=-1, keepdims=True), axis=(1,2,3))
|
|
600
|
+
return out
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
604
|
+
# https://www.tensorflow.org/tensorboard/image_summaries
|
|
605
|
+
|
|
606
|
+
# inputs
|
|
607
|
+
if epoch == 0:
|
|
608
|
+
with self.file_writer.as_default():
|
|
609
|
+
for i,sl in zip(self.image_for_inputs,self.input_slices):
|
|
610
|
+
# print('input', self.model_.inputs[i], tuple(sl))
|
|
611
|
+
input_name = self._name('net_input', self.model_.inputs[i], i, self.n_inputs)
|
|
612
|
+
input_image = self._normalize_image(self.X[i][tuple(sl)])
|
|
613
|
+
tf.summary.image(input_name, input_image, step=epoch)
|
|
614
|
+
|
|
615
|
+
# outputs
|
|
616
|
+
Yhat = self.model_.predict(self.X, batch_size=self.batch_size, verbose=0)
|
|
617
|
+
if self.n_outputs == 1 and isinstance(Yhat,np.ndarray): Yhat = (Yhat,)
|
|
618
|
+
with self.file_writer.as_default():
|
|
619
|
+
for i,sl in zip(self.image_for_outputs,self.output_slices):
|
|
620
|
+
# print('output', self.model_.outputs[i], tuple(sl))
|
|
621
|
+
output_shape = self.model_.output_shape if self.n_outputs==1 else self.model_.output_shape[i]
|
|
622
|
+
n_channels_out = sep = output_shape[-1]
|
|
623
|
+
if self.prob_out: # first half of output channels is mean, second half scale
|
|
624
|
+
n_channels_out % 2 == 0 or _raise(ValueError())
|
|
625
|
+
sep = sep // 2
|
|
626
|
+
output_image = self._normalize_image(Yhat[i][...,:sep][tuple(sl)])
|
|
627
|
+
if self.prob_out:
|
|
628
|
+
scale_image = self._normalize_image(Yhat[i][...,sep:][tuple(sl)], pmin=0, pmax=100)
|
|
629
|
+
scale_name = self._name('net_output_scale', self.model_.outputs[i], i, self.n_outputs)
|
|
630
|
+
output_name = self._name('net_output_mean', self.model_.outputs[i], i, self.n_outputs)
|
|
631
|
+
tf.summary.image(output_name, output_image, step=epoch)
|
|
632
|
+
tf.summary.image(scale_name, scale_image, step=epoch)
|
|
633
|
+
else:
|
|
634
|
+
output_name = self._name('net_output', self.model_.outputs[i], i, self.n_outputs)
|
|
635
|
+
tf.summary.image(output_name, output_image, step=epoch)
|
|
636
|
+
|
|
637
|
+
# targets
|
|
638
|
+
if epoch == 0:
|
|
639
|
+
with self.file_writer.as_default():
|
|
640
|
+
for i,sl in zip(self.image_for_outputs,self.output_slices):
|
|
641
|
+
# print('target', self.model_.outputs[i], tuple(sl))
|
|
642
|
+
target_name = self._name('net_target', self.model_.outputs[i], i, self.n_outputs)
|
|
643
|
+
target_image = self._normalize_image(self.Y[i][tuple(sl)])
|
|
644
|
+
tf.summary.image(target_name, target_image, step=epoch)
|