nettracer3d 0.5.4__py3-none-any.whl → 0.5.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nettracer3d/segmenter.py CHANGED
@@ -57,6 +57,7 @@ class InteractiveSegmenter:
57
57
 
58
58
  self.feature_cache = None
59
59
  self.lock = threading.Lock()
60
+ self._currently_segmenting = None
60
61
 
61
62
  # Current position attributes
62
63
  self.current_z = None
@@ -66,6 +67,119 @@ class InteractiveSegmenter:
66
67
  self.realtimechunks = None
67
68
  self.current_speed = False
68
69
 
70
+ # Tracking if we're using 2d or 3d segs
71
+ self.use_two = False
72
+ self.two_slices = []
73
+ self.speed = True
74
+ self.cur_gpu = False
75
+ self.map_slice = None
76
+ self.prev_z = None
77
+ self.previewing = False
78
+
79
+ # flags to track state
80
+ self._currently_processing = False
81
+ self._skip_next_update = False
82
+ self._last_processed_slice = None
83
+
84
+ def segment_slice_chunked(self, slice_z, block_size=64):
85
+ """
86
+ A completely standalone method to segment a single z-slice in chunks
87
+ with improved safeguards.
88
+ """
89
+ # Check if we're already processing this slice
90
+ if self._currently_processing and self._currently_processing == slice_z:
91
+ return
92
+
93
+ # Set processing flag with the slice we're processing
94
+ self._currently_processing = slice_z
95
+
96
+ try:
97
+
98
+ # First attempt to get the feature map
99
+ feature_map = None
100
+
101
+ if slice_z in self.feature_cache:
102
+ feature_map = self.feature_cache[slice_z]
103
+ elif hasattr(self, 'map_slice') and self.map_slice is not None and slice_z == self.current_z:
104
+ feature_map = self.map_slice
105
+ else:
106
+ # Generate new feature map
107
+ try:
108
+ feature_map = self.get_feature_map_slice(slice_z, self.current_speed, False)
109
+ self.map_slice = feature_map
110
+ # Cache the feature map for future use
111
+ if not hasattr(self, 'feature_cache'):
112
+ self.feature_cache = {}
113
+ self.feature_cache[slice_z] = feature_map
114
+ except Exception as e:
115
+ print(f"Error generating feature map: {e}")
116
+ import traceback
117
+ traceback.print_exc()
118
+ return # Exit if we can't generate the feature map
119
+
120
+ # Check that we have a valid feature map
121
+ if feature_map is None:
122
+ return
123
+
124
+ # Get dimensions of the slice
125
+ y_size, x_size = self.image_3d.shape[1], self.image_3d.shape[2]
126
+ chunk_count = 0
127
+
128
+ # Process in blocks for chunked feedback
129
+ for y_start in range(0, y_size, block_size):
130
+ if self._currently_processing != slice_z:
131
+ return
132
+
133
+ for x_start in range(0, x_size, block_size):
134
+ if self._currently_processing != slice_z:
135
+ return
136
+
137
+ y_end = min(y_start + block_size, y_size)
138
+ x_end = min(x_start + block_size, x_size)
139
+
140
+ # Create coordinates and features for this block
141
+ coords = []
142
+ features = []
143
+
144
+ for y in range(y_start, y_end):
145
+ for x in range(x_start, x_end):
146
+ coords.append((slice_z, y, x))
147
+ features.append(feature_map[y, x])
148
+
149
+ # Skip empty blocks
150
+ if not coords:
151
+ continue
152
+
153
+ # Predict
154
+ try:
155
+ predictions = self.model.predict(features)
156
+
157
+ # Split results
158
+ foreground = set()
159
+ background = set()
160
+
161
+ for coord, pred in zip(coords, predictions):
162
+ if pred:
163
+ foreground.add(coord)
164
+ else:
165
+ background.add(coord)
166
+
167
+ # Yield this chunk
168
+ chunk_count += 1
169
+ yield foreground, background
170
+
171
+ except Exception as e:
172
+ print(f"Error processing chunk: {e}")
173
+ import traceback
174
+ traceback.print_exc()
175
+
176
+
177
+ finally:
178
+ # Only clear if we're still processing the same slice
179
+ # (otherwise, another slice might have taken over)
180
+ if self._currently_processing == slice_z:
181
+ self._currently_processing = None
182
+
69
183
  def compute_deep_feature_maps_cpu(self):
