napari-tmidas 0.1.9__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,13 +1,20 @@
1
1
  """
2
2
  Batch Crop Anything - A Napari plugin for interactive image cropping
3
3
 
4
- This plugin combines Segment Anything Model (SAM) for automatic object detection with
4
+ This plugin combines SAM2 for automatic object detection with
5
5
  an interactive interface for selecting and cropping objects from images.
6
+ The plugin supports both 2D (YX) and 3D (TYX/ZYX) data.
6
7
  """
7
8
 
9
+ import contextlib
8
10
  import os
9
11
 
12
+ # Add this at the beginning of your plugin file
13
+ import sys
14
+
15
+ sys.path.append("/opt/sam2")
10
16
  import numpy as np
17
+ import requests
11
18
  import torch
12
19
  from magicgui import magicgui
13
20
  from napari.layers import Labels
@@ -22,34 +29,55 @@ from qtpy.QtWidgets import (
22
29
  QMessageBox,
23
30
  QPushButton,
24
31
  QScrollArea,
25
- QSlider,
26
32
  QTableWidget,
27
33
  QTableWidgetItem,
28
34
  QVBoxLayout,
29
35
  QWidget,
30
36
  )
31
37
  from skimage.io import imread
32
- from skimage.transform import resize # Added import for resize function
38
+ from skimage.transform import resize
33
39
  from tifffile import imwrite
34
40
 
41
+ from napari_tmidas.processing_functions.sam2_mp4 import tif_to_mp4
42
+
43
+ def get_device():
44
+ if sys.platform == "darwin":
45
+ # MacOS: Only check for MPS
46
+ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
47
+ device = torch.device("mps")
48
+ print("Using Apple Silicon GPU (MPS)")
49
+ else:
50
+ device = torch.device("cpu")
51
+ print("Using CPU")
52
+ else:
53
+ # Other platforms: check for CUDA, then CPU
54
+ if hasattr(torch, "cuda") and torch.cuda.is_available():
55
+ device = torch.device("cuda")
56
+ print(f"Using CUDA GPU: {torch.cuda.get_device_name()}")
57
+ else:
58
+ device = torch.device("cpu")
59
+ print("Using CPU")
60
+ return device
61
+
62
+
63
+
35
64
 
36
65
  class BatchCropAnything:
37
- """
38
- Class for processing images with Segment Anything and cropping selected objects.
39
- """
66
+ """Class for processing images with SAM2 and cropping selected objects."""
40
67
 
41
- def __init__(self, viewer: Viewer):
68
+ def __init__(self, viewer: Viewer, use_3d=False):
42
69
  """Initialize the BatchCropAnything processor."""
43
70
  # Core components
44
71
  self.viewer = viewer
45
72
  self.images = []
46
73
  self.current_index = 0
74
+ self.use_3d = use_3d
47
75
 
48
76
  # Image and segmentation data
49
77
  self.original_image = None
50
78
  self.segmentation_result = None
51
79
  self.current_image_for_segmentation = None
52
- self.current_scale_factor = 1.0 # Added scale factor tracking
80
+ self.current_scale_factor = 1.0
53
81
 
54
82
  # UI references
55
83
  self.image_layer = None
@@ -63,101 +91,73 @@ class BatchCropAnything:
63
91
  # Segmentation parameters
64
92
  self.sensitivity = 50 # Default sensitivity (0-100 scale)
65
93
 
66
- # Initialize the SAM model
67
- self._initialize_sam()
68
-
69
- # --------------------------------------------------
70
- # Model Initialization
71
- # --------------------------------------------------
72
-
73
- def _initialize_sam(self):
74
- """Initialize the Segment Anything Model."""
75
- try:
76
- # Import required modules
77
- from mobile_sam import (
78
- SamAutomaticMaskGenerator,
79
- sam_model_registry,
80
- )
94
+ # Initialize the SAM2 model
95
+ self._initialize_sam2()
81
96
 
82
- # Setup device
83
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
84
- model_type = "vit_t"
97
+ def _initialize_sam2(self):
98
+ """Initialize the SAM2 model based on dimension mode."""
85
99
 
86
- # Find the model weights file
87
- checkpoint_path = self._find_sam_checkpoint()
88
- if checkpoint_path is None:
89
- self.mobile_sam = None
90
- self.mask_generator = None
91
- return
100
+ def download_checkpoint(url, dest_folder):
101
+ import os
92
102
 
93
- # Initialize the model
94
- self.mobile_sam = sam_model_registry[model_type](
95
- checkpoint=checkpoint_path
96
- )
97
- self.mobile_sam.to(device=self.device)
98
- self.mobile_sam.eval()
103
+ os.makedirs(dest_folder, exist_ok=True)
104
+ filename = os.path.join(dest_folder, url.split("/")[-1])
105
+ if not os.path.exists(filename):
106
+ print(f"Downloading checkpoint to {filename}...")
107
+ response = requests.get(url, stream=True)
108
+ response.raise_for_status()
109
+ with open(filename, "wb") as f:
110
+ for chunk in response.iter_content(chunk_size=8192):
111
+ f.write(chunk)
112
+ print("Download complete.")
113
+ else:
114
+ print(f"Checkpoint already exists at {filename}.")
115
+ return filename
99
116
 
100
- # Create mask generator with default parameters
101
- self.mask_generator = SamAutomaticMaskGenerator(self.mobile_sam)
102
- self.viewer.status = f"Initialized SAM model from {checkpoint_path} on {self.device}"
117
+ try:
118
+ # import torch
103
119
 
104
- except (ImportError, Exception) as e:
105
- self.viewer.status = f"Error initializing SAM: {str(e)}"
106
- self.mobile_sam = None
107
- self.mask_generator = None
120
+ self.device = get_device()
108
121
 
