nettracer3d 0.5.3__py3-none-any.whl → 0.5.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nettracer3d/segmenter.py CHANGED
@@ -57,6 +57,7 @@ class InteractiveSegmenter:
57
57
 
58
58
  self.feature_cache = None
59
59
  self.lock = threading.Lock()
60
+ self._currently_segmenting = None
60
61
 
61
62
  # Current position attributes
62
63
  self.current_z = None
@@ -66,6 +67,119 @@ class InteractiveSegmenter:
66
67
  self.realtimechunks = None
67
68
  self.current_speed = False
68
69
 
70
+ # Tracking if we're using 2d or 3d segs
71
+ self.use_two = False
72
+ self.two_slices = []
73
+ self.speed = True
74
+ self.cur_gpu = False
75
+ self.map_slice = None
76
+ self.prev_z = None
77
+ self.previewing = False
78
+
79
+ # flags to track state
80
+ self._currently_processing = False
81
+ self._skip_next_update = False
82
+ self._last_processed_slice = None
83
+
84
+ def segment_slice_chunked(self, slice_z, block_size=64):
85
+ """
86
+ A completely standalone method to segment a single z-slice in chunks
87
+ with improved safeguards.
88
+ """
89
+ # Check if we're already processing this slice
90
+ if self._currently_processing and self._currently_processing == slice_z:
91
+ return
92
+
93
+ # Set processing flag with the slice we're processing
94
+ self._currently_processing = slice_z
95
+
96
+ try:
97
+
98
+ # First attempt to get the feature map
99
+ feature_map = None
100
+
101
+ if slice_z in self.feature_cache:
102
+ feature_map = self.feature_cache[slice_z]
103
+ elif hasattr(self, 'map_slice') and self.map_slice is not None and slice_z == self.current_z:
104
+ feature_map = self.map_slice
105
+ else:
106
+ # Generate new feature map
107
+ try:
108
+ feature_map = self.get_feature_map_slice(slice_z, self.current_speed, False)
109
+ self.map_slice = feature_map
110
+ # Cache the feature map for future use
111
+ if not hasattr(self, 'feature_cache'):
112
+ self.feature_cache = {}
113
+ self.feature_cache[slice_z] = feature_map
114
+ except Exception as e:
115
+ print(f"Error generating feature map: {e}")
116
+ import traceback
117
+ traceback.print_exc()
118
+ return # Exit if we can't generate the feature map
119
+
120
+ # Check that we have a valid feature map
121
+ if feature_map is None:
122
+ return
123
+
124
+ # Get dimensions of the slice
125
+ y_size, x_size = self.image_3d.shape[1], self.image_3d.shape[2]
126
+ chunk_count = 0
127
+
128
+ # Process in blocks for chunked feedback
129
+ for y_start in range(0, y_size, block_size):
130
+ if self._currently_processing != slice_z:
131
+ return
132
+
133
+ for x_start in range(0, x_size, block_size):
134
+ if self._currently_processing != slice_z:
135
+ return
136
+
137
+ y_end = min(y_start + block_size, y_size)
138
+ x_end = min(x_start + block_size, x_size)
139
+
140
+ # Create coordinates and features for this block
141
+ coords = []
142
+ features = []
143
+
144
+ for y in range(y_start, y_end):
145
+ for x in range(x_start, x_end):
146
+ coords.append((slice_z, y, x))
147
+ features.append(feature_map[y, x])
148
+
149
+ # Skip empty blocks
150
+ if not coords:
151
+ continue
152
+
153
+ # Predict
154
+ try:
155
+ predictions = self.model.predict(features)
156
+
157
+ # Split results
158
+ foreground = set()
159
+ background = set()
160
+
161
+ for coord, pred in zip(coords, predictions):
162
+ if pred:
163
+ foreground.add(coord)
164
+ else:
165
+ background.add(coord)
166
+
167
+ # Yield this chunk
168
+ chunk_count += 1
169
+ yield foreground, background
170
+
171
+ except Exception as e:
172
+ print(f"Error processing chunk: {e}")
173
+ import traceback
174
+ traceback.print_exc()
175
+
176
+
177
+ finally:
178
+ # Only clear if we're still processing the same slice
179
+ # (otherwise, another slice might have taken over)
180
+ if self._currently_processing == slice_z:
181
+ self._currently_processing = None
182
+
69
183
  def compute_deep_feature_maps_cpu(self):
