DiSTNet2D 0.2.2__tar.gz → 0.2.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. {distnet2d-0.2.2 → distnet2d-0.2.4}/DiSTNet2D.egg-info/PKG-INFO +3 -3
  2. {distnet2d-0.2.2 → distnet2d-0.2.4}/DiSTNet2D.egg-info/requires.txt +1 -1
  3. {distnet2d-0.2.2 → distnet2d-0.2.4}/PKG-INFO +3 -3
  4. {distnet2d-0.2.2 → distnet2d-0.2.4}/distnet_2d/model/distnet_2d.py +46 -44
  5. {distnet2d-0.2.2 → distnet2d-0.2.4}/distnet_2d/utils/metrics_tf.py +9 -10
  6. {distnet2d-0.2.2 → distnet2d-0.2.4}/distnet_2d/utils/objectwise_computation_tf.py +62 -21
  7. {distnet2d-0.2.2 → distnet2d-0.2.4}/setup.py +3 -3
  8. {distnet2d-0.2.2 → distnet2d-0.2.4}/DiSTNet2D.egg-info/SOURCES.txt +0 -0
  9. {distnet2d-0.2.2 → distnet2d-0.2.4}/DiSTNet2D.egg-info/dependency_links.txt +0 -0
  10. {distnet2d-0.2.2 → distnet2d-0.2.4}/DiSTNet2D.egg-info/top_level.txt +0 -0
  11. {distnet2d-0.2.2 → distnet2d-0.2.4}/LICENSE.txt +0 -0
  12. {distnet2d-0.2.2 → distnet2d-0.2.4}/README.md +0 -0
  13. {distnet2d-0.2.2 → distnet2d-0.2.4}/distnet_2d/__init__.py +0 -0
  14. {distnet2d-0.2.2 → distnet2d-0.2.4}/distnet_2d/data/__init__.py +0 -0
  15. {distnet2d-0.2.2 → distnet2d-0.2.4}/distnet_2d/data/center_edm.py +0 -0
  16. {distnet2d-0.2.2 → distnet2d-0.2.4}/distnet_2d/data/dydx_iterator.py +0 -0
  17. {distnet2d-0.2.2 → distnet2d-0.2.4}/distnet_2d/data/medoid.py +0 -0
  18. {distnet2d-0.2.2 → distnet2d-0.2.4}/distnet_2d/data/swim1d.py +0 -0
  19. {distnet2d-0.2.2 → distnet2d-0.2.4}/distnet_2d/model/__init__.py +0 -0
  20. {distnet2d-0.2.2 → distnet2d-0.2.4}/distnet_2d/model/architectures.py +0 -0
  21. {distnet2d-0.2.2 → distnet2d-0.2.4}/distnet_2d/model/distnet_2d_seg.py +0 -0
  22. {distnet2d-0.2.2 → distnet2d-0.2.4}/distnet_2d/model/gradient_accumulator.py +0 -0
  23. {distnet2d-0.2.2 → distnet2d-0.2.4}/distnet_2d/model/layers.py +0 -0
  24. {distnet2d-0.2.2 → distnet2d-0.2.4}/distnet_2d/model/spatial_attention.py +0 -0
  25. {distnet2d-0.2.2 → distnet2d-0.2.4}/distnet_2d/utils/__init__.py +0 -0
  26. {distnet2d-0.2.2 → distnet2d-0.2.4}/distnet_2d/utils/agc.py +0 -0
  27. {distnet2d-0.2.2 → distnet2d-0.2.4}/distnet_2d/utils/helpers.py +0 -0
  28. {distnet2d-0.2.2 → distnet2d-0.2.4}/distnet_2d/utils/image_derivatives_np.py +0 -0
  29. {distnet2d-0.2.2 → distnet2d-0.2.4}/distnet_2d/utils/image_derivatives_tf.py +0 -0
  30. {distnet2d-0.2.2 → distnet2d-0.2.4}/distnet_2d/utils/losses.py +0 -0
  31. {distnet2d-0.2.2 → distnet2d-0.2.4}/distnet_2d/utils/lovasz_loss.py +0 -0
  32. {distnet2d-0.2.2 → distnet2d-0.2.4}/setup.cfg +0 -0
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: DiSTNet2D
3
- Version: 0.2.2
3
+ Version: 0.2.4
4
4
  Summary: tensorflow/keras implementation of DiSTNet 2D