109
- def _find_sam_checkpoint(self):
110
- """Find the SAM model checkpoint file."""
111
- try:
112
- import importlib.util
113
-
114
- # Find the mobile_sam package location
115
- mobile_sam_spec = importlib.util.find_spec("mobile_sam")
116
- if mobile_sam_spec is None:
117
- raise ImportError("mobile_sam package not found")
118
-
119
- mobile_sam_path = os.path.dirname(mobile_sam_spec.origin)
120
-
121
- # Check common locations for the model file
122
- checkpoint_paths = [
123
- os.path.join(mobile_sam_path, "weights", "mobile_sam.pt"),
124
- os.path.join(mobile_sam_path, "mobile_sam.pt"),
125
- os.path.join(
126
- os.path.dirname(mobile_sam_path),
127
- "weights",
128
- "mobile_sam.pt",
129
- ),
130
- os.path.join(
131
- os.path.expanduser("~"), "models", "mobile_sam.pt"
132
- ),
133
- "/opt/T-MIDAS/models/mobile_sam.pt",
134
- os.path.join(os.getcwd(), "mobile_sam.pt"),
135
- ]
136
-
137
- for path in checkpoint_paths:
138
- if os.path.exists(path):
139
- return path
140
-
141
- # If model not found, ask user
142
- QMessageBox.information(
143
- None,
144
- "Model Not Found",
145
- "Mobile-SAM model weights not found. Please select the mobile_sam.pt file.",
122
+ # Download checkpoint if needed
123
+ checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt"
124
+ checkpoint_path = download_checkpoint(
125
+ checkpoint_url, "/opt/sam2/checkpoints/"
146
126
  )
127
+ model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
147
128
 
148
- checkpoint_path, _ = QFileDialog.getOpenFileName(
149
- None, "Select Mobile-SAM model file", "", "Model Files (*.pt)"
150
- )
129
+ if self.use_3d:
130
+ from sam2.build_sam import build_sam2_video_predictor
151
131
 
152
- return checkpoint_path if checkpoint_path else None
132
+ self.predictor = build_sam2_video_predictor(
133
+ model_cfg, checkpoint_path, device=self.device
134
+ )
135
+ self.viewer.status = (
136
+ f"Initialized SAM2 Video Predictor on {self.device}"
137
+ )
138
+ else:
139
+ from sam2.build_sam import build_sam2
140
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
141
+
142
+ self.predictor = SAM2ImagePredictor(
143
+ build_sam2(model_cfg, checkpoint_path)
144
+ )
145
+ self.viewer.status = (
146
+ f"Initialized SAM2 Image Predictor on {self.device}"
147
+ )
153
148
 
154
- except (ImportError, Exception) as e:
155
- self.viewer.status = f"Error finding SAM checkpoint: {str(e)}"
156
- return None
149
+ except (
150
+ ImportError,
151
+ RuntimeError,
152
+ ValueError,
153
+ FileNotFoundError,
154
+ requests.RequestException,
155
+ ) as e:
156
+ import traceback
157
157
 
158
- # --------------------------------------------------
159
- # Image Loading and Navigation
160
- # --------------------------------------------------
158
+ self.viewer.status = f"Error initializing SAM2: {str(e)}"
159
+ self.predictor = None
160
+ print(traceback.format_exc())
161
161
 
162
162
  def load_images(self, folder_path: str):
163
163
  """Load images from the specified folder path."""
@@ -169,17 +169,19 @@ class BatchCropAnything:
169
169
  self.images = [
170
170
  os.path.join(folder_path, file)
171
171
  for file in files
172
- if file.lower().endswith(
173
- (".tif", ".tiff", ".png", ".jpg", ".jpeg")
174
- )
175
- and not file.endswith(("_labels.tif", "_cropped.tif", "_cropped_"))
172
+ if file.lower().endswith(".tif")
173
+ or file.lower().endswith(".tiff")
174
+ and "label" not in file.lower()
175
+ and "cropped" not in file.lower()
176
+ and "_labels_" not in file.lower()
177
+ and "_cropped_" not in file.lower()
176
178
  ]
177
179
 
178
180
  if not self.images:
179
181
  self.viewer.status = "No compatible images found in the folder."
180
182
  return
181
183
 
182
- self.viewer.status = f"Found {len(self.images)} images."
184
+ self.viewer.status = f"Found {len(self.images)} .tif images."
183
185
  self.current_index = 0
184
186
  self._load_current_image()
185
187
 
@@ -237,9 +239,9 @@ class BatchCropAnything:
237
239
  self.viewer.status = "No images to process."
238
240
  return
239
241
 
240
- if self.mobile_sam is None or self.mask_generator is None:
242
+ if self.predictor is None:
241
243
  self.viewer.status = (
242
- "SAM model not initialized. Cannot segment images."
244
+ "SAM2 model not initialized. Cannot segment images."
243
245
  )
244
246
  return
245
247
 
@@ -253,66 +255,147 @@ class BatchCropAnything:
253
255
  # Load and process image
254
256
  self.original_image = imread(image_path)
255
257
 
256
- # Ensure image is 8-bit for SAM display (keeping original for saving)
257
- if self.original_image.dtype != np.uint8:
258
- image_for_display = (
259
- self.original_image / np.amax(self.original_image) * 255
260
- ).astype(np.uint8)
258
+ # For 3D/4D data, determine dimensions
259
+ if self.use_3d and len(self.original_image.shape) >= 3:
260
+ # Check shape to identify dimensions
261
+ if len(self.original_image.shape) == 4: # TZYX or similar
262
+ # Identify time dimension as first dim with size > 4 and < 400
263
+ # This is a heuristic to differentiate time from channels/small Z stacks
264
+ time_dim_idx = -1
265
+ for i, dim_size in enumerate(self.original_image.shape):
266
+ if 4 < dim_size < 400:
267
+ time_dim_idx = i
268
+ break
269
+
270
+ if time_dim_idx == 0: # TZYX format
271
+ # Keep as is, T is already the first dimension
272
+ self.image_layer = self.viewer.add_image(
273
+ self.original_image,
274
+ name=f"Image ({os.path.basename(image_path)})",
275
+ )
276
+ # Store time dimension info
277
+ self.time_dim_size = self.original_image.shape[0]
278
+ self.has_z_dim = True
279
+ elif (
280
+ time_dim_idx > 0
281
+ ): # Unusual format, we need to transpose
282
+ # Transpose to move T to first dimension
283
+ # Create permutation order that puts time_dim_idx first
284
+ perm_order = list(
285
+ range(len(self.original_image.shape))
286
+ )
287
+ perm_order.remove(time_dim_idx)
288
+ perm_order.insert(0, time_dim_idx)
289
+
290
+ transposed_image = np.transpose(
291
+ self.original_image, perm_order
292
+ )
293
+ self.original_image = (
294
+ transposed_image # Replace with transposed version
295
+ )
296
+
297
+ self.image_layer = self.viewer.add_image(
298
+ self.original_image,
299
+ name=f"Image ({os.path.basename(image_path)})",
300
+ )
301
+ # Store time dimension info
302
+ self.time_dim_size = self.original_image.shape[0]
303
+ self.has_z_dim = True
304
+ else:
305
+ # No time dimension found, treat as ZYX
306
+ self.image_layer = self.viewer.add_image(
307
+ self.original_image,
308
+ name=f"Image ({os.path.basename(image_path)})",
309
+ )
310
+ self.time_dim_size = 1
311
+ self.has_z_dim = True
312
+ elif (
313
+ len(self.original_image.shape) == 3
314
+ ): # Could be TYX or ZYX
315
+ # Check if first dimension is likely time (> 4, < 400)
316
+ if 4 < self.original_image.shape[0] < 400:
317
+ # Likely TYX format
318
+ self.image_layer = self.viewer.add_image(
319
+ self.original_image,
320
+ name=f"Image ({os.path.basename(image_path)})",
321
+ )
322
+ self.time_dim_size = self.original_image.shape[0]
323
+ self.has_z_dim = False
324
+ else:
325
+ # Likely ZYX format or another 3D format
326
+ self.image_layer = self.viewer.add_image(
327
+ self.original_image,
328
+ name=f"Image ({os.path.basename(image_path)})",
329
+ )
330
+ self.time_dim_size = 1
331
+ self.has_z_dim = True
332
+ else:
333
+ # Should not reach here with use_3d=True, but just in case
334
+ self.image_layer = self.viewer.add_image(
335
+ self.original_image,
336
+ name=f"Image ({os.path.basename(image_path)})",
337
+ )
338
+ self.time_dim_size = 1
339
+ self.has_z_dim = False
261
340
  else:
262
- image_for_display = self.original_image
263
-
264
- # Add image to viewer
265
- self.image_layer = self.viewer.add_image(
266
- image_for_display,
267
- name=f"Image ({os.path.basename(image_path)})",
268
- )
341
+ # Handle 2D data as before
342
+ if self.original_image.dtype != np.uint8:
343
+ image_for_display = (
344
+ self.original_image
345
+ / np.amax(self.original_image)
346
+ * 255
347
+ ).astype(np.uint8)
348
+ else:
349
+ image_for_display = self.original_image
350
+
351
+ # Add image to viewer
352
+ self.image_layer = self.viewer.add_image(
353
+ image_for_display,
354
+ name=f"Image ({os.path.basename(image_path)})",
355
+ )
269
356
 
270
357
  # Generate segmentation
271
- self._generate_segmentation(image_for_display)
358
+ self._generate_segmentation(self.original_image, image_path)
272
359
 
273
- except (Exception, ValueError) as e:
360
+ except (FileNotFoundError, ValueError, TypeError, OSError) as e:
274
361
  import traceback
275
362
 
276
363
  self.viewer.status = f"Error processing image: {str(e)}"
277
364
  traceback.print_exc()
365
+
278
366
  # Create empty segmentation in case of error
279
367
  if (
280
368
  hasattr(self, "original_image")
281
369
  and self.original_image is not None
282
370
  ):
283
- self.segmentation_result = np.zeros(
284
- self.original_image.shape[:2], dtype=np.uint32
285
- )
371
+ if self.use_3d:
372
+ shape = self.original_image.shape
373
+ else:
374
+ shape = self.original_image.shape[:2]
375
+
376
+ self.segmentation_result = np.zeros(shape, dtype=np.uint32)
286
377
  self.label_layer = self.viewer.add_labels(
287
378
  self.segmentation_result, name="Error: No Segmentation"
288
379
  )
289
380
 
290
- # --------------------------------------------------
291
- # Segmentation Generation and Control
292
- # --------------------------------------------------
293
-
294
- def _generate_segmentation(self, image):
295
- """Generate segmentation for the current image."""
296
- # Prepare for SAM (add color channel if needed)
297
- if len(image.shape) == 2:
298
- image_for_sam = image[:, :, np.newaxis].repeat(3, axis=2)
299
- else:
300
- image_for_sam = image
301
-
302
- # Store the current image for later regeneration if sensitivity changes
303
- self.current_image_for_segmentation = image_for_sam
381
+ def _generate_segmentation(self, image, image_path: str):
382
+ """Generate segmentation for the current image using SAM2."""
383
+ # Store the current image for later processing
384
+ self.current_image_for_segmentation = image
304
385
 
305
386
  # Generate segmentation with current sensitivity
306
- self.generate_segmentation_with_sensitivity()
387
+ self.generate_segmentation_with_sensitivity(image_path)
307
388
 
308
- def generate_segmentation_with_sensitivity(self, sensitivity=None):
389
+ def generate_segmentation_with_sensitivity(
390
+ self, image_path: str, sensitivity=None
391
+ ):
309
392
  """Generate segmentation with the specified sensitivity."""
310
393
  if sensitivity is not None:
311
394
  self.sensitivity = sensitivity
312
395
 
313
- if self.mobile_sam is None or self.mask_generator is None:
396
+ if self.predictor is None:
314
397
  self.viewer.status = (
315
- "SAM model not initialized. Cannot segment images."
398
+ "SAM2 model not initialized. Cannot segment images."
316
399
  )
317
400
  return
318
401
 
@@ -321,298 +404,723 @@ class BatchCropAnything:
321
404
  return
322
405
 
323
406
  try:
324
- # Map sensitivity (0-100) to SAM parameters
325
- # Higher sensitivity (100) = lower thresholds = more objects detected
326
- # Lower sensitivity (0) = higher thresholds = fewer objects detected
407
+ # Map sensitivity (0-100) to SAM2 parameters
408
+ # For SAM2, adjust confidence threshold based on sensitivity
409
+ confidence_threshold = (
410
+ 0.9 - (self.sensitivity / 100) * 0.4
411
+ ) # Range from 0.9 to 0.5
412
+
413
+ # Process based on dimension mode
414
+ if self.use_3d:
415
+ # Process 3D data
416
+ self._generate_3d_segmentation(
417
+ confidence_threshold, image_path
418
+ )
419
+ else:
420
+ # Process 2D data
421
+ self._generate_2d_segmentation(confidence_threshold)
422
+
423
+ except (
424
+ ValueError,
425
+ RuntimeError,
426
+ torch.cuda.OutOfMemoryError,
427
+ TypeError,
428
+ ) as e:
429
+ import traceback
327
430
 
328
- # pred_iou_thresh range: 0.92 (low sensitivity) to 0.75 (high sensitivity)
329
- pred_iou = 0.92 - (self.sensitivity / 100) * 0.17
431
+ self.viewer.status = f"Error generating segmentation: {str(e)}"
432
+ traceback.print_exc()
330
433
 
331
- # stability_score_thresh range: 0.97 (low sensitivity) to 0.85 (high sensitivity)
332
- stability = 0.97 - (self.sensitivity / 100) * 0.12
434
+ def _generate_2d_segmentation(self, confidence_threshold):
435
+ """Generate 2D segmentation using SAM2 Image Predictor."""
436
+ # Ensure image is in the correct format for SAM2
437
+ image = self.current_image_for_segmentation
438
+
439
+ # Handle resizing for very large images
440
+ orig_shape = image.shape[:2]
441
+ image_mp = (orig_shape[0] * orig_shape[1]) / 1e6
442
+ max_mp = 2.0 # Maximum image size in megapixels
443
+
444
+ if image_mp > max_mp:
445
+ scale_factor = np.sqrt(max_mp / image_mp)
446
+ new_height = int(orig_shape[0] * scale_factor)
447
+ new_width = int(orig_shape[1] * scale_factor)
448
+
449
+ self.viewer.status = f"Downscaling image from {orig_shape} to {(new_height, new_width)} for processing"
450
+
451
+ # Resize image
452
+ resized_image = resize(
453
+ image,
454
+ (new_height, new_width),
455
+ anti_aliasing=True,
456
+ preserve_range=True,
457
+ ).astype(
458
+ np.float32
459
+ ) # Convert to float32
460
+
461
+ self.current_scale_factor = scale_factor
462
+ else:
463
+ # Convert to float32 format
464
+ if image.dtype != np.float32:
465
+ resized_image = image.astype(np.float32)
466
+ else:
467
+ resized_image = image
468
+ self.current_scale_factor = 1.0
469
+
470
+ # Ensure image is in RGB format for SAM2
471
+ if len(resized_image.shape) == 2:
472
+ # Convert grayscale to RGB
473
+ resized_image = np.stack([resized_image] * 3, axis=-1)
474
+ elif len(resized_image.shape) == 3 and resized_image.shape[2] == 1:
475
+ # Convert single channel to RGB
476
+ resized_image = np.concatenate([resized_image] * 3, axis=2)
477
+ elif len(resized_image.shape) == 3 and resized_image.shape[2] > 3:
478
+ # Use first 3 channels
479
+ resized_image = resized_image[:, :, :3]
480
+
481
+ # Normalize the image to [0,1] range if it's not already
482
+ if resized_image.max() > 1.0:
483
+ resized_image = resized_image / 255.0
484
+
485
+ # Set SAM2 prediction parameters based on sensitivity
486
+ with torch.inference_mode(), torch.autocast(
487
+ "cuda", dtype=torch.float32
488
+ ):
489
+ # Set the image in the predictor
490
+ self.predictor.set_image(resized_image)
491
+
492
+ # Use automatic points generation with confidence threshold
493
+ masks, scores, _ = self.predictor.predict(
494
+ point_coords=None,
495
+ point_labels=None,
496
+ box=None,
497
+ multimask_output=True,
498
+ )
333
499
 
334
- # min_mask_region_area range: 300 (low sensitivity) to 30 (high sensitivity)
335
- min_area = 300 - (self.sensitivity / 100) * 270
500
+ # Filter masks by confidence threshold
501
+ valid_masks = scores > confidence_threshold
502
+ masks = masks[valid_masks]
503
+ scores = scores[valid_masks]
336
504
 
337
- # Configure mask generator with sensitivity-adjusted parameters
338
- self.mask_generator.pred_iou_thresh = pred_iou
339
- self.mask_generator.stability_score_thresh = stability
340
- self.mask_generator.min_mask_region_area = min_area
505
+ # Convert masks to label image
506
+ labels = np.zeros(resized_image.shape[:2], dtype=np.uint32)
507
+ self.label_info = {} # Reset label info
341
508
 
342
- # Apply gamma correction based on sensitivity
343
- # Low sensitivity: gamma > 1 (brighten image)
344
- # High sensitivity: gamma < 1 (darken image)
345
- gamma = (
346
- 1.5 - (self.sensitivity / 100) * 1.0
347
- ) # Range from 1.5 to 0.5
509
+ for i, mask in enumerate(masks):
510
+ label_id = i + 1 # Start label IDs from 1
511
+ labels[mask] = label_id
348
512
 
349
- # Apply gamma correction to the input image
350
- image_for_processing = self.current_image_for_segmentation.copy()
513
+ # Calculate label information
514
+ area = np.sum(mask)
515
+ y_indices, x_indices = np.where(mask)
516
+ center_y = np.mean(y_indices) if len(y_indices) > 0 else 0
517
+ center_x = np.mean(x_indices) if len(x_indices) > 0 else 0
351
518
 
352
- # Convert to float for proper gamma correction
353
- image_float = image_for_processing.astype(np.float32) / 255.0
519
+ # Store label info
520
+ self.label_info[label_id] = {
521
+ "area": area,
522
+ "center_y": center_y,
523
+ "center_x": center_x,
524
+ "score": float(scores[i]),
525
+ }
354
526
 
355
- # Apply gamma correction
356
- image_gamma = np.power(image_float, gamma)
527
+ # Handle upscaling if needed
528
+ if self.current_scale_factor < 1.0:
529
+ labels = resize(
530
+ labels,
531
+ orig_shape,
532
+ order=0, # Nearest neighbor interpolation
533
+ preserve_range=True,
534
+ anti_aliasing=False,
535
+ ).astype(np.uint32)
357
536
 
358
- # Convert back to uint8
359
- image_gamma = (image_gamma * 255).astype(np.uint8)
537
+ # Sort labels by area (largest first)
538
+ self.label_info = dict(
539
+ sorted(
540
+ self.label_info.items(),
541
+ key=lambda item: item[1]["area"],
542
+ reverse=True,
543
+ )
544
+ )
360
545
 
361
- # Check if the image is very large and needs downscaling
362
- orig_shape = image_gamma.shape[:2] # (height, width)
546
+ # Save segmentation result
547
+ self.segmentation_result = labels
548
+
549
+ # Update the label layer
550
+ self._update_label_layer()
363
551
 
364
- # Calculate image size in megapixels
365
- image_mp = (orig_shape[0] * orig_shape[1]) / 1e6
552
+ def _generate_3d_segmentation(self, confidence_threshold, image_path):
553
+ """
554
+ Initialize 3D segmentation using SAM2 Video Predictor.
555
+ This correctly sets up interactive segmentation following SAM2's video approach.
556
+ """
557
+ try:
558
+ # Handle image_path - make sure it's a string
559
+ if not isinstance(image_path, str):
560
+ image_path = self.images[self.current_index]
366
561
 
367
- # If image is larger than 2 megapixels, downscale it
368
- max_mp = 2.0 # Maximum image size in megapixels
369
- scale_factor = 1.0
562
+ # Initialize empty segmentation
563
+ volume_shape = self.current_image_for_segmentation.shape
564
+ labels = np.zeros(volume_shape, dtype=np.uint32)
565
+ self.segmentation_result = labels
370
566
 
371
- if image_mp > max_mp:
372
- scale_factor = np.sqrt(max_mp / image_mp)
373
- new_height = int(orig_shape[0] * scale_factor)
374
- new_width = int(orig_shape[1] * scale_factor)
567
+ # Create a temp directory for the MP4 conversion if needed
568
+ import os
569
+ import tempfile
375
570
 
376
- self.viewer.status = f"Downscaling image from {orig_shape} to {(new_height, new_width)} for processing (scale: {scale_factor:.2f})"
571
+ temp_dir = tempfile.gettempdir()
572
+ mp4_path = os.path.join(
573
+ temp_dir, f"temp_volume_{os.path.basename(image_path)}.mp4"
574
+ )
377
575
 
378
- # Resize the image for processing
379
- image_gamma_resized = resize(
380
- image_gamma,
381
- (new_height, new_width),
382
- anti_aliasing=True,
383
- preserve_range=True,
384
- ).astype(np.uint8)
576
+ # If we need to save a modified version for MP4 conversion
577
+ need_temp_tif = False
578
+ temp_tif_path = None
385
579
 
386
- # Store scale factor for later use
387
- self.current_scale_factor = scale_factor
388
- else:
389
- image_gamma_resized = image_gamma
390
- self.current_scale_factor = 1.0
580
+ # Check if we have a 4D volume with Z dimension
581
+ if (
582
+ hasattr(self, "has_z_dim")
583
+ and self.has_z_dim
584
+ and len(self.current_image_for_segmentation.shape) == 4
585
+ ):
586
+ # We need to convert the 4D TZYX to a 3D TYX for proper video conversion
587
+ # by taking maximum intensity projection of Z for each time point
588
+ self.viewer.status = (
589
+ "Converting 4D TZYX volume to 3D TYX for SAM2..."
590
+ )
391
591
 
392
- self.viewer.status = f"Generating segmentation with sensitivity {self.sensitivity} (gamma={gamma:.2f})..."
592
+ # Create maximum intensity projection along Z axis (axis 1 in TZYX)
593
+ projected_volume = np.max(
594
+ self.current_image_for_segmentation, axis=1
595
+ )
393
596
 
394
- # Generate masks with gamma-corrected and potentially resized image
395
- masks = self.mask_generator.generate(image_gamma_resized)
396
- self.viewer.status = f"Generated {len(masks)} masks"
597
+ # Save this as a temporary TIF for MP4 conversion
598
+ temp_tif_path = os.path.join(
599
+ temp_dir, f"temp_projected_{os.path.basename(image_path)}"
600
+ )
601
+ imwrite(temp_tif_path, projected_volume)
602
+ need_temp_tif = True
397
603
 
398
- if not masks:
604
+ # Convert the projected TIF to MP4
399
605
  self.viewer.status = (
400
- "No segments detected. Try increasing the sensitivity."
606
+ "Converting projected 3D volume to MP4 format for SAM2..."
401
607
  )
402
- # Create empty label layer
403
- shape = self.current_image_for_segmentation.shape[:2]
404
- self.segmentation_result = np.zeros(shape, dtype=np.uint32)
608
+ mp4_path = tif_to_mp4(temp_tif_path)
609
+ else:
610
+ # Convert original volume to video format for SAM2
611
+ self.viewer.status = (
612
+ "Converting 3D volume to MP4 format for SAM2..."
613
+ )
614
+ mp4_path = tif_to_mp4(image_path)
405
615
 
406
- # Remove existing label layer if exists
407
- for layer in list(self.viewer.layers):
408
- if (
409
- isinstance(layer, Labels)
410
- and "Segmentation" in layer.name
411
- ):
412
- self.viewer.layers.remove(layer)
616
+ # Initialize SAM2 state with the video
617
+ self.viewer.status = "Initializing SAM2 Video Predictor..."
618
+ with torch.inference_mode(), torch.autocast(
619
+ "cuda", dtype=torch.bfloat16
620
+ ):
621
+ self._sam2_state = self.predictor.init_state(mp4_path)
413
622
 
414
- # Add new empty label layer
415
- self.label_layer = self.viewer.add_labels(
416
- self.segmentation_result,
417
- name=f"Segmentation ({os.path.basename(self.images[self.current_index])})",
418
- opacity=0.7,
623
+ # Store needed state for 3D processing
624
+ self._sam2_next_obj_id = 1
625
+ self._sam2_prompts = (
626
+ {}
627
+ ) # Store prompts for each object (points, labels, box)
628
+
629
+ # Update the label layer with empty segmentation
630
+ self._update_label_layer()
631
+
632
+ # Replace the click handler for interactive 3D segmentation
633
+ if self.label_layer is not None and hasattr(
634
+ self.label_layer, "mouse_drag_callbacks"
635
+ ):
636
+ for callback in list(self.label_layer.mouse_drag_callbacks):
637
+ self.label_layer.mouse_drag_callbacks.remove(callback)
638
+
639
+ # Add 3D-specific click handler
640
+ self.label_layer.mouse_drag_callbacks.append(
641
+ self._on_3d_label_clicked
419
642
  )
420
643
 
421
- # Make the label layer active
422
- self.viewer.layers.selection.active = self.label_layer
644
+ # Set the viewer to show the first frame
645
+ if hasattr(self.viewer, "dims") and self.viewer.dims.ndim > 2:
646
+ self.viewer.dims.set_point(
647
+ 0, 0
648
+ ) # Set the first dimension (typically time/z) to 0
649
+
650
+ # Clean up temporary file if we created one
651
+ if (
652
+ need_temp_tif
653
+ and temp_tif_path
654
+ and os.path.exists(temp_tif_path)
655
+ ):
656
+ with contextlib.suppress(Exception):
657
+ os.remove(temp_tif_path)
658
+
659
+ # Show instructions
660
+ self.viewer.status = (
661
+ "3D Mode active: Navigate to the first frame where object appears, then click. "
662
+ "Use Shift+click for negative points (to remove areas). "
663
+ "Segmentation will be propagated to all frames automatically."
664
+ )
665
+
666
+ return True
667
+
668
+ except (
669
+ FileNotFoundError,
670
+ RuntimeError,
671
+ torch.cuda.OutOfMemoryError,
672
+ ValueError,
673
+ OSError,
674
+ ) as e:
675
+ import traceback
676
+
677
+ self.viewer.status = f"Error in 3D segmentation setup: {str(e)}"
678
+ traceback.print_exc()
679
+ return False
680
+
681
+ def _on_3d_label_clicked(self, layer, event):
682
+ """Handle click on 3D label layer to add a prompt for segmentation."""
683
+ try:
684
+ if event.button != 1:
685
+ return
686
+
687
+ coords = layer.world_to_data(event.position)
688
+ if len(coords) == 3:
689
+ z, y, x = map(int, coords)
690
+ elif len(coords) == 2:
691
+ z = int(self.viewer.dims.current_step[0])
692
+ y, x = map(int, coords)
693
+ else:
694
+ self.viewer.status = (
695
+ f"Unexpected coordinate dimensions: {coords}"
696
+ )
423
697
  return
424
698
 
425
- # Process segmentation masks
426
- # If image was downscaled, we need to ensure masks are upscaled correctly
427
- if self.current_scale_factor < 1.0:
428
- # Upscale the segmentation masks to match the original image dimensions
429
- self._process_segmentation_masks_with_scaling(
430
- masks, self.current_image_for_segmentation.shape[:2]
699
+ # Check if Shift key is pressed
700
+ is_negative = "Shift" in event.modifiers
701
+ point_label = -1 if is_negative else 1
702
+
703
+ # Initialize a unique object ID for this click
704
+ if not hasattr(self, "_sam2_next_obj_id"):
705
+ self._sam2_next_obj_id = 1
706
+
707
+ # Get current object ID (or create new one)
708
+ label_id = self.segmentation_result[z, y, x]
709
+ if is_negative and label_id > 0:
710
+ # Use existing object ID for negative points
711
+ ann_obj_id = label_id
712
+ else:
713
+ # Create new object for positive points on background
714
+ ann_obj_id = self._sam2_next_obj_id
715
+ if point_label > 0 and label_id == 0:
716
+ self._sam2_next_obj_id += 1
717
+
718
+ # Find or create points layer for this object
719
+ points_layer = None
720
+ for layer in list(self.viewer.layers):
721
+ if f"Points for Object {ann_obj_id}" in layer.name:
722
+ points_layer = layer
723
+ break
724
+
725
+ if points_layer is None:
726
+ # Create new points layer for this object
727
+ points_layer = self.viewer.add_points(
728
+ np.array([[z, y, x]]),
729
+ name=f"Points for Object {ann_obj_id}",
730
+ size=10,
731
+ face_color="green" if point_label > 0 else "red",
732
+ border_color="white",
733
+ border_width=1,
734
+ opacity=0.8,
431
735
  )
736
+ # Initialize points for this object
737
+ if not hasattr(self, "sam2_points_by_obj"):
738
+ self.sam2_points_by_obj = {}
739
+ self.sam2_labels_by_obj = {}
740
+
741
+ self.sam2_points_by_obj[ann_obj_id] = [[x, y]]
742
+ self.sam2_labels_by_obj[ann_obj_id] = [point_label]
432
743
  else:
433
- self._process_segmentation_masks(
434
- masks, self.current_image_for_segmentation.shape[:2]
744
+ # Add to existing points layer
745
+ current_points = points_layer.data
746
+ new_points = np.vstack([current_points, [z, y, x]])
747
+ points_layer.data = new_points
748
+
749
+ # Add to existing point lists
750
+ if not hasattr(self, "sam2_points_by_obj"):
751
+ self.sam2_points_by_obj = {}
752
+ self.sam2_labels_by_obj = {}
753
+
754
+ if ann_obj_id not in self.sam2_points_by_obj:
755
+ self.sam2_points_by_obj[ann_obj_id] = []
756
+ self.sam2_labels_by_obj[ann_obj_id] = []
757
+
758
+ self.sam2_points_by_obj[ann_obj_id].append([x, y])
759
+ self.sam2_labels_by_obj[ann_obj_id].append(point_label)
760
+
761
+ # Perform SAM2 segmentation
762
+ if hasattr(self, "_sam2_state") and self._sam2_state is not None:
763
+ points = np.array(
764
+ self.sam2_points_by_obj[ann_obj_id], dtype=np.float32
765
+ )
766
+ labels = np.array(
767
+ self.sam2_labels_by_obj[ann_obj_id], dtype=np.int32
435
768
  )
436
769
 
437
- # Clear selected labels since segmentation has changed
438
- self.selected_labels = set()
770
+ self.viewer.status = f"Processing object at frame {z}..."
439
771
 
440
- # Update table if it exists
441
- if self.label_table_widget:
442
- self._populate_label_table(self.label_table_widget)
772
+ _, out_obj_ids, out_mask_logits = (
773
+ self.predictor.add_new_points_or_box(
774
+ inference_state=self._sam2_state,
775
+ frame_idx=z,
776
+ obj_id=ann_obj_id,
777
+ points=points,
778
+ labels=labels,
779
+ )
780
+ )
443
781
 
444
- except (Exception, ValueError) as e:
782
+ # Convert logits to mask and update segmentation
783
+ mask = (out_mask_logits[0] > 0.0).cpu().numpy()
784
+
785
+ # Fix mask dimensions if needed
786
+ if mask.ndim > 2:
787
+ mask = mask.squeeze()
788
+
789
+ # Check mask dimensions and resize if needed
790
+ if mask.shape != self.segmentation_result[z].shape:
791
+ from skimage.transform import resize
792
+
793
+ mask = resize(
794
+ mask.astype(float),
795
+ self.segmentation_result[z].shape,
796
+ order=0,
797
+ preserve_range=True,
798
+ anti_aliasing=False,
799
+ ).astype(bool)
800
+
801
+ # Apply the mask to current frame
802
+ # For negative points, only remove from the current object
803
+ if point_label < 0:
804
+ # Remove only from current object
805
+ self.segmentation_result[z][
806
+ (self.segmentation_result[z] == ann_obj_id) & mask
807
+ ] = 0
808
+ else:
809
+ # Add to current object (only overwrite background)
810
+ self.segmentation_result[z][
811
+ mask & (self.segmentation_result[z] == 0)
812
+ ] = ann_obj_id
813
+
814
+ # Automatically propagate to other frames
815
+ self._propagate_mask_for_current_object(ann_obj_id, z)
816
+
817
+ # Update label layer
818
+ self._update_label_layer()
819
+
820
+ # Update label table if needed
821
+ if (
822
+ hasattr(self, "label_table_widget")
823
+ and self.label_table_widget is not None
824
+ ):
825
+ self._populate_label_table(self.label_table_widget)
826
+
827
+ self.viewer.status = (
828
+ f"Updated 3D object {ann_obj_id} across all frames"
829
+ )
830
+ else:
831
+ self.viewer.status = "SAM2 3D state not initialized"
832
+
833
+ except (
834
+ IndexError,
835
+ KeyError,
836
+ ValueError,
837
+ RuntimeError,
838
+ torch.cuda.OutOfMemoryError,
839
+ ) as e:
445
840
  import traceback
446
841
 
447
- self.viewer.status = f"Error generating segmentation: {str(e)}"
842
+ self.viewer.status = f"Error in 3D click handler: {str(e)}"
448
843
  traceback.print_exc()
449
844
 
450
- def _process_segmentation_masks(self, masks, shape):
451
- """Process segmentation masks and create label layer."""
452
- # Create label image from masks
453
- labels = np.zeros(shape, dtype=np.uint32)
454
- self.label_info = {} # Reset label info
845
+ def _propagate_mask_for_current_object(self, obj_id, current_frame_idx):
846
+ """
847
+ Propagate the mask for the current object from the given frame to all other frames.
848
+ Uses SAM2's video propagation with proper error handling.
455
849
 
456
- for i, mask_data in enumerate(masks):
457
- mask = mask_data["segmentation"]
458
- label_id = i + 1 # Start label IDs from 1
459
- labels[mask] = label_id
850
+ Parameters:
851
+ obj_id: The ID of the object to propagate
852
+ current_frame_idx: The frame index where the object was identified
853
+ """
854
+ try:
855
+ if not hasattr(self, "_sam2_state") or self._sam2_state is None:
856
+ self.viewer.status = (
857
+ "SAM2 3D state not initialized for propagation"
858
+ )
859
+ return
460
860
 
461
- # Calculate label information
462
- area = np.sum(mask)
463
- y_indices, x_indices = np.where(mask)
464
- center_y = np.mean(y_indices) if len(y_indices) > 0 else 0
465
- center_x = np.mean(x_indices) if len(x_indices) > 0 else 0
861
+ total_frames = self.segmentation_result.shape[0]
862
+ self.viewer.status = f"Propagating object {obj_id} through all {total_frames} frames..."
466
863
 
467
- # Store label info
468
- self.label_info[label_id] = {
469
- "area": area,
470
- "center_y": center_y,
471
- "center_x": center_x,
472
- "score": mask_data.get("stability_score", 0),
473
- }
864
+ # Create a progress layer for visualization
865
+ progress_layer = None
866
+ for layer in list(self.viewer.layers):
867
+ if "Propagation Progress" in layer.name:
868
+ progress_layer = layer
869
+ break
474
870
 
475
- # Sort labels by area (largest first)
476
- self.label_info = dict(
477
- sorted(
478
- self.label_info.items(),
479
- key=lambda item: item[1]["area"],
480
- reverse=True,
871
+ if progress_layer is None:
872
+ progress_data = np.zeros_like(
873
+ self.segmentation_result, dtype=float
874
+ )
875
+ progress_layer = self.viewer.add_image(
876
+ progress_data,
877
+ name="Propagation Progress",
878
+ colormap="magma",
879
+ opacity=0.3,
880
+ visible=True,
881
+ )
882
+
883
+ # Update current frame in the progress layer
884
+ progress_data = progress_layer.data
885
+ current_mask = (
886
+ self.segmentation_result[current_frame_idx] == obj_id
481
887
  )
482
- )
888
+ progress_data[current_frame_idx] = current_mask.astype(float) * 0.8
889
+ progress_layer.data = progress_data
890
+
891
+ # Try to perform SAM2 propagation with error handling
892
+ try:
893
+ # Use torch.inference_mode() and torch.autocast to ensure consistent dtypes
894
+ with torch.inference_mode(), torch.autocast(
895
+ "cuda", dtype=torch.float32
896
+ ):
897
+ # Attempt to run SAM2 propagation - this will iterate through all frames
898
+ for (
899
+ frame_idx,
900
+ object_ids,
901
+ mask_logits,
902
+ ) in self.predictor.propagate_in_video(self._sam2_state):
903
+ if frame_idx >= total_frames:
904
+ continue
905
+
906
+ # Find our object ID in the results
907
+ # obj_mask = None
908
+ for i, prop_obj_id in enumerate(object_ids):
909
+ if prop_obj_id == obj_id:
910
+ # Get the mask for our object
911
+ mask = (mask_logits[i] > 0.0).cpu().numpy()
912
+
913
+ # Fix dimensions if needed
914
+ if mask.ndim > 2:
915
+ mask = mask.squeeze()
916
+
917
+ # Resize if needed
918
+ if (
919
+ mask.shape
920
+ != self.segmentation_result[
921
+ frame_idx
922
+ ].shape
923
+ ):
924
+ from skimage.transform import resize
925
+
926
+ mask = resize(
927
+ mask.astype(float),
928
+ self.segmentation_result[
929
+ frame_idx
930
+ ].shape,
931
+ order=0,
932
+ preserve_range=True,
933
+ anti_aliasing=False,
934
+ ).astype(bool)
935
+
936
+ # Update segmentation - only replacing background pixels
937
+ self.segmentation_result[frame_idx][
938
+ mask
939
+ & (
940
+ self.segmentation_result[frame_idx]
941
+ == 0
942
+ )
943
+ ] = obj_id
944
+
945
+ # Update progress visualization
946
+ progress_data = progress_layer.data
947
+ progress_data[frame_idx] = (
948
+ mask.astype(float) * 0.8
949
+ )
950
+ progress_layer.data = progress_data
951
+
952
+ # Update status occasionally
953
+ if frame_idx % 10 == 0:
954
+ self.viewer.status = f"Propagating: frame {frame_idx+1}/{total_frames}"
955
+
956
+ except RuntimeError as e:
957
+ # If we get a dtype mismatch or other error, the current frame's mask to other frames
958
+ self.viewer.status = f"SAM2 propagation failed with error: {str(e)}. Falling back to alternative method."
959
+
960
+ # Use the current frame's mask for propagation
961
+ for frame_idx in range(total_frames):
962
+ if (
963
+ frame_idx != current_frame_idx
964
+ ): # Skip current frame as it's already done
965
+ # Only replace background pixels with the current frame's object
966
+ self.segmentation_result[frame_idx][
967
+ current_mask
968
+ & (self.segmentation_result[frame_idx] == 0)
969
+ ] = obj_id
970
+
971
+ # Update progress layer
972
+ progress_data = progress_layer.data
973
+ progress_data[frame_idx] = (
974
+ current_mask.astype(float) * 0.5
975
+ ) # Different intensity to indicate fallback
976
+ progress_layer.data = progress_data
977
+
978
+ # Update status occasionally
979
+ if frame_idx % 10 == 0:
980
+ self.viewer.status = f"Fallback propagation: frame {frame_idx+1}/{total_frames}"
981
+
982
+ # Remove progress layer after 2 seconds
983
+ import threading
984
+
985
+ def remove_progress():
986
+ import time
987
+
988
+ time.sleep(2)
989
+ for layer in list(self.viewer.layers):
990
+ if "Propagation Progress" in layer.name:
991
+ self.viewer.layers.remove(layer)
483
992
 
484
- # Save segmentation result
485
- self.segmentation_result = labels
993
+ threading.Thread(target=remove_progress).start()
486
994
 
487
- # Remove existing label layer if exists
488
- for layer in list(self.viewer.layers):
489
- if isinstance(layer, Labels) and "Segmentation" in layer.name:
490
- self.viewer.layers.remove(layer)
995
+ self.viewer.status = f"Propagation of object {obj_id} complete"
491
996
 
492
- # Add label layer to viewer
493
- self.label_layer = self.viewer.add_labels(
494
- labels,
495
- name=f"Segmentation ({os.path.basename(self.images[self.current_index])})",
496
- opacity=0.7,
497
- )
997
+ except (
998
+ IndexError,
999
+ ValueError,
1000
+ RuntimeError,
1001
+ torch.cuda.OutOfMemoryError,
1002
+ TypeError,
1003
+ ) as e:
1004
+ import traceback
498
1005
 
499
- # Make the label layer active by default
500
- self.viewer.layers.selection.active = self.label_layer
1006
+ self.viewer.status = f"Error in propagation: {str(e)}"
1007
+ traceback.print_exc()
501
1008
 
502
- # Disconnect existing callbacks if any
503
- if (
504
- hasattr(self, "label_layer")
505
- and self.label_layer is not None
506
- and hasattr(self.label_layer, "mouse_drag_callbacks")
507
- ):
508
- # Remove old callbacks
509
- for callback in list(self.label_layer.mouse_drag_callbacks):
510
- self.label_layer.mouse_drag_callbacks.remove(callback)
511
-
512
- # Connect mouse click event to label selection
513
- self.label_layer.mouse_drag_callbacks.append(self._on_label_clicked)
514
-
515
- # image_name = os.path.basename(self.images[self.current_index])
516
- self.viewer.status = f"Loaded image {self.current_index + 1}/{len(self.images)} - Found {len(masks)} segments"
517
-
518
- # New method for handling scaled segmentation masks
519
- def _process_segmentation_masks_with_scaling(self, masks, original_shape):
520
- """Process segmentation masks with scaling to match the original image size."""
521
- # Create label image from masks
522
- # First determine the size of the mask predictions (which are at the downscaled resolution)
523
- if not masks:
1009
+ def _add_3d_prompt(self, prompt_coords):
1010
+ """
1011
+ Given a 3D coordinate (x, y, z), run SAM2 video predictor to segment the object at that point,
1012
+ update the segmentation result and label layer.
1013
+ """
1014
+ if not hasattr(self, "_sam2_state") or self._sam2_state is None:
1015
+ self.viewer.status = "SAM2 3D state not initialized."
524
1016
  return
525
1017
 
526
- mask_shape = masks[0]["segmentation"].shape
527
-
528
- # Create an empty label image at the downscaled resolution
529
- downscaled_labels = np.zeros(mask_shape, dtype=np.uint32)
530
- self.label_info = {} # Reset label info
531
-
532
- # Fill in the downscaled labels
533
- for i, mask_data in enumerate(masks):
534
- mask = mask_data["segmentation"]
535
- label_id = i + 1 # Start label IDs from 1
536
- downscaled_labels[mask] = label_id
537
-
538
- # Store basic label info
539
- area = np.sum(mask)
540
- y_indices, x_indices = np.where(mask)
541
- center_y = np.mean(y_indices) if len(y_indices) > 0 else 0
542
- center_x = np.mean(x_indices) if len(x_indices) > 0 else 0
543
-
544
- # Scale centers to original image coordinates
545
- center_y_orig = center_y / self.current_scale_factor
546
- center_x_orig = center_x / self.current_scale_factor
1018
+ if self.predictor is None:
1019
+ self.viewer.status = "SAM2 predictor not initialized."
1020
+ return
547
1021
 
548
- # Store label info at original scale
549
- self.label_info[label_id] = {
550
- "area": area
551
- / (
552
- self.current_scale_factor**2
553
- ), # Approximate area in original scale
554
- "center_y": center_y_orig,
555
- "center_x": center_x_orig,
556
- "score": mask_data.get("stability_score", 0),
557
- }
1022
+ # Prepare prompt for SAM2: point_coords is [[x, y, t]], point_labels is [1]
1023
+ x, y, z = prompt_coords
1024
+ point_coords = np.array([[x, y, z]])
1025
+ point_labels = np.array([1]) # 1 = foreground
558
1026
 
559
- # Upscale the labels to the original image size
560
- upscaled_labels = resize(
561
- downscaled_labels,
562
- original_shape,
563
- order=0, # Nearest neighbor interpolation
564
- preserve_range=True,
565
- anti_aliasing=False,
566
- ).astype(np.uint32)
1027
+ with torch.inference_mode(), torch.autocast(
1028
+ "cuda", dtype=torch.bfloat16
1029
+ ):
1030
+ masks, scores, _ = self.predictor.predict(
1031
+ state=self._sam2_state,
1032
+ point_coords=point_coords,
1033
+ point_labels=point_labels,
1034
+ multimask_output=True,
1035
+ )
567
1036
 
568
- # Sort labels by area (largest first)
569
- self.label_info = dict(
570
- sorted(
571
- self.label_info.items(),
572
- key=lambda item: item[1]["area"],
573
- reverse=True,
1037
+ # Pick the best mask (highest score)
1038
+ if masks is not None and len(masks) > 0:
1039
+ best_idx = np.argmax(scores)
1040
+ mask = masks[best_idx]
1041
+ obj_id = self._sam2_next_obj_id
1042
+ self.segmentation_result[mask] = obj_id
1043
+ self._sam2_next_obj_id += 1
1044
+ self.viewer.status = (
1045
+ f"Added object {obj_id} at (x={x}, y={y}, z={z})"
574
1046
  )
575
- )
1047
+ self._update_label_layer()
1048
+ else:
1049
+ self.viewer.status = "No mask found for this prompt."
1050
+
1051
+ def on_apply_propagate(self):
1052
+ """Propagate masks across the video and update the segmentation layer."""
1053
+ self.viewer.status = "Propagating masks across all frames..."
1054
+ self.viewer.window._qt_window.setCursor(Qt.WaitCursor)
1055
+
1056
+ self.segmentation_result[:] = 0
1057
+
1058
+ for (
1059
+ frame_idx,
1060
+ object_ids,
1061
+ mask_logits,
1062
+ ) in self.predictor.propagate_in_video(self._sam2_state):
1063
+ masks = (mask_logits > 0.0).cpu().numpy()
1064
+ if frame_idx >= self.segmentation_result.shape[0]:
1065
+ print(
1066
+ f"Warning: frame_idx {frame_idx} out of bounds for segmentation_result with shape {self.segmentation_result.shape}"
1067
+ )
1068
+ continue
1069
+ for i, obj_id in enumerate(object_ids):
1070
+ self.segmentation_result[frame_idx][masks[i]] = obj_id
1071
+ self.viewer.status = f"Propagating: frame {frame_idx+1}"
576
1072
 
577
- # Save segmentation result
578
- self.segmentation_result = upscaled_labels
1073
+ self._update_label_layer()
1074
+ self.viewer.status = "Propagation complete!"
1075
+ self.viewer.window._qt_window.setCursor(Qt.ArrowCursor)
579
1076
 
580
- # Remove existing label layer if exists
1077
+ def _update_label_layer(self):
1078
+ """Update the label layer in the viewer."""
1079
+ # Remove existing label layer if it exists
581
1080
  for layer in list(self.viewer.layers):
582
1081
  if isinstance(layer, Labels) and "Segmentation" in layer.name:
583
1082
  self.viewer.layers.remove(layer)
584
1083
 
585
1084
  # Add label layer to viewer
586
1085
  self.label_layer = self.viewer.add_labels(
587
- upscaled_labels,
1086
+ self.segmentation_result,
588
1087
  name=f"Segmentation ({os.path.basename(self.images[self.current_index])})",
589
1088
  opacity=0.7,
590
1089
  )
591
1090
 
592
- # Make the label layer active by default
593
- self.viewer.layers.selection.active = self.label_layer
594
-
595
- # Disconnect existing callbacks if any
596
- if (
597
- hasattr(self, "label_layer")
598
- and self.label_layer is not None
599
- and hasattr(self.label_layer, "mouse_drag_callbacks")
600
- ):
601
- # Remove old callbacks
602
- for callback in list(self.label_layer.mouse_drag_callbacks):
603
- self.label_layer.mouse_drag_callbacks.remove(callback)
1091
+ # Create points layer for interaction if it doesn't exist
1092
+ points_layer = None
1093
+ for layer in list(self.viewer.layers):
1094
+ if "Points" in layer.name:
1095
+ points_layer = layer
1096
+ break
1097
+
1098
+ if points_layer is None:
1099
+ # Initialize an empty points layer
1100
+ points_layer = self.viewer.add_points(
1101
+ np.zeros((0, 2 if not self.use_3d else 3)),
1102
+ name="Points (Click to Add)",
1103
+ size=10,
1104
+ face_color="green",
1105
+ border_color="white",
1106
+ border_width=1,
1107
+ opacity=0.8,
1108
+ )
604
1109
 
605
- # Connect mouse click event to label selection
606
- self.label_layer.mouse_drag_callbacks.append(self._on_label_clicked)
1110
+ # Connect points layer mouse click event
1111
+ points_layer.mouse_drag_callbacks.append(self._on_points_clicked)
607
1112
 
608
- self.viewer.status = f"Loaded image {self.current_index + 1}/{len(self.images)} - Found {len(masks)} segments"
1113
+ # Make the points layer active to encourage interaction with it
1114
+ self.viewer.layers.selection.active = points_layer
609
1115
 
610
- # --------------------------------------------------
611
- # Label Selection and UI Elements
612
- # --------------------------------------------------
1116
+ # Update status
1117
+ n_labels = len(np.unique(self.segmentation_result)) - (
1118
+ 1 if 0 in np.unique(self.segmentation_result) else 0
1119
+ )
1120
+ self.viewer.status = f"Loaded image {self.current_index + 1}/{len(self.images)} - Found {n_labels} segments"
613
1121
 
614
- def _on_label_clicked(self, layer, event):
615
- """Handle label selection on mouse click."""
1122
+ def _on_points_clicked(self, layer, event):
1123
+ """Handle clicks on the points layer for adding/removing points."""
616
1124
  try:
617
1125
  # Only process clicks, not drags
618
1126
  if event.type != "mouse_press":
@@ -621,39 +1129,799 @@ class BatchCropAnything:
621
1129
  # Get coordinates of mouse click
622
1130
  coords = np.round(event.position).astype(int)
623
1131
 
624
- # Make sure coordinates are within bounds
625
- shape = self.segmentation_result.shape
626
- if (
627
- coords[0] < 0
628
- or coords[1] < 0
629
- or coords[0] >= shape[0]
630
- or coords[1] >= shape[1]
631
- ):
632
- return
1132
+ # Check if Shift is pressed for negative points
1133
+ is_negative = "Shift" in event.modifiers
1134
+ point_label = -1 if is_negative else 1
1135
+
1136
+ # Handle 2D vs 3D coordinates
1137
+ if self.use_3d:
1138
+ if len(coords) == 3:
1139
+ t, y, x = map(int, coords)
1140
+ elif len(coords) == 2:
1141
+ t = int(self.viewer.dims.current_step[0])
1142
+ y, x = map(int, coords)
1143
+ else:
1144
+ self.viewer.status = (
1145
+ f"Unexpected coordinate dimensions: {coords}"
1146
+ )
1147
+ return
1148
+
1149
+ # Add point to the layer immediately for visual feedback
1150
+ new_point = np.array([[t, y, x]])
1151
+ if len(layer.data) == 0:
1152
+ layer.data = new_point
1153
+ else:
1154
+ layer.data = np.vstack([layer.data, new_point])
1155
+
1156
+ # Update point colors
1157
+ colors = layer.face_color
1158
+ if isinstance(colors, list):
1159
+ colors.append("red" if is_negative else "green")
1160
+ else:
1161
+ n_points = len(layer.data)
1162
+ colors = ["green"] * (n_points - 1)
1163
+ colors.append("red" if is_negative else "green")
1164
+ layer.face_color = colors
1165
+
1166
+ # Get the object ID
1167
+ # If clicking on existing segmentation with negative point
1168
+ label_id = self.segmentation_result[t, y, x]
1169
+ if is_negative and label_id > 0:
1170
+ obj_id = label_id
1171
+ else:
1172
+ # For new objects or negative on background
1173
+ if not hasattr(self, "_sam2_next_obj_id"):
1174
+ self._sam2_next_obj_id = 1
1175
+ obj_id = self._sam2_next_obj_id
1176
+ if point_label > 0 and label_id == 0:
1177
+ self._sam2_next_obj_id += 1
1178
+
1179
+ # Store point information
1180
+ if not hasattr(self, "points_data"):
1181
+ self.points_data = {}
1182
+ self.points_labels = {}
1183
+
1184
+ if obj_id not in self.points_data:
1185
+ self.points_data[obj_id] = []
1186
+ self.points_labels[obj_id] = []
1187
+
1188
+ self.points_data[obj_id].append(
1189
+ [x, y]
1190
+ ) # Note: SAM2 expects [x,y] format
1191
+ self.points_labels[obj_id].append(point_label)
1192
+
1193
+ # Perform segmentation
1194
+ if (
1195
+ hasattr(self, "_sam2_state")
1196
+ and self._sam2_state is not None
1197
+ ):
1198
+ # Prepare points
1199
+ points = np.array(
1200
+ self.points_data[obj_id], dtype=np.float32
1201
+ )
1202
+ labels = np.array(
1203
+ self.points_labels[obj_id], dtype=np.int32
1204
+ )
1205
+
1206
+ # Create progress layer for visual feedback
1207
+ progress_layer = None
1208
+ for existing_layer in self.viewer.layers:
1209
+ if "Propagation Progress" in existing_layer.name:
1210
+ progress_layer = existing_layer
1211
+ break
1212
+
1213
+ if progress_layer is None:
1214
+ progress_data = np.zeros_like(self.segmentation_result)
1215
+ progress_layer = self.viewer.add_image(
1216
+ progress_data,
1217
+ name="Propagation Progress",
1218
+ colormap="magma",
1219
+ opacity=0.5,
1220
+ visible=True,
1221
+ )
1222
+
1223
+ # First update the current frame immediately
1224
+ self.viewer.status = f"Processing object at frame {t}..."
1225
+
1226
+ # Run SAM2 on current frame
1227
+ _, out_obj_ids, out_mask_logits = (
1228
+ self.predictor.add_new_points_or_box(
1229
+ inference_state=self._sam2_state,
1230
+ frame_idx=t,
1231
+ obj_id=obj_id,
1232
+ points=points,
1233
+ labels=labels,
1234
+ )
1235
+ )
1236
+
1237
+ # Update current frame
1238
+ mask = (out_mask_logits[0] > 0.0).cpu().numpy()
1239
+ if mask.ndim > 2:
1240
+ mask = mask.squeeze()
1241
+
1242
+ # Resize if needed
1243
+ if mask.shape != self.segmentation_result[t].shape:
1244
+ from skimage.transform import resize
1245
+
1246
+ mask = resize(
1247
+ mask.astype(float),
1248
+ self.segmentation_result[t].shape,
1249
+ order=0,
1250
+ preserve_range=True,
1251
+ anti_aliasing=False,
1252
+ ).astype(bool)
1253
+
1254
+ # Update segmentation for this frame
1255
+ if point_label < 0:
1256
+ # For negative points, only remove from this object
1257
+ self.segmentation_result[t][
1258
+ (self.segmentation_result[t] == obj_id) & mask
1259
+ ] = 0
1260
+ else:
1261
+ # For positive points, only replace background
1262
+ self.segmentation_result[t][
1263
+ mask & (self.segmentation_result[t] == 0)
1264
+ ] = obj_id
1265
+
1266
+ # Update progress layer for this frame
1267
+ progress_data = progress_layer.data
1268
+ progress_data[t] = (
1269
+ mask.astype(float) * 0.5
1270
+ ) # Highlight current frame
1271
+ progress_layer.data = progress_data
1272
+
1273
+ # Now propagate to all frames with visual feedback
1274
+ self.viewer.status = "Propagating to all frames..."
1275
+
1276
+ # Run propagation
1277
+ frame_count = self.segmentation_result.shape[0]
1278
+ for (
1279
+ frame_idx,
1280
+ prop_obj_ids,
1281
+ mask_logits,
1282
+ ) in self.predictor.propagate_in_video(self._sam2_state):
1283
+ if frame_idx >= frame_count:
1284
+ continue
1285
+
1286
+ # Find our object
1287
+ obj_mask = None
1288
+ for i, prop_obj_id in enumerate(prop_obj_ids):
1289
+ if prop_obj_id == obj_id:
1290
+ obj_mask = (mask_logits[i] > 0.0).cpu().numpy()
1291
+ if obj_mask.ndim > 2:
1292
+ obj_mask = obj_mask.squeeze()
1293
+
1294
+ # Resize if needed
1295
+ if (
1296
+ obj_mask.shape
1297
+ != self.segmentation_result[
1298
+ frame_idx
1299
+ ].shape
1300
+ ):
1301
+ obj_mask = resize(
1302
+ obj_mask.astype(float),
1303
+ self.segmentation_result[
1304
+ frame_idx
1305
+ ].shape,
1306
+ order=0,
1307
+ preserve_range=True,
1308
+ anti_aliasing=False,
1309
+ ).astype(bool)
1310
+
1311
+ # Update segmentation
1312
+ self.segmentation_result[frame_idx][
1313
+ obj_mask
1314
+ & (
1315
+ self.segmentation_result[frame_idx]
1316
+ == 0
1317
+ )
1318
+ ] = obj_id
1319
+
1320
+ # Update progress visualization
1321
+ progress_data = progress_layer.data
1322
+ progress_data[frame_idx] = (
1323
+ obj_mask.astype(float) * 0.8
1324
+ ) # Show as processed
1325
+ progress_layer.data = progress_data
1326
+
1327
+ # Update status
1328
+ if frame_idx % 5 == 0:
1329
+ self.viewer.status = f"Propagating: frame {frame_idx+1}/{frame_count}"
1330
+ # Remove the viewer.update() call as it's causing errors
1331
+
1332
+ # Process any missing frames
1333
+ processed_frames = set(range(frame_count))
1334
+ for frame_idx in range(frame_count):
1335
+ if (
1336
+ progress_data[frame_idx].max() == 0
1337
+ ): # Frame not processed yet
1338
+ # Use nearest processed frame's mask
1339
+ nearest_idx = min(
1340
+ processed_frames,
1341
+ key=lambda x: abs(x - frame_idx),
1342
+ )
1343
+ if progress_data[nearest_idx].max() > 0:
1344
+ self.segmentation_result[frame_idx][
1345
+ (self.segmentation_result[frame_idx] == 0)
1346
+ & (
1347
+ self.segmentation_result[nearest_idx]
1348
+ == obj_id
1349
+ )
1350
+ ] = obj_id
1351
+
1352
+ # Update progress visualization
1353
+ progress_data[frame_idx] = (
1354
+ progress_data[nearest_idx] * 0.6
1355
+ ) # Mark as copied
1356
+
1357
+ # Final update of progress layer
1358
+ progress_layer.data = progress_data
1359
+
1360
+ # Remove progress layer after 2 seconds
1361
+ import threading
1362
+
1363
+ def remove_progress():
1364
+ import time
1365
+
1366
+ time.sleep(2)
1367
+ for layer in list(self.viewer.layers):
1368
+ if "Propagation Progress" in layer.name:
1369
+ self.viewer.layers.remove(layer)
1370
+
1371
+ threading.Thread(target=remove_progress).start()
1372
+
1373
+ # Update UI
1374
+ self._update_label_layer()
1375
+ if (
1376
+ hasattr(self, "label_table_widget")
1377
+ and self.label_table_widget is not None
1378
+ ):
1379
+ self._populate_label_table(self.label_table_widget)
633
1380
 
634
- # Get the label ID at the clicked position
635
- label_id = self.segmentation_result[coords[0], coords[1]]
1381
+ self.viewer.status = f"Object {obj_id} segmented and propagated to all frames"
636
1382
 
637
- # Skip if background (0) is clicked
638
- if label_id == 0:
1383
+ else:
1384
+ # 2D case
1385
+ if len(coords) == 2:
1386
+ y, x = map(int, coords)
1387
+ else:
1388
+ self.viewer.status = (
1389
+ f"Unexpected coordinate dimensions: {coords}"
1390
+ )
1391
+ return
1392
+
1393
+ # Add point to the layer immediately for visual feedback
1394
+ new_point = np.array([[y, x]])
1395
+ if len(layer.data) == 0:
1396
+ layer.data = new_point
1397
+ else:
1398
+ layer.data = np.vstack([layer.data, new_point])
1399
+
1400
+ # Update point colors
1401
+ colors = layer.face_color
1402
+ if isinstance(colors, list):
1403
+ colors.append("red" if is_negative else "green")
1404
+ else:
1405
+ n_points = len(layer.data)
1406
+ colors = ["green"] * (n_points - 1)
1407
+ colors.append("red" if is_negative else "green")
1408
+ layer.face_color = colors
1409
+
1410
+ # Get object ID
1411
+ label_id = self.segmentation_result[y, x]
1412
+ if is_negative and label_id > 0:
1413
+ obj_id = label_id
1414
+ else:
1415
+ if not hasattr(self, "next_obj_id"):
1416
+ self.next_obj_id = 1
1417
+ obj_id = self.next_obj_id
1418
+ if point_label > 0 and label_id == 0:
1419
+ self.next_obj_id += 1
1420
+
1421
+ # Store point information
1422
+ if not hasattr(self, "obj_points"):
1423
+ self.obj_points = {}
1424
+ self.obj_labels = {}
1425
+
1426
+ if obj_id not in self.obj_points:
1427
+ self.obj_points[obj_id] = []
1428
+ self.obj_labels[obj_id] = []
1429
+
1430
+ self.obj_points[obj_id].append(
1431
+ [x, y]
1432
+ ) # SAM2 expects [x,y] format
1433
+ self.obj_labels[obj_id].append(point_label)
1434
+
1435
+ # Perform segmentation
1436
+ if hasattr(self, "predictor") and self.predictor is not None:
1437
+ # Make sure image is loaded
1438
+ if self.current_image_for_segmentation is None:
1439
+ self.viewer.status = "No image loaded for segmentation"
1440
+ return
1441
+
1442
+ # Prepare image for SAM2
1443
+ image = self.current_image_for_segmentation
1444
+ if len(image.shape) == 2:
1445
+ image = np.stack([image] * 3, axis=-1)
1446
+ elif len(image.shape) == 3 and image.shape[2] == 1:
1447
+ image = np.concatenate([image] * 3, axis=2)
1448
+ elif len(image.shape) == 3 and image.shape[2] > 3:
1449
+ image = image[:, :, :3]
1450
+
1451
+ if image.dtype != np.uint8:
1452
+ image = (image / np.max(image) * 255).astype(np.uint8)
1453
+
1454
+ # Set the image in the predictor
1455
+ self.predictor.set_image(image)
1456
+
1457
+ # Use only points for current object
1458
+ points = np.array(
1459
+ self.obj_points[obj_id], dtype=np.float32
1460
+ )
1461
+ labels = np.array(self.obj_labels[obj_id], dtype=np.int32)
1462
+
1463
+ self.viewer.status = f"Segmenting object {obj_id} with {len(points)} points..."
1464
+
1465
+ with torch.inference_mode(), torch.autocast("cuda"):
1466
+ masks, scores, _ = self.predictor.predict(
1467
+ point_coords=points,
1468
+ point_labels=labels,
1469
+ multimask_output=True,
1470
+ )
1471
+
1472
+ # Get best mask
1473
+ if len(masks) > 0:
1474
+ best_mask = masks[0]
1475
+
1476
+ # Update segmentation result
1477
+ if (
1478
+ best_mask.shape
1479
+ != self.segmentation_result.shape
1480
+ ):
1481
+ from skimage.transform import resize
1482
+
1483
+ best_mask = resize(
1484
+ best_mask.astype(float),
1485
+ self.segmentation_result.shape,
1486
+ order=0,
1487
+ preserve_range=True,
1488
+ anti_aliasing=False,
1489
+ ).astype(bool)
1490
+
1491
+ # Apply mask based on point type
1492
+ if point_label < 0:
1493
+ # Remove only from current object
1494
+ mask_condition = np.logical_and(
1495
+ self.segmentation_result == obj_id,
1496
+ best_mask,
1497
+ )
1498
+ self.segmentation_result[mask_condition] = 0
1499
+ else:
1500
+ # Add to current object (only overwrite background)
1501
+ mask_condition = np.logical_and(
1502
+ best_mask, (self.segmentation_result == 0)
1503
+ )
1504
+ self.segmentation_result[mask_condition] = (
1505
+ obj_id
1506
+ )
1507
+
1508
+ # Update label info
1509
+ area = np.sum(self.segmentation_result == obj_id)
1510
+ y_indices, x_indices = np.where(
1511
+ self.segmentation_result == obj_id
1512
+ )
1513
+ center_y = (
1514
+ np.mean(y_indices) if len(y_indices) > 0 else 0
1515
+ )
1516
+ center_x = (
1517
+ np.mean(x_indices) if len(x_indices) > 0 else 0
1518
+ )
1519
+
1520
+ self.label_info[obj_id] = {
1521
+ "area": area,
1522
+ "center_y": center_y,
1523
+ "center_x": center_x,
1524
+ "score": float(scores[0]),
1525
+ }
1526
+
1527
+ self.viewer.status = f"Updated object {obj_id}"
1528
+ else:
1529
+ self.viewer.status = "No valid mask produced"
1530
+
1531
+ # Update the UI
1532
+ self._update_label_layer()
1533
+ if (
1534
+ hasattr(self, "label_table_widget")
1535
+ and self.label_table_widget is not None
1536
+ ):
1537
+ self._populate_label_table(self.label_table_widget)
1538
+
1539
+ except (
1540
+ IndexError,
1541
+ KeyError,
1542
+ ValueError,
1543
+ RuntimeError,
1544
+ TypeError,
1545
+ ) as e:
1546
+ import traceback
1547
+
1548
+ self.viewer.status = f"Error in points handling: {str(e)}"
1549
+ traceback.print_exc()
1550
+
1551
+ def _on_label_clicked(self, layer, event):
1552
+ """Handle label selection and user prompts on mouse click."""
1553
+ try:
1554
+ # Only process clicks, not drags
1555
+ if event.type != "mouse_press":
639
1556
  return
640
1557
 
641
- # Toggle the label selection
642
- if label_id in self.selected_labels:
643
- self.selected_labels.remove(label_id)
644
- self.viewer.status = f"Deselected label ID: {label_id} | Selected labels: {self.selected_labels}"
1558
+ # Get coordinates of mouse click
1559
+ coords = np.round(event.position).astype(int)
1560
+
1561
+ # Check if Shift is pressed (negative point)
1562
+ is_negative = "Shift" in event.modifiers
1563
+ point_label = -1 if is_negative else 1
1564
+
1565
+ # For 2D data
1566
+ if not self.use_3d:
1567
+ if len(coords) == 2:
1568
+ y, x = map(int, coords)
1569
+ else:
1570
+ self.viewer.status = (
1571
+ f"Unexpected coordinate dimensions: {coords}"
1572
+ )
1573
+ return
1574
+
1575
+ # Check if within image bounds
1576
+ shape = self.segmentation_result.shape
1577
+ if y < 0 or x < 0 or y >= shape[0] or x >= shape[1]:
1578
+ self.viewer.status = "Click is outside image bounds"
1579
+ return
1580
+
1581
+ # Get the label ID at the clicked position
1582
+ label_id = self.segmentation_result[y, x]
1583
+
1584
+ # Initialize a unique object ID for this click (if needed)
1585
+ if not hasattr(self, "next_obj_id"):
1586
+ # Start with highest existing ID + 1
1587
+ if self.segmentation_result.max() > 0:
1588
+ self.next_obj_id = (
1589
+ int(self.segmentation_result.max()) + 1
1590
+ )
1591
+ else:
1592
+ self.next_obj_id = 1
1593
+
1594
+ # If clicking on background or using negative click, handle segmentation
1595
+ if label_id == 0 or is_negative:
1596
+ # Find or create points layer for the current object we're working on
1597
+ current_obj_id = None
1598
+
1599
+ # If negative point on existing label, use that label's ID
1600
+ if is_negative and label_id > 0:
1601
+ current_obj_id = label_id
1602
+ # For positive clicks on background, create a new object
1603
+ elif point_label > 0 and label_id == 0:
1604
+ current_obj_id = self.next_obj_id
1605
+ self.next_obj_id += 1
1606
+ # For negative on background, try to find most recent object
1607
+ elif point_label < 0 and label_id == 0:
1608
+ # Use most recently created object if available
1609
+ if hasattr(self, "obj_points") and self.obj_points:
1610
+ current_obj_id = max(self.obj_points.keys())
1611
+ else:
1612
+ self.viewer.status = "No existing object to modify with negative point"
1613
+ return
1614
+
1615
+ if current_obj_id is None:
1616
+ self.viewer.status = (
1617
+ "Could not determine which object to modify"
1618
+ )
1619
+ return
1620
+
1621
+ # Find or create points layer for this object
1622
+ points_layer = None
1623
+ for layer in list(self.viewer.layers):
1624
+ if f"Points for Object {current_obj_id}" in layer.name:
1625
+ points_layer = layer
1626
+ break
1627
+
1628
+ # Initialize object tracking if needed
1629
+ if not hasattr(self, "obj_points"):
1630
+ self.obj_points = {}
1631
+ self.obj_labels = {}
1632
+
1633
+ if current_obj_id not in self.obj_points:
1634
+ self.obj_points[current_obj_id] = []
1635
+ self.obj_labels[current_obj_id] = []
1636
+
1637
+ # Create or update points layer for this object
1638
+ if points_layer is None:
1639
+ # First point for this object
1640
+ points_layer = self.viewer.add_points(
1641
+ np.array([[y, x]]),
1642
+ name=f"Points for Object {current_obj_id}",
1643
+ size=10,
1644
+ face_color=["green" if point_label > 0 else "red"],
1645
+ border_color="white",
1646
+ border_width=1,
1647
+ opacity=0.8,
1648
+ )
1649
+ self.obj_points[current_obj_id] = [[x, y]]
1650
+ self.obj_labels[current_obj_id] = [point_label]
1651
+ else:
1652
+ # Add point to existing layer
1653
+ current_points = points_layer.data
1654
+ current_colors = points_layer.face_color
1655
+
1656
+ # Add new point
1657
+ new_points = np.vstack([current_points, [y, x]])
1658
+ new_color = "green" if point_label > 0 else "red"
1659
+
1660
+ # Update points layer
1661
+ points_layer.data = new_points
1662
+
1663
+ # Update colors
1664
+ if isinstance(current_colors, list):
1665
+ current_colors.append(new_color)
1666
+ points_layer.face_color = current_colors
1667
+ else:
1668
+ # If it's an array, create a list of colors
1669
+ colors = []
1670
+ for i in range(len(new_points)):
1671
+ if i < len(current_points):
1672
+ colors.append(
1673
+ "green" if point_label > 0 else "red"
1674
+ )
1675
+ else:
1676
+ colors.append(new_color)
1677
+ points_layer.face_color = colors
1678
+
1679
+ # Update object tracking
1680
+ self.obj_points[current_obj_id].append([x, y])
1681
+ self.obj_labels[current_obj_id].append(point_label)
1682
+
1683
+ # Now do the actual segmentation using SAM2
1684
+ if (
1685
+ hasattr(self, "predictor")
1686
+ and self.predictor is not None
1687
+ ):
1688
+ try:
1689
+ # Make sure image is loaded
1690
+ if self.current_image_for_segmentation is None:
1691
+ self.viewer.status = (
1692
+ "No image loaded for segmentation"
1693
+ )
1694
+ return
1695
+
1696
+ # Prepare image for SAM2
1697
+ image = self.current_image_for_segmentation
1698
+ if len(image.shape) == 2:
1699
+ image = np.stack([image] * 3, axis=-1)
1700
+ elif len(image.shape) == 3 and image.shape[2] == 1:
1701
+ image = np.concatenate([image] * 3, axis=2)
1702
+ elif len(image.shape) == 3 and image.shape[2] > 3:
1703
+ image = image[:, :, :3]
1704
+
1705
+ if image.dtype != np.uint8:
1706
+ image = (image / np.max(image) * 255).astype(
1707
+ np.uint8
1708
+ )
1709
+
1710
+ # Set the image in the predictor
1711
+ self.predictor.set_image(image)
1712
+
1713
+ # Only use the points for the current object being segmented
1714
+ points = np.array(
1715
+ self.obj_points[current_obj_id],
1716
+ dtype=np.float32,
1717
+ )
1718
+ labels = np.array(
1719
+ self.obj_labels[current_obj_id], dtype=np.int32
1720
+ )
1721
+
1722
+ self.viewer.status = f"Segmenting object {current_obj_id} with {len(points)} points..."
1723
+
1724
+ with torch.inference_mode(), torch.autocast(
1725
+ "cuda"
1726
+ ):
1727
+ masks, scores, _ = self.predictor.predict(
1728
+ point_coords=points,
1729
+ point_labels=labels,
1730
+ multimask_output=True,
1731
+ )
1732
+
1733
+ # Get best mask
1734
+ if len(masks) > 0:
1735
+ best_mask = masks[0]
1736
+
1737
+ # Update segmentation result
1738
+ if (
1739
+ best_mask.shape
1740
+ != self.segmentation_result.shape
1741
+ ):
1742
+ from skimage.transform import resize
1743
+
1744
+ best_mask = resize(
1745
+ best_mask.astype(float),
1746
+ self.segmentation_result.shape,
1747
+ order=0,
1748
+ preserve_range=True,
1749
+ anti_aliasing=False,
1750
+ ).astype(bool)
1751
+
1752
+ # CRITICAL FIX: For negative points, only remove from this object's mask
1753
+ # For positive points, add to this object's mask without removing other objects
1754
+ if point_label < 0:
1755
+ # Remove only from current object's mask
1756
+ self.segmentation_result[
1757
+ (
1758
+ self.segmentation_result
1759
+ == current_obj_id
1760
+ )
1761
+ & best_mask
1762
+ ] = 0
1763
+ else:
1764
+ # Add to current object's mask without affecting other objects
1765
+ # Only overwrite background (value 0)
1766
+ self.segmentation_result[
1767
+ best_mask
1768
+ & (self.segmentation_result == 0)
1769
+ ] = current_obj_id
1770
+
1771
+ # Update label info
1772
+ area = np.sum(
1773
+ self.segmentation_result
1774
+ == current_obj_id
1775
+ )
1776
+ y_indices, x_indices = np.where(
1777
+ self.segmentation_result
1778
+ == current_obj_id
1779
+ )
1780
+ center_y = (
1781
+ np.mean(y_indices)
1782
+ if len(y_indices) > 0
1783
+ else 0
1784
+ )
1785
+ center_x = (
1786
+ np.mean(x_indices)
1787
+ if len(x_indices) > 0
1788
+ else 0
1789
+ )
1790
+
1791
+ self.label_info[current_obj_id] = {
1792
+ "area": area,
1793
+ "center_y": center_y,
1794
+ "center_x": center_x,
1795
+ "score": float(scores[0]),
1796
+ }
1797
+
1798
+ self.viewer.status = (
1799
+ f"Updated object {current_obj_id}"
1800
+ )
1801
+ else:
1802
+ self.viewer.status = (
1803
+ "No valid mask produced"
1804
+ )
1805
+
1806
+ # Update the UI
1807
+ self._update_label_layer()
1808
+ if (
1809
+ hasattr(self, "label_table_widget")
1810
+ and self.label_table_widget is not None
1811
+ ):
1812
+ self._populate_label_table(
1813
+ self.label_table_widget
1814
+ )
1815
+
1816
+ except (
1817
+ IndexError,
1818
+ KeyError,
1819
+ ValueError,
1820
+ AttributeError,
1821
+ TypeError,
1822
+ ) as e:
1823
+ import traceback
1824
+
1825
+ self.viewer.status = (
1826
+ f"Error in SAM2 processing: {str(e)}"
1827
+ )
1828
+ traceback.print_exc()
1829
+
1830
+ # If clicking on an existing label, toggle selection
1831
+ elif label_id > 0:
1832
+ # Toggle the label selection
1833
+ if label_id in self.selected_labels:
1834
+ self.selected_labels.remove(label_id)
1835
+ self.viewer.status = f"Deselected label ID: {label_id} | Selected labels: {self.selected_labels}"
1836
+ else:
1837
+ self.selected_labels.add(label_id)
1838
+ self.viewer.status = f"Selected label ID: {label_id} | Selected labels: {self.selected_labels}"
1839
+
1840
+ # Update table and preview
1841
+ self._update_label_table()
1842
+ self.preview_crop()
1843
+
1844
+ # 3D case (handle differently)
645
1845
  else:
646
- self.selected_labels.add(label_id)
647
- self.viewer.status = f"Selected label ID: {label_id} | Selected labels: {self.selected_labels}"
1846
+ if len(coords) == 3:
1847
+ t, y, x = map(int, coords)
1848
+ elif len(coords) == 2:
1849
+ t = int(self.viewer.dims.current_step[0])
1850
+ y, x = map(int, coords)
1851
+ else:
1852
+ self.viewer.status = (
1853
+ f"Unexpected coordinate dimensions: {coords}"
1854
+ )
1855
+ return
1856
+
1857
+ # Check if within bounds
1858
+ shape = self.segmentation_result.shape
1859
+ if (
1860
+ t < 0
1861
+ or t >= shape[0]
1862
+ or y < 0
1863
+ or y >= shape[1]
1864
+ or x < 0
1865
+ or x >= shape[2]
1866
+ ):
1867
+ self.viewer.status = "Click is outside volume bounds"
1868
+ return
1869
+
1870
+ # Get the label ID at the clicked position
1871
+ label_id = self.segmentation_result[t, y, x]
1872
+
1873
+ # If background or shift is pressed, handle in _on_3d_label_clicked
1874
+ if label_id == 0 or is_negative:
1875
+ # This will be handled by _on_3d_label_clicked already attached
1876
+ pass
1877
+ # If clicking on an existing label, handle selection
1878
+ elif label_id > 0:
1879
+ # Toggle the label selection
1880
+ if label_id in self.selected_labels:
1881
+ self.selected_labels.remove(label_id)
1882
+ self.viewer.status = f"Deselected label ID: {label_id} | Selected labels: {self.selected_labels}"
1883
+ else:
1884
+ self.selected_labels.add(label_id)
1885
+ self.viewer.status = f"Selected label ID: {label_id} | Selected labels: {self.selected_labels}"
648
1886
 
649
- # Update table if it exists
650
- self._update_label_table()
1887
+ # Update table if it exists
1888
+ self._update_label_table()
651
1889
 
652
- # Update preview after selection changes
653
- self.preview_crop()
1890
+ # Update preview after selection changes
1891
+ self.preview_crop()
654
1892
 
655
- except (Exception, ValueError) as e:
656
- self.viewer.status = f"Error selecting label: {str(e)}"
1893
+ except (
1894
+ IndexError,
1895
+ KeyError,
1896
+ ValueError,
1897
+ AttributeError,
1898
+ TypeError,
1899
+ ) as e:
1900
+ import traceback
1901
+
1902
+ self.viewer.status = f"Error in click handling: {str(e)}"
1903
+ traceback.print_exc()
1904
+
1905
+ def _add_point_marker(self, coords, label_type):
1906
+ """Add a visible marker for where the user clicked."""
1907
+ # Remove previous point markers
1908
+ for layer in list(self.viewer.layers):
1909
+ if "Point Prompt" in layer.name:
1910
+ self.viewer.layers.remove(layer)
1911
+
1912
+ # Create points layer
1913
+ color = (
1914
+ "red" if label_type < 0 else "green"
1915
+ ) # Red for negative, green for positive
1916
+ self.viewer.add_points(
1917
+ [coords],
1918
+ name="Point Prompt",
1919
+ size=10,
1920
+ face_color=color,
1921
+ edge_color="white",
1922
+ edge_width=2,
1923
+ opacity=0.8,
1924
+ )
657
1925
 
658
1926
  def create_label_table(self, parent_widget):
659
1927
  """Create a table widget displaying all detected labels."""
@@ -694,57 +1962,86 @@ class BatchCropAnything:
694
1962
 
695
1963
  def _populate_label_table(self, table):
696
1964
  """Populate the table with label information."""
697
- if not self.label_info:
698
- table.setRowCount(0)
699
- return
1965
+ try:
1966
+ # Get all unique non-zero labels from the segmentation result safely
1967
+ if self.segmentation_result is None:
1968
+ # No segmentation yet
1969
+ table.setRowCount(0)
1970
+ self.viewer.status = "No segmentation available"
1971
+ return
700
1972
 
701
- # Set row count
702
- table.setRowCount(len(self.label_info))
1973
+ # Get unique labels, safely handling None values
1974
+ unique_labels = []
1975
+ for val in np.unique(self.segmentation_result):
1976
+ if val is not None and val > 0:
1977
+ unique_labels.append(val)
703
1978
 
704
- # Sort labels by size (largest first)
705
- sorted_labels = sorted(
706
- self.label_info.items(),
707
- key=lambda item: item[1]["area"],
708
- reverse=True,
709
- )
1979
+ if len(unique_labels) == 0:
1980
+ table.setRowCount(0)
1981
+ self.viewer.status = "No labeled objects found"
1982
+ return
710
1983
 
711
- # Fill table with data
712
- for row, (label_id, _info) in enumerate(sorted_labels):
713
- # Checkbox for selection
714
- checkbox_widget = QWidget()
715
- checkbox_layout = QHBoxLayout(checkbox_widget)
716
- checkbox_layout.setContentsMargins(5, 0, 5, 0)
717
- checkbox_layout.setAlignment(Qt.AlignCenter)
718
-
719
- checkbox = QCheckBox()
720
- checkbox.setChecked(label_id in self.selected_labels)
721
-
722
- # Connect checkbox to label selection
723
- def make_checkbox_callback(lid):
724
- def callback(state):
725
- if state == Qt.Checked:
726
- self.selected_labels.add(lid)
727
- else:
728
- self.selected_labels.discard(lid)
729
- self.preview_crop()
1984
+ # Set row count
1985
+ table.setRowCount(len(unique_labels))
1986
+
1987
+ # Fill in label info for any missing labels
1988
+ for label_id in unique_labels:
1989
+ if label_id not in self.label_info:
1990
+ # Calculate basic info for this label
1991
+ mask = self.segmentation_result == label_id
1992
+ area = np.sum(mask)
1993
+
1994
+ # Add info to label_info dictionary
1995
+ self.label_info[label_id] = {
1996
+ "area": area,
1997
+ "score": 1.0, # Default score
1998
+ }
1999
+
2000
+ # Fill table with data
2001
+ for row, label_id in enumerate(unique_labels):
2002
+ # Checkbox for selection
2003
+ checkbox_widget = QWidget()
2004
+ checkbox_layout = QHBoxLayout(checkbox_widget)
2005
+ checkbox_layout.setContentsMargins(5, 0, 5, 0)
2006
+ checkbox_layout.setAlignment(Qt.AlignCenter)
2007
+
2008
+ checkbox = QCheckBox()
2009
+ checkbox.setChecked(label_id in self.selected_labels)
2010
+
2011
+ # Connect checkbox to label selection
2012
+ def make_checkbox_callback(lid):
2013
+ def callback(state):
2014
+ if state == Qt.Checked:
2015
+ self.selected_labels.add(lid)
2016
+ else:
2017
+ self.selected_labels.discard(lid)
2018
+ self.preview_crop()
730
2019
 
731
- return callback
2020
+ return callback
732
2021
 
733
- checkbox.stateChanged.connect(make_checkbox_callback(label_id))
2022
+ checkbox.stateChanged.connect(make_checkbox_callback(label_id))
734
2023
 
735
- checkbox_layout.addWidget(checkbox)
736
- table.setCellWidget(row, 0, checkbox_widget)
2024
+ checkbox_layout.addWidget(checkbox)
2025
+ table.setCellWidget(row, 0, checkbox_widget)
737
2026
 
738
- # Label ID as plain text with transparent background
739
- item = QTableWidgetItem(str(label_id))
740
- item.setTextAlignment(Qt.AlignCenter)
2027
+ # Label ID as plain text with transparent background
2028
+ item = QTableWidgetItem(str(label_id))
2029
+ item.setTextAlignment(Qt.AlignCenter)
741
2030
 
742
- # Set the background color to transparent
743
- brush = item.background()
744
- brush.setStyle(Qt.NoBrush)
745
- item.setBackground(brush)
2031
+ # Set the background color to transparent
2032
+ brush = item.background()
2033
+ brush.setStyle(Qt.NoBrush)
2034
+ item.setBackground(brush)
746
2035
 
747
- table.setItem(row, 1, item)
2036
+ table.setItem(row, 1, item)
2037
+
2038
+ except (KeyError, TypeError, ValueError, AttributeError) as e:
2039
+ import traceback
2040
+
2041
+ self.viewer.status = f"Error populating table: {str(e)}"
2042
+ traceback.print_exc()
2043
+ # Set empty table as fallback
2044
+ table.setRowCount(0)
748
2045
 
749
2046
  def _update_label_table(self):
750
2047
  """Update the label selection table if it exists."""
@@ -754,6 +2051,9 @@ class BatchCropAnything:
754
2051
  # Block signals during update
755
2052
  self.label_table_widget.blockSignals(True)
756
2053
 
2054
+ # Completely repopulate the table to ensure it's up to date
2055
+ self._populate_label_table(self.label_table_widget)
2056
+
757
2057
  # Update checkboxes
758
2058
  for row in range(self.label_table_widget.rowCount()):
759
2059
  # Get label ID from the visible column
@@ -793,10 +2093,6 @@ class BatchCropAnything:
793
2093
  self.preview_crop()
794
2094
  self.viewer.status = "Cleared all selections"
795
2095
 
796
- # --------------------------------------------------
797
- # Image Processing and Export
798
- # --------------------------------------------------
799
-
800
2096
  def preview_crop(self, label_ids=None):
801
2097
  """Preview the crop result with the selected label IDs."""
802
2098
  if self.segmentation_result is None or self.image_layer is None:
@@ -826,20 +2122,29 @@ class BatchCropAnything:
826
2122
  image = self.original_image.copy()
827
2123
 
828
2124
  # Create mask from selected label IDs
829
- mask = np.zeros_like(self.segmentation_result, dtype=bool)
830
- for label_id in label_ids:
831
- mask |= self.segmentation_result == label_id
2125
+ if self.use_3d:
2126
+ # For 3D data
2127
+ mask = np.zeros_like(self.segmentation_result, dtype=bool)
2128
+ for label_id in label_ids:
2129
+ mask |= self.segmentation_result == label_id
832
2130
 
833
- # Apply mask to image for preview (set everything outside mask to 0)
834
- if len(image.shape) == 2:
835
- # Grayscale image
2131
+ # Apply mask
836
2132
  preview_image = image.copy()
837
2133
  preview_image[~mask] = 0
838
2134
  else:
839
- # Color image
840
- preview_image = image.copy()
841
- for c in range(preview_image.shape[2]):
842
- preview_image[:, :, c][~mask] = 0
2135
+ # For 2D data
2136
+ mask = np.zeros_like(self.segmentation_result, dtype=bool)
2137
+ for label_id in label_ids:
2138
+ mask |= self.segmentation_result == label_id
2139
+
2140
+ # Apply mask
2141
+ if len(image.shape) == 2:
2142
+ preview_image = image.copy()
2143
+ preview_image[~mask] = 0
2144
+ else:
2145
+ preview_image = image.copy()
2146
+ for c in range(preview_image.shape[2]):
2147
+ preview_image[:, :, c][~mask] = 0
843
2148
 
844
2149
  # Remove previous preview if exists
845
2150
  for layer in list(self.viewer.layers):
@@ -879,20 +2184,58 @@ class BatchCropAnything:
879
2184
  image = self.original_image
880
2185
 
881
2186
  # Create mask from all selected label IDs
882
- mask = np.zeros_like(self.segmentation_result, dtype=bool)
883
- for label_id in self.selected_labels:
884
- mask |= self.segmentation_result == label_id
2187
+ if self.use_3d:
2188
+ # For 3D data, create a 3D mask
2189
+ mask = np.zeros_like(self.segmentation_result, dtype=bool)
2190
+ for label_id in self.selected_labels:
2191
+ mask |= self.segmentation_result == label_id
885
2192
 
886
- # Apply mask to image (set everything outside mask to 0)
887
- if len(image.shape) == 2:
888
- # Grayscale image
2193
+ # Apply mask to image (set everything outside mask to 0)
889
2194
  cropped_image = image.copy()
890
2195
  cropped_image[~mask] = 0
2196
+
2197
+ # Save label image with same dimensions as original
2198
+ label_image = np.zeros_like(
2199
+ self.segmentation_result, dtype=np.uint32
2200
+ )
2201
+ for label_id in self.selected_labels:
2202
+ label_image[self.segmentation_result == label_id] = (
2203
+ label_id
2204
+ )
891
2205
  else:
892
- # Color image
893
- cropped_image = image.copy()
894
- for c in range(cropped_image.shape[2]):
895
- cropped_image[:, :, c][~mask] = 0
2206
+ # For 2D data, handle as before
2207
+ mask = np.zeros_like(self.segmentation_result, dtype=bool)
2208
+ for label_id in self.selected_labels:
2209
+ mask |= self.segmentation_result == label_id
2210
+
2211
+ # Apply mask to image (set everything outside mask to 0)
2212
+ if len(image.shape) == 2:
2213
+ # Grayscale image
2214
+ cropped_image = image.copy()
2215
+ cropped_image[~mask] = 0
2216
+
2217
+ # Create label image with same dimensions
2218
+ label_image = np.zeros_like(
2219
+ self.segmentation_result, dtype=np.uint32
2220
+ )
2221
+ for label_id in self.selected_labels:
2222
+ label_image[self.segmentation_result == label_id] = (
2223
+ label_id
2224
+ )
2225
+ else:
2226
+ # Color image - mask must be expanded to match channel dimension
2227
+ cropped_image = image.copy()
2228
+ for c in range(cropped_image.shape[2]):
2229
+ cropped_image[:, :, c][~mask] = 0
2230
+
2231
+ # Create label image with 2D dimensions (without channels)
2232
+ label_image = np.zeros_like(
2233
+ self.segmentation_result, dtype=np.uint32
2234
+ )
2235
+ for label_id in self.selected_labels:
2236
+ label_image[self.segmentation_result == label_id] = (
2237
+ label_id
2238
+ )
896
2239
 
897
2240
  # Save cropped image
898
2241
  image_path = self.images[self.current_index]
@@ -900,18 +2243,17 @@ class BatchCropAnything:
900
2243
  label_str = "_".join(
901
2244
  str(lid) for lid in sorted(self.selected_labels)
902
2245
  )
903
- output_path = f"{base_name}_cropped_{label_str}{ext}"
904
-
905
- # Save using appropriate method based on file type
906
- if output_path.lower().endswith((".tif", ".tiff")):
907
- imwrite(output_path, cropped_image, compression="zlib")
908
- else:
909
- from skimage.io import imsave
910
-
911
- imsave(output_path, cropped_image)
2246
+ output_path = f"{base_name}_cropped_{label_str}.tif"
912
2247
 
2248
+ # Save using tifffile with explicit parameters for best compatibility
2249
+ imwrite(output_path, cropped_image, compression="zlib")
913
2250
  self.viewer.status = f"Saved cropped image to {output_path}"
914
2251
 
2252
+ # Save the label image with exact same dimensions as original
2253
+ label_output_path = f"{base_name}_labels_{label_str}.tif"
2254
+ imwrite(label_output_path, label_image, compression="zlib")
2255
+ self.viewer.status += f"\nSaved label mask to {label_output_path}"
2256
+
915
2257
  # Make sure the segmentation layer is active again
916
2258
  if self.label_layer is not None:
917
2259
  self.viewer.layers.selection.active = self.label_layer
@@ -923,76 +2265,44 @@ class BatchCropAnything:
923
2265
  return False
924
2266
 
925
2267
 
926
- # --------------------------------------------------
927
- # UI Creation Functions
928
- # --------------------------------------------------
929
-
930
-
931
2268
  def create_crop_widget(processor):
932
2269
  """Create the crop control widget."""
933
2270
  crop_widget = QWidget()
934
2271
  layout = QVBoxLayout()
935
- layout.setSpacing(10) # Add more space between elements
936
- layout.setContentsMargins(
937
- 10, 10, 10, 10
938
- ) # Add margins around all elements
2272
+ layout.setSpacing(10)
2273
+ layout.setContentsMargins(10, 10, 10, 10)
939
2274
 
940
2275
  # Instructions
2276
+ dimension_type = "3D (TYX/ZYX)" if processor.use_3d else "2D (YX)"
941
2277
  instructions_label = QLabel(
942
- "Select objects to keep in the cropped image.\n"
943
- "You can select labels using the table below or by clicking directly on objects "
944
- "in the image (make sure the Segmentation layer is active)."
2278
+ f"<b>Processing {dimension_type} data</b><br><br>"
2279
+ "To create/edit objects:<br>"
2280
+ "1. <b>Click on the POINTS layer</b> to add positive points<br>"
2281
+ "2. Use Shift+click for negative points to refine segmentation<br>"
2282
+ "3. Click on existing objects in the Segmentation layer to select them<br>"
2283
+ "4. Press 'Crop' to save the selected objects to disk"
945
2284
  )
946
2285
  instructions_label.setWordWrap(True)
947
2286
  layout.addWidget(instructions_label)
948
2287
 
949
- # Sensitivity slider
950
- sensitivity_layout = QVBoxLayout()
951
-
952
- # Header label
953
- sensitivity_header_layout = QHBoxLayout()
954
- sensitivity_label = QLabel("Segmentation Sensitivity:")
955
- sensitivity_value_label = QLabel(f"{processor.sensitivity}")
956
- sensitivity_header_layout.addWidget(sensitivity_label)
957
- sensitivity_header_layout.addStretch()
958
- sensitivity_header_layout.addWidget(sensitivity_value_label)
959
- sensitivity_layout.addLayout(sensitivity_header_layout)
960
-
961
- # Slider
962
- slider_layout = QHBoxLayout()
963
- sensitivity_slider = QSlider(Qt.Horizontal)
964
- sensitivity_slider.setMinimum(0)
965
- sensitivity_slider.setMaximum(100)
966
- sensitivity_slider.setValue(processor.sensitivity)
967
- sensitivity_slider.setTickPosition(QSlider.TicksBelow)
968
- sensitivity_slider.setTickInterval(10)
969
- slider_layout.addWidget(sensitivity_slider)
970
-
971
- apply_sensitivity_button = QPushButton("Apply")
972
- apply_sensitivity_button.setToolTip(
973
- "Apply sensitivity changes to regenerate segmentation"
2288
+ # Add a button to ensure points layer is active
2289
+ activate_button = QPushButton("Make Points Layer Active")
2290
+ activate_button.clicked.connect(
2291
+ lambda: processor._ensure_points_layer_active()
974
2292
  )
975
- slider_layout.addWidget(apply_sensitivity_button)
976
- sensitivity_layout.addLayout(slider_layout)
2293
+ layout.addWidget(activate_button)
977
2294
 
978
- # Description label
979
- sensitivity_description = QLabel(
980
- "Medium sensitivity - Balanced detection (γ=1.00)"
981
- )
982
- sensitivity_description.setStyleSheet("font-style: italic; color: #666;")
983
- sensitivity_layout.addWidget(sensitivity_description)
984
-
985
- layout.addLayout(sensitivity_layout)
2295
+ # Add a "Clear Points" button to reset prompts
2296
+ clear_points_button = QPushButton("Clear Points")
2297
+ layout.addWidget(clear_points_button)
986
2298
 
987
2299
  # Create label table
988
2300
  label_table = processor.create_label_table(crop_widget)
989
- label_table.setMinimumHeight(150) # Reduce minimum height to save space
990
- label_table.setMaximumHeight(
991
- 300
992
- ) # Set maximum height to prevent taking too much space
2301
+ label_table.setMinimumHeight(150)
2302
+ label_table.setMaximumHeight(300)
993
2303
  layout.addWidget(label_table)
994
2304
 
995
- # Remove "Focus on Segmentation Layer" button as it's now redundant
2305
+ # Selection buttons
996
2306
  selection_layout = QHBoxLayout()
997
2307
  select_all_button = QPushButton("Select All")
998
2308
  clear_selection_button = QPushButton("Clear Selection")
@@ -1014,7 +2324,7 @@ def create_crop_widget(processor):
1014
2324
 
1015
2325
  # Status label
1016
2326
  status_label = QLabel(
1017
- "Ready to process images. Select objects using the table or by clicking on them."
2327
+ "Ready to process images. Click on POINTS layer to add segmentation points."
1018
2328
  )
1019
2329
  status_label.setWordWrap(True)
1020
2330
  layout.addWidget(status_label)
@@ -1033,36 +2343,51 @@ def create_crop_widget(processor):
1033
2343
  # Create new table
1034
2344
  label_table = processor.create_label_table(crop_widget)
1035
2345
  label_table.setMinimumHeight(200)
1036
- layout.insertWidget(3, label_table) # Insert after sensitivity slider
2346
+ layout.insertWidget(3, label_table) # Insert after clear points button
1037
2347
  return label_table
1038
2348
 
1039
- # Connect button signals
1040
- def on_sensitivity_changed(value):
1041
- sensitivity_value_label.setText(f"{value}")
1042
- # Update description based on sensitivity
1043
- if value < 25:
1044
- gamma = (
1045
- 1.5 - (value / 100) * 1.0
1046
- ) # Higher gamma for low sensitivity
1047
- description = f"Low sensitivity - Seeks large, distinct objects (γ={gamma:.2f})"
1048
- elif value < 75:
1049
- gamma = 1.5 - (value / 100) * 1.0
1050
- description = (
1051
- f"Medium sensitivity - Balanced detection (γ={gamma:.2f})"
2349
+ # Add helper method to ensure points layer is active
2350
+ def _ensure_points_layer_active():
2351
+ points_layer = None
2352
+ for layer in list(processor.viewer.layers):
2353
+ if "Points" in layer.name:
2354
+ points_layer = layer
2355
+ break
2356
+
2357
+ if points_layer is not None:
2358
+ processor.viewer.layers.selection.active = points_layer
2359
+ status_label.setText(
2360
+ "Points layer is now active - click to add points"
1052
2361
  )
1053
2362
  else:
1054
- gamma = (
1055
- 1.5 - (value / 100) * 1.0
1056
- ) # Lower gamma for high sensitivity
1057
- description = f"High sensitivity - Detects subtle, small objects (γ={gamma:.2f})"
1058
- sensitivity_description.setText(description)
1059
-
1060
- def on_apply_sensitivity_clicked():
1061
- new_sensitivity = sensitivity_slider.value()
1062
- processor.generate_segmentation_with_sensitivity(new_sensitivity)
1063
- replace_table_widget()
2363
+ status_label.setText(
2364
+ "No points layer found. Please load an image first."
2365
+ )
2366
+
2367
+ processor._ensure_points_layer_active = _ensure_points_layer_active
2368
+
2369
+ # Connect button signals
2370
+ def on_clear_points_clicked():
2371
+ # Remove all point layers
2372
+ for layer in list(processor.viewer.layers):
2373
+ if "Points" in layer.name:
2374
+ processor.viewer.layers.remove(layer)
2375
+
2376
+ # Reset point tracking attributes
2377
+ if hasattr(processor, "points_data"):
2378
+ processor.points_data = {}
2379
+ processor.points_labels = {}
2380
+
2381
+ if hasattr(processor, "obj_points"):
2382
+ processor.obj_points = {}
2383
+ processor.obj_labels = {}
2384
+
2385
+ # Re-create empty points layer
2386
+ processor._update_label_layer()
2387
+ processor._ensure_points_layer_active()
2388
+
1064
2389
  status_label.setText(
1065
- f"Regenerated segmentation with sensitivity {new_sensitivity}"
2390
+ "Cleared all points. Click on Points layer to add new points."
1066
2391
  )
1067
2392
 
1068
2393
  def on_select_all_clicked():
@@ -1086,117 +2411,83 @@ def create_crop_widget(processor):
1086
2411
  )
1087
2412
 
1088
2413
  def on_next_clicked():
2414
+ # Clear points before moving to next image
2415
+ on_clear_points_clicked()
2416
+
1089
2417
  if not processor.next_image():
1090
2418
  next_button.setEnabled(False)
1091
2419
  else:
1092
2420
  prev_button.setEnabled(True)
1093
2421
  replace_table_widget()
1094
- # Reset sensitivity slider to default
1095
- sensitivity_slider.setValue(processor.sensitivity)
1096
- sensitivity_value_label.setText(f"{processor.sensitivity}")
1097
2422
  status_label.setText(
1098
2423
  f"Showing image {processor.current_index + 1}/{len(processor.images)}"
1099
2424
  )
2425
+ processor._ensure_points_layer_active()
1100
2426
 
1101
2427
  def on_prev_clicked():
2428
+ # Clear points before moving to previous image
2429
+ on_clear_points_clicked()
2430
+
1102
2431
  if not processor.previous_image():
1103
2432
  prev_button.setEnabled(False)
1104
2433
  else:
1105
2434
  next_button.setEnabled(True)
1106
2435
  replace_table_widget()
1107
- # Reset sensitivity slider to default
1108
- sensitivity_slider.setValue(processor.sensitivity)
1109
- sensitivity_value_label.setText(f"{processor.sensitivity}")
1110
2436
  status_label.setText(
1111
2437
  f"Showing image {processor.current_index + 1}/{len(processor.images)}"
1112
2438
  )
2439
+ processor._ensure_points_layer_active()
1113
2440
 
1114
- sensitivity_slider.valueChanged.connect(on_sensitivity_changed)
1115
- apply_sensitivity_button.clicked.connect(on_apply_sensitivity_clicked)
2441
+ clear_points_button.clicked.connect(on_clear_points_clicked)
1116
2442
  select_all_button.clicked.connect(on_select_all_clicked)
1117
2443
  clear_selection_button.clicked.connect(on_clear_selection_clicked)
1118
2444
  crop_button.clicked.connect(on_crop_clicked)
1119
2445
  next_button.clicked.connect(on_next_clicked)
1120
2446
  prev_button.clicked.connect(on_prev_clicked)
2447
+ activate_button.clicked.connect(_ensure_points_layer_active)
1121
2448
 
1122
2449
  return crop_widget
1123
2450
 
1124
2451
 
1125
- # --------------------------------------------------
1126
- # Napari Plugin Functions
1127
- # --------------------------------------------------
1128
-
1129
-
1130
2452
  @magicgui(
1131
2453
  call_button="Start Batch Crop Anything",
1132
2454
  folder_path={"label": "Folder Path", "widget_type": "LineEdit"},
2455
+ data_dimensions={
2456
+ "label": "Data Dimensions",
2457
+ "choices": ["YX (2D)", "TYX/ZYX (3D)"],
2458
+ },
1133
2459
  )
1134
2460
  def batch_crop_anything(
1135
2461
  folder_path: str,
2462
+ data_dimensions: str,
1136
2463
  viewer: Viewer = None,
1137
2464
  ):
1138
- """MagicGUI widget for starting Batch Crop Anything."""
1139
- # Check if Mobile-SAM is available
2465
+ """MagicGUI widget for starting Batch Crop Anything using SAM2."""
2466
+ # Check if SAM2 is available
1140
2467
  try:
1141
- # import torch
1142
- # from mobile_sam import sam_model_registry
1143
-
1144
- # Check if the required files are included with the package
1145
- try:
1146
- import importlib.util
1147
- import os
1148
-
1149
- mobile_sam_spec = importlib.util.find_spec("mobile_sam")
1150
- if mobile_sam_spec is None:
1151
- raise ImportError("mobile_sam package not found")
1152
-
1153
- mobile_sam_path = os.path.dirname(mobile_sam_spec.origin)
1154
-
1155
- # Check for model file in package
1156
- model_found = False
1157
- checkpoint_paths = [
1158
- os.path.join(mobile_sam_path, "weights", "mobile_sam.pt"),
1159
- os.path.join(mobile_sam_path, "mobile_sam.pt"),
1160
- os.path.join(
1161
- os.path.dirname(mobile_sam_path),
1162
- "weights",
1163
- "mobile_sam.pt",
1164
- ),
1165
- os.path.join(
1166
- os.path.expanduser("~"), "models", "mobile_sam.pt"
1167
- ),
1168
- "/opt/T-MIDAS/models/mobile_sam.pt",
1169
- os.path.join(os.getcwd(), "mobile_sam.pt"),
1170
- ]
1171
-
1172
- for path in checkpoint_paths:
1173
- if os.path.exists(path):
1174
- model_found = True
1175
- break
1176
-
1177
- if not model_found:
1178
- QMessageBox.warning(
1179
- None,
1180
- "Model File Missing",
1181
- "Mobile-SAM model weights (mobile_sam.pt) not found. You'll be prompted to locate it when starting the tool.\n\n"
1182
- "You can download it from: https://github.com/ChaoningZhang/MobileSAM/tree/master/weights",
1183
- )
1184
- except (ImportError, AttributeError) as e:
1185
- print(f"Warning checking for model file: {str(e)}")
2468
+ import importlib.util
1186
2469
 
2470
+ sam2_spec = importlib.util.find_spec("sam2")
2471
+ if sam2_spec is None:
2472
+ QMessageBox.critical(
2473
+ None,
2474
+ "Missing Dependency",
2475
+ "SAM2 not found. Please follow installation instructions at:\n"
2476
+ "https://github.com/MercaderLabAnatomy/napari-tmidas?tab=readme-ov-file#dependencies\n",
2477
+ )
2478
+ return
1187
2479
  except ImportError:
1188
2480
  QMessageBox.critical(
1189
2481
  None,
1190
2482
  "Missing Dependency",
1191
- "Mobile-SAM not found. Please install with:\n"
1192
- "pip install git+https://github.com/ChaoningZhang/MobileSAM.git\n\n"
1193
- "You'll also need to download the model weights file (mobile_sam.pt) from:\n"
1194
- "https://github.com/ChaoningZhang/MobileSAM/tree/master/weights",
2483
+ "SAM2 package cannot be imported. Please follow installation instructions at\n"
2484
+ "https://github.com/MercaderLabAnatomy/napari-tmidas?tab=readme-ov-file#dependencies",
1195
2485
  )
1196
2486
  return
1197
2487
 
1198
- # Initialize processor and load images
1199
- processor = BatchCropAnything(viewer)
2488
+ # Initialize processor with the selected dimensions mode
2489
+ use_3d = "TYX/ZYX" in data_dimensions
2490
+ processor = BatchCropAnything(viewer, use_3d=use_3d)
1200
2491
  processor.load_images(folder_path)
1201
2492
 
1202
2493
  # Create UI
@@ -1205,13 +2496,9 @@ def batch_crop_anything(
1205
2496
  # Wrap the widget in a scroll area
1206
2497
  scroll_area = QScrollArea()
1207
2498
  scroll_area.setWidget(crop_widget)
1208
- scroll_area.setWidgetResizable(
1209
- True
1210
- ) # This allows the widget to resize with the scroll area
1211
- scroll_area.setFrameShape(QScrollArea.NoFrame) # Hide the frame
1212
- scroll_area.setMinimumHeight(
1213
- 500
1214
- ) # Set a minimum height to ensure visibility
2499
+ scroll_area.setWidgetResizable(True)
2500
+ scroll_area.setFrameShape(QScrollArea.NoFrame)
2501
+ scroll_area.setMinimumHeight(500)
1215
2502
 
1216
2503
  # Add scroll area to viewer
1217
2504
  viewer.window.add_dock_widget(scroll_area, name="Crop Controls")