70
184
  """Compute feature maps using CPU"""
71
185
  features = []
@@ -125,6 +239,67 @@ class InteractiveSegmenter:
125
239
 
126
240
  return np.stack(features, axis=-1)
127
241
 
242
+ def compute_deep_feature_maps_cpu_2d(self, z = None):
243
+ """Compute 2D feature maps using CPU"""
244
+ features = []
245
+
246
+ image_2d = self.image_3d[z, :, :]
247
+ original_shape = image_2d.shape
248
+
249
+ # Gaussian using scipy
250
+ for sigma in [0.5, 1.0, 2.0, 4.0]:
251
+ smooth = ndimage.gaussian_filter(image_2d, sigma)
252
+ features.append(smooth)
253
+
254
+ # Local statistics using scipy's convolve - adjusted for 2D
255
+ window_size = 5
256
+ kernel = np.ones((window_size, window_size)) / (window_size**2)
257
+
258
+ # Local mean
259
+ local_mean = ndimage.convolve(image_2d, kernel, mode='reflect')
260
+ features.append(local_mean)
261
+
262
+ # Local variance
263
+ mean = np.mean(image_2d)
264
+ local_var = ndimage.convolve((image_2d - mean)**2, kernel, mode='reflect')
265
+ features.append(local_var)
266
+
267
+ # Gradient computations using scipy - adjusted axes for 2D
268
+ gx = ndimage.sobel(image_2d, axis=1, mode='reflect') # x direction
269
+ gy = ndimage.sobel(image_2d, axis=0, mode='reflect') # y direction
270
+
271
+ # Gradient magnitude (2D version)
272
+ gradient_magnitude = np.sqrt(gx**2 + gy**2)
273
+ features.append(gradient_magnitude)
274
+
275
+ # Second-order gradients
276
+ gxx = ndimage.sobel(gx, axis=1, mode='reflect')
277
+ gyy = ndimage.sobel(gy, axis=0, mode='reflect')
278
+
279
+ # Laplacian (sum of second derivatives) - 2D version
280
+ laplacian = gxx + gyy
281
+ features.append(laplacian)
282
+
283
+ # Hessian determinant - 2D version
284
+ hessian_det = gxx * gyy - ndimage.sobel(gx, axis=0, mode='reflect') * ndimage.sobel(gy, axis=1, mode='reflect')
285
+ features.append(hessian_det)
286
+
287
+ for i, feat in enumerate(features):
288
+ if feat.shape != original_shape:
289
+ # Check dimensionality and expand if needed
290
+ if len(feat.shape) < len(original_shape):
291
+ feat_adjusted = feat
292
+ missing_dims = len(original_shape) - len(feat.shape)
293
+ for _ in range(missing_dims):
294
+ feat_adjusted = np.expand_dims(feat_adjusted, axis=0)
295
+
296
+ if feat_adjusted.shape != original_shape:
297
+ raise ValueError(f"Feature {i} has shape {feat.shape}, expected {original_shape}")
298
+
299
+ features[i] = feat_adjusted
300
+
301
+ return np.stack(features, axis=-1)
302
+
128
303
  def compute_feature_maps(self):
129
304
  """Compute all feature maps using GPU acceleration"""
130
305
  #if not self.use_gpu:
@@ -132,7 +307,10 @@ class InteractiveSegmenter:
132
307
 
133
308
  features = []
134
309
  image = self.image_gpu
310
+ image_3d = self.image_3d
135
311
  original_shape = self.image_3d.shape
312
+
313
+
136
314
 
137
315
  # Gaussian smoothing at different scales
138
316
  print("Obtaining gaussians")
@@ -150,7 +328,7 @@ class InteractiveSegmenter:
150
328
  features.append(dog)
151
329
 
152
330
  # Convert image to PyTorch tensor for gradient operations