70
184
  """Compute feature maps using CPU"""
71
185
  features = []
@@ -125,6 +239,67 @@ class InteractiveSegmenter:
125
239
 
126
240
  return np.stack(features, axis=-1)
127
241
 
242
+ def compute_deep_feature_maps_cpu_2d(self, z = None):
243
+ """Compute 2D feature maps using CPU"""
244
+ features = []
245
+
246
+ image_2d = self.image_3d[z, :, :]
247
+ original_shape = image_2d.shape
248
+
249
+ # Gaussian using scipy
250
+ for sigma in [0.5, 1.0, 2.0, 4.0]:
251
+ smooth = ndimage.gaussian_filter(image_2d, sigma)
252
+ features.append(smooth)
253
+
254
+ # Local statistics using scipy's convolve - adjusted for 2D
255
+ window_size = 5
256
+ kernel = np.ones((window_size, window_size)) / (window_size**2)
257
+
258
+ # Local mean
259
+ local_mean = ndimage.convolve(image_2d, kernel, mode='reflect')
260
+ features.append(local_mean)
261
+
262
+ # Local variance
263
+ mean = np.mean(image_2d)
264
+ local_var = ndimage.convolve((image_2d - mean)**2, kernel, mode='reflect')
265
+ features.append(local_var)
266
+
267
+ # Gradient computations using scipy - adjusted axes for 2D
268
+ gx = ndimage.sobel(image_2d, axis=1, mode='reflect') # x direction
269
+ gy = ndimage.sobel(image_2d, axis=0, mode='reflect') # y direction
270
+
271
+ # Gradient magnitude (2D version)
272
+ gradient_magnitude = np.sqrt(gx**2 + gy**2)
273
+ features.append(gradient_magnitude)
274
+
275
+ # Second-order gradients
276
+ gxx = ndimage.sobel(gx, axis=1, mode='reflect')
277
+ gyy = ndimage.sobel(gy, axis=0, mode='reflect')
278
+
279
+ # Laplacian (sum of second derivatives) - 2D version
280
+ laplacian = gxx + gyy
281
+ features.append(laplacian)
282
+
283
+ # Hessian determinant - 2D version
284
+ hessian_det = gxx * gyy - ndimage.sobel(gx, axis=0, mode='reflect') * ndimage.sobel(gy, axis=1, mode='reflect')
285
+ features.append(hessian_det)
286
+
287
+ for i, feat in enumerate(features):
288
+ if feat.shape != original_shape:
289
+ # Check dimensionality and expand if needed
290
+ if len(feat.shape) < len(original_shape):
291
+ feat_adjusted = feat
292
+ missing_dims = len(original_shape) - len(feat.shape)
293
+ for _ in range(missing_dims):
294
+ feat_adjusted = np.expand_dims(feat_adjusted, axis=0)
295
+
296
+ if feat_adjusted.shape != original_shape:
297
+ raise ValueError(f"Feature {i} has shape {feat.shape}, expected {original_shape}")
298
+
299
+ features[i] = feat_adjusted
300
+
301
+ return np.stack(features, axis=-1)
302
+
128
303
  def compute_feature_maps(self):
129
304
  """Compute all feature maps using GPU acceleration"""
130
305
  #if not self.use_gpu:
@@ -132,7 +307,10 @@ class InteractiveSegmenter:
132
307
 
133
308
  features = []
134
309
  image = self.image_gpu
310
+ image_3d = self.image_3d
135
311
  original_shape = self.image_3d.shape
312
+
313
+
136
314
 
137
315
  # Gaussian smoothing at different scales
138
316
  print("Obtaining gaussians")
@@ -150,7 +328,7 @@ class InteractiveSegmenter:
150
328
  features.append(dog)
151
329
 
152
330
  # Convert image to PyTorch tensor for gradient operations
153
- image_torch = torch.from_numpy(self.image_3d).cuda()
331
+ image_torch = torch.from_numpy(image_3d).cuda()
154
332
  image_torch = image_torch.float().unsqueeze(0).unsqueeze(0)
