napari-tmidas 0.2.1__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. napari_tmidas/__init__.py +35 -5
  2. napari_tmidas/_crop_anything.py +1458 -499
  3. napari_tmidas/_env_manager.py +76 -0
  4. napari_tmidas/_file_conversion.py +1646 -1131
  5. napari_tmidas/_file_selector.py +1464 -223
  6. napari_tmidas/_label_inspection.py +83 -8
  7. napari_tmidas/_processing_worker.py +309 -0
  8. napari_tmidas/_reader.py +6 -10
  9. napari_tmidas/_registry.py +15 -14
  10. napari_tmidas/_roi_colocalization.py +1221 -84
  11. napari_tmidas/_tests/test_crop_anything.py +123 -0
  12. napari_tmidas/_tests/test_env_manager.py +89 -0
  13. napari_tmidas/_tests/test_file_selector.py +90 -0
  14. napari_tmidas/_tests/test_grid_view_overlay.py +193 -0
  15. napari_tmidas/_tests/test_init.py +98 -0
  16. napari_tmidas/_tests/test_intensity_label_filter.py +222 -0
  17. napari_tmidas/_tests/test_label_inspection.py +86 -0
  18. napari_tmidas/_tests/test_processing_basic.py +500 -0
  19. napari_tmidas/_tests/test_processing_worker.py +142 -0
  20. napari_tmidas/_tests/test_regionprops_analysis.py +547 -0
  21. napari_tmidas/_tests/test_registry.py +135 -0
  22. napari_tmidas/_tests/test_scipy_filters.py +168 -0
  23. napari_tmidas/_tests/test_skimage_filters.py +259 -0
  24. napari_tmidas/_tests/test_split_channels.py +217 -0
  25. napari_tmidas/_tests/test_spotiflow.py +87 -0
  26. napari_tmidas/_tests/test_tyx_display_fix.py +142 -0
  27. napari_tmidas/_tests/test_ui_utils.py +68 -0
  28. napari_tmidas/_tests/test_widget.py +30 -0
  29. napari_tmidas/_tests/test_windows_basic.py +66 -0
  30. napari_tmidas/_ui_utils.py +57 -0
  31. napari_tmidas/_version.py +16 -3
  32. napari_tmidas/_widget.py +41 -4
  33. napari_tmidas/processing_functions/basic.py +557 -20
  34. napari_tmidas/processing_functions/careamics_env_manager.py +72 -99
  35. napari_tmidas/processing_functions/cellpose_env_manager.py +415 -112
  36. napari_tmidas/processing_functions/cellpose_segmentation.py +132 -191
  37. napari_tmidas/processing_functions/colocalization.py +513 -56
  38. napari_tmidas/processing_functions/grid_view_overlay.py +703 -0
  39. napari_tmidas/processing_functions/intensity_label_filter.py +422 -0
  40. napari_tmidas/processing_functions/regionprops_analysis.py +1280 -0
  41. napari_tmidas/processing_functions/sam2_env_manager.py +53 -69
  42. napari_tmidas/processing_functions/sam2_mp4.py +274 -195
  43. napari_tmidas/processing_functions/scipy_filters.py +403 -8
  44. napari_tmidas/processing_functions/skimage_filters.py +424 -212
  45. napari_tmidas/processing_functions/spotiflow_detection.py +949 -0
  46. napari_tmidas/processing_functions/spotiflow_env_manager.py +591 -0
  47. napari_tmidas/processing_functions/timepoint_merger.py +334 -86
  48. napari_tmidas/processing_functions/trackastra_tracking.py +24 -5
  49. {napari_tmidas-0.2.1.dist-info → napari_tmidas-0.2.4.dist-info}/METADATA +92 -39
  50. napari_tmidas-0.2.4.dist-info/RECORD +63 -0
  51. napari_tmidas/_tests/__init__.py +0 -0
  52. napari_tmidas-0.2.1.dist-info/RECORD +0 -38
  53. {napari_tmidas-0.2.1.dist-info → napari_tmidas-0.2.4.dist-info}/WHEEL +0 -0
  54. {napari_tmidas-0.2.1.dist-info → napari_tmidas-0.2.4.dist-info}/entry_points.txt +0 -0
  55. {napari_tmidas-0.2.1.dist-info → napari_tmidas-0.2.4.dist-info}/licenses/LICENSE +0 -0
  56. {napari_tmidas-0.2.1.dist-info → napari_tmidas-0.2.4.dist-info}/top_level.txt +0 -0
@@ -8,42 +8,126 @@ The plugin supports both 2D (YX) and 3D (TYX/ZYX) data.
8
8
 
9
9
  import contextlib
10
10
  import os
11
-
12
- # Add this at the beginning of your plugin file
13
11
  import sys
12
+ from pathlib import Path
14
13
 
15
- sys.path.append("/opt/sam2")
16
14
  import numpy as np
17
- import requests
18
- import torch
19
- from magicgui import magicgui
20
- from napari.layers import Labels
21
- from napari.viewer import Viewer
22
- from qtpy.QtCore import Qt
23
- from qtpy.QtWidgets import (
24
- QCheckBox,
25
- QFileDialog,
26
- QHBoxLayout,
27
- QHeaderView,
28
- QLabel,
29
- QMessageBox,
30
- QPushButton,
31
- QScrollArea,
32
- QTableWidget,
33
- QTableWidgetItem,
34
- QVBoxLayout,
35
- QWidget,
36
- )
37
- from skimage.io import imread
38
- from skimage.transform import resize
39
- from tifffile import imwrite
40
15
 
16
+ # Lazy imports for optional heavy dependencies
17
+ try:
18
+ import requests
19
+
20
+ _HAS_REQUESTS = True
21
+ except ImportError:
22
+ requests = None
23
+ _HAS_REQUESTS = False
24
+
25
+ try:
26
+ import torch
27
+
28
+ _HAS_TORCH = True
29
+ except ImportError:
30
+ torch = None
31
+ _HAS_TORCH = False
32
+
33
+ try:
34
+ from magicgui import magicgui
35
+
36
+ _HAS_MAGICGUI = True
37
+ except ImportError:
38
+ # Create stub decorator
39
+ def magicgui(*args, **kwargs):
40
+ def decorator(func):
41
+ return func
42
+
43
+ if len(args) == 1 and callable(args[0]) and not kwargs:
44
+ return args[0]
45
+ return decorator
46
+
47
+ _HAS_MAGICGUI = False
48
+
49
+ try:
50
+ from napari.layers import Labels
51
+ from napari.viewer import Viewer
52
+
53
+ _HAS_NAPARI = True
54
+ except ImportError:
55
+ Labels = None
56
+ Viewer = None
57
+ _HAS_NAPARI = False
58
+
59
+ try:
60
+ from qtpy.QtCore import Qt
61
+ from qtpy.QtWidgets import (
62
+ QCheckBox,
63
+ QHBoxLayout,
64
+ QHeaderView,
65
+ QLabel,
66
+ QMessageBox,
67
+ QPushButton,
68
+ QScrollArea,
69
+ QTableWidget,
70
+ QTableWidgetItem,
71
+ QVBoxLayout,
72
+ QWidget,
73
+ )
74
+
75
+ _HAS_QTPY = True
76
+ except ImportError:
77
+ Qt = None
78
+ QCheckBox = QHBoxLayout = QHeaderView = QLabel = QMessageBox = None
79
+ QPushButton = QScrollArea = QTableWidget = QTableWidgetItem = None
80
+ QVBoxLayout = QWidget = None
81
+ _HAS_QTPY = False
82
+
83
+ try:
84
+ from skimage.io import imread
85
+ from skimage.transform import resize
86
+
87
+ _HAS_SKIMAGE = True
88
+ except ImportError:
89
+ imread = None
90
+ resize = None
91
+ _HAS_SKIMAGE = False
92
+
93
+ try:
94
+ from tifffile import imwrite
95
+
96
+ _HAS_TIFFFILE = True
97
+ except ImportError:
98
+ imwrite = None
99
+ _HAS_TIFFFILE = False
100
+
101
+ from napari_tmidas._file_selector import (
102
+ load_image_file as load_any_image,
103
+ )
104
+ from napari_tmidas._ui_utils import add_browse_button_to_folder_field
41
105
  from napari_tmidas.processing_functions.sam2_mp4 import tif_to_mp4
42
106
 
107
+ sam2_paths = [
108
+ os.environ.get("SAM2_PATH"),
109
+ "/opt/sam2",
110
+ os.path.expanduser("~/sam2"),
111
+ "./sam2",
112
+ ]
113
+
114
+ for path in sam2_paths:
115
+ if path and os.path.exists(path):
116
+ sys.path.append(path)
117
+ break
118
+ else:
119
+ print(
120
+ "Warning: SAM2 not found in common locations. Please set SAM2_PATH environment variable."
121
+ )
122
+
123
+
43
124
  def get_device():
44
125
  if sys.platform == "darwin":
45
126
  # MacOS: Only check for MPS
46
- if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
127
+ if (
128
+ hasattr(torch.backends, "mps")
129
+ and torch.backends.mps.is_available()
130
+ ):
47
131
  device = torch.device("mps")
48
132
  print("Using Apple Silicon GPU (MPS)")
49
133
  else:
@@ -60,8 +144,6 @@ def get_device():
60
144
  return device
61
145
 
62
146
 
63
-
64
-
65
147
  class BatchCropAnything:
66
148
  """Class for processing images with SAM2 and cropping selected objects."""
67
149
 
@@ -83,6 +165,7 @@ class BatchCropAnything:
83
165
  self.image_layer = None
84
166
  self.label_layer = None
85
167
  self.label_table_widget = None
168
+ self.shapes_layer = None
86
169
 
87
170
  # State tracking
88
171
  self.selected_labels = set()
@@ -91,6 +174,9 @@ class BatchCropAnything:
91
174
  # Segmentation parameters
92
175
  self.sensitivity = 50 # Default sensitivity (0-100 scale)
93
176
 
177
+ # Prompt mode: 'point' or 'box'
178
+ self.prompt_mode = "point"
179
+
94
180
  # Initialize the SAM2 model
95
181
  self._initialize_sam2()
96
182
 
@@ -104,7 +190,7 @@ class BatchCropAnything:
104
190
  filename = os.path.join(dest_folder, url.split("/")[-1])
105
191
  if not os.path.exists(filename):
106
192
  print(f"Downloading checkpoint to {filename}...")
107
- response = requests.get(url, stream=True)
193
+ response = requests.get(url, stream=True, timeout=30)
108
194
  response.raise_for_status()
109
195
  with open(filename, "wb") as f:
110
196
  for chunk in response.iter_content(chunk_size=8192):
@@ -116,17 +202,45 @@ class BatchCropAnything:
116
202
 
117
203
  try:
118
204
  # import torch
205
+ print("DEBUG: Starting SAM2 initialization...")
119
206
 
120
207
  self.device = get_device()
208
+ print(f"DEBUG: Device set to {self.device}")
121
209
 
122
210
  # Download checkpoint if needed
123
211
  checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt"
124
212
  checkpoint_path = download_checkpoint(
125
213
  checkpoint_url, "/opt/sam2/checkpoints/"
126
214
  )
215
+ print(f"DEBUG: Checkpoint path: {checkpoint_path}")
216
+
217
+ # Use relative config path for SAM2's Hydra config system
127
218
  model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
219
+ print(f"DEBUG: Model config: {model_cfg}")
220
+
221
+ # Verify the actual config file exists in the SAM2 installation
222
+ sam2_base_path = None
223
+ for path in sam2_paths:
224
+ if path and os.path.exists(path):
225
+ sam2_base_path = path
226
+ break
227
+
228
+ if sam2_base_path is not None:
229
+ full_config_path = os.path.join(
230
+ sam2_base_path, "sam2", model_cfg
231
+ )
232
+ if not os.path.exists(full_config_path):
233
+ raise FileNotFoundError(
234
+ f"SAM2 config file not found at: {full_config_path}"
235
+ )
236
+ print(f"DEBUG: Verified config exists at: {full_config_path}")
237
+ else:
238
+ print(
239
+ "DEBUG: Warning - could not verify config file exists, but proceeding with relative path"
240
+ )
128
241
 
129
242
  if self.use_3d:
243
+ print("DEBUG: Initializing SAM2 Video Predictor...")
130
244
  from sam2.build_sam import build_sam2_video_predictor
131
245
 
