senoquant 1.0.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. senoquant/__init__.py +6 -0
  2. senoquant/_reader.py +7 -0
  3. senoquant/_widget.py +33 -0
  4. senoquant/napari.yaml +83 -0
  5. senoquant/reader/__init__.py +5 -0
  6. senoquant/reader/core.py +369 -0
  7. senoquant/tabs/__init__.py +15 -0
  8. senoquant/tabs/batch/__init__.py +10 -0
  9. senoquant/tabs/batch/backend.py +641 -0
  10. senoquant/tabs/batch/config.py +270 -0
  11. senoquant/tabs/batch/frontend.py +1283 -0
  12. senoquant/tabs/batch/io.py +326 -0
  13. senoquant/tabs/batch/layers.py +86 -0
  14. senoquant/tabs/quantification/__init__.py +1 -0
  15. senoquant/tabs/quantification/backend.py +228 -0
  16. senoquant/tabs/quantification/features/__init__.py +80 -0
  17. senoquant/tabs/quantification/features/base.py +142 -0
  18. senoquant/tabs/quantification/features/marker/__init__.py +5 -0
  19. senoquant/tabs/quantification/features/marker/config.py +69 -0
  20. senoquant/tabs/quantification/features/marker/dialog.py +437 -0
  21. senoquant/tabs/quantification/features/marker/export.py +879 -0
  22. senoquant/tabs/quantification/features/marker/feature.py +119 -0
  23. senoquant/tabs/quantification/features/marker/morphology.py +285 -0
  24. senoquant/tabs/quantification/features/marker/rows.py +654 -0
  25. senoquant/tabs/quantification/features/marker/thresholding.py +46 -0
  26. senoquant/tabs/quantification/features/roi.py +346 -0
  27. senoquant/tabs/quantification/features/spots/__init__.py +5 -0
  28. senoquant/tabs/quantification/features/spots/config.py +62 -0
  29. senoquant/tabs/quantification/features/spots/dialog.py +477 -0
  30. senoquant/tabs/quantification/features/spots/export.py +1292 -0
  31. senoquant/tabs/quantification/features/spots/feature.py +112 -0
  32. senoquant/tabs/quantification/features/spots/morphology.py +279 -0
  33. senoquant/tabs/quantification/features/spots/rows.py +241 -0
  34. senoquant/tabs/quantification/frontend.py +815 -0
  35. senoquant/tabs/segmentation/__init__.py +1 -0
  36. senoquant/tabs/segmentation/backend.py +131 -0
  37. senoquant/tabs/segmentation/frontend.py +1009 -0
  38. senoquant/tabs/segmentation/models/__init__.py +5 -0
  39. senoquant/tabs/segmentation/models/base.py +146 -0
  40. senoquant/tabs/segmentation/models/cpsam/details.json +65 -0
  41. senoquant/tabs/segmentation/models/cpsam/model.py +150 -0
  42. senoquant/tabs/segmentation/models/default_2d/details.json +69 -0
  43. senoquant/tabs/segmentation/models/default_2d/model.py +664 -0
  44. senoquant/tabs/segmentation/models/default_3d/details.json +69 -0
  45. senoquant/tabs/segmentation/models/default_3d/model.py +682 -0
  46. senoquant/tabs/segmentation/models/hf.py +71 -0
  47. senoquant/tabs/segmentation/models/nuclear_dilation/__init__.py +1 -0
  48. senoquant/tabs/segmentation/models/nuclear_dilation/details.json +26 -0
  49. senoquant/tabs/segmentation/models/nuclear_dilation/model.py +96 -0
  50. senoquant/tabs/segmentation/models/perinuclear_rings/__init__.py +1 -0
  51. senoquant/tabs/segmentation/models/perinuclear_rings/details.json +34 -0
  52. senoquant/tabs/segmentation/models/perinuclear_rings/model.py +132 -0
  53. senoquant/tabs/segmentation/stardist_onnx_utils/__init__.py +2 -0
  54. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/__init__.py +3 -0
  55. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/__init__.py +6 -0
  56. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/generate.py +470 -0
  57. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/prepare.py +273 -0
  58. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/rawdata.py +112 -0
  59. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/data/transform.py +384 -0
  60. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/__init__.py +0 -0
  61. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/blocks.py +184 -0
  62. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/losses.py +79 -0
  63. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/nets.py +165 -0
  64. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/predict.py +467 -0
  65. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/probability.py +67 -0
  66. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/internals/train.py +148 -0
  67. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/io/__init__.py +163 -0
  68. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/__init__.py +52 -0
  69. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/base_model.py +329 -0
  70. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_isotropic.py +160 -0
  71. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_projection.py +178 -0
  72. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_standard.py +446 -0
  73. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/care_upsampling.py +54 -0
  74. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/config.py +254 -0
  75. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/models/pretrained.py +119 -0
  76. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/__init__.py +0 -0
  77. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/scripts/care_predict.py +180 -0
  78. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/__init__.py +5 -0
  79. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/plot_utils.py +159 -0
  80. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/six.py +18 -0
  81. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/tf.py +644 -0
  82. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/utils/utils.py +272 -0
  83. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/csbdeep/version.py +1 -0
  84. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/docs/source/conf.py +368 -0
  85. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/setup.py +68 -0
  86. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_datagen.py +169 -0
  87. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_models.py +462 -0
  88. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tests/test_utils.py +166 -0
  89. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tools/create_zip_contents.py +34 -0
  90. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/__init__.py +30 -0
  91. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/big.py +624 -0
  92. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/bioimageio_utils.py +494 -0
  93. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/data/__init__.py +39 -0
  94. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/__init__.py +10 -0
  95. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom2d.py +215 -0
  96. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/geometry/geom3d.py +349 -0
  97. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/matching.py +483 -0
  98. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/__init__.py +28 -0
  99. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/base.py +1217 -0
  100. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model2d.py +594 -0
  101. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/models/model3d.py +696 -0
  102. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/nms.py +384 -0
  103. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/__init__.py +2 -0
  104. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/plot.py +74 -0
  105. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/plot/render.py +298 -0
  106. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/rays3d.py +373 -0
  107. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/sample_patches.py +65 -0
  108. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/__init__.py +0 -0
  109. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict2d.py +90 -0
  110. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/scripts/predict3d.py +93 -0
  111. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/utils.py +408 -0
  112. senoquant/tabs/segmentation/stardist_onnx_utils/_stardist/version.py +1 -0
  113. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/__init__.py +45 -0
  114. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/__init__.py +17 -0
  115. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/cli.py +55 -0
  116. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/convert/core.py +285 -0
  117. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/__init__.py +15 -0
  118. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/cli.py +36 -0
  119. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/divisibility.py +193 -0
  120. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/probe.py +100 -0
  121. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/receptive_field.py +182 -0
  122. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/rf_cli.py +48 -0
  123. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/valid_sizes.py +278 -0
  124. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/__init__.py +8 -0
  125. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/post/core.py +157 -0
  126. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/__init__.py +17 -0
  127. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/pre/core.py +226 -0
  128. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/__init__.py +5 -0
  129. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/predict/core.py +401 -0
  130. senoquant/tabs/settings/__init__.py +1 -0
  131. senoquant/tabs/settings/backend.py +29 -0
  132. senoquant/tabs/settings/frontend.py +19 -0
  133. senoquant/tabs/spots/__init__.py +1 -0
  134. senoquant/tabs/spots/backend.py +139 -0
  135. senoquant/tabs/spots/frontend.py +800 -0
  136. senoquant/tabs/spots/models/__init__.py +5 -0
  137. senoquant/tabs/spots/models/base.py +94 -0
  138. senoquant/tabs/spots/models/rmp/details.json +61 -0
  139. senoquant/tabs/spots/models/rmp/model.py +499 -0
  140. senoquant/tabs/spots/models/udwt/details.json +103 -0
  141. senoquant/tabs/spots/models/udwt/model.py +482 -0
  142. senoquant/utils.py +25 -0
  143. senoquant-1.0.0b1.dist-info/METADATA +193 -0
  144. senoquant-1.0.0b1.dist-info/RECORD +148 -0
  145. senoquant-1.0.0b1.dist-info/WHEEL +5 -0
  146. senoquant-1.0.0b1.dist-info/entry_points.txt +2 -0
  147. senoquant-1.0.0b1.dist-info/licenses/LICENSE +28 -0
  148. senoquant-1.0.0b1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,696 @@