5
5
  Home-page: https://github.com/jeanollion/distnet2d
6
- Download-URL: https://github.com/jeanollion/distnet2d/releases/download/v0.2.2/distnet2d-0.2.2.tar.gz
6
+ Download-URL: https://github.com/jeanollion/distnet2d/releases/download/v0.2.4/distnet2d-0.2.4.tar.gz
7
7
  Author: Jean Ollion
8
8
  Author-email: jean.ollion@sabilab.fr
9
9
  Keywords: Segmentation,Tracking,Cell,Tensorflow,Keras
@@ -22,7 +22,7 @@ Requires-Dist: tensorflow>=2.7.1
22
22
  Requires-Dist: edt>=2.0.2
23
23
  Requires-Dist: scikit-fmm
24
24
  Requires-Dist: numba
25
- Requires-Dist: dataset_iterator>=0.5.5
25
+ Requires-Dist: dataset_iterator>=0.5.7
26
26
  Requires-Dist: elasticdeform>=0.4.7
27
27
  Dynamic: author
28
28
  Dynamic: author-email
@@ -4,5 +4,5 @@ tensorflow>=2.7.1
4
4
  edt>=2.0.2
5
5
  scikit-fmm
6
6
  numba
7
- dataset_iterator>=0.5.5
7
+ dataset_iterator>=0.5.7
8
8
  elasticdeform>=0.4.7
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: DiSTNet2D
3
- Version: 0.2.2
3
+ Version: 0.2.4
4
4
  Summary: tensorflow/keras implementation of DiSTNet 2D
5
5
  Home-page: https://github.com/jeanollion/distnet2d
6
- Download-URL: https://github.com/jeanollion/distnet2d/releases/download/v0.2.2/distnet2d-0.2.2.tar.gz
6
+ Download-URL: https://github.com/jeanollion/distnet2d/releases/download/v0.2.4/distnet2d-0.2.4.tar.gz
7
7
  Author: Jean Ollion
8
8
  Author-email: jean.ollion@sabilab.fr
9
9
  Keywords: Segmentation,Tracking,Cell,Tensorflow,Keras
@@ -22,7 +22,7 @@ Requires-Dist: tensorflow>=2.7.1
22
22
  Requires-Dist: edt>=2.0.2
23
23
  Requires-Dist: scikit-fmm
24
24
  Requires-Dist: numba
25
- Requires-Dist: dataset_iterator>=0.5.5
25
+ Requires-Dist: dataset_iterator>=0.5.7
26
26
  Requires-Dist: elasticdeform>=0.4.7
27
27
  Dynamic: author
28
28
  Dynamic: author-email
@@ -1,3 +1,5 @@
1
+ import contextlib
2
+
1
3
  import tensorflow as tf
2
4
  from .layers import ker_size_to_string, Combine, ResConv2D, Conv2DBNDrop, Conv2DTransposeBNDrop, WSConv2D, \
3
5
  BatchToChannel, SplitBatch, ChannelToBatch, NConvToBatch2D, SelectFeature, StopGradient, Stack
@@ -31,6 +33,7 @@ class DiSTNetModel(tf.keras.Model):
31
33
  category_number:int=0, category_class_weights = None, category_max_class_weight=10,
32
34
  print_gradients:bool=False, # for optimization, available in eager mode only
33
35
  accum_steps=1, use_agc=False, agc_clip_factor=0.1, agc_eps=1e-3, agc_exclude_output=False, # lower clip factor clips more
36
+ perform_test_step:bool=False,
34
37
  **kwargs):
35
38
  super().__init__(*args, **kwargs)
36
39
  self.edm_weight = edm_loss_weight
@@ -87,41 +90,40 @@ class DiSTNetModel(tf.keras.Model):
87
90
  self.displacement_loss.reduction = tf.keras.losses.Reduction.NONE
88
91
 
89
92
  # metrics associated to losses for to display accurate loss in a distributed setting