132
246
  self.predictor = build_sam2_video_predictor(
@@ -135,7 +249,9 @@ class BatchCropAnything:
135
249
  self.viewer.status = (
136
250
  f"Initialized SAM2 Video Predictor on {self.device}"
137
251
  )
252
+ print("DEBUG: SAM2 Video Predictor initialized successfully")
138
253
  else:
254
+ print("DEBUG: Initializing SAM2 Image Predictor...")
139
255
  from sam2.build_sam import build_sam2
140
256
  from sam2.sam2_image_predictor import SAM2ImagePredictor
141
257
 
@@ -145,6 +261,7 @@ class BatchCropAnything:
145
261
  self.viewer.status = (
146
262
  f"Initialized SAM2 Image Predictor on {self.device}"
147
263
  )
264
+ print("DEBUG: SAM2 Image Predictor initialized successfully")
148
265
 
149
266
  except (
150
267
  ImportError,
@@ -152,37 +269,79 @@ class BatchCropAnything:
152
269
  ValueError,
153
270
  FileNotFoundError,
154
271
  requests.RequestException,
272
+ AttributeError,
273
+ ModuleNotFoundError,
155
274
  ) as e:
156
275
  import traceback
157
276
 
158
- self.viewer.status = f"Error initializing SAM2: {str(e)}"
277
+ error_msg = f"SAM2 initialization failed: {str(e)}"
278
+ error_type = type(e).__name__
279
+ self.viewer.status = (
280
+ f"{error_msg} - Images will load without segmentation"
281
+ )
159
282
  self.predictor = None
283
+ print(f"DEBUG: SAM2 Error ({error_type}): {error_msg}")
284
+ print("DEBUG: Full traceback:")
160
285
  print(traceback.format_exc())
286
+ print(
287
+ "DEBUG: Note: Images will still load, but automatic segmentation will not be available."
288
+ )
289
+
290
+ # Provide specific guidance based on error type
291
+ if isinstance(e, FileNotFoundError):
292
+ print(
293
+ "DEBUG: This appears to be a missing file issue. Check SAM2 installation and config paths."
294
+ )
295
+ elif isinstance(e, (ImportError, ModuleNotFoundError)):
296
+ print(
297
+ "DEBUG: This appears to be a SAM2 import issue. Check SAM2 installation."
298
+ )
299
+ elif isinstance(e, RuntimeError):
300
+ print(
301
+ "DEBUG: This appears to be a runtime issue, possibly GPU/CUDA related."
302
+ )
303
+ else:
304
+ print(f"DEBUG: Unexpected error type: {error_type}")
161
305
 
162
306
  def load_images(self, folder_path: str):
163
307
  """Load images from the specified folder path."""
308
+ print(f"DEBUG: Loading images from folder: {folder_path}")
164
309
  if not os.path.exists(folder_path):
165
310
  self.viewer.status = f"Folder not found: {folder_path}"
311
+ print(f"DEBUG: Folder does not exist: {folder_path}")
166
312
  return
167
313
 
168
314
  files = os.listdir(folder_path)
169
- self.images = [
170
- os.path.join(folder_path, file)
171
- for file in files
172
- if file.lower().endswith(".tif")
173
- or file.lower().endswith(".tiff")
174
- and "label" not in file.lower()
175
- and "cropped" not in file.lower()
176
- and "_labels_" not in file.lower()
177
- and "_cropped_" not in file.lower()
178
- ]
315
+ print(f"DEBUG: Found {len(files)} files in folder")
316
+ self.images = []
317
+ for file in files:
318
+ full = os.path.join(folder_path, file)
319
+ low = file.lower()
320
+ if (
321
+ low.endswith((".tif", ".tiff"))
322
+ or (os.path.isdir(full) and low.endswith(".zarr"))
323
+ ) and (
324
+ "label" not in low
325
+ and "_labels_" not in low
326
+ and "sam2"
327
+ not in low # Exclude any SAM2-related files (including output from this tool)
328
+ ):
329
+ self.images.append(full)
330
+ print(f"DEBUG: Added image: {file}")
331
+ else:
332
+ print(
333
+ f"DEBUG: Excluded file: {file} (reason: filtering criteria)"
334
+ )
179
335
 
180
336
  if not self.images:
181
337
  self.viewer.status = "No compatible images found in the folder."
338
+ print("DEBUG: No compatible images found")
182
339
  return
183
340
 
341
+ print(f"DEBUG: Total compatible images found: {len(self.images)}")
184
342
  self.viewer.status = f"Found {len(self.images)} .tif images."
185
343
  self.current_index = 0
344
+ print(f"DEBUG: About to load first image: {self.images[0]}")
186
345
  self._load_current_image()
187
346
 
188
347
  def next_image(self):
@@ -235,25 +394,69 @@ class BatchCropAnything:
235
394
 
236
395
  def _load_current_image(self):
237
396
  """Load the current image and generate segmentation."""
397
+ print("DEBUG: _load_current_image called")
238
398
  if not self.images:
239
399
  self.viewer.status = "No images to process."
240
- return
241
-
242
- if self.predictor is None:
243
- self.viewer.status = (
244
- "SAM2 model not initialized. Cannot segment images."
245
- )
400
+ print("DEBUG: No images to process")
246
401
  return
247
402
 
248
403
  image_path = self.images[self.current_index]
249
- self.viewer.status = f"Processing {os.path.basename(image_path)}"
404
+ print(f"DEBUG: Loading image at path: {image_path}")
405
+
406
+ if self.predictor is None:
407
+ self.viewer.status = f"Loading {os.path.basename(image_path)} (SAM2 model not initialized - no segmentation will be available)"
408
+ print("DEBUG: SAM2 predictor is None")
409
+ else:
410
+ self.viewer.status = f"Processing {os.path.basename(image_path)}"
411
+ print("DEBUG: SAM2 predictor is available")
250
412
 
251
413
  try:
414
+ print("DEBUG: About to clear viewer layers")
252
415
  # Clear existing layers
253
416
  self.viewer.layers.clear()
417
+ print("DEBUG: Viewer layers cleared")
254
418
 
419
+ print("DEBUG: About to load image file")
255
420
  # Load and process image
256
- self.original_image = imread(image_path)
421
+ if image_path.lower().endswith(".zarr") or (
422
+ os.path.isdir(image_path)
423
+ and image_path.lower().endswith(".zarr")
424
+ ):
425
+ print("DEBUG: Loading Zarr file")
426
+ data = load_any_image(image_path)
427
+ # If multiple layers returned, take first image layer
428
+ if isinstance(data, list):
429
+ img = None
430
+ for entry in data:
431
+ if isinstance(entry, tuple) and len(entry) == 3:
432
+ d, _kwargs, layer_type = entry
433
+ if layer_type == "image":
434
+ img = d
435
+ break
436
+ elif isinstance(entry, tuple) and len(entry) == 2:
437
+ d, _kwargs = entry
438
+ img = d
439
+ break
440
+ else:
441
+ img = entry
442
+ break
443
+ if img is None:
444
+ raise ValueError("No image layer found in Zarr store")
445
+ else:
446
+ img = data
447
+
448
+ # Compute dask arrays to numpy if needed
449
+ if hasattr(img, "compute"):
450
+ img = img.compute()
451
+
452
+ self.original_image = img
453
+ else:
454
+ print("DEBUG: Loading TIFF file")
455
+ self.original_image = imread(image_path)
456
+
457
+ print(
458
+ f"DEBUG: Image loaded, shape: {self.original_image.shape}, dtype: {self.original_image.dtype}"
459
+ )
257
460
 
258
461
  # For 3D/4D data, determine dimensions
259
462
  if self.use_3d and len(self.original_image.shape) >= 3:
@@ -269,10 +472,12 @@ class BatchCropAnything:
269
472
 
270
473
  if time_dim_idx == 0: # TZYX format
271
474
  # Keep as is, T is already the first dimension
475
+ print("DEBUG: Adding 4D image (TZYX format) to viewer")
272
476
  self.image_layer = self.viewer.add_image(
273
477
  self.original_image,
274
478
  name=f"Image ({os.path.basename(image_path)})",
275
479
  )
480
+ print(f"DEBUG: Added image layer: {self.image_layer}")
276
481
  # Store time dimension info
277
482
  self.time_dim_size = self.original_image.shape[0]
278
483
  self.has_z_dim = True
@@ -294,19 +499,23 @@ class BatchCropAnything:
294
499
  transposed_image # Replace with transposed version
295
500
  )
296
501
 
502
+ print("DEBUG: Adding transposed 4D image to viewer")
297
503
  self.image_layer = self.viewer.add_image(
298
504
  self.original_image,
299
505
  name=f"Image ({os.path.basename(image_path)})",
300
506
  )
507
+ print(f"DEBUG: Added image layer: {self.image_layer}")
301
508
  # Store time dimension info
302
509
  self.time_dim_size = self.original_image.shape[0]
303
510
  self.has_z_dim = True
304
511
  else:
305
512
  # No time dimension found, treat as ZYX
513
+ print("DEBUG: Adding 4D image (ZYX format) to viewer")
306
514
  self.image_layer = self.viewer.add_image(
307
515
  self.original_image,
308
516
  name=f"Image ({os.path.basename(image_path)})",
309
517
  )
518
+ print(f"DEBUG: Added image layer: {self.image_layer}")
310
519
  self.time_dim_size = 1
311
520
  self.has_z_dim = True
312
521
  elif (
@@ -315,30 +524,37 @@ class BatchCropAnything:
315
524
  # Check if first dimension is likely time (> 4, < 400)
316
525
  if 4 < self.original_image.shape[0] < 400:
317
526
  # Likely TYX format
527
+ print("DEBUG: Adding 3D image (TYX format) to viewer")
318
528
  self.image_layer = self.viewer.add_image(
319
529
  self.original_image,
320
530
  name=f"Image ({os.path.basename(image_path)})",
321
531
  )
532
+ print(f"DEBUG: Added image layer: {self.image_layer}")
322
533
  self.time_dim_size = self.original_image.shape[0]
323
534
  self.has_z_dim = False
324
535
  else:
325
536
  # Likely ZYX format or another 3D format
537
+ print("DEBUG: Adding 3D image (ZYX format) to viewer")
326
538
  self.image_layer = self.viewer.add_image(
327
539
  self.original_image,
328
540
  name=f"Image ({os.path.basename(image_path)})",
329
541
  )
542
+ print(f"DEBUG: Added image layer: {self.image_layer}")
330
543
  self.time_dim_size = 1
331
544
  self.has_z_dim = True
332
545
  else:
333
546
  # Should not reach here with use_3d=True, but just in case
547
+ print("DEBUG: Adding 3D image (fallback) to viewer")
334
548
  self.image_layer = self.viewer.add_image(
335
549
  self.original_image,
336
550
  name=f"Image ({os.path.basename(image_path)})",
337
551
  )
552
+ print(f"DEBUG: Added image layer: {self.image_layer}")
338
553
  self.time_dim_size = 1
339
554
  self.has_z_dim = False
340
555
  else:
341
556
  # Handle 2D data as before
557
+ print("DEBUG: Processing 2D image")
342
558
  if self.original_image.dtype != np.uint8:
343
559
  image_for_display = (
344
560
  self.original_image
@@ -349,18 +565,42 @@ class BatchCropAnything:
349
565
  image_for_display = self.original_image
350
566
 
351
567
  # Add image to viewer
568
+ print("DEBUG: Adding 2D image to viewer")
352
569
  self.image_layer = self.viewer.add_image(
353
570
  image_for_display,
354
571
  name=f"Image ({os.path.basename(image_path)})",
355
572
  )
573
+ print(f"DEBUG: Added image layer: {self.image_layer}")
574
+
575
+ # Generate segmentation only if predictor is available
576
+ if self.predictor is not None:
577
+ print("DEBUG: About to generate segmentation")
578
+ self._generate_segmentation(self.original_image, image_path)
579
+ print("DEBUG: Segmentation generation completed")
580
+ else:
581
+ print("DEBUG: Creating empty segmentation (no predictor)")
582
+ # Create empty segmentation when predictor is not available
583
+ if self.use_3d:
584
+ shape = self.original_image.shape
585
+ else:
586
+ shape = self.original_image.shape[:2]
587
+
588
+ self.segmentation_result = np.zeros(shape, dtype=np.uint32)
589
+ self.label_layer = self.viewer.add_labels(
590
+ self.segmentation_result,
591
+ name="No Segmentation (SAM2 not available)",
592
+ )
593
+ print(f"DEBUG: Added empty label layer: {self.label_layer}")
356
594
 
357
- # Generate segmentation
358
- self._generate_segmentation(self.original_image, image_path)
595
+ print("DEBUG: _load_current_image completed successfully")
359
596
 
360
597
  except (FileNotFoundError, ValueError, TypeError, OSError) as e:
361
598
  import traceback
362
599
 
363
- self.viewer.status = f"Error processing image: {str(e)}"
600
+ error_msg = f"Error processing image: {str(e)}"
601
+ self.viewer.status = error_msg
602
+ print(f"DEBUG: Exception in _load_current_image: {error_msg}")
603
+ print("DEBUG: Full traceback:")
364
604
  traceback.print_exc()
365
605
 
366
606
  # Create empty segmentation in case of error
@@ -377,6 +617,7 @@ class BatchCropAnything:
377
617
  self.label_layer = self.viewer.add_labels(
378
618
  self.segmentation_result, name="Error: No Segmentation"
379
619
  )
620
+ print(f"DEBUG: Added error label layer: {self.label_layer}")
380
621
 
381
622
  def _generate_segmentation(self, image, image_path: str):
382
623
  """Generate segmentation for the current image using SAM2."""
@@ -432,7 +673,8 @@ class BatchCropAnything:
432
673
  traceback.print_exc()
433
674
 
434
675
  def _generate_2d_segmentation(self, confidence_threshold):
435
- """Generate 2D segmentation using SAM2 Image Predictor."""
676
+ """Generate initial 2D segmentation - start with empty labels for interactive mode."""
677
+ device_type = "cuda" if self.device.type == "cuda" else "cpu"
436
678
  # Ensure image is in the correct format for SAM2
437
679
  image = self.current_image_for_segmentation
438
680
 
@@ -454,9 +696,7 @@ class BatchCropAnything:
454
696
  (new_height, new_width),
455
697
  anti_aliasing=True,
456
698
  preserve_range=True,
457
- ).astype(
458
- np.float32
459
- ) # Convert to float32
699
+ ).astype(np.float32)
460
700
 
461
701
  self.current_scale_factor = scale_factor
462
702
  else:
@@ -482,73 +722,54 @@ class BatchCropAnything:
482
722
  if resized_image.max() > 1.0:
483
723
  resized_image = resized_image / 255.0
484
724
 
485
- # Set SAM2 prediction parameters based on sensitivity
486
- with torch.inference_mode(), torch.autocast(
487
- "cuda", dtype=torch.float32
488
- ):
489
- # Set the image in the predictor
490
- self.predictor.set_image(resized_image)
725
+ # Store the prepared image for later use
726
+ self.prepared_sam2_image = resized_image
491
727
 
492
- # Use automatic points generation with confidence threshold
493
- masks, scores, _ = self.predictor.predict(
494
- point_coords=None,
495
- point_labels=None,
496
- box=None,
497
- multimask_output=True,
498
- )
728
+ # Initialize empty segmentation result
729
+ self.segmentation_result = np.zeros(orig_shape, dtype=np.uint32)
730
+ self.label_info = {}
499
731
 
