senoquant 1.0.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. senoquant/__init__.py +6 -0
  2. senoquant/_reader.py +7 -0
  3. senoquant/_widget.py +33 -0
  4. senoquant/napari.yaml +83 -0
  5. senoquant/reader/__init__.py +5 -0
  6. senoquant/reader/core.py +369 -0
  7. senoquant/tabs/__init__.py +15 -0
  8. senoquant/tabs/batch/__init__.py +10 -0
  9. senoquant/tabs/batch/backend.py +641 -0
  10. senoquant/tabs/batch/config.py +270 -0
  11. senoquant/tabs/batch/frontend.py +1283 -0
  12. senoquant/tabs/batch/io.py +326 -0
  13. senoquant/tabs/batch/layers.py +86 -0
  14. senoquant/tabs/quantification/__init__.py +1 -0
  15. senoquant/tabs/quantification/backend.py +228 -0
  16. senoquant/tabs/quantification/features/__init__.py +80 -0
  17. senoquant/tabs/quantification/features/base.py +142 -0
  18. senoquant/tabs/quantification/features/marker/__init__.py +5 -0
  19. senoquant/tabs/quantification/features/marker/config.py +69 -0
  20. senoquant/tabs/quantification/features/marker/dialog.py +437 -0
  21. senoquant/tabs/quantification/features/marker/export.py +879 -0
  22. senoquant/tabs/quantification/features/marker/feature.py +119 -0
  23. senoquant/tabs/quantification/features/marker/morphology.py +285 -0
  24. senoquant/tabs/quantification/features/marker/rows.py +654 -0
  25. senoquant/tabs/quantification/features/marker/thresholding.py +46 -0
  26. senoquant/tabs/quantification/features/roi.py +346 -0
  27. senoquant/tabs/quantification/features/spots/__init__.py +5 -0
  28. senoquant/tabs/quantification/features/spots/config.py +62 -0
  29. senoquant/tabs/quantification/features/spots/dialog.py +477 -0
  30. senoquant/tabs/quantification/features/spots/export.py +1292 -0
  31. senoquant/tabs/quantification/features/spots/feature.py +112 -0
  32. senoquant/tabs/quantification/features/spots/morphology.py +279 -0
  33. senoquant/tabs/quantification/features/spots/rows.py +241 -0
  34. senoquant/tabs/quantification/frontend.py +815 -0
  35. senoquant/tabs/segmentation/__init__.py +1 -0
  36. senoquant/tabs/segmentation/backend.py +131 -0
  37. senoquant/tabs/segmentation/frontend.py +1009 -0
  38. senoquant/tabs/segmentation/models/__init__.py +5 -0
  39. senoquant/tabs/segmentation/models/base.py +146 -0
  40. senoquant/tabs/segmentation/models/cpsam/details.json +65 -0
  41. senoquant/tabs/segmentation/models/cpsam/model.py +150 -0
  42. senoquant/tabs/segmentation/models/default_2d/details.json +69 -0
  43. senoquant/tabs/segmentation/models/default_2d/model.py +664 -0
  44. senoquant/tabs/segmentation/models/default_3d/details.json +69 -0
  45. senoquant/tabs/segmentation/models/default_3d/model.py +682 -0
  46. senoquant/tabs/segmentation/models/hf.py +71 -0
  47. senoquant/tabs/segmentation/models/nuclear_dilation/__init__.py +1 -0
  48. senoquant/tabs/segmentation/models/nuclear_dilation/details.json +26 -0
  49. senoquant/tabs/segmentation/models/nuclear_dilation/model.py +96 -0
  50. senoquant/tabs/segmentation/models/perinuclear_rings/__init__.py +1 -0
  51. senoquant/tabs/segmentation/models/perinuclear_rings/details.json +34 -0
  52. senoquant/tabs/segmentation/models/perinuclear_rings/model.py +132 -0
  53. senoquant/tabs/segmentation/stardist_onnx_utils/__init__.py +2 -0
  54. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/__init__.py +3 -0
  55. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/__init__.py +6 -0
  56. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/generate.py +470 -0
  57. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/prepare.py +273 -0
  58. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/rawdata.py +112 -0
  59. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/transform.py +384 -0
  60. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/__init__.py +0 -0
  61. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/blocks.py +184 -0
  62. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/losses.py +79 -0
  63. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/nets.py +165 -0
  64. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/predict.py +467 -0
  65. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/probability.py +67 -0
  66. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/train.py +148 -0
  67. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/io/__init__.py +163 -0
  68. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/__init__.py +52 -0
  69. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/base_model.py +329 -0
  70. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_isotropic.py +160 -0
  71. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_projection.py +178 -0
  72. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_standard.py +446 -0
  73. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_upsampling.py +54 -0
  74. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/config.py +254 -0
  75. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/pretrained.py +119 -0
  76. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/__init__.py +0 -0
  77. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/care_predict.py +180 -0
  78. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/__init__.py +5 -0
  79. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/plot_utils.py +159 -0
  80. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/six.py +18 -0
  81. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/tf.py +644 -0
  82. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/utils.py +272 -0
  83. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/version.py +1 -0
  84. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/docs/source/conf.py +368 -0
  85. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/setup.py +68 -0
  86. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_datagen.py +169 -0
  87. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_models.py +462 -0
  88. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_utils.py +166 -0
  89. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tools/create_zip_contents.py +34 -0
  90. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/__init__.py +30 -0
  91. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/big.py +624 -0
  92. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/bioimageio_utils.py +494 -0
  93. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/data/__init__.py +39 -0
  94. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/__init__.py +10 -0
  95. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom2d.py +215 -0
  96. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom3d.py +349 -0
  97. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/matching.py +483 -0
  98. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/__init__.py +28 -0
  99. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/base.py +1217 -0
  100. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model2d.py +594 -0
  101. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model3d.py +696 -0
  102. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/nms.py +384 -0
  103. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/__init__.py +2 -0
  104. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/plot.py +74 -0
  105. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/render.py +298 -0
  106. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/rays3d.py +373 -0
  107. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/sample_patches.py +65 -0
  108. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/__init__.py +0 -0
  109. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict2d.py +90 -0
  110. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict3d.py +93 -0
  111. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/utils.py +408 -0
  112. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/version.py +1 -0
  113. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/__init__.py +45 -0
  114. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/__init__.py +17 -0
  115. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/cli.py +55 -0
  116. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/core.py +285 -0
  117. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/__init__.py +15 -0
  118. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/cli.py +36 -0
  119. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/divisibility.py +193 -0
  120. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/probe.py +100 -0
  121. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/receptive_field.py +182 -0
  122. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/rf_cli.py +48 -0
  123. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/valid_sizes.py +278 -0
  124. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/__init__.py +8 -0
  125. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/core.py +157 -0
  126. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/__init__.py +17 -0
  127. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/core.py +226 -0
  128. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/__init__.py +5 -0
  129. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/core.py +401 -0
  130. senoquant/tabs/settings/__init__.py +1 -0
  131. senoquant/tabs/settings/backend.py +29 -0
  132. senoquant/tabs/settings/frontend.py +19 -0
  133. senoquant/tabs/spots/__init__.py +1 -0
  134. senoquant/tabs/spots/backend.py +139 -0
  135. senoquant/tabs/spots/frontend.py +800 -0
  136. senoquant/tabs/spots/models/__init__.py +5 -0
  137. senoquant/tabs/spots/models/base.py +94 -0
  138. senoquant/tabs/spots/models/rmp/details.json +61 -0
  139. senoquant/tabs/spots/models/rmp/model.py +499 -0
  140. senoquant/tabs/spots/models/udwt/details.json +103 -0
  141. senoquant/tabs/spots/models/udwt/model.py +482 -0
  142. senoquant/utils.py +25 -0
  143. senoquant-1.0.0b1.dist-info/METADATA +193 -0
  144. senoquant-1.0.0b1.dist-info/RECORD +148 -0
  145. senoquant-1.0.0b1.dist-info/WHEEL +5 -0
  146. senoquant-1.0.0b1.dist-info/entry_points.txt +2 -0
  147. senoquant-1.0.0b1.dist-info/licenses/LICENSE +28 -0
  148. senoquant-1.0.0b1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,644 @@