90
- self.edm_loss_metric = tf.keras.metrics.Mean(name="EDM")
91
- self.center_loss_metric = tf.keras.metrics.Mean(name="CDM")
92
- self.category_loss_metric = tf.keras.metrics.Mean(name="category") if self.category_weight > 0 else None
93
- self.dx_loss_metric = tf.keras.metrics.Mean(name="dX")
94
- self.dy_loss_metric = tf.keras.metrics.Mean(name="dY")
95
- self.link_multiplicity_loss_metric = tf.keras.metrics.Mean(name="link_multiplicity")
96
- self.loss_metric = tf.keras.metrics.Mean(name="loss")
97
-
98
- @property
99
- def metrics(self):
100
- metrics = []
101
93
  if self.edm_weight > 0:
102
- metrics.append(self.edm_loss_metric)
94
+ self.edm_loss_metric = tf.keras.metrics.Mean(name="EDM")
103
95
  if self.center_weight > 0:
104
- metrics.append(self.center_loss_metric)
96
+ self.center_loss_metric = tf.keras.metrics.Mean(name="CDM")
97
+ if self.category_weight > 0 and self.category_number > 1:
98
+ self.category_loss_metric = tf.keras.metrics.Mean(name="category")
105
99
  if self.displacement_weight > 0:
106
- metrics.append(self.dy_loss_metric)
107
- metrics.append(self.dx_loss_metric)
100
+ self.dx_loss_metric = tf.keras.metrics.Mean(name="dX")
101
+ self.dy_loss_metric = tf.keras.metrics.Mean(name="dY")
108
102
  if self.link_multiplicity_weight > 0:
109
- metrics.append(self.link_multiplicity_loss_metric)
110
- if self.category_weight > 0:
111
- metrics.append(self.category_loss_metric)
112
- metrics.append(self.loss_metric)
113
- if self._is_compiled:
114
- if self.compiled_metrics is not None:
115
- metrics += self.compiled_metrics.metrics
103
+ self.link_multiplicity_loss_metric = tf.keras.metrics.Mean(name="link_multiplicity")
104
+ self.loss_metric = tf.keras.metrics.Mean(name="loss")
105
+ self.perform_test_step=perform_test_step
116
106
 
117
- for l in self._flatten_layers():
118
- metrics.extend(l._metrics)
119
107
 
120
- return metrics
108
+ @staticmethod
109
+ @contextlib.contextmanager
110
+ def nullcontext():
111
+ yield
112
+
113
+ def maybe_gradient_tape(self, training):
114
+ if training:
115
+ return tf.GradientTape(persistent=self.print_gradients)
116
+ return self.nullcontext()
121
117
 
122
118
 
123
119
  def train_step(self, data):
124
- if self.use_grad_acc:
120
+ return self.step(data, training=True)
121
+
122
+ def test_step(self, data):
123
+ return self.step(data, False)
124
+
125
+ def step(self, data, training:bool):
126
+ if self.use_grad_acc and training:
125
127
  self.gradient_accumulator.init_train_step()
126
128
 
127
129
  fw = self.frame_window
@@ -144,7 +146,7 @@ class DiSTNetModel(tf.keras.Model):
144
146
  lm_idx = int(center_weight > 0) + 2 * int(displacement_weight > 0) + int(link_multiplicity_weight > 0)
145
147
  cat_idx = lm_idx + 1
146
148
 
147
- with tf.GradientTape(persistent=self.print_gradients) as tape:
149
+ with self.maybe_gradient_tape(training) as tape:
148
150
  y_pred = self(x, training=True) # Forward pass
149
151
  if self.predict_edm_derivatives:
150
152
  edm, edm_dy, edm_dx = tf.split(y_pred[0], num_or_size_splits=3, axis=-1)
@@ -225,7 +227,7 @@ class DiSTNetModel(tf.keras.Model):
225
227
  if num_replicas > 1:
226
228
  loss *= 1.0 / num_replicas
227
229
 
