DiSTNet2D 0.2.0__tar.gz → 0.2.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {distnet2d-0.2.0 → distnet2d-0.2.2}/DiSTNet2D.egg-info/PKG-INFO +4 -4
- {distnet2d-0.2.0 → distnet2d-0.2.2}/DiSTNet2D.egg-info/requires.txt +1 -1
- {distnet2d-0.2.0 → distnet2d-0.2.2}/PKG-INFO +4 -4
- {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/data/dydx_iterator.py +173 -119
- {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/data/swim1d.py +2 -2
- {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/model/architectures.py +47 -19
- {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/model/distnet_2d.py +187 -152
- {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/model/distnet_2d_seg.py +16 -34
- {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/model/layers.py +3 -3
- distnet2d-0.2.2/distnet_2d/model/spatial_attention.py +202 -0
- distnet2d-0.2.2/distnet_2d/utils/helpers.py +204 -0
- {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/utils/losses.py +6 -6
- distnet2d-0.2.2/distnet_2d/utils/metrics_tf.py +134 -0
- {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/utils/objectwise_computation_tf.py +48 -21
- {distnet2d-0.2.0 → distnet2d-0.2.2}/setup.py +4 -4
- distnet2d-0.2.0/distnet_2d/model/spatial_attention.py +0 -93
- distnet2d-0.2.0/distnet_2d/utils/helpers.py +0 -86
- distnet2d-0.2.0/distnet_2d/utils/metrics_tf.py +0 -89
- {distnet2d-0.2.0 → distnet2d-0.2.2}/DiSTNet2D.egg-info/SOURCES.txt +0 -0
- {distnet2d-0.2.0 → distnet2d-0.2.2}/DiSTNet2D.egg-info/dependency_links.txt +0 -0
- {distnet2d-0.2.0 → distnet2d-0.2.2}/DiSTNet2D.egg-info/top_level.txt +0 -0
- {distnet2d-0.2.0 → distnet2d-0.2.2}/LICENSE.txt +0 -0
- {distnet2d-0.2.0 → distnet2d-0.2.2}/README.md +0 -0
- {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/__init__.py +0 -0
- {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/data/__init__.py +0 -0
- {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/data/center_edm.py +0 -0
- {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/data/medoid.py +0 -0
- {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/model/__init__.py +0 -0
- {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/model/gradient_accumulator.py +0 -0
- {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/utils/__init__.py +0 -0
- {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/utils/agc.py +0 -0
- {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/utils/image_derivatives_np.py +0 -0
- {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/utils/image_derivatives_tf.py +0 -0
- {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/utils/lovasz_loss.py +0 -0
- {distnet2d-0.2.0 → distnet2d-0.2.2}/setup.cfg +0 -0
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: DiSTNet2D
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.2
|
|
4
4
|
Summary: tensorflow/keras implementation of DiSTNet 2D
|
|
5
5
|
Home-page: https://github.com/jeanollion/distnet2d
|
|
6
|
-
Download-URL: https://github.com/jeanollion/distnet2d/releases/download/v0.
|
|
6
|
+
Download-URL: https://github.com/jeanollion/distnet2d/releases/download/v0.2.2/distnet2d-0.2.2.tar.gz
|
|
7
7
|
Author: Jean Ollion
|
|
8
|
-
Author-email: jean.ollion@
|
|
8
|
+
Author-email: jean.ollion@sabilab.fr
|
|
9
9
|
Keywords: Segmentation,Tracking,Cell,Tensorflow,Keras
|
|
10
10
|
Classifier: Development Status :: 4 - Beta
|
|
11
11
|
Classifier: Intended Audience :: Science/Research
|
|
@@ -22,7 +22,7 @@ Requires-Dist: tensorflow>=2.7.1
|
|
|
22
22
|
Requires-Dist: edt>=2.0.2
|
|
23
23
|
Requires-Dist: scikit-fmm
|
|
24
24
|
Requires-Dist: numba
|
|
25
|
-
Requires-Dist: dataset_iterator>=0.5.
|
|
25
|
+
Requires-Dist: dataset_iterator>=0.5.5
|
|
26
26
|
Requires-Dist: elasticdeform>=0.4.7
|
|
27
27
|
Dynamic: author
|
|
28
28
|
Dynamic: author-email
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: DiSTNet2D
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.2
|
|
4
4
|
Summary: tensorflow/keras implementation of DiSTNet 2D
|
|
5
5
|
Home-page: https://github.com/jeanollion/distnet2d
|
|
6
|
-
Download-URL: https://github.com/jeanollion/distnet2d/releases/download/v0.
|
|
6
|
+
Download-URL: https://github.com/jeanollion/distnet2d/releases/download/v0.2.2/distnet2d-0.2.2.tar.gz
|
|
7
7
|
Author: Jean Ollion
|
|
8
|
-
Author-email: jean.ollion@
|
|
8
|
+
Author-email: jean.ollion@sabilab.fr
|
|
9
9
|
Keywords: Segmentation,Tracking,Cell,Tensorflow,Keras
|
|
10
10
|
Classifier: Development Status :: 4 - Beta
|
|
11
11
|
Classifier: Intended Audience :: Science/Research
|
|
@@ -22,7 +22,7 @@ Requires-Dist: tensorflow>=2.7.1
|
|
|
22
22
|
Requires-Dist: edt>=2.0.2
|
|
23
23
|
Requires-Dist: scikit-fmm
|
|
24
24
|
Requires-Dist: numba
|
|
25
|
-
Requires-Dist: dataset_iterator>=0.5.
|
|
25
|
+
Requires-Dist: dataset_iterator>=0.5.5
|
|
26
26
|
Requires-Dist: elasticdeform>=0.4.7
|
|
27
27
|
Dynamic: author
|
|
28
28
|
Dynamic: author-email
|
|
@@ -30,30 +30,43 @@ class DyDxIterator(TrackingIterator):
|
|
|
30
30
|
frame_window:int,
|
|
31
31
|
aug_frame_subsampling, # either int: frame number interval will be drawn uniformly in [frame_window,aug_frame_subsampling] or callable that generate an frame number interval (int)
|
|
32
32
|
erase_edge_cell_size:int,
|
|
33
|
-
|
|
33
|
+
next_frames:bool = True,
|
|
34
34
|
allow_frame_subsampling_direct_neigh:bool = False,
|
|
35
35
|
aug_remove_prob: float = 0.005,
|
|
36
36
|
return_link_multiplicity:bool = True,
|
|
37
37
|
channel_keywords:list=CHANNEL_KEYWORDS, # channel @1 must be label
|
|
38
|
-
input_label_keywords:list = None,
|
|
39
|
-
array_keywords:list=ARRAY_KEYWORDS[:1],
|
|
38
|
+
input_label_keywords:list = None, # additional labels that will be considered as input to the neural network
|
|
39
|
+
array_keywords:list=ARRAY_KEYWORDS[:1], # if second array : category
|
|
40
40
|
elasticdeform_parameters:dict = None,
|
|
41
41
|
downscale_displacement_and_link_multiplicity=1,
|
|
42
42
|
return_edm_derivatives: bool = False,
|
|
43
43
|
return_center:bool = True,
|
|
44
|
-
scale_edm:bool = False,
|
|
44
|
+
scale_edm:bool = False, # for each cell max edm value is 1
|
|
45
45
|
center_mode:str = "MEDOID", # GEOMETRICAL, "EDM_MAX", "EDM_MEAN", "SKELETON", "MEDOID"
|
|
46
|
-
center_distance_mode = "GEODESIC",
|
|
46
|
+
center_distance_mode = "GEODESIC", # GEODESIC, EUCLIDEAN
|
|
47
|
+
input_label_center_idx:int = -1, # center of target label are centers of the input label of this index
|
|
47
48
|
return_label_rank = False,
|
|
48
49
|
long_term:bool = True,
|
|
50
|
+
tracking:bool = True,
|
|
49
51
|
return_next_displacement:bool = True,
|
|
50
52
|
output_float16=False,
|
|
51
53
|
**kwargs):
|
|
52
54
|
assert len(channel_keywords)>=2, 'keyword should contain at least 2 elements in this order: grayscale input images, object labels, [other grayscale input images]'
|
|
53
|
-
|
|
55
|
+
if frame_window == 0:
|
|
56
|
+
tracking = False
|
|
57
|
+
if tracking:
|
|
58
|
+
assert 2 >= len(array_keywords) >= 1, 'array keyword first element should be links to previous objects. if 2 elements: second must be category'
|
|
59
|
+
else:
|
|
60
|
+
assert 1 >= len(array_keywords) >= 0, 'only one array keyword allow (category)'
|
|
61
|
+
|
|
54
62
|
assert center_mode.upper() in CENTER_MODE, f"invalid center mode: {center_mode} should be in {CENTER_MODE}"
|
|
55
63
|
assert center_distance_mode.upper() in CENTER_DISTANCE_MODE, f"invalid center distance mode: {center_distance_mode} should be in {CENTER_DISTANCE_MODE}"
|
|
56
|
-
self.
|
|
64
|
+
self.category_array_idx = next((i for i, s in enumerate(array_keywords) if "category" in s), -1)
|
|
65
|
+
self.tracking = tracking
|
|
66
|
+
if not tracking:
|
|
67
|
+
return_link_multiplicity = False
|
|
68
|
+
return_next_displacement = False
|
|
69
|
+
long_term = False
|
|
57
70
|
self.return_link_multiplicity=return_link_multiplicity
|
|
58
71
|
self.downscale=downscale_displacement_and_link_multiplicity
|
|
59
72
|
self.erase_edge_cell_size=erase_edge_cell_size
|
|
@@ -65,13 +78,15 @@ class DyDxIterator(TrackingIterator):
|
|
|
65
78
|
self.return_center=return_center
|
|
66
79
|
self.center_mode=center_mode.upper()
|
|
67
80
|
self.center_distance_mode = center_distance_mode.upper()
|
|
81
|
+
if input_label_center_idx >= 0:
|
|
82
|
+
assert input_label_center_idx < len(input_label_keywords), f"invalid input_label_center_idx={input_label_center_idx} must be < len(input_label_keywords)={len(input_label_keywords)}"
|
|
83
|
+
self.input_label_center_idx = input_label_center_idx
|
|
68
84
|
self.return_label_rank=return_label_rank
|
|
69
|
-
assert frame_window>=1, "frame_window must be >=1"
|
|
70
85
|
self.frame_window = frame_window
|
|
71
86
|
self.return_next_displacement=return_next_displacement
|
|
72
87
|
self.n_label_max = kwargs.pop("n_label_max", 2000)
|
|
73
|
-
self.long_term=long_term
|
|
74
|
-
self.
|
|
88
|
+
self.long_term=long_term
|
|
89
|
+
self.output_central_only = False
|
|
75
90
|
nchan = len(channel_keywords)
|
|
76
91
|
if input_label_keywords is not None:
|
|
77
92
|
if not isinstance(input_label_keywords, list):
|
|
@@ -85,20 +100,20 @@ class DyDxIterator(TrackingIterator):
|
|
|
85
100
|
else:
|
|
86
101
|
self.label_input_channels = []
|
|
87
102
|
super().__init__(dataset=dataset,
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
103
|
+
channel_keywords=channel_keywords,
|
|
104
|
+
array_keywords = array_keywords,
|
|
105
|
+
input_channels=[0] + [i for i in range(2, nchan)] + self.label_input_channels,
|
|
106
|
+
output_channels=[1],
|
|
107
|
+
channels_prev=[True]*len(channel_keywords),
|
|
108
|
+
channels_next=[next_frames] * len(channel_keywords),
|
|
109
|
+
mask_channels=[1] + self.label_input_channels,
|
|
110
|
+
n_frames = self.frame_window,
|
|
111
|
+
aug_remove_prob=aug_remove_prob,
|
|
112
|
+
aug_all_frames=False,
|
|
113
|
+
convert_masks_to_dtype=False,
|
|
114
|
+
extract_tile_function=extract_tile_function,
|
|
115
|
+
elasticdeform_parameters=elasticdeform_parameters,
|
|
116
|
+
**kwargs)
|
|
102
117
|
|
|
103
118
|
def disable_random_transforms(self, data_augmentation:bool=True, channels_postprocessing:bool=False):
|
|
104
119
|
params = super().disable_random_transforms(data_augmentation, channels_postprocessing)
|
|
@@ -112,18 +127,21 @@ class DyDxIterator(TrackingIterator):
|
|
|
112
127
|
self.aug_frame_subsampling = parameters["aug_frame_subsampling"]
|
|
113
128
|
|
|
114
129
|
def _get_batch_by_channel(self, index_array, perform_augmentation, input_only=False, perform_elasticdeform=True, perform_tiling=True, **kwargs):
|
|
115
|
-
if self.
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
if
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
130
|
+
if self.frame_window > 0:
|
|
131
|
+
if self.aug_remove_prob>0 and random() < self.aug_remove_prob:
|
|
132
|
+
n_frames = 0 # flag that aug_remove = true
|
|
133
|
+
else:
|
|
134
|
+
if self.aug_frame_subsampling is not None :
|
|
135
|
+
if callable(self.aug_frame_subsampling):
|
|
136
|
+
n_frames = self.aug_frame_subsampling()
|
|
137
|
+
elif self.aug_frame_subsampling > self.frame_window:
|
|
138
|
+
n_frames = max(self.frame_window, np.random.randint(self.aug_frame_subsampling))
|
|
139
|
+
else:
|
|
140
|
+
n_frames = self.frame_window
|
|
123
141
|
else:
|
|
124
142
|
n_frames = self.frame_window
|
|
125
|
-
|
|
126
|
-
|
|
143
|
+
else:
|
|
144
|
+
n_frames = 0
|
|
127
145
|
kwargs.update({"n_frames":n_frames})
|
|
128
146
|
batch_by_channel, aug_param_array, ref_channel = super()._get_batch_by_channel(index_array, perform_augmentation, input_only, perform_elasticdeform=False, perform_tiling=False, **kwargs)
|
|
129
147
|
ref_shape = batch_by_channel[0].shape
|
|
@@ -132,25 +150,26 @@ class DyDxIterator(TrackingIterator):
|
|
|
132
150
|
if not issubclass(batch_by_channel[1].dtype.type, np.integer): # label
|
|
133
151
|
batch_by_channel[1] = batch_by_channel[1].astype(np.int32)
|
|
134
152
|
# correction for oob @ previous labels : add identity links
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
153
|
+
if self.tracking:
|
|
154
|
+
prevLinks = batch_by_channel['arrays'][0]
|
|
155
|
+
for b in range(prevLinks.shape[0]):
|
|
156
|
+
if n_frames > 0:
|
|
157
|
+
for i in range(n_frames):
|
|
158
|
+
inc = n_frames - i
|
|
159
|
+
prev_inc = aug_param_array[b][ref_channel].get(f"prev_inc_{inc}", inc)
|
|
160
|
+
if prev_inc!=inc:
|
|
161
|
+
#print(f"oob prev: batch: {b} n_frames={n_frames}, frame_idx:{i} inc={inc} actual inc:{prev_inc} will replace at {i+1}")
|
|
162
|
+
self._set_identity_link(prevLinks, b, i+1)
|
|
163
|
+
if self.channels_next[1]:
|
|
164
|
+
next_inc = aug_param_array[b][ref_channel].get(f"next_inc_{inc}", inc)
|
|
165
|
+
if next_inc!=inc:
|
|
166
|
+
#print(f"oob next: batch: {b} n_frames={n_frames}, frame_idx:{i} inc={inc} actual inc:{next_inc} will replace prev labels at {n_frames+inc}")
|
|
167
|
+
self._set_identity_link(prevLinks, b, n_frames + inc)
|
|
168
|
+
else: # n_frame == 0 ->
|
|
169
|
+
for c in range(1, prevLinks.shape[-1]):
|
|
170
|
+
self._set_identity_link(prevLinks, b, c)
|
|
171
|
+
# get previous labels and store in batch_by_channel BEFORE applying tiling and elastic deform
|
|
172
|
+
self._get_prev_label(batch_by_channel, n_frames)
|
|
154
173
|
batch_by_channel["batch_size"] = batch_by_channel[0].shape[0] # batch size is recorded here: it will be used in case of tiling
|
|
155
174
|
if n_frames>1: # remove unused frames
|
|
156
175
|
sel = self._get_end_points(n_frames, False)
|
|
@@ -168,7 +187,7 @@ class DyDxIterator(TrackingIterator):
|
|
|
168
187
|
if perform_elasticdeform:
|
|
169
188
|
self._apply_elasticdeform(batch_by_channel)
|
|
170
189
|
if perform_tiling:
|
|
171
|
-
self._apply_tiling(batch_by_channel)
|
|
190
|
+
self._apply_tiling(batch_by_channel, index_array)
|
|
172
191
|
if perform_elasticdeform or perform_tiling:
|
|
173
192
|
for c in converted_from_float16:
|
|
174
193
|
batch_by_channel[c] = batch_by_channel[c].astype('float16')
|
|
@@ -181,7 +200,7 @@ class DyDxIterator(TrackingIterator):
|
|
|
181
200
|
prev_label[b, :, 1, c] = prev_label[b, :, 0, c]
|
|
182
201
|
|
|
183
202
|
def _get_frames_to_augment(self, img, chan_idx, aug_params):
|
|
184
|
-
if self.aug_all_frames:
|
|
203
|
+
if self.aug_all_frames or self.frame_window == 0:
|
|
185
204
|
return list(range(img.shape[-1]))
|
|
186
205
|
n_frames = (img.shape[-1]-1)//2 if self.channels_prev[chan_idx] and self.channels_next[chan_idx] else img.shape[-1]-1
|
|
187
206
|
return self._get_end_points(n_frames, False)
|
|
@@ -224,17 +243,24 @@ class DyDxIterator(TrackingIterator):
|
|
|
224
243
|
def _get_output_batch(self, batch_by_channel, ref_chan_idx, aug_param_array): # compute edm, center, dx & dy edm, link_multiplicity
|
|
225
244
|
ndisp = self.return_next_displacement
|
|
226
245
|
labelIms = batch_by_channel[1]
|
|
227
|
-
labels_map_prev = batch_by_channel["prev_label_map"]
|
|
246
|
+
labels_map_prev = batch_by_channel["prev_label_map"] if self.tracking else None
|
|
228
247
|
return_next = self.channels_next[1]
|
|
229
248
|
long_term = self.long_term
|
|
230
249
|
frame_window = self.frame_window
|
|
231
|
-
if self.
|
|
250
|
+
if self.output_central_only and self.frame_window > 0:
|
|
232
251
|
assert self.channels_prev[1], "in return_central_only mode previous must be returned"
|
|
233
252
|
assert return_next, "in return_central_only mode next must be returned"
|
|
234
|
-
|
|
253
|
+
for midx in self.mask_channels:
|
|
254
|
+
batch_by_channel[midx] = batch_by_channel[midx][..., self.frame_window-1:self.frame_window+2] # only prev, central, and next frame
|
|
255
|
+
labelIms = batch_by_channel[1]
|
|
256
|
+
for aidx in range(len(batch_by_channel['arrays'])):
|
|
257
|
+
batch_by_channel['arrays'][aidx] = batch_by_channel['arrays'][aidx][..., self.frame_window-1:self.frame_window+2]
|
|
258
|
+
for icidx in range(len(batch_by_channel["input_centers"])): # remap : 0 -> fw-1, 1 -> fw, 2 -> fw+1
|
|
259
|
+
centers = batch_by_channel["input_centers"][icidx]
|
|
260
|
+
batch_by_channel["input_centers"][icidx] = {(b, c-(self.frame_window-1)):center for (b, c), center in centers.items() if self.frame_window-1<=c<=self.frame_window+1 }
|
|
235
261
|
long_term = False
|
|
236
262
|
frame_window = 1
|
|
237
|
-
labels_map_prev = [lmp[self.frame_window-1:self.frame_window+2] for lmp in labels_map_prev]
|
|
263
|
+
labels_map_prev = [lmp[self.frame_window-1:self.frame_window+2] for lmp in labels_map_prev] if self.tracking else None
|
|
238
264
|
# remove small object
|
|
239
265
|
mask_to_erase_cur = [chan_idx for chan_idx in self.mask_channels if chan_idx!=1 and chan_idx in batch_by_channel]
|
|
240
266
|
mask_to_erase_chan_cur = [frame_window if self.channels_prev[chan_idx] else 0 for chan_idx in mask_to_erase_cur]
|
|
@@ -262,23 +288,23 @@ class DyDxIterator(TrackingIterator):
|
|
|
262
288
|
n_motion = 2 * frame_window if return_next else frame_window
|
|
263
289
|
if long_term:
|
|
264
290
|
n_motion = n_motion + (2 * ( frame_window - 1 ) if return_next else frame_window -1)
|
|
265
|
-
dyIm = np.zeros(labelIms.shape[:-1]+(n_motion,), dtype=self.dtype)
|
|
266
|
-
dxIm = np.zeros(labelIms.shape[:-1]+(n_motion,), dtype=self.dtype)
|
|
291
|
+
dyIm = np.zeros(labelIms.shape[:-1]+(n_motion,), dtype=self.dtype) if self.tracking else None
|
|
292
|
+
dxIm = np.zeros(labelIms.shape[:-1]+(n_motion,), dtype=self.dtype) if self.tracking else None
|
|
267
293
|
if ndisp:
|
|
268
294
|
dyImNext = np.zeros(labelIms.shape[:-1]+(n_motion,), dtype=self.dtype)
|
|
269
295
|
dxImNext = np.zeros(labelIms.shape[:-1]+(n_motion,), dtype=self.dtype)
|
|
270
296
|
if self.return_link_multiplicity:
|
|
271
297
|
linkMultiplicityImNext = np.zeros(labelIms.shape[:-1]+(n_motion,), dtype=self.dtype)
|
|
272
298
|
centerIm = np.zeros(labelIms.shape, dtype=self.dtype) if self.return_center else None
|
|
273
|
-
categoryIm = np.zeros(labelIms.shape, dtype=self.dtype) if self.
|
|
274
|
-
cat_array = batch_by_channel['arrays'][
|
|
299
|
+
categoryIm = np.zeros(labelIms.shape, dtype=self.dtype) if self.category_array_idx>0 else None
|
|
300
|
+
cat_array = batch_by_channel['arrays'][self.category_array_idx] if self.category_array_idx>0 else None
|
|
275
301
|
if cat_array is not None and len(cat_array.shape) == 4:
|
|
276
302
|
cat_array = cat_array[:, :, 0]
|
|
277
303
|
if self.return_label_rank:
|
|
278
304
|
rankIm = np.zeros(labelIms.shape, dtype=np.int32)
|
|
279
|
-
prevLabelArr = np.zeros(labelIms.shape[:1]+(n_motion, self.n_label_max), dtype=np.int32)
|
|
280
|
-
nextLabelArr = np.zeros(labelIms.shape[:1] + (n_motion, self.n_label_max), dtype=np.int32)
|
|
281
|
-
centerArr = np.zeros(labelIms.shape[:1]+labelIms.shape[-1:]+(self.n_label_max,2), dtype=np.float32)
|
|
305
|
+
prevLabelArr = np.zeros(labelIms.shape[:1]+(n_motion, self.n_label_max), dtype=np.int32) if self.tracking else None
|
|
306
|
+
nextLabelArr = np.zeros(labelIms.shape[:1] + (n_motion, self.n_label_max), dtype=np.int32) if self.tracking else None
|
|
307
|
+
centerArr = np.zeros(labelIms.shape[:1]+labelIms.shape[-1:]+(self.n_label_max,2), dtype=np.float32) # B, C, N, 2
|
|
282
308
|
centerArr.fill(np.nan)
|
|
283
309
|
if self.return_link_multiplicity:
|
|
284
310
|
linkMultiplicityIm = np.zeros(labelIms.shape[:-1]+(n_motion,), dtype=self.dtype)
|
|
@@ -290,63 +316,71 @@ class DyDxIterator(TrackingIterator):
|
|
|
290
316
|
get_idx = lambda x:x%batch_size # We assume here that the indexing order of tiling is tile x batch
|
|
291
317
|
else:
|
|
292
318
|
get_idx = lambda x:x
|
|
319
|
+
|
|
293
320
|
labels_and_centers = {}
|
|
294
321
|
for b,c in itertools.product(range(labelIms.shape[0]), range(labelIms.shape[-1])):
|
|
295
|
-
|
|
322
|
+
input_centers = batch_by_channel["input_centers"][self.input_label_center_idx][(b, c)] if self.input_label_center_idx >= 0 else None
|
|
323
|
+
labels_and_centers[(b, c)] = _get_labels_and_centers(labelIms[b][...,c], edm[b][...,c], self.center_mode, input_centers=input_centers)
|
|
296
324
|
for i in range(labelIms.shape[0]):
|
|
297
325
|
bidx = get_idx(i)
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
l_c = [labels_and_centers[(i,s)] for s in sel]
|
|
301
|
-
o_s = [object_slices[(i, s)] for s in sel]
|
|
302
|
-
_compute_outputs(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c], o_s, dyIm[i,...,c], dxIm[i,...,c], dyImNext=dyImNext[i,...,c] if ndisp else None, dxImNext=dxImNext[i,...,c] if ndisp else None, cdmIm=centerIm[i,...,frame_window] if self.return_center and sel[1] == frame_window else None, cdmImPrev=centerIm[i,...,c] if self.return_center else None, edmIm = edm[i,...,frame_window] if self.scale_edm else None, edmImPrev = edm[i,...,c] if self.scale_edm else None, scale_edm=self.scale_edm, categoryIm=categoryIm[i,...,frame_window] if self.return_category and sel[1] == frame_window else None, categoryArray=cat_array[bidx, :, frame_window] if self.return_category and sel[1] == frame_window else None, categoryImPrev=categoryIm[i,...,c] if self.return_category else None, categoryArrayPrev=cat_array[bidx, :, c] if self.return_category else None, linkMultiplicityIm=linkMultiplicityIm[i,...,c] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,...,c] if self.return_link_multiplicity and ndisp else None, rankIm=rankIm[i,...,frame_window] if self.return_label_rank and sel[1] == frame_window else None, rankImPrev=rankIm[i,...,c] if self.return_label_rank else None, prevLabelArr=prevLabelArr[i,c] if self.return_label_rank else None, nextLabelArr=nextLabelArr[i,c] if self.return_label_rank and ndisp else None, centerArr=centerArr[i,frame_window] if self.return_label_rank and sel[1] == frame_window else None, centerArrPrev=centerArr[i,c] if self.return_label_rank else None, center_distance_mode=self.center_distance_mode)
|
|
303
|
-
if return_next:
|
|
304
|
-
for c in range(frame_window, 2*frame_window):
|
|
326
|
+
if self.frame_window > 0:
|
|
327
|
+
for c in range(0, frame_window):
|
|
305
328
|
sel = [c, c+1]
|
|
306
|
-
l_c = [labels_and_centers[(i,
|
|
307
|
-
o_s = [object_slices[(i, s)] for s in sel]
|
|
308
|
-
_compute_outputs(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c], o_s, dyIm[i,...,c], dxIm[i,...,c], dyImNext=dyImNext[i,...,c] if ndisp else None, dxImNext=dxImNext[i,...,c] if ndisp else None, cdmIm=centerIm[i,..., c + 1] if self.return_center else None, edmIm = edm[i,...,c+1] if self.scale_edm else None, scale_edm=self.scale_edm, categoryIm=categoryIm[i,..., c + 1] if self.return_category else None, categoryArray=cat_array[bidx, :, c+1] if self.return_category else None, cdmImPrev=None, linkMultiplicityIm=linkMultiplicityIm[i,...,c] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,...,c] if self.return_link_multiplicity and ndisp else None, rankIm=rankIm[i,..., c + 1] if self.return_label_rank else None, rankImPrev=None, prevLabelArr=prevLabelArr[i,c] if self.return_label_rank else None, nextLabelArr=nextLabelArr[i,c] if self.return_label_rank and ndisp else None, centerArr=centerArr[i, c + 1] if self.return_label_rank else None, center_distance_mode=self.center_distance_mode)
|
|
309
|
-
if long_term:
|
|
310
|
-
off = 2*frame_window if return_next else frame_window
|
|
311
|
-
for c in range(0, frame_window-1):
|
|
312
|
-
sel = [c, frame_window]
|
|
313
|
-
l_c = [labels_and_centers[(i, s)] for s in sel]
|
|
329
|
+
l_c = [labels_and_centers[(i,s)] for s in sel]
|
|
314
330
|
o_s = [object_slices[(i, s)] for s in sel]
|
|
315
|
-
_compute_outputs(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c
|
|
331
|
+
_compute_outputs(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c] if self.tracking else None, o_s, dyIm[i,...,c] if self.tracking else None, dxIm[i,...,c] if self.tracking else None, dyImNext=dyImNext[i,...,c] if ndisp else None, dxImNext=dxImNext[i,...,c] if ndisp else None, cdmIm=centerIm[i,...,frame_window] if self.return_center and sel[1] == frame_window else None, cdmImPrev=centerIm[i,...,c] if self.return_center else None, edmIm = edm[i,...,frame_window] if self.scale_edm else None, edmImPrev = edm[i,...,c] if self.scale_edm else None, scale_edm=self.scale_edm, categoryIm=categoryIm[i,...,frame_window] if self.category_array_idx>0 and sel[1] == frame_window else None, categoryArray=cat_array[bidx, :, frame_window] if self.category_array_idx>0 and sel[1] == frame_window else None, categoryImPrev=categoryIm[i,...,c] if self.category_array_idx>0 else None, categoryArrayPrev=cat_array[bidx, :, c] if self.category_array_idx>0 else None, linkMultiplicityIm=linkMultiplicityIm[i,...,c] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,...,c] if self.return_link_multiplicity and ndisp else None, rankIm=rankIm[i,...,frame_window] if self.return_label_rank and sel[1] == frame_window else None, rankImPrev=rankIm[i,...,c] if self.return_label_rank else None, prevLabelArr=prevLabelArr[i,c] if self.return_label_rank and self.tracking else None, nextLabelArr=nextLabelArr[i,c] if self.return_label_rank and self.tracking and ndisp else None, centerArr=centerArr[i,frame_window] if self.return_label_rank and sel[1] == frame_window else None, centerArrPrev=centerArr[i,c] if self.return_label_rank else None, center_distance_mode=self.center_distance_mode)
|
|
316
332
|
if return_next:
|
|
317
|
-
for c in range(frame_window
|
|
318
|
-
sel = [
|
|
333
|
+
for c in range(frame_window, 2*frame_window):
|
|
334
|
+
sel = [c, c+1]
|
|
335
|
+
l_c = [labels_and_centers[(i, s)] for s in sel]
|
|
336
|
+
o_s = [object_slices[(i, s)] for s in sel]
|
|
337
|
+
_compute_outputs(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c] if self.tracking else None, o_s, dyIm[i,...,c] if self.tracking else None, dxIm[i,...,c] if self.tracking else None, dyImNext=dyImNext[i,...,c] if ndisp else None, dxImNext=dxImNext[i,...,c] if ndisp else None, cdmIm=centerIm[i,..., c + 1] if self.return_center else None, edmIm = edm[i,...,c+1] if self.scale_edm else None, scale_edm=self.scale_edm, categoryIm=categoryIm[i,..., c + 1] if self.category_array_idx>0 else None, categoryArray=cat_array[bidx, :, c+1] if self.category_array_idx>0 else None, cdmImPrev=None, linkMultiplicityIm=linkMultiplicityIm[i,...,c] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,...,c] if self.return_link_multiplicity and ndisp else None, rankIm=rankIm[i,..., c + 1] if self.return_label_rank else None, rankImPrev=None, prevLabelArr=prevLabelArr[i,c] if self.return_label_rank and self.tracking else None, nextLabelArr=nextLabelArr[i,c] if self.return_label_rank and ndisp else None, centerArr=centerArr[i, c + 1] if self.return_label_rank else None, center_distance_mode=self.center_distance_mode)
|
|
338
|
+
if long_term:
|
|
339
|
+
off = 2*frame_window if return_next else frame_window
|
|
340
|
+
for c in range(0, frame_window-1):
|
|
341
|
+
sel = [c, frame_window]
|
|
319
342
|
l_c = [labels_and_centers[(i, s)] for s in sel]
|
|
320
343
|
o_s = [object_slices[(i, s)] for s in sel]
|
|
321
344
|
_compute_outputs(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c + off], o_s, dyIm[i,..., c + off], dxIm[i,..., c + off], dyImNext=dyImNext[i,..., c + off] if ndisp else None, dxImNext=dxImNext[i,..., c + off] if ndisp else None, cdmIm=None, cdmImPrev=None, linkMultiplicityIm=linkMultiplicityIm[i,..., c + off] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,..., c + off] if self.return_link_multiplicity and ndisp else None, rankIm=None, rankImPrev=None, prevLabelArr=prevLabelArr[i, c + off] if self.return_label_rank else None, nextLabelArr=nextLabelArr[i, c + off] if self.return_label_rank and ndisp else None, center_distance_mode=self.center_distance_mode)
|
|
345
|
+
if return_next:
|
|
346
|
+
for c in range(frame_window-1, 2*(frame_window-1)):
|
|
347
|
+
sel = [frame_window, c+3]
|
|
348
|
+
l_c = [labels_and_centers[(i, s)] for s in sel]
|
|
349
|
+
o_s = [object_slices[(i, s)] for s in sel]
|
|
350
|
+
_compute_outputs(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c + off], o_s, dyIm[i,..., c + off], dxIm[i,..., c + off], dyImNext=dyImNext[i,..., c + off] if ndisp else None, dxImNext=dxImNext[i,..., c + off] if ndisp else None, cdmIm=None, cdmImPrev=None, linkMultiplicityIm=linkMultiplicityIm[i,..., c + off] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,..., c + off] if self.return_link_multiplicity and ndisp else None, rankIm=None, rankImPrev=None, prevLabelArr=prevLabelArr[i, c + off] if self.return_label_rank else None, nextLabelArr=nextLabelArr[i, c + off] if self.return_label_rank and ndisp else None, center_distance_mode=self.center_distance_mode)
|
|
351
|
+
else:
|
|
352
|
+
l_c = [labels_and_centers[(i, 0)]]
|
|
353
|
+
o_s = [object_slices[(i, 0)]]
|
|
354
|
+
_compute_outputs(l_c, labelIms[i][...,0:1], None, o_s, None, None, cdmIm=centerIm[i,...,0], edmIm = edm[i,...,0] if self.scale_edm else None, scale_edm=self.scale_edm, categoryIm=categoryIm[i,...,0] if self.category_array_idx>0 else None, categoryArray=cat_array[bidx, :, frame_window] if self.category_array_idx>0 else None, rankIm=rankIm[i,...,0] if self.return_label_rank else None, centerArr=centerArr[i,0] if self.return_label_rank else None, center_distance_mode=self.center_distance_mode)
|
|
322
355
|
|
|
323
356
|
edm[edm == 0] = -1
|
|
324
357
|
if self.return_edm_derivatives:
|
|
325
358
|
der_y, der_x = np.zeros_like(edm), np.zeros_like(edm)
|
|
326
|
-
|
|
359
|
+
c_range = range(edm.shape[-1]) if not self.output_central_only else range(frame_window, frame_window - 1)
|
|
360
|
+
for b, c in itertools.product(range(edm.shape[0]), c_range):
|
|
327
361
|
derivatives_labelwise(edm[b, ..., c], -1, der_y[b, ..., c], der_x[b, ..., c], labelIms[b, ..., c], object_slices[(b, c)])
|
|
328
|
-
if self.
|
|
329
|
-
der_y = der_y[...,
|
|
330
|
-
der_x = der_x[...,
|
|
331
|
-
|
|
332
|
-
if self.
|
|
333
|
-
edm = edm[...,
|
|
334
|
-
centerIm = centerIm[...,
|
|
335
|
-
dyIm = dyIm[...,
|
|
336
|
-
dxIm = dxIm[...,
|
|
362
|
+
if self.output_central_only:
|
|
363
|
+
der_y = der_y[..., 1:-1]
|
|
364
|
+
der_x = der_x[..., 1:-1]
|
|
365
|
+
|
|
366
|
+
if self.output_central_only: # select only central frame for edm / center and only displacement / link multiplicity related to central frame
|
|
367
|
+
edm = edm[..., 1:-1]
|
|
368
|
+
centerIm = centerIm[..., 1:-1]
|
|
369
|
+
dyIm = dyIm[..., :1] if self.tracking else None
|
|
370
|
+
dxIm = dxIm[..., :1] if self.tracking else None
|
|
337
371
|
if self.return_link_multiplicity:
|
|
338
|
-
linkMultiplicityIm = linkMultiplicityIm[...,
|
|
339
|
-
if self.
|
|
340
|
-
categoryIm = categoryIm[...,
|
|
372
|
+
linkMultiplicityIm = linkMultiplicityIm[..., :1]
|
|
373
|
+
if self.category_array_idx>0:
|
|
374
|
+
categoryIm = categoryIm[..., 1:-1]
|
|
341
375
|
if ndisp:
|
|
342
|
-
dyImNext = dyImNext[...,
|
|
343
|
-
dxImNext = dxImNext[...,
|
|
376
|
+
dyImNext = dyImNext[..., 1:]
|
|
377
|
+
dxImNext = dxImNext[..., 1:]
|
|
344
378
|
if self.return_link_multiplicity:
|
|
345
|
-
linkMultiplicityImNext = linkMultiplicityImNext[...,
|
|
379
|
+
linkMultiplicityImNext = linkMultiplicityImNext[..., 1:]
|
|
346
380
|
if self.return_label_rank:
|
|
347
|
-
rankIm = rankIm[...,
|
|
348
|
-
centerArr = centerArr[: ,
|
|
349
|
-
prevLabelArr = prevLabelArr[:, :1]
|
|
381
|
+
rankIm = rankIm[..., 1:-1]
|
|
382
|
+
centerArr = centerArr[: , 1:-1] # B, C, N, 2
|
|
383
|
+
prevLabelArr = prevLabelArr[:, :1] if self.tracking else None
|
|
350
384
|
if ndisp:
|
|
351
385
|
nextLabelArr = nextLabelArr[:, 1:]
|
|
352
386
|
if self.return_edm_derivatives:
|
|
@@ -360,7 +394,7 @@ class DyDxIterator(TrackingIterator):
|
|
|
360
394
|
all_channels.append(centerIm)
|
|
361
395
|
downscale_factor = 1./self.downscale if self.downscale>1 else 0
|
|
362
396
|
scale = [1, downscale_factor, downscale_factor, 1]
|
|
363
|
-
if self.downscale>1:
|
|
397
|
+
if self.downscale>1 and self.tracking:
|
|
364
398
|
dyIm = rescale(dyIm, scale, anti_aliasing= False, order=0)
|
|
365
399
|
dxIm = rescale(dxIm, scale, anti_aliasing= False, order=0)
|
|
366
400
|
if ndisp:
|
|
@@ -369,11 +403,12 @@ class DyDxIterator(TrackingIterator):
|
|
|
369
403
|
if ndisp:
|
|
370
404
|
dyIm = np.concatenate([dyIm, dyImNext], -1)
|
|
371
405
|
dxIm = np.concatenate([dxIm, dxImNext], -1)
|
|
372
|
-
if self.output_float16:
|
|
406
|
+
if self.output_float16 and self.tracking:
|
|
373
407
|
dyIm = dyIm.astype('float16', copy=False)
|
|
374
408
|
dxIm = dxIm.astype('float16', copy=False)
|
|
375
|
-
|
|
376
|
-
|
|
409
|
+
if self.tracking:
|
|
410
|
+
all_channels.append(dyIm)
|
|
411
|
+
all_channels.append(dxIm)
|
|
377
412
|
if self.return_link_multiplicity:
|
|
378
413
|
if self.downscale>1:
|
|
379
414
|
linkMultiplicityIm = rescale(linkMultiplicityIm, scale, anti_aliasing= False, order=0)
|
|
@@ -382,20 +417,22 @@ class DyDxIterator(TrackingIterator):
|
|
|
382
417
|
if ndisp:
|
|
383
418
|
linkMultiplicityIm = np.concatenate([linkMultiplicityIm, linkMultiplicityImNext], -1)
|
|
384
419
|
all_channels.append(linkMultiplicityIm)
|
|
385
|
-
if self.
|
|
420
|
+
if self.category_array_idx>0:
|
|
386
421
|
if self.downscale > 1:
|
|
387
422
|
categoryIm = rescale(categoryIm, scale, anti_aliasing=False, order=0)
|
|
388
423
|
all_channels.append(categoryIm)
|
|
389
424
|
if self.return_label_rank:
|
|
425
|
+
all_channels.append(rankIm)
|
|
390
426
|
if ndisp:
|
|
391
427
|
prevLabelArr = np.concatenate([prevLabelArr, nextLabelArr], 1)
|
|
392
|
-
|
|
393
|
-
|
|
428
|
+
if self.tracking:
|
|
429
|
+
all_channels.append(prevLabelArr)
|
|
394
430
|
all_channels.append(centerArr)
|
|
395
431
|
return all_channels
|
|
396
432
|
|
|
397
433
|
def _get_input_batch(self, batch_by_channel, ref_chan_idx, aug_param_array):
|
|
398
434
|
inputs = super()._get_input_batch(batch_by_channel, ref_chan_idx, aug_param_array)
|
|
435
|
+
batch_by_channel["input_centers"] = []
|
|
399
436
|
if len(self.label_input_channels)>0 : # for each input channel compute EDM and GCDM
|
|
400
437
|
lidx = len(inputs) - len(self.label_input_channels)
|
|
401
438
|
labels = inputs[lidx:]
|
|
@@ -403,11 +440,14 @@ class DyDxIterator(TrackingIterator):
|
|
|
403
440
|
for labelIms in labels: # compute EDM and GDCM for each additional input label image
|
|
404
441
|
edm = np.zeros(shape=labelIms.shape, dtype=np.float32)
|
|
405
442
|
gdcm = np.zeros(shape=labelIms.shape, dtype=np.float32)
|
|
443
|
+
centers = dict()
|
|
444
|
+
batch_by_channel["input_centers"].append(centers)
|
|
406
445
|
for b, c in itertools.product(range(labelIms.shape[0]), range(labelIms.shape[-1])):
|
|
407
446
|
cur_labels = labelIms[b, ..., c].astype(np.int32)
|
|
408
447
|
object_slices = find_objects(cur_labels)
|
|
409
448
|
edm[b, ..., c] = edt_smooth(cur_labels, object_slices)
|
|
410
449
|
labels_and_centers = _get_labels_and_centers(cur_labels, edm[b, ..., c], "MEDOID")
|
|
450
|
+
centers[(b, c)] = labels_and_centers.values() # record for output batch
|
|
411
451
|
_draw_centers(gdcm[b,...,c], labels_and_centers, cur_labels, object_slices, "GEODESIC")
|
|
412
452
|
inputs.append(edm)
|
|
413
453
|
inputs.append(gdcm)
|
|
@@ -439,11 +479,21 @@ def _get_small_objects_at_edges_to_erase(labelIm, min_size):
|
|
|
439
479
|
|
|
440
480
|
# displacement computation utils
|
|
441
481
|
|
|
442
|
-
def _get_labels_and_centers(labelIm, edm, center_mode = "GEOMETRICAL"):
|
|
482
|
+
def _get_labels_and_centers(labelIm, edm, center_mode = "GEOMETRICAL", input_centers=None):
|
|
443
483
|
labels = np.unique(labelIm)
|
|
444
484
|
labels = [int(round(l)) for l in labels if l!=0]
|
|
445
485
|
if len(labels)==0:
|
|
446
486
|
return dict()
|
|
487
|
+
if input_centers is not None: # map input center to labels and get other centers if necessary
|
|
488
|
+
l_c = dict()
|
|
489
|
+
for center in input_centers:
|
|
490
|
+
l = labelIm[int(round(center[0])), int(round(center[1]))]
|
|
491
|
+
if l>0:
|
|
492
|
+
l_c[l] = center
|
|
493
|
+
# remove labels with known centers
|
|
494
|
+
labels = [l for l in labels if l not in l_c.keys()]
|
|
495
|
+
if len(labels) == 0: # no unknown centers
|
|
496
|
+
return l_c
|
|
447
497
|
if center_mode == "GEOMETRICAL":
|
|
448
498
|
centers = center_of_mass(labelIm, labelIm, labels)
|
|
449
499
|
elif center_mode == "EDM_MAX":
|
|
@@ -465,7 +515,12 @@ def _get_labels_and_centers(labelIm, edm, center_mode = "GEOMETRICAL"):
|
|
|
465
515
|
centers = [get_medoid(*np.asarray(labelIm == l).nonzero()) for l in labels]
|
|
466
516
|
else:
|
|
467
517
|
raise ValueError(f"Invalid center mode: {center_mode}")
|
|
468
|
-
|
|
518
|
+
new_l_c = dict(zip(labels, centers))
|
|
519
|
+
if input_centers is None:
|
|
520
|
+
return new_l_c
|
|
521
|
+
else:
|
|
522
|
+
l_c.update(new_l_c)
|
|
523
|
+
return l_c
|
|
469
524
|
|
|
470
525
|
# channel dimension = frames
|
|
471
526
|
def _compute_prev_label_map(labelIm, prevlabelArray, end_points):
|
|
@@ -568,14 +623,13 @@ def _get_link_multiplicity(n_neigh):
|
|
|
568
623
|
return 2
|
|
569
624
|
|
|
570
625
|
def _compute_outputs(labels_map_centers, labelIm, labels_map_prev, object_slices, dyIm, dxIm, dyImNext=None, dxImNext=None, cdmIm=None, cdmImPrev=None, edmIm=None, edmImPrev=None, scale_edm:bool=False, categoryIm=None, categoryArray=None, categoryImPrev=None, categoryArrayPrev=None, linkMultiplicityIm=None, linkMultiplicityImNext=None, rankIm=None, rankImPrev=None, prevLabelArr=None, nextLabelArr=None, centerArr=None, centerArrPrev=None, center_distance_mode:str= "GEODESIC"):
|
|
571
|
-
assert labelIm.shape[-1] == 2, f"invalid labelIm : {labelIm.shape[-1]} channels instead of 2"
|
|
572
626
|
assert (dxImNext is None) == (dyImNext is None)
|
|
573
627
|
curLabelIm = labelIm[...,-1]
|
|
574
628
|
labels_prev = labels_map_centers[0].keys()
|
|
575
629
|
labels_prev_rank = {l:r for r, l in enumerate(labels_prev)}
|
|
576
|
-
labels_map_prev = _subset_label_map_prev(labels_map_prev, labels_prev, labels_map_centers[-1].keys())
|
|
630
|
+
labels_map_prev = _subset_label_map_prev(labels_map_prev, labels_prev, labels_map_centers[-1].keys()) if labels_map_prev is not None else None
|
|
577
631
|
for rank, (label, center) in enumerate(labels_map_centers[-1].items()):
|
|
578
|
-
label_prevs = labels_map_prev.get(label, [])
|
|
632
|
+
label_prevs = labels_map_prev.get(label, []) if labels_map_prev is not None else []
|
|
579
633
|
mask = curLabelIm == label
|
|
580
634
|
if len(label_prevs)==1:
|
|
581
635
|
label_prev = next(iter(label_prevs))
|
|
@@ -624,10 +678,10 @@ def _compute_outputs(labels_map_centers, labelIm, labels_map_prev, object_slices
|
|
|
624
678
|
if rankImPrev is not None:
|
|
625
679
|
rankImPrev[mask] = rank + 1
|
|
626
680
|
if cdmIm is not None:
|
|
627
|
-
assert cdmIm.shape ==
|
|
628
|
-
_draw_centers(cdmIm, labels_map_centers[-1],
|
|
681
|
+
assert cdmIm.shape == curLabelIm.shape, "invalid shape for center image"
|
|
682
|
+
_draw_centers(cdmIm, labels_map_centers[-1], curLabelIm, object_slices[-1], center_distance_mode=center_distance_mode)
|
|
629
683
|
if cdmImPrev is not None:
|
|
630
|
-
assert cdmImPrev.shape ==
|
|
684
|
+
assert cdmImPrev.shape == curLabelIm.shape, "invalid shape for center image prev"
|
|
631
685
|
_draw_centers(cdmImPrev, labels_map_centers[0], labelIm[...,0], object_slices[0], center_distance_mode=center_distance_mode)
|
|
632
686
|
if centerArr is not None:
|
|
633
687
|
for rank, (label, center) in enumerate(labels_map_centers[-1].items()):
|
|
@@ -3,14 +3,14 @@ import numpy as np
|
|
|
3
3
|
from random import getrandbits, uniform, choice
|
|
4
4
|
import scipy.ndimage as ndi
|
|
5
5
|
import tensorflow as tf
|
|
6
|
-
def get_swim1d_function(mask_channels:list, distance:int=50, min_gap:int=3, closed_end:bool = True):
|
|
6
|
+
def get_swim1d_function(mask_channels:list, ref_mask_idx:int=0, distance:int=50, min_gap:int=3, closed_end:bool = True):
|
|
7
7
|
if not isinstance(mask_channels, (list, tuple)):
|
|
8
8
|
mask_channels = [mask_channels]
|
|
9
9
|
assert len(mask_channels)>=1, "at least one mask channel must be provided"
|
|
10
10
|
def fun(batch_by_channel):
|
|
11
11
|
if distance > 1:
|
|
12
12
|
channels = [c for c in batch_by_channel.keys() if not isinstance(c, str) and c>=0]
|
|
13
|
-
mask_batch = batch_by_channel[mask_channels[
|
|
13
|
+
mask_batch = batch_by_channel[mask_channels[ref_mask_idx]]
|
|
14
14
|
for b, c in itertools.product(range(mask_batch.shape[0]), range(mask_batch.shape[-1])):
|
|
15
15
|
mask_img = mask_batch[b,...,c]
|
|
16
16
|
# get y space between bacteria
|