1
+ from __future__ import print_function, unicode_literals, absolute_import, division
2
+ from six.moves import range, zip, map, reduce, filter
3
+
4
+ import numpy as np
5
+ import os
6
+ import warnings
7
+ import shutil
8
+ import datetime
9
+ from importlib import import_module
10
+ from packaging import version
11
+
12
+ from tensorflow import __version__ as _v_tf
13
+ v_tf = version.parse(_v_tf)
14
+ IS_TF_1 = v_tf.major == 1
15
+ IS_TF_2_6_PLUS = v_tf >= version.parse('2.6')
16
+ IS_TF_2_16_PLUS = v_tf >= version.parse('2.16')
17
+ IS_KERAS_3_PLUS = False
18
+
19
+
20
+ def err_old_keras(v_keras, package='keras', min_version=None):
21
+ min_version = "{v_tf.major}.{v_tf.minor}".format(v_tf=v_tf) if min_version is None else min_version
22
+ return RuntimeError(\
23
+ """
24
+ Found version {v_keras} of '{package}', which is too old for the installed version {v_tf} of 'tensorflow'.
25
+ Please update '{package}': pip install "{package}>={min_version}"
26
+ """.format(v_keras=v_keras, v_tf=v_tf, package=package, min_version=min_version))
27
+
28
+
29
+ if IS_TF_1:
30
+ try:
31
+ from keras import __version__ as _v_keras
32
+ v_keras = version.parse(_v_keras)
33
+ except ModuleNotFoundError:
34
+ raise RuntimeError(\
35
+ """
36
+ For 'tensorflow' 1.x (found version {v_tf}), the stand-alone 'keras' package (2.1.2 <= version < 2.4) is required.
37
+ When using the most recent version of 'tensorflow' 1.x, install 'keras' like this: pip install "keras>=2.1.2,<2.4"
38
+
39
+ If you must use an older version of 'tensorflow' 1.x, please see this file for compatible versions of 'keras':
40
+ https://github.com/CSBDeep/CSBDeep/blob/main/.github/workflows/tests_legacy.yml
41
+ """.format(v_tf=v_tf))
42
+
43
+ elif IS_TF_2_16_PLUS:
44
+ # https://keras.io/getting_started/#tensorflow--keras-2-backwards-compatibility
45
+ # https://keras.io/keras_3/ -> "Moving from Keras 2 to Keras 3"
46
+ if os.environ.get('TF_USE_LEGACY_KERAS','0') == '1':
47
+ try:
48
+ from tf_keras import __version__ as _v_keras
49
+ v_keras = version.parse(_v_keras)
50
+ assert v_keras.major < 3
51
+ if (v_tf.major,v_tf.minor) > (v_keras.major,v_keras.minor):
52
+ raise err_old_keras(v_keras, package='tf_keras')
53
+ warnings.warn("Using Keras 2 (via 'tf_keras' package) with Tensorflow 2.16+ might work, but is not regularly tested.")
54
+ except ModuleNotFoundError:
55
+ raise RuntimeError("Found environment variable 'TF_USE_LEGACY_KERAS=1' but 'tf_keras' package not installed.")
56
+ else:
57
+ try:
58
+ from keras import __version__ as _v_keras
59
+ v_keras = version.parse(_v_keras)
60
+ IS_KERAS_3_PLUS = v_keras.major >= 3
61
+ if not IS_KERAS_3_PLUS:
62
+ raise err_old_keras(v_keras, min_version='3')
63
+ except ModuleNotFoundError:
64
+ raise RuntimeError("Package 'keras' is not installed.")
65
+
66
+ elif IS_TF_2_6_PLUS:
67
+ # https://github.com/CSBDeep/CSBDeep/releases/tag/0.6.3
68
+ try:
69
+ from keras import __version__ as _v_keras
70
+ v_keras = version.parse(_v_keras)
71
+ IS_KERAS_3_PLUS = v_keras.major >= 3
72
+ if IS_KERAS_3_PLUS:
73
+ warnings.warn("Using Keras version 3+ with Tensorflow below version 2.16 might work, but has not been tested.")
74
+ else:
75
+ if (v_tf.major,v_tf.minor) > (v_keras.major,v_keras.minor):
76
+ raise err_old_keras(v_keras)
77
+ except ModuleNotFoundError:
78
+ raise RuntimeError(\
79
+ """
80
+ Starting with 'tensorflow' version 2.6.0, a recent version of the stand-alone 'keras' package is required.
81
+ Please install/update 'keras': pip install "keras>={v_tf.major}.{v_tf.minor}<3"
82
+ """.format(v_tf=v_tf))
83
+
84
+ else:
85
+ from tensorflow.keras import __version__ as _v_keras
86
+ v_keras = version.parse(_v_keras)
87
+ assert v_keras.major < 3
88
+
89
+
90
+ _KERAS = 'keras' if (IS_TF_1 or IS_KERAS_3_PLUS) else 'tensorflow.keras'
91
+ def keras_import(sub=None, *names):
92
+ if sub is None:
93
+ return import_module(_KERAS)
94
+ else:
95
+ mod = import_module('{_KERAS}.{sub}'.format(_KERAS=_KERAS,sub=sub))
96
+ if len(names) == 0:
97
+ return mod
98
+ elif len(names) == 1:
99
+ return getattr(mod, names[0])
100
+ return tuple(getattr(mod, name) for name in names)
101
+
102
+ import tensorflow as tf
103
+ # if IS_TF_1:
104
+ # import tensorflow as tf
105
+ # else:
106
+ # import tensorflow.compat.v1 as tf
107
+ # # tf.disable_v2_behavior()
108
+
109
+ keras = keras_import()
110
+ K = keras_import('backend')
111
+ Callback = keras_import('callbacks', 'Callback')
112
+ Lambda = keras_import('layers', 'Lambda')
113
+
114
+ from .utils import _raise, is_tf_backend, save_json, backend_channels_last, normalize
115
+ from .six import tempfile
116
+
117
+
118
+ if IS_KERAS_3_PLUS:
119
+ OPS = keras_import('ops')
120
+ _remap = dict(tf=tf, int_shape=OPS.shape)
121
+ class Backend(object):
122
+ def __getattr__(self, name):
123
+ if name in _remap:
124
+ return _remap[name]
125
+ elif hasattr(K, name):
126
+ return getattr(K, name)
127
+ else:
128
+ return getattr(OPS, name)
129
+ else:
130
+ _remap = dict(tf=tf)
131
+ class Backend(object):
132
+ def __getattr__(self, name):
133
+ if name in _remap:
134
+ return _remap[name]
135
+ else:
136
+ return getattr(K, name)
137
+ BACKEND = Backend()
138
+
139
+
140
+
141
+ def limit_gpu_memory(fraction, allow_growth=False, total_memory=None):
142
+ """Limit GPU memory allocation for TensorFlow (TF) backend.
143
+
144
+ Parameters
145
+ ----------
146
+ fraction : float
147
+ Limit TF to use only a fraction (value between 0 and 1) of the available GPU memory.
148
+ Reduced memory allocation can be disabled if fraction is set to ``None``.
149
+ allow_growth : bool, optional
150
+ If ``False`` (default), TF will allocate all designated (see `fraction`) memory all at once.
151
+ If ``True``, TF will allocate memory as needed up to the limit imposed by `fraction`; this may
152
+ incur a performance penalty due to memory fragmentation.
153
+ total_memory : int or iterable of int
154
+ Total amount of available GPU memory (in MB).
155
+
156
+ Raises
157
+ ------
158
+ ValueError
159
+ If `fraction` is not ``None`` or a float value between 0 and 1.
160
+ NotImplementedError
161
+ If TensorFlow is not used as the backend.
162
+ """
163
+
164
+ is_tf_backend() or _raise(NotImplementedError('Not using tensorflow backend.'))
165
+ fraction is None or (np.isscalar(fraction) and 0<=fraction<=1) or _raise(ValueError('fraction must be between 0 and 1.'))
166
+
167
+ if IS_TF_1:
168
+ _session = None
169
+ try:
170
+ _session = K.tensorflow_backend._SESSION
171
+ except AttributeError:
172
+ pass
173
+
174
+ if _session is None:
175
+ config = tf.ConfigProto()
176
+ if fraction is not None:
177
+ config.gpu_options.per_process_gpu_memory_fraction = fraction
178
+ config.gpu_options.allow_growth = bool(allow_growth)
179
+ session = tf.Session(config=config)
180
+ K.tensorflow_backend.set_session(session)
181
+ # print("[tf_limit]\t setting config.gpu_options.per_process_gpu_memory_fraction to ",config.gpu_options.per_process_gpu_memory_fraction)
182
+ else:
183
+ warnings.warn('Too late to limit GPU memory, can only be done once and before any computation.')
184
+ else:
185
+ gpus = tf.config.experimental.list_physical_devices('GPU')
186
+ if gpus:
187
+ if fraction is not None:
188
+ np.isscalar(total_memory) or _raise(ValueError("'total_memory' must be provided when using TensorFlow 2."))
189
+ vdc = tf.config.experimental.VirtualDeviceConfiguration(memory_limit=int(np.ceil(total_memory*fraction)))
190
+ try:
191
+ for gpu in gpus:
192
+ if fraction is not None:
193
+ tf.config.experimental.set_virtual_device_configuration(gpu,[vdc])
194
+ if allow_growth:
195
+ tf.config.experimental.set_memory_growth(gpu, True)
196
+ except RuntimeError as e:
197
+ # must be set before GPUs have been initialized
198
+ print(e)
199
+
200
+
201
+
202
+ def export_SavedModel(model, outpath, meta={}, format='zip'):
203
+ """Export Keras model in TensorFlow's SavedModel_ format.
204
+
205
+ See `Your Model in Fiji`_ to learn how to use the exported model with our CSBDeep Fiji plugins.
206
+
207
+ .. _SavedModel: https://www.tensorflow.org/programmers_guide/saved_model#structure_of_a_savedmodel_directory
208
+ .. _`Your Model in Fiji`: https://github.com/CSBDeep/CSBDeep_website/wiki/Your-Model-in-Fiji
209
+
210
+ Parameters
211
+ ----------
212
+ model : :class:`keras.models.Model`
213
+ Keras model to be exported.
214
+ outpath : str
215
+ Path of the file/folder that the model will exported to.
216
+ meta : dict, optional
217
+ Metadata to be saved in an additional ``meta.json`` file.
218
+ format : str, optional
219
+ Can be 'dir' to export as a directory or 'zip' (default) to export as a ZIP file.
220
+
221
+ Raises
222
+ ------
223
+ ValueError
224
+ Illegal arguments.
225
+
226
+ """
227
+ if IS_KERAS_3_PLUS:
228
+ raise NotImplementedError(\
229
+ """Exporting to SavedModel is no longer supported with Keras 3+.
230
+
231
+ It is likely that the exported model *will only work* in associated ImageJ/Fiji
232
+ plugins (e.g. CSBDeep and StarDist) when using 'tensorflow' 1.x to export the model.
233
+
234
+ The current workaround is to load the trained model in a Python environment with
235
+ installed 'tensorflow' 1.x and then export it again. If you need help with this, please read:
236
+
237
+ https://gist.github.com/uschmidt83/4b747862fe307044c722d6d1009f6183
238
+ """)
239
+
240
+ def export_to_dir(dirname):
241
+ if len(model.inputs) > 1 or len(model.outputs) > 1:
242
+ warnings.warn('Found multiple input or output layers.')
243
+
244
+ def _export(model):
245
+ if IS_TF_1:
246
+ from tensorflow import saved_model
247
+ from keras.backend import get_session
248
+ else:
249
+ from tensorflow.compat.v1 import saved_model
250
+ from tensorflow.compat.v1.keras.backend import get_session
251
+
252
+ if not IS_TF_1:
253
+ warnings.warn(\
254
+ """
255
+ ***IMPORTANT NOTE***
256
+
257
+ You are using 'tensorflow' 2.x, hence it is likely that the exported model *will not work*
258
+ in associated ImageJ/Fiji plugins (e.g. CSBDeep and StarDist).
259
+
260
+ If you indeed have problems loading the exported model in Fiji, the current workaround is
261
+ to load the trained model in a Python environment with installed 'tensorflow' 1.x and then
262
+ export it again. If you need help with this, please read:
263
+
264
+ https://gist.github.com/uschmidt83/4b747862fe307044c722d6d1009f6183
265
+ """)
266
+
267
+ builder = saved_model.builder.SavedModelBuilder(dirname)
268
+ # use name 'input'/'output' if there's just a single input/output layer
269
+ inputs = dict(zip(model.input_names,model.inputs)) if len(model.inputs) > 1 else dict(input=model.input)
270
+ outputs = dict(zip(model.output_names,model.outputs)) if len(model.outputs) > 1 else dict(output=model.output)
271
+ signature = saved_model.signature_def_utils.predict_signature_def(inputs=inputs, outputs=outputs)
272
+ signature_def_map = { saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature }
273
+ builder.add_meta_graph_and_variables(get_session(), [saved_model.tag_constants.SERVING], signature_def_map=signature_def_map)
274
+ builder.save()
275
+
276
+ if IS_TF_1:
277
+ _export(model)
278
+ else:
279
+ from tensorflow.keras.models import clone_model
280
+ weights = model.get_weights()
281
+ with tf.Graph().as_default():
282
+ # clone model in new graph and set weights
283
+ _model = clone_model(model)
284
+ _model.set_weights(weights)
285
+ _export(_model)
286
+
287
+ if meta is not None and len(meta) > 0:
288
+ save_json(meta, os.path.join(dirname,'meta.json'))
289
+
290
+
291
+ ## checks
292
+ isinstance(model,keras.models.Model) or _raise(ValueError("'model' must be a Keras model."))
293
+ # supported_formats = tuple(['dir']+[name for name,description in shutil.get_archive_formats()])
294
+ supported_formats = 'dir','zip'
295
+ format in supported_formats or _raise(ValueError("Unsupported format '%s', must be one of %s." % (format,str(supported_formats))))
296
+
297
+ # remove '.zip' file name extension if necessary
298
+ if format == 'zip' and outpath.endswith('.zip'):
299
+ outpath = os.path.splitext(outpath)[0]
300
+
301
+ if format == 'dir':
302
+ export_to_dir(outpath)
303
+ else:
304
+ with tempfile.TemporaryDirectory() as tmpdir:
305
+ tmpsubdir = os.path.join(tmpdir,'model')
306
+ export_to_dir(tmpsubdir)
307
+ shutil.make_archive(outpath, format, tmpsubdir)
308
+
309
+
310
+
311
+ def tf_normalize(x, pmin=1, pmax=99.8, axis=None, clip=False):
312
+ assert pmin < pmax
313
+ mi = tf.contrib.distributions.percentile(x,pmin, axis=axis, keep_dims=True)
314
+ ma = tf.contrib.distributions.percentile(x,pmax, axis=axis, keep_dims=True)
315
+ y = (x-mi)/(ma-mi+K.epsilon())
316
+ if clip:
317
+ y = K.clip(y,0,1.0)
318
+ return y
319
+
320
+
321
+ def tf_normalize_layer(layer, pmin=1, pmax=99.8, clip=True):
322
+ def norm(x,axis):
323
+ return tf_normalize(x, pmin=pmin, pmax=pmax, axis=axis, clip=clip)
324
+
325
+ shape = K.int_shape(layer)
326
+ n_channels_out = shape[-1]
327
+ n_dim_out = len(shape)
328
+
329
+ if n_dim_out > 4:
330
+ layer = Lambda(lambda x: K.max(x, axis=tuple(range(1,1+n_dim_out-4))))(layer)
331
+
332
+ assert 0 < n_channels_out
333
+
334
+ if n_channels_out == 1:
335
+ out = Lambda(lambda x: norm(x, axis=(1,2)))(layer)
336
+ elif n_channels_out == 2:
337
+ out = Lambda(lambda x: norm(K.concatenate([x,x[...,:1]], axis=-1), axis=(1,2,3)))(layer)
338
+ elif n_channels_out == 3:
339
+ out = Lambda(lambda x: norm(x, axis=(1,2,3)))(layer)
340
+ else:
341
+ out = Lambda(lambda x: norm(K.max(x, axis=-1, keepdims=True), axis=(1,2,3)))(layer)
342
+ return out
343
+
344
+
345
+ class CARETensorBoard(Callback):
346
+ """ TODO """
347
+ def __init__(self, log_dir='./logs',
348
+ freq=1,
349
+ compute_histograms=False,
350
+ n_images=3,
351
+ prob_out=False,
352
+ write_graph=False,
353
+ prefix_with_timestamp=True,
354
+ write_images=False,
355
+ image_for_inputs=None, # write images for only these input indices
356
+ image_for_outputs=None, # write images for only these output indices
357
+ input_slices=None, # list (of list) of slices to apply to `image_for_inputs` layers before writing image
358
+ output_slices=None, # list (of list) of slices to apply to `image_for_outputs` layers before writing image
359
+ output_target_shapes=None): # list of shapes of the target/gt images that correspond to all model outputs
360
+ super(CARETensorBoard, self).__init__()
361
+ is_tf_backend() or _raise(RuntimeError('TensorBoard callback only works with the TensorFlow backend.'))
362
+ backend_channels_last() or _raise(NotImplementedError())
363
+ IS_TF_1 or _raise(NotImplementedError("Not supported with TensorFlow 2"))
364
+
365
+ self.freq = freq
366
+ self.image_freq = freq
367
+ self.prob_out = prob_out
368
+ self.merged = None
369
+ self.gt_outputs = None
370
+ self.write_graph = write_graph
371
+ self.write_images = write_images
372
+ self.n_images = n_images
373
+ self.image_for_inputs = image_for_inputs
374
+ self.image_for_outputs = image_for_outputs
375
+ self.input_slices = input_slices
376
+ self.output_slices = output_slices
377
+ self.output_target_shapes = output_target_shapes
378
+ self.compute_histograms = compute_histograms
379
+
380
+ if prefix_with_timestamp:
381
+ log_dir = os.path.join(log_dir, datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S.%f"))
382
+
383
+ self.log_dir = log_dir
384
+
385
+ def set_model(self, model):
386
+ self.model = model
387
+ self.sess = K.get_session()
388
+ tf_sums = []
389
+
390
+ if self.compute_histograms and self.freq and self.merged is None:
391
+ for layer in self.model.layers:
392
+ for weight in layer.weights:
393
+ tf_sums.append(tf.summary.histogram(weight.name, weight))
394
+
395
+ if hasattr(layer, 'output'):
396
+ tf_sums.append(tf.summary.histogram('{}_out'.format(layer.name),
397
+ layer.output))
398
+
399
+ def _gt_shape(output_shape,target_shape):
400
+ if target_shape is not None:
401
+ output_shape = target_shape
402
+ if not self.prob_out: return output_shape
403
+ output_shape[-1] % 2 == 0 or _raise(ValueError())
404
+ return list(output_shape[:-1]) + [output_shape[-1] // 2]
405
+
406
+ n_inputs, n_outputs = len(self.model.inputs), len(self.model.outputs)
407
+ image_for_inputs = np.arange(n_inputs) if self.image_for_inputs is None else self.image_for_inputs
408
+ image_for_outputs = np.arange(n_outputs) if self.image_for_outputs is None else self.image_for_outputs
409
+
410
+ output_target_shapes = [None]*n_outputs if self.output_target_shapes is None else self.output_target_shapes
411
+ self.gt_outputs = [K.placeholder(shape=_gt_shape(K.int_shape(x),sh)) for x,sh in zip(self.model.outputs,output_target_shapes)]
412
+
413
+ input_slices = (slice(None),) if self.input_slices is None else self.input_slices
414
+ output_slices = (slice(None),) if self.output_slices is None else self.output_slices
415
+ if isinstance(input_slices[0],slice): # apply same slices to all inputs
416
+ input_slices = [input_slices]*len(image_for_inputs)
417
+ if isinstance(output_slices[0],slice): # apply same slices to all outputs
418
+ output_slices = [output_slices]*len(image_for_outputs)
419
+ len(input_slices) == len(image_for_inputs) or _raise(ValueError())
420
+ len(output_slices) == len(image_for_outputs) or _raise(ValueError())
421
+
422
+ def _name(prefix, layer, i, n, show_layer_names=False):
423
+ return '{prefix}{i}{name}'.format (
424
+ prefix = prefix,
425
+ i = (i if n > 1 else ''),
426
+ name = '' if (layer is None or not show_layer_names) else '_'+''.join(layer.name.split(':')[:-1]),
427
+ )
428
+
429
+ # inputs
430
+ for i,sl in zip(image_for_inputs,input_slices):
431
+ # print('input', self.model.inputs[i], tuple(sl))
432
+ layer_name = _name('net_input', self.model.inputs[i], i, n_inputs)
433
+ input_layer = tf_normalize_layer(self.model.inputs[i][tuple(sl)])
434
+ tf_sums.append(tf.summary.image(layer_name, input_layer, max_outputs=self.n_images))
435
+
436
+ # outputs
437
+ for i,sl in zip(image_for_outputs,output_slices):
438
+ # print('output', self.model.outputs[i], tuple(sl))
439
+ output_shape = self.model.output_shape if n_outputs==1 else self.model.output_shape[i]
440
+ # target
441
+ output_layer = tf_normalize_layer(self.gt_outputs[i][tuple(sl)])
442
+ layer_name = _name('net_target', self.model.outputs[i], i, n_outputs)
443
+ tf_sums.append(tf.summary.image(layer_name, output_layer, max_outputs=self.n_images))
444
+ # prediction
445
+ n_channels_out = sep = output_shape[-1]
446
+ if self.prob_out: # first half of output channels is mean, second half scale
447
+ n_channels_out % 2 == 0 or _raise(ValueError())
448
+ sep = sep // 2
449
+ output_layer = tf_normalize_layer(self.model.outputs[i][...,:sep][tuple(sl)])
450
+ if self.prob_out:
451
+ scale_layer = tf_normalize_layer(self.model.outputs[i][...,sep:][tuple(sl)], pmin=0, pmax=100)
452
+ mean_name = _name('net_output_mean', self.model.outputs[i], i, n_outputs)
453
+ scale_name = _name('net_output_scale', self.model.outputs[i], i, n_outputs)
454
+ tf_sums.append(tf.summary.image(mean_name, output_layer, max_outputs=self.n_images))
455
+ tf_sums.append(tf.summary.image(scale_name, scale_layer, max_outputs=self.n_images))
456
+ else:
457
+ layer_name = _name('net_output', self.model.outputs[i], i, n_outputs)
458
+ tf_sums.append(tf.summary.image(layer_name, output_layer, max_outputs=self.n_images))
459
+
460
+
461
+ with tf.name_scope('merged'):
462
+ self.merged = tf.summary.merge(tf_sums)
463
+
464
+ with tf.name_scope('summary_writer'):
465
+ if self.write_graph:
466
+ self.writer = tf.summary.FileWriter(self.log_dir,
467
+ self.sess.graph)
468
+ else:
469
+ self.writer = tf.summary.FileWriter(self.log_dir)
470
+
471
+ def on_epoch_end(self, epoch, logs=None):
472
+ logs = logs or {}
473
+
474
+ if self.validation_data and self.freq:
475
+ if epoch % self.freq == 0:
476
+ # TODO: implement batched calls to sess.run
477
+ # (current call will likely go OOM on GPU)
478
+
479
+ tensors = self.model.inputs + self.gt_outputs + self.model.sample_weights
480
+
481
+ if self.model.uses_learning_phase:
482
+ tensors += [K.learning_phase()]
483
+ val_data = list(v[:self.n_images] for v in self.validation_data[:-1])
484
+ val_data += self.validation_data[-1:]
485
+ else:
486
+ val_data = list(v[:self.n_images] for v in self.validation_data)
487
+
488
+ feed_dict = dict(zip(tensors, val_data))
489
+ result = self.sess.run([self.merged], feed_dict=feed_dict)
490
+ summary_str = result[0]
491
+
492
+ self.writer.add_summary(summary_str, epoch)
493
+
494
+ for name, value in logs.items():
495
+ if name in ['batch', 'size']:
496
+ continue
497
+ summary = tf.Summary()
498
+ summary_value = summary.value.add()
499
+ summary_value.simple_value = float(value)
500
+ summary_value.tag = name
501
+ self.writer.add_summary(summary, epoch)
502
+ self.writer.flush()
503
+
504
+ def on_train_end(self, _):
505
+ self.writer.close()
506
+
507
+
508
+
509
+ class CARETensorBoardImage(Callback):
510
+
511
+ def __init__(self, model, data, log_dir,
512
+ n_images=3,
513
+ batch_size=1,
514
+ prob_out=False,
515
+ image_for_inputs=None, # write images for only these input indices
516
+ image_for_outputs=None, # write images for only these output indices
517
+ input_slices=None, # list (of list) of slices to apply to `image_for_inputs` layers before writing image
518
+ output_slices=None): # list (of list) of slices to apply to `image_for_outputs` layers before writing image
519
+
520
+ super(CARETensorBoardImage, self).__init__()
521
+ backend_channels_last() or _raise(NotImplementedError())
522
+ not IS_TF_1 or _raise(NotImplementedError("Not supported with TensorFlow 1"))
523
+
524
+ ## model
525
+ from ..models import BaseModel
526
+ if isinstance(model, BaseModel):
527
+ model = model.keras_model
528
+ isinstance(model, keras.Model) or _raise(ValueError())
529
+ # self.model is already used by keras.callbacks.Callback in keras 3+
530
+ self.model_ = model
531
+ self.n_inputs = len(self.model_.inputs)
532
+ self.n_outputs = len(self.model_.outputs)
533
+
534
+ ## data
535
+ if isinstance(data,(list,tuple)):
536
+ X, Y = data
537
+ else:
538
+ X, Y = data[0]
539
+ if self.n_inputs == 1 and isinstance(X,np.ndarray): X = (X,)
540
+ if self.n_outputs == 1 and isinstance(Y,np.ndarray): Y = (Y,)
541
+ (self.n_inputs == len(X) and self.n_outputs == len(Y)) or _raise(ValueError())
542
+ # all(len(v)>=n_images for v in X) or _raise(ValueError())
543
+ # all(len(v)>=n_images for v in Y) or _raise(ValueError())
544
+ self.X = [v[:n_images] for v in X]
545
+ self.Y = [v[:n_images] for v in Y]
546
+
547
+ self.log_dir = log_dir
548
+ self.batch_size = batch_size
549
+ self.prob_out = prob_out
550
+
551
+ ## I/O choices
552
+ image_for_inputs = np.arange(self.n_inputs) if image_for_inputs is None else image_for_inputs
553
+ image_for_outputs = np.arange(self.n_outputs) if image_for_outputs is None else image_for_outputs
554
+
555
+ input_slices = (slice(None),) if input_slices is None else input_slices
556
+ output_slices = (slice(None),) if output_slices is None else output_slices
557
+ if isinstance(input_slices[0],slice): # apply same slices to all inputs
558
+ input_slices = [input_slices]*len(image_for_inputs)
559
+ if isinstance(output_slices[0],slice): # apply same slices to all outputs
560
+ output_slices = [output_slices]*len(image_for_outputs)
561
+
562
+ len(input_slices) == len(image_for_inputs) or _raise(ValueError())
563
+ len(output_slices) == len(image_for_outputs) or _raise(ValueError())
564
+
565
+ self.image_for_inputs = image_for_inputs
566
+ self.image_for_outputs = image_for_outputs
567
+ self.input_slices = input_slices
568
+ self.output_slices = output_slices
569
+
570
+ self.file_writer = tf.summary.create_file_writer(str(self.log_dir))
571
+
572
+
573
+ def _name(self, prefix, layer, i, n, show_layer_names=False):
574
+ return '{prefix}{i}{name}'.format (
575
+ prefix = prefix,
576
+ i = (i if n > 1 else ''),
577
+ name = '' if (layer is None or not show_layer_names) else '_'+''.join(layer.name.split(':')[:-1]),
578
+ )
579
+
580
+
581
+ def _normalize_image(self, x, pmin=1, pmax=99.8, clip=True):
582
+ def norm(x, axis):
583
+ return normalize(x, pmin=pmin, pmax=pmax, axis=axis, clip=clip)
584
+
585
+ n_channels_out = x.shape[-1]
586
+ n_dim_out = len(x.shape)
587
+ assert 0 < n_channels_out
588
+
589
+ if n_dim_out > 4:
590
+ x = np.max(x, axis=tuple(range(1,1+n_dim_out-4)))
591
+
592
+ if n_channels_out == 1:
593
+ out = norm(x, axis=(1,2))
594
+ elif n_channels_out == 2:
595
+ out = norm(np.concatenate([x,x[...,:1]], axis=-1), axis=(1,2,3))
596
+ elif n_channels_out == 3:
597
+ out = norm(x, axis=(1,2,3))
598
+ else:
599
+ out = norm(np.max(x, axis=-1, keepdims=True), axis=(1,2,3))
600
+ return out
601
+
602
+
603
+ def on_epoch_end(self, epoch, logs=None):
604
+ # https://www.tensorflow.org/tensorboard/image_summaries
605
+
606
+ # inputs
607
+ if epoch == 0:
608
+ with self.file_writer.as_default():
609
+ for i,sl in zip(self.image_for_inputs,self.input_slices):
610
+ # print('input', self.model_.inputs[i], tuple(sl))
611
+ input_name = self._name('net_input', self.model_.inputs[i], i, self.n_inputs)
612
+ input_image = self._normalize_image(self.X[i][tuple(sl)])
613
+ tf.summary.image(input_name, input_image, step=epoch)
614
+
615
+ # outputs
616
+ Yhat = self.model_.predict(self.X, batch_size=self.batch_size, verbose=0)
617
+ if self.n_outputs == 1 and isinstance(Yhat,np.ndarray): Yhat = (Yhat,)
618
+ with self.file_writer.as_default():
619
+ for i,sl in zip(self.image_for_outputs,self.output_slices):
620
+ # print('output', self.model_.outputs[i], tuple(sl))
621
+ output_shape = self.model_.output_shape if self.n_outputs==1 else self.model_.output_shape[i]
622
+ n_channels_out = sep = output_shape[-1]
623
+ if self.prob_out: # first half of output channels is mean, second half scale
624
+ n_channels_out % 2 == 0 or _raise(ValueError())
625
+ sep = sep // 2
626
+ output_image = self._normalize_image(Yhat[i][...,:sep][tuple(sl)])
627
+ if self.prob_out:
628
+ scale_image = self._normalize_image(Yhat[i][...,sep:][tuple(sl)], pmin=0, pmax=100)
629
+ scale_name = self._name('net_output_scale', self.model_.outputs[i], i, self.n_outputs)
630
+ output_name = self._name('net_output_mean', self.model_.outputs[i], i, self.n_outputs)
631
+ tf.summary.image(output_name, output_image, step=epoch)
632
+ tf.summary.image(scale_name, scale_image, step=epoch)
633
+ else:
634
+ output_name = self._name('net_output', self.model_.outputs[i], i, self.n_outputs)
635
+ tf.summary.image(output_name, output_image, step=epoch)
636
+
637
+ # targets
638
+ if epoch == 0:
639
+ with self.file_writer.as_default():
640
+ for i,sl in zip(self.image_for_outputs,self.output_slices):
641
+ # print('target', self.model_.outputs[i], tuple(sl))
642
+ target_name = self._name('net_target', self.model_.outputs[i], i, self.n_outputs)
643
+ target_image = self._normalize_image(self.Y[i][tuple(sl)])
644
+ tf.summary.image(target_name, target_image, step=epoch)