228
- if self.print_gradients:
230
+ if self.print_gradients and training:
229
231
  trainable_vars_tape = [t for t in self.trainable_variables if (t.name.startswith("DecoderSegEDM") or t.name.startswith("DecoderCenterCDM") or t.name.startswith("DecoderTrackY0") or t.name.startswith("DecoderTrackX0") or t.name.startswith("DecoderLinkMultiplicity0") or t.name.startswith("FeatureSequence_Op4") or t.name.startswith("Attention")) and ("/kernel" in t.name or "/wv" in t.name) ]
230
232
  for loss_name, loss_value in losses.items():
231
233
  if loss_name != "loss" :
@@ -246,31 +248,31 @@ class DiSTNetModel(tf.keras.Model):
246
248
  print(f"AGC: layer: {v.name}, loss: {loss_name}, grad: {tf.math.sqrt(tf.reduce_mean(tf.math.square(g))).numpy()}")
247
249
 
248
250
  # Compute gradients
249
- gradients = tape.gradient(loss, self.trainable_variables)
250
- if mixed_precision:
251
- gradients = self.optimizer.get_unscaled_gradients(gradients)
252
- if self.use_agc:
253
- gradients = adaptive_clip_grad(self.trainable_variables, gradients, clip_factor=self.agc_clip_factor, eps=self.agc_eps, exclude_keywords=self.agc_exclude_keywords)
254
- if not self.use_grad_acc:
255
- self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) #Update weights
256
- else:
257
- self.gradient_accumulator.accumulate_gradients(gradients)
258
- self.gradient_accumulator.apply_gradients()
251
+ if training:
252
+ gradients = tape.gradient(loss, self.trainable_variables)
253
+ if mixed_precision:
254
+ gradients = self.optimizer.get_unscaled_gradients(gradients)
255
+ if self.use_agc:
256
+ gradients = adaptive_clip_grad(self.trainable_variables, gradients, clip_factor=self.agc_clip_factor, eps=self.agc_eps, exclude_keywords=self.agc_exclude_keywords)
257
+ if not self.use_grad_acc:
258
+ self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) #Update weights
259
+ else:
260
+ self.gradient_accumulator.accumulate_gradients(gradients)
261
+ self.gradient_accumulator.apply_gradients()
259
262
 
260
263
  # Update metrics state
261
- if edm_weight > 0:
264
+ if self.edm_weight > 0:
262
265
  self.edm_loss_metric.update_state(losses["EDM"], sample_weight=batch_dim)
263
- if center_weight > 0:
266
+ if self.center_weight > 0:
264
267
  self.center_loss_metric.update_state(losses["CDM"], sample_weight=batch_dim)
265
- if displacement_weight>0:
268
+ if self.displacement_weight > 0:
266
269
  self.dx_loss_metric.update_state(losses["dX"], sample_weight=batch_dim)
267
270
  self.dy_loss_metric.update_state(losses["dY"], sample_weight=batch_dim)
268
- if link_multiplicity_weight >0:
271
+ if self.link_multiplicity_weight > 0:
269
272
  self.link_multiplicity_loss_metric.update_state(losses["link_multiplicity"], sample_weight=batch_dim)
270
- if self.category_loss_metric is not None:
273
+ if self.category_weight > 0 and self.category_number > 1:
271
274
  self.category_loss_metric.update_state(losses["category"], sample_weight=batch_dim)
272
275
  self.loss_metric.update_state(losses["loss"], sample_weight=batch_dim)
273
-
274
276
  return self.compute_metrics(x, y, y_pred, None)
275
277
 
276
278
  def _compute_displacement_loss(self, y, y_pred, cell_mask):
@@ -1,16 +1,16 @@
1
1
  import tensorflow as tf
2
2
  from .objectwise_computation_tf import get_max_by_object_fun, coord_distance_fun, get_argmax_2d_by_object_fun, \
3
3
  get_mean_by_object_fun, get_label_size, IoU, objectwise_compute, objectwise_compute_channel, reduce_pop_size, \
4
- FPR
4
+ FP
5
5
 
6
6
 
