DiSTNet2D 0.2.0__tar.gz → 0.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. {distnet2d-0.2.0 → distnet2d-0.2.2}/DiSTNet2D.egg-info/PKG-INFO +4 -4
  2. {distnet2d-0.2.0 → distnet2d-0.2.2}/DiSTNet2D.egg-info/requires.txt +1 -1
  3. {distnet2d-0.2.0 → distnet2d-0.2.2}/PKG-INFO +4 -4
  4. {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/data/dydx_iterator.py +173 -119
  5. {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/data/swim1d.py +2 -2
  6. {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/model/architectures.py +47 -19
  7. {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/model/distnet_2d.py +187 -152
  8. {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/model/distnet_2d_seg.py +16 -34
  9. {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/model/layers.py +3 -3
  10. distnet2d-0.2.2/distnet_2d/model/spatial_attention.py +202 -0
  11. distnet2d-0.2.2/distnet_2d/utils/helpers.py +204 -0
  12. {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/utils/losses.py +6 -6
  13. distnet2d-0.2.2/distnet_2d/utils/metrics_tf.py +134 -0
  14. {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/utils/objectwise_computation_tf.py +48 -21
  15. {distnet2d-0.2.0 → distnet2d-0.2.2}/setup.py +4 -4
  16. distnet2d-0.2.0/distnet_2d/model/spatial_attention.py +0 -93
  17. distnet2d-0.2.0/distnet_2d/utils/helpers.py +0 -86
  18. distnet2d-0.2.0/distnet_2d/utils/metrics_tf.py +0 -89
  19. {distnet2d-0.2.0 → distnet2d-0.2.2}/DiSTNet2D.egg-info/SOURCES.txt +0 -0
  20. {distnet2d-0.2.0 → distnet2d-0.2.2}/DiSTNet2D.egg-info/dependency_links.txt +0 -0
  21. {distnet2d-0.2.0 → distnet2d-0.2.2}/DiSTNet2D.egg-info/top_level.txt +0 -0
  22. {distnet2d-0.2.0 → distnet2d-0.2.2}/LICENSE.txt +0 -0
  23. {distnet2d-0.2.0 → distnet2d-0.2.2}/README.md +0 -0
  24. {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/__init__.py +0 -0
  25. {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/data/__init__.py +0 -0
  26. {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/data/center_edm.py +0 -0
  27. {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/data/medoid.py +0 -0
  28. {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/model/__init__.py +0 -0
  29. {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/model/gradient_accumulator.py +0 -0
  30. {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/utils/__init__.py +0 -0
  31. {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/utils/agc.py +0 -0
  32. {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/utils/image_derivatives_np.py +0 -0
  33. {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/utils/image_derivatives_tf.py +0 -0
  34. {distnet2d-0.2.0 → distnet2d-0.2.2}/distnet_2d/utils/lovasz_loss.py +0 -0
  35. {distnet2d-0.2.0 → distnet2d-0.2.2}/setup.cfg +0 -0
@@ -1,11 +1,11 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: DiSTNet2D
3
- Version: 0.2.0
3
+ Version: 0.2.2
4
4
  Summary: tensorflow/keras implementation of DiSTNet 2D
5
5
  Home-page: https://github.com/jeanollion/distnet2d
6
- Download-URL: https://github.com/jeanollion/distnet2d/releases/download/v0.1.9/distnet2d-0.1.9.tar.gz
6
+ Download-URL: https://github.com/jeanollion/distnet2d/releases/download/v0.2.2/distnet2d-0.2.2.tar.gz
7
7
  Author: Jean Ollion
8
- Author-email: jean.ollion@polytechnique.org
8
+ Author-email: jean.ollion@sabilab.fr
9
9
  Keywords: Segmentation,Tracking,Cell,Tensorflow,Keras
10
10
  Classifier: Development Status :: 4 - Beta
11
11
  Classifier: Intended Audience :: Science/Research
@@ -22,7 +22,7 @@ Requires-Dist: tensorflow>=2.7.1
22
22
  Requires-Dist: edt>=2.0.2
23
23
  Requires-Dist: scikit-fmm
24
24
  Requires-Dist: numba
25
- Requires-Dist: dataset_iterator>=0.5.3
25
+ Requires-Dist: dataset_iterator>=0.5.5
26
26
  Requires-Dist: elasticdeform>=0.4.7
27
27
  Dynamic: author
28
28
  Dynamic: author-email
@@ -4,5 +4,5 @@ tensorflow>=2.7.1
4
4
  edt>=2.0.2
5
5
  scikit-fmm
6
6
  numba
7
- dataset_iterator>=0.5.3
7
+ dataset_iterator>=0.5.5
8
8
  elasticdeform>=0.4.7
@@ -1,11 +1,11 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: DiSTNet2D
3
- Version: 0.2.0
3
+ Version: 0.2.2
4
4
  Summary: tensorflow/keras implementation of DiSTNet 2D
5
5
  Home-page: https://github.com/jeanollion/distnet2d
6
- Download-URL: https://github.com/jeanollion/distnet2d/releases/download/v0.1.9/distnet2d-0.1.9.tar.gz
6
+ Download-URL: https://github.com/jeanollion/distnet2d/releases/download/v0.2.2/distnet2d-0.2.2.tar.gz
7
7
  Author: Jean Ollion
8
- Author-email: jean.ollion@polytechnique.org
8
+ Author-email: jean.ollion@sabilab.fr
9
9
  Keywords: Segmentation,Tracking,Cell,Tensorflow,Keras
10
10
  Classifier: Development Status :: 4 - Beta
11
11
  Classifier: Intended Audience :: Science/Research
@@ -22,7 +22,7 @@ Requires-Dist: tensorflow>=2.7.1
22
22
  Requires-Dist: edt>=2.0.2
23
23
  Requires-Dist: scikit-fmm
24
24
  Requires-Dist: numba
25
- Requires-Dist: dataset_iterator>=0.5.3
25
+ Requires-Dist: dataset_iterator>=0.5.5
26
26
  Requires-Dist: elasticdeform>=0.4.7
27
27
  Dynamic: author
28
28
  Dynamic: author-email
@@ -30,30 +30,43 @@ class DyDxIterator(TrackingIterator):
30
30
  frame_window:int,
31
31
  aug_frame_subsampling, # either int: frame number interval will be drawn uniformly in [frame_window,aug_frame_subsampling] or callable that generate an frame number interval (int)
32
32
  erase_edge_cell_size:int,
33
- next:bool = True,
33
+ next_frames:bool = True,
34
34
  allow_frame_subsampling_direct_neigh:bool = False,
35
35
  aug_remove_prob: float = 0.005,
36
36
  return_link_multiplicity:bool = True,
37
37
  channel_keywords:list=CHANNEL_KEYWORDS, # channel @1 must be label
38
- input_label_keywords:list = None, # additional labels that will be considered as input to the neural network
39
- array_keywords:list=ARRAY_KEYWORDS[:1], # if second array : category
38
+ input_label_keywords:list = None, # additional labels that will be considered as input to the neural network
39
+ array_keywords:list=ARRAY_KEYWORDS[:1], # if second array : category
40
40
  elasticdeform_parameters:dict = None,
41
41
  downscale_displacement_and_link_multiplicity=1,
42
42
  return_edm_derivatives: bool = False,
43
43
  return_center:bool = True,
44
- scale_edm:bool = False, # for each cell max edm value is 1
44
+ scale_edm:bool = False, # for each cell max edm value is 1
45
45
  center_mode:str = "MEDOID", # GEOMETRICAL, "EDM_MAX", "EDM_MEAN", "SKELETON", "MEDOID"
46
- center_distance_mode = "GEODESIC", # GEODESIC, EUCLIDEAN
46
+ center_distance_mode = "GEODESIC", # GEODESIC, EUCLIDEAN
47
+ input_label_center_idx:int = -1, # center of target label are centers of the input label of this index
47
48
  return_label_rank = False,
48
49
  long_term:bool = True,
50
+ tracking:bool = True,
49
51
  return_next_displacement:bool = True,
50
52
  output_float16=False,
51
53
  **kwargs):
52
54
  assert len(channel_keywords)>=2, 'keyword should contain at least 2 elements in this order: grayscale input images, object labels, [other grayscale input images]'
53
- assert 2 >= len(array_keywords) >= 1, 'array keyword first element should be links to previous objects. if 2 elements: second must be cateogry'
55
+ if frame_window == 0:
56
+ tracking = False
57
+ if tracking:
58
+ assert 2 >= len(array_keywords) >= 1, 'array keyword first element should be links to previous objects. if 2 elements: second must be category'
59
+ else:
60
+ assert 1 >= len(array_keywords) >= 0, 'only one array keyword allow (category)'
61
+
54
62
  assert center_mode.upper() in CENTER_MODE, f"invalid center mode: {center_mode} should be in {CENTER_MODE}"
55
63
  assert center_distance_mode.upper() in CENTER_DISTANCE_MODE, f"invalid center distance mode: {center_distance_mode} should be in {CENTER_DISTANCE_MODE}"
56
- self.return_category = len(array_keywords)>1
64
+ self.category_array_idx = next((i for i, s in enumerate(array_keywords) if "category" in s), -1)
65
+ self.tracking = tracking
66
+ if not tracking:
67
+ return_link_multiplicity = False
68
+ return_next_displacement = False
69
+ long_term = False
57
70
  self.return_link_multiplicity=return_link_multiplicity
58
71
  self.downscale=downscale_displacement_and_link_multiplicity
59
72
  self.erase_edge_cell_size=erase_edge_cell_size
@@ -65,13 +78,15 @@ class DyDxIterator(TrackingIterator):
65
78
  self.return_center=return_center
66
79
  self.center_mode=center_mode.upper()
67
80
  self.center_distance_mode = center_distance_mode.upper()
81
+ if input_label_center_idx >= 0:
82
+ assert input_label_center_idx < len(input_label_keywords), f"invalid input_label_center_idx={input_label_center_idx} must be < len(input_label_keywords)={len(input_label_keywords)}"
83
+ self.input_label_center_idx = input_label_center_idx
68
84
  self.return_label_rank=return_label_rank
69
- assert frame_window>=1, "frame_window must be >=1"
70
85
  self.frame_window = frame_window
71
86
  self.return_next_displacement=return_next_displacement
72
87
  self.n_label_max = kwargs.pop("n_label_max", 2000)
73
- self.long_term=long_term if self.frame_window>1 else False
74
- self.return_central_only = False
88
+ self.long_term=long_term
89
+ self.output_central_only = False
75
90
  nchan = len(channel_keywords)
76
91
  if input_label_keywords is not None:
77
92
  if not isinstance(input_label_keywords, list):
@@ -85,20 +100,20 @@ class DyDxIterator(TrackingIterator):
85
100
  else:
86
101
  self.label_input_channels = []
87
102
  super().__init__(dataset=dataset,
88
- channel_keywords=channel_keywords,
89
- array_keywords = array_keywords,
90
- input_channels=[0] + [i for i in range(2, nchan)] + self.label_input_channels,
91
- output_channels=[1],
92
- channels_prev=[True]*len(channel_keywords),
93
- channels_next=[next]*len(channel_keywords),
94
- mask_channels=[1] + self.label_input_channels,
95
- n_frames = self.frame_window,
96
- aug_remove_prob=aug_remove_prob,
97
- aug_all_frames=False,
98
- convert_masks_to_dtype=False,
99
- extract_tile_function=extract_tile_function,
100
- elasticdeform_parameters=elasticdeform_parameters,
101
- **kwargs)
103
+ channel_keywords=channel_keywords,
104
+ array_keywords = array_keywords,
105
+ input_channels=[0] + [i for i in range(2, nchan)] + self.label_input_channels,
106
+ output_channels=[1],
107
+ channels_prev=[True]*len(channel_keywords),
108
+ channels_next=[next_frames] * len(channel_keywords),
109
+ mask_channels=[1] + self.label_input_channels,
110
+ n_frames = self.frame_window,
111
+ aug_remove_prob=aug_remove_prob,
112
+ aug_all_frames=False,
113
+ convert_masks_to_dtype=False,
114
+ extract_tile_function=extract_tile_function,
115
+ elasticdeform_parameters=elasticdeform_parameters,
116
+ **kwargs)
102
117
 
103
118
  def disable_random_transforms(self, data_augmentation:bool=True, channels_postprocessing:bool=False):
104
119
  params = super().disable_random_transforms(data_augmentation, channels_postprocessing)
@@ -112,18 +127,21 @@ class DyDxIterator(TrackingIterator):
112
127
  self.aug_frame_subsampling = parameters["aug_frame_subsampling"]
113
128
 
114
129
  def _get_batch_by_channel(self, index_array, perform_augmentation, input_only=False, perform_elasticdeform=True, perform_tiling=True, **kwargs):
115
- if self.aug_remove_prob>0 and random() < self.aug_remove_prob:
116
- n_frames = 0 # flag that aug_remove = true
117
- else:
118
- if self.aug_frame_subsampling is not None :
119
- if callable(self.aug_frame_subsampling):
120
- n_frames = self.aug_frame_subsampling()
121
- elif self.aug_frame_subsampling > self.frame_window:
122
- n_frames = max(self.frame_window, np.random.randint(self.aug_frame_subsampling))
130
+ if self.frame_window > 0:
131
+ if self.aug_remove_prob>0 and random() < self.aug_remove_prob:
132
+ n_frames = 0 # flag that aug_remove = true
133
+ else:
134
+ if self.aug_frame_subsampling is not None :
135
+ if callable(self.aug_frame_subsampling):
136
+ n_frames = self.aug_frame_subsampling()
137
+ elif self.aug_frame_subsampling > self.frame_window:
138
+ n_frames = max(self.frame_window, np.random.randint(self.aug_frame_subsampling))
139
+ else:
140
+ n_frames = self.frame_window
123
141
  else:
124
142
  n_frames = self.frame_window
125
- else:
126
- n_frames = self.frame_window
143
+ else:
144
+ n_frames = 0
127
145
  kwargs.update({"n_frames":n_frames})
128
146
  batch_by_channel, aug_param_array, ref_channel = super()._get_batch_by_channel(index_array, perform_augmentation, input_only, perform_elasticdeform=False, perform_tiling=False, **kwargs)
129
147
  ref_shape = batch_by_channel[0].shape
@@ -132,25 +150,26 @@ class DyDxIterator(TrackingIterator):
132
150
  if not issubclass(batch_by_channel[1].dtype.type, np.integer): # label
133
151
  batch_by_channel[1] = batch_by_channel[1].astype(np.int32)
134
152
  # correction for oob @ previous labels : add identity links
135
- prevLinks = batch_by_channel['arrays'][0]
136
- for b in range(prevLinks.shape[0]):
137
- if n_frames > 0:
138
- for i in range(n_frames):
139
- inc = n_frames - i
140
- prev_inc = aug_param_array[b][ref_channel].get(f"prev_inc_{inc}", inc)
141
- if prev_inc!=inc:
142
- #print(f"oob prev: batch: {b} n_frames={n_frames}, frame_idx:{i} inc={inc} actual inc:{prev_inc} will replace at {i+1}")
143
- self._set_identity_link(prevLinks, b, i+1)
144
- if self.channels_next[1]:
145
- next_inc = aug_param_array[b][ref_channel].get(f"next_inc_{inc}", inc)
146
- if next_inc!=inc:
147
- #print(f"oob next: batch: {b} n_frames={n_frames}, frame_idx:{i} inc={inc} actual inc:{next_inc} will replace prev labels at {n_frames+inc}")
148
- self._set_identity_link(prevLinks, b, n_frames + inc)
149
- else: # n_frame == 0 ->
150
- for c in range(1, prevLinks.shape[-1]):
151
- self._set_identity_link(prevLinks, b, c)
152
- # get previous labels and store in batch_by_channel BEFORE applying tiling and elastic deform
153
- self._get_prev_label(batch_by_channel, n_frames)
153
+ if self.tracking:
154
+ prevLinks = batch_by_channel['arrays'][0]
155
+ for b in range(prevLinks.shape[0]):
156
+ if n_frames > 0:
157
+ for i in range(n_frames):
158
+ inc = n_frames - i
159
+ prev_inc = aug_param_array[b][ref_channel].get(f"prev_inc_{inc}", inc)
160
+ if prev_inc!=inc:
161
+ #print(f"oob prev: batch: {b} n_frames={n_frames}, frame_idx:{i} inc={inc} actual inc:{prev_inc} will replace at {i+1}")
162
+ self._set_identity_link(prevLinks, b, i+1)
163
+ if self.channels_next[1]:
164
+ next_inc = aug_param_array[b][ref_channel].get(f"next_inc_{inc}", inc)
165
+ if next_inc!=inc:
166
+ #print(f"oob next: batch: {b} n_frames={n_frames}, frame_idx:{i} inc={inc} actual inc:{next_inc} will replace prev labels at {n_frames+inc}")
167
+ self._set_identity_link(prevLinks, b, n_frames + inc)
168
+ else: # n_frame == 0 ->
169
+ for c in range(1, prevLinks.shape[-1]):
170
+ self._set_identity_link(prevLinks, b, c)
171
+ # get previous labels and store in batch_by_channel BEFORE applying tiling and elastic deform
172
+ self._get_prev_label(batch_by_channel, n_frames)
154
173
  batch_by_channel["batch_size"] = batch_by_channel[0].shape[0] # batch size is recorded here: it will be used in case of tiling
155
174
  if n_frames>1: # remove unused frames
156
175
  sel = self._get_end_points(n_frames, False)
@@ -168,7 +187,7 @@ class DyDxIterator(TrackingIterator):
168
187
  if perform_elasticdeform:
169
188
  self._apply_elasticdeform(batch_by_channel)
170
189
  if perform_tiling:
171
- self._apply_tiling(batch_by_channel)
190
+ self._apply_tiling(batch_by_channel, index_array)
172
191
  if perform_elasticdeform or perform_tiling:
173
192
  for c in converted_from_float16:
174
193
  batch_by_channel[c] = batch_by_channel[c].astype('float16')
@@ -181,7 +200,7 @@ class DyDxIterator(TrackingIterator):
181
200
  prev_label[b, :, 1, c] = prev_label[b, :, 0, c]
182
201
 
183
202
  def _get_frames_to_augment(self, img, chan_idx, aug_params):
184
- if self.aug_all_frames:
203
+ if self.aug_all_frames or self.frame_window == 0:
185
204
  return list(range(img.shape[-1]))
186
205
  n_frames = (img.shape[-1]-1)//2 if self.channels_prev[chan_idx] and self.channels_next[chan_idx] else img.shape[-1]-1
187
206
  return self._get_end_points(n_frames, False)
@@ -224,17 +243,24 @@ class DyDxIterator(TrackingIterator):
224
243
  def _get_output_batch(self, batch_by_channel, ref_chan_idx, aug_param_array): # compute edm, center, dx & dy edm, link_multiplicity
225
244
  ndisp = self.return_next_displacement
226
245
  labelIms = batch_by_channel[1]
227
- labels_map_prev = batch_by_channel["prev_label_map"]
246
+ labels_map_prev = batch_by_channel["prev_label_map"] if self.tracking else None
228
247
  return_next = self.channels_next[1]
229
248
  long_term = self.long_term
230
249
  frame_window = self.frame_window
231
- if self.return_central_only:
250
+ if self.output_central_only and self.frame_window > 0:
232
251
  assert self.channels_prev[1], "in return_central_only mode previous must be returned"
233
252
  assert return_next, "in return_central_only mode next must be returned"
234
- labelIms = labelIms[..., self.frame_window-1:self.frame_window+2] # only prev, central, and next frame
253
+ for midx in self.mask_channels:
254
+ batch_by_channel[midx] = batch_by_channel[midx][..., self.frame_window-1:self.frame_window+2] # only prev, central, and next frame
255
+ labelIms = batch_by_channel[1]
256
+ for aidx in range(len(batch_by_channel['arrays'])):
257
+ batch_by_channel['arrays'][aidx] = batch_by_channel['arrays'][aidx][..., self.frame_window-1:self.frame_window+2]
258
+ for icidx in range(len(batch_by_channel["input_centers"])): # remap : 0 -> fw-1, 1 -> fw, 2 -> fw+1
259
+ centers = batch_by_channel["input_centers"][icidx]
260
+ batch_by_channel["input_centers"][icidx] = {(b, c-(self.frame_window-1)):center for (b, c), center in centers.items() if self.frame_window-1<=c<=self.frame_window+1 }
235
261
  long_term = False
236
262
  frame_window = 1
237
- labels_map_prev = [lmp[self.frame_window-1:self.frame_window+2] for lmp in labels_map_prev]
263
+ labels_map_prev = [lmp[self.frame_window-1:self.frame_window+2] for lmp in labels_map_prev] if self.tracking else None
238
264
  # remove small object
239
265
  mask_to_erase_cur = [chan_idx for chan_idx in self.mask_channels if chan_idx!=1 and chan_idx in batch_by_channel]
240
266
  mask_to_erase_chan_cur = [frame_window if self.channels_prev[chan_idx] else 0 for chan_idx in mask_to_erase_cur]
@@ -262,23 +288,23 @@ class DyDxIterator(TrackingIterator):
262
288
  n_motion = 2 * frame_window if return_next else frame_window
263
289
  if long_term:
264
290
  n_motion = n_motion + (2 * ( frame_window - 1 ) if return_next else frame_window -1)
265
- dyIm = np.zeros(labelIms.shape[:-1]+(n_motion,), dtype=self.dtype)
266
- dxIm = np.zeros(labelIms.shape[:-1]+(n_motion,), dtype=self.dtype)
291
+ dyIm = np.zeros(labelIms.shape[:-1]+(n_motion,), dtype=self.dtype) if self.tracking else None
292
+ dxIm = np.zeros(labelIms.shape[:-1]+(n_motion,), dtype=self.dtype) if self.tracking else None
267
293
  if ndisp:
268
294
  dyImNext = np.zeros(labelIms.shape[:-1]+(n_motion,), dtype=self.dtype)
269
295
  dxImNext = np.zeros(labelIms.shape[:-1]+(n_motion,), dtype=self.dtype)
270
296
  if self.return_link_multiplicity:
271
297
  linkMultiplicityImNext = np.zeros(labelIms.shape[:-1]+(n_motion,), dtype=self.dtype)
272
298
  centerIm = np.zeros(labelIms.shape, dtype=self.dtype) if self.return_center else None
273
- categoryIm = np.zeros(labelIms.shape, dtype=self.dtype) if self.return_category else None
274
- cat_array = batch_by_channel['arrays'][1] if self.return_category else None
299
+ categoryIm = np.zeros(labelIms.shape, dtype=self.dtype) if self.category_array_idx>0 else None
300
+ cat_array = batch_by_channel['arrays'][self.category_array_idx] if self.category_array_idx>0 else None
275
301
  if cat_array is not None and len(cat_array.shape) == 4:
276
302
  cat_array = cat_array[:, :, 0]
277
303
  if self.return_label_rank:
278
304
  rankIm = np.zeros(labelIms.shape, dtype=np.int32)
279
- prevLabelArr = np.zeros(labelIms.shape[:1]+(n_motion, self.n_label_max), dtype=np.int32)
280
- nextLabelArr = np.zeros(labelIms.shape[:1] + (n_motion, self.n_label_max), dtype=np.int32)
281
- centerArr = np.zeros(labelIms.shape[:1]+labelIms.shape[-1:]+(self.n_label_max,2), dtype=np.float32)
305
+ prevLabelArr = np.zeros(labelIms.shape[:1]+(n_motion, self.n_label_max), dtype=np.int32) if self.tracking else None
306
+ nextLabelArr = np.zeros(labelIms.shape[:1] + (n_motion, self.n_label_max), dtype=np.int32) if self.tracking else None
307
+ centerArr = np.zeros(labelIms.shape[:1]+labelIms.shape[-1:]+(self.n_label_max,2), dtype=np.float32) # B, C, N, 2
282
308
  centerArr.fill(np.nan)
283
309
  if self.return_link_multiplicity:
284
310
  linkMultiplicityIm = np.zeros(labelIms.shape[:-1]+(n_motion,), dtype=self.dtype)
@@ -290,63 +316,71 @@ class DyDxIterator(TrackingIterator):
290
316
  get_idx = lambda x:x%batch_size # We assume here that the indexing order of tiling is tile x batch
291
317
  else:
292
318
  get_idx = lambda x:x
319
+
293
320
  labels_and_centers = {}
294
321
  for b,c in itertools.product(range(labelIms.shape[0]), range(labelIms.shape[-1])):
295
- labels_and_centers[(b, c)] = _get_labels_and_centers(labelIms[b][...,c], edm[b][...,c], self.center_mode)
322
+ input_centers = batch_by_channel["input_centers"][self.input_label_center_idx][(b, c)] if self.input_label_center_idx >= 0 else None
323
+ labels_and_centers[(b, c)] = _get_labels_and_centers(labelIms[b][...,c], edm[b][...,c], self.center_mode, input_centers=input_centers)
296
324
  for i in range(labelIms.shape[0]):
297
325
  bidx = get_idx(i)
298
- for c in range(0, frame_window):
299
- sel = [c, c+1]
300
- l_c = [labels_and_centers[(i,s)] for s in sel]
301
- o_s = [object_slices[(i, s)] for s in sel]
302
- _compute_outputs(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c], o_s, dyIm[i,...,c], dxIm[i,...,c], dyImNext=dyImNext[i,...,c] if ndisp else None, dxImNext=dxImNext[i,...,c] if ndisp else None, cdmIm=centerIm[i,...,frame_window] if self.return_center and sel[1] == frame_window else None, cdmImPrev=centerIm[i,...,c] if self.return_center else None, edmIm = edm[i,...,frame_window] if self.scale_edm else None, edmImPrev = edm[i,...,c] if self.scale_edm else None, scale_edm=self.scale_edm, categoryIm=categoryIm[i,...,frame_window] if self.return_category and sel[1] == frame_window else None, categoryArray=cat_array[bidx, :, frame_window] if self.return_category and sel[1] == frame_window else None, categoryImPrev=categoryIm[i,...,c] if self.return_category else None, categoryArrayPrev=cat_array[bidx, :, c] if self.return_category else None, linkMultiplicityIm=linkMultiplicityIm[i,...,c] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,...,c] if self.return_link_multiplicity and ndisp else None, rankIm=rankIm[i,...,frame_window] if self.return_label_rank and sel[1] == frame_window else None, rankImPrev=rankIm[i,...,c] if self.return_label_rank else None, prevLabelArr=prevLabelArr[i,c] if self.return_label_rank else None, nextLabelArr=nextLabelArr[i,c] if self.return_label_rank and ndisp else None, centerArr=centerArr[i,frame_window] if self.return_label_rank and sel[1] == frame_window else None, centerArrPrev=centerArr[i,c] if self.return_label_rank else None, center_distance_mode=self.center_distance_mode)
303
- if return_next:
304
- for c in range(frame_window, 2*frame_window):
326
+ if self.frame_window > 0:
327
+ for c in range(0, frame_window):
305
328
  sel = [c, c+1]
306
- l_c = [labels_and_centers[(i, s)] for s in sel]
307
- o_s = [object_slices[(i, s)] for s in sel]
308
- _compute_outputs(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c], o_s, dyIm[i,...,c], dxIm[i,...,c], dyImNext=dyImNext[i,...,c] if ndisp else None, dxImNext=dxImNext[i,...,c] if ndisp else None, cdmIm=centerIm[i,..., c + 1] if self.return_center else None, edmIm = edm[i,...,c+1] if self.scale_edm else None, scale_edm=self.scale_edm, categoryIm=categoryIm[i,..., c + 1] if self.return_category else None, categoryArray=cat_array[bidx, :, c+1] if self.return_category else None, cdmImPrev=None, linkMultiplicityIm=linkMultiplicityIm[i,...,c] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,...,c] if self.return_link_multiplicity and ndisp else None, rankIm=rankIm[i,..., c + 1] if self.return_label_rank else None, rankImPrev=None, prevLabelArr=prevLabelArr[i,c] if self.return_label_rank else None, nextLabelArr=nextLabelArr[i,c] if self.return_label_rank and ndisp else None, centerArr=centerArr[i, c + 1] if self.return_label_rank else None, center_distance_mode=self.center_distance_mode)
309
- if long_term:
310
- off = 2*frame_window if return_next else frame_window
311
- for c in range(0, frame_window-1):
312
- sel = [c, frame_window]
313
- l_c = [labels_and_centers[(i, s)] for s in sel]
329
+ l_c = [labels_and_centers[(i,s)] for s in sel]
314
330
  o_s = [object_slices[(i, s)] for s in sel]
315
- _compute_outputs(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c + off], o_s, dyIm[i,..., c + off], dxIm[i,..., c + off], dyImNext=dyImNext[i,..., c + off] if ndisp else None, dxImNext=dxImNext[i,..., c + off] if ndisp else None, cdmIm=None, cdmImPrev=None, linkMultiplicityIm=linkMultiplicityIm[i,..., c + off] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,..., c + off] if self.return_link_multiplicity and ndisp else None, rankIm=None, rankImPrev=None, prevLabelArr=prevLabelArr[i, c + off] if self.return_label_rank else None, nextLabelArr=nextLabelArr[i, c + off] if self.return_label_rank and ndisp else None, center_distance_mode=self.center_distance_mode)
331
+ _compute_outputs(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c] if self.tracking else None, o_s, dyIm[i,...,c] if self.tracking else None, dxIm[i,...,c] if self.tracking else None, dyImNext=dyImNext[i,...,c] if ndisp else None, dxImNext=dxImNext[i,...,c] if ndisp else None, cdmIm=centerIm[i,...,frame_window] if self.return_center and sel[1] == frame_window else None, cdmImPrev=centerIm[i,...,c] if self.return_center else None, edmIm = edm[i,...,frame_window] if self.scale_edm else None, edmImPrev = edm[i,...,c] if self.scale_edm else None, scale_edm=self.scale_edm, categoryIm=categoryIm[i,...,frame_window] if self.category_array_idx>0 and sel[1] == frame_window else None, categoryArray=cat_array[bidx, :, frame_window] if self.category_array_idx>0 and sel[1] == frame_window else None, categoryImPrev=categoryIm[i,...,c] if self.category_array_idx>0 else None, categoryArrayPrev=cat_array[bidx, :, c] if self.category_array_idx>0 else None, linkMultiplicityIm=linkMultiplicityIm[i,...,c] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,...,c] if self.return_link_multiplicity and ndisp else None, rankIm=rankIm[i,...,frame_window] if self.return_label_rank and sel[1] == frame_window else None, rankImPrev=rankIm[i,...,c] if self.return_label_rank else None, prevLabelArr=prevLabelArr[i,c] if self.return_label_rank and self.tracking else None, nextLabelArr=nextLabelArr[i,c] if self.return_label_rank and self.tracking and ndisp else None, centerArr=centerArr[i,frame_window] if self.return_label_rank and sel[1] == frame_window else None, centerArrPrev=centerArr[i,c] if self.return_label_rank else None, center_distance_mode=self.center_distance_mode)
316
332
  if return_next:
317
- for c in range(frame_window-1, 2*(frame_window-1)):
318
- sel = [frame_window, c+3]
333
+ for c in range(frame_window, 2*frame_window):
334
+ sel = [c, c+1]
335
+ l_c = [labels_and_centers[(i, s)] for s in sel]
336
+ o_s = [object_slices[(i, s)] for s in sel]
337
+ _compute_outputs(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c] if self.tracking else None, o_s, dyIm[i,...,c] if self.tracking else None, dxIm[i,...,c] if self.tracking else None, dyImNext=dyImNext[i,...,c] if ndisp else None, dxImNext=dxImNext[i,...,c] if ndisp else None, cdmIm=centerIm[i,..., c + 1] if self.return_center else None, edmIm = edm[i,...,c+1] if self.scale_edm else None, scale_edm=self.scale_edm, categoryIm=categoryIm[i,..., c + 1] if self.category_array_idx>0 else None, categoryArray=cat_array[bidx, :, c+1] if self.category_array_idx>0 else None, cdmImPrev=None, linkMultiplicityIm=linkMultiplicityIm[i,...,c] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,...,c] if self.return_link_multiplicity and ndisp else None, rankIm=rankIm[i,..., c + 1] if self.return_label_rank else None, rankImPrev=None, prevLabelArr=prevLabelArr[i,c] if self.return_label_rank and self.tracking else None, nextLabelArr=nextLabelArr[i,c] if self.return_label_rank and ndisp else None, centerArr=centerArr[i, c + 1] if self.return_label_rank else None, center_distance_mode=self.center_distance_mode)
338
+ if long_term:
339
+ off = 2*frame_window if return_next else frame_window
340
+ for c in range(0, frame_window-1):
341
+ sel = [c, frame_window]
319
342
  l_c = [labels_and_centers[(i, s)] for s in sel]
320
343
  o_s = [object_slices[(i, s)] for s in sel]
321
344
  _compute_outputs(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c + off], o_s, dyIm[i,..., c + off], dxIm[i,..., c + off], dyImNext=dyImNext[i,..., c + off] if ndisp else None, dxImNext=dxImNext[i,..., c + off] if ndisp else None, cdmIm=None, cdmImPrev=None, linkMultiplicityIm=linkMultiplicityIm[i,..., c + off] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,..., c + off] if self.return_link_multiplicity and ndisp else None, rankIm=None, rankImPrev=None, prevLabelArr=prevLabelArr[i, c + off] if self.return_label_rank else None, nextLabelArr=nextLabelArr[i, c + off] if self.return_label_rank and ndisp else None, center_distance_mode=self.center_distance_mode)
345
+ if return_next:
346
+ for c in range(frame_window-1, 2*(frame_window-1)):
347
+ sel = [frame_window, c+3]
348
+ l_c = [labels_and_centers[(i, s)] for s in sel]
349
+ o_s = [object_slices[(i, s)] for s in sel]
350
+ _compute_outputs(l_c, labelIms[i][...,sel], labels_map_prev[bidx][c + off], o_s, dyIm[i,..., c + off], dxIm[i,..., c + off], dyImNext=dyImNext[i,..., c + off] if ndisp else None, dxImNext=dxImNext[i,..., c + off] if ndisp else None, cdmIm=None, cdmImPrev=None, linkMultiplicityIm=linkMultiplicityIm[i,..., c + off] if self.return_link_multiplicity else None, linkMultiplicityImNext=linkMultiplicityImNext[i,..., c + off] if self.return_link_multiplicity and ndisp else None, rankIm=None, rankImPrev=None, prevLabelArr=prevLabelArr[i, c + off] if self.return_label_rank else None, nextLabelArr=nextLabelArr[i, c + off] if self.return_label_rank and ndisp else None, center_distance_mode=self.center_distance_mode)
351
+ else:
352
+ l_c = [labels_and_centers[(i, 0)]]
353
+ o_s = [object_slices[(i, 0)]]
354
+ _compute_outputs(l_c, labelIms[i][...,0:1], None, o_s, None, None, cdmIm=centerIm[i,...,0], edmIm = edm[i,...,0] if self.scale_edm else None, scale_edm=self.scale_edm, categoryIm=categoryIm[i,...,0] if self.category_array_idx>0 else None, categoryArray=cat_array[bidx, :, frame_window] if self.category_array_idx>0 else None, rankIm=rankIm[i,...,0] if self.return_label_rank else None, centerArr=centerArr[i,0] if self.return_label_rank else None, center_distance_mode=self.center_distance_mode)
322
355
 
323
356
  edm[edm == 0] = -1
324
357
  if self.return_edm_derivatives:
325
358
  der_y, der_x = np.zeros_like(edm), np.zeros_like(edm)
326
- for b, c in itertools.product(range(edm.shape[0]), range(edm.shape[-1])):
359
+ c_range = range(edm.shape[-1]) if not self.output_central_only else range(frame_window, frame_window - 1)
360
+ for b, c in itertools.product(range(edm.shape[0]), c_range):
327
361
  derivatives_labelwise(edm[b, ..., c], -1, der_y[b, ..., c], der_x[b, ..., c], labelIms[b, ..., c], object_slices[(b, c)])
328
- if self.return_central_only:
329
- der_y = der_y[..., frame_window:frame_window+1]
330
- der_x = der_x[..., frame_window:frame_window+1]
331
-
332
- if self.return_central_only: # select only central frame for edm / center and only displacement / link multiplicity related to central frame
333
- edm = edm[..., frame_window:frame_window+1]
334
- centerIm = centerIm[..., frame_window:frame_window+1]
335
- dyIm = dyIm[..., frame_window-1:frame_window]
336
- dxIm = dxIm[..., frame_window-1:frame_window]
362
+ if self.output_central_only:
363
+ der_y = der_y[..., 1:-1]
364
+ der_x = der_x[..., 1:-1]
365
+
366
+ if self.output_central_only: # select only central frame for edm / center and only displacement / link multiplicity related to central frame
367
+ edm = edm[..., 1:-1]
368
+ centerIm = centerIm[..., 1:-1]
369
+ dyIm = dyIm[..., :1] if self.tracking else None
370
+ dxIm = dxIm[..., :1] if self.tracking else None
337
371
  if self.return_link_multiplicity:
338
- linkMultiplicityIm = linkMultiplicityIm[..., frame_window-1:frame_window]
339
- if self.return_category:
340
- categoryIm = categoryIm[..., frame_window:frame_window+1]
372
+ linkMultiplicityIm = linkMultiplicityIm[..., :1]
373
+ if self.category_array_idx>0:
374
+ categoryIm = categoryIm[..., 1:-1]
341
375
  if ndisp:
342
- dyImNext = dyImNext[..., frame_window-1:frame_window]
343
- dxImNext = dxImNext[..., frame_window-1:frame_window]
376
+ dyImNext = dyImNext[..., 1:]
377
+ dxImNext = dxImNext[..., 1:]
344
378
  if self.return_link_multiplicity:
345
- linkMultiplicityImNext = linkMultiplicityImNext[..., frame_window-1:frame_window]
379
+ linkMultiplicityImNext = linkMultiplicityImNext[..., 1:]
346
380
  if self.return_label_rank:
347
- rankIm = rankIm[..., frame_window:frame_window+1]
348
- centerArr = centerArr[: , frame_window:frame_window+1]
349
- prevLabelArr = prevLabelArr[:, :1]
381
+ rankIm = rankIm[..., 1:-1]
382
+ centerArr = centerArr[: , 1:-1] # B, C, N, 2
383
+ prevLabelArr = prevLabelArr[:, :1] if self.tracking else None
350
384
  if ndisp:
351
385
  nextLabelArr = nextLabelArr[:, 1:]
352
386
  if self.return_edm_derivatives:
@@ -360,7 +394,7 @@ class DyDxIterator(TrackingIterator):
360
394
  all_channels.append(centerIm)
361
395
  downscale_factor = 1./self.downscale if self.downscale>1 else 0
362
396
  scale = [1, downscale_factor, downscale_factor, 1]
363
- if self.downscale>1:
397
+ if self.downscale>1 and self.tracking:
364
398
  dyIm = rescale(dyIm, scale, anti_aliasing= False, order=0)
365
399
  dxIm = rescale(dxIm, scale, anti_aliasing= False, order=0)
366
400
  if ndisp:
@@ -369,11 +403,12 @@ class DyDxIterator(TrackingIterator):
369
403
  if ndisp:
370
404
  dyIm = np.concatenate([dyIm, dyImNext], -1)
371
405
  dxIm = np.concatenate([dxIm, dxImNext], -1)
372
- if self.output_float16:
406
+ if self.output_float16 and self.tracking:
373
407
  dyIm = dyIm.astype('float16', copy=False)
374
408
  dxIm = dxIm.astype('float16', copy=False)
375
- all_channels.append(dyIm)
376
- all_channels.append(dxIm)
409
+ if self.tracking:
410
+ all_channels.append(dyIm)
411
+ all_channels.append(dxIm)
377
412
  if self.return_link_multiplicity:
378
413
  if self.downscale>1:
379
414
  linkMultiplicityIm = rescale(linkMultiplicityIm, scale, anti_aliasing= False, order=0)
@@ -382,20 +417,22 @@ class DyDxIterator(TrackingIterator):
382
417
  if ndisp:
383
418
  linkMultiplicityIm = np.concatenate([linkMultiplicityIm, linkMultiplicityImNext], -1)
384
419
  all_channels.append(linkMultiplicityIm)
385
- if self.return_category:
420
+ if self.category_array_idx>0:
386
421
  if self.downscale > 1:
387
422
  categoryIm = rescale(categoryIm, scale, anti_aliasing=False, order=0)
388
423
  all_channels.append(categoryIm)
389
424
  if self.return_label_rank:
425
+ all_channels.append(rankIm)
390
426
  if ndisp:
391
427
  prevLabelArr = np.concatenate([prevLabelArr, nextLabelArr], 1)
392
- all_channels.append(rankIm)
393
- all_channels.append(prevLabelArr)
428
+ if self.tracking:
429
+ all_channels.append(prevLabelArr)
394
430
  all_channels.append(centerArr)
395
431
  return all_channels
396
432
 
397
433
  def _get_input_batch(self, batch_by_channel, ref_chan_idx, aug_param_array):
398
434
  inputs = super()._get_input_batch(batch_by_channel, ref_chan_idx, aug_param_array)
435
+ batch_by_channel["input_centers"] = []
399
436
  if len(self.label_input_channels)>0 : # for each input channel compute EDM and GCDM
400
437
  lidx = len(inputs) - len(self.label_input_channels)
401
438
  labels = inputs[lidx:]
@@ -403,11 +440,14 @@ class DyDxIterator(TrackingIterator):
403
440
  for labelIms in labels: # compute EDM and GDCM for each additional input label image
404
441
  edm = np.zeros(shape=labelIms.shape, dtype=np.float32)
405
442
  gdcm = np.zeros(shape=labelIms.shape, dtype=np.float32)
443
+ centers = dict()
444
+ batch_by_channel["input_centers"].append(centers)
406
445
  for b, c in itertools.product(range(labelIms.shape[0]), range(labelIms.shape[-1])):
407
446
  cur_labels = labelIms[b, ..., c].astype(np.int32)
408
447
  object_slices = find_objects(cur_labels)
409
448
  edm[b, ..., c] = edt_smooth(cur_labels, object_slices)
410
449
  labels_and_centers = _get_labels_and_centers(cur_labels, edm[b, ..., c], "MEDOID")
450
+ centers[(b, c)] = labels_and_centers.values() # record for output batch
411
451
  _draw_centers(gdcm[b,...,c], labels_and_centers, cur_labels, object_slices, "GEODESIC")
412
452
  inputs.append(edm)
413
453
  inputs.append(gdcm)
@@ -439,11 +479,21 @@ def _get_small_objects_at_edges_to_erase(labelIm, min_size):
439
479
 
440
480
  # displacement computation utils
441
481
 
442
- def _get_labels_and_centers(labelIm, edm, center_mode = "GEOMETRICAL"):
482
+ def _get_labels_and_centers(labelIm, edm, center_mode = "GEOMETRICAL", input_centers=None):
443
483
  labels = np.unique(labelIm)
444
484
  labels = [int(round(l)) for l in labels if l!=0]
445
485
  if len(labels)==0:
446
486
  return dict()
487
+ if input_centers is not None: # map input center to labels and get other centers if necessary
488
+ l_c = dict()
489
+ for center in input_centers:
490
+ l = labelIm[int(round(center[0])), int(round(center[1]))]
491
+ if l>0:
492
+ l_c[l] = center
493
+ # remove labels with known centers
494
+ labels = [l for l in labels if l not in l_c.keys()]
495
+ if len(labels) == 0: # no unknown centers
496
+ return l_c
447
497
  if center_mode == "GEOMETRICAL":
448
498
  centers = center_of_mass(labelIm, labelIm, labels)
449
499
  elif center_mode == "EDM_MAX":
@@ -465,7 +515,12 @@ def _get_labels_and_centers(labelIm, edm, center_mode = "GEOMETRICAL"):
465
515
  centers = [get_medoid(*np.asarray(labelIm == l).nonzero()) for l in labels]
466
516
  else:
467
517
  raise ValueError(f"Invalid center mode: {center_mode}")
468
- return dict(zip(labels, centers))
518
+ new_l_c = dict(zip(labels, centers))
519
+ if input_centers is None:
520
+ return new_l_c
521
+ else:
522
+ l_c.update(new_l_c)
523
+ return l_c
469
524
 
470
525
  # channel dimension = frames
471
526
  def _compute_prev_label_map(labelIm, prevlabelArray, end_points):
@@ -568,14 +623,13 @@ def _get_link_multiplicity(n_neigh):
568
623
  return 2
569
624
 
570
625
  def _compute_outputs(labels_map_centers, labelIm, labels_map_prev, object_slices, dyIm, dxIm, dyImNext=None, dxImNext=None, cdmIm=None, cdmImPrev=None, edmIm=None, edmImPrev=None, scale_edm:bool=False, categoryIm=None, categoryArray=None, categoryImPrev=None, categoryArrayPrev=None, linkMultiplicityIm=None, linkMultiplicityImNext=None, rankIm=None, rankImPrev=None, prevLabelArr=None, nextLabelArr=None, centerArr=None, centerArrPrev=None, center_distance_mode:str= "GEODESIC"):
571
- assert labelIm.shape[-1] == 2, f"invalid labelIm : {labelIm.shape[-1]} channels instead of 2"
572
626
  assert (dxImNext is None) == (dyImNext is None)
573
627
  curLabelIm = labelIm[...,-1]
574
628
  labels_prev = labels_map_centers[0].keys()
575
629
  labels_prev_rank = {l:r for r, l in enumerate(labels_prev)}
576
- labels_map_prev = _subset_label_map_prev(labels_map_prev, labels_prev, labels_map_centers[-1].keys())
630
+ labels_map_prev = _subset_label_map_prev(labels_map_prev, labels_prev, labels_map_centers[-1].keys()) if labels_map_prev is not None else None
577
631
  for rank, (label, center) in enumerate(labels_map_centers[-1].items()):
578
- label_prevs = labels_map_prev.get(label, [])
632
+ label_prevs = labels_map_prev.get(label, []) if labels_map_prev is not None else []
579
633
  mask = curLabelIm == label
580
634
  if len(label_prevs)==1:
581
635
  label_prev = next(iter(label_prevs))
@@ -624,10 +678,10 @@ def _compute_outputs(labels_map_centers, labelIm, labels_map_prev, object_slices
624
678
  if rankImPrev is not None:
625
679
  rankImPrev[mask] = rank + 1
626
680
  if cdmIm is not None:
627
- assert cdmIm.shape == dyIm.shape, "invalid shape for center image"
628
- _draw_centers(cdmIm, labels_map_centers[-1], labelIm[...,1], object_slices[1], center_distance_mode=center_distance_mode)
681
+ assert cdmIm.shape == curLabelIm.shape, "invalid shape for center image"
682
+ _draw_centers(cdmIm, labels_map_centers[-1], curLabelIm, object_slices[-1], center_distance_mode=center_distance_mode)
629
683
  if cdmImPrev is not None:
630
- assert cdmImPrev.shape == dyIm.shape, "invalid shape for center image prev"
684
+ assert cdmImPrev.shape == curLabelIm.shape, "invalid shape for center image prev"
631
685
  _draw_centers(cdmImPrev, labels_map_centers[0], labelIm[...,0], object_slices[0], center_distance_mode=center_distance_mode)
632
686
  if centerArr is not None:
633
687
  for rank, (label, center) in enumerate(labels_map_centers[-1].items()):
@@ -3,14 +3,14 @@ import numpy as np
3
3
  from random import getrandbits, uniform, choice
4
4
  import scipy.ndimage as ndi
5
5
  import tensorflow as tf
6
- def get_swim1d_function(mask_channels:list, distance:int=50, min_gap:int=3, closed_end:bool = True):
6
+ def get_swim1d_function(mask_channels:list, ref_mask_idx:int=0, distance:int=50, min_gap:int=3, closed_end:bool = True):
7
7
  if not isinstance(mask_channels, (list, tuple)):
8
8
  mask_channels = [mask_channels]
9
9
  assert len(mask_channels)>=1, "at least one mask channel must be provided"
10
10
  def fun(batch_by_channel):
11
11
  if distance > 1:
12
12
  channels = [c for c in batch_by_channel.keys() if not isinstance(c, str) and c>=0]
13
- mask_batch = batch_by_channel[mask_channels[0]]
13
+ mask_batch = batch_by_channel[mask_channels[ref_mask_idx]]
14
14
  for b, c in itertools.product(range(mask_batch.shape[0]), range(mask_batch.shape[-1])):
15
15
  mask_img = mask_batch[b,...,c]
16
16
  # get y space between bacteria