155
333
 
156
334
  # Calculate required padding
@@ -159,7 +337,7 @@ class InteractiveSegmenter:
159
337
 
160
338
  # Create a single padded version with same padding
161
339
  pad = torch.nn.functional.pad(image_torch, (padding, padding, padding, padding, padding, padding), mode='replicate')
162
-
340
+
163
341
  print("Computing sobel kernels")
164
342
 
165
343
  # Create sobel kernels
@@ -181,6 +359,8 @@ class InteractiveSegmenter:
181
359
  gradient_feature = gradient_magnitude.cpu().numpy().squeeze()
182
360
 
183
361
  features.append(gradient_feature)
362
+
363
+ print(features.shape)
184
364
 
185
365
  # Verify shapes
186
366
  for i, feat in enumerate(features):
@@ -194,6 +374,122 @@ class InteractiveSegmenter:
194
374
 
195
375
  return np.stack(features, axis=-1)
196
376
 
377
+ def compute_feature_maps_2d(self, z=None):
378
+ """Compute all feature maps for 2D images using GPU acceleration"""
379
+
380
+ features = []
381
+
382
+ image = self.image_gpu[z, :, :]
383
+ image_2d = self.image_3d[z, :, :]
384
+ original_shape = image_2d.shape
385
+
386
+ # Gaussian smoothing at different scales
387
+ print("Obtaining gaussians")
388
+ for sigma in [0.5, 1.0, 2.0, 4.0]:
389
+ smooth = cp.asnumpy(self.gaussian_filter_gpu(image, sigma))
390
+ features.append(smooth)
391
+
392
+ print("Obtaining diff of gaussians")
393
+ # Difference of Gaussians
394
+ for (s1, s2) in [(1, 2), (2, 4)]:
395
+ g1 = self.gaussian_filter_gpu(image, s1)
396
+ g2 = self.gaussian_filter_gpu(image, s2)
397
+ dog = cp.asnumpy(g1 - g2)
398
+ features.append(dog)
399
+
400
+ # Convert image to PyTorch tensor for gradient operations
401
+ image_torch = torch.from_numpy(image_2d).cuda()
402
+ image_torch = image_torch.float().unsqueeze(0).unsqueeze(0)
403
+
404
+ # Calculate required padding
405
+ kernel_size = 3
406
+ padding = kernel_size // 2
407
+
408
+ # Create a single padded version with same padding
409
+ pad = torch.nn.functional.pad(image_torch, (padding, padding, padding, padding), mode='replicate')
410
+
411
+ print("Computing sobel kernels")
412
+ # Create 2D sobel kernels
413
+ sobel_x = torch.tensor([-1, 0, 1], device='cuda').float().view(1, 1, 1, 3)
414
+ sobel_y = torch.tensor([-1, 0, 1], device='cuda').float().view(1, 1, 3, 1)
415
+
416
+ # Compute gradients
417
+ print("Computing gradients")
418
+ gx = torch.nn.functional.conv2d(pad, sobel_x, padding=0)[:, :, :original_shape[0], :original_shape[1]]
419
+ gy = torch.nn.functional.conv2d(pad, sobel_y, padding=0)[:, :, :original_shape[0], :original_shape[1]]
420
+
421
+ # Compute gradient magnitude (no z component in 2D)
422
+ print("Computing gradient mags")
423
+ gradient_magnitude = torch.sqrt(gx**2 + gy**2)
424
+ gradient_feature = gradient_magnitude.cpu().numpy().squeeze()
425
+
426
+ features.append(gradient_feature)
427
+
428
+ # Verify shapes
429
+ for i, feat in enumerate(features):
430
+ if feat.shape != original_shape:
431
+ # Create a copy of the feature to modify
432
+ feat_adjusted = feat
433
+ # Check dimensionality and expand if needed
434
+ if len(feat.shape) < len(original_shape):
435
+ missing_dims = len(original_shape) - len(feat.shape)
436
+ for _ in range(missing_dims):
437
+ feat_adjusted = np.expand_dims(feat_adjusted, axis=0)
438
+
439
+ if feat_adjusted.shape != original_shape:
440
+ raise ValueError(f"Feature {i} has shape {feat.shape}, expected {original_shape}")
441
+
442
+ # Update the original features list with the adjusted version
443
+ features[i] = feat_adjusted
444
+
445
+ return np.stack(features, axis=-1)
446
+
447
+ def compute_feature_maps_cpu_2d(self, z = None):
448
+ """Compute feature maps for 2D images using CPU"""
449
+
450
+
451
+ features = []
452
+
453
+ image_2d = self.image_3d[z, :, :]
454
+ original_shape = image_2d.shape
455
+
456
+ # Gaussian smoothing at different scales
457
+ for sigma in [0.5, 1.0, 2.0, 4.0]:
458
+ smooth = ndimage.gaussian_filter(image_2d, sigma)
459
+ features.append(smooth)
460
+
461
+ # Difference of Gaussians
462
+ for (s1, s2) in [(1, 2), (2, 4)]:
463
+ g1 = ndimage.gaussian_filter(image_2d, s1)
464
+ g2 = ndimage.gaussian_filter(image_2d, s2)
465
+ dog = g1 - g2
466
+ features.append(dog)
467
+
468
+ # Gradient computations using scipy - note axis changes for 2D
469
+ gx = ndimage.sobel(image_2d, axis=1, mode='reflect') # x direction
470
+ gy = ndimage.sobel(image_2d, axis=0, mode='reflect') # y direction
471
+
472
+ # Gradient magnitude (no z component in 2D)
473
+ gradient_magnitude = np.sqrt(gx**2 + gy**2)
474
+ features.append(gradient_magnitude)
475
+
476
+ # Verify shapes
477
+ for i, feat in enumerate(features):
478
+ if feat.shape != original_shape:
479
+ # Check dimensionality and expand if needed
480
+ if len(feat.shape) < len(original_shape):
481
+ feat_adjusted = feat
482
+ missing_dims = len(original_shape) - len(feat.shape)
483
+ for _ in range(missing_dims):
484
+ feat_adjusted = np.expand_dims(feat_adjusted, axis=0)
485
+
486
+ if feat_adjusted.shape != original_shape:
487
+ raise ValueError(f"Feature {i} has shape {feat.shape}, expected {original_shape}")
488
+
489
+ features[i] = feat_adjusted
490
+
491
+ return np.stack(features, axis=-1)
492
+
197
493
  def compute_feature_maps_cpu(self):