7
- def get_metrics_fun(center_scale: float, max_objects_number: int = 0, category:bool = False, tracking:bool=True):
7
+ def get_metrics_fun(scale: float, max_objects_number: int = 0, category:bool = False, tracking:bool=True):
8
8
  """
9
9
  return metric function for disnet2D
10
10
  assumes iterator in return_central_only= True mode (thus framewindow = 1 and next = true)
11
11
  Parameters
12
12
  ----------
13
- center_scale
13
+ scale
14
14
  max_objects_number
15
15
  reduce
16
16
 
@@ -19,7 +19,6 @@ def get_metrics_fun(center_scale: float, max_objects_number: int = 0, category:b
19
19
 
20
20
  """
21
21
 
22
- scale = tf.cast(center_scale, tf.float32)
23
22
  coord_distance_function = coord_distance_fun(max=True, sqrt=True, pop_fraction=0.25)
24
23
  spa_max_fun = get_argmax_2d_by_object_fun()
25
24
  mean_fun = get_mean_by_object_fun()
@@ -41,7 +40,7 @@ def get_metrics_fun(center_scale: float, max_objects_number: int = 0, category:b
41
40
  labels = tf.transpose(labels, perm=[2, 0, 1]) # (1, Y, X)
42
41
  edm = tf.transpose(edm, perm=[2, 0, 1]) # (1, Y, X)
43
42
  gdcm = tf.transpose(gdcm, perm=[2, 0, 1]) # (1, Y, X)
44
- center_values = tf.math.exp(-tf.math.square(tf.math.divide(gdcm, scale)))
43
+ center_values = tf.math.exp(-tf.math.square(tf.math.divide(gdcm, tf.cast(scale/2., tf.float32))))
45
44
  ids, sizes, N = get_label_size(labels, max_objects_number) # (1, N), (1, N)
46
45
  ids = ids[0]
47
46
  sizes = sizes[0]
@@ -61,14 +60,14 @@ def get_metrics_fun(center_scale: float, max_objects_number: int = 0, category:b
61
60
  metrics = []
62
61
 
63
62
  # EDM : foreground/background IoU
64
- pred_foreground = tf.math.greater(edm, tf.cast(0.5, edm.dtype))
63
+ pred_foreground = tf.math.greater(edm, tf.cast(0, edm.dtype))
65
64
  true_foreground = tf.math.greater(labels, tf.cast(0, labels.dtype))
66
- edm_IoU = IoU(true_foreground, pred_foreground, tolerance=False)
65
+ edm_IoU = IoU(true_foreground, pred_foreground, tolerance_radius=scale / 8.) #
67
66
  metrics.append(edm_IoU)
68
67
 
69
- # Surface-based False Positive Rate (FPR) based on EDM
70
- #fpr = FPR(true_foreground, pred_foreground, tolerance=True)
71
- #metrics.append(-fpr)
68
+ # Surface-based False Positive Density (FPD) based on EDM
69
+ fp = FP(true_foreground, pred_foreground, rate=False, tolerance_radius=scale / 4.) #
70
+ metrics.append(-fp)
72
71
 
73
72
  # contour IoU : problem: true positive contours are usually not precise enough.
74
73
  #pred_contours = tf.math.logical_and(tf.math.greater(edm, tf.cast(0.5, edm.dtype)), tf.math.less_equal(edm, tf.cast(1.5, edm.dtype)))
@@ -211,32 +211,71 @@ def _generate_kernel(sizeY, sizeX, C=1, O=0):
211
211
  return kernel
212
212
 
213
213
 
214
- def IoU(true_foreground, pred_foreground, tolerance:bool=False):
215
- true_inter = _dilate_mask(true_foreground) if tolerance else true_foreground
214
+ def IoU(true_foreground, pred_foreground, tolerance_radius:float=0):
215
+ true_inter = _dilate_mask(true_foreground, radius=tolerance_radius, symmetric_padding=True) if tolerance_radius>=1 else true_foreground
216
216
  intersection = tf.math.count_nonzero(tf.math.logical_and(true_inter, pred_foreground), keepdims=False)
217
217
  union = tf.math.count_nonzero(tf.math.logical_or(true_foreground, pred_foreground), keepdims=False)
218
218
  return tf.cond(tf.math.equal(union, tf.cast(0, union.dtype)), lambda: tf.cast(1., tf.float32), lambda: tf.math.divide(tf.cast(intersection, tf.float32), tf.cast(union, tf.float32))) # if union is null -> metric is 1
219
219
 
220
220
 
221
- def FPR(true_foreground, pred_foreground, tolerance:bool=False):
221
+ def FP(true_foreground, pred_foreground, rate:bool = False, tolerance_radius:float=0):
222
222
  true_background = tf.math.logical_not(true_foreground)
223
223
  false_positives = tf.logical_and(pred_foreground, true_background)
224
- false_positives = _erode_mask(false_positives) if tolerance else false_positives
225
- num_fp = tf.reduce_sum(tf.cast(false_positives, tf.float32))
226
- num_tn = tf.reduce_sum(tf.cast(true_background, tf.float32))
227
- return tf.math.divide_no_nan(num_fp, num_tn)
224
+ false_positives = _erode_mask(false_positives, radius=tolerance_radius, symmetric_padding=False) if tolerance_radius>=1 else false_positives
225
+ fp = tf.math.count_nonzero(false_positives, keepdims=False)
226
+ if rate: # FPR
227
+ tn = tf.math.count_nonzero(true_background, keepdims=False) # for FRP
228
+ return tf.math.divide_no_nan(tf.cast(fp, tf.float32), tf.cast(tn, tf.float32))
229
+ else: # FPD
230
+ #return tf.cast(fp, tf.float32)
231
+ npix = tf.reduce_prod(tf.shape(true_background)) # for FPD
232
+ return tf.math.divide(tf.cast(fp, tf.float32), tf.cast(npix, tf.float32))
233
+
234
+
235
+ def _dilate_mask(maskBYX, radius:float=1.5, tolerance:float=0.25, symmetric_padding:bool=True):
236
+ assert 0<=tolerance<0.5
237
+ maskBYX = tf.cast(maskBYX, tf.int32)
238
+ ker, rad = circular_kernel(radius)
239
+ thld = tf.math.floor(tf.cast(tf.math.reduce_sum(ker), tf.float32) * tf.cast(tolerance, tf.float32))
240
+ conv = _convolve(maskBYX, ker, rad, symmetric_padding=symmetric_padding)
241
+ return tf.math.greater(conv, tf.cast(thld, tf.int32))
228
242
 
229
243
 
230
- def _dilate_mask(maskBYX):
244
+ def _erode_mask(maskBYX, radius:float=1.5, tolerance:float=0.25, symmetric_padding:bool=False):
245
+ assert 0 <= tolerance < 0.5
231
246
  maskBYX = tf.cast(maskBYX, tf.int32)
232
- conv = _convolve(maskBYX, tf.ones(shape=[3, 3], dtype=tf.int32))
233
- return tf.math.greater(conv, tf.cast(2, tf.int32))
247
+ ker, rad = circular_kernel(radius)
248
+ thld = tf.math.ceil(tf.cast(tf.math.reduce_sum(ker), tf.float32) * tf.cast(1 - tolerance, tf.float32))
249
+ conv = _convolve(maskBYX, ker, rad, symmetric_padding=symmetric_padding)
250
+ return tf.math.greater_equal(conv, tf.cast(thld, tf.int32))
234
251
 
235
252
 
236
- def _erode_mask(maskBYX):
237
- maskBYX = tf.cast(maskBYX, tf.int32)
238
- conv = _convolve(maskBYX, tf.ones(shape=[3, 3], dtype=tf.int32))
239
- return tf.math.greater_equal(conv, tf.cast(7, tf.int32))
253
+
254
+ def circular_kernel(radius: float) :
255
+ """
256
+ Create a circular 2D kernel of ones with a given float radius.
257
+ Args:
258
+ radius: The radius of the circle (float).
259
+ Returns:
260
+ A 2D TensorFlow tensor representing the circular kernel (dtype: tf.int32).
261
+ """
262
+ radius_int = tf.cast(radius, tf.int32)
263
+ diameter = tf.cast(2 * radius_int + 1, tf.int32)
264
+ center = diameter // 2
265
+
266
+ # Create a grid of coordinates
267
+ y = tf.range(-center, diameter - center, dtype=tf.float32)
268
+ x = tf.range(-center, diameter - center, dtype=tf.float32)
269
+ y_grid, x_grid = tf.meshgrid(y, x, indexing='ij')
270
+
271
+ # Compute the distance from the center
272
+ distance = tf.math.sqrt(x_grid**2 + y_grid**2)
273
+
274
+ # Create the circular kernel
275
+ kernel = tf.zeros((diameter, diameter), dtype=tf.int32)
276
+ kernel = tf.where(distance <= radius, tf.ones_like(kernel, dtype=tf.int32), kernel)
277
+
278
+ return kernel, radius_int
240
279
 
241
280
 
242
281
  def _contour_IoU_fun(pred_contour, mask, size):
@@ -245,13 +284,15 @@ def _contour_IoU_fun(pred_contour, mask, size):
245
284
 
246
285
 
247
286
  def _compute_contours(maskBYX):
248
- kernel = [[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]
249
- conv = _convolve(maskBYX, kernel)
250
- return tf.math.greater(conv, 0) # detect at least one zero in the neighborhood
287
+ kernel = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]])
288
+ conv = _convolve(maskBYX, kernel, 1, symmetric_padding=True)
289
+ return tf.math.greater(conv, tf.cast(0, conv.dtype)) # detect at least one zero in the neighborhood
251
290
 
