napari-tmidas 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,13 +1,17 @@
1
1
  """
2
2
  Batch Crop Anything - A Napari plugin for interactive image cropping
3
3
 
4
- This plugin combines Segment Anything Model (SAM) for automatic object detection with
4
+ This plugin combines SAM2 for automatic object detection with
5
5
  an interactive interface for selecting and cropping objects from images.
6
+ The plugin supports both 2D (YX) and 3D (TYX/ZYX) data.
6
7
  """
7
8
 
9
+ import contextlib
8
10
  import os
11
+ import sys
9
12
 
10
13
  import numpy as np
14
+ import requests
11
15
  import torch
12
16
  from magicgui import magicgui
13
17
  from napari.layers import Labels
@@ -22,34 +26,73 @@ from qtpy.QtWidgets import (
22
26
  QMessageBox,
23
27
  QPushButton,
24
28
  QScrollArea,
25
- QSlider,
26
29
  QTableWidget,
27
30
  QTableWidgetItem,
28
31
  QVBoxLayout,
29
32
  QWidget,
30
33
  )
31
34
  from skimage.io import imread
32
- from skimage.transform import resize # Added import for resize function
35
+ from skimage.transform import resize
33
36
  from tifffile import imwrite
34
37
 
38
+ from napari_tmidas.processing_functions.sam2_mp4 import tif_to_mp4
39
+
40
+ sam2_paths = [
41
+ os.environ.get("SAM2_PATH"),
42
+ "/opt/sam2",
43
+ os.path.expanduser("~/sam2"),
44
+ "./sam2",
45
+ ]
46
+
47
+ for path in sam2_paths:
48
+ if path and os.path.exists(path):
49
+ sys.path.append(path)
50
+ break
51
+ else:
52
+ print(
53
+ "Warning: SAM2 not found in common locations. Please set SAM2_PATH environment variable."
54
+ )
55
+
56
+
57
+ def get_device():
58
+ if sys.platform == "darwin":
59
+ # MacOS: Only check for MPS
60
+ if (
61
+ hasattr(torch.backends, "mps")
62
+ and torch.backends.mps.is_available()
63
+ ):
64
+ device = torch.device("mps")
65
+ print("Using Apple Silicon GPU (MPS)")
66
+ else:
67
+ device = torch.device("cpu")
68
+ print("Using CPU")
69
+ else:
70
+ # Other platforms: check for CUDA, then CPU
71
+ if hasattr(torch, "cuda") and torch.cuda.is_available():
72
+ device = torch.device("cuda")
73
+ print(f"Using CUDA GPU: {torch.cuda.get_device_name()}")
74
+ else:
75
+ device = torch.device("cpu")
76
+ print("Using CPU")
77
+ return device
78
+
35
79
 
36
80
  class BatchCropAnything:
37
- """
38
- Class for processing images with Segment Anything and cropping selected objects.
39
- """
81
+ """Class for processing images with SAM2 and cropping selected objects."""
40
82
 
41
- def __init__(self, viewer: Viewer):
83
+ def __init__(self, viewer: Viewer, use_3d=False):
42
84
  """Initialize the BatchCropAnything processor."""
43
85
  # Core components
44
86
  self.viewer = viewer
45
87
  self.images = []
46
88
  self.current_index = 0
89
+ self.use_3d = use_3d
47
90
 
48
91
  # Image and segmentation data
49
92
  self.original_image = None
50
93
  self.segmentation_result = None
51
94
  self.current_image_for_segmentation = None
52
- self.current_scale_factor = 1.0 # Added scale factor tracking
95
+ self.current_scale_factor = 1.0
53
96
 
54
97
  # UI references
55
98
  self.image_layer = None
@@ -63,101 +106,73 @@ class BatchCropAnything:
63
106
  # Segmentation parameters
64
107
  self.sensitivity = 50 # Default sensitivity (0-100 scale)
65
108
 
66
- # Initialize the SAM model
67
- self._initialize_sam()
68
-
69
- # --------------------------------------------------
70
- # Model Initialization
71
- # --------------------------------------------------
72
-
73
- def _initialize_sam(self):
74
- """Initialize the Segment Anything Model."""
75
- try:
76
- # Import required modules
77
- from mobile_sam import (
78
- SamAutomaticMaskGenerator,
79
- sam_model_registry,
80
- )
109
+ # Initialize the SAM2 model
110
+ self._initialize_sam2()
81
111
 
82
- # Setup device
83
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
84
- model_type = "vit_t"
112
+ def _initialize_sam2(self):
113
+ """Initialize the SAM2 model based on dimension mode."""
85
114
 
86
- # Find the model weights file
87
- checkpoint_path = self._find_sam_checkpoint()
88
- if checkpoint_path is None:
89
- self.mobile_sam = None
90
- self.mask_generator = None
91
- return
115
+ def download_checkpoint(url, dest_folder):
116
+ import os
92
117
 
93
- # Initialize the model
94
- self.mobile_sam = sam_model_registry[model_type](
95
- checkpoint=checkpoint_path
96
- )
97
- self.mobile_sam.to(device=self.device)
98
- self.mobile_sam.eval()
118
+ os.makedirs(dest_folder, exist_ok=True)
119
+ filename = os.path.join(dest_folder, url.split("/")[-1])
120
+ if not os.path.exists(filename):
121
+ print(f"Downloading checkpoint to {filename}...")
122
+ response = requests.get(url, stream=True, timeout=30)
123
+ response.raise_for_status()
124
+ with open(filename, "wb") as f:
125
+ for chunk in response.iter_content(chunk_size=8192):
126
+ f.write(chunk)
127
+ print("Download complete.")
128
+ else:
129
+ print(f"Checkpoint already exists at {filename}.")
130
+ return filename
99
131
 
100
- # Create mask generator with default parameters
101
- self.mask_generator = SamAutomaticMaskGenerator(self.mobile_sam)
102
- self.viewer.status = f"Initialized SAM model from {checkpoint_path} on {self.device}"
132
+ try:
133
+ # import torch
103
134
 
104
- except (ImportError, Exception) as e:
105
- self.viewer.status = f"Error initializing SAM: {str(e)}"
106
- self.mobile_sam = None
107
- self.mask_generator = None
135
+ self.device = get_device()
108
136
 