198
494
  """Compute feature maps using CPU"""
199
495
  features = []
@@ -386,25 +682,112 @@ class InteractiveSegmenter:
386
682
 
387
683
  return foreground, background
388
684
 
685
+ def organize_by_z(self, coordinates):
686
+ """
687
+ Organizes a list of [z, y, x] coordinates into a dictionary of [y, x] coordinates grouped by z-value.
688
+
689
+ Args:
690
+ coordinates: List of [z, y, x] coordinate lists
691
+
692
+ Returns:
693
+ Dictionary with z-values as keys and lists of corresponding [y, x] coordinates as values
694
+ """
695
+ z_dict = {}
696
+
697
+ for coord in coordinates:
698
+ z, y, x = coord # Unpack the coordinates
699
+
700
+ # Add the y, x coordinate to the appropriate z-value group
701
+ if z not in z_dict:
702
+ z_dict[z] = []
703
+
704
+ z_dict[z].append((y, x)) # Store as tuple, not list, so it's hashable
705
+
706
+ return z_dict
707
+
389
708
  def process_chunk(self, chunk_coords):
390
709
  """Process a chunk of coordinates"""
391
- features = [self.feature_cache[z, y, x] for z, y, x in chunk_coords]
392
- predictions = self.model.predict(features)
393
-
710
+
394
711
  foreground = set()
395
712
  background = set()