153
- image_torch = torch.from_numpy(self.image_3d).cuda()
331
+ image_torch = torch.from_numpy(image_3d).cuda()
154
332
  image_torch = image_torch.float().unsqueeze(0).unsqueeze(0)
155
333
 
156
334
  # Calculate required padding
@@ -159,7 +337,7 @@ class InteractiveSegmenter:
159
337
 
160
338
  # Create a single padded version with same padding
161
339
  pad = torch.nn.functional.pad(image_torch, (padding, padding, padding, padding, padding, padding), mode='replicate')
162
-
340
+
163
341
  print("Computing sobel kernels")
164
342
 
165
343
  # Create sobel kernels
@@ -181,6 +359,8 @@ class InteractiveSegmenter:
181
359
  gradient_feature = gradient_magnitude.cpu().numpy().squeeze()
182
360
 
183
361
  features.append(gradient_feature)
362
+
363
+ print(features.shape)
184
364
 
185
365
  # Verify shapes
186
366
  for i, feat in enumerate(features):
@@ -194,6 +374,122 @@ class InteractiveSegmenter:
194
374
 
195
375
  return np.stack(features, axis=-1)
196
376
 
377
+ def compute_feature_maps_2d(self, z=None):
378
+ """Compute all feature maps for 2D images using GPU acceleration"""
379
+
380
+ features = []
381
+
382
+ image = self.image_gpu[z, :, :]
383
+ image_2d = self.image_3d[z, :, :]
384
+ original_shape = image_2d.shape
385
+
386
+ # Gaussian smoothing at different scales
387
+ print("Obtaining gaussians")
388
+ for sigma in [0.5, 1.0, 2.0, 4.0]:
389
+ smooth = cp.asnumpy(self.gaussian_filter_gpu(image, sigma))
390
+ features.append(smooth)
391
+
392
+ print("Obtaining diff of gaussians")
393
+ # Difference of Gaussians
394
+ for (s1, s2) in [(1, 2), (2, 4)]:
395
+ g1 = self.gaussian_filter_gpu(image, s1)
396
+ g2 = self.gaussian_filter_gpu(image, s2)
397
+ dog = cp.asnumpy(g1 - g2)
398
+ features.append(dog)
399
+
400
+ # Convert image to PyTorch tensor for gradient operations
401
+ image_torch = torch.from_numpy(image_2d).cuda()
402
+ image_torch = image_torch.float().unsqueeze(0).unsqueeze(0)
403
+
404
+ # Calculate required padding
405
+ kernel_size = 3
406
+ padding = kernel_size // 2
407
+
408
+ # Create a single padded version with same padding
409
+ pad = torch.nn.functional.pad(image_torch, (padding, padding, padding, padding), mode='replicate')
410
+
411
+ print("Computing sobel kernels")
412
+ # Create 2D sobel kernels
413
+ sobel_x = torch.tensor([-1, 0, 1], device='cuda').float().view(1, 1, 1, 3)
414
+ sobel_y = torch.tensor([-1, 0, 1], device='cuda').float().view(1, 1, 3, 1)
415
+
416
+ # Compute gradients
417
+ print("Computing gradients")
418
+ gx = torch.nn.functional.conv2d(pad, sobel_x, padding=0)[:, :, :original_shape[0], :original_shape[1]]
419
+ gy = torch.nn.functional.conv2d(pad, sobel_y, padding=0)[:, :, :original_shape[0], :original_shape[1]]
420
+
421
+ # Compute gradient magnitude (no z component in 2D)
422
+ print("Computing gradient mags")
423
+ gradient_magnitude = torch.sqrt(gx**2 + gy**2)
424
+ gradient_feature = gradient_magnitude.cpu().numpy().squeeze()
425
+
426
+ features.append(gradient_feature)
427
+
428
+ # Verify shapes
429
+ for i, feat in enumerate(features):
430
+ if feat.shape != original_shape:
431
+ # Create a copy of the feature to modify
432
+ feat_adjusted = feat
433
+ # Check dimensionality and expand if needed
434
+ if len(feat.shape) < len(original_shape):
435
+ missing_dims = len(original_shape) - len(feat.shape)
436
+ for _ in range(missing_dims):
437
+ feat_adjusted = np.expand_dims(feat_adjusted, axis=0)
438
+
439
+ if feat_adjusted.shape != original_shape:
440
+ raise ValueError(f"Feature {i} has shape {feat.shape}, expected {original_shape}")
441
+
442
+ # Update the original features list with the adjusted version
443
+ features[i] = feat_adjusted
444
+
445
+ return np.stack(features, axis=-1)
446
+
447
+ def compute_feature_maps_cpu_2d(self, z = None):
448
+ """Compute feature maps for 2D images using CPU"""
449
+
450
+
451
+ features = []
452
+
453
+ image_2d = self.image_3d[z, :, :]
454
+ original_shape = image_2d.shape
455
+
456
+ # Gaussian smoothing at different scales
457
+ for sigma in [0.5, 1.0, 2.0, 4.0]:
458
+ smooth = ndimage.gaussian_filter(image_2d, sigma)
459
+ features.append(smooth)
460
+
461
+ # Difference of Gaussians
462
+ for (s1, s2) in [(1, 2), (2, 4)]:
463
+ g1 = ndimage.gaussian_filter(image_2d, s1)
464
+ g2 = ndimage.gaussian_filter(image_2d, s2)
465
+ dog = g1 - g2
466
+ features.append(dog)
467
+
468
+ # Gradient computations using scipy - note axis changes for 2D
469
+ gx = ndimage.sobel(image_2d, axis=1, mode='reflect') # x direction
470
+ gy = ndimage.sobel(image_2d, axis=0, mode='reflect') # y direction
471
+
472
+ # Gradient magnitude (no z component in 2D)
473
+ gradient_magnitude = np.sqrt(gx**2 + gy**2)
474
+ features.append(gradient_magnitude)
475
+
476
+ # Verify shapes
477
+ for i, feat in enumerate(features):
478
+ if feat.shape != original_shape:
479
+ # Check dimensionality and expand if needed
480
+ if len(feat.shape) < len(original_shape):
481
+ feat_adjusted = feat
482
+ missing_dims = len(original_shape) - len(feat.shape)
483
+ for _ in range(missing_dims):
484
+ feat_adjusted = np.expand_dims(feat_adjusted, axis=0)
485
+
486
+ if feat_adjusted.shape != original_shape:
487
+ raise ValueError(f"Feature {i} has shape {feat.shape}, expected {original_shape}")
488
+
489
+ features[i] = feat_adjusted
490
+
491
+ return np.stack(features, axis=-1)
492
+
197
493
  def compute_feature_maps_cpu(self):