500
- # Filter masks by confidence threshold
501
- valid_masks = scores > confidence_threshold
502
- masks = masks[valid_masks]
503
- scores = scores[valid_masks]
504
-
505
- # Convert masks to label image
506
- labels = np.zeros(resized_image.shape[:2], dtype=np.uint32)
507
- self.label_info = {} # Reset label info
508
-
509
- for i, mask in enumerate(masks):
510
- label_id = i + 1 # Start label IDs from 1
511
- labels[mask] = label_id
512
-
513
- # Calculate label information
514
- area = np.sum(mask)
515
- y_indices, x_indices = np.where(mask)
516
- center_y = np.mean(y_indices) if len(y_indices) > 0 else 0
517
- center_x = np.mean(x_indices) if len(x_indices) > 0 else 0
518
-
519
- # Store label info
520
- self.label_info[label_id] = {
521
- "area": area,
522
- "center_y": center_y,
523
- "center_x": center_x,
524
- "score": float(scores[i]),
525
- }
526
-
527
- # Handle upscaling if needed
528
- if self.current_scale_factor < 1.0:
529
- labels = resize(
530
- labels,
531
- orig_shape,
532
- order=0, # Nearest neighbor interpolation
533
- preserve_range=True,
534
- anti_aliasing=False,
535
- ).astype(np.uint32)
536
-
537
- # Sort labels by area (largest first)
538
- self.label_info = dict(
539
- sorted(
540
- self.label_info.items(),
541
- key=lambda item: item[1]["area"],
542
- reverse=True,
543
- )
732
+ # Initialize tracking for interactive segmentation
733
+ self.current_points = []
734
+ self.current_labels = []
735
+ self.current_obj_id = 1
736
+ self.next_obj_id = 1
737
+
738
+ # Initialize object tracking dictionaries
739
+ self.obj_points = {}
740
+ self.obj_labels = {}
741
+
742
+ # Reset SAM2-specific tracking dictionaries for 2D mode
743
+ self.sam2_points_by_obj = {}
744
+ self.sam2_labels_by_obj = {}
745
+ self._sam2_next_obj_id = 1
746
+ print(
747
+ "DEBUG: Reset _sam2_next_obj_id to 1 in _generate_2d_segmentation"
544
748
  )
545
749
 
546
- # Save segmentation result
547
- self.segmentation_result = labels
750
+ # Set the image in the predictor for later use (2D mode only)
751
+ device_type = "cuda" if self.device.type == "cuda" else "cpu"
752
+ if hasattr(self.predictor, "set_image"):
753
+ with (
754
+ torch.inference_mode(),
755
+ torch.autocast(device_type, dtype=torch.float32),
756
+ ):
757
+ self.predictor.set_image(resized_image)
758
+ else:
759
+ print(
760
+ "DEBUG: Skipping set_image - predictor doesn't support it (likely VideoPredictor)"
761
+ )
548
762
 
549
763
  # Update the label layer
550
764
  self._update_label_layer()
551
765
 
766
+ # Show instructions
767
+ self.viewer.status = (
768
+ "2D Mode: Click on the image to add objects. Use Shift+click for negative points to refine. "
769
+ "Click existing objects to select them for cropping. "
770
+ "Note: For stacks, interactive segmentation only works in 2D view mode."
771
+ )
772
+
552
773
  def _generate_3d_segmentation(self, confidence_threshold, image_path):
553
774
  """
554
775
  Initialize 3D segmentation using SAM2 Video Predictor.
@@ -569,9 +790,7 @@ class BatchCropAnything:
569
790
  import tempfile
570
791
 
571
792
  temp_dir = tempfile.gettempdir()
572
- mp4_path = os.path.join(
573
- temp_dir, f"temp_volume_{os.path.basename(image_path)}.mp4"
574
- )
793
+ mp4_path = None
575
794
 
576
795
  # If we need to save a modified version for MP4 conversion
577
796
  need_temp_tif = False
@@ -601,31 +820,72 @@ class BatchCropAnything:
601
820
  imwrite(temp_tif_path, projected_volume)
602
821
  need_temp_tif = True
603
822
 
604
- # Convert the projected TIF to MP4
605
- self.viewer.status = (
606
- "Converting projected 3D volume to MP4 format for SAM2..."
607
- )
608
- mp4_path = tif_to_mp4(temp_tif_path)
823
+ # Check if MP4 already exists
824
+ expected_mp4 = str(Path(temp_tif_path).with_suffix(".mp4"))
825
+ if os.path.exists(expected_mp4):
826
+ self.viewer.status = (
827
+ f"Using existing MP4: {os.path.basename(expected_mp4)}"
828
+ )
829
+ print(
830
+ f"DEBUG: MP4 already exists, skipping conversion: {expected_mp4}"
831
+ )
832
+ mp4_path = expected_mp4
833
+ else:
834
+ # Convert the projected TIF to MP4
835
+ self.viewer.status = "Converting projected 3D volume to MP4 format for SAM2..."
836
+ mp4_path = tif_to_mp4(temp_tif_path)
609
837
  else:
610
- # Convert original volume to video format for SAM2
611
- self.viewer.status = (
612
- "Converting 3D volume to MP4 format for SAM2..."
613
- )
614
- mp4_path = tif_to_mp4(image_path)
838
+ # Check if MP4 already exists for the original image
839
+ expected_mp4 = str(Path(image_path).with_suffix(".mp4"))
840
+ if os.path.exists(expected_mp4):
841
+ self.viewer.status = (
842
+ f"Using existing MP4: {os.path.basename(expected_mp4)}"
843
+ )
844
+ print(
845
+ f"DEBUG: MP4 already exists, skipping conversion: {expected_mp4}"
846
+ )
847
+ mp4_path = expected_mp4
848
+ else:
849
+ # Convert original volume to video format for SAM2
850
+ self.viewer.status = (
851
+ "Converting 3D volume to MP4 format for SAM2..."
852
+ )
853
+ mp4_path = tif_to_mp4(image_path)
615
854
 
616
855
  # Initialize SAM2 state with the video
617
856
  self.viewer.status = "Initializing SAM2 Video Predictor..."
618
- with torch.inference_mode(), torch.autocast(
619
- "cuda", dtype=torch.bfloat16
620
- ):
621
- self._sam2_state = self.predictor.init_state(mp4_path)
857
+ try:
858
+ device_type = "cuda" if self.device.type == "cuda" else "cpu"
859
+ with (
860
+ torch.inference_mode(),
861
+ torch.autocast(device_type, dtype=torch.float32),
862
+ ):
863
+ self._sam2_state = self.predictor.init_state(mp4_path)
864
+ except (
865
+ RuntimeError,
866
+ ValueError,
867
+ TypeError,
868
+ torch.cuda.OutOfMemoryError,
869
+ ) as e:
870
+ self.viewer.status = (
871
+ f"Error initializing SAM2 video predictor: {str(e)}"
872
+ )
873
+ print(f"SAM2 video predictor initialization failed: {e}")
874
+ return
622
875
 
623
876
  # Store needed state for 3D processing
624
877
  self._sam2_next_obj_id = 1
878
+ print(
879
+ "DEBUG: Reset _sam2_next_obj_id to 1 in _generate_3d_segmentation"
880
+ )
625
881
  self._sam2_prompts = (
626
882
  {}
627
883
  ) # Store prompts for each object (points, labels, box)
628
884
 
885
+ # Reset SAM2-specific tracking dictionaries for 3D mode
886
+ self.sam2_points_by_obj = {}
887
+ self.sam2_labels_by_obj = {}
888
+
629
889
  # Update the label layer with empty segmentation
630
890
  self._update_label_layer()
631
891
 
@@ -633,8 +893,10 @@ class BatchCropAnything:
633
893
  if self.label_layer is not None and hasattr(
634
894
  self.label_layer, "mouse_drag_callbacks"
635
895
  ):
896
+ # Safely remove all existing callbacks
636
897
  for callback in list(self.label_layer.mouse_drag_callbacks):
637
- self.label_layer.mouse_drag_callbacks.remove(callback)
898
+ with contextlib.suppress(ValueError):
899
+ self.label_layer.mouse_drag_callbacks.remove(callback)
638
900
 
639
901
  # Add 3D-specific click handler
640
902
  self.label_layer.mouse_drag_callbacks.append(
@@ -658,8 +920,8 @@ class BatchCropAnything:
658
920
 
659
921
  # Show instructions
660
922
  self.viewer.status = (
661
- "3D Mode active: Navigate to the first frame where object appears, then click. "
662
- "Use Shift+click for negative points (to remove areas). "
923
+ "3D Mode active: IMPORTANT - Navigate to the FIRST SLICE where object appears (using slider), "
924
+ "then click on object in 2D view (not 3D view). Use Shift+click for negative points. "
663
925
  "Segmentation will be propagated to all frames automatically."
664
926
  )
665
927
 
@@ -713,6 +975,9 @@ class BatchCropAnything:
713
975
  # Create new object for positive points on background
714
976
  ann_obj_id = self._sam2_next_obj_id
715
977
  if point_label > 0 and label_id == 0:
978
+ print(
979
+ f"DEBUG: Incrementing _sam2_next_obj_id from {self._sam2_next_obj_id} to {self._sam2_next_obj_id + 1}"
980
+ )
716
981
  self._sam2_next_obj_id += 1
717
982
 
718
983
  # Find or create points layer for this object
@@ -733,6 +998,15 @@ class BatchCropAnything:
733
998
  border_width=1,
734
999
  opacity=0.8,
735
1000
  )
1001
+
1002
+ with contextlib.suppress(AttributeError, ValueError):
1003
+ points_layer.mouse_drag_callbacks.remove(
1004
+ self._on_points_clicked
1005
+ )
1006
+ points_layer.mouse_drag_callbacks.append(
1007
+ self._on_points_clicked
1008
+ )
1009
+
736
1010
  # Initialize points for this object
737
1011
  if not hasattr(self, "sam2_points_by_obj"):
738
1012
  self.sam2_points_by_obj = {}
@@ -891,8 +1165,10 @@ class BatchCropAnything:
891
1165
  # Try to perform SAM2 propagation with error handling
892
1166
  try:
893
1167
  # Use torch.inference_mode() and torch.autocast to ensure consistent dtypes
894
- with torch.inference_mode(), torch.autocast(
895
- "cuda", dtype=torch.float32
1168
+ device_type = "cuda" if self.device.type == "cuda" else "cpu"
1169
+ with (
1170
+ torch.inference_mode(),
1171
+ torch.autocast(device_type, dtype=torch.float32),
896
1172
  ):
897
1173
  # Attempt to run SAM2 propagation - this will iterate through all frames
898
1174
  for (
@@ -988,7 +1264,11 @@ class BatchCropAnything:
988
1264
  time.sleep(2)
989
1265
  for layer in list(self.viewer.layers):
990
1266
  if "Propagation Progress" in layer.name:
991
- self.viewer.layers.remove(layer)
1267
+ # Clean up callbacks before removing the layer to prevent cleanup issues
1268
+ if hasattr(layer, "mouse_drag_callbacks"):
1269
+ layer.mouse_drag_callbacks.clear()
1270
+ with contextlib.suppress(ValueError):
1271
+ self.viewer.layers.remove(layer)
992
1272
 
993
1273
  threading.Thread(target=remove_progress).start()
994
1274
 
@@ -1011,6 +1291,7 @@ class BatchCropAnything:
1011
1291
  Given a 3D coordinate (x, y, z), run SAM2 video predictor to segment the object at that point,
1012
1292
  update the segmentation result and label layer.
1013
1293
  """
1294
+ device_type = "cuda" if self.device.type == "cuda" else "cpu"
1014
1295
  if not hasattr(self, "_sam2_state") or self._sam2_state is None:
1015
1296
  self.viewer.status = "SAM2 3D state not initialized."
1016
1297
  return
@@ -1024,8 +1305,9 @@ class BatchCropAnything:
1024
1305
  point_coords = np.array([[x, y, z]])
1025
1306
  point_labels = np.array([1]) # 1 = foreground
1026
1307
 
1027
- with torch.inference_mode(), torch.autocast(
1028
- "cuda", dtype=torch.bfloat16
1308
+ with (
1309
+ torch.inference_mode(),
1310
+ torch.autocast(device_type, dtype=torch.float32),
1029
1311
  ):
1030
1312
  masks, scores, _ = self.predictor.predict(
1031
1313
  state=self._sam2_state,
@@ -1079,7 +1361,11 @@ class BatchCropAnything:
1079
1361
  # Remove existing label layer if it exists
1080
1362
  for layer in list(self.viewer.layers):
1081
1363
  if isinstance(layer, Labels) and "Segmentation" in layer.name:
1082
- self.viewer.layers.remove(layer)
1364
+ # Clean up callbacks before removing the layer to prevent cleanup issues
1365
+ if hasattr(layer, "mouse_drag_callbacks"):
1366
+ layer.mouse_drag_callbacks.clear()
1367
+ with contextlib.suppress(ValueError):
1368
+ self.viewer.layers.remove(layer)
1083
1369
 
1084
1370
  # Add label layer to viewer
1085
1371
  self.label_layer = self.viewer.add_labels(
@@ -1088,10 +1374,36 @@ class BatchCropAnything:
1088
1374
  opacity=0.7,
1089
1375
  )
1090
1376
 
1091
- # Create points layer for interaction if it doesn't exist
1377
+ # Connect click handler to the label layer for selection and deletion
1378
+ if hasattr(self.label_layer, "mouse_drag_callbacks"):
1379
+ # Clear existing callbacks to avoid duplicates
1380
+ self.label_layer.mouse_drag_callbacks.clear()
1381
+ # Add our click handler
1382
+ self.label_layer.mouse_drag_callbacks.append(
1383
+ self._on_label_clicked
1384
+ )
1385
+
1386
+ # Create or update interaction layers based on mode
1387
+ if self.prompt_mode == "point":
1388
+ self._ensure_points_layer()
1389
+ self._remove_shapes_layer()
1390
+ else: # box mode
1391
+ self._ensure_shapes_layer()
1392
+ self._remove_points_layer()
1393
+
1394
+ # Update status
1395
+ n_labels = len(np.unique(self.segmentation_result)) - (
1396
+ 1 if 0 in np.unique(self.segmentation_result) else 0
1397
+ )
1398
+ self.viewer.status = f"Loaded image {self.current_index + 1}/{len(self.images)} - Found {n_labels} segments"
1399
+
1400
+ def _ensure_points_layer(self):
1401
+ """Ensure points layer exists and is properly configured."""
1092
1402
  points_layer = None
1093
1403
  for layer in list(self.viewer.layers):
1094
- if "Points" in layer.name:
1404
+ if (
1405
+ "Points" in layer.name and "Object" not in layer.name
1406
+ ): # Main points layer
1095
1407
  points_layer = layer
1096
1408
  break
1097
1409
 
@@ -1108,24 +1420,424 @@ class BatchCropAnything:
1108
1420
  )
1109
1421
 
1110
1422
  # Connect points layer mouse click event
1111
- points_layer.mouse_drag_callbacks.append(self._on_points_clicked)
1423
+ if hasattr(points_layer, "mouse_drag_callbacks"):
1424
+ points_layer.mouse_drag_callbacks.clear()
1425
+ points_layer.mouse_drag_callbacks.append(
1426
+ self._on_points_clicked
1427
+ )
1112
1428
 
1113
1429
  # Make the points layer active to encourage interaction with it
1114
1430
  self.viewer.layers.selection.active = points_layer
1115
1431
 
1116
- # Update status
1117
- n_labels = len(np.unique(self.segmentation_result)) - (
1118
- 1 if 0 in np.unique(self.segmentation_result) else 0
1119
- )
1120
- self.viewer.status = f"Loaded image {self.current_index + 1}/{len(self.images)} - Found {n_labels} segments"
1432
+ def _ensure_shapes_layer(self):
1433
+ """Ensure shapes layer exists and is properly configured."""
1434
+ shapes_layer = None
1435
+ for layer in list(self.viewer.layers):
1436
+ if "Rectangles" in layer.name:
1437
+ shapes_layer = layer
1438
+ break
1439
+
1440
+ if shapes_layer is None:
1441
+ # Initialize an empty shapes layer
1442
+ shapes_layer = self.viewer.add_shapes(
1443
+ None,
1444
+ shape_type="rectangle",
1445
+ edge_width=3,
1446
+ edge_color="green",
1447
+ face_color="transparent",
1448
+ name="Rectangles (Draw to Segment)",
1449
+ )
1450
+
1451
+ # Store reference
1452
+ self.shapes_layer = shapes_layer
1453
+
1454
+ # Initialize processing flag to prevent re-entry
1455
+ if not hasattr(self, "_processing_rectangle"):
1456
+ self._processing_rectangle = False
1457
+
1458
+ # Always ensure the event is connected (disconnect old ones first to avoid duplicates)
1459
+ # Remove any existing callbacks
1460
+ with contextlib.suppress(Exception):
1461
+ shapes_layer.events.data.disconnect()
1462
+
1463
+ # Connect shape added event
1464
+ @shapes_layer.events.data.connect
1465
+ def on_shape_added(event):
1466
+ print(
1467
+ f"DEBUG: Shape event triggered! Shapes: {len(shapes_layer.data)}, Processing: {self._processing_rectangle}"
1468
+ )
1469
+
1470
+ # Ignore if we're already processing or if there are no shapes
1471
+ if self._processing_rectangle:
1472
+ print("DEBUG: Already processing a rectangle, ignoring event")
1473
+ return
1474
+
1475
+ if len(shapes_layer.data) == 0:
1476
+ print("DEBUG: No shapes present, ignoring event")
1477
+ return
1478
+
1479
+ # Only process if we have exactly 1 shape (newly drawn)
1480
+ if len(shapes_layer.data) == 1:
1481
+ print("DEBUG: New shape detected, processing...")
1482
+ # Set flag to prevent re-entry
1483
+ self._processing_rectangle = True
1484
+ try:
1485
+ # Get the shape
1486
+ self._on_rectangle_added(shapes_layer.data[-1])
1487
+ finally:
1488
+ # Always reset flag
1489
+ self._processing_rectangle = False
1490
+ else:
1491
+ print(
1492
+ f"DEBUG: Multiple shapes present ({len(shapes_layer.data)}), skipping"
1493
+ )
1494
+
1495
+ # Make the shapes layer active
1496
+ self.viewer.layers.selection.active = shapes_layer
1497
+
1498
+ def _remove_points_layer(self):
1499
+ """Remove points layer when not in point mode."""
1500
+ for layer in list(self.viewer.layers):
1501
+ if "Points" in layer.name and "Object" not in layer.name:
1502
+ if hasattr(layer, "mouse_drag_callbacks"):
1503
+ layer.mouse_drag_callbacks.clear()
1504
+ with contextlib.suppress(ValueError):
1505
+ self.viewer.layers.remove(layer)
1506
+
1507
+ def _remove_shapes_layer(self):
1508
+ """Remove shapes layer when not in box mode."""
1509
+ for layer in list(self.viewer.layers):
1510
+ if "Rectangles" in layer.name:
1511
+ with contextlib.suppress(ValueError):
1512
+ self.viewer.layers.remove(layer)
1513
+ self.shapes_layer = None
1514
+
1515
+ def _on_rectangle_added(self, rectangle_coords):
1516
+ """Handle rectangle selection for segmentation."""
1517
+ print("DEBUG: _on_rectangle_added called!")
1518
+ device_type = "cuda" if self.device.type == "cuda" else "cpu"
1519
+ try:
1520
+ # Rectangle coords are in the form of a 4x2 or 4x3 array (corners)
1521
+ # Convert to bounding box format [x_min, y_min, x_max, y_max]
1522
+
1523
+ # Debug info
1524
+ print(f"DEBUG: Rectangle coords: {rectangle_coords}")
1525
+ print(f"DEBUG: Rectangle coords shape: {rectangle_coords.shape}")
1526
+ print(f"DEBUG: use_3d flag: {self.use_3d}")
1527
+ print(
1528
+ f"DEBUG: Has predictor: {hasattr(self, 'predictor') and self.predictor is not None}"
1529
+ )
1530
+ if hasattr(self, "predictor") and self.predictor is not None:
1531
+ print(
1532
+ f"DEBUG: Predictor type: {type(self.predictor).__name__}"
1533
+ )
1534
+ else:
1535
+ print("DEBUG: No predictor available!")
1536
+ self.viewer.status = "Error: Predictor not initialized"
1537
+ return
1538
+
1539
+ # Check if we're in 3D mode (use the flag, not coordinate shape)
1540
+ # In 3D mode, even when drawing on a 2D slice, we get (4, 2) coords
1541
+ # but we need to treat it as 3D with propagation
1542
+ if (
1543
+ self.use_3d
1544
+ and len(rectangle_coords.shape) == 2
1545
+ and rectangle_coords.shape[0] == 4
1546
+ ):
1547
+ print("DEBUG: Processing as 3D rectangle (will propagate)")
1548
+
1549
+ # Get current frame/slice
1550
+ t = int(self.viewer.dims.current_step[0])
1551
+ print(f"DEBUG: Current frame/slice: {t}")
1552
+
1553
+ # Get Y and X bounds from 2D coordinates
1554
+ if rectangle_coords.shape[1] == 3:
1555
+ # If we somehow got 3D coords (T/Z, Y, X)
1556
+ y_coords = rectangle_coords[:, 1]
1557
+ x_coords = rectangle_coords[:, 2]
1558
+ elif rectangle_coords.shape[1] == 2:
1559
+ # More common: 2D coords (Y, X) when drawing on a slice
1560
+ y_coords = rectangle_coords[:, 0]
1561
+ x_coords = rectangle_coords[:, 1]
1562
+ else:
1563
+ print(
1564
+ f"DEBUG: Unexpected coordinate dimensions: {rectangle_coords.shape[1]}"
1565
+ )
1566
+ self.viewer.status = "Error: Unexpected rectangle format"
1567
+ return
1568
+
1569
+ y_min, y_max = int(min(y_coords)), int(max(y_coords))
1570
+ x_min, x_max = int(min(x_coords)), int(max(x_coords))
1571
+
1572
+ box = np.array([x_min, y_min, x_max, y_max], dtype=np.float32)
1573
+ print(f"DEBUG: Box coordinates: {box}")
1574
+
1575
+ # Use SAM2 with box prompt - use _sam2_next_obj_id for 3D mode
1576
+ if not hasattr(self, "_sam2_next_obj_id"):
1577
+ self._sam2_next_obj_id = 1
1578
+ obj_id = self._sam2_next_obj_id
1579
+ self._sam2_next_obj_id += 1
1580
+ print(
1581
+ f"DEBUG: Box mode - using object ID {obj_id}, next will be {self._sam2_next_obj_id}"
1582
+ )
1583
+
1584
+ # Store box for this object
1585
+ if not hasattr(self, "obj_boxes"):
1586
+ self.obj_boxes = {}
1587
+ self.obj_boxes[obj_id] = box
1588
+
1589
+ # Perform segmentation with 3D propagation
1590
+ if (
1591
+ hasattr(self, "_sam2_state")
1592
+ and self._sam2_state is not None
1593
+ ):
1594
+ self.viewer.status = (
1595
+ f"Segmenting object {obj_id} with box at frame {t}..."
1596
+ )
1597
+ print(f"DEBUG: Starting segmentation for object {obj_id}")
1598
+
1599
+ _, out_obj_ids, out_mask_logits = (
1600
+ self.predictor.add_new_points_or_box(
1601
+ inference_state=self._sam2_state,
1602
+ frame_idx=t,
1603
+ obj_id=obj_id,
1604
+ box=box,
1605
+ )
1606
+ )
1607
+
1608
+ print("DEBUG: Segmentation complete, processing mask")
1609
+ # Update current frame
1610
+ mask = (out_mask_logits[0] > 0.0).cpu().numpy()
1611
+ if mask.ndim > 2:
1612
+ mask = mask.squeeze()
1613
+
1614
+ # Resize if needed
1615
+ if mask.shape != self.segmentation_result[t].shape:
1616
+ from skimage.transform import resize
1617
+
1618
+ mask = resize(
1619
+ mask.astype(float),
1620
+ self.segmentation_result[t].shape,
1621
+ order=0,
1622
+ preserve_range=True,
1623
+ anti_aliasing=False,
1624
+ ).astype(bool)
1625
+
1626
+ # Update segmentation
1627
+ self.segmentation_result[t][
1628
+ mask & (self.segmentation_result[t] == 0)
1629
+ ] = obj_id
1630
+
1631
+ print(f"DEBUG: Starting propagation for object {obj_id}")
1632
+ # Propagate to all frames
1633
+ self._propagate_mask_for_current_object(obj_id, t)
1634
+
1635
+ # Update UI
1636
+ print("DEBUG: Updating label layer")
1637
+ self._update_label_layer()
1638
+ if (
1639
+ hasattr(self, "label_table_widget")
1640
+ and self.label_table_widget is not None
1641
+ ):
1642
+ self._populate_label_table(self.label_table_widget)
1643
+
1644
+ self.viewer.status = (
1645
+ f"Segmented and propagated object {obj_id} from box"
1646
+ )
1647
+ print("DEBUG: Rectangle processing complete!")
1648
+
1649
+ # Keep the rectangle visible after processing
1650
+ # Users can manually delete it if needed
1651
+ # if self.shapes_layer is not None:
1652
+ # self.shapes_layer.data = []
1653
+ else:
1654
+ print("DEBUG: _sam2_state not available")
1655
+ self.viewer.status = (
1656
+ "Error: 3D segmentation state not initialized"
1657
+ )
1658
+
1659
+ elif (
1660
+ not self.use_3d
1661
+ and len(rectangle_coords.shape) == 2
1662
+ and rectangle_coords.shape[1] == 2
1663
+ ):
1664
+ # 2D case: rectangle_coords shape is (4, 2) for Y, X
1665
+ if rectangle_coords.shape[0] == 4:
1666
+ # Get Y and X bounds
1667
+ y_coords = rectangle_coords[:, 0]
1668
+ x_coords = rectangle_coords[:, 1]
1669
+ y_min, y_max = int(min(y_coords)), int(max(y_coords))
1670
+ x_min, x_max = int(min(x_coords)), int(max(x_coords))
1671
+
1672
+ box = np.array(
1673
+ [x_min, y_min, x_max, y_max], dtype=np.float32
1674
+ )
1675
+
1676
+ # Use SAM2 with box prompt - use next_obj_id for 2D mode
1677
+ if not hasattr(self, "next_obj_id"):
1678
+ self.next_obj_id = 1
1679
+ obj_id = self.next_obj_id
1680
+ self.next_obj_id += 1
1681
+ print(
1682
+ f"DEBUG: 2D Box mode - using object ID {obj_id}, next will be {self.next_obj_id}"
1683
+ )
1684
+
1685
+ # Store box for this object
1686
+ if not hasattr(self, "obj_boxes"):
1687
+ self.obj_boxes = {}
1688
+ self.obj_boxes[obj_id] = box
1689
+
1690
+ # Perform segmentation
1691
+ if (
1692
+ hasattr(self, "predictor")
1693
+ and self.predictor is not None
1694
+ ):
1695
+ # Make sure image is loaded
1696
+ if self.current_image_for_segmentation is None:
1697
+ self.viewer.status = (
1698
+ "No image loaded for segmentation"
1699
+ )
1700
+ return
1701
+
1702
+ # Prepare image for SAM2
1703
+ image = self.current_image_for_segmentation
1704
+ if len(image.shape) == 2:
1705
+ image = np.stack([image] * 3, axis=-1)
1706
+ elif len(image.shape) == 3 and image.shape[2] == 1:
1707
+ image = np.concatenate([image] * 3, axis=2)
1708
+ elif len(image.shape) == 3 and image.shape[2] > 3:
1709
+ image = image[:, :, :3]
1710
+
1711
+ if image.dtype != np.uint8:
1712
+ image = (image / np.max(image) * 255).astype(
1713
+ np.uint8
1714
+ )
1715
+
1716
+ # Set the image in the predictor (only for ImagePredictor, not VideoPredictor)
1717
+ if hasattr(self.predictor, "set_image"):
1718
+ self.predictor.set_image(image)
1719
+ else:
1720
+ self.viewer.status = "Error: Rectangle mode requires Image Predictor (2D mode)"
1721
+ return
1722
+
1723
+ self.viewer.status = (
1724
+ f"Segmenting object {obj_id} with box..."
1725
+ )
1726
+
1727
+ with (
1728
+ torch.inference_mode(),
1729
+ torch.autocast(device_type),
1730
+ ):
1731
+ masks, scores, _ = self.predictor.predict(
1732
+ box=box,
1733
+ multimask_output=False,
1734
+ )
1735
+
1736
+ # Get the mask
1737
+ if len(masks) > 0:
1738
+ best_mask = masks[0]
1739
+
1740
+ # Resize if needed
1741
+ if (
1742
+ best_mask.shape
1743
+ != self.segmentation_result.shape
1744
+ ):
1745
+ from skimage.transform import resize
1746
+
1747
+ best_mask = resize(
1748
+ best_mask.astype(float),
1749
+ self.segmentation_result.shape,
1750
+ order=0,
1751
+ preserve_range=True,
1752
+ anti_aliasing=False,
1753
+ ).astype(bool)
1754
+
1755
+ # Apply mask (only overwrite background)
1756
+ mask_condition = np.logical_and(
1757
+ best_mask, (self.segmentation_result == 0)
1758
+ )
1759
+ self.segmentation_result[mask_condition] = (
1760
+ obj_id
1761
+ )
1762
+
1763
+ # Update label info
1764
+ area = np.sum(
1765
+ self.segmentation_result == obj_id
1766
+ )
1767
+ y_indices, x_indices = np.where(
1768
+ self.segmentation_result == obj_id
1769
+ )
1770
+ center_y = (
1771
+ np.mean(y_indices)
1772
+ if len(y_indices) > 0
1773
+ else 0
1774
+ )
1775
+ center_x = (
1776
+ np.mean(x_indices)
1777
+ if len(x_indices) > 0
1778
+ else 0
1779
+ )
1780
+
1781
+ self.label_info[obj_id] = {
1782
+ "area": area,
1783
+ "center_y": center_y,
1784
+ "center_x": center_x,
1785
+ "score": float(scores[0]),
1786
+ }
1787
+
1788
+ self.viewer.status = (
1789
+ f"Segmented object {obj_id} from box"
1790
+ )
1791
+ else:
1792
+ self.viewer.status = "No valid mask produced"
1793
+
1794
+ # Update the UI
1795
+ self._update_label_layer()
1796
+ if (
1797
+ hasattr(self, "label_table_widget")
1798
+ and self.label_table_widget is not None
1799
+ ):
1800
+ self._populate_label_table(self.label_table_widget)
1801
+
1802
+ # Keep the rectangle visible after processing
1803
+ # Users can manually delete it if needed
1804
+ # if self.shapes_layer is not None:
1805
+ # self.shapes_layer.data = []
1806
+ else:
1807
+ # Unexpected shape dimensions
1808
+ print(
1809
+ f"DEBUG: Unexpected rectangle shape: {rectangle_coords.shape}"
1810
+ )
1811
+ self.viewer.status = f"Error: Unexpected rectangle dimensions {rectangle_coords.shape}. Expected (4,2) for 2D or (4,3) for 3D."
1812
+
1813
+ except (
1814
+ IndexError,
1815
+ KeyError,
1816
+ ValueError,
1817
+ RuntimeError,
1818
+ TypeError,
1819
+ ) as e:
1820
+ import traceback
1821
+
1822
+ self.viewer.status = f"Error in rectangle handling: {str(e)}"
1823
+ print("DEBUG: Exception in _on_rectangle_added:")
1824
+ traceback.print_exc()
1121
1825
 
1122
1826
  def _on_points_clicked(self, layer, event):
1123
1827
  """Handle clicks on the points layer for adding/removing points."""
1828
+ device_type = "cuda" if self.device.type == "cuda" else "cpu"
1124
1829
  try:
1125
1830
  # Only process clicks, not drags
1126
1831
  if event.type != "mouse_press":
1127
1832
  return
1128
1833
 
1834
+ # Check if segmentation result exists
1835
+ if self.segmentation_result is None:
1836
+ self.viewer.status = (
1837
+ "Segmentation not ready. Please wait for image to load."
1838
+ )
1839
+ return
1840
+
1129
1841
  # Get coordinates of mouse click
1130
1842
  coords = np.round(event.position).astype(int)
1131
1843
 
@@ -1163,6 +1875,25 @@ class BatchCropAnything:
1163
1875
  colors.append("red" if is_negative else "green")
1164
1876
  layer.face_color = colors
1165
1877
 
1878
+ # Validate coordinates are within segmentation bounds
1879
+ if (
1880
+ t < 0
1881
+ or t >= self.segmentation_result.shape[0]
1882
+ or y < 0
1883
+ or y >= self.segmentation_result.shape[1]
1884
+ or x < 0
1885
+ or x >= self.segmentation_result.shape[2]
1886
+ ):
1887
+ self.viewer.status = (
1888
+ f"Click at ({t}, {y}, {x}) is out of bounds for "
1889
+ f"segmentation shape {self.segmentation_result.shape}. "
1890
+ f"Please click within the image bounds."
1891
+ )
1892
+ # Remove the invalid point that was just added
1893
+ if len(layer.data) > 0:
1894
+ layer.data = layer.data[:-1]
1895
+ return
1896
+
1166
1897
  # Get the object ID
1167
1898
  # If clicking on existing segmentation with negative point
1168
1899
  label_id = self.segmentation_result[t, y, x]
@@ -1366,7 +2097,11 @@ class BatchCropAnything:
1366
2097
  time.sleep(2)
1367
2098
  for layer in list(self.viewer.layers):
1368
2099
  if "Propagation Progress" in layer.name:
1369
- self.viewer.layers.remove(layer)
2100
+ # Clean up callbacks before removing the layer to prevent cleanup issues
2101
+ if hasattr(layer, "mouse_drag_callbacks"):
2102
+ layer.mouse_drag_callbacks.clear()
2103
+ with contextlib.suppress(ValueError):
2104
+ self.viewer.layers.remove(layer)
1370
2105
 
1371
2106
  threading.Thread(target=remove_progress).start()
1372
2107
 
@@ -1407,6 +2142,23 @@ class BatchCropAnything:
1407
2142
  colors.append("red" if is_negative else "green")
1408
2143
  layer.face_color = colors
1409
2144
 
2145
+ # Validate coordinates are within segmentation bounds
2146
+ if (
2147
+ y < 0
2148
+ or y >= self.segmentation_result.shape[0]
2149
+ or x < 0
2150
+ or x >= self.segmentation_result.shape[1]
2151
+ ):
2152
+ self.viewer.status = (
2153
+ f"Click at ({y}, {x}) is out of bounds for "
2154
+ f"segmentation shape {self.segmentation_result.shape}. "
2155
+ f"Please click within the image bounds."
2156
+ )
2157
+ # Remove the invalid point that was just added
2158
+ if len(layer.data) > 0:
2159
+ layer.data = layer.data[:-1]
2160
+ return
2161
+
1410
2162
  # Get object ID
1411
2163
  label_id = self.segmentation_result[y, x]
1412
2164
  if is_negative and label_id > 0:
@@ -1451,8 +2203,14 @@ class BatchCropAnything:
1451
2203
  if image.dtype != np.uint8:
1452
2204
  image = (image / np.max(image) * 255).astype(np.uint8)
1453
2205
 
1454
- # Set the image in the predictor
1455
- self.predictor.set_image(image)
2206
+ # Set the image in the predictor (only for ImagePredictor, not VideoPredictor)
2207
+ if hasattr(self.predictor, "set_image"):
2208
+ self.predictor.set_image(image)
2209
+ else:
2210
+ self.viewer.status = (
2211
+ "Error: Point mode in 2D requires Image Predictor"
2212
+ )
2213
+ return
1456
2214
 
1457
2215
  # Use only points for current object
1458
2216
  points = np.array(
@@ -1462,7 +2220,7 @@ class BatchCropAnything:
1462
2220
 
1463
2221
  self.viewer.status = f"Segmenting object {obj_id} with {len(points)} points..."
1464
2222
 
1465
- with torch.inference_mode(), torch.autocast("cuda"):
2223
+ with torch.inference_mode(), torch.autocast(device_type):
1466
2224
  masks, scores, _ = self.predictor.predict(
1467
2225
  point_coords=points,
1468
2226
  point_labels=labels,
@@ -1551,16 +2309,23 @@ class BatchCropAnything:
1551
2309
  def _on_label_clicked(self, layer, event):
1552
2310
  """Handle label selection and user prompts on mouse click."""
1553
2311
  try:
1554
- # Only process clicks, not drags
2312
+ # Only process mouse press events
1555
2313
  if event.type != "mouse_press":
1556
2314
  return
1557
2315
 
2316
+ # Only handle left mouse button
2317
+ if event.button != 1:
2318
+ return
2319
+
1558
2320
  # Get coordinates of mouse click
1559
2321
  coords = np.round(event.position).astype(int)
1560
2322
 
1561
- # Check if Shift is pressed (negative point)
2323
+ # Check modifiers
1562
2324
  is_negative = "Shift" in event.modifiers
1563
- point_label = -1 if is_negative else 1
2325
+ is_control = (
2326
+ "Control" in event.modifiers or "Ctrl" in event.modifiers
2327
+ )
2328
+ # point_label = -1 if is_negative else 1
1564
2329
 
1565
2330
  # For 2D data
1566
2331
  if not self.use_3d:
@@ -1581,254 +2346,13 @@ class BatchCropAnything:
1581
2346
  # Get the label ID at the clicked position
1582
2347
  label_id = self.segmentation_result[y, x]
1583
2348
 
1584
- # Initialize a unique object ID for this click (if needed)
1585
- if not hasattr(self, "next_obj_id"):
1586
- # Start with highest existing ID + 1
1587
- if self.segmentation_result.max() > 0:
1588
- self.next_obj_id = (
1589
- int(self.segmentation_result.max()) + 1
1590
- )
1591
- else:
1592
- self.next_obj_id = 1
1593
-
1594
- # If clicking on background or using negative click, handle segmentation
1595
- if label_id == 0 or is_negative:
1596
- # Find or create points layer for the current object we're working on
1597
- current_obj_id = None
1598
-
1599
- # If negative point on existing label, use that label's ID
1600
- if is_negative and label_id > 0:
1601
- current_obj_id = label_id
1602
- # For positive clicks on background, create a new object
1603
- elif point_label > 0 and label_id == 0:
1604
- current_obj_id = self.next_obj_id
1605
- self.next_obj_id += 1
1606
- # For negative on background, try to find most recent object
1607
- elif point_label < 0 and label_id == 0:
1608
- # Use most recently created object if available
1609
- if hasattr(self, "obj_points") and self.obj_points:
1610
- current_obj_id = max(self.obj_points.keys())
1611
- else:
1612
- self.viewer.status = "No existing object to modify with negative point"
1613
- return
1614
-
1615
- if current_obj_id is None:
1616
- self.viewer.status = (
1617
- "Could not determine which object to modify"
1618
- )
1619
- return
1620
-
1621
- # Find or create points layer for this object
1622
- points_layer = None
1623
- for layer in list(self.viewer.layers):
1624
- if f"Points for Object {current_obj_id}" in layer.name:
1625
- points_layer = layer
1626
- break
1627
-
1628
- # Initialize object tracking if needed
1629
- if not hasattr(self, "obj_points"):
1630
- self.obj_points = {}
1631
- self.obj_labels = {}
1632
-
1633
- if current_obj_id not in self.obj_points:
1634
- self.obj_points[current_obj_id] = []
1635
- self.obj_labels[current_obj_id] = []
1636
-
1637
- # Create or update points layer for this object
1638
- if points_layer is None:
1639
- # First point for this object
1640
- points_layer = self.viewer.add_points(
1641
- np.array([[y, x]]),
1642
- name=f"Points for Object {current_obj_id}",
1643
- size=10,
1644
- face_color=["green" if point_label > 0 else "red"],
1645
- border_color="white",
1646
- border_width=1,
1647
- opacity=0.8,
1648
- )
1649
- self.obj_points[current_obj_id] = [[x, y]]
1650
- self.obj_labels[current_obj_id] = [point_label]
1651
- else:
1652
- # Add point to existing layer
1653
- current_points = points_layer.data
1654
- current_colors = points_layer.face_color
1655
-
1656
- # Add new point
1657
- new_points = np.vstack([current_points, [y, x]])
1658
- new_color = "green" if point_label > 0 else "red"
1659
-
1660
- # Update points layer
1661
- points_layer.data = new_points
1662
-
1663
- # Update colors
1664
- if isinstance(current_colors, list):
1665
- current_colors.append(new_color)
1666
- points_layer.face_color = current_colors
1667
- else:
1668
- # If it's an array, create a list of colors
1669
- colors = []
1670
- for i in range(len(new_points)):
1671
- if i < len(current_points):
1672
- colors.append(
1673
- "green" if point_label > 0 else "red"
1674
- )
1675
- else:
1676
- colors.append(new_color)
1677
- points_layer.face_color = colors
1678
-
1679
- # Update object tracking
1680
- self.obj_points[current_obj_id].append([x, y])
1681
- self.obj_labels[current_obj_id].append(point_label)
1682
-
1683
- # Now do the actual segmentation using SAM2
1684
- if (
1685
- hasattr(self, "predictor")
1686
- and self.predictor is not None
1687
- ):
1688
- try:
1689
- # Make sure image is loaded
1690
- if self.current_image_for_segmentation is None:
1691
- self.viewer.status = (
1692
- "No image loaded for segmentation"
1693
- )
1694
- return
1695
-
1696
- # Prepare image for SAM2
1697
- image = self.current_image_for_segmentation
1698
- if len(image.shape) == 2:
1699
- image = np.stack([image] * 3, axis=-1)
1700
- elif len(image.shape) == 3 and image.shape[2] == 1:
1701
- image = np.concatenate([image] * 3, axis=2)
1702
- elif len(image.shape) == 3 and image.shape[2] > 3:
1703
- image = image[:, :, :3]
1704
-
1705
- if image.dtype != np.uint8:
1706
- image = (image / np.max(image) * 255).astype(
1707
- np.uint8
1708
- )
1709
-
1710
- # Set the image in the predictor
1711
- self.predictor.set_image(image)
1712
-
1713
- # Only use the points for the current object being segmented
1714
- points = np.array(
1715
- self.obj_points[current_obj_id],
1716
- dtype=np.float32,
1717
- )
1718
- labels = np.array(
1719
- self.obj_labels[current_obj_id], dtype=np.int32
1720
- )
1721
-
1722
- self.viewer.status = f"Segmenting object {current_obj_id} with {len(points)} points..."
1723
-
1724
- with torch.inference_mode(), torch.autocast(
1725
- "cuda"
1726
- ):
1727
- masks, scores, _ = self.predictor.predict(
1728
- point_coords=points,
1729
- point_labels=labels,
1730
- multimask_output=True,
1731
- )
1732
-
1733
- # Get best mask
1734
- if len(masks) > 0:
1735
- best_mask = masks[0]
1736
-
1737
- # Update segmentation result
1738
- if (
1739
- best_mask.shape
1740
- != self.segmentation_result.shape
1741
- ):
1742
- from skimage.transform import resize
1743
-
1744
- best_mask = resize(
1745
- best_mask.astype(float),
1746
- self.segmentation_result.shape,
1747
- order=0,
1748
- preserve_range=True,
1749
- anti_aliasing=False,
1750
- ).astype(bool)
1751
-
1752
- # CRITICAL FIX: For negative points, only remove from this object's mask
1753
- # For positive points, add to this object's mask without removing other objects
1754
- if point_label < 0:
1755
- # Remove only from current object's mask
1756
- self.segmentation_result[
1757
- (
1758
- self.segmentation_result
1759
- == current_obj_id
1760
- )
1761
- & best_mask
1762
- ] = 0
1763
- else:
1764
- # Add to current object's mask without affecting other objects
1765
- # Only overwrite background (value 0)
1766
- self.segmentation_result[
1767
- best_mask
1768
- & (self.segmentation_result == 0)
1769
- ] = current_obj_id
1770
-
1771
- # Update label info
1772
- area = np.sum(
1773
- self.segmentation_result
1774
- == current_obj_id
1775
- )
1776
- y_indices, x_indices = np.where(
1777
- self.segmentation_result
1778
- == current_obj_id
1779
- )
1780
- center_y = (
1781
- np.mean(y_indices)
1782
- if len(y_indices) > 0
1783
- else 0
1784
- )
1785
- center_x = (
1786
- np.mean(x_indices)
1787
- if len(x_indices) > 0
1788
- else 0
1789
- )
1790
-
1791
- self.label_info[current_obj_id] = {
1792
- "area": area,
1793
- "center_y": center_y,
1794
- "center_x": center_x,
1795
- "score": float(scores[0]),
1796
- }
1797
-
1798
- self.viewer.status = (
1799
- f"Updated object {current_obj_id}"
1800
- )
1801
- else:
1802
- self.viewer.status = (
1803
- "No valid mask produced"
1804
- )
1805
-
1806
- # Update the UI
1807
- self._update_label_layer()
1808
- if (
1809
- hasattr(self, "label_table_widget")
1810
- and self.label_table_widget is not None
1811
- ):
1812
- self._populate_label_table(
1813
- self.label_table_widget
1814
- )
1815
-
1816
- except (
1817
- IndexError,
1818
- KeyError,
1819
- ValueError,
1820
- AttributeError,
1821
- TypeError,
1822
- ) as e:
1823
- import traceback
1824
-
1825
- self.viewer.status = (
1826
- f"Error in SAM2 processing: {str(e)}"
1827
- )
1828
- traceback.print_exc()
2349
+ # Handle Ctrl+Click to clear a single label
2350
+ if is_control and label_id > 0:
2351
+ self.clear_label_at_position(y, x)
2352
+ return
1829
2353
 
1830
- # If clicking on an existing label, toggle selection
1831
- elif label_id > 0:
2354
+ # If clicking on an existing label (and not using modifiers), toggle selection
2355
+ if label_id > 0 and not is_negative and not is_control:
1832
2356
  # Toggle the label selection
1833
2357
  if label_id in self.selected_labels:
1834
2358
  self.selected_labels.remove(label_id)
@@ -1840,8 +2364,14 @@ class BatchCropAnything:
1840
2364
  # Update table and preview
1841
2365
  self._update_label_table()
1842
2366
  self.preview_crop()
2367
+ return
1843
2368
 
1844
- # 3D case (handle differently)
2369
+ # If clicking on background or using Shift (negative points), this should be handled by points layer
2370
+ # Don't process these clicks here to avoid conflicts
2371
+ if label_id == 0 or is_negative:
2372
+ return
2373
+
2374
+ # 3D case
1845
2375
  else:
1846
2376
  if len(coords) == 3:
1847
2377
  t, y, x = map(int, coords)
@@ -1870,12 +2400,13 @@ class BatchCropAnything:
1870
2400
  # Get the label ID at the clicked position
1871
2401
  label_id = self.segmentation_result[t, y, x]
1872
2402
 
1873
- # If background or shift is pressed, handle in _on_3d_label_clicked
1874
- if label_id == 0 or is_negative:
1875
- # This will be handled by _on_3d_label_clicked already attached
1876
- pass
1877
- # If clicking on an existing label, handle selection
1878
- elif label_id > 0:
2403
+ # Handle Ctrl+Click to clear a single label
2404
+ if is_control and label_id > 0:
2405
+ self.clear_label_at_position_3d(t, y, x)
2406
+ return
2407
+
2408
+ # If clicking on an existing label and not using negative points, handle selection
2409
+ if label_id > 0 and not is_negative and not is_control:
1879
2410
  # Toggle the label selection
1880
2411
  if label_id in self.selected_labels:
1881
2412
  self.selected_labels.remove(label_id)
@@ -1886,9 +2417,12 @@ class BatchCropAnything:
1886
2417
 
1887
2418
  # Update table if it exists
1888
2419
  self._update_label_table()
1889
-
1890
- # Update preview after selection changes
1891
2420
  self.preview_crop()
2421
+ return
2422
+
2423
+ # For background clicks or negative points, let the 3D handler deal with it
2424
+ if label_id == 0 or is_negative:
2425
+ return
1892
2426
 
1893
2427
  except (
1894
2428
  IndexError,
@@ -1902,12 +2436,74 @@ class BatchCropAnything:
1902
2436
  self.viewer.status = f"Error in click handling: {str(e)}"
1903
2437
  traceback.print_exc()
1904
2438
 
2439
+ def _add_segmentation_point(self, x, y, event):
2440
+ """Add a point for segmentation."""
2441
+ is_negative = "Shift" in event.modifiers
2442
+
2443
+ # Initialize tracking if needed
2444
+ if not hasattr(self, "current_points"):
2445
+ self.current_points = []
2446
+ self.current_labels = []
2447
+ self.current_obj_id = 1
2448
+
2449
+ # Add point
2450
+ self.current_points.append([x, y])
2451
+ self.current_labels.append(0 if is_negative else 1)
2452
+
2453
+ # Run SAM2 prediction
2454
+ if self.predictor is not None:
2455
+ # Prepare image
2456
+ image = self._prepare_image_for_sam2()
2457
+
2458
+ # Set the image in the predictor (only for ImagePredictor, not VideoPredictor)
2459
+ if hasattr(self.predictor, "set_image"):
2460
+ self.predictor.set_image(image)
2461
+ else:
2462
+ self.viewer.status = (
2463
+ "Error: This operation requires Image Predictor (2D mode)"
2464
+ )
2465
+ return
2466
+
2467
+ # Predict
2468
+ device_type = "cuda" if self.device.type == "cuda" else "cpu"
2469
+ with torch.inference_mode(), torch.autocast(device_type):
2470
+ masks, scores, _ = self.predictor.predict(
2471
+ point_coords=np.array(
2472
+ self.current_points, dtype=np.float32
2473
+ ),
2474
+ point_labels=np.array(self.current_labels, dtype=np.int32),
2475
+ multimask_output=False,
2476
+ )
2477
+
2478
+ # Update segmentation
2479
+ if len(masks) > 0:
2480
+ mask = masks[0] > 0.5
2481
+ if self.current_scale_factor < 1.0:
2482
+ mask = resize(
2483
+ mask, self.segmentation_result.shape, order=0
2484
+ ).astype(bool)
2485
+
2486
+ # Update segmentation result
2487
+ self.segmentation_result[mask] = self.current_obj_id
2488
+
2489
+ # Move to next object if adding positive point
2490
+ if not is_negative:
2491
+ self.current_obj_id += 1
2492
+ self.current_points = []
2493
+ self.current_labels = []
2494
+
2495
+ self._update_label_layer()
2496
+
1905
2497
  def _add_point_marker(self, coords, label_type):
1906
2498
  """Add a visible marker for where the user clicked."""
1907
2499
  # Remove previous point markers
1908
2500
  for layer in list(self.viewer.layers):
1909
2501
  if "Point Prompt" in layer.name:
1910
- self.viewer.layers.remove(layer)
2502
+ # Clean up callbacks before removing the layer to prevent cleanup issues
2503
+ if hasattr(layer, "mouse_drag_callbacks"):
2504
+ layer.mouse_drag_callbacks.clear()
2505
+ with contextlib.suppress(ValueError):
2506
+ self.viewer.layers.remove(layer)
1911
2507
 
1912
2508
  # Create points layer
1913
2509
  color = (
@@ -1923,6 +2519,14 @@ class BatchCropAnything:
1923
2519
  opacity=0.8,
1924
2520
  )
1925
2521
 
2522
+ with contextlib.suppress(AttributeError, ValueError):
2523
+ self.points_layer.mouse_drag_callbacks.remove(
2524
+ self._on_points_clicked
2525
+ )
2526
+ self.points_layer.mouse_drag_callbacks.append(
2527
+ self._on_points_clicked
2528
+ )
2529
+
1926
2530
  def create_label_table(self, parent_widget):
1927
2531
  """Create a table widget displaying all detected labels."""
1928
2532
  # Create table widget
@@ -2087,11 +2691,170 @@ class BatchCropAnything:
2087
2691
  self.viewer.status = f"Selected all {len(self.selected_labels)} labels"
2088
2692
 
2089
2693
  def clear_selection(self):
2090
- """Clear all selected labels."""
2694
+ """Clear all labels from the segmentation.
2695
+
2696
+ This removes all segmented objects from the label layer, resets all tracking data,
2697
+ and prepares the interface for new segmentations. Note: The method name is kept as
2698
+ 'clear_selection' for backwards compatibility, but it clears all labels, not just
2699
+ the selection.
2700
+ """
2701
+ if self.segmentation_result is None:
2702
+ self.viewer.status = "No segmentation available"
2703
+ return
2704
+
2705
+ # Get all unique label IDs (excluding background 0)
2706
+ unique_labels = np.unique(self.segmentation_result)
2707
+ label_ids = [label for label in unique_labels if label > 0]
2708
+
2709
+ if len(label_ids) == 0:
2710
+ self.viewer.status = "No labels to clear"
2711
+ return
2712
+
2713
+ # Clear the entire segmentation result
2714
+ self.segmentation_result[:] = 0
2715
+
2716
+ # Clear selected labels
2091
2717
  self.selected_labels = set()
2718
+
2719
+ # Clear label info
2720
+ self.label_info = {}
2721
+
2722
+ # Remove any object-specific point layers
2723
+ for layer in list(self.viewer.layers):
2724
+ if "Points for Object" in layer.name:
2725
+ # Clean up callbacks before removing the layer to prevent cleanup issues
2726
+ if hasattr(layer, "mouse_drag_callbacks"):
2727
+ layer.mouse_drag_callbacks.clear()
2728
+ with contextlib.suppress(ValueError):
2729
+ self.viewer.layers.remove(layer)
2730
+
2731
+ # Clean up object tracking data
2732
+ if hasattr(self, "obj_points"):
2733
+ self.obj_points = {}
2734
+ if hasattr(self, "obj_labels"):
2735
+ self.obj_labels = {}
2736
+ if hasattr(self, "points_data"):
2737
+ self.points_data = {}
2738
+ if hasattr(self, "points_labels"):
2739
+ self.points_labels = {}
2740
+
2741
+ # Reset object ID counters
2742
+ if hasattr(self, "next_obj_id"):
2743
+ self.next_obj_id = 1
2744
+ if hasattr(self, "_sam2_next_obj_id"):
2745
+ self._sam2_next_obj_id = 1
2746
+
2747
+ # Update UI
2748
+ self._update_label_layer()
2092
2749
  self._update_label_table()
2093
2750
  self.preview_crop()
2094
- self.viewer.status = "Cleared all selections"
2751
+
2752
+ self.viewer.status = (
2753
+ f"Cleared all {len(label_ids)} labels from segmentation"
2754
+ )
2755
+
2756
+ def clear_label_at_position(self, y, x):
2757
+ """Clear a single label at the specified 2D position."""
2758
+ if self.segmentation_result is None:
2759
+ self.viewer.status = "No segmentation available"
2760
+ return
2761
+
2762
+ label_id = self.segmentation_result[y, x]
2763
+ if label_id > 0:
2764
+ # Remove all pixels with this label ID
2765
+ self.segmentation_result[self.segmentation_result == label_id] = 0
2766
+
2767
+ # Remove from selected labels if it was selected
2768
+ self.selected_labels.discard(label_id)
2769
+
2770
+ # Remove from label info
2771
+ if label_id in self.label_info:
2772
+ del self.label_info[label_id]
2773
+
2774
+ # Remove any object-specific point layers for this label
2775
+ for layer in list(self.viewer.layers):
2776
+ if f"Points for Object {label_id}" in layer.name:
2777
+ # Clean up callbacks before removing the layer to prevent cleanup issues
2778
+ if hasattr(layer, "mouse_drag_callbacks"):
2779
+ layer.mouse_drag_callbacks.clear()
2780
+ with contextlib.suppress(ValueError):
2781
+ self.viewer.layers.remove(layer)
2782
+
2783
+ # Clean up object tracking data
2784
+ if hasattr(self, "obj_points") and label_id in self.obj_points:
2785
+ del self.obj_points[label_id]
2786
+ if hasattr(self, "obj_labels") and label_id in self.obj_labels:
2787
+ del self.obj_labels[label_id]
2788
+
2789
+ # Update UI
2790
+ self._update_label_layer()
2791
+ self._update_label_table()
2792
+ self.preview_crop()
2793
+
2794
+ self.viewer.status = f"Deleted label ID: {label_id}"
2795
+ else:
2796
+ self.viewer.status = "No label to delete at this position"
2797
+
2798
+ def clear_label_at_position_3d(self, t, y, x):
2799
+ """Clear a single label at the specified 3D position."""
2800
+ if self.segmentation_result is None:
2801
+ self.viewer.status = "No segmentation available"
2802
+ return
2803
+
2804
+ label_id = self.segmentation_result[t, y, x]
2805
+ if label_id > 0:
2806
+ # Remove all pixels with this label ID across all timeframes
2807
+ self.segmentation_result[self.segmentation_result == label_id] = 0
2808
+
2809
+ # Remove from selected labels if it was selected
2810
+ self.selected_labels.discard(label_id)
2811
+
2812
+ # Remove from label info
2813
+ if label_id in self.label_info:
2814
+ del self.label_info[label_id]
2815
+
2816
+ # Remove any object-specific point layers for this label
2817
+ for layer in list(self.viewer.layers):
2818
+ if f"Points for Object {label_id}" in layer.name:
2819
+ # Clean up callbacks before removing the layer to prevent cleanup issues
2820
+ if hasattr(layer, "mouse_drag_callbacks"):
2821
+ layer.mouse_drag_callbacks.clear()
2822
+ with contextlib.suppress(ValueError):
2823
+ self.viewer.layers.remove(layer)
2824
+
2825
+ # Clean up 3D object tracking data
2826
+ if (
2827
+ hasattr(self, "sam2_points_by_obj")
2828
+ and label_id in self.sam2_points_by_obj
2829
+ ):
2830
+ del self.sam2_points_by_obj[label_id]
2831
+ if (
2832
+ hasattr(self, "sam2_labels_by_obj")
2833
+ and label_id in self.sam2_labels_by_obj
2834
+ ):
2835
+ del self.sam2_labels_by_obj[label_id]
2836
+ if hasattr(self, "points_data") and label_id in self.points_data:
2837
+ del self.points_data[label_id]
2838
+ if (
2839
+ hasattr(self, "points_labels")
2840
+ and label_id in self.points_labels
2841
+ ):
2842
+ del self.points_labels[label_id]
2843
+
2844
+ # Update UI
2845
+ self._update_label_layer()
2846
+ if (
2847
+ hasattr(self, "label_table_widget")
2848
+ and self.label_table_widget is not None
2849
+ ):
2850
+ self._populate_label_table(self.label_table_widget)
2851
+ self.preview_crop()
2852
+
2853
+ self.viewer.status = (
2854
+ f"Deleted label ID: {label_id} from all timeframes"
2855
+ )
2856
+ else:
2857
+ self.viewer.status = "No label to delete at this position"
2095
2858
 
2096
2859
  def preview_crop(self, label_ids=None):
2097
2860
  """Preview the crop result with the selected label IDs."""
@@ -2111,7 +2874,11 @@ class BatchCropAnything:
2111
2874
  # Remove previous preview if exists
2112
2875
  for layer in list(self.viewer.layers):
2113
2876
  if "Preview" in layer.name:
2114
- self.viewer.layers.remove(layer)
2877
+ # Clean up callbacks before removing the layer to prevent cleanup issues
2878
+ if hasattr(layer, "mouse_drag_callbacks"):
2879
+ layer.mouse_drag_callbacks.clear()
2880
+ with contextlib.suppress(ValueError):
2881
+ self.viewer.layers.remove(layer)
2115
2882
 
2116
2883
  # Make sure the segmentation layer is active again
2117
2884
  if self.label_layer is not None:
@@ -2149,7 +2916,11 @@ class BatchCropAnything:
2149
2916
  # Remove previous preview if exists
2150
2917
  for layer in list(self.viewer.layers):
2151
2918
  if "Preview" in layer.name:
2152
- self.viewer.layers.remove(layer)
2919
+ # Clean up callbacks before removing the layer to prevent cleanup issues
2920
+ if hasattr(layer, "mouse_drag_callbacks"):
2921
+ layer.mouse_drag_callbacks.clear()
2922
+ with contextlib.suppress(ValueError):
2923
+ self.viewer.layers.remove(layer)
2153
2924
 
2154
2925
  # Add preview layer
2155
2926
  if label_ids:
@@ -2240,17 +3011,14 @@ class BatchCropAnything:
2240
3011
  # Save cropped image
2241
3012
  image_path = self.images[self.current_index]
2242
3013
  base_name, ext = os.path.splitext(image_path)
2243
- label_str = "_".join(
2244
- str(lid) for lid in sorted(self.selected_labels)
2245
- )
2246
- output_path = f"{base_name}_cropped_{label_str}.tif"
3014
+ output_path = f"{base_name}_sam2_cropped.tif"
2247
3015
 
2248
3016
  # Save using tifffile with explicit parameters for best compatibility
2249
3017
  imwrite(output_path, cropped_image, compression="zlib")
2250
3018
  self.viewer.status = f"Saved cropped image to {output_path}"
2251
3019
 
2252
3020
  # Save the label image with exact same dimensions as original
2253
- label_output_path = f"{base_name}_labels_{label_str}.tif"
3021
+ label_output_path = f"{base_name}_sam2_labels.tif"
2254
3022
  imwrite(label_output_path, label_image, compression="zlib")
2255
3023
  self.viewer.status += f"\nSaved label mask to {label_output_path}"
2256
3024
 
@@ -2264,6 +3032,27 @@ class BatchCropAnything:
2264
3032
  self.viewer.status = f"Error cropping image: {str(e)}"
2265
3033
  return False
2266
3034
 
3035
+ def reset_sam2_state(self):
3036
+ """Reset SAM2 predictor state for 2D segmentation."""
3037
+ if not self.use_3d and hasattr(self, "prepared_sam2_image"):
3038
+ # Re-set the image in the predictor (only for ImagePredictor)
3039
+ device_type = "cuda" if self.device.type == "cuda" else "cpu"
3040
+ try:
3041
+ if hasattr(self.predictor, "set_image"):
3042
+ with (
3043
+ torch.inference_mode(),
3044
+ torch.autocast(device_type, dtype=torch.float32),
3045
+ ):
3046
+ self.predictor.set_image(self.prepared_sam2_image)
3047
+ else:
3048
+ print(
3049
+ "DEBUG: reset_sam2_state - predictor doesn't have set_image method"
3050
+ )
3051
+ except (RuntimeError, AssertionError, TypeError, ValueError) as e:
3052
+ print(f"Error resetting SAM2 state: {e}")
3053
+ # If there's an error, try to reinitialize
3054
+ self._initialize_sam2()
3055
+
2267
3056
 
2268
3057
  def create_crop_widget(processor):
2269
3058
  """Create the crop control widget."""
@@ -2274,27 +3063,70 @@ def create_crop_widget(processor):
2274
3063
 
2275
3064
  # Instructions
2276
3065
  dimension_type = "3D (TYX/ZYX)" if processor.use_3d else "2D (YX)"
2277
- instructions_label = QLabel(
2278
- f"<b>Processing {dimension_type} data</b><br><br>"
2279
- "To create/edit objects:<br>"
2280
- "1. <b>Click on the POINTS layer</b> to add positive points<br>"
2281
- "2. Use Shift+click for negative points to refine segmentation<br>"
2282
- "3. Click on existing objects in the Segmentation layer to select them<br>"
2283
- "4. Press 'Crop' to save the selected objects to disk"
2284
- )
3066
+
3067
+ if processor.use_3d:
3068
+ instructions_text = (
3069
+ f"<b>Processing {dimension_type} data</b><br><br>"
3070
+ "<b>⚠️ IMPORTANT for 3D stacks:</b><br>"
3071
+ "<ul>"
3072
+ "<li><b>Navigate to the FIRST SLICE</b> where your object appears (use the time/Z slider)</li>"
3073
+ "<li><b>Switch to 2D view</b> (click 2D icon in napari, NOT 3D view)</li>"
3074
+ "<li><b>Point Mode:</b> Select Points layer and click on objects to segment them</li>"
3075
+ "<li><b>Rectangle Mode:</b> Draw rectangles around objects to segment them</li>"
3076
+ "<li>Segmentation will automatically propagate to all slices</li>"
3077
+ "</ul><br>"
3078
+ "<b>General Controls:</b><br>"
3079
+ "<ul>"
3080
+ "<li>Use <b>Shift+click</b> for negative points (remove areas from segmentation)</li>"
3081
+ "<li>Click on existing objects in <b>Segmentation layer</b> to select for cropping</li>"
3082
+ "<li>Press <b>CTRL+click</b> on labels in <b>Segmentation layer</b> to delete them</li>"
3083
+ "<li>Press <b>'Crop'</b> to save selected objects to disk</li>"
3084
+ "</ul>"
3085
+ )
3086
+ else:
3087
+ instructions_text = (
3088
+ f"<b>Processing {dimension_type} data</b><br><br>"
3089
+ "<b>Point Mode:</b> Click on objects to segment them. Use Shift+click for negative points.<br>"
3090
+ "<b>Rectangle Mode:</b> Draw rectangles around objects to segment them.<br><br>"
3091
+ "<ul>"
3092
+ "<li>Click on existing objects in <b>Segmentation layer</b> to select them for cropping</li>"
3093
+ "<li>Press <b>CTRL+click</b> on labels in <b>Segmentation layer</b> to delete them</li>"
3094
+ "<li>Press <b>'Crop'</b> to save selected objects to disk</li>"
3095
+ "</ul>"
3096
+ )
3097
+
3098
+ instructions_label = QLabel(instructions_text)
2285
3099
  instructions_label.setWordWrap(True)
2286
3100
  layout.addWidget(instructions_label)
2287
3101
 
2288
- # Add a button to ensure points layer is active
2289
- activate_button = QPushButton("Make Points Layer Active")
3102
+ # Add mode selector
3103
+ mode_layout = QHBoxLayout()
3104
+ mode_label = QLabel("<b>Prompt Mode:</b>")
3105
+ mode_layout.addWidget(mode_label)
3106
+
3107
+ point_mode_button = QPushButton("Points")
3108
+ point_mode_button.setCheckable(True)
3109
+ point_mode_button.setChecked(True)
3110
+ mode_layout.addWidget(point_mode_button)
3111
+
3112
+ box_mode_button = QPushButton("Rectangle")
3113
+ box_mode_button.setCheckable(True)
3114
+ box_mode_button.setChecked(False)
3115
+ mode_layout.addWidget(box_mode_button)
3116
+
3117
+ mode_layout.addStretch()
3118
+ layout.addLayout(mode_layout)
3119
+
3120
+ # Add a button to ensure active layer is correct
3121
+ activate_button = QPushButton("Make Prompt Layer Active")
2290
3122
  activate_button.clicked.connect(
2291
- lambda: processor._ensure_points_layer_active()
3123
+ lambda: processor._ensure_active_prompt_layer()
2292
3124
  )
2293
3125
  layout.addWidget(activate_button)
2294
3126
 
2295
- # Add a "Clear Points" button to reset prompts
2296
- clear_points_button = QPushButton("Clear Points")
2297
- layout.addWidget(clear_points_button)
3127
+ # Add a "Clear Prompts" button to reset prompts
3128
+ clear_prompts_button = QPushButton("Clear Prompts")
3129
+ layout.addWidget(clear_prompts_button)
2298
3130
 
2299
3131
  # Create label table
2300
3132
  label_table = processor.create_label_table(crop_widget)
@@ -2305,7 +3137,7 @@ def create_crop_widget(processor):
2305
3137
  # Selection buttons
2306
3138
  selection_layout = QHBoxLayout()
2307
3139
  select_all_button = QPushButton("Select All")
2308
- clear_selection_button = QPushButton("Clear Selection")
3140
+ clear_selection_button = QPushButton("Clear All Labels")
2309
3141
  selection_layout.addWidget(select_all_button)
2310
3142
  selection_layout.addWidget(clear_selection_button)
2311
3143
  layout.addLayout(selection_layout)
@@ -2343,51 +3175,152 @@ def create_crop_widget(processor):
2343
3175
  # Create new table
2344
3176
  label_table = processor.create_label_table(crop_widget)
2345
3177
  label_table.setMinimumHeight(200)
2346
- layout.insertWidget(3, label_table) # Insert after clear points button
3178
+ layout.insertWidget(
3179
+ 3, label_table
3180
+ ) # Insert after clear prompts button
2347
3181
  return label_table
2348
3182
 
2349
- # Add helper method to ensure points layer is active
2350
- def _ensure_points_layer_active():
2351
- points_layer = None
2352
- for layer in list(processor.viewer.layers):
2353
- if "Points" in layer.name:
2354
- points_layer = layer
2355
- break
3183
+ # Add helper method to ensure active prompt layer is selected based on mode
3184
+ def _ensure_active_prompt_layer():
3185
+ if processor.prompt_mode == "point":
3186
+ points_layer = None
3187
+ for layer in list(processor.viewer.layers):
3188
+ if "Points" in layer.name and "Object" not in layer.name:
3189
+ points_layer = layer
3190
+ break
2356
3191
 
2357
- if points_layer is not None:
2358
- processor.viewer.layers.selection.active = points_layer
2359
- status_label.setText(
2360
- "Points layer is now active - click to add points"
2361
- )
2362
- else:
2363
- status_label.setText(
2364
- "No points layer found. Please load an image first."
2365
- )
3192
+ if points_layer is not None:
3193
+ processor.viewer.layers.selection.active = points_layer
3194
+ if processor.use_3d:
3195
+ status_label.setText(
3196
+ "Points layer active - Navigate to FIRST SLICE of object, ensure 2D view, then click"
3197
+ )
3198
+ else:
3199
+ status_label.setText(
3200
+ "Points layer is now active - click to add points"
3201
+ )
3202
+ else:
3203
+ status_label.setText(
3204
+ "No points layer found. Please load an image first."
3205
+ )
3206
+ else: # box mode
3207
+ shapes_layer = None
3208
+ for layer in list(processor.viewer.layers):
3209
+ if "Rectangles" in layer.name:
3210
+ shapes_layer = layer
3211
+ break
3212
+
3213
+ if shapes_layer is not None:
3214
+ processor.viewer.layers.selection.active = shapes_layer
3215
+ status_label.setText(
3216
+ "Rectangles layer is now active - draw rectangles"
3217
+ )
3218
+ else:
3219
+ status_label.setText(
3220
+ "No rectangles layer found. Please load an image first."
3221
+ )
2366
3222
 
2367
- processor._ensure_points_layer_active = _ensure_points_layer_active
3223
+ processor._ensure_active_prompt_layer = _ensure_active_prompt_layer
3224
+
3225
+ # Keep the old method for backward compatibility
3226
+ processor._ensure_points_layer_active = _ensure_active_prompt_layer
3227
+
3228
+ def on_clear_prompts_clicked():
3229
+ # Find and clear/remove prompt layers based on mode
3230
+ main_points_layer = None
3231
+ object_points_layers = []
3232
+ shapes_layer = None
2368
3233
 
2369
- # Connect button signals
2370
- def on_clear_points_clicked():
2371
- # Remove all point layers
2372
3234
  for layer in list(processor.viewer.layers):
2373
3235
  if "Points" in layer.name:
3236
+ if "Object" in layer.name:
3237
+ object_points_layers.append(layer)
3238
+ else:
3239
+ main_points_layer = layer
3240
+ elif "Rectangles" in layer.name:
3241
+ shapes_layer = layer
3242
+
3243
+ # Remove object-specific point layers (these are created dynamically)
3244
+ for layer in object_points_layers:
3245
+ # Clean up callbacks before removing the layer to prevent cleanup issues
3246
+ if hasattr(layer, "mouse_drag_callbacks"):
3247
+ layer.mouse_drag_callbacks.clear()
3248
+ with contextlib.suppress(ValueError):
2374
3249
  processor.viewer.layers.remove(layer)
2375
3250
 
2376
- # Reset point tracking attributes
2377
- if hasattr(processor, "points_data"):
2378
- processor.points_data = {}
2379
- processor.points_labels = {}
3251
+ # Clear shapes layer
3252
+ if shapes_layer is not None:
3253
+ shapes_layer.data = []
2380
3254
 
2381
- if hasattr(processor, "obj_points"):
2382
- processor.obj_points = {}
2383
- processor.obj_labels = {}
3255
+ # Clear data from main points layer instead of removing it
3256
+ if main_points_layer is not None:
3257
+ # Clear the points data
3258
+ main_points_layer.data = np.zeros(
3259
+ (0, 2 if not processor.use_3d else 3)
3260
+ )
3261
+ main_points_layer.face_color = "green"
2384
3262
 
2385
- # Re-create empty points layer
2386
- processor._update_label_layer()
2387
- processor._ensure_points_layer_active()
3263
+ # Ensure the click callback is still connected
3264
+ if (
3265
+ hasattr(main_points_layer, "mouse_drag_callbacks")
3266
+ and processor._on_points_clicked
3267
+ not in main_points_layer.mouse_drag_callbacks
3268
+ ):
3269
+ main_points_layer.mouse_drag_callbacks.append(
3270
+ processor._on_points_clicked
3271
+ )
3272
+
3273
+ # Reset all tracking attributes for 2D
3274
+ if not processor.use_3d:
3275
+ # Reset current segmentation tracking
3276
+ if hasattr(processor, "current_points"):
3277
+ processor.current_points = []
3278
+ processor.current_labels = []
3279
+
3280
+ # Reset object tracking
3281
+ if hasattr(processor, "obj_points"):
3282
+ processor.obj_points = {}
3283
+ processor.obj_labels = {}
3284
+
3285
+ # Reset box tracking
3286
+ if hasattr(processor, "obj_boxes"):
3287
+ processor.obj_boxes = {}
3288
+
3289
+ # Reset object ID counters
3290
+ if hasattr(processor, "current_obj_id"):
3291
+ # Find the highest existing label ID
3292
+ if processor.segmentation_result is not None:
3293
+ max_label = processor.segmentation_result.max()
3294
+ processor.current_obj_id = max(int(max_label) + 1, 1)
3295
+ processor.next_obj_id = processor.current_obj_id
3296
+ else:
3297
+ processor.current_obj_id = 1
3298
+ processor.next_obj_id = 1
3299
+
3300
+ # Reset SAM2 predictor state
3301
+ processor.reset_sam2_state()
3302
+
3303
+ # For 3D, reset video-specific tracking
3304
+ else:
3305
+ if hasattr(processor, "sam2_points_by_obj"):
3306
+ processor.sam2_points_by_obj = {}
3307
+ processor.sam2_labels_by_obj = {}
3308
+
3309
+ # Reset box tracking
3310
+ if hasattr(processor, "obj_boxes"):
3311
+ processor.obj_boxes = {}
3312
+
3313
+ if hasattr(processor, "points_data"):
3314
+ processor.points_data = {}
3315
+ processor.points_labels = {}
3316
+
3317
+ # Note: We don't reset _sam2_state for 3D as it needs to maintain video state
3318
+
3319
+ # Make the appropriate prompt layer active based on mode
3320
+ _ensure_active_prompt_layer()
2388
3321
 
2389
3322
  status_label.setText(
2390
- "Cleared all points. Click on Points layer to add new points."
3323
+ "Cleared all prompts. Ready to add new segmentation prompts."
2391
3324
  )
2392
3325
 
2393
3326
  def on_select_all_clicked():
@@ -2411,8 +3344,14 @@ def create_crop_widget(processor):
2411
3344
  )
2412
3345
 
2413
3346
  def on_next_clicked():
2414
- # Clear points before moving to next image
2415
- on_clear_points_clicked()
3347
+ # Check if we can move to the next image before clearing prompts
3348
+ if processor.current_index >= len(processor.images) - 1:
3349
+ next_button.setEnabled(False)
3350
+ status_label.setText("No more images. Processing complete.")
3351
+ return
3352
+
3353
+ # Clear prompts before moving to next image
3354
+ on_clear_prompts_clicked()
2416
3355
 
2417
3356
  if not processor.next_image():
2418
3357
  next_button.setEnabled(False)
@@ -2422,11 +3361,17 @@ def create_crop_widget(processor):
2422
3361
  status_label.setText(
2423
3362
  f"Showing image {processor.current_index + 1}/{len(processor.images)}"
2424
3363
  )
2425
- processor._ensure_points_layer_active()
3364
+ processor._ensure_active_prompt_layer()
2426
3365
 
2427
3366
  def on_prev_clicked():
2428
- # Clear points before moving to previous image
2429
- on_clear_points_clicked()
3367
+ # Check if we can move to the previous image before clearing prompts
3368
+ if processor.current_index <= 0:
3369
+ prev_button.setEnabled(False)
3370
+ status_label.setText("Already at the first image.")
3371
+ return
3372
+
3373
+ # Clear prompts before moving to previous image
3374
+ on_clear_prompts_clicked()
2430
3375
 
2431
3376
  if not processor.previous_image():
2432
3377
  prev_button.setEnabled(False)
@@ -2436,15 +3381,33 @@ def create_crop_widget(processor):
2436
3381
  status_label.setText(
2437
3382
  f"Showing image {processor.current_index + 1}/{len(processor.images)}"
2438
3383
  )
2439
- processor._ensure_points_layer_active()
3384
+ processor._ensure_active_prompt_layer()
3385
+
3386
+ def on_point_mode_clicked():
3387
+ processor.prompt_mode = "point"
3388
+ point_mode_button.setChecked(True)
3389
+ box_mode_button.setChecked(False)
3390
+ processor._update_label_layer()
3391
+ status_label.setText("Point mode active - click on objects to segment")
3392
+
3393
+ def on_box_mode_clicked():
3394
+ processor.prompt_mode = "box"
3395
+ point_mode_button.setChecked(False)
3396
+ box_mode_button.setChecked(True)
3397
+ processor._update_label_layer()
3398
+ status_label.setText(
3399
+ "Rectangle mode active - draw rectangles around objects"
3400
+ )
2440
3401
 
2441
- clear_points_button.clicked.connect(on_clear_points_clicked)
3402
+ clear_prompts_button.clicked.connect(on_clear_prompts_clicked)
2442
3403
  select_all_button.clicked.connect(on_select_all_clicked)
2443
3404
  clear_selection_button.clicked.connect(on_clear_selection_clicked)
2444
3405
  crop_button.clicked.connect(on_crop_clicked)
2445
3406
  next_button.clicked.connect(on_next_clicked)
2446
3407
  prev_button.clicked.connect(on_prev_clicked)
2447
- activate_button.clicked.connect(_ensure_points_layer_active)
3408
+ activate_button.clicked.connect(_ensure_active_prompt_layer)
3409
+ point_mode_button.clicked.connect(on_point_mode_clicked)
3410
+ box_mode_button.clicked.connect(on_box_mode_clicked)
2448
3411
 
2449
3412
  return crop_widget
2450
3413
 
@@ -2463,6 +3426,19 @@ def batch_crop_anything(
2463
3426
  viewer: Viewer = None,
2464
3427
  ):
2465
3428
  """MagicGUI widget for starting Batch Crop Anything using SAM2."""
3429
+ # Check if torch is available
3430
+ if not _HAS_TORCH:
3431
+ QMessageBox.critical(
3432
+ None,
3433
+ "Missing Dependency",
3434
+ "PyTorch not found. Batch Crop Anything requires PyTorch and SAM2.\n\n"
3435
+ "To install the required dependencies, run:\n"
3436
+ "pip install 'napari-tmidas[deep-learning]'\n\n"
3437
+ "Then follow SAM2 installation instructions at:\n"
3438
+ "https://github.com/MercaderLabAnatomy/napari-tmidas#installation",
3439
+ )
3440
+ return
3441
+
2466
3442
  # Check if SAM2 is available
2467
3443
  try:
2468
3444
  import importlib.util
@@ -2473,15 +3449,15 @@ def batch_crop_anything(
2473
3449
  None,
2474
3450
  "Missing Dependency",
2475
3451
  "SAM2 not found. Please follow installation instructions at:\n"
2476
- "https://github.com/MercaderLabAnatomy/napari-tmidas?tab=readme-ov-file#dependencies\n",
3452
+ "https://github.com/MercaderLabAnatomy/napari-tmidas#installation\n",
2477
3453
  )
2478
3454
  return
2479
3455
  except ImportError:
2480
3456
  QMessageBox.critical(
2481
3457
  None,
2482
3458
  "Missing Dependency",
2483
- "SAM2 package cannot be imported. Please follow installation instructions at\n"
2484
- "https://github.com/MercaderLabAnatomy/napari-tmidas?tab=readme-ov-file#dependencies",
3459
+ "SAM2 package cannot be imported. Please follow installation instructions at:\n"
3460
+ "https://github.com/MercaderLabAnatomy/napari-tmidas#installation",
2485
3461
  )
2486
3462
  return
2487
3463
 
@@ -2509,24 +3485,7 @@ def batch_crop_anything_widget():
2509
3485
  # Create the magicgui widget
2510
3486
  widget = batch_crop_anything
2511
3487
 
2512
- # Create and add browse button for folder path
2513
- folder_browse_button = QPushButton("Browse...")
2514
-
2515
- def on_folder_browse_clicked():
2516
- folder = QFileDialog.getExistingDirectory(
2517
- None,
2518
- "Select Folder",
2519
- os.path.expanduser("~"),
2520
- QFileDialog.ShowDirsOnly | QFileDialog.DontResolveSymlinks,
2521
- )
2522
- if folder:
2523
- # Update the folder_path field
2524
- widget.folder_path.value = folder
2525
-
2526
- folder_browse_button.clicked.connect(on_folder_browse_clicked)
2527
-
2528
- # Insert the browse button next to the folder_path field
2529
- folder_layout = widget.folder_path.native.parent().layout()
2530
- folder_layout.addWidget(folder_browse_button)
3488
+ # Add browse button using common utility
3489
+ add_browse_button_to_folder_field(widget, "folder_path")
2531
3490
 
2532
3491
  return widget