396
- for coord, pred in zip(chunk_coords, predictions):
397
- if pred:
398
- foreground.add(coord)
399
- else:
400
- background.add(coord)
713
+
714
+ if not self.use_two:
715
+
716
+
717
+ features = [self.feature_cache[z, y, x] for z, y, x in chunk_coords]
718
+ predictions = self.model.predict(features)
719
+
720
+ for coord, pred in zip(chunk_coords, predictions):
721
+ if pred:
722
+ foreground.add(coord)
723
+ else:
724
+ background.add(coord)
725
+
726
+ else:
727
+ chunk_by_z = self.organize_by_z(chunk_coords)
728
+ for z, coords in chunk_by_z.items():
729
+
730
+ if z not in self.feature_cache and not self.previewing:
731
+ features = self.get_feature_map_slice(z, self.speed, self.cur_gpu)
732
+ features = [features[y, x] for y, x in coords]
733
+ elif z not in self.feature_cache and self.previewing:
734
+ features = self.map_slice
735
+ try:
736
+ features = [features[y, x] for y, x in coords]
737
+ except:
738
+ return [], []
739
+ else:
740
+ features = [self.feature_cache[z][y, x] for y, x in coords]
741
+
742
+ predictions = self.model.predict(features)
743
+
744
+ for (y, x), pred in zip(coords, predictions):
745
+ coord = (z, y, x) # Reconstruct the 3D coordinate as a tuple
746
+ if pred:
747
+ foreground.add(coord)
748
+ else:
749
+ background.add(coord)
401
750
 
402
751
  return foreground, background
403
752
 
753
+
754
+
404
755
  def segment_volume(self, chunk_size=None, gpu=False):
405
756
  """Segment volume using parallel processing of chunks with vectorized chunk creation"""
406
757
  #Change the above chunk size to None to have it auto-compute largest chunks (not sure which is faster, 64 seems reasonable in test cases)
407
758
 
759
+ def create_2d_chunks():
760
+ """
761
+ Create chunks by z-slices for 2D processing.
762
+ Each chunk is a complete z-slice with all y,x coordinates.
763
+
764
+ Returns:
765
+ List of chunks, where each chunk contains the coordinates for one z-slice
766
+ """
767
+ # Pre-calculate the number of coordinates per z-slice
768
+ coords_per_slice = self.image_3d.shape[1] * self.image_3d.shape[2]
769
+
770
+ # Create all coordinates at once
771
+ chunks = []
772
+ for z in range(self.image_3d.shape[0]):
773
+ # Create y, x meshgrid once
774
+ y_coords, x_coords = np.meshgrid(
775
+ np.arange(self.image_3d.shape[1]),
776
+ np.arange(self.image_3d.shape[2]),
777
+ indexing='ij'
778
+ )
779
+
780
+ # Create the slice coordinates more efficiently
781
+ slice_coords = np.column_stack((
782
+ np.full(coords_per_slice, z),
783
+ y_coords.ravel(),
784
+ x_coords.ravel()
785
+ ))
786
+
787
+ # Convert to list of tuples
788
+ chunks.append(list(map(tuple, slice_coords)))
789
+
790
+ return chunks
408
791
  try:
409
792
  from cuml.ensemble import RandomForestClassifier as cuRandomForestClassifier
410
793
  except:
@@ -418,52 +801,55 @@ class InteractiveSegmenter:
418
801
 
419
802
  print("Chunking data...")
420
803
 