1
+ from __future__ import print_function, unicode_literals, absolute_import, division
2
+
3
+ import numpy as np
4
+ import warnings
5
+ import math
6
+ from tqdm import tqdm
7
+
8
+
9
+ from csbdeep.models import BaseConfig
10
+ from csbdeep.internals.blocks import conv_block3, unet_block, resnet_block
11
+ from csbdeep.utils import _raise, backend_channels_last, axes_check_and_normalize, axes_dict
12
+ from csbdeep.utils.tf import keras_import, IS_TF_1, CARETensorBoard, CARETensorBoardImage, IS_KERAS_3_PLUS, BACKEND as K
13
+ from packaging.version import Version
14
+ from scipy.ndimage import zoom
15
+ from skimage.measure import regionprops
16
+ keras = keras_import()
17
+ Input, Conv3D, MaxPooling3D, UpSampling3D, Add, Concatenate = keras_import('layers', 'Input', 'Conv3D', 'MaxPooling3D', 'UpSampling3D', 'Add', 'Concatenate')
18
+ Model = keras_import('models', 'Model')
19
+
20
+ from .base import StarDistBase, StarDistDataBase, _tf_version_at_least
21
+ from ..sample_patches import sample_patches
22
+ from ..utils import edt_prob, _normalize_grid, mask_to_categorical
23
+ from ..matching import relabel_sequential
24
+ from ..geometry import star_dist3D, polyhedron_to_label
25
+ from ..rays3d import Rays_GoldenSpiral, rays_from_json
26
+ from ..nms import non_maximum_suppression_3d, non_maximum_suppression_3d_sparse
27
+
28
+ _gen_rtype = list if IS_TF_1 else tuple
29
+
30
+ class StarDistData3D(StarDistDataBase):
31
+
32
+ def __init__(self, X, Y, batch_size, rays, length,
33
+ n_classes=None, classes=None,
34
+ patch_size=(128,128,128), grid=(1,1,1), anisotropy=None, augmenter=None, foreground_prob=0, **kwargs):
35
+ # TODO: support shape completion as in 2D?
36
+
37
+ super().__init__(X=X, Y=Y, n_rays=len(rays), grid=grid,
38
+ classes=classes, n_classes=n_classes,
39
+ batch_size=batch_size, patch_size=patch_size, length=length,
40
+ augmenter=augmenter, foreground_prob=foreground_prob, **kwargs)
41
+
42
+ self.rays = rays
43
+ self.anisotropy = anisotropy
44
+ self.sd_mode = 'opencl' if self.use_gpu else 'cpp'
45
+ # re-use arrays
46
+ if self.batch_size > 1:
47
+ self.out_X = np.empty((self.batch_size,)+tuple(self.patch_size)+(() if self.n_channel is None else (self.n_channel,)), dtype=np.float32)
48
+ patch_size_grid = tuple((p-1)//g+1 for p,g in zip(self.patch_size,self.grid))
49
+ self.out_mask_neg_labels = np.empty((self.batch_size,)+patch_size_grid, dtype=bool)
50
+ self.out_edt_prob = np.empty((self.batch_size,)+patch_size_grid, dtype=np.float32)
51
+ self.out_star_dist3D = np.empty((self.batch_size,)+patch_size_grid+(len(self.rays),), dtype=np.float32)
52
+ if self.n_classes is not None:
53
+ self.out_prob_class = np.empty((self.batch_size,)+tuple(self.patch_size)+(self.n_classes+1,), dtype=np.float32)
54
+
55
+
56
+ def __getitem__(self, i):
57
+ idx = self.batch(i)
58
+ arrays = [sample_patches((self.Y[k],) + self.channels_as_tuple(self.X[k]),
59
+ patch_size=self.patch_size, n_samples=1,
60
+ valid_inds=self.get_valid_inds(k)) for k in idx]
61
+
62
+ if self.n_channel is None:
63
+ X, Y = list(zip(*[(x[0],y[0]) for y,x in arrays]))
64
+ else:
65
+ X, Y = list(zip(*[(np.stack([_x[0] for _x in x],axis=-1), y[0]) for y,*x in arrays]))
66
+
67
+ X, Y = tuple(zip(*tuple(self.augmenter(_x, _y) for _x, _y in zip(X,Y))))
68
+
69
+ tmp = [y[self.ss_grid[1:]] < 0 for y in Y]
70
+ has_neg_labels = any(m.any() for m in tmp)
71
+ if has_neg_labels:
72
+ if len(Y) == 1:
73
+ mask_neg_labels = tmp[0][np.newaxis]
74
+ else:
75
+ mask_neg_labels = np.stack(tmp, out=self.out_mask_neg_labels[:len(Y)])
76
+ # set negative label pixels to 0 (background)
77
+ Y = tuple(np.maximum(y, 0) for y in Y)
78
+
79
+ if len(Y) == 1:
80
+ X = X[0][np.newaxis]
81
+ else:
82
+ X = np.stack(X, out=self.out_X[:len(Y)])
83
+ if X.ndim == 4: # input image has no channel axis
84
+ X = np.expand_dims(X,-1)
85
+
86
+ tmp = [edt_prob(lbl, anisotropy=self.anisotropy)[self.ss_grid[1:]] for lbl in Y]
87
+ if len(Y) == 1:
88
+ prob = tmp[0][np.newaxis]
89
+ else:
90
+ prob = np.stack(tmp, out=self.out_edt_prob[:len(Y)])
91
+
92
+ tmp = [star_dist3D(lbl, self.rays, mode=self.sd_mode, grid=self.grid) for lbl in Y]
93
+ if len(Y) == 1:
94
+ dist = tmp[0][np.newaxis]
95
+ else:
96
+ dist = np.stack(tmp, out=self.out_star_dist3D[:len(Y)])
97
+
98
+ prob = dist_mask = np.expand_dims(prob, -1)
99
+
100
+ # append dist_mask to dist as additional channel
101
+ dist = np.concatenate([dist,dist_mask],axis=-1)
102
+
103
+ if has_neg_labels:
104
+ prob[mask_neg_labels] = -1 # set to -1 to disable loss
105
+
106
+ # note: must return tuples in keras 3 (cf. https://stackoverflow.com/a/78158487)
107
+ if self.n_classes is None:
108
+ return _gen_rtype((X,)), _gen_rtype((prob,dist))
109
+ else:
110
+ tmp = [mask_to_categorical(y, self.n_classes, self.classes[k]) for y,k in zip(Y, idx)]
111
+ # TODO: downsample here before stacking?
112
+ if len(Y) == 1:
113
+ prob_class = tmp[0][np.newaxis]
114
+ else:
115
+ prob_class = np.stack(tmp, out=self.out_prob_class[:len(Y)])
116
+
117
+ # TODO: investigate downsampling via simple indexing vs. using 'zoom'
118
+ # prob_class = prob_class[self.ss_grid]
119
+ # 'zoom' might lead to better registered maps (especially if upscaled later)
120
+ prob_class = zoom(prob_class, (1,)+tuple(1/g for g in self.grid)+(1,), order=0)
121
+
122
+ if has_neg_labels:
123
+ prob_class[mask_neg_labels] = -1 # set to -1 to disable loss
124
+
125
+ return _gen_rtype((X,)), _gen_rtype((prob,dist, prob_class))
126
+
127
+
128
+
129
+ class Config3D(BaseConfig):
130
+ """Configuration for a :class:`StarDist3D` model.
131
+
132
+ Parameters
133
+ ----------
134
+ axes : str or None
135
+ Axes of the input images.
136
+ rays : Rays_Base, int, or None
137
+ Ray factory (e.g. Ray_GoldenSpiral).
138
+ If an integer then Ray_GoldenSpiral(rays) will be used
139
+ n_channel_in : int
140
+ Number of channels of given input image (default: 1).
141
+ grid : (int,int,int)
142
+ Subsampling factors (must be powers of 2) for each of the axes.
143
+ Model will predict on a subsampled grid for increased efficiency and larger field of view.
144
+ n_classes : None or int
145
+ Number of object classes to use for multi-class prediction (use None to disable)
146
+ anisotropy : (float,float,float)
147
+ Anisotropy of objects along each of the axes.
148
+ Use ``None`` to disable only for (nearly) isotropic objects shapes.
149
+ Also see ``utils.calculate_extents``.
150
+ backbone : str
151
+ Name of the neural network architecture to be used as backbone.
152
+ kwargs : dict
153
+ Overwrite (or add) configuration attributes (see below).
154
+
155
+
156
+ Attributes
157
+ ----------
158
+ unet_n_depth : int
159
+ Number of U-Net resolution levels (down/up-sampling layers).
160
+ unet_kernel_size : (int,int,int)
161
+ Convolution kernel size for all (U-Net) convolution layers.
162
+ unet_n_filter_base : int
163
+ Number of convolution kernels (feature channels) for first U-Net layer.
164
+ Doubled after each down-sampling layer.
165
+ unet_pool : (int,int,int)
166
+ Maxpooling size for all (U-Net) convolution layers.
167
+ net_conv_after_unet : int
168
+ Number of filters of the extra convolution layer after U-Net (0 to disable).
169
+ unet_* : *
170
+ Additional parameters for U-net backbone.
171
+ resnet_n_blocks : int
172
+ Number of ResNet blocks.
173
+ resnet_kernel_size : (int,int,int)
174
+ Convolution kernel size for all ResNet blocks.
175
+ resnet_n_filter_base : int
176
+ Number of convolution kernels (feature channels) for ResNet blocks.
177
+ (Number is doubled after every downsampling, see ``grid``.)
178
+ net_conv_after_resnet : int
179
+ Number of filters of the extra convolution layer after ResNet (0 to disable).
180
+ resnet_* : *
181
+ Additional parameters for ResNet backbone.
182
+ train_patch_size : (int,int,int)
183
+ Size of patches to be cropped from provided training images.
184
+ train_background_reg : float
185
+ Regularizer to encourage distance predictions on background regions to be 0.
186
+ train_foreground_only : float
187
+ Fraction (0..1) of patches that will only be sampled from regions that contain foreground pixels.
188
+ train_sample_cache : bool
189
+ Activate caching of valid patch regions for all training images (disable to save memory for large datasets)
190
+ train_dist_loss : str
191
+ Training loss for star-convex polygon distances ('mse' or 'mae').
192
+ train_loss_weights : tuple of float
193
+ Weights for losses relating to (probability, distance)
194
+ train_epochs : int
195
+ Number of training epochs.
196
+ train_steps_per_epoch : int
197
+ Number of parameter update steps per epoch.
198
+ train_learning_rate : float
199
+ Learning rate for training.
200
+ train_batch_size : int
201
+ Batch size for training.
202
+ train_tensorboard : bool
203
+ Enable TensorBoard for monitoring training progress.
204
+ train_n_val_patches : int
205
+ Number of patches to be extracted from validation images (``None`` = one patch per image).
206
+ train_reduce_lr : dict
207
+ Parameter :class:`dict` of ReduceLROnPlateau_ callback; set to ``None`` to disable.
208
+ use_gpu : bool
209
+ Indicate that the data generator should use OpenCL to do computations on the GPU.
210
+
211
+ .. _ReduceLROnPlateau: https://keras.io/api/callbacks/reduce_lr_on_plateau/
212
+ """
213
+
214
+ def __init__(self, axes='ZYX', rays=None, n_channel_in=1, grid=(1,1,1), n_classes=None, anisotropy=None, backbone='unet', **kwargs):
215
+
216
+ if rays is None:
217
+ if 'rays_json' in kwargs:
218
+ rays = rays_from_json(kwargs['rays_json'])
219
+ elif 'n_rays' in kwargs:
220
+ rays = Rays_GoldenSpiral(kwargs['n_rays'])
221
+ else:
222
+ rays = Rays_GoldenSpiral(96)
223
+ elif np.isscalar(rays):
224
+ rays = Rays_GoldenSpiral(rays)
225
+
226
+ super().__init__(axes=axes, n_channel_in=n_channel_in, n_channel_out=1+len(rays))
227
+
228
+ # directly set by parameters
229
+ self.n_rays = len(rays)
230
+ self.grid = _normalize_grid(grid,3)
231
+ self.anisotropy = anisotropy if anisotropy is None else tuple(anisotropy)
232
+ self.backbone = str(backbone).lower()
233
+ self.rays_json = rays.to_json()
234
+ self.n_classes = None if n_classes is None else int(n_classes)
235
+
236
+ if 'anisotropy' in self.rays_json['kwargs']:
237
+ if self.rays_json['kwargs']['anisotropy'] is None and self.anisotropy is not None:
238
+ self.rays_json['kwargs']['anisotropy'] = self.anisotropy
239
+ print("Changing 'anisotropy' of rays to %s" % str(anisotropy))
240
+ elif self.rays_json['kwargs']['anisotropy'] != self.anisotropy:
241
+ warnings.warn("Mismatch of 'anisotropy' of rays and 'anisotropy'.")
242
+
243
+ # default config (can be overwritten by kwargs below)
244
+ if self.backbone == 'unet':
245
+ self.unet_n_depth = 2
246
+ self.unet_kernel_size = 3,3,3
247
+ self.unet_n_filter_base = 32
248
+ self.unet_n_conv_per_depth = 2
249
+ self.unet_pool = 2,2,2
250
+ self.unet_activation = 'relu'
251
+ self.unet_last_activation = 'relu'
252
+ self.unet_batch_norm = False
253
+ self.unet_dropout = 0.0
254
+ self.unet_expansion = 2
255
+ self.unet_prefix = ''
256
+ self.net_conv_after_unet = 128
257
+ elif self.backbone == 'resnet':
258
+ self.resnet_n_blocks = 4
259
+ self.resnet_kernel_size = 3,3,3
260
+ self.resnet_kernel_init = 'he_normal'
261
+ self.resnet_n_filter_base = 32
262
+ self.resnet_n_conv_per_block = 3
263
+ self.resnet_activation = 'relu'
264
+ self.resnet_batch_norm = False
265
+ self.net_conv_after_resnet = 128
266
+ else:
267
+ raise ValueError("backbone '%s' not supported." % self.backbone)
268
+
269
+ # net_mask_shape not needed but kept for legacy reasons
270
+ if backend_channels_last():
271
+ self.net_input_shape = None,None,None,self.n_channel_in
272
+ self.net_mask_shape = None,None,None,1
273
+ else:
274
+ self.net_input_shape = self.n_channel_in,None,None,None
275
+ self.net_mask_shape = 1,None,None,None
276
+
277
+ # self.train_shape_completion = False
278
+ # self.train_completion_crop = 32
279
+ self.train_patch_size = 128,128,128
280
+ self.train_background_reg = 1e-4
281
+ self.train_foreground_only = 0.9
282
+ self.train_sample_cache = True
283
+
284
+ self.train_dist_loss = 'mae'
285
+ self.train_loss_weights = (1,0.2) if self.n_classes is None else (1,0.2,1)
286
+ self.train_class_weights = (1,1) if self.n_classes is None else (1,)*(self.n_classes+1)
287
+ self.train_epochs = 400
288
+ self.train_steps_per_epoch = 100
289
+ self.train_learning_rate = 0.0003
290
+ self.train_batch_size = 1
291
+ self.train_n_val_patches = None
292
+ self.train_tensorboard = True
293
+ # the parameter 'min_delta' was called 'epsilon' for keras<=2.1.5
294
+ # keras.__version__ was removed in tensorflow 2.13.0
295
+ min_delta_key = 'epsilon' if Version(getattr(keras, '__version__', '9.9.9'))<=Version('2.1.5') else 'min_delta'
296
+ self.train_reduce_lr = {'factor': 0.5, 'patience': 40, min_delta_key: 0}
297
+
298
+ self.use_gpu = False
299
+
300
+ # remove derived attributes that shouldn't be overwritten
301
+ for k in ('n_dim', 'n_channel_out', 'n_rays', 'rays_json'):
302
+ try: del kwargs[k]
303
+ except KeyError: pass
304
+
305
+ self.update_parameters(False, **kwargs)
306
+
307
+ # FIXME: put into is_valid()
308
+ if not len(self.train_loss_weights) == (2 if self.n_classes is None else 3):
309
+ raise ValueError(f"train_loss_weights {self.train_loss_weights} not compatible with n_classes ({self.n_classes}): must be 3 weights if n_classes is not None, otherwise 2")
310
+
311
+ if not len(self.train_class_weights) == (2 if self.n_classes is None else self.n_classes+1):
312
+ raise ValueError(f"train_class_weights {self.train_class_weights} not compatible with n_classes ({self.n_classes}): must be 'n_classes + 1' weights if n_classes is not None, otherwise 2")
313
+
314
+
315
+ class StarDist3D(StarDistBase):
316
+ """StarDist3D model.
317
+
318
+ Parameters
319
+ ----------
320
+ config : :class:`Config` or None
321
+ Will be saved to disk as JSON (``config.json``).
322
+ If set to ``None``, will be loaded from disk (must exist).
323
+ name : str or None
324
+ Model name. Uses a timestamp if set to ``None`` (default).
325
+ basedir : str
326
+ Directory that contains (or will contain) a folder with the given model name.
327
+
328
+ Raises
329
+ ------
330
+ FileNotFoundError
331
+ If ``config=None`` and config cannot be loaded from disk.
332
+ ValueError
333
+ Illegal arguments, including invalid configuration.
334
+
335
+ Attributes
336
+ ----------
337
+ config : :class:`Config`
338
+ Configuration, as provided during instantiation.
339
+ keras_model : `Keras model <https://keras.io/getting-started/functional-api-guide/>`_
340
+ Keras neural network model.
341
+ name : str
342
+ Model name.
343
+ logdir : :class:`pathlib.Path`
344
+ Path to model folder (which stores configuration, weights, etc.)
345
+ """
346
+
347
+ def __init__(self, config=Config3D(), name=None, basedir='.'):
348
+ """See class docstring."""
349
+ super().__init__(config, name=name, basedir=basedir)
350
+
351
+
352
+ def _build(self):
353
+ if self.config.backbone == "unet":
354
+ return self._build_unet()
355
+ elif self.config.backbone == "resnet":
356
+ return self._build_resnet()
357
+ else:
358
+ raise NotImplementedError(self.config.backbone)
359
+
360
+
361
+ def _build_unet(self):
362
+ assert self.config.backbone == 'unet'
363
+ unet_kwargs = {k[len('unet_'):]:v for (k,v) in vars(self.config).items() if k.startswith('unet_')}
364
+
365
+ input_img = Input(self.config.net_input_shape, name='input')
366
+
367
+ # maxpool input image to grid size
368
+ pooled = np.array([1,1,1])
369
+ pooled_img = input_img
370
+ while tuple(pooled) != tuple(self.config.grid):
371
+ pool = 1 + (np.asarray(self.config.grid) > pooled)
372
+ pooled *= pool
373
+ for _ in range(self.config.unet_n_conv_per_depth):
374
+ pooled_img = Conv3D(self.config.unet_n_filter_base, self.config.unet_kernel_size,
375
+ padding='same', activation=self.config.unet_activation)(pooled_img)
376
+ pooled_img = MaxPooling3D(pool)(pooled_img)
377
+
378
+ unet_base = unet_block(**unet_kwargs)(pooled_img)
379
+
380
+ if self.config.net_conv_after_unet > 0:
381
+ unet = Conv3D(self.config.net_conv_after_unet, self.config.unet_kernel_size,
382
+ name='features', padding='same', activation=self.config.unet_activation)(unet_base)
383
+ else:
384
+ unet = unet_base
385
+
386
+ output_prob = Conv3D( 1, (1,1,1), name='prob', padding='same', activation='sigmoid')(unet)
387
+ output_dist = Conv3D(self.config.n_rays, (1,1,1), name='dist', padding='same', activation='linear')(unet)
388
+
389
+ # attach extra classification head when self.n_classes is given
390
+ if self._is_multiclass():
391
+ if self.config.net_conv_after_unet > 0:
392
+ unet_class = Conv3D(self.config.net_conv_after_unet, self.config.unet_kernel_size,
393
+ name='features_class', padding='same', activation=self.config.unet_activation)(unet_base)
394
+ else:
395
+ unet_class = unet_base
396
+
397
+ output_prob_class = Conv3D(self.config.n_classes+1, (1,1,1), name='prob_class', padding='same', activation='softmax')(unet_class)
398
+ return Model([input_img], [output_prob,output_dist,output_prob_class])
399
+ else:
400
+ return Model([input_img], [output_prob,output_dist])
401
+
402
+
403
+ def _build_resnet(self):
404
+ assert self.config.backbone == 'resnet'
405
+ n_filter = self.config.resnet_n_filter_base
406
+ resnet_kwargs = dict (
407
+ kernel_size = self.config.resnet_kernel_size,
408
+ n_conv_per_block = self.config.resnet_n_conv_per_block,
409
+ batch_norm = self.config.resnet_batch_norm,
410
+ kernel_initializer = self.config.resnet_kernel_init,
411
+ activation = self.config.resnet_activation,
412
+ )
413
+
414
+ input_img = Input(self.config.net_input_shape, name='input')
415
+
416
+ layer = input_img
417
+ layer = Conv3D(n_filter, (7,7,7), padding="same", kernel_initializer=self.config.resnet_kernel_init)(layer)
418
+ layer = Conv3D(n_filter, (3,3,3), padding="same", kernel_initializer=self.config.resnet_kernel_init)(layer)
419
+
420
+ pooled = np.array([1,1,1])
421
+ for n in range(self.config.resnet_n_blocks):
422
+ pool = 1 + (np.asarray(self.config.grid) > pooled)
423
+ pooled *= pool
424
+ if any(p > 1 for p in pool):
425
+ n_filter *= 2
426
+ layer = resnet_block(n_filter, pool=tuple(pool), **resnet_kwargs)(layer)
427
+
428
+ layer_base = layer
429
+
430
+ if self.config.net_conv_after_resnet > 0:
431
+ layer = Conv3D(self.config.net_conv_after_resnet, self.config.resnet_kernel_size,
432
+ name='features', padding='same', activation=self.config.resnet_activation)(layer_base)
433
+
434
+ output_prob = Conv3D( 1, (1,1,1), name='prob', padding='same', activation='sigmoid')(layer)
435
+ output_dist = Conv3D(self.config.n_rays, (1,1,1), name='dist', padding='same', activation='linear')(layer)
436
+
437
+ # attach extra classification head when self.n_classes is given
438
+ if self._is_multiclass():
439
+ if self.config.net_conv_after_resnet > 0:
440
+ layer_class = Conv3D(self.config.net_conv_after_resnet, self.config.resnet_kernel_size,
441
+ name='features_class', padding='same', activation=self.config.resnet_activation)(layer_base)
442
+ else:
443
+ layer_class = layer_base
444
+
445
+ output_prob_class = Conv3D(self.config.n_classes+1, (1,1,1), name='prob_class', padding='same', activation='softmax')(layer_class)
446
+ return Model([input_img], [output_prob,output_dist,output_prob_class])
447
+ else:
448
+ return Model([input_img], [output_prob,output_dist])
449
+
450
+
451
+ def train(self, X, Y, validation_data, classes='auto', augmenter=None, seed=None, epochs=None, steps_per_epoch=None, workers=1):
452
+ """Train the neural network with the given data.
453
+
454
+ Parameters
455
+ ----------
456
+ X : tuple, list, `numpy.ndarray`, `keras.utils.Sequence`
457
+ Input images
458
+ Y : tuple, list, `numpy.ndarray`, `keras.utils.Sequence`
459
+ Label masks
460
+ Positive pixel values denote object instance ids (0 for background).
461
+ Negative values can be used to turn off all losses for the corresponding pixels (e.g. for regions that haven't been labeled).
462
+ classes (optional): 'auto' or iterable of same length as X
463
+ label id -> class id mapping for each label mask of Y if multiclass prediction is activated (n_classes > 0)
464
+ list of dicts with label id -> class id (1,...,n_classes)
465
+ 'auto' -> all objects will be assigned to the first non-background class,
466
+ or will be ignored if config.n_classes is None
467
+ validation_data : tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`) or triple (if multiclass)
468
+ Tuple (triple if multiclass) of X,Y,[classes] validation data.
469
+ augmenter : None or callable
470
+ Function with expected signature ``xt, yt = augmenter(x, y)``
471
+ that takes in a single pair of input/label image (x,y) and returns
472
+ the transformed images (xt, yt) for the purpose of data augmentation
473
+ during training. Not applied to validation images.
474
+ Example:
475
+ def simple_augmenter(x,y):
476
+ x = x + 0.05*np.random.normal(0,1,x.shape)
477
+ return x,y
478
+ seed : int
479
+ Convenience to set ``np.random.seed(seed)``. (To obtain reproducible validation patches, etc.)
480
+ epochs : int
481
+ Optional argument to use instead of the value from ``config``.
482
+ steps_per_epoch : int
483
+ Optional argument to use instead of the value from ``config``.
484
+
485
+ Returns
486
+ -------
487
+ ``History`` object
488
+ See `Keras training history <https://keras.io/models/model/#fit>`_.
489
+
490
+ """
491
+ if seed is not None:
492
+ # https://keras.io/getting-started/faq/#how-can-i-obtain-reproducible-results-using-keras-during-development
493
+ np.random.seed(seed)
494
+ if epochs is None:
495
+ epochs = self.config.train_epochs
496
+ if steps_per_epoch is None:
497
+ steps_per_epoch = self.config.train_steps_per_epoch
498
+
499
+ classes = self._parse_classes_arg(classes, len(X))
500
+
501
+ if not self._is_multiclass() and classes is not None:
502
+ warnings.warn("Ignoring given classes as n_classes is set to None")
503
+
504
+ isinstance(validation_data,(list,tuple)) or _raise(ValueError())
505
+ if self._is_multiclass() and len(validation_data) == 2:
506
+ validation_data = tuple(validation_data) + ('auto',)
507
+ ((len(validation_data) == (3 if self._is_multiclass() else 2))
508
+ or _raise(ValueError(f'len(validation_data) = {len(validation_data)}, but should be {3 if self._is_multiclass() else 2}')))
509
+
510
+ patch_size = self.config.train_patch_size
511
+ axes = self.config.axes.replace('C','')
512
+ div_by = self._axes_div_by(axes)
513
+ [p % d == 0 or _raise(ValueError(
514
+ "'train_patch_size' must be divisible by {d} along axis '{a}'".format(a=a,d=d)
515
+ )) for p,d,a in zip(patch_size,div_by,axes)]
516
+
517
+ if not self._model_prepared:
518
+ self.prepare_for_training()
519
+
520
+ data_kwargs = dict (
521
+ rays = rays_from_json(self.config.rays_json),
522
+ grid = self.config.grid,
523
+ patch_size = self.config.train_patch_size,
524
+ anisotropy = self.config.anisotropy,
525
+ use_gpu = self.config.use_gpu,
526
+ foreground_prob = self.config.train_foreground_only,
527
+ n_classes = self.config.n_classes,
528
+ sample_ind_cache = self.config.train_sample_cache,
529
+ )
530
+ worker_kwargs = dict(workers=workers, use_multiprocessing=workers>1)
531
+ if IS_KERAS_3_PLUS:
532
+ data_kwargs['keras_kwargs'] = worker_kwargs
533
+ fit_kwargs = {}
534
+ else:
535
+ fit_kwargs = worker_kwargs
536
+
537
+ # generate validation data and store in numpy arrays
538
+ n_data_val = len(validation_data[0])
539
+ classes_val = self._parse_classes_arg(validation_data[2], n_data_val) if self._is_multiclass() else None
540
+ n_take = self.config.train_n_val_patches if self.config.train_n_val_patches is not None else n_data_val
541
+ _data_val = StarDistData3D(validation_data[0],validation_data[1], classes=classes_val, batch_size=n_take, length=1, **data_kwargs)
542
+ data_val = _data_val[0]
543
+
544
+ # expose data generator as member for general diagnostics
545
+ self.data_train = StarDistData3D(X, Y, classes=classes, batch_size=self.config.train_batch_size,
546
+ augmenter=augmenter, length=epochs*steps_per_epoch, **data_kwargs)
547
+
548
+ if self.config.train_tensorboard:
549
+ # only show middle slice of 3D inputs/outputs
550
+ input_slices, output_slices = [[slice(None)]*5], [[slice(None)]*5,[slice(None)]*5]
551
+ i = axes_dict(self.config.axes)['Z']
552
+ channel = axes_dict(self.config.axes)['C']
553
+ _n_in = _data_val.patch_size[i] // 2
554
+ _n_out = _data_val.patch_size[i] // (2 * (self.config.grid[i] if self.config.grid is not None else 1))
555
+ input_slices[0][1+i] = _n_in
556
+ output_slices[0][1+i] = _n_out
557
+ output_slices[1][1+i] = _n_out
558
+ # show dist for three rays
559
+ _n = min(3, self.config.n_rays)
560
+ output_slices[1][1+channel] = slice(0,(self.config.n_rays//_n)*_n, self.config.n_rays//_n)
561
+ if self._is_multiclass():
562
+ _n = min(3, self.config.n_classes)
563
+ output_slices += [[slice(None)]*5]
564
+ output_slices[2][1+channel] = slice(1,1+(self.config.n_classes//_n)*_n, self.config.n_classes//_n)
565
+
566
+ if IS_TF_1:
567
+ for cb in self.callbacks:
568
+ if isinstance(cb,CARETensorBoard):
569
+ cb.input_slices = input_slices
570
+ cb.output_slices = output_slices
571
+ # target image for dist includes dist_mask and thus has more channels than dist output
572
+ cb.output_target_shapes = [None,[None]*5,None]
573
+ cb.output_target_shapes[1][1+channel] = data_val[1][1].shape[1+channel]
574
+ elif self.basedir is not None and not any(isinstance(cb,CARETensorBoardImage) for cb in self.callbacks):
575
+ self.callbacks.append(CARETensorBoardImage(model=self.keras_model, data=data_val, log_dir=str(self.logdir/'logs'/'images'),
576
+ n_images=3, prob_out=False, input_slices=input_slices, output_slices=output_slices))
577
+
578
+ fit = self.keras_model.fit_generator if (IS_TF_1 and not IS_KERAS_3_PLUS) else self.keras_model.fit
579
+ history = fit(iter(self.data_train), validation_data=data_val,
580
+ epochs=epochs, steps_per_epoch=steps_per_epoch,
581
+ **fit_kwargs,
582
+ callbacks=self.callbacks, verbose=1,
583
+ # set validation batchsize to training batchsize (only works in tf 2.x)
584
+ **(dict(validation_batch_size = self.config.train_batch_size) if _tf_version_at_least("2.2.0") else {}))
585
+ self._training_finished()
586
+
587
+ return history
588
+
589
+
590
+ def _instances_from_prediction(self, img_shape, prob, dist, points=None, prob_class=None, prob_thresh=None, nms_thresh=None, overlap_label=None, return_labels=True, scale=None, **nms_kwargs):
591
+ """
592
+ if points is None -> dense prediction
593
+ if points is not None -> sparse prediction
594
+
595
+ if prob_class is None -> single class prediction
596
+ if prob_class is not None -> multi class prediction
597
+ """
598
+ if prob_thresh is None: prob_thresh = self.thresholds.prob
599
+ if nms_thresh is None: nms_thresh = self.thresholds.nms
600
+
601
+ rays = rays_from_json(self.config.rays_json)
602
+
603
+ # sparse prediction
604
+ if points is not None:
605
+ points, probi, disti, indsi = non_maximum_suppression_3d_sparse(dist, prob, points, rays, nms_thresh=nms_thresh, **nms_kwargs)
606
+ if prob_class is not None:
607
+ prob_class = prob_class[indsi]
608
+
609
+ # dense prediction
610
+ else:
611
+ points, probi, disti = non_maximum_suppression_3d(dist, prob, rays, grid=self.config.grid,
612
+ prob_thresh=prob_thresh, nms_thresh=nms_thresh, **nms_kwargs)
613
+ if prob_class is not None:
614
+ inds = tuple(p//g for p,g in zip(points.T, self.config.grid))
615
+ prob_class = prob_class[inds]
616
+
617
+ verbose = nms_kwargs.get('verbose',False)
618
+ verbose and print("render polygons...")
619
+
620
+ if scale is not None:
621
+ # need to undo the scaling given by the scale dict, e.g. scale = dict(X=0.5,Y=0.5,Z=1.0):
622
+ # 1. re-scale points (origins of polyhedra)
623
+ # 2. re-scale vectors of rays object (computed from distances)
624
+ if not (isinstance(scale,dict) and 'X' in scale and 'Y' in scale and 'Z' in scale):
625
+ raise ValueError("scale must be a dictionary with entries for 'X', 'Y', and 'Z'")
626
+ rescale = (1/scale['Z'],1/scale['Y'],1/scale['X'])
627
+ points = points * np.array(rescale).reshape(1,3)
628
+ rays = rays.copy(scale=rescale)
629
+ else:
630
+ rescale = (1,1,1)
631
+
632
+ if return_labels:
633
+ labels = polyhedron_to_label(disti, points, rays=rays, prob=probi, shape=img_shape, overlap_label=overlap_label, verbose=verbose)
634
+
635
+ # map the overlap_label to something positive and back
636
+ # (as relabel_sequential doesn't like negative values)
637
+ if overlap_label is not None and overlap_label<0 and (overlap_label in labels):
638
+ overlap_mask = (labels == overlap_label)
639
+ overlap_label2 = max(set(np.unique(labels))-{overlap_label})+1
640
+ labels[overlap_mask] = overlap_label2
641
+ labels, fwd, bwd = relabel_sequential(labels)
642
+ labels[labels == fwd[overlap_label2]] = overlap_label
643
+ else:
644
+ # TODO relabel_sequential necessary?
645
+ # print(np.unique(labels))
646
+ labels, _,_ = relabel_sequential(labels)
647
+ # print(np.unique(labels))
648
+ else:
649
+ labels = None
650
+
651
+ res_dict = dict(dist=disti, points=points, prob=probi, rays=rays, rays_vertices=rays.vertices, rays_faces=rays.faces)
652
+
653
+ if prob_class is not None:
654
+ # build the list of class ids per label via majority vote
655
+ # zoom prob_class to img_shape
656
+ # prob_class_up = zoom(prob_class,
657
+ # tuple(s2/s1 for s1, s2 in zip(prob_class.shape[:3], img_shape))+(1,),
658
+ # order=0)
659
+ # class_id, label_ids = [], []
660
+ # for reg in regionprops(labels):
661
+ # m = labels[reg.slice]==reg.label
662
+ # cls_id = np.argmax(np.mean(prob_class_up[reg.slice][m], axis = 0))
663
+ # class_id.append(cls_id)
664
+ # label_ids.append(reg.label)
665
+ # # just a sanity check whether labels where in sorted order
666
+ # assert all(x <= y for x,y in zip(label_ids, label_ids[1:]))
667
+ # res_dict.update(dict(classes = class_id))
668
+ # res_dict.update(dict(labels = label_ids))
669
+ # self.p = prob_class_up
670
+
671
+ prob_class = np.asarray(prob_class)
672
+ class_id = np.argmax(prob_class, axis=-1)
673
+ res_dict.update(dict(class_prob=prob_class, class_id=class_id))
674
+
675
+ return labels, res_dict
676
+
677
+
678
+ def _axes_div_by(self, query_axes):
679
+ if self.config.backbone == "unet":
680
+ query_axes = axes_check_and_normalize(query_axes)
681
+ assert len(self.config.unet_pool) == len(self.config.grid)
682
+ div_by = dict(zip(
683
+ self.config.axes.replace('C',''),
684
+ tuple(p**self.config.unet_n_depth * g for p,g in zip(self.config.unet_pool,self.config.grid))
685
+ ))
686
+ return tuple(div_by.get(a,1) for a in query_axes)
687
+ elif self.config.backbone == "resnet":
688
+ grid_dict = dict(zip(self.config.axes.replace('C',''), self.config.grid))
689
+ return tuple(grid_dict.get(a,1) for a in query_axes)
690
+ else:
691
+ raise NotImplementedError()
692
+
693
+
694
+ @property
695
+ def _config_class(self):
696
+ return Config3D