252
291
 
253
- def _convolve(imageBYX, kernel):
254
- padded = tf.pad(imageBYX, [[0, 0], [1, 1], [1, 1]], 'SYMMETRIC')
255
- input = padded[..., tf.newaxis]
256
- conv = tf.nn.conv2d(input, kernel[:, :, tf.newaxis, tf.newaxis], strides=1, padding='VALID')
292
+ def _convolve(imageBYX, kernel, radius, symmetric_padding:bool):
293
+ if symmetric_padding:
294
+ imageBYX = tf.pad(imageBYX, [[0, 0], [radius, radius], [radius, radius]], 'SYMMETRIC')
295
+ imageBYX = imageBYX[..., tf.newaxis]
296
+ kernel = tf.cast(kernel, imageBYX.dtype)
297
+ conv = tf.nn.conv2d(imageBYX, kernel[:, :, tf.newaxis, tf.newaxis], strides=1, padding='VALID' if symmetric_padding else "SAME")
257
298
  return conv[..., 0]
@@ -5,14 +5,14 @@ with open("README.md", "r") as fh:
5
5
 
6
6
  setuptools.setup(
7
7
  name="DiSTNet2D",
8
- version="0.2.2",
8
+ version="0.2.4",
9
9
  author="Jean Ollion",
10
10
  author_email="jean.ollion@sabilab.fr",
11
11
  description="tensorflow/keras implementation of DiSTNet 2D",
12
12
  long_description=long_description,
13
13
  long_description_content_type="text/markdown",
14
14
  url="https://github.com/jeanollion/distnet2d",
15
- download_url='https://github.com/jeanollion/distnet2d/releases/download/v0.2.2/distnet2d-0.2.2.tar.gz',
15
+ download_url='https://github.com/jeanollion/distnet2d/releases/download/v0.2.4/distnet2d-0.2.4.tar.gz',
16
16
  packages=setuptools.find_packages(),
17
17
  keywords=['Segmentation', 'Tracking', 'Cell', 'Tensorflow', 'Keras'],
18
18
  classifiers=[
@@ -24,5 +24,5 @@ setuptools.setup(
24
24
  'Programming Language :: Python :: 3',
25
25
  ],
26
26
  python_requires='>=3',
27
- install_requires=['numpy', 'scipy', 'tensorflow>=2.7.1', 'edt>=2.0.2', 'scikit-fmm', 'numba', 'dataset_iterator>=0.5.5', 'elasticdeform>=0.4.7']
27
+ install_requires=['numpy', 'scipy', 'tensorflow>=2.7.1', 'edt>=2.0.2', 'scikit-fmm', 'numba', 'dataset_iterator>=0.5.7', 'elasticdeform>=0.4.7']
28
28
  )
File without changes
File without changes
File without changes