109
- def _find_sam_checkpoint(self):
110
- """Find the SAM model checkpoint file."""
111
- try:
112
- import importlib.util
113
-
114
- # Find the mobile_sam package location
115
- mobile_sam_spec = importlib.util.find_spec("mobile_sam")
116
- if mobile_sam_spec is None:
117
- raise ImportError("mobile_sam package not found")
118
-
119
- mobile_sam_path = os.path.dirname(mobile_sam_spec.origin)
120
-
121
- # Check common locations for the model file
122
- checkpoint_paths = [
123
- os.path.join(mobile_sam_path, "weights", "mobile_sam.pt"),
124
- os.path.join(mobile_sam_path, "mobile_sam.pt"),
125
- os.path.join(
126
- os.path.dirname(mobile_sam_path),
127
- "weights",
128
- "mobile_sam.pt",
129
- ),
130
- os.path.join(
131
- os.path.expanduser("~"), "models", "mobile_sam.pt"
132
- ),
133
- "/opt/T-MIDAS/models/mobile_sam.pt",
134
- os.path.join(os.getcwd(), "mobile_sam.pt"),
135
- ]
136
-
137
- for path in checkpoint_paths:
138
- if os.path.exists(path):
139
- return path
140
-
141
- # If model not found, ask user
142
- QMessageBox.information(
143
- None,
144
- "Model Not Found",
145
- "Mobile-SAM model weights not found. Please select the mobile_sam.pt file.",
137
+ # Download checkpoint if needed
138
+ checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt"
139
+ checkpoint_path = download_checkpoint(
140
+ checkpoint_url, "/opt/sam2/checkpoints/"
146
141
  )
142
+ model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
147
143
 
148
- checkpoint_path, _ = QFileDialog.getOpenFileName(
149
- None, "Select Mobile-SAM model file", "", "Model Files (*.pt)"
150
- )
144
+ if self.use_3d:
145
+ from sam2.build_sam import build_sam2_video_predictor
151
146
 
152
- return checkpoint_path if checkpoint_path else None
147
+ self.predictor = build_sam2_video_predictor(
148
+ model_cfg, checkpoint_path, device=self.device
149
+ )
150
+ self.viewer.status = (
151
+ f"Initialized SAM2 Video Predictor on {self.device}"
152
+ )
153
+ else:
154
+ from sam2.build_sam import build_sam2
155
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
156
+
157
+ self.predictor = SAM2ImagePredictor(
158
+ build_sam2(model_cfg, checkpoint_path)
159
+ )
160
+ self.viewer.status = (
161
+ f"Initialized SAM2 Image Predictor on {self.device}"
162
+ )
153
163
 
154
- except (ImportError, Exception) as e:
155
- self.viewer.status = f"Error finding SAM checkpoint: {str(e)}"
156
- return None
164
+ except (
165
+ ImportError,
166
+ RuntimeError,
167
+ ValueError,
168
+ FileNotFoundError,
169
+ requests.RequestException,
170
+ ) as e:
171
+ import traceback
157
172
 
158
- # --------------------------------------------------
159
- # Image Loading and Navigation
160
- # --------------------------------------------------
173
+ self.viewer.status = f"Error initializing SAM2: {str(e)}"
174
+ self.predictor = None
175
+ print(traceback.format_exc())
161
176
 
162
177
  def load_images(self, folder_path: str):
163
178
  """Load images from the specified folder path."""
@@ -169,17 +184,19 @@ class BatchCropAnything:
169
184
  self.images = [
170
185
  os.path.join(folder_path, file)
171
186
  for file in files
172
- if file.lower().endswith(
173
- (".tif", ".tiff", ".png", ".jpg", ".jpeg")
174
- )
175
- and not file.endswith(("_labels.tif", "_cropped.tif", "_cropped_"))
187
+ if file.lower().endswith(".tif")
188
+ or file.lower().endswith(".tiff")
189
+ and "label" not in file.lower()
190
+ and "cropped" not in file.lower()
191
+ and "_labels_" not in file.lower()
192
+ and "_cropped_" not in file.lower()
176
193
  ]
177
194
 
178
195
  if not self.images:
179
196
  self.viewer.status = "No compatible images found in the folder."
180
197
  return
181
198
 
182
- self.viewer.status = f"Found {len(self.images)} images."
199
+ self.viewer.status = f"Found {len(self.images)} .tif images."
183
200
  self.current_index = 0
184
201
  self._load_current_image()
185
202
 
@@ -237,9 +254,9 @@ class BatchCropAnything:
237
254
  self.viewer.status = "No images to process."
238
255
  return
239
256
 
240
- if self.mobile_sam is None or self.mask_generator is None:
257
+ if self.predictor is None:
241
258
  self.viewer.status = (
242
- "SAM model not initialized. Cannot segment images."
259
+ "SAM2 model not initialized. Cannot segment images."
243
260
  )
244
261
  return
245
262
 
@@ -253,66 +270,147 @@ class BatchCropAnything:
253
270
  # Load and process image
254
271
  self.original_image = imread(image_path)
255
272
 
256
- # Ensure image is 8-bit for SAM display (keeping original for saving)
257
- if self.original_image.dtype != np.uint8:
258
- image_for_display = (
259
- self.original_image / np.amax(self.original_image) * 255
260
- ).astype(np.uint8)
273
+ # For 3D/4D data, determine dimensions
274
+ if self.use_3d and len(self.original_image.shape) >= 3:
275
+ # Check shape to identify dimensions
276
+ if len(self.original_image.shape) == 4: # TZYX or similar
277
+ # Identify time dimension as first dim with size > 4 and < 400
278
+ # This is a heuristic to differentiate time from channels/small Z stacks
279
+ time_dim_idx = -1
280
+ for i, dim_size in enumerate(self.original_image.shape):
281
+ if 4 < dim_size < 400:
282
+ time_dim_idx = i
283
+ break
284
+
285
+ if time_dim_idx == 0: # TZYX format
286
+ # Keep as is, T is already the first dimension
287
+ self.image_layer = self.viewer.add_image(
288
+ self.original_image,
289
+ name=f"Image ({os.path.basename(image_path)})",
290
+ )
291
+ # Store time dimension info
292
+ self.time_dim_size = self.original_image.shape[0]
293
+ self.has_z_dim = True
294
+ elif (
295
+ time_dim_idx > 0
296
+ ): # Unusual format, we need to transpose
297
+ # Transpose to move T to first dimension
298
+ # Create permutation order that puts time_dim_idx first
299
+ perm_order = list(
300
+ range(len(self.original_image.shape))
301
+ )
302
+ perm_order.remove(time_dim_idx)
303
+ perm_order.insert(0, time_dim_idx)
304
+
305
+ transposed_image = np.transpose(
306
+ self.original_image, perm_order
307
+ )
308
+ self.original_image = (
309
+ transposed_image # Replace with transposed version
310
+ )
311
+
312
+ self.image_layer = self.viewer.add_image(
313
+ self.original_image,
314
+ name=f"Image ({os.path.basename(image_path)})",
315
+ )
316
+ # Store time dimension info
317
+ self.time_dim_size = self.original_image.shape[0]
318
+ self.has_z_dim = True
319
+ else:
320
+ # No time dimension found, treat as ZYX
321
+ self.image_layer = self.viewer.add_image(
322
+ self.original_image,
323
+ name=f"Image ({os.path.basename(image_path)})",
324
+ )
325
+ self.time_dim_size = 1
326
+ self.has_z_dim = True
327
+ elif (
328
+ len(self.original_image.shape) == 3
329
+ ): # Could be TYX or ZYX
330
+ # Check if first dimension is likely time (> 4, < 400)
331
+ if 4 < self.original_image.shape[0] < 400:
332
+ # Likely TYX format
333
+ self.image_layer = self.viewer.add_image(
334
+ self.original_image,
335
+ name=f"Image ({os.path.basename(image_path)})",
336
+ )
337
+ self.time_dim_size = self.original_image.shape[0]
338
+ self.has_z_dim = False
339
+ else:
340
+ # Likely ZYX format or another 3D format
341
+ self.image_layer = self.viewer.add_image(
342
+ self.original_image,
343
+ name=f"Image ({os.path.basename(image_path)})",
344
+ )
345
+ self.time_dim_size = 1
346
+ self.has_z_dim = True
347
+ else:
348
+ # Should not reach here with use_3d=True, but just in case
349
+ self.image_layer = self.viewer.add_image(
350
+ self.original_image,
351
+ name=f"Image ({os.path.basename(image_path)})",
352
+ )
353
+ self.time_dim_size = 1
354
+ self.has_z_dim = False
261
355
  else:
262
- image_for_display = self.original_image
263
-
264
- # Add image to viewer
265
- self.image_layer = self.viewer.add_image(
266
- image_for_display,
267
- name=f"Image ({os.path.basename(image_path)})",
268
- )
356
+ # Handle 2D data as before
357
+ if self.original_image.dtype != np.uint8:
358
+ image_for_display = (
359
+ self.original_image
360
+ / np.amax(self.original_image)
361
+ * 255
362
+ ).astype(np.uint8)
363
+ else:
364
+ image_for_display = self.original_image
365
+
366
+ # Add image to viewer
367
+ self.image_layer = self.viewer.add_image(
368
+ image_for_display,
369
+ name=f"Image ({os.path.basename(image_path)})",
370
+ )
269
371
 
270
372
  # Generate segmentation
271
- self._generate_segmentation(image_for_display)
373
+ self._generate_segmentation(self.original_image, image_path)
272
374
 
273
- except (Exception, ValueError) as e:
375
+ except (FileNotFoundError, ValueError, TypeError, OSError) as e:
274
376
  import traceback
275
377
 
276
378
  self.viewer.status = f"Error processing image: {str(e)}"
277
379
  traceback.print_exc()
380
+
278
381
  # Create empty segmentation in case of error
279
382
  if (
280
383
  hasattr(self, "original_image")
281
384
  and self.original_image is not None
282
385
  ):
283
- self.segmentation_result = np.zeros(
284
- self.original_image.shape[:2], dtype=np.uint32
285
- )
386
+ if self.use_3d:
387
+ shape = self.original_image.shape
388
+ else:
389
+ shape = self.original_image.shape[:2]
390
+
391
+ self.segmentation_result = np.zeros(shape, dtype=np.uint32)
286
392
  self.label_layer = self.viewer.add_labels(
287
393
  self.segmentation_result, name="Error: No Segmentation"
288
394
  )
289
395
 
290
- # --------------------------------------------------
291
- # Segmentation Generation and Control
292
- # --------------------------------------------------
293
-
294
- def _generate_segmentation(self, image):
295
- """Generate segmentation for the current image."""
296
- # Prepare for SAM (add color channel if needed)
297
- if len(image.shape) == 2:
298
- image_for_sam = image[:, :, np.newaxis].repeat(3, axis=2)
299
- else:
300
- image_for_sam = image
301
-
302
- # Store the current image for later regeneration if sensitivity changes
303
- self.current_image_for_segmentation = image_for_sam
396
+ def _generate_segmentation(self, image, image_path: str):
397
+ """Generate segmentation for the current image using SAM2."""
398
+ # Store the current image for later processing
399
+ self.current_image_for_segmentation = image
304
400
 
305
401
  # Generate segmentation with current sensitivity
306
- self.generate_segmentation_with_sensitivity()
402
+ self.generate_segmentation_with_sensitivity(image_path)
307
403
 
308
- def generate_segmentation_with_sensitivity(self, sensitivity=None):
404
+ def generate_segmentation_with_sensitivity(
405
+ self, image_path: str, sensitivity=None
406
+ ):
309
407
  """Generate segmentation with the specified sensitivity."""
310
408
  if sensitivity is not None:
311
409
  self.sensitivity = sensitivity
312
410
 
313
- if self.mobile_sam is None or self.mask_generator is None:
411
+ if self.predictor is None:
314
412
  self.viewer.status = (
315
- "SAM model not initialized. Cannot segment images."
413
+ "SAM2 model not initialized. Cannot segment images."
316
414
  )
317
415
  return
318
416
 
@@ -321,298 +419,740 @@ class BatchCropAnything:
321
419
  return
322
420
 
323
421
  try:
324
- # Map sensitivity (0-100) to SAM parameters
325
- # Higher sensitivity (100) = lower thresholds = more objects detected
326
- # Lower sensitivity (0) = higher thresholds = fewer objects detected
422
+ # Map sensitivity (0-100) to SAM2 parameters
423
+ # For SAM2, adjust confidence threshold based on sensitivity
424
+ confidence_threshold = (
425
+ 0.9 - (self.sensitivity / 100) * 0.4
426
+ ) # Range from 0.9 to 0.5
427
+
428
+ # Process based on dimension mode
429
+ if self.use_3d:
430
+ # Process 3D data
431
+ self._generate_3d_segmentation(
432
+ confidence_threshold, image_path
433
+ )
434
+ else:
435
+ # Process 2D data
436
+ self._generate_2d_segmentation(confidence_threshold)
437
+
438
+ except (
439
+ ValueError,
440
+ RuntimeError,
441
+ torch.cuda.OutOfMemoryError,
442
+ TypeError,
443
+ ) as e:
444
+ import traceback
327
445
 
328
- # pred_iou_thresh range: 0.92 (low sensitivity) to 0.75 (high sensitivity)
329
- pred_iou = 0.92 - (self.sensitivity / 100) * 0.17
446
+ self.viewer.status = f"Error generating segmentation: {str(e)}"
447
+ traceback.print_exc()
330
448
 
331
- # stability_score_thresh range: 0.97 (low sensitivity) to 0.85 (high sensitivity)
332
- stability = 0.97 - (self.sensitivity / 100) * 0.12
449
+ def _generate_2d_segmentation(self, confidence_threshold):
450
+ """Generate 2D segmentation using SAM2 Image Predictor."""
451
+ # Ensure image is in the correct format for SAM2
452
+ image = self.current_image_for_segmentation
453
+
454
+ # Handle resizing for very large images
455
+ orig_shape = image.shape[:2]
456
+ image_mp = (orig_shape[0] * orig_shape[1]) / 1e6
457
+ max_mp = 2.0 # Maximum image size in megapixels
458
+
459
+ if image_mp > max_mp:
460
+ scale_factor = np.sqrt(max_mp / image_mp)
461
+ new_height = int(orig_shape[0] * scale_factor)
462
+ new_width = int(orig_shape[1] * scale_factor)
463
+
464
+ self.viewer.status = f"Downscaling image from {orig_shape} to {(new_height, new_width)} for processing"
465
+
466
+ # Resize image
467
+ resized_image = resize(
468
+ image,
469
+ (new_height, new_width),
470
+ anti_aliasing=True,
471
+ preserve_range=True,
472
+ ).astype(
473
+ np.float32
474
+ ) # Convert to float32
475
+
476
+ self.current_scale_factor = scale_factor
477
+ else:
478
+ # Convert to float32 format
479
+ if image.dtype != np.float32:
480
+ resized_image = image.astype(np.float32)
481
+ else:
482
+ resized_image = image
483
+ self.current_scale_factor = 1.0
484
+
485
+ # Ensure image is in RGB format for SAM2
486
+ if len(resized_image.shape) == 2:
487
+ # Convert grayscale to RGB
488
+ resized_image = np.stack([resized_image] * 3, axis=-1)
489
+ elif len(resized_image.shape) == 3 and resized_image.shape[2] == 1:
490
+ # Convert single channel to RGB
491
+ resized_image = np.concatenate([resized_image] * 3, axis=2)
492
+ elif len(resized_image.shape) == 3 and resized_image.shape[2] > 3:
493
+ # Use first 3 channels
494
+ resized_image = resized_image[:, :, :3]
495
+
496
+ # Normalize the image to [0,1] range if it's not already
497
+ if resized_image.max() > 1.0:
498
+ resized_image = resized_image / 255.0
499
+
500
+ # Set SAM2 prediction parameters based on sensitivity
501
+ with torch.inference_mode(), torch.autocast(
502
+ "cuda", dtype=torch.float32
503
+ ):
504
+ # Set the image in the predictor
505
+ self.predictor.set_image(resized_image)
506
+
507
+ # Use automatic points generation with confidence threshold
508
+ masks, scores, _ = self.predictor.predict(
509
+ point_coords=None,
510
+ point_labels=None,
511
+ box=None,
512
+ multimask_output=True,
513
+ )
333
514
 
334
- # min_mask_region_area range: 300 (low sensitivity) to 30 (high sensitivity)
335
- min_area = 300 - (self.sensitivity / 100) * 270
515
+ # Filter masks by confidence threshold
516
+ valid_masks = scores > confidence_threshold
517
+ masks = masks[valid_masks]
518
+ scores = scores[valid_masks]
336
519
 
337
- # Configure mask generator with sensitivity-adjusted parameters
338
- self.mask_generator.pred_iou_thresh = pred_iou
339
- self.mask_generator.stability_score_thresh = stability
340
- self.mask_generator.min_mask_region_area = min_area
520
+ # Convert masks to label image
521
+ labels = np.zeros(resized_image.shape[:2], dtype=np.uint32)
522
+ self.label_info = {} # Reset label info
341
523
 
342
- # Apply gamma correction based on sensitivity
343
- # Low sensitivity: gamma > 1 (brighten image)
344
- # High sensitivity: gamma < 1 (darken image)
345
- gamma = (
346
- 1.5 - (self.sensitivity / 100) * 1.0
347
- ) # Range from 1.5 to 0.5
524
+ for i, mask in enumerate(masks):
525
+ label_id = i + 1 # Start label IDs from 1
526
+ labels[mask] = label_id
348
527
 
349
- # Apply gamma correction to the input image
350
- image_for_processing = self.current_image_for_segmentation.copy()
528
+ # Calculate label information
529
+ area = np.sum(mask)
530
+ y_indices, x_indices = np.where(mask)
531
+ center_y = np.mean(y_indices) if len(y_indices) > 0 else 0
532
+ center_x = np.mean(x_indices) if len(x_indices) > 0 else 0
351
533
 
352
- # Convert to float for proper gamma correction
353
- image_float = image_for_processing.astype(np.float32) / 255.0
534
+ # Store label info
535
+ self.label_info[label_id] = {
536
+ "area": area,
537
+ "center_y": center_y,
538
+ "center_x": center_x,
539
+ "score": float(scores[i]),
540
+ }
354
541
 
355
- # Apply gamma correction
356
- image_gamma = np.power(image_float, gamma)
542
+ # Handle upscaling if needed
543
+ if self.current_scale_factor < 1.0:
544
+ labels = resize(
545
+ labels,
546
+ orig_shape,
547
+ order=0, # Nearest neighbor interpolation
548
+ preserve_range=True,
549
+ anti_aliasing=False,
550
+ ).astype(np.uint32)
357
551
 
358
- # Convert back to uint8
359
- image_gamma = (image_gamma * 255).astype(np.uint8)
552
+ # Sort labels by area (largest first)
553
+ self.label_info = dict(
554
+ sorted(
555
+ self.label_info.items(),
556
+ key=lambda item: item[1]["area"],
557
+ reverse=True,
558
+ )
559
+ )
360
560
 
361
- # Check if the image is very large and needs downscaling
362
- orig_shape = image_gamma.shape[:2] # (height, width)
561
+ # Save segmentation result
562
+ self.segmentation_result = labels
363
563
 
364
- # Calculate image size in megapixels
365
- image_mp = (orig_shape[0] * orig_shape[1]) / 1e6
564
+ # Update the label layer
565
+ self._update_label_layer()
366
566
 
367
- # If image is larger than 2 megapixels, downscale it
368
- max_mp = 2.0 # Maximum image size in megapixels
369
- scale_factor = 1.0
567
+ def _generate_3d_segmentation(self, confidence_threshold, image_path):
568
+ """
569
+ Initialize 3D segmentation using SAM2 Video Predictor.
570
+ This correctly sets up interactive segmentation following SAM2's video approach.
571
+ """
572
+ try:
573
+ # Handle image_path - make sure it's a string
574
+ if not isinstance(image_path, str):
575
+ image_path = self.images[self.current_index]
370
576
 
371
- if image_mp > max_mp:
372
- scale_factor = np.sqrt(max_mp / image_mp)
373
- new_height = int(orig_shape[0] * scale_factor)
374
- new_width = int(orig_shape[1] * scale_factor)
577
+ # Initialize empty segmentation
578
+ volume_shape = self.current_image_for_segmentation.shape
579
+ labels = np.zeros(volume_shape, dtype=np.uint32)
580
+ self.segmentation_result = labels
375
581
 
376
- self.viewer.status = f"Downscaling image from {orig_shape} to {(new_height, new_width)} for processing (scale: {scale_factor:.2f})"
582
+ # Create a temp directory for the MP4 conversion if needed
583
+ import os
584
+ import tempfile
377
585
 
378
- # Resize the image for processing
379
- image_gamma_resized = resize(
380
- image_gamma,
381
- (new_height, new_width),
382
- anti_aliasing=True,
383
- preserve_range=True,
384
- ).astype(np.uint8)
586
+ temp_dir = tempfile.gettempdir()
587
+ mp4_path = os.path.join(
588
+ temp_dir, f"temp_volume_{os.path.basename(image_path)}.mp4"
589
+ )
385
590
 
386
- # Store scale factor for later use
387
- self.current_scale_factor = scale_factor
388
- else:
389
- image_gamma_resized = image_gamma
390
- self.current_scale_factor = 1.0
591
+ # If we need to save a modified version for MP4 conversion
592
+ need_temp_tif = False
593
+ temp_tif_path = None
391
594
 
392
- self.viewer.status = f"Generating segmentation with sensitivity {self.sensitivity} (gamma={gamma:.2f})..."
595
+ # Check if we have a 4D volume with Z dimension
596
+ if (
597
+ hasattr(self, "has_z_dim")
598
+ and self.has_z_dim
599
+ and len(self.current_image_for_segmentation.shape) == 4
600
+ ):
601
+ # We need to convert the 4D TZYX to a 3D TYX for proper video conversion
602
+ # by taking maximum intensity projection of Z for each time point
603
+ self.viewer.status = (
604
+ "Converting 4D TZYX volume to 3D TYX for SAM2..."
605
+ )
606
+
607
+ # Create maximum intensity projection along Z axis (axis 1 in TZYX)
608
+ projected_volume = np.max(
609
+ self.current_image_for_segmentation, axis=1
610
+ )
393
611
 
394
- # Generate masks with gamma-corrected and potentially resized image
395
- masks = self.mask_generator.generate(image_gamma_resized)
396
- self.viewer.status = f"Generated {len(masks)} masks"
612
+ # Save this as a temporary TIF for MP4 conversion
613
+ temp_tif_path = os.path.join(
614
+ temp_dir, f"temp_projected_{os.path.basename(image_path)}"
615
+ )
616
+ imwrite(temp_tif_path, projected_volume)
617
+ need_temp_tif = True
397
618
 
398
- if not masks:
619
+ # Convert the projected TIF to MP4
399
620
  self.viewer.status = (
400
- "No segments detected. Try increasing the sensitivity."
621
+ "Converting projected 3D volume to MP4 format for SAM2..."
401
622
  )
402
- # Create empty label layer
403
- shape = self.current_image_for_segmentation.shape[:2]
404
- self.segmentation_result = np.zeros(shape, dtype=np.uint32)
623
+ mp4_path = tif_to_mp4(temp_tif_path)
624
+ else:
625
+ # Convert original volume to video format for SAM2
626
+ self.viewer.status = (
627
+ "Converting 3D volume to MP4 format for SAM2..."
628
+ )
629
+ mp4_path = tif_to_mp4(image_path)
405
630
 
406
- # Remove existing label layer if exists
407
- for layer in list(self.viewer.layers):
408
- if (
409
- isinstance(layer, Labels)
410
- and "Segmentation" in layer.name
411
- ):
412
- self.viewer.layers.remove(layer)
631
+ # Initialize SAM2 state with the video
632
+ self.viewer.status = "Initializing SAM2 Video Predictor..."
633
+ with torch.inference_mode(), torch.autocast(
634
+ "cuda", dtype=torch.bfloat16
635
+ ):
636
+ self._sam2_state = self.predictor.init_state(mp4_path)
413
637
 
414
- # Add new empty label layer
415
- self.label_layer = self.viewer.add_labels(
416
- self.segmentation_result,
417
- name=f"Segmentation ({os.path.basename(self.images[self.current_index])})",
418
- opacity=0.7,
638
+ # Store needed state for 3D processing
639
+ self._sam2_next_obj_id = 1
640
+ self._sam2_prompts = (
641
+ {}
642
+ ) # Store prompts for each object (points, labels, box)
643
+
644
+ # Update the label layer with empty segmentation
645
+ self._update_label_layer()
646
+
647
+ # Replace the click handler for interactive 3D segmentation
648
+ if self.label_layer is not None and hasattr(
649
+ self.label_layer, "mouse_drag_callbacks"
650
+ ):
651
+ for callback in list(self.label_layer.mouse_drag_callbacks):
652
+ self.label_layer.mouse_drag_callbacks.remove(callback)
653
+
654
+ # Add 3D-specific click handler
655
+ self.label_layer.mouse_drag_callbacks.append(
656
+ self._on_3d_label_clicked
419
657
  )
420
658
 
421
- # Make the label layer active
422
- self.viewer.layers.selection.active = self.label_layer
659
+ # Set the viewer to show the first frame
660
+ if hasattr(self.viewer, "dims") and self.viewer.dims.ndim > 2:
661
+ self.viewer.dims.set_point(
662
+ 0, 0
663
+ ) # Set the first dimension (typically time/z) to 0
664
+
665
+ # Clean up temporary file if we created one
666
+ if (
667
+ need_temp_tif
668
+ and temp_tif_path
669
+ and os.path.exists(temp_tif_path)
670
+ ):
671
+ with contextlib.suppress(Exception):
672
+ os.remove(temp_tif_path)
673
+
674
+ # Show instructions
675
+ self.viewer.status = (
676
+ "3D Mode active: Navigate to the first frame where object appears, then click. "
677
+ "Use Shift+click for negative points (to remove areas). "
678
+ "Segmentation will be propagated to all frames automatically."
679
+ )
680
+
681
+ return True
682
+
683
+ except (
684
+ FileNotFoundError,
685
+ RuntimeError,
686
+ torch.cuda.OutOfMemoryError,
687
+ ValueError,
688
+ OSError,
689
+ ) as e:
690
+ import traceback
691
+
692
+ self.viewer.status = f"Error in 3D segmentation setup: {str(e)}"
693
+ traceback.print_exc()
694
+ return False
695
+
696
+ def _on_3d_label_clicked(self, layer, event):
697
+ """Handle click on 3D label layer to add a prompt for segmentation."""
698
+ try:
699
+ if event.button != 1:
423
700
  return
424
701
 
425
- # Process segmentation masks
426
- # If image was downscaled, we need to ensure masks are upscaled correctly
427
- if self.current_scale_factor < 1.0:
428
- # Upscale the segmentation masks to match the original image dimensions
429
- self._process_segmentation_masks_with_scaling(
430
- masks, self.current_image_for_segmentation.shape[:2]
702
+ coords = layer.world_to_data(event.position)
703
+ if len(coords) == 3:
704
+ z, y, x = map(int, coords)
705
+ elif len(coords) == 2:
706
+ z = int(self.viewer.dims.current_step[0])
707
+ y, x = map(int, coords)
708
+ else:
709
+ self.viewer.status = (
710
+ f"Unexpected coordinate dimensions: {coords}"
431
711
  )
712
+ return
713
+
714
+ # Check if Shift key is pressed
715
+ is_negative = "Shift" in event.modifiers
716
+ point_label = -1 if is_negative else 1
717
+
718
+ # Initialize a unique object ID for this click
719
+ if not hasattr(self, "_sam2_next_obj_id"):
720
+ self._sam2_next_obj_id = 1
721
+
722
+ # Get current object ID (or create new one)
723
+ label_id = self.segmentation_result[z, y, x]
724
+ if is_negative and label_id > 0:
725
+ # Use existing object ID for negative points
726
+ ann_obj_id = label_id
727
+ else:
728
+ # Create new object for positive points on background
729
+ ann_obj_id = self._sam2_next_obj_id
730
+ if point_label > 0 and label_id == 0:
731
+ self._sam2_next_obj_id += 1
732
+
733
+ # Find or create points layer for this object
734
+ points_layer = None
735
+ for layer in list(self.viewer.layers):
736
+ if f"Points for Object {ann_obj_id}" in layer.name:
737
+ points_layer = layer
738
+ break
739
+
740
+ if points_layer is None:
741
+ # Create new points layer for this object
742
+ points_layer = self.viewer.add_points(
743
+ np.array([[z, y, x]]),
744
+ name=f"Points for Object {ann_obj_id}",
745
+ size=10,
746
+ face_color="green" if point_label > 0 else "red",
747
+ border_color="white",
748
+ border_width=1,
749
+ opacity=0.8,
750
+ )
751
+
752
+ with contextlib.suppress(AttributeError, ValueError):
753
+ points_layer.mouse_drag_callbacks.remove(
754
+ self._on_points_clicked
755
+ )
756
+ points_layer.mouse_drag_callbacks.append(
757
+ self._on_points_clicked
758
+ )
759
+
760
+ # Initialize points for this object
761
+ if not hasattr(self, "sam2_points_by_obj"):
762
+ self.sam2_points_by_obj = {}
763
+ self.sam2_labels_by_obj = {}
764
+
765
+ self.sam2_points_by_obj[ann_obj_id] = [[x, y]]
766
+ self.sam2_labels_by_obj[ann_obj_id] = [point_label]
432
767
  else:
433
- self._process_segmentation_masks(
434
- masks, self.current_image_for_segmentation.shape[:2]
768
+ # Add to existing points layer
769
+ current_points = points_layer.data
770
+ new_points = np.vstack([current_points, [z, y, x]])
771
+ points_layer.data = new_points
772
+
773
+ # Add to existing point lists
774
+ if not hasattr(self, "sam2_points_by_obj"):
775
+ self.sam2_points_by_obj = {}
776
+ self.sam2_labels_by_obj = {}
777
+
778
+ if ann_obj_id not in self.sam2_points_by_obj:
779
+ self.sam2_points_by_obj[ann_obj_id] = []
780
+ self.sam2_labels_by_obj[ann_obj_id] = []
781
+
782
+ self.sam2_points_by_obj[ann_obj_id].append([x, y])
783
+ self.sam2_labels_by_obj[ann_obj_id].append(point_label)
784
+
785
+ # Perform SAM2 segmentation
786
+ if hasattr(self, "_sam2_state") and self._sam2_state is not None:
787
+ points = np.array(
788
+ self.sam2_points_by_obj[ann_obj_id], dtype=np.float32
789
+ )
790
+ labels = np.array(
791
+ self.sam2_labels_by_obj[ann_obj_id], dtype=np.int32
435
792
  )
436
793
 
437
- # Clear selected labels since segmentation has changed
438
- self.selected_labels = set()
794
+ self.viewer.status = f"Processing object at frame {z}..."
439
795
 
440
- # Update table if it exists
441
- if self.label_table_widget:
442
- self._populate_label_table(self.label_table_widget)
796
+ _, out_obj_ids, out_mask_logits = (
797
+ self.predictor.add_new_points_or_box(
798
+ inference_state=self._sam2_state,
799
+ frame_idx=z,
800
+ obj_id=ann_obj_id,
801
+ points=points,
802
+ labels=labels,
803
+ )
804
+ )
443
805
 
444
- except (Exception, ValueError) as e:
806
+ # Convert logits to mask and update segmentation
807
+ mask = (out_mask_logits[0] > 0.0).cpu().numpy()
808
+
809
+ # Fix mask dimensions if needed
810
+ if mask.ndim > 2:
811
+ mask = mask.squeeze()
812
+
813
+ # Check mask dimensions and resize if needed
814
+ if mask.shape != self.segmentation_result[z].shape:
815
+ from skimage.transform import resize
816
+
817
+ mask = resize(
818
+ mask.astype(float),
819
+ self.segmentation_result[z].shape,
820
+ order=0,
821
+ preserve_range=True,
822
+ anti_aliasing=False,
823
+ ).astype(bool)
824
+
825
+ # Apply the mask to current frame
826
+ # For negative points, only remove from the current object
827
+ if point_label < 0:
828
+ # Remove only from current object
829
+ self.segmentation_result[z][
830
+ (self.segmentation_result[z] == ann_obj_id) & mask
831
+ ] = 0
832
+ else:
833
+ # Add to current object (only overwrite background)
834
+ self.segmentation_result[z][
835
+ mask & (self.segmentation_result[z] == 0)
836
+ ] = ann_obj_id
837
+
838
+ # Automatically propagate to other frames
839
+ self._propagate_mask_for_current_object(ann_obj_id, z)
840
+
841
+ # Update label layer
842
+ self._update_label_layer()
843
+
844
+ # Update label table if needed
845
+ if (
846
+ hasattr(self, "label_table_widget")
847
+ and self.label_table_widget is not None
848
+ ):
849
+ self._populate_label_table(self.label_table_widget)
850
+
851
+ self.viewer.status = (
852
+ f"Updated 3D object {ann_obj_id} across all frames"
853
+ )
854
+ else:
855
+ self.viewer.status = "SAM2 3D state not initialized"
856
+
857
+ except (
858
+ IndexError,
859
+ KeyError,
860
+ ValueError,
861
+ RuntimeError,
862
+ torch.cuda.OutOfMemoryError,
863
+ ) as e:
445
864
  import traceback
446
865
 
447
- self.viewer.status = f"Error generating segmentation: {str(e)}"
866
+ self.viewer.status = f"Error in 3D click handler: {str(e)}"
448
867
  traceback.print_exc()
449
868
 
450
- def _process_segmentation_masks(self, masks, shape):
451
- """Process segmentation masks and create label layer."""
452
- # Create label image from masks
453
- labels = np.zeros(shape, dtype=np.uint32)
454
- self.label_info = {} # Reset label info
869
+ def _propagate_mask_for_current_object(self, obj_id, current_frame_idx):
870
+ """
871
+ Propagate the mask for the current object from the given frame to all other frames.
872
+ Uses SAM2's video propagation with proper error handling.
455
873
 
456
- for i, mask_data in enumerate(masks):
457
- mask = mask_data["segmentation"]
458
- label_id = i + 1 # Start label IDs from 1
459
- labels[mask] = label_id
874
+ Parameters:
875
+ obj_id: The ID of the object to propagate
876
+ current_frame_idx: The frame index where the object was identified
877
+ """
878
+ try:
879
+ if not hasattr(self, "_sam2_state") or self._sam2_state is None:
880
+ self.viewer.status = (
881
+ "SAM2 3D state not initialized for propagation"
882
+ )
883
+ return
460
884
 
461
- # Calculate label information
462
- area = np.sum(mask)
463
- y_indices, x_indices = np.where(mask)
464
- center_y = np.mean(y_indices) if len(y_indices) > 0 else 0
465
- center_x = np.mean(x_indices) if len(x_indices) > 0 else 0
885
+ total_frames = self.segmentation_result.shape[0]
886
+ self.viewer.status = f"Propagating object {obj_id} through all {total_frames} frames..."
466
887
 
467
- # Store label info
468
- self.label_info[label_id] = {
469
- "area": area,
470
- "center_y": center_y,
471
- "center_x": center_x,
472
- "score": mask_data.get("stability_score", 0),
473
- }
888
+ # Create a progress layer for visualization
889
+ progress_layer = None
890
+ for layer in list(self.viewer.layers):
891
+ if "Propagation Progress" in layer.name:
892
+ progress_layer = layer
893
+ break
474
894
 
475
- # Sort labels by area (largest first)
476
- self.label_info = dict(
477
- sorted(
478
- self.label_info.items(),
479
- key=lambda item: item[1]["area"],
480
- reverse=True,
895
+ if progress_layer is None:
896
+ progress_data = np.zeros_like(
897
+ self.segmentation_result, dtype=float
898
+ )
899
+ progress_layer = self.viewer.add_image(
900
+ progress_data,
901
+ name="Propagation Progress",
902
+ colormap="magma",
903
+ opacity=0.3,
904
+ visible=True,
905
+ )
906
+
907
+ # Update current frame in the progress layer
908
+ progress_data = progress_layer.data
909
+ current_mask = (
910
+ self.segmentation_result[current_frame_idx] == obj_id
481
911
  )
482
- )
912
+ progress_data[current_frame_idx] = current_mask.astype(float) * 0.8
913
+ progress_layer.data = progress_data
914
+
915
+ # Try to perform SAM2 propagation with error handling
916
+ try:
917
+ # Use torch.inference_mode() and torch.autocast to ensure consistent dtypes
918
+ with torch.inference_mode(), torch.autocast(
919
+ "cuda", dtype=torch.float32
920
+ ):
921
+ # Attempt to run SAM2 propagation - this will iterate through all frames
922
+ for (
923
+ frame_idx,
924
+ object_ids,
925
+ mask_logits,
926
+ ) in self.predictor.propagate_in_video(self._sam2_state):
927
+ if frame_idx >= total_frames:
928
+ continue
929
+
930
+ # Find our object ID in the results
931
+ # obj_mask = None
932
+ for i, prop_obj_id in enumerate(object_ids):
933
+ if prop_obj_id == obj_id:
934
+ # Get the mask for our object
935
+ mask = (mask_logits[i] > 0.0).cpu().numpy()
936
+
937
+ # Fix dimensions if needed
938
+ if mask.ndim > 2:
939
+ mask = mask.squeeze()
940
+
941
+ # Resize if needed
942
+ if (
943
+ mask.shape
944
+ != self.segmentation_result[
945
+ frame_idx
946
+ ].shape
947
+ ):
948
+ from skimage.transform import resize
949
+
950
+ mask = resize(
951
+ mask.astype(float),
952
+ self.segmentation_result[
953
+ frame_idx
954
+ ].shape,
955
+ order=0,
956
+ preserve_range=True,
957
+ anti_aliasing=False,
958
+ ).astype(bool)
959
+
960
+ # Update segmentation - only replacing background pixels
961
+ self.segmentation_result[frame_idx][
962
+ mask
963
+ & (
964
+ self.segmentation_result[frame_idx]
965
+ == 0
966
+ )
967
+ ] = obj_id
968
+
969
+ # Update progress visualization
970
+ progress_data = progress_layer.data
971
+ progress_data[frame_idx] = (
972
+ mask.astype(float) * 0.8
973
+ )
974
+ progress_layer.data = progress_data
975
+
976
+ # Update status occasionally
977
+ if frame_idx % 10 == 0:
978
+ self.viewer.status = f"Propagating: frame {frame_idx+1}/{total_frames}"
979
+
980
+ except RuntimeError as e:
981
+ # If we get a dtype mismatch or other error, the current frame's mask to other frames
982
+ self.viewer.status = f"SAM2 propagation failed with error: {str(e)}. Falling back to alternative method."
983
+
984
+ # Use the current frame's mask for propagation
985
+ for frame_idx in range(total_frames):
986
+ if (
987
+ frame_idx != current_frame_idx
988
+ ): # Skip current frame as it's already done
989
+ # Only replace background pixels with the current frame's object
990
+ self.segmentation_result[frame_idx][
991
+ current_mask
992
+ & (self.segmentation_result[frame_idx] == 0)
993
+ ] = obj_id
994
+
995
+ # Update progress layer
996
+ progress_data = progress_layer.data
997
+ progress_data[frame_idx] = (
998
+ current_mask.astype(float) * 0.5
999
+ ) # Different intensity to indicate fallback
1000
+ progress_layer.data = progress_data
1001
+
1002
+ # Update status occasionally
1003
+ if frame_idx % 10 == 0:
1004
+ self.viewer.status = f"Fallback propagation: frame {frame_idx+1}/{total_frames}"
1005
+
1006
+ # Remove progress layer after 2 seconds
1007
+ import threading
1008
+
1009
+ def remove_progress():
1010
+ import time
1011
+
1012
+ time.sleep(2)
1013
+ for layer in list(self.viewer.layers):
1014
+ if "Propagation Progress" in layer.name:
1015
+ self.viewer.layers.remove(layer)
483
1016
 
484
- # Save segmentation result
485
- self.segmentation_result = labels
1017
+ threading.Thread(target=remove_progress).start()
486
1018
 
487
- # Remove existing label layer if exists
488
- for layer in list(self.viewer.layers):
489
- if isinstance(layer, Labels) and "Segmentation" in layer.name:
490
- self.viewer.layers.remove(layer)
1019
+ self.viewer.status = f"Propagation of object {obj_id} complete"
491
1020
 
492
- # Add label layer to viewer
493
- self.label_layer = self.viewer.add_labels(
494
- labels,
495
- name=f"Segmentation ({os.path.basename(self.images[self.current_index])})",
496
- opacity=0.7,
497
- )
1021
+ except (
1022
+ IndexError,
1023
+ ValueError,
1024
+ RuntimeError,
1025
+ torch.cuda.OutOfMemoryError,
1026
+ TypeError,
1027
+ ) as e:
1028
+ import traceback
498
1029
 
499
- # Make the label layer active by default
500
- self.viewer.layers.selection.active = self.label_layer
1030
+ self.viewer.status = f"Error in propagation: {str(e)}"
1031
+ traceback.print_exc()
501
1032
 
502
- # Disconnect existing callbacks if any
503
- if (
504
- hasattr(self, "label_layer")
505
- and self.label_layer is not None
506
- and hasattr(self.label_layer, "mouse_drag_callbacks")
507
- ):
508
- # Remove old callbacks
509
- for callback in list(self.label_layer.mouse_drag_callbacks):
510
- self.label_layer.mouse_drag_callbacks.remove(callback)
511
-
512
- # Connect mouse click event to label selection
513
- self.label_layer.mouse_drag_callbacks.append(self._on_label_clicked)
514
-
515
- # image_name = os.path.basename(self.images[self.current_index])
516
- self.viewer.status = f"Loaded image {self.current_index + 1}/{len(self.images)} - Found {len(masks)} segments"
517
-
518
- # New method for handling scaled segmentation masks
519
- def _process_segmentation_masks_with_scaling(self, masks, original_shape):
520
- """Process segmentation masks with scaling to match the original image size."""
521
- # Create label image from masks
522
- # First determine the size of the mask predictions (which are at the downscaled resolution)
523
- if not masks:
1033
+ def _add_3d_prompt(self, prompt_coords):
1034
+ """
1035
+ Given a 3D coordinate (x, y, z), run SAM2 video predictor to segment the object at that point,
1036
+ update the segmentation result and label layer.
1037
+ """
1038
+ if not hasattr(self, "_sam2_state") or self._sam2_state is None:
1039
+ self.viewer.status = "SAM2 3D state not initialized."
524
1040
  return
525
1041
 
526
- mask_shape = masks[0]["segmentation"].shape
527
-
528
- # Create an empty label image at the downscaled resolution
529
- downscaled_labels = np.zeros(mask_shape, dtype=np.uint32)
530
- self.label_info = {} # Reset label info
531
-
532
- # Fill in the downscaled labels
533
- for i, mask_data in enumerate(masks):
534
- mask = mask_data["segmentation"]
535
- label_id = i + 1 # Start label IDs from 1
536
- downscaled_labels[mask] = label_id
537
-
538
- # Store basic label info
539
- area = np.sum(mask)
540
- y_indices, x_indices = np.where(mask)
541
- center_y = np.mean(y_indices) if len(y_indices) > 0 else 0
542
- center_x = np.mean(x_indices) if len(x_indices) > 0 else 0
543
-
544
- # Scale centers to original image coordinates
545
- center_y_orig = center_y / self.current_scale_factor
546
- center_x_orig = center_x / self.current_scale_factor
1042
+ if self.predictor is None:
1043
+ self.viewer.status = "SAM2 predictor not initialized."
1044
+ return
547
1045
 
548
- # Store label info at original scale
549
- self.label_info[label_id] = {
550
- "area": area
551
- / (
552
- self.current_scale_factor**2
553
- ), # Approximate area in original scale
554
- "center_y": center_y_orig,
555
- "center_x": center_x_orig,
556
- "score": mask_data.get("stability_score", 0),
557
- }
1046
+ # Prepare prompt for SAM2: point_coords is [[x, y, t]], point_labels is [1]
1047
+ x, y, z = prompt_coords
1048
+ point_coords = np.array([[x, y, z]])
1049
+ point_labels = np.array([1]) # 1 = foreground
558
1050
 
559
- # Upscale the labels to the original image size
560
- upscaled_labels = resize(
561
- downscaled_labels,
562
- original_shape,
563
- order=0, # Nearest neighbor interpolation
564
- preserve_range=True,
565
- anti_aliasing=False,
566
- ).astype(np.uint32)
1051
+ with torch.inference_mode(), torch.autocast(
1052
+ "cuda", dtype=torch.bfloat16
1053
+ ):
1054
+ masks, scores, _ = self.predictor.predict(
1055
+ state=self._sam2_state,
1056
+ point_coords=point_coords,
1057
+ point_labels=point_labels,
1058
+ multimask_output=True,
1059
+ )
567
1060
 
568
- # Sort labels by area (largest first)
569
- self.label_info = dict(
570
- sorted(
571
- self.label_info.items(),
572
- key=lambda item: item[1]["area"],
573
- reverse=True,
1061
+ # Pick the best mask (highest score)
1062
+ if masks is not None and len(masks) > 0:
1063
+ best_idx = np.argmax(scores)
1064
+ mask = masks[best_idx]
1065
+ obj_id = self._sam2_next_obj_id
1066
+ self.segmentation_result[mask] = obj_id
1067
+ self._sam2_next_obj_id += 1
1068
+ self.viewer.status = (
1069
+ f"Added object {obj_id} at (x={x}, y={y}, z={z})"
574
1070
  )
575
- )
1071
+ self._update_label_layer()
1072
+ else:
1073
+ self.viewer.status = "No mask found for this prompt."
1074
+
1075
+ def on_apply_propagate(self):
1076
+ """Propagate masks across the video and update the segmentation layer."""
1077
+ self.viewer.status = "Propagating masks across all frames..."
1078
+ self.viewer.window._qt_window.setCursor(Qt.WaitCursor)
1079
+
1080
+ self.segmentation_result[:] = 0
1081
+
1082
+ for (
1083
+ frame_idx,
1084
+ object_ids,
1085
+ mask_logits,
1086
+ ) in self.predictor.propagate_in_video(self._sam2_state):
1087
+ masks = (mask_logits > 0.0).cpu().numpy()
1088
+ if frame_idx >= self.segmentation_result.shape[0]:
1089
+ print(
1090
+ f"Warning: frame_idx {frame_idx} out of bounds for segmentation_result with shape {self.segmentation_result.shape}"
1091
+ )
1092
+ continue
1093
+ for i, obj_id in enumerate(object_ids):
1094
+ self.segmentation_result[frame_idx][masks[i]] = obj_id
1095
+ self.viewer.status = f"Propagating: frame {frame_idx+1}"
576
1096
 
577
- # Save segmentation result
578
- self.segmentation_result = upscaled_labels
1097
+ self._update_label_layer()
1098
+ self.viewer.status = "Propagation complete!"
1099
+ self.viewer.window._qt_window.setCursor(Qt.ArrowCursor)
579
1100
 
580
- # Remove existing label layer if exists
1101
+ def _update_label_layer(self):
1102
+ """Update the label layer in the viewer."""
1103
+ # Remove existing label layer if it exists
581
1104
  for layer in list(self.viewer.layers):
582
1105
  if isinstance(layer, Labels) and "Segmentation" in layer.name:
583
1106
  self.viewer.layers.remove(layer)
584
1107
 
585
1108
  # Add label layer to viewer
586
1109
  self.label_layer = self.viewer.add_labels(
587
- upscaled_labels,
1110
+ self.segmentation_result,
588
1111
  name=f"Segmentation ({os.path.basename(self.images[self.current_index])})",
589
1112
  opacity=0.7,
590
1113
  )
591
1114
 
592
- # Make the label layer active by default
593
- self.viewer.layers.selection.active = self.label_layer
1115
+ # Create points layer for interaction if it doesn't exist
1116
+ points_layer = None
1117
+ for layer in list(self.viewer.layers):
1118
+ if "Points" in layer.name:
1119
+ points_layer = layer
1120
+ break
1121
+
1122
+ if points_layer is None:
1123
+ # Initialize an empty points layer
1124
+ points_layer = self.viewer.add_points(
1125
+ np.zeros((0, 2 if not self.use_3d else 3)),
1126
+ name="Points (Click to Add)",
1127
+ size=10,
1128
+ face_color="green",
1129
+ border_color="white",
1130
+ border_width=1,
1131
+ opacity=0.8,
1132
+ )
594
1133
 
595
- # Disconnect existing callbacks if any
596
- if (
597
- hasattr(self, "label_layer")
598
- and self.label_layer is not None
599
- and hasattr(self.label_layer, "mouse_drag_callbacks")
600
- ):
601
- # Remove old callbacks
602
- for callback in list(self.label_layer.mouse_drag_callbacks):
603
- self.label_layer.mouse_drag_callbacks.remove(callback)
1134
+ with contextlib.suppress(AttributeError, ValueError):
1135
+ points_layer.mouse_drag_callbacks.remove(
1136
+ self._on_points_clicked
1137
+ )
1138
+ points_layer.mouse_drag_callbacks.append(
1139
+ self._on_points_clicked
1140
+ )
604
1141
 
605
- # Connect mouse click event to label selection
606
- self.label_layer.mouse_drag_callbacks.append(self._on_label_clicked)
1142
+ # Connect points layer mouse click event
1143
+ points_layer.mouse_drag_callbacks.append(self._on_points_clicked)
607
1144
 
608
- self.viewer.status = f"Loaded image {self.current_index + 1}/{len(self.images)} - Found {len(masks)} segments"
1145
+ # Make the points layer active to encourage interaction with it
1146
+ self.viewer.layers.selection.active = points_layer
609
1147
 
610
- # --------------------------------------------------
611
- # Label Selection and UI Elements
612
- # --------------------------------------------------
1148
+ # Update status
1149
+ n_labels = len(np.unique(self.segmentation_result)) - (
1150
+ 1 if 0 in np.unique(self.segmentation_result) else 0
1151
+ )
1152
+ self.viewer.status = f"Loaded image {self.current_index + 1}/{len(self.images)} - Found {n_labels} segments"
613
1153
 
614
- def _on_label_clicked(self, layer, event):
615
- """Handle label selection on mouse click."""
1154
+ def _on_points_clicked(self, layer, event):
1155
+ """Handle clicks on the points layer for adding/removing points."""
616
1156
  try:
617
1157
  # Only process clicks, not drags
618
1158
  if event.type != "mouse_press":
@@ -621,39 +1161,815 @@ class BatchCropAnything:
621
1161
  # Get coordinates of mouse click
622
1162
  coords = np.round(event.position).astype(int)
623
1163
 
624
- # Make sure coordinates are within bounds
625
- shape = self.segmentation_result.shape
626
- if (
627
- coords[0] < 0
628
- or coords[1] < 0
629
- or coords[0] >= shape[0]
630
- or coords[1] >= shape[1]
631
- ):
632
- return
1164
+ # Check if Shift is pressed for negative points
1165
+ is_negative = "Shift" in event.modifiers
1166
+ point_label = -1 if is_negative else 1
1167
+
1168
+ # Handle 2D vs 3D coordinates
1169
+ if self.use_3d:
1170
+ if len(coords) == 3:
1171
+ t, y, x = map(int, coords)
1172
+ elif len(coords) == 2:
1173
+ t = int(self.viewer.dims.current_step[0])
1174
+ y, x = map(int, coords)
1175
+ else:
1176
+ self.viewer.status = (
1177
+ f"Unexpected coordinate dimensions: {coords}"
1178
+ )
1179
+ return
1180
+
1181
+ # Add point to the layer immediately for visual feedback
1182
+ new_point = np.array([[t, y, x]])
1183
+ if len(layer.data) == 0:
1184
+ layer.data = new_point
1185
+ else:
1186
+ layer.data = np.vstack([layer.data, new_point])
1187
+
1188
+ # Update point colors
1189
+ colors = layer.face_color
1190
+ if isinstance(colors, list):
1191
+ colors.append("red" if is_negative else "green")
1192
+ else:
1193
+ n_points = len(layer.data)
1194
+ colors = ["green"] * (n_points - 1)
1195
+ colors.append("red" if is_negative else "green")
1196
+ layer.face_color = colors
1197
+
1198
+ # Get the object ID
1199
+ # If clicking on existing segmentation with negative point
1200
+ label_id = self.segmentation_result[t, y, x]
1201
+ if is_negative and label_id > 0:
1202
+ obj_id = label_id
1203
+ else:
1204
+ # For new objects or negative on background
1205
+ if not hasattr(self, "_sam2_next_obj_id"):
1206
+ self._sam2_next_obj_id = 1
1207
+ obj_id = self._sam2_next_obj_id
1208
+ if point_label > 0 and label_id == 0:
1209
+ self._sam2_next_obj_id += 1
1210
+
1211
+ # Store point information
1212
+ if not hasattr(self, "points_data"):
1213
+ self.points_data = {}
1214
+ self.points_labels = {}
1215
+
1216
+ if obj_id not in self.points_data:
1217
+ self.points_data[obj_id] = []
1218
+ self.points_labels[obj_id] = []
1219
+
1220
+ self.points_data[obj_id].append(
1221
+ [x, y]
1222
+ ) # Note: SAM2 expects [x,y] format
1223
+ self.points_labels[obj_id].append(point_label)
1224
+
1225
+ # Perform segmentation
1226
+ if (
1227
+ hasattr(self, "_sam2_state")
1228
+ and self._sam2_state is not None
1229
+ ):
1230
+ # Prepare points
1231
+ points = np.array(
1232
+ self.points_data[obj_id], dtype=np.float32
1233
+ )
1234
+ labels = np.array(
1235
+ self.points_labels[obj_id], dtype=np.int32
1236
+ )
1237
+
1238
+ # Create progress layer for visual feedback
1239
+ progress_layer = None
1240
+ for existing_layer in self.viewer.layers:
1241
+ if "Propagation Progress" in existing_layer.name:
1242
+ progress_layer = existing_layer
1243
+ break
1244
+
1245
+ if progress_layer is None:
1246
+ progress_data = np.zeros_like(self.segmentation_result)
1247
+ progress_layer = self.viewer.add_image(
1248
+ progress_data,
1249
+ name="Propagation Progress",
1250
+ colormap="magma",
1251
+ opacity=0.5,
1252
+ visible=True,
1253
+ )
1254
+
1255
+ # First update the current frame immediately
1256
+ self.viewer.status = f"Processing object at frame {t}..."
1257
+
1258
+ # Run SAM2 on current frame
1259
+ _, out_obj_ids, out_mask_logits = (
1260
+ self.predictor.add_new_points_or_box(
1261
+ inference_state=self._sam2_state,
1262
+ frame_idx=t,
1263
+ obj_id=obj_id,
1264
+ points=points,
1265
+ labels=labels,
1266
+ )
1267
+ )
1268
+
1269
+ # Update current frame
1270
+ mask = (out_mask_logits[0] > 0.0).cpu().numpy()
1271
+ if mask.ndim > 2:
1272
+ mask = mask.squeeze()
1273
+
1274
+ # Resize if needed
1275
+ if mask.shape != self.segmentation_result[t].shape:
1276
+ from skimage.transform import resize
1277
+
1278
+ mask = resize(
1279
+ mask.astype(float),
1280
+ self.segmentation_result[t].shape,
1281
+ order=0,
1282
+ preserve_range=True,
1283
+ anti_aliasing=False,
1284
+ ).astype(bool)
1285
+
1286
+ # Update segmentation for this frame
1287
+ if point_label < 0:
1288
+ # For negative points, only remove from this object
1289
+ self.segmentation_result[t][
1290
+ (self.segmentation_result[t] == obj_id) & mask
1291
+ ] = 0
1292
+ else:
1293
+ # For positive points, only replace background
1294
+ self.segmentation_result[t][
1295
+ mask & (self.segmentation_result[t] == 0)
1296
+ ] = obj_id
1297
+
1298
+ # Update progress layer for this frame
1299
+ progress_data = progress_layer.data
1300
+ progress_data[t] = (
1301
+ mask.astype(float) * 0.5
1302
+ ) # Highlight current frame
1303
+ progress_layer.data = progress_data
1304
+
1305
+ # Now propagate to all frames with visual feedback
1306
+ self.viewer.status = "Propagating to all frames..."
1307
+
1308
+ # Run propagation
1309
+ frame_count = self.segmentation_result.shape[0]
1310
+ for (
1311
+ frame_idx,
1312
+ prop_obj_ids,
1313
+ mask_logits,
1314
+ ) in self.predictor.propagate_in_video(self._sam2_state):
1315
+ if frame_idx >= frame_count:
1316
+ continue
1317
+
1318
+ # Find our object
1319
+ obj_mask = None
1320
+ for i, prop_obj_id in enumerate(prop_obj_ids):
1321
+ if prop_obj_id == obj_id:
1322
+ obj_mask = (mask_logits[i] > 0.0).cpu().numpy()
1323
+ if obj_mask.ndim > 2:
1324
+ obj_mask = obj_mask.squeeze()
1325
+
1326
+ # Resize if needed
1327
+ if (
1328
+ obj_mask.shape
1329
+ != self.segmentation_result[
1330
+ frame_idx
1331
+ ].shape
1332
+ ):
1333
+ obj_mask = resize(
1334
+ obj_mask.astype(float),
1335
+ self.segmentation_result[
1336
+ frame_idx
1337
+ ].shape,
1338
+ order=0,
1339
+ preserve_range=True,
1340
+ anti_aliasing=False,
1341
+ ).astype(bool)
1342
+
1343
+ # Update segmentation
1344
+ self.segmentation_result[frame_idx][
1345
+ obj_mask
1346
+ & (
1347
+ self.segmentation_result[frame_idx]
1348
+ == 0
1349
+ )
1350
+ ] = obj_id
1351
+
1352
+ # Update progress visualization
1353
+ progress_data = progress_layer.data
1354
+ progress_data[frame_idx] = (
1355
+ obj_mask.astype(float) * 0.8
1356
+ ) # Show as processed
1357
+ progress_layer.data = progress_data
1358
+
1359
+ # Update status
1360
+ if frame_idx % 5 == 0:
1361
+ self.viewer.status = f"Propagating: frame {frame_idx+1}/{frame_count}"
1362
+ # Remove the viewer.update() call as it's causing errors
1363
+
1364
+ # Process any missing frames
1365
+ processed_frames = set(range(frame_count))
1366
+ for frame_idx in range(frame_count):
1367
+ if (
1368
+ progress_data[frame_idx].max() == 0
1369
+ ): # Frame not processed yet
1370
+ # Use nearest processed frame's mask
1371
+ nearest_idx = min(
1372
+ processed_frames,
1373
+ key=lambda x: abs(x - frame_idx),
1374
+ )
1375
+ if progress_data[nearest_idx].max() > 0:
1376
+ self.segmentation_result[frame_idx][
1377
+ (self.segmentation_result[frame_idx] == 0)
1378
+ & (
1379
+ self.segmentation_result[nearest_idx]
1380
+ == obj_id
1381
+ )
1382
+ ] = obj_id
1383
+
1384
+ # Update progress visualization
1385
+ progress_data[frame_idx] = (
1386
+ progress_data[nearest_idx] * 0.6
1387
+ ) # Mark as copied
1388
+
1389
+ # Final update of progress layer
1390
+ progress_layer.data = progress_data
1391
+
1392
+ # Remove progress layer after 2 seconds
1393
+ import threading
1394
+
1395
+ def remove_progress():
1396
+ import time
1397
+
1398
+ time.sleep(2)
1399
+ for layer in list(self.viewer.layers):
1400
+ if "Propagation Progress" in layer.name:
1401
+ self.viewer.layers.remove(layer)
1402
+
1403
+ threading.Thread(target=remove_progress).start()
1404
+
1405
+ # Update UI
1406
+ self._update_label_layer()
1407
+ if (
1408
+ hasattr(self, "label_table_widget")
1409
+ and self.label_table_widget is not None
1410
+ ):
1411
+ self._populate_label_table(self.label_table_widget)
1412
+
1413
+ self.viewer.status = f"Object {obj_id} segmented and propagated to all frames"
1414
+
1415
+ else:
1416
+ # 2D case
1417
+ if len(coords) == 2:
1418
+ y, x = map(int, coords)
1419
+ else:
1420
+ self.viewer.status = (
1421
+ f"Unexpected coordinate dimensions: {coords}"
1422
+ )
1423
+ return
1424
+
1425
+ # Add point to the layer immediately for visual feedback
1426
+ new_point = np.array([[y, x]])
1427
+ if len(layer.data) == 0:
1428
+ layer.data = new_point
1429
+ else:
1430
+ layer.data = np.vstack([layer.data, new_point])
1431
+
1432
+ # Update point colors
1433
+ colors = layer.face_color
1434
+ if isinstance(colors, list):
1435
+ colors.append("red" if is_negative else "green")
1436
+ else:
1437
+ n_points = len(layer.data)
1438
+ colors = ["green"] * (n_points - 1)
1439
+ colors.append("red" if is_negative else "green")
1440
+ layer.face_color = colors
1441
+
1442
+ # Get object ID
1443
+ label_id = self.segmentation_result[y, x]
1444
+ if is_negative and label_id > 0:
1445
+ obj_id = label_id
1446
+ else:
1447
+ if not hasattr(self, "next_obj_id"):
1448
+ self.next_obj_id = 1
1449
+ obj_id = self.next_obj_id
1450
+ if point_label > 0 and label_id == 0:
1451
+ self.next_obj_id += 1
1452
+
1453
+ # Store point information
1454
+ if not hasattr(self, "obj_points"):
1455
+ self.obj_points = {}
1456
+ self.obj_labels = {}
1457
+
1458
+ if obj_id not in self.obj_points:
1459
+ self.obj_points[obj_id] = []
1460
+ self.obj_labels[obj_id] = []
1461
+
1462
+ self.obj_points[obj_id].append(
1463
+ [x, y]
1464
+ ) # SAM2 expects [x,y] format
1465
+ self.obj_labels[obj_id].append(point_label)
1466
+
1467
+ # Perform segmentation
1468
+ if hasattr(self, "predictor") and self.predictor is not None:
1469
+ # Make sure image is loaded
1470
+ if self.current_image_for_segmentation is None:
1471
+ self.viewer.status = "No image loaded for segmentation"
1472
+ return
1473
+
1474
+ # Prepare image for SAM2
1475
+ image = self.current_image_for_segmentation
1476
+ if len(image.shape) == 2:
1477
+ image = np.stack([image] * 3, axis=-1)
1478
+ elif len(image.shape) == 3 and image.shape[2] == 1:
1479
+ image = np.concatenate([image] * 3, axis=2)
1480
+ elif len(image.shape) == 3 and image.shape[2] > 3:
1481
+ image = image[:, :, :3]
1482
+
1483
+ if image.dtype != np.uint8:
1484
+ image = (image / np.max(image) * 255).astype(np.uint8)
1485
+
1486
+ # Set the image in the predictor
1487
+ self.predictor.set_image(image)
1488
+
1489
+ # Use only points for current object
1490
+ points = np.array(
1491
+ self.obj_points[obj_id], dtype=np.float32
1492
+ )
1493
+ labels = np.array(self.obj_labels[obj_id], dtype=np.int32)
1494
+
1495
+ self.viewer.status = f"Segmenting object {obj_id} with {len(points)} points..."
1496
+
1497
+ with torch.inference_mode(), torch.autocast("cuda"):
1498
+ masks, scores, _ = self.predictor.predict(
1499
+ point_coords=points,
1500
+ point_labels=labels,
1501
+ multimask_output=True,
1502
+ )
1503
+
1504
+ # Get best mask
1505
+ if len(masks) > 0:
1506
+ best_mask = masks[0]
1507
+
1508
+ # Update segmentation result
1509
+ if (
1510
+ best_mask.shape
1511
+ != self.segmentation_result.shape
1512
+ ):
1513
+ from skimage.transform import resize
1514
+
1515
+ best_mask = resize(
1516
+ best_mask.astype(float),
1517
+ self.segmentation_result.shape,
1518
+ order=0,
1519
+ preserve_range=True,
1520
+ anti_aliasing=False,
1521
+ ).astype(bool)
1522
+
1523
+ # Apply mask based on point type
1524
+ if point_label < 0:
1525
+ # Remove only from current object
1526
+ mask_condition = np.logical_and(
1527
+ self.segmentation_result == obj_id,
1528
+ best_mask,
1529
+ )
1530
+ self.segmentation_result[mask_condition] = 0
1531
+ else:
1532
+ # Add to current object (only overwrite background)
1533
+ mask_condition = np.logical_and(
1534
+ best_mask, (self.segmentation_result == 0)
1535
+ )
1536
+ self.segmentation_result[mask_condition] = (
1537
+ obj_id
1538
+ )
1539
+
1540
+ # Update label info
1541
+ area = np.sum(self.segmentation_result == obj_id)
1542
+ y_indices, x_indices = np.where(
1543
+ self.segmentation_result == obj_id
1544
+ )
1545
+ center_y = (
1546
+ np.mean(y_indices) if len(y_indices) > 0 else 0
1547
+ )
1548
+ center_x = (
1549
+ np.mean(x_indices) if len(x_indices) > 0 else 0
1550
+ )
1551
+
1552
+ self.label_info[obj_id] = {
1553
+ "area": area,
1554
+ "center_y": center_y,
1555
+ "center_x": center_x,
1556
+ "score": float(scores[0]),
1557
+ }
1558
+
1559
+ self.viewer.status = f"Updated object {obj_id}"
1560
+ else:
1561
+ self.viewer.status = "No valid mask produced"
1562
+
1563
+ # Update the UI
1564
+ self._update_label_layer()
1565
+ if (
1566
+ hasattr(self, "label_table_widget")
1567
+ and self.label_table_widget is not None
1568
+ ):
1569
+ self._populate_label_table(self.label_table_widget)
1570
+
1571
+ except (
1572
+ IndexError,
1573
+ KeyError,
1574
+ ValueError,
1575
+ RuntimeError,
1576
+ TypeError,
1577
+ ) as e:
1578
+ import traceback
633
1579
 
634
- # Get the label ID at the clicked position
635
- label_id = self.segmentation_result[coords[0], coords[1]]
1580
+ self.viewer.status = f"Error in points handling: {str(e)}"
1581
+ traceback.print_exc()
636
1582
 
637
- # Skip if background (0) is clicked
638
- if label_id == 0:
1583
+ def _on_label_clicked(self, layer, event):
1584
+ """Handle label selection and user prompts on mouse click."""
1585
+ try:
1586
+ # Only process clicks, not drags
1587
+ if event.type != "mouse_press":
639
1588
  return
640
1589
 
641
- # Toggle the label selection
642
- if label_id in self.selected_labels:
643
- self.selected_labels.remove(label_id)
644
- self.viewer.status = f"Deselected label ID: {label_id} | Selected labels: {self.selected_labels}"
1590
+ # Get coordinates of mouse click
1591
+ coords = np.round(event.position).astype(int)
1592
+
1593
+ # Check if Shift is pressed (negative point)
1594
+ is_negative = "Shift" in event.modifiers
1595
+ point_label = -1 if is_negative else 1
1596
+
1597
+ # For 2D data
1598
+ if not self.use_3d:
1599
+ if len(coords) == 2:
1600
+ y, x = map(int, coords)
1601
+ else:
1602
+ self.viewer.status = (
1603
+ f"Unexpected coordinate dimensions: {coords}"
1604
+ )
1605
+ return
1606
+
1607
+ # Check if within image bounds
1608
+ shape = self.segmentation_result.shape
1609
+ if y < 0 or x < 0 or y >= shape[0] or x >= shape[1]:
1610
+ self.viewer.status = "Click is outside image bounds"
1611
+ return
1612
+
1613
+ # Get the label ID at the clicked position
1614
+ label_id = self.segmentation_result[y, x]
1615
+
1616
+ # Initialize a unique object ID for this click (if needed)
1617
+ if not hasattr(self, "next_obj_id"):
1618
+ # Start with highest existing ID + 1
1619
+ if self.segmentation_result.max() > 0:
1620
+ self.next_obj_id = (
1621
+ int(self.segmentation_result.max()) + 1
1622
+ )
1623
+ else:
1624
+ self.next_obj_id = 1
1625
+
1626
+ # If clicking on background or using negative click, handle segmentation
1627
+ if label_id == 0 or is_negative:
1628
+ # Find or create points layer for the current object we're working on
1629
+ current_obj_id = None
1630
+
1631
+ # If negative point on existing label, use that label's ID
1632
+ if is_negative and label_id > 0:
1633
+ current_obj_id = label_id
1634
+ # For positive clicks on background, create a new object
1635
+ elif point_label > 0 and label_id == 0:
1636
+ current_obj_id = self.next_obj_id
1637
+ self.next_obj_id += 1
1638
+ # For negative on background, try to find most recent object
1639
+ elif point_label < 0 and label_id == 0:
1640
+ # Use most recently created object if available
1641
+ if hasattr(self, "obj_points") and self.obj_points:
1642
+ current_obj_id = max(self.obj_points.keys())
1643
+ else:
1644
+ self.viewer.status = "No existing object to modify with negative point"
1645
+ return
1646
+
1647
+ if current_obj_id is None:
1648
+ self.viewer.status = (
1649
+ "Could not determine which object to modify"
1650
+ )
1651
+ return
1652
+
1653
+ # Find or create points layer for this object
1654
+ points_layer = None
1655
+ for layer in list(self.viewer.layers):
1656
+ if f"Points for Object {current_obj_id}" in layer.name:
1657
+ points_layer = layer
1658
+ break
1659
+
1660
+ # Initialize object tracking if needed
1661
+ if not hasattr(self, "obj_points"):
1662
+ self.obj_points = {}
1663
+ self.obj_labels = {}
1664
+
1665
+ if current_obj_id not in self.obj_points:
1666
+ self.obj_points[current_obj_id] = []
1667
+ self.obj_labels[current_obj_id] = []
1668
+
1669
+ # Create or update points layer for this object
1670
+ if points_layer is None:
1671
+ # First point for this object
1672
+ points_layer = self.viewer.add_points(
1673
+ np.array([[y, x]]),
1674
+ name=f"Points for Object {current_obj_id}",
1675
+ size=10,
1676
+ face_color=["green" if point_label > 0 else "red"],
1677
+ border_color="white",
1678
+ border_width=1,
1679
+ opacity=0.8,
1680
+ )
1681
+ with contextlib.suppress(AttributeError, ValueError):
1682
+ points_layer.mouse_drag_callbacks.remove(
1683
+ self._on_points_clicked
1684
+ )
1685
+ points_layer.mouse_drag_callbacks.append(
1686
+ self._on_points_clicked
1687
+ )
1688
+
1689
+ self.obj_points[current_obj_id] = [[x, y]]
1690
+ self.obj_labels[current_obj_id] = [point_label]
1691
+ else:
1692
+ # Add point to existing layer
1693
+ current_points = points_layer.data
1694
+ current_colors = points_layer.face_color
1695
+
1696
+ # Add new point
1697
+ new_points = np.vstack([current_points, [y, x]])
1698
+ new_color = "green" if point_label > 0 else "red"
1699
+
1700
+ # Update points layer
1701
+ points_layer.data = new_points
1702
+
1703
+ # Update colors
1704
+ if isinstance(current_colors, list):
1705
+ current_colors.append(new_color)
1706
+ points_layer.face_color = current_colors
1707
+ else:
1708
+ # If it's an array, create a list of colors
1709
+ colors = []
1710
+ for i in range(len(new_points)):
1711
+ if i < len(current_points):
1712
+ colors.append(
1713
+ "green" if point_label > 0 else "red"
1714
+ )
1715
+ else:
1716
+ colors.append(new_color)
1717
+ points_layer.face_color = colors
1718
+
1719
+ # Update object tracking
1720
+ self.obj_points[current_obj_id].append([x, y])
1721
+ self.obj_labels[current_obj_id].append(point_label)
1722
+
1723
+ # Now do the actual segmentation using SAM2
1724
+ if (
1725
+ hasattr(self, "predictor")
1726
+ and self.predictor is not None
1727
+ ):
1728
+ try:
1729
+ # Make sure image is loaded
1730
+ if self.current_image_for_segmentation is None:
1731
+ self.viewer.status = (
1732
+ "No image loaded for segmentation"
1733
+ )
1734
+ return
1735
+
1736
+ # Prepare image for SAM2
1737
+ image = self.current_image_for_segmentation
1738
+ if len(image.shape) == 2:
1739
+ image = np.stack([image] * 3, axis=-1)
1740
+ elif len(image.shape) == 3 and image.shape[2] == 1:
1741
+ image = np.concatenate([image] * 3, axis=2)
1742
+ elif len(image.shape) == 3 and image.shape[2] > 3:
1743
+ image = image[:, :, :3]
1744
+
1745
+ if image.dtype != np.uint8:
1746
+ image = (image / np.max(image) * 255).astype(
1747
+ np.uint8
1748
+ )
1749
+
1750
+ # Set the image in the predictor
1751
+ self.predictor.set_image(image)
1752
+
1753
+ # Only use the points for the current object being segmented
1754
+ points = np.array(
1755
+ self.obj_points[current_obj_id],
1756
+ dtype=np.float32,
1757
+ )
1758
+ labels = np.array(
1759
+ self.obj_labels[current_obj_id], dtype=np.int32
1760
+ )
1761
+
1762
+ self.viewer.status = f"Segmenting object {current_obj_id} with {len(points)} points..."
1763
+
1764
+ with torch.inference_mode(), torch.autocast(
1765
+ "cuda"
1766
+ ):
1767
+ masks, scores, _ = self.predictor.predict(
1768
+ point_coords=points,
1769
+ point_labels=labels,
1770
+ multimask_output=True,
1771
+ )
1772
+
1773
+ # Get best mask
1774
+ if len(masks) > 0:
1775
+ best_mask = masks[0]
1776
+
1777
+ # Update segmentation result
1778
+ if (
1779
+ best_mask.shape
1780
+ != self.segmentation_result.shape
1781
+ ):
1782
+ from skimage.transform import resize
1783
+
1784
+ best_mask = resize(
1785
+ best_mask.astype(float),
1786
+ self.segmentation_result.shape,
1787
+ order=0,
1788
+ preserve_range=True,
1789
+ anti_aliasing=False,
1790
+ ).astype(bool)
1791
+
1792
+ # CRITICAL FIX: For negative points, only remove from this object's mask
1793
+ # For positive points, add to this object's mask without removing other objects
1794
+ if point_label < 0:
1795
+ # Remove only from current object's mask
1796
+ self.segmentation_result[
1797
+ (
1798
+ self.segmentation_result
1799
+ == current_obj_id
1800
+ )
1801
+ & best_mask
1802
+ ] = 0
1803
+ else:
1804
+ # Add to current object's mask without affecting other objects
1805
+ # Only overwrite background (value 0)
1806
+ self.segmentation_result[
1807
+ best_mask
1808
+ & (self.segmentation_result == 0)
1809
+ ] = current_obj_id
1810
+
1811
+ # Update label info
1812
+ area = np.sum(
1813
+ self.segmentation_result
1814
+ == current_obj_id
1815
+ )
1816
+ y_indices, x_indices = np.where(
1817
+ self.segmentation_result
1818
+ == current_obj_id
1819
+ )
1820
+ center_y = (
1821
+ np.mean(y_indices)
1822
+ if len(y_indices) > 0
1823
+ else 0
1824
+ )
1825
+ center_x = (
1826
+ np.mean(x_indices)
1827
+ if len(x_indices) > 0
1828
+ else 0
1829
+ )
1830
+
1831
+ self.label_info[current_obj_id] = {
1832
+ "area": area,
1833
+ "center_y": center_y,
1834
+ "center_x": center_x,
1835
+ "score": float(scores[0]),
1836
+ }
1837
+
1838
+ self.viewer.status = (
1839
+ f"Updated object {current_obj_id}"
1840
+ )
1841
+ else:
1842
+ self.viewer.status = (
1843
+ "No valid mask produced"
1844
+ )
1845
+
1846
+ # Update the UI
1847
+ self._update_label_layer()
1848
+ if (
1849
+ hasattr(self, "label_table_widget")
1850
+ and self.label_table_widget is not None
1851
+ ):
1852
+ self._populate_label_table(
1853
+ self.label_table_widget
1854
+ )
1855
+
1856
+ except (
1857
+ IndexError,
1858
+ KeyError,
1859
+ ValueError,
1860
+ AttributeError,
1861
+ TypeError,
1862
+ ) as e:
1863
+ import traceback
1864
+
1865
+ self.viewer.status = (
1866
+ f"Error in SAM2 processing: {str(e)}"
1867
+ )
1868
+ traceback.print_exc()
1869
+
1870
+ # If clicking on an existing label, toggle selection
1871
+ elif label_id > 0:
1872
+ # Toggle the label selection
1873
+ if label_id in self.selected_labels:
1874
+ self.selected_labels.remove(label_id)
1875
+ self.viewer.status = f"Deselected label ID: {label_id} | Selected labels: {self.selected_labels}"
1876
+ else:
1877
+ self.selected_labels.add(label_id)
1878
+ self.viewer.status = f"Selected label ID: {label_id} | Selected labels: {self.selected_labels}"
1879
+
1880
+ # Update table and preview
1881
+ self._update_label_table()
1882
+ self.preview_crop()
1883
+
1884
+ # 3D case (handle differently)
645
1885
  else:
646
- self.selected_labels.add(label_id)
647
- self.viewer.status = f"Selected label ID: {label_id} | Selected labels: {self.selected_labels}"
1886
+ if len(coords) == 3:
1887
+ t, y, x = map(int, coords)
1888
+ elif len(coords) == 2:
1889
+ t = int(self.viewer.dims.current_step[0])
1890
+ y, x = map(int, coords)
1891
+ else:
1892
+ self.viewer.status = (
1893
+ f"Unexpected coordinate dimensions: {coords}"
1894
+ )
1895
+ return
1896
+
1897
+ # Check if within bounds
1898
+ shape = self.segmentation_result.shape
1899
+ if (
1900
+ t < 0
1901
+ or t >= shape[0]
1902
+ or y < 0
1903
+ or y >= shape[1]
1904
+ or x < 0
1905
+ or x >= shape[2]
1906
+ ):
1907
+ self.viewer.status = "Click is outside volume bounds"
1908
+ return
1909
+
1910
+ # Get the label ID at the clicked position
1911
+ label_id = self.segmentation_result[t, y, x]
1912
+
1913
+ # If background or shift is pressed, handle in _on_3d_label_clicked
1914
+ if label_id == 0 or is_negative:
1915
+ # This will be handled by _on_3d_label_clicked already attached
1916
+ pass
1917
+ # If clicking on an existing label, handle selection
1918
+ elif label_id > 0:
1919
+ # Toggle the label selection
1920
+ if label_id in self.selected_labels:
1921
+ self.selected_labels.remove(label_id)
1922
+ self.viewer.status = f"Deselected label ID: {label_id} | Selected labels: {self.selected_labels}"
1923
+ else:
1924
+ self.selected_labels.add(label_id)
1925
+ self.viewer.status = f"Selected label ID: {label_id} | Selected labels: {self.selected_labels}"
648
1926
 
649
- # Update table if it exists
650
- self._update_label_table()
1927
+ # Update table if it exists
1928
+ self._update_label_table()
651
1929
 
652
- # Update preview after selection changes
653
- self.preview_crop()
1930
+ # Update preview after selection changes
1931
+ self.preview_crop()
654
1932
 
655
- except (Exception, ValueError) as e:
656
- self.viewer.status = f"Error selecting label: {str(e)}"
1933
+ except (
1934
+ IndexError,
1935
+ KeyError,
1936
+ ValueError,
1937
+ AttributeError,
1938
+ TypeError,
1939
+ ) as e:
1940
+ import traceback
1941
+
1942
+ self.viewer.status = f"Error in click handling: {str(e)}"
1943
+ traceback.print_exc()
1944
+
1945
+ def _add_point_marker(self, coords, label_type):
1946
+ """Add a visible marker for where the user clicked."""
1947
+ # Remove previous point markers
1948
+ for layer in list(self.viewer.layers):
1949
+ if "Point Prompt" in layer.name:
1950
+ self.viewer.layers.remove(layer)
1951
+
1952
+ # Create points layer
1953
+ color = (
1954
+ "red" if label_type < 0 else "green"
1955
+ ) # Red for negative, green for positive
1956
+ self.viewer.add_points(
1957
+ [coords],
1958
+ name="Point Prompt",
1959
+ size=10,
1960
+ face_color=color,
1961
+ edge_color="white",
1962
+ edge_width=2,
1963
+ opacity=0.8,
1964
+ )
1965
+
1966
+ with contextlib.suppress(AttributeError, ValueError):
1967
+ self.points_layer.mouse_drag_callbacks.remove(
1968
+ self._on_points_clicked
1969
+ )
1970
+ self.points_layer.mouse_drag_callbacks.append(
1971
+ self._on_points_clicked
1972
+ )
657
1973
 
658
1974
  def create_label_table(self, parent_widget):
659
1975
  """Create a table widget displaying all detected labels."""
@@ -694,57 +2010,86 @@ class BatchCropAnything:
694
2010
 
695
2011
  def _populate_label_table(self, table):
696
2012
  """Populate the table with label information."""
697
- if not self.label_info:
698
- table.setRowCount(0)
699
- return
2013
+ try:
2014
+ # Get all unique non-zero labels from the segmentation result safely
2015
+ if self.segmentation_result is None:
2016
+ # No segmentation yet
2017
+ table.setRowCount(0)
2018
+ self.viewer.status = "No segmentation available"
2019
+ return
700
2020
 
701
- # Set row count
702
- table.setRowCount(len(self.label_info))
2021
+ # Get unique labels, safely handling None values
2022
+ unique_labels = []
2023
+ for val in np.unique(self.segmentation_result):
2024
+ if val is not None and val > 0:
2025
+ unique_labels.append(val)
703
2026
 
704
- # Sort labels by size (largest first)
705
- sorted_labels = sorted(
706
- self.label_info.items(),
707
- key=lambda item: item[1]["area"],
708
- reverse=True,
709
- )
2027
+ if len(unique_labels) == 0:
2028
+ table.setRowCount(0)
2029
+ self.viewer.status = "No labeled objects found"
2030
+ return
710
2031
 
711
- # Fill table with data
712
- for row, (label_id, _info) in enumerate(sorted_labels):
713
- # Checkbox for selection
714
- checkbox_widget = QWidget()
715
- checkbox_layout = QHBoxLayout(checkbox_widget)
716
- checkbox_layout.setContentsMargins(5, 0, 5, 0)
717
- checkbox_layout.setAlignment(Qt.AlignCenter)
718
-
719
- checkbox = QCheckBox()
720
- checkbox.setChecked(label_id in self.selected_labels)
721
-
722
- # Connect checkbox to label selection
723
- def make_checkbox_callback(lid):
724
- def callback(state):
725
- if state == Qt.Checked:
726
- self.selected_labels.add(lid)
727
- else:
728
- self.selected_labels.discard(lid)
729
- self.preview_crop()
2032
+ # Set row count
2033
+ table.setRowCount(len(unique_labels))
2034
+
2035
+ # Fill in label info for any missing labels
2036
+ for label_id in unique_labels:
2037
+ if label_id not in self.label_info:
2038
+ # Calculate basic info for this label
2039
+ mask = self.segmentation_result == label_id
2040
+ area = np.sum(mask)
2041
+
2042
+ # Add info to label_info dictionary
2043
+ self.label_info[label_id] = {
2044
+ "area": area,
2045
+ "score": 1.0, # Default score
2046
+ }
2047
+
2048
+ # Fill table with data
2049
+ for row, label_id in enumerate(unique_labels):
2050
+ # Checkbox for selection
2051
+ checkbox_widget = QWidget()
2052
+ checkbox_layout = QHBoxLayout(checkbox_widget)
2053
+ checkbox_layout.setContentsMargins(5, 0, 5, 0)
2054
+ checkbox_layout.setAlignment(Qt.AlignCenter)
2055
+
2056
+ checkbox = QCheckBox()
2057
+ checkbox.setChecked(label_id in self.selected_labels)
2058
+
2059
+ # Connect checkbox to label selection
2060
+ def make_checkbox_callback(lid):
2061
+ def callback(state):
2062
+ if state == Qt.Checked:
2063
+ self.selected_labels.add(lid)
2064
+ else:
2065
+ self.selected_labels.discard(lid)
2066
+ self.preview_crop()
2067
+
2068
+ return callback
730
2069
 
731
- return callback
2070
+ checkbox.stateChanged.connect(make_checkbox_callback(label_id))
732
2071
 
733
- checkbox.stateChanged.connect(make_checkbox_callback(label_id))
2072
+ checkbox_layout.addWidget(checkbox)
2073
+ table.setCellWidget(row, 0, checkbox_widget)
734
2074
 
735
- checkbox_layout.addWidget(checkbox)
736
- table.setCellWidget(row, 0, checkbox_widget)
2075
+ # Label ID as plain text with transparent background
2076
+ item = QTableWidgetItem(str(label_id))
2077
+ item.setTextAlignment(Qt.AlignCenter)
737
2078
 
738
- # Label ID as plain text with transparent background
739
- item = QTableWidgetItem(str(label_id))
740
- item.setTextAlignment(Qt.AlignCenter)
2079
+ # Set the background color to transparent
2080
+ brush = item.background()
2081
+ brush.setStyle(Qt.NoBrush)
2082
+ item.setBackground(brush)
741
2083
 
742
- # Set the background color to transparent
743
- brush = item.background()
744
- brush.setStyle(Qt.NoBrush)
745
- item.setBackground(brush)
2084
+ table.setItem(row, 1, item)
746
2085
 
747
- table.setItem(row, 1, item)
2086
+ except (KeyError, TypeError, ValueError, AttributeError) as e:
2087
+ import traceback
2088
+
2089
+ self.viewer.status = f"Error populating table: {str(e)}"
2090
+ traceback.print_exc()
2091
+ # Set empty table as fallback
2092
+ table.setRowCount(0)
748
2093
 
749
2094
  def _update_label_table(self):
750
2095
  """Update the label selection table if it exists."""
@@ -754,6 +2099,9 @@ class BatchCropAnything:
754
2099
  # Block signals during update
755
2100
  self.label_table_widget.blockSignals(True)
756
2101
 
2102
+ # Completely repopulate the table to ensure it's up to date
2103
+ self._populate_label_table(self.label_table_widget)
2104
+
757
2105
  # Update checkboxes
758
2106
  for row in range(self.label_table_widget.rowCount()):
759
2107
  # Get label ID from the visible column
@@ -793,10 +2141,6 @@ class BatchCropAnything:
793
2141
  self.preview_crop()
794
2142
  self.viewer.status = "Cleared all selections"
795
2143
 
796
- # --------------------------------------------------
797
- # Image Processing and Export
798
- # --------------------------------------------------
799
-
800
2144
  def preview_crop(self, label_ids=None):
801
2145
  """Preview the crop result with the selected label IDs."""
802
2146
  if self.segmentation_result is None or self.image_layer is None:
@@ -826,20 +2170,29 @@ class BatchCropAnything:
826
2170
  image = self.original_image.copy()
827
2171
 
828
2172
  # Create mask from selected label IDs
829
- mask = np.zeros_like(self.segmentation_result, dtype=bool)
830
- for label_id in label_ids:
831
- mask |= self.segmentation_result == label_id
2173
+ if self.use_3d:
2174
+ # For 3D data
2175
+ mask = np.zeros_like(self.segmentation_result, dtype=bool)
2176
+ for label_id in label_ids:
2177
+ mask |= self.segmentation_result == label_id
832
2178
 
833
- # Apply mask to image for preview (set everything outside mask to 0)
834
- if len(image.shape) == 2:
835
- # Grayscale image
2179
+ # Apply mask
836
2180
  preview_image = image.copy()
837
2181
  preview_image[~mask] = 0
838
2182
  else:
839
- # Color image
840
- preview_image = image.copy()
841
- for c in range(preview_image.shape[2]):
842
- preview_image[:, :, c][~mask] = 0
2183
+ # For 2D data
2184
+ mask = np.zeros_like(self.segmentation_result, dtype=bool)
2185
+ for label_id in label_ids:
2186
+ mask |= self.segmentation_result == label_id
2187
+
2188
+ # Apply mask
2189
+ if len(image.shape) == 2:
2190
+ preview_image = image.copy()
2191
+ preview_image[~mask] = 0
2192
+ else:
2193
+ preview_image = image.copy()
2194
+ for c in range(preview_image.shape[2]):
2195
+ preview_image[:, :, c][~mask] = 0
843
2196
 
844
2197
  # Remove previous preview if exists
845
2198
  for layer in list(self.viewer.layers):
@@ -879,20 +2232,58 @@ class BatchCropAnything:
879
2232
  image = self.original_image
880
2233
 
881
2234
  # Create mask from all selected label IDs
882
- mask = np.zeros_like(self.segmentation_result, dtype=bool)
883
- for label_id in self.selected_labels:
884
- mask |= self.segmentation_result == label_id
2235
+ if self.use_3d:
2236
+ # For 3D data, create a 3D mask
2237
+ mask = np.zeros_like(self.segmentation_result, dtype=bool)
2238
+ for label_id in self.selected_labels:
2239
+ mask |= self.segmentation_result == label_id
885
2240
 
886
- # Apply mask to image (set everything outside mask to 0)
887
- if len(image.shape) == 2:
888
- # Grayscale image
2241
+ # Apply mask to image (set everything outside mask to 0)
889
2242
  cropped_image = image.copy()
890
2243
  cropped_image[~mask] = 0
2244
+
2245
+ # Save label image with same dimensions as original
2246
+ label_image = np.zeros_like(
2247
+ self.segmentation_result, dtype=np.uint32
2248
+ )
2249
+ for label_id in self.selected_labels:
2250
+ label_image[self.segmentation_result == label_id] = (
2251
+ label_id
2252
+ )
891
2253
  else:
892
- # Color image
893
- cropped_image = image.copy()
894
- for c in range(cropped_image.shape[2]):
895
- cropped_image[:, :, c][~mask] = 0
2254
+ # For 2D data, handle as before
2255
+ mask = np.zeros_like(self.segmentation_result, dtype=bool)
2256
+ for label_id in self.selected_labels:
2257
+ mask |= self.segmentation_result == label_id
2258
+
2259
+ # Apply mask to image (set everything outside mask to 0)
2260
+ if len(image.shape) == 2:
2261
+ # Grayscale image
2262
+ cropped_image = image.copy()
2263
+ cropped_image[~mask] = 0
2264
+
2265
+ # Create label image with same dimensions
2266
+ label_image = np.zeros_like(
2267
+ self.segmentation_result, dtype=np.uint32
2268
+ )
2269
+ for label_id in self.selected_labels:
2270
+ label_image[self.segmentation_result == label_id] = (
2271
+ label_id
2272
+ )
2273
+ else:
2274
+ # Color image - mask must be expanded to match channel dimension
2275
+ cropped_image = image.copy()
2276
+ for c in range(cropped_image.shape[2]):
2277
+ cropped_image[:, :, c][~mask] = 0
2278
+
2279
+ # Create label image with 2D dimensions (without channels)
2280
+ label_image = np.zeros_like(
2281
+ self.segmentation_result, dtype=np.uint32
2282
+ )
2283
+ for label_id in self.selected_labels:
2284
+ label_image[self.segmentation_result == label_id] = (
2285
+ label_id
2286
+ )
896
2287
 
897
2288
  # Save cropped image
898
2289
  image_path = self.images[self.current_index]
@@ -900,18 +2291,17 @@ class BatchCropAnything:
900
2291
  label_str = "_".join(
901
2292
  str(lid) for lid in sorted(self.selected_labels)
902
2293
  )
903
- output_path = f"{base_name}_cropped_{label_str}{ext}"
904
-
905
- # Save using appropriate method based on file type
906
- if output_path.lower().endswith((".tif", ".tiff")):
907
- imwrite(output_path, cropped_image, compression="zlib")
908
- else:
909
- from skimage.io import imsave
910
-
911
- imsave(output_path, cropped_image)
2294
+ output_path = f"{base_name}_cropped_{label_str}.tif"
912
2295
 
2296
+ # Save using tifffile with explicit parameters for best compatibility
2297
+ imwrite(output_path, cropped_image, compression="zlib")
913
2298
  self.viewer.status = f"Saved cropped image to {output_path}"
914
2299
 
2300
+ # Save the label image with exact same dimensions as original
2301
+ label_output_path = f"{base_name}_labels_{label_str}.tif"
2302
+ imwrite(label_output_path, label_image, compression="zlib")
2303
+ self.viewer.status += f"\nSaved label mask to {label_output_path}"
2304
+
915
2305
  # Make sure the segmentation layer is active again
916
2306
  if self.label_layer is not None:
917
2307
  self.viewer.layers.selection.active = self.label_layer
@@ -923,76 +2313,44 @@ class BatchCropAnything:
923
2313
  return False
924
2314
 
925
2315
 
926
- # --------------------------------------------------
927
- # UI Creation Functions
928
- # --------------------------------------------------
929
-
930
-
931
2316
  def create_crop_widget(processor):
932
2317
  """Create the crop control widget."""
933
2318
  crop_widget = QWidget()
934
2319
  layout = QVBoxLayout()
935
- layout.setSpacing(10) # Add more space between elements
936
- layout.setContentsMargins(
937
- 10, 10, 10, 10
938
- ) # Add margins around all elements
2320
+ layout.setSpacing(10)
2321
+ layout.setContentsMargins(10, 10, 10, 10)
939
2322
 
940
2323
  # Instructions
2324
+ dimension_type = "3D (TYX/ZYX)" if processor.use_3d else "2D (YX)"
941
2325
  instructions_label = QLabel(
942
- "Select objects to keep in the cropped image.\n"
943
- "You can select labels using the table below or by clicking directly on objects "
944
- "in the image (make sure the Segmentation layer is active)."
2326
+ f"<b>Processing {dimension_type} data</b><br><br>"
2327
+ "To create/edit objects:<br>"
2328
+ "1. <b>Click on the POINTS layer</b> to add positive points<br>"
2329
+ "2. Use Shift+click for negative points to refine segmentation<br>"
2330
+ "3. Click on existing objects in the Segmentation layer to select them<br>"
2331
+ "4. Press 'Crop' to save the selected objects to disk"
945
2332
  )
946
2333
  instructions_label.setWordWrap(True)
947
2334
  layout.addWidget(instructions_label)
948
2335
 
949
- # Sensitivity slider
950
- sensitivity_layout = QVBoxLayout()
951
-
952
- # Header label
953
- sensitivity_header_layout = QHBoxLayout()
954
- sensitivity_label = QLabel("Segmentation Sensitivity:")
955
- sensitivity_value_label = QLabel(f"{processor.sensitivity}")
956
- sensitivity_header_layout.addWidget(sensitivity_label)
957
- sensitivity_header_layout.addStretch()
958
- sensitivity_header_layout.addWidget(sensitivity_value_label)
959
- sensitivity_layout.addLayout(sensitivity_header_layout)
960
-
961
- # Slider
962
- slider_layout = QHBoxLayout()
963
- sensitivity_slider = QSlider(Qt.Horizontal)
964
- sensitivity_slider.setMinimum(0)
965
- sensitivity_slider.setMaximum(100)
966
- sensitivity_slider.setValue(processor.sensitivity)
967
- sensitivity_slider.setTickPosition(QSlider.TicksBelow)
968
- sensitivity_slider.setTickInterval(10)
969
- slider_layout.addWidget(sensitivity_slider)
970
-
971
- apply_sensitivity_button = QPushButton("Apply")
972
- apply_sensitivity_button.setToolTip(
973
- "Apply sensitivity changes to regenerate segmentation"
974
- )
975
- slider_layout.addWidget(apply_sensitivity_button)
976
- sensitivity_layout.addLayout(slider_layout)
977
-
978
- # Description label
979
- sensitivity_description = QLabel(
980
- "Medium sensitivity - Balanced detection (γ=1.00)"
2336
+ # Add a button to ensure points layer is active
2337
+ activate_button = QPushButton("Make Points Layer Active")
2338
+ activate_button.clicked.connect(
2339
+ lambda: processor._ensure_points_layer_active()
981
2340
  )
982
- sensitivity_description.setStyleSheet("font-style: italic; color: #666;")
983
- sensitivity_layout.addWidget(sensitivity_description)
2341
+ layout.addWidget(activate_button)
984
2342
 
985
- layout.addLayout(sensitivity_layout)
2343
+ # Add a "Clear Points" button to reset prompts
2344
+ clear_points_button = QPushButton("Clear Points")
2345
+ layout.addWidget(clear_points_button)
986
2346
 
987
2347
  # Create label table
988
2348
  label_table = processor.create_label_table(crop_widget)
989
- label_table.setMinimumHeight(150) # Reduce minimum height to save space
990
- label_table.setMaximumHeight(
991
- 300
992
- ) # Set maximum height to prevent taking too much space
2349
+ label_table.setMinimumHeight(150)
2350
+ label_table.setMaximumHeight(300)
993
2351
  layout.addWidget(label_table)
994
2352
 
995
- # Remove "Focus on Segmentation Layer" button as it's now redundant
2353
+ # Selection buttons
996
2354
  selection_layout = QHBoxLayout()
997
2355
  select_all_button = QPushButton("Select All")
998
2356
  clear_selection_button = QPushButton("Clear Selection")
@@ -1014,7 +2372,7 @@ def create_crop_widget(processor):
1014
2372
 
1015
2373
  # Status label
1016
2374
  status_label = QLabel(
1017
- "Ready to process images. Select objects using the table or by clicking on them."
2375
+ "Ready to process images. Click on POINTS layer to add segmentation points."
1018
2376
  )
1019
2377
  status_label.setWordWrap(True)
1020
2378
  layout.addWidget(status_label)
@@ -1033,36 +2391,51 @@ def create_crop_widget(processor):
1033
2391
  # Create new table
1034
2392
  label_table = processor.create_label_table(crop_widget)
1035
2393
  label_table.setMinimumHeight(200)
1036
- layout.insertWidget(3, label_table) # Insert after sensitivity slider
2394
+ layout.insertWidget(3, label_table) # Insert after clear points button
1037
2395
  return label_table
1038
2396
 
1039
- # Connect button signals
1040
- def on_sensitivity_changed(value):
1041
- sensitivity_value_label.setText(f"{value}")
1042
- # Update description based on sensitivity
1043
- if value < 25:
1044
- gamma = (
1045
- 1.5 - (value / 100) * 1.0
1046
- ) # Higher gamma for low sensitivity
1047
- description = f"Low sensitivity - Seeks large, distinct objects (γ={gamma:.2f})"
1048
- elif value < 75:
1049
- gamma = 1.5 - (value / 100) * 1.0
1050
- description = (
1051
- f"Medium sensitivity - Balanced detection (γ={gamma:.2f})"
2397
+ # Add helper method to ensure points layer is active
2398
+ def _ensure_points_layer_active():
2399
+ points_layer = None
2400
+ for layer in list(processor.viewer.layers):
2401
+ if "Points" in layer.name:
2402
+ points_layer = layer
2403
+ break
2404
+
2405
+ if points_layer is not None:
2406
+ processor.viewer.layers.selection.active = points_layer
2407
+ status_label.setText(
2408
+ "Points layer is now active - click to add points"
1052
2409
  )
1053
2410
  else:
1054
- gamma = (
1055
- 1.5 - (value / 100) * 1.0
1056
- ) # Lower gamma for high sensitivity
1057
- description = f"High sensitivity - Detects subtle, small objects (γ={gamma:.2f})"
1058
- sensitivity_description.setText(description)
1059
-
1060
- def on_apply_sensitivity_clicked():
1061
- new_sensitivity = sensitivity_slider.value()
1062
- processor.generate_segmentation_with_sensitivity(new_sensitivity)
1063
- replace_table_widget()
2411
+ status_label.setText(
2412
+ "No points layer found. Please load an image first."
2413
+ )
2414
+
2415
+ processor._ensure_points_layer_active = _ensure_points_layer_active
2416
+
2417
+ # Connect button signals
2418
+ def on_clear_points_clicked():
2419
+ # Remove all point layers
2420
+ for layer in list(processor.viewer.layers):
2421
+ if "Points" in layer.name:
2422
+ processor.viewer.layers.remove(layer)
2423
+
2424
+ # Reset point tracking attributes
2425
+ if hasattr(processor, "points_data"):
2426
+ processor.points_data = {}
2427
+ processor.points_labels = {}
2428
+
2429
+ if hasattr(processor, "obj_points"):
2430
+ processor.obj_points = {}
2431
+ processor.obj_labels = {}
2432
+
2433
+ # Re-create empty points layer
2434
+ processor._update_label_layer()
2435
+ processor._ensure_points_layer_active()
2436
+
1064
2437
  status_label.setText(
1065
- f"Regenerated segmentation with sensitivity {new_sensitivity}"
2438
+ "Cleared all points. Click on Points layer to add new points."
1066
2439
  )
1067
2440
 
1068
2441
  def on_select_all_clicked():
@@ -1086,117 +2459,83 @@ def create_crop_widget(processor):
1086
2459
  )
1087
2460
 
1088
2461
  def on_next_clicked():
2462
+ # Clear points before moving to next image
2463
+ on_clear_points_clicked()
2464
+
1089
2465
  if not processor.next_image():
1090
2466
  next_button.setEnabled(False)
1091
2467
  else:
1092
2468
  prev_button.setEnabled(True)
1093
2469
  replace_table_widget()
1094
- # Reset sensitivity slider to default
1095
- sensitivity_slider.setValue(processor.sensitivity)
1096
- sensitivity_value_label.setText(f"{processor.sensitivity}")
1097
2470
  status_label.setText(
1098
2471
  f"Showing image {processor.current_index + 1}/{len(processor.images)}"
1099
2472
  )
2473
+ processor._ensure_points_layer_active()
1100
2474
 
1101
2475
  def on_prev_clicked():
2476
+ # Clear points before moving to previous image
2477
+ on_clear_points_clicked()
2478
+
1102
2479
  if not processor.previous_image():
1103
2480
  prev_button.setEnabled(False)
1104
2481
  else:
1105
2482
  next_button.setEnabled(True)
1106
2483
  replace_table_widget()
1107
- # Reset sensitivity slider to default
1108
- sensitivity_slider.setValue(processor.sensitivity)
1109
- sensitivity_value_label.setText(f"{processor.sensitivity}")
1110
2484
  status_label.setText(
1111
2485
  f"Showing image {processor.current_index + 1}/{len(processor.images)}"
1112
2486
  )
2487
+ processor._ensure_points_layer_active()
1113
2488
 
1114
- sensitivity_slider.valueChanged.connect(on_sensitivity_changed)
1115
- apply_sensitivity_button.clicked.connect(on_apply_sensitivity_clicked)
2489
+ clear_points_button.clicked.connect(on_clear_points_clicked)
1116
2490
  select_all_button.clicked.connect(on_select_all_clicked)
1117
2491
  clear_selection_button.clicked.connect(on_clear_selection_clicked)
1118
2492
  crop_button.clicked.connect(on_crop_clicked)
1119
2493
  next_button.clicked.connect(on_next_clicked)
1120
2494
  prev_button.clicked.connect(on_prev_clicked)
2495
+ activate_button.clicked.connect(_ensure_points_layer_active)
1121
2496
 
1122
2497
  return crop_widget
1123
2498
 
1124
2499
 
1125
- # --------------------------------------------------
1126
- # Napari Plugin Functions
1127
- # --------------------------------------------------
1128
-
1129
-
1130
2500
  @magicgui(
1131
2501
  call_button="Start Batch Crop Anything",
1132
2502
  folder_path={"label": "Folder Path", "widget_type": "LineEdit"},
2503
+ data_dimensions={
2504
+ "label": "Data Dimensions",
2505
+ "choices": ["YX (2D)", "TYX/ZYX (3D)"],
2506
+ },
1133
2507
  )
1134
2508
  def batch_crop_anything(
1135
2509
  folder_path: str,
2510
+ data_dimensions: str,
1136
2511
  viewer: Viewer = None,
1137
2512
  ):
1138
- """MagicGUI widget for starting Batch Crop Anything."""
1139
- # Check if Mobile-SAM is available
2513
+ """MagicGUI widget for starting Batch Crop Anything using SAM2."""
2514
+ # Check if SAM2 is available
1140
2515
  try:
1141
- # import torch
1142
- # from mobile_sam import sam_model_registry
1143
-
1144
- # Check if the required files are included with the package
1145
- try:
1146
- import importlib.util
1147
- import os
1148
-
1149
- mobile_sam_spec = importlib.util.find_spec("mobile_sam")
1150
- if mobile_sam_spec is None:
1151
- raise ImportError("mobile_sam package not found")
1152
-
1153
- mobile_sam_path = os.path.dirname(mobile_sam_spec.origin)
1154
-
1155
- # Check for model file in package
1156
- model_found = False
1157
- checkpoint_paths = [
1158
- os.path.join(mobile_sam_path, "weights", "mobile_sam.pt"),
1159
- os.path.join(mobile_sam_path, "mobile_sam.pt"),
1160
- os.path.join(
1161
- os.path.dirname(mobile_sam_path),
1162
- "weights",
1163
- "mobile_sam.pt",
1164
- ),
1165
- os.path.join(
1166
- os.path.expanduser("~"), "models", "mobile_sam.pt"
1167
- ),
1168
- "/opt/T-MIDAS/models/mobile_sam.pt",
1169
- os.path.join(os.getcwd(), "mobile_sam.pt"),
1170
- ]
1171
-
1172
- for path in checkpoint_paths:
1173
- if os.path.exists(path):
1174
- model_found = True
1175
- break
1176
-
1177
- if not model_found:
1178
- QMessageBox.warning(
1179
- None,
1180
- "Model File Missing",
1181
- "Mobile-SAM model weights (mobile_sam.pt) not found. You'll be prompted to locate it when starting the tool.\n\n"
1182
- "You can download it from: https://github.com/ChaoningZhang/MobileSAM/tree/master/weights",
1183
- )
1184
- except (ImportError, AttributeError) as e:
1185
- print(f"Warning checking for model file: {str(e)}")
2516
+ import importlib.util
1186
2517
 
2518
+ sam2_spec = importlib.util.find_spec("sam2")
2519
+ if sam2_spec is None:
2520
+ QMessageBox.critical(
2521
+ None,
2522
+ "Missing Dependency",
2523
+ "SAM2 not found. Please follow installation instructions at:\n"
2524
+ "https://github.com/MercaderLabAnatomy/napari-tmidas?tab=readme-ov-file#dependencies\n",
2525
+ )
2526
+ return
1187
2527
  except ImportError:
1188
2528
  QMessageBox.critical(
1189
2529
  None,
1190
2530
  "Missing Dependency",
1191
- "Mobile-SAM not found. Please install with:\n"
1192
- "pip install git+https://github.com/ChaoningZhang/MobileSAM.git\n\n"
1193
- "You'll also need to download the model weights file (mobile_sam.pt) from:\n"
1194
- "https://github.com/ChaoningZhang/MobileSAM/tree/master/weights",
2531
+ "SAM2 package cannot be imported. Please follow installation instructions at\n"
2532
+ "https://github.com/MercaderLabAnatomy/napari-tmidas?tab=readme-ov-file#dependencies",
1195
2533
  )
1196
2534
  return
1197
2535
 
1198
- # Initialize processor and load images
1199
- processor = BatchCropAnything(viewer)
2536
+ # Initialize processor with the selected dimensions mode
2537
+ use_3d = "TYX/ZYX" in data_dimensions
2538
+ processor = BatchCropAnything(viewer, use_3d=use_3d)
1200
2539
  processor.load_images(folder_path)
1201
2540
 
1202
2541
  # Create UI
@@ -1205,13 +2544,9 @@ def batch_crop_anything(
1205
2544
  # Wrap the widget in a scroll area
1206
2545
  scroll_area = QScrollArea()
1207
2546
  scroll_area.setWidget(crop_widget)
1208
- scroll_area.setWidgetResizable(
1209
- True
1210
- ) # This allows the widget to resize with the scroll area
1211
- scroll_area.setFrameShape(QScrollArea.NoFrame) # Hide the frame
1212
- scroll_area.setMinimumHeight(
1213
- 500
1214
- ) # Set a minimum height to ensure visibility
2547
+ scroll_area.setWidgetResizable(True)
2548
+ scroll_area.setFrameShape(QScrollArea.NoFrame)
2549
+ scroll_area.setMinimumHeight(500)
1215
2550
 
1216
2551
  # Add scroll area to viewer
1217
2552
  viewer.window.add_dock_widget(scroll_area, name="Crop Controls")