198
494
  """Compute feature maps using CPU"""
199
495
  features = []
@@ -386,25 +682,108 @@ class InteractiveSegmenter:
386
682
 
387
683
  return foreground, background
388
684
 
685
+ def organize_by_z(self, coordinates):
686
+ """
687
+ Organizes a list of [z, y, x] coordinates into a dictionary of [y, x] coordinates grouped by z-value.
688
+
689
+ Args:
690
+ coordinates: List of [z, y, x] coordinate lists
691
+
692
+ Returns:
693
+ Dictionary with z-values as keys and lists of corresponding [y, x] coordinates as values
694
+ """
695
+ z_dict = {}
696
+
697
+ for coord in coordinates:
698
+ z, y, x = coord # Unpack the coordinates
699
+
700
+ # Add the y, x coordinate to the appropriate z-value group
701
+ if z not in z_dict:
702
+ z_dict[z] = []
703
+
704
+ z_dict[z].append((y, x)) # Store as tuple, not list, so it's hashable
705
+
706
+ return z_dict
707
+
389
708
  def process_chunk(self, chunk_coords):
390
709
  """Process a chunk of coordinates"""
391
- features = [self.feature_cache[z, y, x] for z, y, x in chunk_coords]
392
- predictions = self.model.predict(features)
393
-
710
+
394
711
  foreground = set()
395
712
  background = set()