421
- # Determine optimal chunk size based on number of cores if not specified
422
- if chunk_size is None:
423
- total_cores = multiprocessing.cpu_count()
424
-
425
- # Calculate total volume and target volume per core
426
- total_volume = np.prod(self.image_3d.shape)
427
- target_volume_per_chunk = total_volume / total_cores
428
-
429
- # Calculate chunk size that would give us roughly one chunk per core
430
- # Using cube root since we want roughly equal sizes in all dimensions
431
- chunk_size = int(np.cbrt(target_volume_per_chunk))
432
-
433
- # Ensure chunk size is at least 32 (minimum reasonable size) and not larger than smallest dimension
434
- chunk_size = max(32, min(chunk_size, min(self.image_3d.shape)))
804
+ if not self.use_two:
805
+ # Determine optimal chunk size based on number of cores if not specified
806
+ if chunk_size is None:
807
+ total_cores = multiprocessing.cpu_count()
808
+
809
+ # Calculate total volume and target volume per core
810
+ total_volume = np.prod(self.image_3d.shape)
811
+ target_volume_per_chunk = total_volume / total_cores
812
+
813
+ # Calculate chunk size that would give us roughly one chunk per core
814
+ # Using cube root since we want roughly equal sizes in all dimensions
815
+ chunk_size = int(np.cbrt(target_volume_per_chunk))
816
+
817
+ # Ensure chunk size is at least 32 (minimum reasonable size) and not larger than smallest dimension
818
+ chunk_size = max(32, min(chunk_size, min(self.image_3d.shape)))
819
+
820
+ # Round to nearest multiple of 32 for better memory alignment
821
+ chunk_size = ((chunk_size + 15) // 32) * 32
435
822
 
436
- # Round to nearest multiple of 32 for better memory alignment
437
- chunk_size = ((chunk_size + 15) // 32) * 32
438
-
439
- # Calculate number of chunks in each dimension
440
- z_chunks = (self.image_3d.shape[0] + chunk_size - 1) // chunk_size
441
- y_chunks = (self.image_3d.shape[1] + chunk_size - 1) // chunk_size
442
- x_chunks = (self.image_3d.shape[2] + chunk_size - 1) // chunk_size
443
-
444
- # Create start indices for all chunks at once
445
- chunk_starts = np.array(np.meshgrid(
446
- np.arange(z_chunks) * chunk_size,
447
- np.arange(y_chunks) * chunk_size,
448
- np.arange(x_chunks) * chunk_size,
449
- indexing='ij'
450
- )).reshape(3, -1).T
451
-
452
- chunks = []
453
- for z_start, y_start, x_start in chunk_starts:
454
- z_end = min(z_start + chunk_size, self.image_3d.shape[0])
455
- y_end = min(y_start + chunk_size, self.image_3d.shape[1])
456
- x_end = min(x_start + chunk_size, self.image_3d.shape[2])
823
+ # Calculate number of chunks in each dimension
824
+ z_chunks = (self.image_3d.shape[0] + chunk_size - 1) // chunk_size
825
+ y_chunks = (self.image_3d.shape[1] + chunk_size - 1) // chunk_size
826
+ x_chunks = (self.image_3d.shape[2] + chunk_size - 1) // chunk_size
457
827
 
458
- # Create coordinates for this chunk efficiently
459
- coords = np.stack(np.meshgrid(
460
- np.arange(z_start, z_end),
461
- np.arange(y_start, y_end),
462
- np.arange(x_start, x_end),
828
+ # Create start indices for all chunks at once
829
+ chunk_starts = np.array(np.meshgrid(
830
+ np.arange(z_chunks) * chunk_size,
831
+ np.arange(y_chunks) * chunk_size,
832
+ np.arange(x_chunks) * chunk_size,
463
833
  indexing='ij'
464
834
  )).reshape(3, -1).T
465
835
 
466
- chunks.append(list(map(tuple, coords)))
836
+ chunks = []
837
+ for z_start, y_start, x_start in chunk_starts:
838
+ z_end = min(z_start + chunk_size, self.image_3d.shape[0])
839
+ y_end = min(y_start + chunk_size, self.image_3d.shape[1])
840
+ x_end = min(x_start + chunk_size, self.image_3d.shape[2])
841
+
842
+ # Create coordinates for this chunk efficiently
843
+ coords = np.stack(np.meshgrid(
844
+ np.arange(z_start, z_end),
845
+ np.arange(y_start, y_end),
846
+ np.arange(x_start, x_end),
847
+ indexing='ij'
848
+ )).reshape(3, -1).T
849
+
850
+ chunks.append(list(map(tuple, coords)))
851
+ else:
852
+ chunks = create_2d_chunks()
467
853
 
468
854
  foreground_coords = set()
469
855
  background_coords = set()
@@ -471,7 +857,7 @@ class InteractiveSegmenter:
471
857
  print("Segmenting chunks...")
472
858
 
473
859
 
474
- with ThreadPoolExecutor() as executor:
860
+ with ThreadPoolExecutor(max_workers=multiprocessing.cpu_count()) as executor:
475
861
  if gpu:
476
862
  try:
477
863
  futures = [executor.submit(self.process_chunk_GPU, chunk) for chunk in chunks]
@@ -491,11 +877,39 @@ class InteractiveSegmenter:
491
877
  return foreground_coords, background_coords
492
878
 
493
879
  def update_position(self, z=None, x=None, y=None):
494
- """Update current position for chunk prioritization"""
880
+ """Update current position for chunk prioritization with safeguards"""
881
+
882
+ # Check if we should skip this update
883
+ if hasattr(self, '_skip_next_update') and self._skip_next_update:
884
+ self._skip_next_update = False
885
+ return
886
+
887
+ # Store the previous z-position if not set
888
+ if not hasattr(self, 'prev_z') or self.prev_z is None:
889
+ self.prev_z = z
890
+
891
+ # Check if currently processing - if so, only update position but don't trigger map_slice changes
892
+ if hasattr(self, '_currently_processing') and self._currently_processing:
893
+ self.current_z = z
894
+ self.current_x = x
895
+ self.current_y = y
896
+ self.prev_z = z
897
+ return
898
+
899
+ # Update current positions
495
900
  self.current_z = z
496
901
  self.current_x = x
497
902
  self.current_y = y
498
-
903
+
904
+ # Only clear map_slice if z changes and we're not already generating a new one
905
+ if self.current_z != self.prev_z:
906
+ # Instead of setting to None, check if we already have it in the cache
907
+ if hasattr(self, 'feature_cache') and self.current_z not in self.feature_cache:
908
+ self.map_slice = None
909
+ self._currently_segmenting = None
910
+
911
+ # Update previous z
912
+ self.prev_z = z
499
913
 
500
914
  def get_realtime_chunks(self, chunk_size = 64):
501
915
  print("Computing some overhead...")
@@ -572,6 +986,7 @@ class InteractiveSegmenter:
572
986
  gpu = False
573
987
 
574
988
 
989
+
575
990
  if self.realtimechunks is None:
576
991
  self.get_realtime_chunks()
577
992
  else:
@@ -658,55 +1073,141 @@ class InteractiveSegmenter:
658
1073
  del futures[future]
659
1074
  yield fore, back
660
1075
 
1076
+
661
1077
  def cleanup(self):
662
1078
  """Clean up GPU memory"""
663
1079
  if self.use_gpu:
664
1080
  cp.get_default_memory_pool().free_all_blocks()
665
1081
  torch.cuda.empty_cache()
666
1082
 
667
- def train_batch(self, foreground_array, speed = True, use_gpu = False):
1083
+ def train_batch(self, foreground_array, speed = True, use_gpu = False, use_two = False):
668
1084
  """Train directly on foreground and background arrays"""
669
1085
 
1086
+ self.speed = speed
1087
+ self.cur_gpu = use_gpu
1088
+
670
1089
  if self.current_speed != speed:
671
1090
  self.feature_cache = None
672
1091
 
673
- if self.feature_cache is None:
674
- with self.lock:
675
- if self.feature_cache is None and speed:
676
- if use_gpu:
677
- self.feature_cache = self.compute_feature_maps()
678
- else:
679
- self.feature_cache = self.compute_feature_maps_cpu()
1092
+ if use_two:
680
1093
 
681
- elif self.feature_cache is None and not speed:
682
- if use_gpu:
1094
+ changed = [] #Track which slices need feature maps
683
1095
 
684
- self.feature_cache = self.compute_deep_feature_maps()
685
- else:
686
- self.feature_cache = self.compute_deep_feature_maps_cpu()
1096
+ if not self.use_two: #Clarifies if we need to redo feature cache for 2D
1097
+ self.feature_cache = None
1098
+ self.use_two = True
687
1099
 
1100
+ if self.feature_cache == None:
1101
+ self.feature_cache = {}
688
1102
 
689
- try:
690
1103
  # Get foreground coordinates and features
691
1104
  z_fore, y_fore, x_fore = np.where(foreground_array == 1)
692
- foreground_features = self.feature_cache[z_fore, y_fore, x_fore]
693
1105
 
694
1106
  # Get background coordinates and features
695
1107
  z_back, y_back, x_back = np.where(foreground_array == 2)
696
- background_features = self.feature_cache[z_back, y_back, x_back]
697
-
698
- # Combine features and labels
699
- X = np.vstack([foreground_features, background_features])
700
- y = np.hstack([np.ones(len(z_fore)), np.zeros(len(z_back))])
701
-
702
- # Train the model
703
- self.model.fit(X, y)
704
1108
 
705
- self.current_speed = speed
706
- except:
707
- print("Features maps computed, but no segmentation examples were provided so the model was not trained")
1109
+ slices = set(list(z_back) + list(z_fore))
1110
+
1111
+ for z in slices:
1112
+ if z not in self.two_slices:
1113
+ changed.append(z)
1114
+ self.two_slices.append(z) #Tracks assigning coords to feature map slices
1115
+
1116
+ foreground_features = []
1117
+ background_features = []
1118
+
1119
+ for i, z in enumerate(z_fore):
1120
+ if z in changed: # Means this slice needs a feature map
1121
+ new_map = self.get_feature_map_slice(z, speed, use_gpu)
1122
+ self.feature_cache[z] = new_map
1123
+ changed.remove(z)
1124
+
1125
+ current_map = self.feature_cache[z]
1126
+
1127
+ # Get the feature vector for this foreground point
1128
+ feature_vector = current_map[y_fore[i], x_fore[i]]
1129
+
1130
+ # Add to our collection
1131
+ foreground_features.append(feature_vector)
1132
+
1133
+ for i, z in enumerate(z_back):
1134
+ if z in changed: # Means this slice needs a feature map
1135
+ new_map = self.get_feature_map_slice(z, speed, use_gpu)
1136
+ self.feature_cache[z] = new_map
1137
+
1138
+ current_map = self.feature_cache[z]
1139
+
1140
+ # Get the feature vector for this foreground point
1141
+ feature_vector = current_map[y_back[i], x_back[i]]
1142
+
1143
+ # Add to our collection
1144
+ background_features.append(feature_vector)
1145
+
1146
+
1147
+ else:
1148
+
1149
+ self.two_slices = []
1150
+
1151
+ if self.use_two: #Clarifies if we need to redo feature cache for 3D
1152
+
1153
+ self.feature_cache = None
1154
+ self.use_two = False
1155
+
1156
+ if self.feature_cache is None:
1157
+ with self.lock:
1158
+ if self.feature_cache is None and speed:
1159
+ if use_gpu:
1160
+ self.feature_cache = self.compute_feature_maps()
1161
+ else:
1162
+ self.feature_cache = self.compute_feature_maps_cpu()
1163
+
1164
+ elif self.feature_cache is None and not speed:
1165
+ if use_gpu:
1166
+
1167
+ self.feature_cache = self.compute_deep_feature_maps()
1168
+ else:
1169
+ self.feature_cache = self.compute_deep_feature_maps_cpu()
1170
+
1171
+
1172
+ try:
1173
+ # Get foreground coordinates and features
1174
+ z_fore, y_fore, x_fore = np.where(foreground_array == 1)
1175
+ foreground_features = self.feature_cache[z_fore, y_fore, x_fore]
1176
+
1177
+ # Get background coordinates and features
1178
+ z_back, y_back, x_back = np.where(foreground_array == 2)
1179
+ background_features = self.feature_cache[z_back, y_back, x_back]
1180
+ except:
1181
+ print("Features maps computed, but no segmentation examples were provided so the model was not trained")
1182
+
1183
+
1184
+ # Combine features and labels
1185
+ X = np.vstack([foreground_features, background_features])
1186
+ y = np.hstack([np.ones(len(z_fore)), np.zeros(len(z_back))])
1187
+
1188
+ # Train the model
1189
+ self.model.fit(X, y)
1190
+
1191
+ self.current_speed = speed
1192
+
1193
+
708
1194
 
709
1195
 
710
1196
  print("Done")
711
1197
 
1198
+ def get_feature_map_slice(self, z, speed, use_gpu):
1199
+
1200
+ if self._currently_segmenting is not None:
1201
+ return
1202
+
1203
+ with self.lock:
1204
+ if speed:
1205
+
1206
+ output = self.compute_feature_maps_cpu_2d(z = z)
1207
+
1208
+ elif not speed:
1209
+
1210
+ output = self.compute_deep_feature_maps_cpu_2d(z = z)
1211
+
1212
+ return output
712
1213