396
- for coord, pred in zip(chunk_coords, predictions):
397
- if pred:
398
- foreground.add(coord)
399
- else:
400
- background.add(coord)
713
+
714
+ if not self.use_two:
715
+
716
+
717
+ features = [self.feature_cache[z, y, x] for z, y, x in chunk_coords]
718
+ predictions = self.model.predict(features)
719
+
720
+ for coord, pred in zip(chunk_coords, predictions):
721
+ if pred:
722
+ foreground.add(coord)
723
+ else:
724
+ background.add(coord)
725
+
726
+ else:
727
+ chunk_by_z = self.organize_by_z(chunk_coords)
728
+ for z, coords in chunk_by_z.items():
729
+
730
+ if z not in self.feature_cache and not self.previewing:
731
+ features = self.get_feature_map_slice(z, self.speed, self.cur_gpu)
732
+ features = [features[y, x] for y, x in coords]
733
+ elif z not in self.feature_cache and self.previewing:
734
+ features = self.map_slice
735
+ try:
736
+ features = [features[y, x] for y, x in coords]
737
+ except:
738
+ return [], []
739
+ else:
740
+ features = [self.feature_cache[z][y, x] for y, x in coords]
741
+
742
+ predictions = self.model.predict(features)
743
+
744
+ for (y, x), pred in zip(coords, predictions):
745
+ coord = (z, y, x) # Reconstruct the 3D coordinate as a tuple
746
+ if pred:
747
+ foreground.add(coord)
748
+ else:
749
+ background.add(coord)
401
750
 
402
751
  return foreground, background
403
752
 
404
- def segment_volume(self, chunk_size=64, gpu=False):
753
+
754
+
755
+ def segment_volume(self, chunk_size=None, gpu=False):
405
756
  """Segment volume using parallel processing of chunks with vectorized chunk creation"""
406
757
  #Change the above chunk size to None to have it auto-compute largest chunks (not sure which is faster, 64 seems reasonable in test cases)
407
758
 
759
+ def create_2d_chunks():
760
+ """
761
+ Create chunks by z-slices for 2D processing.
762
+ Each chunk is a complete z-slice with all y,x coordinates.
763
+
764
+ Returns:
765
+ List of chunks, where each chunk contains the coordinates for one z-slice
766
+ """
767
+ chunks = []
768
+
769
+ # Process one z-slice at a time
770
+ for z in range(self.image_3d.shape[0]):
771
+ # For each z-slice, gather all y,x coordinates
772
+ y_coords, x_coords = np.meshgrid(
773
+ np.arange(self.image_3d.shape[1]),
774
+ np.arange(self.image_3d.shape[2]),
775
+ indexing='ij'
776
+ )
777
+
778
+ # Create the z-slice coordinates
779
+ z_array = np.full_like(y_coords, z)
780
+ coords = np.stack([z_array, y_coords, x_coords]).reshape(3, -1).T
781
+
782
+ # Convert to list of tuples and add as a chunk
783
+ chunks.append(list(map(tuple, coords)))
784
+
785
+ return chunks
786
+
408
787
  try:
409
788
  from cuml.ensemble import RandomForestClassifier as cuRandomForestClassifier
410
789
  except:
@@ -418,52 +797,55 @@ class InteractiveSegmenter:
418
797
 
419
798
  print("Chunking data...")
420
799
 
421
- # Determine optimal chunk size based on number of cores if not specified
422
- if chunk_size is None:
423
- total_cores = multiprocessing.cpu_count()
424
-
425
- # Calculate total volume and target volume per core
426
- total_volume = np.prod(self.image_3d.shape)
427
- target_volume_per_chunk = total_volume / total_cores
428
-
429
- # Calculate chunk size that would give us roughly one chunk per core
430
- # Using cube root since we want roughly equal sizes in all dimensions
431
- chunk_size = int(np.cbrt(target_volume_per_chunk))
432
-
433
- # Ensure chunk size is at least 32 (minimum reasonable size) and not larger than smallest dimension
434
- chunk_size = max(32, min(chunk_size, min(self.image_3d.shape)))
800
+ if not self.use_two:
801
+ # Determine optimal chunk size based on number of cores if not specified
802
+ if chunk_size is None:
803
+ total_cores = multiprocessing.cpu_count()
804
+
805
+ # Calculate total volume and target volume per core
806
+ total_volume = np.prod(self.image_3d.shape)
807
+ target_volume_per_chunk = total_volume / total_cores
808
+
809
+ # Calculate chunk size that would give us roughly one chunk per core
810
+ # Using cube root since we want roughly equal sizes in all dimensions
811
+ chunk_size = int(np.cbrt(target_volume_per_chunk))
812
+
813
+ # Ensure chunk size is at least 32 (minimum reasonable size) and not larger than smallest dimension
814
+ chunk_size = max(32, min(chunk_size, min(self.image_3d.shape)))
815
+
816
+ # Round to nearest multiple of 32 for better memory alignment
817
+ chunk_size = ((chunk_size + 15) // 32) * 32
435
818
 
436
- # Round to nearest multiple of 32 for better memory alignment
437
- chunk_size = ((chunk_size + 15) // 32) * 32
438
-
439
- # Calculate number of chunks in each dimension
440
- z_chunks = (self.image_3d.shape[0] + chunk_size - 1) // chunk_size
441
- y_chunks = (self.image_3d.shape[1] + chunk_size - 1) // chunk_size
442
- x_chunks = (self.image_3d.shape[2] + chunk_size - 1) // chunk_size
443
-
444
- # Create start indices for all chunks at once
445
- chunk_starts = np.array(np.meshgrid(
446
- np.arange(z_chunks) * chunk_size,
447
- np.arange(y_chunks) * chunk_size,
448
- np.arange(x_chunks) * chunk_size,
449
- indexing='ij'
450
- )).reshape(3, -1).T
451
-
452
- chunks = []
453
- for z_start, y_start, x_start in chunk_starts:
454
- z_end = min(z_start + chunk_size, self.image_3d.shape[0])
455
- y_end = min(y_start + chunk_size, self.image_3d.shape[1])
456
- x_end = min(x_start + chunk_size, self.image_3d.shape[2])
819
+ # Calculate number of chunks in each dimension
820
+ z_chunks = (self.image_3d.shape[0] + chunk_size - 1) // chunk_size
821
+ y_chunks = (self.image_3d.shape[1] + chunk_size - 1) // chunk_size
822
+ x_chunks = (self.image_3d.shape[2] + chunk_size - 1) // chunk_size
457
823
 
458
- # Create coordinates for this chunk efficiently
459
- coords = np.stack(np.meshgrid(
460
- np.arange(z_start, z_end),
461
- np.arange(y_start, y_end),
462
- np.arange(x_start, x_end),
824
+ # Create start indices for all chunks at once
825
+ chunk_starts = np.array(np.meshgrid(
826
+ np.arange(z_chunks) * chunk_size,
827
+ np.arange(y_chunks) * chunk_size,
828
+ np.arange(x_chunks) * chunk_size,
463
829
  indexing='ij'
464
830
  )).reshape(3, -1).T
465
831
 
466
- chunks.append(list(map(tuple, coords)))
832
+ chunks = []
833
+ for z_start, y_start, x_start in chunk_starts:
834
+ z_end = min(z_start + chunk_size, self.image_3d.shape[0])
835
+ y_end = min(y_start + chunk_size, self.image_3d.shape[1])
836
+ x_end = min(x_start + chunk_size, self.image_3d.shape[2])
837
+
838
+ # Create coordinates for this chunk efficiently
839
+ coords = np.stack(np.meshgrid(
840
+ np.arange(z_start, z_end),
841
+ np.arange(y_start, y_end),
842
+ np.arange(x_start, x_end),
843
+ indexing='ij'
844
+ )).reshape(3, -1).T
845
+
846
+ chunks.append(list(map(tuple, coords)))
847
+ else:
848
+ chunks = create_2d_chunks()
467
849
 
468
850
  foreground_coords = set()
469
851
  background_coords = set()
@@ -491,11 +873,39 @@ class InteractiveSegmenter:
491
873
  return foreground_coords, background_coords
492
874
 
493
875
  def update_position(self, z=None, x=None, y=None):
494
- """Update current position for chunk prioritization"""
876
+ """Update current position for chunk prioritization with safeguards"""
877
+
878
+ # Check if we should skip this update
879
+ if hasattr(self, '_skip_next_update') and self._skip_next_update:
880
+ self._skip_next_update = False
881
+ return
882
+
883
+ # Store the previous z-position if not set
884
+ if not hasattr(self, 'prev_z') or self.prev_z is None:
885
+ self.prev_z = z
886
+
887
+ # Check if currently processing - if so, only update position but don't trigger map_slice changes
888
+ if hasattr(self, '_currently_processing') and self._currently_processing:
889
+ self.current_z = z
890
+ self.current_x = x
891
+ self.current_y = y
892
+ self.prev_z = z
893
+ return
894
+
895
+ # Update current positions
495
896
  self.current_z = z
496
897
  self.current_x = x
497
898
  self.current_y = y
498
-
899
+
900
+ # Only clear map_slice if z changes and we're not already generating a new one
901
+ if self.current_z != self.prev_z:
902
+ # Instead of setting to None, check if we already have it in the cache
903
+ if hasattr(self, 'feature_cache') and self.current_z not in self.feature_cache:
904
+ self.map_slice = None
905
+ self._currently_segmenting = None
906
+
907
+ # Update previous z
908
+ self.prev_z = z
499
909
 
500
910
  def get_realtime_chunks(self, chunk_size = 64):
501
911
  print("Computing some overhead...")
@@ -572,6 +982,7 @@ class InteractiveSegmenter:
572
982
  gpu = False
573
983
 
574
984
 
985
+
575
986
  if self.realtimechunks is None:
576
987
  self.get_realtime_chunks()
577
988
  else:
@@ -658,55 +1069,141 @@ class InteractiveSegmenter:
658
1069
  del futures[future]
659
1070
  yield fore, back
660
1071
 
1072
+
661
1073
  def cleanup(self):
662
1074
  """Clean up GPU memory"""
663
1075
  if self.use_gpu:
664
1076
  cp.get_default_memory_pool().free_all_blocks()
665
1077
  torch.cuda.empty_cache()
666
1078
 
667
- def train_batch(self, foreground_array, speed = True, use_gpu = False):
1079
+ def train_batch(self, foreground_array, speed = True, use_gpu = False, use_two = False):
668
1080
  """Train directly on foreground and background arrays"""
669
1081
 
1082
+ self.speed = speed
1083
+ self.cur_gpu = use_gpu
1084
+
670
1085
  if self.current_speed != speed:
671
1086
  self.feature_cache = None
672
1087
 
673
- if self.feature_cache is None:
674
- with self.lock:
675
- if self.feature_cache is None and speed:
676
- if use_gpu:
677
- self.feature_cache = self.compute_feature_maps()
678
- else:
679
- self.feature_cache = self.compute_feature_maps_cpu()
1088
+ if use_two:
680
1089
 
681
- elif self.feature_cache is None and not speed:
682
- if use_gpu:
1090
+ changed = [] #Track which slices need feature maps
683
1091
 
684
- self.feature_cache = self.compute_deep_feature_maps()
685
- else:
686
- self.feature_cache = self.compute_deep_feature_maps_cpu()
1092
+ if not self.use_two: #Clarifies if we need to redo feature cache for 2D
1093
+ self.feature_cache = None
1094
+ self.use_two = True
687
1095
 
1096
+ if self.feature_cache == None:
1097
+ self.feature_cache = {}
688
1098
 
689
- try:
690
1099
  # Get foreground coordinates and features
691
1100
  z_fore, y_fore, x_fore = np.where(foreground_array == 1)
692
- foreground_features = self.feature_cache[z_fore, y_fore, x_fore]
693
1101
 
694
1102
  # Get background coordinates and features
695
1103
  z_back, y_back, x_back = np.where(foreground_array == 2)
696
- background_features = self.feature_cache[z_back, y_back, x_back]
697
-
698
- # Combine features and labels
699
- X = np.vstack([foreground_features, background_features])
700
- y = np.hstack([np.ones(len(z_fore)), np.zeros(len(z_back))])
701
-
702
- # Train the model
703
- self.model.fit(X, y)
704
1104
 
705
- self.current_speed = speed
706
- except:
707
- print("Features maps computed, but no segmentation examples were provided so the model was not trained")
1105
+ slices = set(list(z_back) + list(z_fore))
1106
+
1107
+ for z in slices:
1108
+ if z not in self.two_slices:
1109
+ changed.append(z)
1110
+ self.two_slices.append(z) #Tracks assigning coords to feature map slices
1111
+
1112
+ foreground_features = []
1113
+ background_features = []
1114
+
1115
+ for i, z in enumerate(z_fore):
1116
+ if z in changed: # Means this slice needs a feature map
1117
+ new_map = self.get_feature_map_slice(z, speed, use_gpu)
1118
+ self.feature_cache[z] = new_map
1119
+ changed.remove(z)
1120
+
1121
+ current_map = self.feature_cache[z]
1122
+
1123
+ # Get the feature vector for this foreground point
1124
+ feature_vector = current_map[y_fore[i], x_fore[i]]
1125
+
1126
+ # Add to our collection
1127
+ foreground_features.append(feature_vector)
1128
+
1129
+ for i, z in enumerate(z_back):
1130
+ if z in changed: # Means this slice needs a feature map
1131
+ new_map = self.get_feature_map_slice(z, speed, use_gpu)
1132
+ self.feature_cache[z] = new_map
1133
+
1134
+ current_map = self.feature_cache[z]
1135
+
1136
+ # Get the feature vector for this foreground point
1137
+ feature_vector = current_map[y_back[i], x_back[i]]
1138
+
1139
+ # Add to our collection
1140
+ background_features.append(feature_vector)
1141
+
1142
+
1143
+ else:
1144
+
1145
+ self.two_slices = []
1146
+
1147
+ if self.use_two: #Clarifies if we need to redo feature cache for 3D
1148
+
1149
+ self.feature_cache = None
1150
+ self.use_two = False
1151
+
1152
+ if self.feature_cache is None:
1153
+ with self.lock:
1154
+ if self.feature_cache is None and speed:
1155
+ if use_gpu:
1156
+ self.feature_cache = self.compute_feature_maps()
1157
+ else:
1158
+ self.feature_cache = self.compute_feature_maps_cpu()
1159
+
1160
+ elif self.feature_cache is None and not speed:
1161
+ if use_gpu:
1162
+
1163
+ self.feature_cache = self.compute_deep_feature_maps()
1164
+ else:
1165
+ self.feature_cache = self.compute_deep_feature_maps_cpu()
1166
+
1167
+
1168
+ try:
1169
+ # Get foreground coordinates and features
1170
+ z_fore, y_fore, x_fore = np.where(foreground_array == 1)
1171
+ foreground_features = self.feature_cache[z_fore, y_fore, x_fore]
1172
+
1173
+ # Get background coordinates and features
1174
+ z_back, y_back, x_back = np.where(foreground_array == 2)
1175
+ background_features = self.feature_cache[z_back, y_back, x_back]
1176
+ except:
1177
+ print("Features maps computed, but no segmentation examples were provided so the model was not trained")
1178
+
1179
+
1180
+ # Combine features and labels
1181
+ X = np.vstack([foreground_features, background_features])
1182
+ y = np.hstack([np.ones(len(z_fore)), np.zeros(len(z_back))])
1183
+
1184
+ # Train the model
1185
+ self.model.fit(X, y)
1186
+
1187
+ self.current_speed = speed
1188
+
1189
+
708
1190
 
709
1191
 
710
1192
  print("Done")
711
1193
 
1194
+ def get_feature_map_slice(self, z, speed, use_gpu):
1195
+
1196
+ if self._currently_segmenting is not None:
1197
+ return
1198
+
1199
+ with self.lock:
1200
+ if speed:
1201
+
1202
+ output = self.compute_feature_maps_cpu_2d(z = z)
1203
+
1204
+ elif not speed:
1205
+
1206
+ output = self.compute_deep_feature_maps_cpu_2d(z = z)
1207
+
1208
+ return output
712
1209