napari-tmidas 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. napari_tmidas/__init__.py +35 -5
  2. napari_tmidas/_crop_anything.py +1520 -609
  3. napari_tmidas/_env_manager.py +76 -0
  4. napari_tmidas/_file_conversion.py +1646 -1131
  5. napari_tmidas/_file_selector.py +1455 -216
  6. napari_tmidas/_label_inspection.py +83 -8
  7. napari_tmidas/_processing_worker.py +309 -0
  8. napari_tmidas/_reader.py +6 -10
  9. napari_tmidas/_registry.py +2 -2
  10. napari_tmidas/_roi_colocalization.py +1221 -84
  11. napari_tmidas/_tests/test_crop_anything.py +123 -0
  12. napari_tmidas/_tests/test_env_manager.py +89 -0
  13. napari_tmidas/_tests/test_grid_view_overlay.py +193 -0
  14. napari_tmidas/_tests/test_init.py +98 -0
  15. napari_tmidas/_tests/test_intensity_label_filter.py +222 -0
  16. napari_tmidas/_tests/test_label_inspection.py +86 -0
  17. napari_tmidas/_tests/test_processing_basic.py +500 -0
  18. napari_tmidas/_tests/test_processing_worker.py +142 -0
  19. napari_tmidas/_tests/test_regionprops_analysis.py +547 -0
  20. napari_tmidas/_tests/test_registry.py +70 -2
  21. napari_tmidas/_tests/test_scipy_filters.py +168 -0
  22. napari_tmidas/_tests/test_skimage_filters.py +259 -0
  23. napari_tmidas/_tests/test_split_channels.py +217 -0
  24. napari_tmidas/_tests/test_spotiflow.py +87 -0
  25. napari_tmidas/_tests/test_tyx_display_fix.py +142 -0
  26. napari_tmidas/_tests/test_ui_utils.py +68 -0
  27. napari_tmidas/_tests/test_widget.py +30 -0
  28. napari_tmidas/_tests/test_windows_basic.py +66 -0
  29. napari_tmidas/_ui_utils.py +57 -0
  30. napari_tmidas/_version.py +16 -3
  31. napari_tmidas/_widget.py +41 -4
  32. napari_tmidas/processing_functions/basic.py +557 -20
  33. napari_tmidas/processing_functions/careamics_env_manager.py +72 -99
  34. napari_tmidas/processing_functions/cellpose_env_manager.py +415 -112
  35. napari_tmidas/processing_functions/cellpose_segmentation.py +132 -191
  36. napari_tmidas/processing_functions/colocalization.py +513 -56
  37. napari_tmidas/processing_functions/grid_view_overlay.py +703 -0
  38. napari_tmidas/processing_functions/intensity_label_filter.py +422 -0
  39. napari_tmidas/processing_functions/regionprops_analysis.py +1280 -0
  40. napari_tmidas/processing_functions/sam2_env_manager.py +53 -69
  41. napari_tmidas/processing_functions/sam2_mp4.py +274 -195
  42. napari_tmidas/processing_functions/scipy_filters.py +403 -8
  43. napari_tmidas/processing_functions/skimage_filters.py +424 -212
  44. napari_tmidas/processing_functions/spotiflow_detection.py +949 -0
  45. napari_tmidas/processing_functions/spotiflow_env_manager.py +591 -0
  46. napari_tmidas/processing_functions/timepoint_merger.py +334 -86
  47. {napari_tmidas-0.2.2.dist-info → napari_tmidas-0.2.4.dist-info}/METADATA +70 -30
  48. napari_tmidas-0.2.4.dist-info/RECORD +63 -0
  49. napari_tmidas/_tests/__init__.py +0 -0
  50. napari_tmidas-0.2.2.dist-info/RECORD +0 -40
  51. {napari_tmidas-0.2.2.dist-info → napari_tmidas-0.2.4.dist-info}/WHEEL +0 -0
  52. {napari_tmidas-0.2.2.dist-info → napari_tmidas-0.2.4.dist-info}/entry_points.txt +0 -0
  53. {napari_tmidas-0.2.2.dist-info → napari_tmidas-0.2.4.dist-info}/licenses/LICENSE +0 -0
  54. {napari_tmidas-0.2.2.dist-info → napari_tmidas-0.2.4.dist-info}/top_level.txt +0 -0
@@ -9,32 +9,99 @@ The plugin supports both 2D (YX) and 3D (TYX/ZYX) data.
9
9
  import contextlib
10
10
  import os
11
11
  import sys
12
+ from pathlib import Path
12
13
 
13
14
  import numpy as np
14
- import requests
15
- import torch
16
- from magicgui import magicgui
17
- from napari.layers import Labels
18
- from napari.viewer import Viewer
19
- from qtpy.QtCore import Qt
20
- from qtpy.QtWidgets import (
21
- QCheckBox,
22
- QFileDialog,
23
- QHBoxLayout,
24
- QHeaderView,
25
- QLabel,
26
- QMessageBox,
27
- QPushButton,
28
- QScrollArea,
29
- QTableWidget,
30
- QTableWidgetItem,
31
- QVBoxLayout,
32
- QWidget,
33
- )
34
- from skimage.io import imread
35
- from skimage.transform import resize
36
- from tifffile import imwrite
37
15
 
16
+ # Lazy imports for optional heavy dependencies
17
+ try:
18
+ import requests
19
+
20
+ _HAS_REQUESTS = True
21
+ except ImportError:
22
+ requests = None
23
+ _HAS_REQUESTS = False
24
+
25
+ try:
26
+ import torch
27
+
28
+ _HAS_TORCH = True
29
+ except ImportError:
30
+ torch = None
31
+ _HAS_TORCH = False
32
+
33
+ try:
34
+ from magicgui import magicgui
35
+
36
+ _HAS_MAGICGUI = True
37
+ except ImportError:
38
+ # Create stub decorator
39
+ def magicgui(*args, **kwargs):
40
+ def decorator(func):
41
+ return func
42
+
43
+ if len(args) == 1 and callable(args[0]) and not kwargs:
44
+ return args[0]
45
+ return decorator
46
+
47
+ _HAS_MAGICGUI = False
48
+
49
+ try:
50
+ from napari.layers import Labels
51
+ from napari.viewer import Viewer
52
+
53
+ _HAS_NAPARI = True
54
+ except ImportError:
55
+ Labels = None
56
+ Viewer = None
57
+ _HAS_NAPARI = False
58
+
59
+ try:
60
+ from qtpy.QtCore import Qt
61
+ from qtpy.QtWidgets import (
62
+ QCheckBox,
63
+ QHBoxLayout,
64
+ QHeaderView,
65
+ QLabel,
66
+ QMessageBox,
67
+ QPushButton,
68
+ QScrollArea,
69
+ QTableWidget,
70
+ QTableWidgetItem,
71
+ QVBoxLayout,
72
+ QWidget,
73
+ )
74
+
75
+ _HAS_QTPY = True
76
+ except ImportError:
77
+ Qt = None
78
+ QCheckBox = QHBoxLayout = QHeaderView = QLabel = QMessageBox = None
79
+ QPushButton = QScrollArea = QTableWidget = QTableWidgetItem = None
80
+ QVBoxLayout = QWidget = None
81
+ _HAS_QTPY = False
82
+
83
+ try:
84
+ from skimage.io import imread
85
+ from skimage.transform import resize
86
+
87
+ _HAS_SKIMAGE = True
88
+ except ImportError:
89
+ imread = None
90
+ resize = None
91
+ _HAS_SKIMAGE = False
92
+
93
+ try:
94
+ from tifffile import imwrite
95
+
96
+ _HAS_TIFFFILE = True
97
+ except ImportError:
98
+ imwrite = None
99
+ _HAS_TIFFFILE = False
100
+
101
+ from napari_tmidas._file_selector import (
102
+ load_image_file as load_any_image,
103
+ )
104
+ from napari_tmidas._ui_utils import add_browse_button_to_folder_field
38
105
  from napari_tmidas.processing_functions.sam2_mp4 import tif_to_mp4
39
106
 
40
107
  sam2_paths = [
@@ -98,6 +165,7 @@ class BatchCropAnything:
98
165
  self.image_layer = None
99
166
  self.label_layer = None
100
167
  self.label_table_widget = None
168
+ self.shapes_layer = None
101
169
 
102
170
  # State tracking
103
171
  self.selected_labels = set()
@@ -106,6 +174,9 @@ class BatchCropAnything:
106
174
  # Segmentation parameters
107
175
  self.sensitivity = 50 # Default sensitivity (0-100 scale)
108
176
 
177
+ # Prompt mode: 'point' or 'box'
178
+ self.prompt_mode = "point"
179
+
109
180
  # Initialize the SAM2 model
110
181
  self._initialize_sam2()
111
182
 
@@ -131,17 +202,45 @@ class BatchCropAnything:
131
202
 
132
203
  try:
133
204
  # import torch
205
+ print("DEBUG: Starting SAM2 initialization...")
134
206
 
135
207
  self.device = get_device()
208
+ print(f"DEBUG: Device set to {self.device}")
136
209
 
137
210
  # Download checkpoint if needed
138
211
  checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt"
139
212
  checkpoint_path = download_checkpoint(
140
213
  checkpoint_url, "/opt/sam2/checkpoints/"
141
214
  )
215
+ print(f"DEBUG: Checkpoint path: {checkpoint_path}")
216
+
217
+ # Use relative config path for SAM2's Hydra config system
142
218
  model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
219
+ print(f"DEBUG: Model config: {model_cfg}")
220
+
221
+ # Verify the actual config file exists in the SAM2 installation
222
+ sam2_base_path = None
223
+ for path in sam2_paths:
224
+ if path and os.path.exists(path):
225
+ sam2_base_path = path
226
+ break
227
+
228
+ if sam2_base_path is not None:
229
+ full_config_path = os.path.join(
230
+ sam2_base_path, "sam2", model_cfg
231
+ )
232
+ if not os.path.exists(full_config_path):
233
+ raise FileNotFoundError(
234
+ f"SAM2 config file not found at: {full_config_path}"
235
+ )
236
+ print(f"DEBUG: Verified config exists at: {full_config_path}")
237
+ else:
238
+ print(
239
+ "DEBUG: Warning - could not verify config file exists, but proceeding with relative path"
240
+ )
143
241
 
144
242
  if self.use_3d:
243
+ print("DEBUG: Initializing SAM2 Video Predictor...")
145
244
  from sam2.build_sam import build_sam2_video_predictor
146
245
 
147
246
  self.predictor = build_sam2_video_predictor(
@@ -150,7 +249,9 @@ class BatchCropAnything:
150
249
  self.viewer.status = (
151
250
  f"Initialized SAM2 Video Predictor on {self.device}"
152
251
  )
252
+ print("DEBUG: SAM2 Video Predictor initialized successfully")
153
253
  else:
254
+ print("DEBUG: Initializing SAM2 Image Predictor...")
154
255
  from sam2.build_sam import build_sam2
155
256
  from sam2.sam2_image_predictor import SAM2ImagePredictor
156
257
 
@@ -160,6 +261,7 @@ class BatchCropAnything:
160
261
  self.viewer.status = (
161
262
  f"Initialized SAM2 Image Predictor on {self.device}"
162
263
  )
264
+ print("DEBUG: SAM2 Image Predictor initialized successfully")
163
265
 
164
266
  except (
165
267
  ImportError,
@@ -167,37 +269,79 @@ class BatchCropAnything:
167
269
  ValueError,
168
270
  FileNotFoundError,
169
271
  requests.RequestException,
272
+ AttributeError,
273
+ ModuleNotFoundError,
170
274
  ) as e:
171
275
  import traceback
172
276
 
173
- self.viewer.status = f"Error initializing SAM2: {str(e)}"
277
+ error_msg = f"SAM2 initialization failed: {str(e)}"
278
+ error_type = type(e).__name__
279
+ self.viewer.status = (
280
+ f"{error_msg} - Images will load without segmentation"
281
+ )
174
282
  self.predictor = None
283
+ print(f"DEBUG: SAM2 Error ({error_type}): {error_msg}")
284
+ print("DEBUG: Full traceback:")
175
285
  print(traceback.format_exc())
286
+ print(
287
+ "DEBUG: Note: Images will still load, but automatic segmentation will not be available."
288
+ )
289
+
290
+ # Provide specific guidance based on error type
291
+ if isinstance(e, FileNotFoundError):
292
+ print(
293
+ "DEBUG: This appears to be a missing file issue. Check SAM2 installation and config paths."
294
+ )
295
+ elif isinstance(e, (ImportError, ModuleNotFoundError)):
296
+ print(
297
+ "DEBUG: This appears to be a SAM2 import issue. Check SAM2 installation."
298
+ )
299
+ elif isinstance(e, RuntimeError):
300
+ print(
301
+ "DEBUG: This appears to be a runtime issue, possibly GPU/CUDA related."
302
+ )
303
+ else:
304
+ print(f"DEBUG: Unexpected error type: {error_type}")
176
305
 
177
306
  def load_images(self, folder_path: str):
178
307
  """Load images from the specified folder path."""
308
+ print(f"DEBUG: Loading images from folder: {folder_path}")
179
309
  if not os.path.exists(folder_path):
180
310
  self.viewer.status = f"Folder not found: {folder_path}"
311
+ print(f"DEBUG: Folder does not exist: {folder_path}")
181
312
  return
182
313
 
183
314
  files = os.listdir(folder_path)
184
- self.images = [
185
- os.path.join(folder_path, file)
186
- for file in files
187
- if file.lower().endswith(".tif")
188
- or file.lower().endswith(".tiff")
189
- and "label" not in file.lower()
190
- and "cropped" not in file.lower()
191
- and "_labels_" not in file.lower()
192
- and "_cropped_" not in file.lower()
193
- ]
315
+ print(f"DEBUG: Found {len(files)} files in folder")
316
+ self.images = []
317
+ for file in files:
318
+ full = os.path.join(folder_path, file)
319
+ low = file.lower()
320
+ if (
321
+ low.endswith((".tif", ".tiff"))
322
+ or (os.path.isdir(full) and low.endswith(".zarr"))
323
+ ) and (
324
+ "label" not in low
325
+ and "_labels_" not in low
326
+ and "sam2"
327
+ not in low # Exclude any SAM2-related files (including output from this tool)
328
+ ):
329
+ self.images.append(full)
330
+ print(f"DEBUG: Added image: {file}")
331
+ else:
332
+ print(
333
+ f"DEBUG: Excluded file: {file} (reason: filtering criteria)"
334
+ )
194
335
 
195
336
  if not self.images:
196
337
  self.viewer.status = "No compatible images found in the folder."
338
+ print("DEBUG: No compatible images found")
197
339
  return
198
340
 
341
+ print(f"DEBUG: Total compatible images found: {len(self.images)}")
199
342
  self.viewer.status = f"Found {len(self.images)} .tif images."
200
343
  self.current_index = 0
344
+ print(f"DEBUG: About to load first image: {self.images[0]}")
201
345
  self._load_current_image()
202
346
 
203
347
  def next_image(self):
@@ -250,25 +394,69 @@ class BatchCropAnything:
250
394
 
251
395
  def _load_current_image(self):
252
396
  """Load the current image and generate segmentation."""
397
+ print("DEBUG: _load_current_image called")
253
398
  if not self.images:
254
399
  self.viewer.status = "No images to process."
255
- return
256
-
257
- if self.predictor is None:
258
- self.viewer.status = (
259
- "SAM2 model not initialized. Cannot segment images."
260
- )
400
+ print("DEBUG: No images to process")
261
401
  return
262
402
 
263
403
  image_path = self.images[self.current_index]
264
- self.viewer.status = f"Processing {os.path.basename(image_path)}"
404
+ print(f"DEBUG: Loading image at path: {image_path}")
405
+
406
+ if self.predictor is None:
407
+ self.viewer.status = f"Loading {os.path.basename(image_path)} (SAM2 model not initialized - no segmentation will be available)"
408
+ print("DEBUG: SAM2 predictor is None")
409
+ else:
410
+ self.viewer.status = f"Processing {os.path.basename(image_path)}"
411
+ print("DEBUG: SAM2 predictor is available")
265
412
 
266
413
  try:
414
+ print("DEBUG: About to clear viewer layers")
267
415
  # Clear existing layers
268
416
  self.viewer.layers.clear()
417
+ print("DEBUG: Viewer layers cleared")
269
418
 
419
+ print("DEBUG: About to load image file")
270
420
  # Load and process image
271
- self.original_image = imread(image_path)
421
+ if image_path.lower().endswith(".zarr") or (
422
+ os.path.isdir(image_path)
423
+ and image_path.lower().endswith(".zarr")
424
+ ):
425
+ print("DEBUG: Loading Zarr file")
426
+ data = load_any_image(image_path)
427
+ # If multiple layers returned, take first image layer
428
+ if isinstance(data, list):
429
+ img = None
430
+ for entry in data:
431
+ if isinstance(entry, tuple) and len(entry) == 3:
432
+ d, _kwargs, layer_type = entry
433
+ if layer_type == "image":
434
+ img = d
435
+ break
436
+ elif isinstance(entry, tuple) and len(entry) == 2:
437
+ d, _kwargs = entry
438
+ img = d
439
+ break
440
+ else:
441
+ img = entry
442
+ break
443
+ if img is None:
444
+ raise ValueError("No image layer found in Zarr store")
445
+ else:
446
+ img = data
447
+
448
+ # Compute dask arrays to numpy if needed
449
+ if hasattr(img, "compute"):
450
+ img = img.compute()
451
+
452
+ self.original_image = img
453
+ else:
454
+ print("DEBUG: Loading TIFF file")
455
+ self.original_image = imread(image_path)
456
+
457
+ print(
458
+ f"DEBUG: Image loaded, shape: {self.original_image.shape}, dtype: {self.original_image.dtype}"
459
+ )
272
460
 
273
461
  # For 3D/4D data, determine dimensions
274
462
  if self.use_3d and len(self.original_image.shape) >= 3:
@@ -284,10 +472,12 @@ class BatchCropAnything:
284
472
 
285
473
  if time_dim_idx == 0: # TZYX format
286
474
  # Keep as is, T is already the first dimension
475
+ print("DEBUG: Adding 4D image (TZYX format) to viewer")
287
476
  self.image_layer = self.viewer.add_image(
288
477
  self.original_image,
289
478
  name=f"Image ({os.path.basename(image_path)})",
290
479
  )
480
+ print(f"DEBUG: Added image layer: {self.image_layer}")
291
481
  # Store time dimension info
292
482
  self.time_dim_size = self.original_image.shape[0]
293
483
  self.has_z_dim = True
@@ -309,19 +499,23 @@ class BatchCropAnything:
309
499
  transposed_image # Replace with transposed version
310
500
  )
311
501
 
502
+ print("DEBUG: Adding transposed 4D image to viewer")
312
503
  self.image_layer = self.viewer.add_image(
313
504
  self.original_image,
314
505
  name=f"Image ({os.path.basename(image_path)})",
315
506
  )
507
+ print(f"DEBUG: Added image layer: {self.image_layer}")
316
508
  # Store time dimension info
317
509
  self.time_dim_size = self.original_image.shape[0]
318
510
  self.has_z_dim = True
319
511
  else:
320
512
  # No time dimension found, treat as ZYX
513
+ print("DEBUG: Adding 4D image (ZYX format) to viewer")
321
514
  self.image_layer = self.viewer.add_image(
322
515
  self.original_image,
323
516
  name=f"Image ({os.path.basename(image_path)})",
324
517
  )
518
+ print(f"DEBUG: Added image layer: {self.image_layer}")
325
519
  self.time_dim_size = 1
326
520
  self.has_z_dim = True
327
521
  elif (
@@ -330,30 +524,37 @@ class BatchCropAnything:
330
524
  # Check if first dimension is likely time (> 4, < 400)
331
525
  if 4 < self.original_image.shape[0] < 400:
332
526
  # Likely TYX format
527
+ print("DEBUG: Adding 3D image (TYX format) to viewer")
333
528
  self.image_layer = self.viewer.add_image(
334
529
  self.original_image,
335
530
  name=f"Image ({os.path.basename(image_path)})",
336
531
  )
532
+ print(f"DEBUG: Added image layer: {self.image_layer}")
337
533
  self.time_dim_size = self.original_image.shape[0]
338
534
  self.has_z_dim = False
339
535
  else:
340
536
  # Likely ZYX format or another 3D format
537
+ print("DEBUG: Adding 3D image (ZYX format) to viewer")
341
538
  self.image_layer = self.viewer.add_image(
342
539
  self.original_image,
343
540
  name=f"Image ({os.path.basename(image_path)})",
344
541
  )
542
+ print(f"DEBUG: Added image layer: {self.image_layer}")
345
543
  self.time_dim_size = 1
346
544
  self.has_z_dim = True
347
545
  else:
348
546
  # Should not reach here with use_3d=True, but just in case
547
+ print("DEBUG: Adding 3D image (fallback) to viewer")
349
548
  self.image_layer = self.viewer.add_image(
350
549
  self.original_image,
351
550
  name=f"Image ({os.path.basename(image_path)})",
352
551
  )
552
+ print(f"DEBUG: Added image layer: {self.image_layer}")
353
553
  self.time_dim_size = 1
354
554
  self.has_z_dim = False
355
555
  else:
356
556
  # Handle 2D data as before
557
+ print("DEBUG: Processing 2D image")
357
558
  if self.original_image.dtype != np.uint8:
358
559
  image_for_display = (
359
560
  self.original_image
@@ -364,18 +565,42 @@ class BatchCropAnything:
364
565
  image_for_display = self.original_image
365
566
 
366
567
  # Add image to viewer
568
+ print("DEBUG: Adding 2D image to viewer")
367
569
  self.image_layer = self.viewer.add_image(
368
570
  image_for_display,
369
571
  name=f"Image ({os.path.basename(image_path)})",
370
572
  )
573
+ print(f"DEBUG: Added image layer: {self.image_layer}")
574
+
575
+ # Generate segmentation only if predictor is available
576
+ if self.predictor is not None:
577
+ print("DEBUG: About to generate segmentation")
578
+ self._generate_segmentation(self.original_image, image_path)
579
+ print("DEBUG: Segmentation generation completed")
580
+ else:
581
+ print("DEBUG: Creating empty segmentation (no predictor)")
582
+ # Create empty segmentation when predictor is not available
583
+ if self.use_3d:
584
+ shape = self.original_image.shape
585
+ else:
586
+ shape = self.original_image.shape[:2]
587
+
588
+ self.segmentation_result = np.zeros(shape, dtype=np.uint32)
589
+ self.label_layer = self.viewer.add_labels(
590
+ self.segmentation_result,
591
+ name="No Segmentation (SAM2 not available)",
592
+ )
593
+ print(f"DEBUG: Added empty label layer: {self.label_layer}")
371
594
 
372
- # Generate segmentation
373
- self._generate_segmentation(self.original_image, image_path)
595
+ print("DEBUG: _load_current_image completed successfully")
374
596
 
375
597
  except (FileNotFoundError, ValueError, TypeError, OSError) as e:
376
598
  import traceback
377
599
 
378
- self.viewer.status = f"Error processing image: {str(e)}"
600
+ error_msg = f"Error processing image: {str(e)}"
601
+ self.viewer.status = error_msg
602
+ print(f"DEBUG: Exception in _load_current_image: {error_msg}")
603
+ print("DEBUG: Full traceback:")
379
604
  traceback.print_exc()
380
605
 
381
606
  # Create empty segmentation in case of error
@@ -392,6 +617,7 @@ class BatchCropAnything:
392
617
  self.label_layer = self.viewer.add_labels(
393
618
  self.segmentation_result, name="Error: No Segmentation"
394
619
  )
620
+ print(f"DEBUG: Added error label layer: {self.label_layer}")
395
621
 
396
622
  def _generate_segmentation(self, image, image_path: str):
397
623
  """Generate segmentation for the current image using SAM2."""
@@ -447,7 +673,8 @@ class BatchCropAnything:
447
673
  traceback.print_exc()
448
674
 
449
675
  def _generate_2d_segmentation(self, confidence_threshold):
450
- """Generate 2D segmentation using SAM2 Image Predictor."""
676
+ """Generate initial 2D segmentation - start with empty labels for interactive mode."""
677
+ device_type = "cuda" if self.device.type == "cuda" else "cpu"
451
678
  # Ensure image is in the correct format for SAM2
452
679
  image = self.current_image_for_segmentation
453
680
 
@@ -469,9 +696,7 @@ class BatchCropAnything:
469
696
  (new_height, new_width),
470
697
  anti_aliasing=True,
471
698
  preserve_range=True,
472
- ).astype(
473
- np.float32
474
- ) # Convert to float32
699
+ ).astype(np.float32)
475
700
 
476
701
  self.current_scale_factor = scale_factor
477
702
  else:
@@ -497,73 +722,54 @@ class BatchCropAnything:
497
722
  if resized_image.max() > 1.0:
498
723
  resized_image = resized_image / 255.0
499
724
 
500
- # Set SAM2 prediction parameters based on sensitivity
501
- with torch.inference_mode(), torch.autocast(
502
- "cuda", dtype=torch.float32
503
- ):
504
- # Set the image in the predictor
505
- self.predictor.set_image(resized_image)
725
+ # Store the prepared image for later use
726
+ self.prepared_sam2_image = resized_image
506
727
 
507
- # Use automatic points generation with confidence threshold
508
- masks, scores, _ = self.predictor.predict(
509
- point_coords=None,
510
- point_labels=None,
511
- box=None,
512
- multimask_output=True,
513
- )
728
+ # Initialize empty segmentation result
729
+ self.segmentation_result = np.zeros(orig_shape, dtype=np.uint32)
730
+ self.label_info = {}
514
731
 
515
- # Filter masks by confidence threshold
516
- valid_masks = scores > confidence_threshold
517
- masks = masks[valid_masks]
518
- scores = scores[valid_masks]
519
-
520
- # Convert masks to label image
521
- labels = np.zeros(resized_image.shape[:2], dtype=np.uint32)
522
- self.label_info = {} # Reset label info
523
-
524
- for i, mask in enumerate(masks):
525
- label_id = i + 1 # Start label IDs from 1
526
- labels[mask] = label_id
527
-
528
- # Calculate label information
529
- area = np.sum(mask)
530
- y_indices, x_indices = np.where(mask)
531
- center_y = np.mean(y_indices) if len(y_indices) > 0 else 0
532
- center_x = np.mean(x_indices) if len(x_indices) > 0 else 0
533
-
534
- # Store label info
535
- self.label_info[label_id] = {
536
- "area": area,
537
- "center_y": center_y,
538
- "center_x": center_x,
539
- "score": float(scores[i]),
540
- }
541
-
542
- # Handle upscaling if needed
543
- if self.current_scale_factor < 1.0:
544
- labels = resize(
545
- labels,
546
- orig_shape,
547
- order=0, # Nearest neighbor interpolation
548
- preserve_range=True,
549
- anti_aliasing=False,
550
- ).astype(np.uint32)
551
-
552
- # Sort labels by area (largest first)
553
- self.label_info = dict(
554
- sorted(
555
- self.label_info.items(),
556
- key=lambda item: item[1]["area"],
557
- reverse=True,
558
- )
732
+ # Initialize tracking for interactive segmentation
733
+ self.current_points = []
734
+ self.current_labels = []
735
+ self.current_obj_id = 1
736
+ self.next_obj_id = 1
737
+
738
+ # Initialize object tracking dictionaries
739
+ self.obj_points = {}
740
+ self.obj_labels = {}
741
+
742
+ # Reset SAM2-specific tracking dictionaries for 2D mode
743
+ self.sam2_points_by_obj = {}
744
+ self.sam2_labels_by_obj = {}
745
+ self._sam2_next_obj_id = 1
746
+ print(
747
+ "DEBUG: Reset _sam2_next_obj_id to 1 in _generate_2d_segmentation"
559
748
  )
560
749
 
561
- # Save segmentation result
562
- self.segmentation_result = labels
750
+ # Set the image in the predictor for later use (2D mode only)
751
+ device_type = "cuda" if self.device.type == "cuda" else "cpu"
752
+ if hasattr(self.predictor, "set_image"):
753
+ with (
754
+ torch.inference_mode(),
755
+ torch.autocast(device_type, dtype=torch.float32),
756
+ ):
757
+ self.predictor.set_image(resized_image)
758
+ else:
759
+ print(
760
+ "DEBUG: Skipping set_image - predictor doesn't support it (likely VideoPredictor)"
761
+ )
563
762
 
564
763
  # Update the label layer
565
764
  self._update_label_layer()
566
765
 
766
+ # Show instructions
767
+ self.viewer.status = (
768
+ "2D Mode: Click on the image to add objects. Use Shift+click for negative points to refine. "
769
+ "Click existing objects to select them for cropping. "
770
+ "Note: For stacks, interactive segmentation only works in 2D view mode."
771
+ )
772
+
567
773
  def _generate_3d_segmentation(self, confidence_threshold, image_path):
568
774
  """
569
775
  Initialize 3D segmentation using SAM2 Video Predictor.
@@ -584,9 +790,7 @@ class BatchCropAnything:
584
790
  import tempfile
585
791
 
586
792
  temp_dir = tempfile.gettempdir()
587
- mp4_path = os.path.join(
588
- temp_dir, f"temp_volume_{os.path.basename(image_path)}.mp4"
589
- )
793
+ mp4_path = None
590
794
 
591
795
  # If we need to save a modified version for MP4 conversion
592
796
  need_temp_tif = False
@@ -616,31 +820,72 @@ class BatchCropAnything:
616
820
  imwrite(temp_tif_path, projected_volume)
617
821
  need_temp_tif = True
618
822
 
619
- # Convert the projected TIF to MP4
620
- self.viewer.status = (
621
- "Converting projected 3D volume to MP4 format for SAM2..."
622
- )
623
- mp4_path = tif_to_mp4(temp_tif_path)
823
+ # Check if MP4 already exists
824
+ expected_mp4 = str(Path(temp_tif_path).with_suffix(".mp4"))
825
+ if os.path.exists(expected_mp4):
826
+ self.viewer.status = (
827
+ f"Using existing MP4: {os.path.basename(expected_mp4)}"
828
+ )
829
+ print(
830
+ f"DEBUG: MP4 already exists, skipping conversion: {expected_mp4}"
831
+ )
832
+ mp4_path = expected_mp4
833
+ else:
834
+ # Convert the projected TIF to MP4
835
+ self.viewer.status = "Converting projected 3D volume to MP4 format for SAM2..."
836
+ mp4_path = tif_to_mp4(temp_tif_path)
624
837
  else:
625
- # Convert original volume to video format for SAM2
626
- self.viewer.status = (
627
- "Converting 3D volume to MP4 format for SAM2..."
628
- )
629
- mp4_path = tif_to_mp4(image_path)
838
+ # Check if MP4 already exists for the original image
839
+ expected_mp4 = str(Path(image_path).with_suffix(".mp4"))
840
+ if os.path.exists(expected_mp4):
841
+ self.viewer.status = (
842
+ f"Using existing MP4: {os.path.basename(expected_mp4)}"
843
+ )
844
+ print(
845
+ f"DEBUG: MP4 already exists, skipping conversion: {expected_mp4}"
846
+ )
847
+ mp4_path = expected_mp4
848
+ else:
849
+ # Convert original volume to video format for SAM2
850
+ self.viewer.status = (
851
+ "Converting 3D volume to MP4 format for SAM2..."
852
+ )
853
+ mp4_path = tif_to_mp4(image_path)
630
854
 
631
855
  # Initialize SAM2 state with the video
632
856
  self.viewer.status = "Initializing SAM2 Video Predictor..."
633
- with torch.inference_mode(), torch.autocast(
634
- "cuda", dtype=torch.bfloat16
635
- ):
636
- self._sam2_state = self.predictor.init_state(mp4_path)
857
+ try:
858
+ device_type = "cuda" if self.device.type == "cuda" else "cpu"
859
+ with (
860
+ torch.inference_mode(),
861
+ torch.autocast(device_type, dtype=torch.float32),
862
+ ):
863
+ self._sam2_state = self.predictor.init_state(mp4_path)
864
+ except (
865
+ RuntimeError,
866
+ ValueError,
867
+ TypeError,
868
+ torch.cuda.OutOfMemoryError,
869
+ ) as e:
870
+ self.viewer.status = (
871
+ f"Error initializing SAM2 video predictor: {str(e)}"
872
+ )
873
+ print(f"SAM2 video predictor initialization failed: {e}")
874
+ return
637
875
 
638
876
  # Store needed state for 3D processing
639
877
  self._sam2_next_obj_id = 1
878
+ print(
879
+ "DEBUG: Reset _sam2_next_obj_id to 1 in _generate_3d_segmentation"
880
+ )
640
881
  self._sam2_prompts = (
641
882
  {}
642
883
  ) # Store prompts for each object (points, labels, box)
643
884
 
885
+ # Reset SAM2-specific tracking dictionaries for 3D mode
886
+ self.sam2_points_by_obj = {}
887
+ self.sam2_labels_by_obj = {}
888
+
644
889
  # Update the label layer with empty segmentation
645
890
  self._update_label_layer()
646
891
 
@@ -648,8 +893,10 @@ class BatchCropAnything:
648
893
  if self.label_layer is not None and hasattr(
649
894
  self.label_layer, "mouse_drag_callbacks"
650
895
  ):
896
+ # Safely remove all existing callbacks
651
897
  for callback in list(self.label_layer.mouse_drag_callbacks):
652
- self.label_layer.mouse_drag_callbacks.remove(callback)
898
+ with contextlib.suppress(ValueError):
899
+ self.label_layer.mouse_drag_callbacks.remove(callback)
653
900
 
654
901
  # Add 3D-specific click handler
655
902
  self.label_layer.mouse_drag_callbacks.append(
@@ -673,8 +920,8 @@ class BatchCropAnything:
673
920
 
674
921
  # Show instructions
675
922
  self.viewer.status = (
676
- "3D Mode active: Navigate to the first frame where object appears, then click. "
677
- "Use Shift+click for negative points (to remove areas). "
923
+ "3D Mode active: IMPORTANT - Navigate to the FIRST SLICE where object appears (using slider), "
924
+ "then click on object in 2D view (not 3D view). Use Shift+click for negative points. "
678
925
  "Segmentation will be propagated to all frames automatically."
679
926
  )
680
927
 
@@ -728,6 +975,9 @@ class BatchCropAnything:
728
975
  # Create new object for positive points on background
729
976
  ann_obj_id = self._sam2_next_obj_id
730
977
  if point_label > 0 and label_id == 0:
978
+ print(
979
+ f"DEBUG: Incrementing _sam2_next_obj_id from {self._sam2_next_obj_id} to {self._sam2_next_obj_id + 1}"
980
+ )
731
981
  self._sam2_next_obj_id += 1
732
982
 
733
983
  # Find or create points layer for this object
@@ -915,8 +1165,10 @@ class BatchCropAnything:
915
1165
  # Try to perform SAM2 propagation with error handling
916
1166
  try:
917
1167
  # Use torch.inference_mode() and torch.autocast to ensure consistent dtypes
918
- with torch.inference_mode(), torch.autocast(
919
- "cuda", dtype=torch.float32
1168
+ device_type = "cuda" if self.device.type == "cuda" else "cpu"
1169
+ with (
1170
+ torch.inference_mode(),
1171
+ torch.autocast(device_type, dtype=torch.float32),
920
1172
  ):
921
1173
  # Attempt to run SAM2 propagation - this will iterate through all frames
922
1174
  for (
@@ -1012,7 +1264,11 @@ class BatchCropAnything:
1012
1264
  time.sleep(2)
1013
1265
  for layer in list(self.viewer.layers):
1014
1266
  if "Propagation Progress" in layer.name:
1015
- self.viewer.layers.remove(layer)
1267
+ # Clean up callbacks before removing the layer to prevent cleanup issues
1268
+ if hasattr(layer, "mouse_drag_callbacks"):
1269
+ layer.mouse_drag_callbacks.clear()
1270
+ with contextlib.suppress(ValueError):
1271
+ self.viewer.layers.remove(layer)
1016
1272
 
1017
1273
  threading.Thread(target=remove_progress).start()
1018
1274
 
@@ -1035,6 +1291,7 @@ class BatchCropAnything:
1035
1291
  Given a 3D coordinate (x, y, z), run SAM2 video predictor to segment the object at that point,
1036
1292
  update the segmentation result and label layer.
1037
1293
  """
1294
+ device_type = "cuda" if self.device.type == "cuda" else "cpu"
1038
1295
  if not hasattr(self, "_sam2_state") or self._sam2_state is None:
1039
1296
  self.viewer.status = "SAM2 3D state not initialized."
1040
1297
  return
@@ -1048,8 +1305,9 @@ class BatchCropAnything:
1048
1305
  point_coords = np.array([[x, y, z]])
1049
1306
  point_labels = np.array([1]) # 1 = foreground
1050
1307
 
1051
- with torch.inference_mode(), torch.autocast(
1052
- "cuda", dtype=torch.bfloat16
1308
+ with (
1309
+ torch.inference_mode(),
1310
+ torch.autocast(device_type, dtype=torch.float32),
1053
1311
  ):
1054
1312
  masks, scores, _ = self.predictor.predict(
1055
1313
  state=self._sam2_state,
@@ -1103,7 +1361,11 @@ class BatchCropAnything:
1103
1361
  # Remove existing label layer if it exists
1104
1362
  for layer in list(self.viewer.layers):
1105
1363
  if isinstance(layer, Labels) and "Segmentation" in layer.name:
1106
- self.viewer.layers.remove(layer)
1364
+ # Clean up callbacks before removing the layer to prevent cleanup issues
1365
+ if hasattr(layer, "mouse_drag_callbacks"):
1366
+ layer.mouse_drag_callbacks.clear()
1367
+ with contextlib.suppress(ValueError):
1368
+ self.viewer.layers.remove(layer)
1107
1369
 
1108
1370
  # Add label layer to viewer
1109
1371
  self.label_layer = self.viewer.add_labels(
@@ -1112,10 +1374,36 @@ class BatchCropAnything:
1112
1374
  opacity=0.7,
1113
1375
  )
1114
1376
 
1115
- # Create points layer for interaction if it doesn't exist
1377
+ # Connect click handler to the label layer for selection and deletion
1378
+ if hasattr(self.label_layer, "mouse_drag_callbacks"):
1379
+ # Clear existing callbacks to avoid duplicates
1380
+ self.label_layer.mouse_drag_callbacks.clear()
1381
+ # Add our click handler
1382
+ self.label_layer.mouse_drag_callbacks.append(
1383
+ self._on_label_clicked
1384
+ )
1385
+
1386
+ # Create or update interaction layers based on mode
1387
+ if self.prompt_mode == "point":
1388
+ self._ensure_points_layer()
1389
+ self._remove_shapes_layer()
1390
+ else: # box mode
1391
+ self._ensure_shapes_layer()
1392
+ self._remove_points_layer()
1393
+
1394
+ # Update status
1395
+ n_labels = len(np.unique(self.segmentation_result)) - (
1396
+ 1 if 0 in np.unique(self.segmentation_result) else 0
1397
+ )
1398
+ self.viewer.status = f"Loaded image {self.current_index + 1}/{len(self.images)} - Found {n_labels} segments"
1399
+
1400
+ def _ensure_points_layer(self):
1401
+ """Ensure points layer exists and is properly configured."""
1116
1402
  points_layer = None
1117
1403
  for layer in list(self.viewer.layers):
1118
- if "Points" in layer.name:
1404
+ if (
1405
+ "Points" in layer.name and "Object" not in layer.name
1406
+ ): # Main points layer
1119
1407
  points_layer = layer
1120
1408
  break
1121
1409
 
@@ -1131,141 +1419,193 @@ class BatchCropAnything:
1131
1419
  opacity=0.8,
1132
1420
  )
1133
1421
 
1134
- with contextlib.suppress(AttributeError, ValueError):
1135
- points_layer.mouse_drag_callbacks.remove(
1136
- self._on_points_clicked
1137
- )
1422
+ # Connect points layer mouse click event
1423
+ if hasattr(points_layer, "mouse_drag_callbacks"):
1424
+ points_layer.mouse_drag_callbacks.clear()
1138
1425
  points_layer.mouse_drag_callbacks.append(
1139
1426
  self._on_points_clicked
1140
1427
  )
1141
1428
 
1142
- # Connect points layer mouse click event
1143
- points_layer.mouse_drag_callbacks.append(self._on_points_clicked)
1144
-
1145
1429
  # Make the points layer active to encourage interaction with it
1146
1430
  self.viewer.layers.selection.active = points_layer
1147
1431
 
1148
- # Update status
1149
- n_labels = len(np.unique(self.segmentation_result)) - (
1150
- 1 if 0 in np.unique(self.segmentation_result) else 0
1151
- )
1152
- self.viewer.status = f"Loaded image {self.current_index + 1}/{len(self.images)} - Found {n_labels} segments"
1432
+ def _ensure_shapes_layer(self):
1433
+ """Ensure shapes layer exists and is properly configured."""
1434
+ shapes_layer = None
1435
+ for layer in list(self.viewer.layers):
1436
+ if "Rectangles" in layer.name:
1437
+ shapes_layer = layer
1438
+ break
1153
1439
 
1154
- def _on_points_clicked(self, layer, event):
1155
- """Handle clicks on the points layer for adding/removing points."""
1156
- try:
1157
- # Only process clicks, not drags
1158
- if event.type != "mouse_press":
1440
+ if shapes_layer is None:
1441
+ # Initialize an empty shapes layer
1442
+ shapes_layer = self.viewer.add_shapes(
1443
+ None,
1444
+ shape_type="rectangle",
1445
+ edge_width=3,
1446
+ edge_color="green",
1447
+ face_color="transparent",
1448
+ name="Rectangles (Draw to Segment)",
1449
+ )
1450
+
1451
+ # Store reference
1452
+ self.shapes_layer = shapes_layer
1453
+
1454
+ # Initialize processing flag to prevent re-entry
1455
+ if not hasattr(self, "_processing_rectangle"):
1456
+ self._processing_rectangle = False
1457
+
1458
+ # Always ensure the event is connected (disconnect old ones first to avoid duplicates)
1459
+ # Remove any existing callbacks
1460
+ with contextlib.suppress(Exception):
1461
+ shapes_layer.events.data.disconnect()
1462
+
1463
+ # Connect shape added event
1464
+ @shapes_layer.events.data.connect
1465
+ def on_shape_added(event):
1466
+ print(
1467
+ f"DEBUG: Shape event triggered! Shapes: {len(shapes_layer.data)}, Processing: {self._processing_rectangle}"
1468
+ )
1469
+
1470
+ # Ignore if we're already processing or if there are no shapes
1471
+ if self._processing_rectangle:
1472
+ print("DEBUG: Already processing a rectangle, ignoring event")
1159
1473
  return
1160
1474
 
1161
- # Get coordinates of mouse click
1162
- coords = np.round(event.position).astype(int)
1475
+ if len(shapes_layer.data) == 0:
1476
+ print("DEBUG: No shapes present, ignoring event")
1477
+ return
1163
1478
 
1164
- # Check if Shift is pressed for negative points
1165
- is_negative = "Shift" in event.modifiers
1166
- point_label = -1 if is_negative else 1
1479
+ # Only process if we have exactly 1 shape (newly drawn)
1480
+ if len(shapes_layer.data) == 1:
1481
+ print("DEBUG: New shape detected, processing...")
1482
+ # Set flag to prevent re-entry
1483
+ self._processing_rectangle = True
1484
+ try:
1485
+ # Get the shape
1486
+ self._on_rectangle_added(shapes_layer.data[-1])
1487
+ finally:
1488
+ # Always reset flag
1489
+ self._processing_rectangle = False
1490
+ else:
1491
+ print(
1492
+ f"DEBUG: Multiple shapes present ({len(shapes_layer.data)}), skipping"
1493
+ )
1167
1494
 
1168
- # Handle 2D vs 3D coordinates
1169
- if self.use_3d:
1170
- if len(coords) == 3:
1171
- t, y, x = map(int, coords)
1172
- elif len(coords) == 2:
1173
- t = int(self.viewer.dims.current_step[0])
1174
- y, x = map(int, coords)
1175
- else:
1176
- self.viewer.status = (
1177
- f"Unexpected coordinate dimensions: {coords}"
1178
- )
1179
- return
1495
+ # Make the shapes layer active
1496
+ self.viewer.layers.selection.active = shapes_layer
1180
1497
 
1181
- # Add point to the layer immediately for visual feedback
1182
- new_point = np.array([[t, y, x]])
1183
- if len(layer.data) == 0:
1184
- layer.data = new_point
1185
- else:
1186
- layer.data = np.vstack([layer.data, new_point])
1498
+ def _remove_points_layer(self):
1499
+ """Remove points layer when not in point mode."""
1500
+ for layer in list(self.viewer.layers):
1501
+ if "Points" in layer.name and "Object" not in layer.name:
1502
+ if hasattr(layer, "mouse_drag_callbacks"):
1503
+ layer.mouse_drag_callbacks.clear()
1504
+ with contextlib.suppress(ValueError):
1505
+ self.viewer.layers.remove(layer)
1187
1506
 
1188
- # Update point colors
1189
- colors = layer.face_color
1190
- if isinstance(colors, list):
1191
- colors.append("red" if is_negative else "green")
1192
- else:
1193
- n_points = len(layer.data)
1194
- colors = ["green"] * (n_points - 1)
1195
- colors.append("red" if is_negative else "green")
1196
- layer.face_color = colors
1507
+ def _remove_shapes_layer(self):
1508
+ """Remove shapes layer when not in box mode."""
1509
+ for layer in list(self.viewer.layers):
1510
+ if "Rectangles" in layer.name:
1511
+ with contextlib.suppress(ValueError):
1512
+ self.viewer.layers.remove(layer)
1513
+ self.shapes_layer = None
1197
1514
 
1198
- # Get the object ID
1199
- # If clicking on existing segmentation with negative point
1200
- label_id = self.segmentation_result[t, y, x]
1201
- if is_negative and label_id > 0:
1202
- obj_id = label_id
1515
+ def _on_rectangle_added(self, rectangle_coords):
1516
+ """Handle rectangle selection for segmentation."""
1517
+ print("DEBUG: _on_rectangle_added called!")
1518
+ device_type = "cuda" if self.device.type == "cuda" else "cpu"
1519
+ try:
1520
+ # Rectangle coords are in the form of a 4x2 or 4x3 array (corners)
1521
+ # Convert to bounding box format [x_min, y_min, x_max, y_max]
1522
+
1523
+ # Debug info
1524
+ print(f"DEBUG: Rectangle coords: {rectangle_coords}")
1525
+ print(f"DEBUG: Rectangle coords shape: {rectangle_coords.shape}")
1526
+ print(f"DEBUG: use_3d flag: {self.use_3d}")
1527
+ print(
1528
+ f"DEBUG: Has predictor: {hasattr(self, 'predictor') and self.predictor is not None}"
1529
+ )
1530
+ if hasattr(self, "predictor") and self.predictor is not None:
1531
+ print(
1532
+ f"DEBUG: Predictor type: {type(self.predictor).__name__}"
1533
+ )
1534
+ else:
1535
+ print("DEBUG: No predictor available!")
1536
+ self.viewer.status = "Error: Predictor not initialized"
1537
+ return
1538
+
1539
+ # Check if we're in 3D mode (use the flag, not coordinate shape)
1540
+ # In 3D mode, even when drawing on a 2D slice, we get (4, 2) coords
1541
+ # but we need to treat it as 3D with propagation
1542
+ if (
1543
+ self.use_3d
1544
+ and len(rectangle_coords.shape) == 2
1545
+ and rectangle_coords.shape[0] == 4
1546
+ ):
1547
+ print("DEBUG: Processing as 3D rectangle (will propagate)")
1548
+
1549
+ # Get current frame/slice
1550
+ t = int(self.viewer.dims.current_step[0])
1551
+ print(f"DEBUG: Current frame/slice: {t}")
1552
+
1553
+ # Get Y and X bounds from 2D coordinates
1554
+ if rectangle_coords.shape[1] == 3:
1555
+ # If we somehow got 3D coords (T/Z, Y, X)
1556
+ y_coords = rectangle_coords[:, 1]
1557
+ x_coords = rectangle_coords[:, 2]
1558
+ elif rectangle_coords.shape[1] == 2:
1559
+ # More common: 2D coords (Y, X) when drawing on a slice
1560
+ y_coords = rectangle_coords[:, 0]
1561
+ x_coords = rectangle_coords[:, 1]
1203
1562
  else:
1204
- # For new objects or negative on background
1205
- if not hasattr(self, "_sam2_next_obj_id"):
1206
- self._sam2_next_obj_id = 1
1207
- obj_id = self._sam2_next_obj_id
1208
- if point_label > 0 and label_id == 0:
1209
- self._sam2_next_obj_id += 1
1563
+ print(
1564
+ f"DEBUG: Unexpected coordinate dimensions: {rectangle_coords.shape[1]}"
1565
+ )
1566
+ self.viewer.status = "Error: Unexpected rectangle format"
1567
+ return
1210
1568
 
1211
- # Store point information
1212
- if not hasattr(self, "points_data"):
1213
- self.points_data = {}
1214
- self.points_labels = {}
1569
+ y_min, y_max = int(min(y_coords)), int(max(y_coords))
1570
+ x_min, x_max = int(min(x_coords)), int(max(x_coords))
1215
1571
 
1216
- if obj_id not in self.points_data:
1217
- self.points_data[obj_id] = []
1218
- self.points_labels[obj_id] = []
1572
+ box = np.array([x_min, y_min, x_max, y_max], dtype=np.float32)
1573
+ print(f"DEBUG: Box coordinates: {box}")
1219
1574
 
1220
- self.points_data[obj_id].append(
1221
- [x, y]
1222
- ) # Note: SAM2 expects [x,y] format
1223
- self.points_labels[obj_id].append(point_label)
1575
+ # Use SAM2 with box prompt - use _sam2_next_obj_id for 3D mode
1576
+ if not hasattr(self, "_sam2_next_obj_id"):
1577
+ self._sam2_next_obj_id = 1
1578
+ obj_id = self._sam2_next_obj_id
1579
+ self._sam2_next_obj_id += 1
1580
+ print(
1581
+ f"DEBUG: Box mode - using object ID {obj_id}, next will be {self._sam2_next_obj_id}"
1582
+ )
1224
1583
 
1225
- # Perform segmentation
1584
+ # Store box for this object
1585
+ if not hasattr(self, "obj_boxes"):
1586
+ self.obj_boxes = {}
1587
+ self.obj_boxes[obj_id] = box
1588
+
1589
+ # Perform segmentation with 3D propagation
1226
1590
  if (
1227
1591
  hasattr(self, "_sam2_state")
1228
1592
  and self._sam2_state is not None
1229
1593
  ):
1230
- # Prepare points
1231
- points = np.array(
1232
- self.points_data[obj_id], dtype=np.float32
1233
- )
1234
- labels = np.array(
1235
- self.points_labels[obj_id], dtype=np.int32
1594
+ self.viewer.status = (
1595
+ f"Segmenting object {obj_id} with box at frame {t}..."
1236
1596
  )
1597
+ print(f"DEBUG: Starting segmentation for object {obj_id}")
1237
1598
 
1238
- # Create progress layer for visual feedback
1239
- progress_layer = None
1240
- for existing_layer in self.viewer.layers:
1241
- if "Propagation Progress" in existing_layer.name:
1242
- progress_layer = existing_layer
1243
- break
1244
-
1245
- if progress_layer is None:
1246
- progress_data = np.zeros_like(self.segmentation_result)
1247
- progress_layer = self.viewer.add_image(
1248
- progress_data,
1249
- name="Propagation Progress",
1250
- colormap="magma",
1251
- opacity=0.5,
1252
- visible=True,
1253
- )
1254
-
1255
- # First update the current frame immediately
1256
- self.viewer.status = f"Processing object at frame {t}..."
1257
-
1258
- # Run SAM2 on current frame
1259
1599
  _, out_obj_ids, out_mask_logits = (
1260
1600
  self.predictor.add_new_points_or_box(
1261
1601
  inference_state=self._sam2_state,
1262
1602
  frame_idx=t,
1263
1603
  obj_id=obj_id,
1264
- points=points,
1265
- labels=labels,
1604
+ box=box,
1266
1605
  )
1267
1606
  )
1268
1607
 
1608
+ print("DEBUG: Segmentation complete, processing mask")
1269
1609
  # Update current frame
1270
1610
  mask = (out_mask_logits[0] > 0.0).cpu().numpy()
1271
1611
  if mask.ndim > 2:
@@ -1283,21 +1623,380 @@ class BatchCropAnything:
1283
1623
  anti_aliasing=False,
1284
1624
  ).astype(bool)
1285
1625
 
1286
- # Update segmentation for this frame
1287
- if point_label < 0:
1288
- # For negative points, only remove from this object
1289
- self.segmentation_result[t][
1290
- (self.segmentation_result[t] == obj_id) & mask
1291
- ] = 0
1292
- else:
1293
- # For positive points, only replace background
1294
- self.segmentation_result[t][
1295
- mask & (self.segmentation_result[t] == 0)
1296
- ] = obj_id
1626
+ # Update segmentation
1627
+ self.segmentation_result[t][
1628
+ mask & (self.segmentation_result[t] == 0)
1629
+ ] = obj_id
1297
1630
 
1298
- # Update progress layer for this frame
1299
- progress_data = progress_layer.data
1300
- progress_data[t] = (
1631
+ print(f"DEBUG: Starting propagation for object {obj_id}")
1632
+ # Propagate to all frames
1633
+ self._propagate_mask_for_current_object(obj_id, t)
1634
+
1635
+ # Update UI
1636
+ print("DEBUG: Updating label layer")
1637
+ self._update_label_layer()
1638
+ if (
1639
+ hasattr(self, "label_table_widget")
1640
+ and self.label_table_widget is not None
1641
+ ):
1642
+ self._populate_label_table(self.label_table_widget)
1643
+
1644
+ self.viewer.status = (
1645
+ f"Segmented and propagated object {obj_id} from box"
1646
+ )
1647
+ print("DEBUG: Rectangle processing complete!")
1648
+
1649
+ # Keep the rectangle visible after processing
1650
+ # Users can manually delete it if needed
1651
+ # if self.shapes_layer is not None:
1652
+ # self.shapes_layer.data = []
1653
+ else:
1654
+ print("DEBUG: _sam2_state not available")
1655
+ self.viewer.status = (
1656
+ "Error: 3D segmentation state not initialized"
1657
+ )
1658
+
1659
+ elif (
1660
+ not self.use_3d
1661
+ and len(rectangle_coords.shape) == 2
1662
+ and rectangle_coords.shape[1] == 2
1663
+ ):
1664
+ # 2D case: rectangle_coords shape is (4, 2) for Y, X
1665
+ if rectangle_coords.shape[0] == 4:
1666
+ # Get Y and X bounds
1667
+ y_coords = rectangle_coords[:, 0]
1668
+ x_coords = rectangle_coords[:, 1]
1669
+ y_min, y_max = int(min(y_coords)), int(max(y_coords))
1670
+ x_min, x_max = int(min(x_coords)), int(max(x_coords))
1671
+
1672
+ box = np.array(
1673
+ [x_min, y_min, x_max, y_max], dtype=np.float32
1674
+ )
1675
+
1676
+ # Use SAM2 with box prompt - use next_obj_id for 2D mode
1677
+ if not hasattr(self, "next_obj_id"):
1678
+ self.next_obj_id = 1
1679
+ obj_id = self.next_obj_id
1680
+ self.next_obj_id += 1
1681
+ print(
1682
+ f"DEBUG: 2D Box mode - using object ID {obj_id}, next will be {self.next_obj_id}"
1683
+ )
1684
+
1685
+ # Store box for this object
1686
+ if not hasattr(self, "obj_boxes"):
1687
+ self.obj_boxes = {}
1688
+ self.obj_boxes[obj_id] = box
1689
+
1690
+ # Perform segmentation
1691
+ if (
1692
+ hasattr(self, "predictor")
1693
+ and self.predictor is not None
1694
+ ):
1695
+ # Make sure image is loaded
1696
+ if self.current_image_for_segmentation is None:
1697
+ self.viewer.status = (
1698
+ "No image loaded for segmentation"
1699
+ )
1700
+ return
1701
+
1702
+ # Prepare image for SAM2
1703
+ image = self.current_image_for_segmentation
1704
+ if len(image.shape) == 2:
1705
+ image = np.stack([image] * 3, axis=-1)
1706
+ elif len(image.shape) == 3 and image.shape[2] == 1:
1707
+ image = np.concatenate([image] * 3, axis=2)
1708
+ elif len(image.shape) == 3 and image.shape[2] > 3:
1709
+ image = image[:, :, :3]
1710
+
1711
+ if image.dtype != np.uint8:
1712
+ image = (image / np.max(image) * 255).astype(
1713
+ np.uint8
1714
+ )
1715
+
1716
+ # Set the image in the predictor (only for ImagePredictor, not VideoPredictor)
1717
+ if hasattr(self.predictor, "set_image"):
1718
+ self.predictor.set_image(image)
1719
+ else:
1720
+ self.viewer.status = "Error: Rectangle mode requires Image Predictor (2D mode)"
1721
+ return
1722
+
1723
+ self.viewer.status = (
1724
+ f"Segmenting object {obj_id} with box..."
1725
+ )
1726
+
1727
+ with (
1728
+ torch.inference_mode(),
1729
+ torch.autocast(device_type),
1730
+ ):
1731
+ masks, scores, _ = self.predictor.predict(
1732
+ box=box,
1733
+ multimask_output=False,
1734
+ )
1735
+
1736
+ # Get the mask
1737
+ if len(masks) > 0:
1738
+ best_mask = masks[0]
1739
+
1740
+ # Resize if needed
1741
+ if (
1742
+ best_mask.shape
1743
+ != self.segmentation_result.shape
1744
+ ):
1745
+ from skimage.transform import resize
1746
+
1747
+ best_mask = resize(
1748
+ best_mask.astype(float),
1749
+ self.segmentation_result.shape,
1750
+ order=0,
1751
+ preserve_range=True,
1752
+ anti_aliasing=False,
1753
+ ).astype(bool)
1754
+
1755
+ # Apply mask (only overwrite background)
1756
+ mask_condition = np.logical_and(
1757
+ best_mask, (self.segmentation_result == 0)
1758
+ )
1759
+ self.segmentation_result[mask_condition] = (
1760
+ obj_id
1761
+ )
1762
+
1763
+ # Update label info
1764
+ area = np.sum(
1765
+ self.segmentation_result == obj_id
1766
+ )
1767
+ y_indices, x_indices = np.where(
1768
+ self.segmentation_result == obj_id
1769
+ )
1770
+ center_y = (
1771
+ np.mean(y_indices)
1772
+ if len(y_indices) > 0
1773
+ else 0
1774
+ )
1775
+ center_x = (
1776
+ np.mean(x_indices)
1777
+ if len(x_indices) > 0
1778
+ else 0
1779
+ )
1780
+
1781
+ self.label_info[obj_id] = {
1782
+ "area": area,
1783
+ "center_y": center_y,
1784
+ "center_x": center_x,
1785
+ "score": float(scores[0]),
1786
+ }
1787
+
1788
+ self.viewer.status = (
1789
+ f"Segmented object {obj_id} from box"
1790
+ )
1791
+ else:
1792
+ self.viewer.status = "No valid mask produced"
1793
+
1794
+ # Update the UI
1795
+ self._update_label_layer()
1796
+ if (
1797
+ hasattr(self, "label_table_widget")
1798
+ and self.label_table_widget is not None
1799
+ ):
1800
+ self._populate_label_table(self.label_table_widget)
1801
+
1802
+ # Keep the rectangle visible after processing
1803
+ # Users can manually delete it if needed
1804
+ # if self.shapes_layer is not None:
1805
+ # self.shapes_layer.data = []
1806
+ else:
1807
+ # Unexpected shape dimensions
1808
+ print(
1809
+ f"DEBUG: Unexpected rectangle shape: {rectangle_coords.shape}"
1810
+ )
1811
+ self.viewer.status = f"Error: Unexpected rectangle dimensions {rectangle_coords.shape}. Expected (4,2) for 2D or (4,3) for 3D."
1812
+
1813
+ except (
1814
+ IndexError,
1815
+ KeyError,
1816
+ ValueError,
1817
+ RuntimeError,
1818
+ TypeError,
1819
+ ) as e:
1820
+ import traceback
1821
+
1822
+ self.viewer.status = f"Error in rectangle handling: {str(e)}"
1823
+ print("DEBUG: Exception in _on_rectangle_added:")
1824
+ traceback.print_exc()
1825
+
1826
+ def _on_points_clicked(self, layer, event):
1827
+ """Handle clicks on the points layer for adding/removing points."""
1828
+ device_type = "cuda" if self.device.type == "cuda" else "cpu"
1829
+ try:
1830
+ # Only process clicks, not drags
1831
+ if event.type != "mouse_press":
1832
+ return
1833
+
1834
+ # Check if segmentation result exists
1835
+ if self.segmentation_result is None:
1836
+ self.viewer.status = (
1837
+ "Segmentation not ready. Please wait for image to load."
1838
+ )
1839
+ return
1840
+
1841
+ # Get coordinates of mouse click
1842
+ coords = np.round(event.position).astype(int)
1843
+
1844
+ # Check if Shift is pressed for negative points
1845
+ is_negative = "Shift" in event.modifiers
1846
+ point_label = -1 if is_negative else 1
1847
+
1848
+ # Handle 2D vs 3D coordinates
1849
+ if self.use_3d:
1850
+ if len(coords) == 3:
1851
+ t, y, x = map(int, coords)
1852
+ elif len(coords) == 2:
1853
+ t = int(self.viewer.dims.current_step[0])
1854
+ y, x = map(int, coords)
1855
+ else:
1856
+ self.viewer.status = (
1857
+ f"Unexpected coordinate dimensions: {coords}"
1858
+ )
1859
+ return
1860
+
1861
+ # Add point to the layer immediately for visual feedback
1862
+ new_point = np.array([[t, y, x]])
1863
+ if len(layer.data) == 0:
1864
+ layer.data = new_point
1865
+ else:
1866
+ layer.data = np.vstack([layer.data, new_point])
1867
+
1868
+ # Update point colors
1869
+ colors = layer.face_color
1870
+ if isinstance(colors, list):
1871
+ colors.append("red" if is_negative else "green")
1872
+ else:
1873
+ n_points = len(layer.data)
1874
+ colors = ["green"] * (n_points - 1)
1875
+ colors.append("red" if is_negative else "green")
1876
+ layer.face_color = colors
1877
+
1878
+ # Validate coordinates are within segmentation bounds
1879
+ if (
1880
+ t < 0
1881
+ or t >= self.segmentation_result.shape[0]
1882
+ or y < 0
1883
+ or y >= self.segmentation_result.shape[1]
1884
+ or x < 0
1885
+ or x >= self.segmentation_result.shape[2]
1886
+ ):
1887
+ self.viewer.status = (
1888
+ f"Click at ({t}, {y}, {x}) is out of bounds for "
1889
+ f"segmentation shape {self.segmentation_result.shape}. "
1890
+ f"Please click within the image bounds."
1891
+ )
1892
+ # Remove the invalid point that was just added
1893
+ if len(layer.data) > 0:
1894
+ layer.data = layer.data[:-1]
1895
+ return
1896
+
1897
+ # Get the object ID
1898
+ # If clicking on existing segmentation with negative point
1899
+ label_id = self.segmentation_result[t, y, x]
1900
+ if is_negative and label_id > 0:
1901
+ obj_id = label_id
1902
+ else:
1903
+ # For new objects or negative on background
1904
+ if not hasattr(self, "_sam2_next_obj_id"):
1905
+ self._sam2_next_obj_id = 1
1906
+ obj_id = self._sam2_next_obj_id
1907
+ if point_label > 0 and label_id == 0:
1908
+ self._sam2_next_obj_id += 1
1909
+
1910
+ # Store point information
1911
+ if not hasattr(self, "points_data"):
1912
+ self.points_data = {}
1913
+ self.points_labels = {}
1914
+
1915
+ if obj_id not in self.points_data:
1916
+ self.points_data[obj_id] = []
1917
+ self.points_labels[obj_id] = []
1918
+
1919
+ self.points_data[obj_id].append(
1920
+ [x, y]
1921
+ ) # Note: SAM2 expects [x,y] format
1922
+ self.points_labels[obj_id].append(point_label)
1923
+
1924
+ # Perform segmentation
1925
+ if (
1926
+ hasattr(self, "_sam2_state")
1927
+ and self._sam2_state is not None
1928
+ ):
1929
+ # Prepare points
1930
+ points = np.array(
1931
+ self.points_data[obj_id], dtype=np.float32
1932
+ )
1933
+ labels = np.array(
1934
+ self.points_labels[obj_id], dtype=np.int32
1935
+ )
1936
+
1937
+ # Create progress layer for visual feedback
1938
+ progress_layer = None
1939
+ for existing_layer in self.viewer.layers:
1940
+ if "Propagation Progress" in existing_layer.name:
1941
+ progress_layer = existing_layer
1942
+ break
1943
+
1944
+ if progress_layer is None:
1945
+ progress_data = np.zeros_like(self.segmentation_result)
1946
+ progress_layer = self.viewer.add_image(
1947
+ progress_data,
1948
+ name="Propagation Progress",
1949
+ colormap="magma",
1950
+ opacity=0.5,
1951
+ visible=True,
1952
+ )
1953
+
1954
+ # First update the current frame immediately
1955
+ self.viewer.status = f"Processing object at frame {t}..."
1956
+
1957
+ # Run SAM2 on current frame
1958
+ _, out_obj_ids, out_mask_logits = (
1959
+ self.predictor.add_new_points_or_box(
1960
+ inference_state=self._sam2_state,
1961
+ frame_idx=t,
1962
+ obj_id=obj_id,
1963
+ points=points,
1964
+ labels=labels,
1965
+ )
1966
+ )
1967
+
1968
+ # Update current frame
1969
+ mask = (out_mask_logits[0] > 0.0).cpu().numpy()
1970
+ if mask.ndim > 2:
1971
+ mask = mask.squeeze()
1972
+
1973
+ # Resize if needed
1974
+ if mask.shape != self.segmentation_result[t].shape:
1975
+ from skimage.transform import resize
1976
+
1977
+ mask = resize(
1978
+ mask.astype(float),
1979
+ self.segmentation_result[t].shape,
1980
+ order=0,
1981
+ preserve_range=True,
1982
+ anti_aliasing=False,
1983
+ ).astype(bool)
1984
+
1985
+ # Update segmentation for this frame
1986
+ if point_label < 0:
1987
+ # For negative points, only remove from this object
1988
+ self.segmentation_result[t][
1989
+ (self.segmentation_result[t] == obj_id) & mask
1990
+ ] = 0
1991
+ else:
1992
+ # For positive points, only replace background
1993
+ self.segmentation_result[t][
1994
+ mask & (self.segmentation_result[t] == 0)
1995
+ ] = obj_id
1996
+
1997
+ # Update progress layer for this frame
1998
+ progress_data = progress_layer.data
1999
+ progress_data[t] = (
1301
2000
  mask.astype(float) * 0.5
1302
2001
  ) # Highlight current frame
1303
2002
  progress_layer.data = progress_data
@@ -1398,7 +2097,11 @@ class BatchCropAnything:
1398
2097
  time.sleep(2)
1399
2098
  for layer in list(self.viewer.layers):
1400
2099
  if "Propagation Progress" in layer.name:
1401
- self.viewer.layers.remove(layer)
2100
+ # Clean up callbacks before removing the layer to prevent cleanup issues
2101
+ if hasattr(layer, "mouse_drag_callbacks"):
2102
+ layer.mouse_drag_callbacks.clear()
2103
+ with contextlib.suppress(ValueError):
2104
+ self.viewer.layers.remove(layer)
1402
2105
 
1403
2106
  threading.Thread(target=remove_progress).start()
1404
2107
 
@@ -1439,6 +2142,23 @@ class BatchCropAnything:
1439
2142
  colors.append("red" if is_negative else "green")
1440
2143
  layer.face_color = colors
1441
2144
 
2145
+ # Validate coordinates are within segmentation bounds
2146
+ if (
2147
+ y < 0
2148
+ or y >= self.segmentation_result.shape[0]
2149
+ or x < 0
2150
+ or x >= self.segmentation_result.shape[1]
2151
+ ):
2152
+ self.viewer.status = (
2153
+ f"Click at ({y}, {x}) is out of bounds for "
2154
+ f"segmentation shape {self.segmentation_result.shape}. "
2155
+ f"Please click within the image bounds."
2156
+ )
2157
+ # Remove the invalid point that was just added
2158
+ if len(layer.data) > 0:
2159
+ layer.data = layer.data[:-1]
2160
+ return
2161
+
1442
2162
  # Get object ID
1443
2163
  label_id = self.segmentation_result[y, x]
1444
2164
  if is_negative and label_id > 0:
@@ -1483,8 +2203,14 @@ class BatchCropAnything:
1483
2203
  if image.dtype != np.uint8:
1484
2204
  image = (image / np.max(image) * 255).astype(np.uint8)
1485
2205
 
1486
- # Set the image in the predictor
1487
- self.predictor.set_image(image)
2206
+ # Set the image in the predictor (only for ImagePredictor, not VideoPredictor)
2207
+ if hasattr(self.predictor, "set_image"):
2208
+ self.predictor.set_image(image)
2209
+ else:
2210
+ self.viewer.status = (
2211
+ "Error: Point mode in 2D requires Image Predictor"
2212
+ )
2213
+ return
1488
2214
 
1489
2215
  # Use only points for current object
1490
2216
  points = np.array(
@@ -1494,7 +2220,7 @@ class BatchCropAnything:
1494
2220
 
1495
2221
  self.viewer.status = f"Segmenting object {obj_id} with {len(points)} points..."
1496
2222
 
1497
- with torch.inference_mode(), torch.autocast("cuda"):
2223
+ with torch.inference_mode(), torch.autocast(device_type):
1498
2224
  masks, scores, _ = self.predictor.predict(
1499
2225
  point_coords=points,
1500
2226
  point_labels=labels,
@@ -1583,16 +2309,23 @@ class BatchCropAnything:
1583
2309
  def _on_label_clicked(self, layer, event):
1584
2310
  """Handle label selection and user prompts on mouse click."""
1585
2311
  try:
1586
- # Only process clicks, not drags
2312
+ # Only process mouse press events
1587
2313
  if event.type != "mouse_press":
1588
2314
  return
1589
2315
 
2316
+ # Only handle left mouse button
2317
+ if event.button != 1:
2318
+ return
2319
+
1590
2320
  # Get coordinates of mouse click
1591
2321
  coords = np.round(event.position).astype(int)
1592
2322
 
1593
- # Check if Shift is pressed (negative point)
2323
+ # Check modifiers
1594
2324
  is_negative = "Shift" in event.modifiers
1595
- point_label = -1 if is_negative else 1
2325
+ is_control = (
2326
+ "Control" in event.modifiers or "Ctrl" in event.modifiers
2327
+ )
2328
+ # point_label = -1 if is_negative else 1
1596
2329
 
1597
2330
  # For 2D data
1598
2331
  if not self.use_3d:
@@ -1613,262 +2346,13 @@ class BatchCropAnything:
1613
2346
  # Get the label ID at the clicked position
1614
2347
  label_id = self.segmentation_result[y, x]
1615
2348
 
1616
- # Initialize a unique object ID for this click (if needed)
1617
- if not hasattr(self, "next_obj_id"):
1618
- # Start with highest existing ID + 1
1619
- if self.segmentation_result.max() > 0:
1620
- self.next_obj_id = (
1621
- int(self.segmentation_result.max()) + 1
1622
- )
1623
- else:
1624
- self.next_obj_id = 1
1625
-
1626
- # If clicking on background or using negative click, handle segmentation
1627
- if label_id == 0 or is_negative:
1628
- # Find or create points layer for the current object we're working on
1629
- current_obj_id = None
1630
-
1631
- # If negative point on existing label, use that label's ID
1632
- if is_negative and label_id > 0:
1633
- current_obj_id = label_id
1634
- # For positive clicks on background, create a new object
1635
- elif point_label > 0 and label_id == 0:
1636
- current_obj_id = self.next_obj_id
1637
- self.next_obj_id += 1
1638
- # For negative on background, try to find most recent object
1639
- elif point_label < 0 and label_id == 0:
1640
- # Use most recently created object if available
1641
- if hasattr(self, "obj_points") and self.obj_points:
1642
- current_obj_id = max(self.obj_points.keys())
1643
- else:
1644
- self.viewer.status = "No existing object to modify with negative point"
1645
- return
1646
-
1647
- if current_obj_id is None:
1648
- self.viewer.status = (
1649
- "Could not determine which object to modify"
1650
- )
1651
- return
1652
-
1653
- # Find or create points layer for this object
1654
- points_layer = None
1655
- for layer in list(self.viewer.layers):
1656
- if f"Points for Object {current_obj_id}" in layer.name:
1657
- points_layer = layer
1658
- break
1659
-
1660
- # Initialize object tracking if needed
1661
- if not hasattr(self, "obj_points"):
1662
- self.obj_points = {}
1663
- self.obj_labels = {}
1664
-
1665
- if current_obj_id not in self.obj_points:
1666
- self.obj_points[current_obj_id] = []
1667
- self.obj_labels[current_obj_id] = []
1668
-
1669
- # Create or update points layer for this object
1670
- if points_layer is None:
1671
- # First point for this object
1672
- points_layer = self.viewer.add_points(
1673
- np.array([[y, x]]),
1674
- name=f"Points for Object {current_obj_id}",
1675
- size=10,
1676
- face_color=["green" if point_label > 0 else "red"],
1677
- border_color="white",
1678
- border_width=1,
1679
- opacity=0.8,
1680
- )
1681
- with contextlib.suppress(AttributeError, ValueError):
1682
- points_layer.mouse_drag_callbacks.remove(
1683
- self._on_points_clicked
1684
- )
1685
- points_layer.mouse_drag_callbacks.append(
1686
- self._on_points_clicked
1687
- )
1688
-
1689
- self.obj_points[current_obj_id] = [[x, y]]
1690
- self.obj_labels[current_obj_id] = [point_label]
1691
- else:
1692
- # Add point to existing layer
1693
- current_points = points_layer.data
1694
- current_colors = points_layer.face_color
1695
-
1696
- # Add new point
1697
- new_points = np.vstack([current_points, [y, x]])
1698
- new_color = "green" if point_label > 0 else "red"
1699
-
1700
- # Update points layer
1701
- points_layer.data = new_points
1702
-
1703
- # Update colors
1704
- if isinstance(current_colors, list):
1705
- current_colors.append(new_color)
1706
- points_layer.face_color = current_colors
1707
- else:
1708
- # If it's an array, create a list of colors
1709
- colors = []
1710
- for i in range(len(new_points)):
1711
- if i < len(current_points):
1712
- colors.append(
1713
- "green" if point_label > 0 else "red"
1714
- )
1715
- else:
1716
- colors.append(new_color)
1717
- points_layer.face_color = colors
1718
-
1719
- # Update object tracking
1720
- self.obj_points[current_obj_id].append([x, y])
1721
- self.obj_labels[current_obj_id].append(point_label)
1722
-
1723
- # Now do the actual segmentation using SAM2
1724
- if (
1725
- hasattr(self, "predictor")
1726
- and self.predictor is not None
1727
- ):
1728
- try:
1729
- # Make sure image is loaded
1730
- if self.current_image_for_segmentation is None:
1731
- self.viewer.status = (
1732
- "No image loaded for segmentation"
1733
- )
1734
- return
1735
-
1736
- # Prepare image for SAM2
1737
- image = self.current_image_for_segmentation
1738
- if len(image.shape) == 2:
1739
- image = np.stack([image] * 3, axis=-1)
1740
- elif len(image.shape) == 3 and image.shape[2] == 1:
1741
- image = np.concatenate([image] * 3, axis=2)
1742
- elif len(image.shape) == 3 and image.shape[2] > 3:
1743
- image = image[:, :, :3]
1744
-
1745
- if image.dtype != np.uint8:
1746
- image = (image / np.max(image) * 255).astype(
1747
- np.uint8
1748
- )
1749
-
1750
- # Set the image in the predictor
1751
- self.predictor.set_image(image)
1752
-
1753
- # Only use the points for the current object being segmented
1754
- points = np.array(
1755
- self.obj_points[current_obj_id],
1756
- dtype=np.float32,
1757
- )
1758
- labels = np.array(
1759
- self.obj_labels[current_obj_id], dtype=np.int32
1760
- )
1761
-
1762
- self.viewer.status = f"Segmenting object {current_obj_id} with {len(points)} points..."
1763
-
1764
- with torch.inference_mode(), torch.autocast(
1765
- "cuda"
1766
- ):
1767
- masks, scores, _ = self.predictor.predict(
1768
- point_coords=points,
1769
- point_labels=labels,
1770
- multimask_output=True,
1771
- )
1772
-
1773
- # Get best mask
1774
- if len(masks) > 0:
1775
- best_mask = masks[0]
1776
-
1777
- # Update segmentation result
1778
- if (
1779
- best_mask.shape
1780
- != self.segmentation_result.shape
1781
- ):
1782
- from skimage.transform import resize
1783
-
1784
- best_mask = resize(
1785
- best_mask.astype(float),
1786
- self.segmentation_result.shape,
1787
- order=0,
1788
- preserve_range=True,
1789
- anti_aliasing=False,
1790
- ).astype(bool)
1791
-
1792
- # CRITICAL FIX: For negative points, only remove from this object's mask
1793
- # For positive points, add to this object's mask without removing other objects
1794
- if point_label < 0:
1795
- # Remove only from current object's mask
1796
- self.segmentation_result[
1797
- (
1798
- self.segmentation_result
1799
- == current_obj_id
1800
- )
1801
- & best_mask
1802
- ] = 0
1803
- else:
1804
- # Add to current object's mask without affecting other objects
1805
- # Only overwrite background (value 0)
1806
- self.segmentation_result[
1807
- best_mask
1808
- & (self.segmentation_result == 0)
1809
- ] = current_obj_id
1810
-
1811
- # Update label info
1812
- area = np.sum(
1813
- self.segmentation_result
1814
- == current_obj_id
1815
- )
1816
- y_indices, x_indices = np.where(
1817
- self.segmentation_result
1818
- == current_obj_id
1819
- )
1820
- center_y = (
1821
- np.mean(y_indices)
1822
- if len(y_indices) > 0
1823
- else 0
1824
- )
1825
- center_x = (
1826
- np.mean(x_indices)
1827
- if len(x_indices) > 0
1828
- else 0
1829
- )
1830
-
1831
- self.label_info[current_obj_id] = {
1832
- "area": area,
1833
- "center_y": center_y,
1834
- "center_x": center_x,
1835
- "score": float(scores[0]),
1836
- }
1837
-
1838
- self.viewer.status = (
1839
- f"Updated object {current_obj_id}"
1840
- )
1841
- else:
1842
- self.viewer.status = (
1843
- "No valid mask produced"
1844
- )
1845
-
1846
- # Update the UI
1847
- self._update_label_layer()
1848
- if (
1849
- hasattr(self, "label_table_widget")
1850
- and self.label_table_widget is not None
1851
- ):
1852
- self._populate_label_table(
1853
- self.label_table_widget
1854
- )
1855
-
1856
- except (
1857
- IndexError,
1858
- KeyError,
1859
- ValueError,
1860
- AttributeError,
1861
- TypeError,
1862
- ) as e:
1863
- import traceback
1864
-
1865
- self.viewer.status = (
1866
- f"Error in SAM2 processing: {str(e)}"
1867
- )
1868
- traceback.print_exc()
2349
+ # Handle Ctrl+Click to clear a single label
2350
+ if is_control and label_id > 0:
2351
+ self.clear_label_at_position(y, x)
2352
+ return
1869
2353
 
1870
- # If clicking on an existing label, toggle selection
1871
- elif label_id > 0:
2354
+ # If clicking on an existing label (and not using modifiers), toggle selection
2355
+ if label_id > 0 and not is_negative and not is_control:
1872
2356
  # Toggle the label selection
1873
2357
  if label_id in self.selected_labels:
1874
2358
  self.selected_labels.remove(label_id)
@@ -1880,8 +2364,14 @@ class BatchCropAnything:
1880
2364
  # Update table and preview
1881
2365
  self._update_label_table()
1882
2366
  self.preview_crop()
2367
+ return
2368
+
2369
+ # If clicking on background or using Shift (negative points), this should be handled by points layer
2370
+ # Don't process these clicks here to avoid conflicts
2371
+ if label_id == 0 or is_negative:
2372
+ return
1883
2373
 
1884
- # 3D case (handle differently)
2374
+ # 3D case
1885
2375
  else:
1886
2376
  if len(coords) == 3:
1887
2377
  t, y, x = map(int, coords)
@@ -1910,12 +2400,13 @@ class BatchCropAnything:
1910
2400
  # Get the label ID at the clicked position
1911
2401
  label_id = self.segmentation_result[t, y, x]
1912
2402
 
1913
- # If background or shift is pressed, handle in _on_3d_label_clicked
1914
- if label_id == 0 or is_negative:
1915
- # This will be handled by _on_3d_label_clicked already attached
1916
- pass
1917
- # If clicking on an existing label, handle selection
1918
- elif label_id > 0:
2403
+ # Handle Ctrl+Click to clear a single label
2404
+ if is_control and label_id > 0:
2405
+ self.clear_label_at_position_3d(t, y, x)
2406
+ return
2407
+
2408
+ # If clicking on an existing label and not using negative points, handle selection
2409
+ if label_id > 0 and not is_negative and not is_control:
1919
2410
  # Toggle the label selection
1920
2411
  if label_id in self.selected_labels:
1921
2412
  self.selected_labels.remove(label_id)
@@ -1926,9 +2417,12 @@ class BatchCropAnything:
1926
2417
 
1927
2418
  # Update table if it exists
1928
2419
  self._update_label_table()
1929
-
1930
- # Update preview after selection changes
1931
2420
  self.preview_crop()
2421
+ return
2422
+
2423
+ # For background clicks or negative points, let the 3D handler deal with it
2424
+ if label_id == 0 or is_negative:
2425
+ return
1932
2426
 
1933
2427
  except (
1934
2428
  IndexError,
@@ -1942,12 +2436,74 @@ class BatchCropAnything:
1942
2436
  self.viewer.status = f"Error in click handling: {str(e)}"
1943
2437
  traceback.print_exc()
1944
2438
 
2439
+ def _add_segmentation_point(self, x, y, event):
2440
+ """Add a point for segmentation."""
2441
+ is_negative = "Shift" in event.modifiers
2442
+
2443
+ # Initialize tracking if needed
2444
+ if not hasattr(self, "current_points"):
2445
+ self.current_points = []
2446
+ self.current_labels = []
2447
+ self.current_obj_id = 1
2448
+
2449
+ # Add point
2450
+ self.current_points.append([x, y])
2451
+ self.current_labels.append(0 if is_negative else 1)
2452
+
2453
+ # Run SAM2 prediction
2454
+ if self.predictor is not None:
2455
+ # Prepare image
2456
+ image = self._prepare_image_for_sam2()
2457
+
2458
+ # Set the image in the predictor (only for ImagePredictor, not VideoPredictor)
2459
+ if hasattr(self.predictor, "set_image"):
2460
+ self.predictor.set_image(image)
2461
+ else:
2462
+ self.viewer.status = (
2463
+ "Error: This operation requires Image Predictor (2D mode)"
2464
+ )
2465
+ return
2466
+
2467
+ # Predict
2468
+ device_type = "cuda" if self.device.type == "cuda" else "cpu"
2469
+ with torch.inference_mode(), torch.autocast(device_type):
2470
+ masks, scores, _ = self.predictor.predict(
2471
+ point_coords=np.array(
2472
+ self.current_points, dtype=np.float32
2473
+ ),
2474
+ point_labels=np.array(self.current_labels, dtype=np.int32),
2475
+ multimask_output=False,
2476
+ )
2477
+
2478
+ # Update segmentation
2479
+ if len(masks) > 0:
2480
+ mask = masks[0] > 0.5
2481
+ if self.current_scale_factor < 1.0:
2482
+ mask = resize(
2483
+ mask, self.segmentation_result.shape, order=0
2484
+ ).astype(bool)
2485
+
2486
+ # Update segmentation result
2487
+ self.segmentation_result[mask] = self.current_obj_id
2488
+
2489
+ # Move to next object if adding positive point
2490
+ if not is_negative:
2491
+ self.current_obj_id += 1
2492
+ self.current_points = []
2493
+ self.current_labels = []
2494
+
2495
+ self._update_label_layer()
2496
+
1945
2497
  def _add_point_marker(self, coords, label_type):
1946
2498
  """Add a visible marker for where the user clicked."""
1947
2499
  # Remove previous point markers
1948
2500
  for layer in list(self.viewer.layers):
1949
2501
  if "Point Prompt" in layer.name:
1950
- self.viewer.layers.remove(layer)
2502
+ # Clean up callbacks before removing the layer to prevent cleanup issues
2503
+ if hasattr(layer, "mouse_drag_callbacks"):
2504
+ layer.mouse_drag_callbacks.clear()
2505
+ with contextlib.suppress(ValueError):
2506
+ self.viewer.layers.remove(layer)
1951
2507
 
1952
2508
  # Create points layer
1953
2509
  color = (
@@ -2135,11 +2691,170 @@ class BatchCropAnything:
2135
2691
  self.viewer.status = f"Selected all {len(self.selected_labels)} labels"
2136
2692
 
2137
2693
  def clear_selection(self):
2138
- """Clear all selected labels."""
2694
+ """Clear all labels from the segmentation.
2695
+
2696
+ This removes all segmented objects from the label layer, resets all tracking data,
2697
+ and prepares the interface for new segmentations. Note: The method name is kept as
2698
+ 'clear_selection' for backwards compatibility, but it clears all labels, not just
2699
+ the selection.
2700
+ """
2701
+ if self.segmentation_result is None:
2702
+ self.viewer.status = "No segmentation available"
2703
+ return
2704
+
2705
+ # Get all unique label IDs (excluding background 0)
2706
+ unique_labels = np.unique(self.segmentation_result)
2707
+ label_ids = [label for label in unique_labels if label > 0]
2708
+
2709
+ if len(label_ids) == 0:
2710
+ self.viewer.status = "No labels to clear"
2711
+ return
2712
+
2713
+ # Clear the entire segmentation result
2714
+ self.segmentation_result[:] = 0
2715
+
2716
+ # Clear selected labels
2139
2717
  self.selected_labels = set()
2718
+
2719
+ # Clear label info
2720
+ self.label_info = {}
2721
+
2722
+ # Remove any object-specific point layers
2723
+ for layer in list(self.viewer.layers):
2724
+ if "Points for Object" in layer.name:
2725
+ # Clean up callbacks before removing the layer to prevent cleanup issues
2726
+ if hasattr(layer, "mouse_drag_callbacks"):
2727
+ layer.mouse_drag_callbacks.clear()
2728
+ with contextlib.suppress(ValueError):
2729
+ self.viewer.layers.remove(layer)
2730
+
2731
+ # Clean up object tracking data
2732
+ if hasattr(self, "obj_points"):
2733
+ self.obj_points = {}
2734
+ if hasattr(self, "obj_labels"):
2735
+ self.obj_labels = {}
2736
+ if hasattr(self, "points_data"):
2737
+ self.points_data = {}
2738
+ if hasattr(self, "points_labels"):
2739
+ self.points_labels = {}
2740
+
2741
+ # Reset object ID counters
2742
+ if hasattr(self, "next_obj_id"):
2743
+ self.next_obj_id = 1
2744
+ if hasattr(self, "_sam2_next_obj_id"):
2745
+ self._sam2_next_obj_id = 1
2746
+
2747
+ # Update UI
2748
+ self._update_label_layer()
2140
2749
  self._update_label_table()
2141
2750
  self.preview_crop()
2142
- self.viewer.status = "Cleared all selections"
2751
+
2752
+ self.viewer.status = (
2753
+ f"Cleared all {len(label_ids)} labels from segmentation"
2754
+ )
2755
+
2756
+ def clear_label_at_position(self, y, x):
2757
+ """Clear a single label at the specified 2D position."""
2758
+ if self.segmentation_result is None:
2759
+ self.viewer.status = "No segmentation available"
2760
+ return
2761
+
2762
+ label_id = self.segmentation_result[y, x]
2763
+ if label_id > 0:
2764
+ # Remove all pixels with this label ID
2765
+ self.segmentation_result[self.segmentation_result == label_id] = 0
2766
+
2767
+ # Remove from selected labels if it was selected
2768
+ self.selected_labels.discard(label_id)
2769
+
2770
+ # Remove from label info
2771
+ if label_id in self.label_info:
2772
+ del self.label_info[label_id]
2773
+
2774
+ # Remove any object-specific point layers for this label
2775
+ for layer in list(self.viewer.layers):
2776
+ if f"Points for Object {label_id}" in layer.name:
2777
+ # Clean up callbacks before removing the layer to prevent cleanup issues
2778
+ if hasattr(layer, "mouse_drag_callbacks"):
2779
+ layer.mouse_drag_callbacks.clear()
2780
+ with contextlib.suppress(ValueError):
2781
+ self.viewer.layers.remove(layer)
2782
+
2783
+ # Clean up object tracking data
2784
+ if hasattr(self, "obj_points") and label_id in self.obj_points:
2785
+ del self.obj_points[label_id]
2786
+ if hasattr(self, "obj_labels") and label_id in self.obj_labels:
2787
+ del self.obj_labels[label_id]
2788
+
2789
+ # Update UI
2790
+ self._update_label_layer()
2791
+ self._update_label_table()
2792
+ self.preview_crop()
2793
+
2794
+ self.viewer.status = f"Deleted label ID: {label_id}"
2795
+ else:
2796
+ self.viewer.status = "No label to delete at this position"
2797
+
2798
+ def clear_label_at_position_3d(self, t, y, x):
2799
+ """Clear a single label at the specified 3D position."""
2800
+ if self.segmentation_result is None:
2801
+ self.viewer.status = "No segmentation available"
2802
+ return
2803
+
2804
+ label_id = self.segmentation_result[t, y, x]
2805
+ if label_id > 0:
2806
+ # Remove all pixels with this label ID across all timeframes
2807
+ self.segmentation_result[self.segmentation_result == label_id] = 0
2808
+
2809
+ # Remove from selected labels if it was selected
2810
+ self.selected_labels.discard(label_id)
2811
+
2812
+ # Remove from label info
2813
+ if label_id in self.label_info:
2814
+ del self.label_info[label_id]
2815
+
2816
+ # Remove any object-specific point layers for this label
2817
+ for layer in list(self.viewer.layers):
2818
+ if f"Points for Object {label_id}" in layer.name:
2819
+ # Clean up callbacks before removing the layer to prevent cleanup issues
2820
+ if hasattr(layer, "mouse_drag_callbacks"):
2821
+ layer.mouse_drag_callbacks.clear()
2822
+ with contextlib.suppress(ValueError):
2823
+ self.viewer.layers.remove(layer)
2824
+
2825
+ # Clean up 3D object tracking data
2826
+ if (
2827
+ hasattr(self, "sam2_points_by_obj")
2828
+ and label_id in self.sam2_points_by_obj
2829
+ ):
2830
+ del self.sam2_points_by_obj[label_id]
2831
+ if (
2832
+ hasattr(self, "sam2_labels_by_obj")
2833
+ and label_id in self.sam2_labels_by_obj
2834
+ ):
2835
+ del self.sam2_labels_by_obj[label_id]
2836
+ if hasattr(self, "points_data") and label_id in self.points_data:
2837
+ del self.points_data[label_id]
2838
+ if (
2839
+ hasattr(self, "points_labels")
2840
+ and label_id in self.points_labels
2841
+ ):
2842
+ del self.points_labels[label_id]
2843
+
2844
+ # Update UI
2845
+ self._update_label_layer()
2846
+ if (
2847
+ hasattr(self, "label_table_widget")
2848
+ and self.label_table_widget is not None
2849
+ ):
2850
+ self._populate_label_table(self.label_table_widget)
2851
+ self.preview_crop()
2852
+
2853
+ self.viewer.status = (
2854
+ f"Deleted label ID: {label_id} from all timeframes"
2855
+ )
2856
+ else:
2857
+ self.viewer.status = "No label to delete at this position"
2143
2858
 
2144
2859
  def preview_crop(self, label_ids=None):
2145
2860
  """Preview the crop result with the selected label IDs."""
@@ -2159,7 +2874,11 @@ class BatchCropAnything:
2159
2874
  # Remove previous preview if exists
2160
2875
  for layer in list(self.viewer.layers):
2161
2876
  if "Preview" in layer.name:
2162
- self.viewer.layers.remove(layer)
2877
+ # Clean up callbacks before removing the layer to prevent cleanup issues
2878
+ if hasattr(layer, "mouse_drag_callbacks"):
2879
+ layer.mouse_drag_callbacks.clear()
2880
+ with contextlib.suppress(ValueError):
2881
+ self.viewer.layers.remove(layer)
2163
2882
 
2164
2883
  # Make sure the segmentation layer is active again
2165
2884
  if self.label_layer is not None:
@@ -2197,7 +2916,11 @@ class BatchCropAnything:
2197
2916
  # Remove previous preview if exists
2198
2917
  for layer in list(self.viewer.layers):
2199
2918
  if "Preview" in layer.name:
2200
- self.viewer.layers.remove(layer)
2919
+ # Clean up callbacks before removing the layer to prevent cleanup issues
2920
+ if hasattr(layer, "mouse_drag_callbacks"):
2921
+ layer.mouse_drag_callbacks.clear()
2922
+ with contextlib.suppress(ValueError):
2923
+ self.viewer.layers.remove(layer)
2201
2924
 
2202
2925
  # Add preview layer
2203
2926
  if label_ids:
@@ -2288,17 +3011,14 @@ class BatchCropAnything:
2288
3011
  # Save cropped image
2289
3012
  image_path = self.images[self.current_index]
2290
3013
  base_name, ext = os.path.splitext(image_path)
2291
- label_str = "_".join(
2292
- str(lid) for lid in sorted(self.selected_labels)
2293
- )
2294
- output_path = f"{base_name}_cropped_{label_str}.tif"
3014
+ output_path = f"{base_name}_sam2_cropped.tif"
2295
3015
 
2296
3016
  # Save using tifffile with explicit parameters for best compatibility
2297
3017
  imwrite(output_path, cropped_image, compression="zlib")
2298
3018
  self.viewer.status = f"Saved cropped image to {output_path}"
2299
3019
 
2300
3020
  # Save the label image with exact same dimensions as original
2301
- label_output_path = f"{base_name}_labels_{label_str}.tif"
3021
+ label_output_path = f"{base_name}_sam2_labels.tif"
2302
3022
  imwrite(label_output_path, label_image, compression="zlib")
2303
3023
  self.viewer.status += f"\nSaved label mask to {label_output_path}"
2304
3024
 
@@ -2312,6 +3032,27 @@ class BatchCropAnything:
2312
3032
  self.viewer.status = f"Error cropping image: {str(e)}"
2313
3033
  return False
2314
3034
 
3035
+ def reset_sam2_state(self):
3036
+ """Reset SAM2 predictor state for 2D segmentation."""
3037
+ if not self.use_3d and hasattr(self, "prepared_sam2_image"):
3038
+ # Re-set the image in the predictor (only for ImagePredictor)
3039
+ device_type = "cuda" if self.device.type == "cuda" else "cpu"
3040
+ try:
3041
+ if hasattr(self.predictor, "set_image"):
3042
+ with (
3043
+ torch.inference_mode(),
3044
+ torch.autocast(device_type, dtype=torch.float32),
3045
+ ):
3046
+ self.predictor.set_image(self.prepared_sam2_image)
3047
+ else:
3048
+ print(
3049
+ "DEBUG: reset_sam2_state - predictor doesn't have set_image method"
3050
+ )
3051
+ except (RuntimeError, AssertionError, TypeError, ValueError) as e:
3052
+ print(f"Error resetting SAM2 state: {e}")
3053
+ # If there's an error, try to reinitialize
3054
+ self._initialize_sam2()
3055
+
2315
3056
 
2316
3057
  def create_crop_widget(processor):
2317
3058
  """Create the crop control widget."""
@@ -2322,27 +3063,70 @@ def create_crop_widget(processor):
2322
3063
 
2323
3064
  # Instructions
2324
3065
  dimension_type = "3D (TYX/ZYX)" if processor.use_3d else "2D (YX)"
2325
- instructions_label = QLabel(
2326
- f"<b>Processing {dimension_type} data</b><br><br>"
2327
- "To create/edit objects:<br>"
2328
- "1. <b>Click on the POINTS layer</b> to add positive points<br>"
2329
- "2. Use Shift+click for negative points to refine segmentation<br>"
2330
- "3. Click on existing objects in the Segmentation layer to select them<br>"
2331
- "4. Press 'Crop' to save the selected objects to disk"
2332
- )
3066
+
3067
+ if processor.use_3d:
3068
+ instructions_text = (
3069
+ f"<b>Processing {dimension_type} data</b><br><br>"
3070
+ "<b>⚠️ IMPORTANT for 3D stacks:</b><br>"
3071
+ "<ul>"
3072
+ "<li><b>Navigate to the FIRST SLICE</b> where your object appears (use the time/Z slider)</li>"
3073
+ "<li><b>Switch to 2D view</b> (click 2D icon in napari, NOT 3D view)</li>"
3074
+ "<li><b>Point Mode:</b> Select Points layer and click on objects to segment them</li>"
3075
+ "<li><b>Rectangle Mode:</b> Draw rectangles around objects to segment them</li>"
3076
+ "<li>Segmentation will automatically propagate to all slices</li>"
3077
+ "</ul><br>"
3078
+ "<b>General Controls:</b><br>"
3079
+ "<ul>"
3080
+ "<li>Use <b>Shift+click</b> for negative points (remove areas from segmentation)</li>"
3081
+ "<li>Click on existing objects in <b>Segmentation layer</b> to select for cropping</li>"
3082
+ "<li>Press <b>CTRL+click</b> on labels in <b>Segmentation layer</b> to delete them</li>"
3083
+ "<li>Press <b>'Crop'</b> to save selected objects to disk</li>"
3084
+ "</ul>"
3085
+ )
3086
+ else:
3087
+ instructions_text = (
3088
+ f"<b>Processing {dimension_type} data</b><br><br>"
3089
+ "<b>Point Mode:</b> Click on objects to segment them. Use Shift+click for negative points.<br>"
3090
+ "<b>Rectangle Mode:</b> Draw rectangles around objects to segment them.<br><br>"
3091
+ "<ul>"
3092
+ "<li>Click on existing objects in <b>Segmentation layer</b> to select them for cropping</li>"
3093
+ "<li>Press <b>CTRL+click</b> on labels in <b>Segmentation layer</b> to delete them</li>"
3094
+ "<li>Press <b>'Crop'</b> to save selected objects to disk</li>"
3095
+ "</ul>"
3096
+ )
3097
+
3098
+ instructions_label = QLabel(instructions_text)
2333
3099
  instructions_label.setWordWrap(True)
2334
3100
  layout.addWidget(instructions_label)
2335
3101
 
2336
- # Add a button to ensure points layer is active
2337
- activate_button = QPushButton("Make Points Layer Active")
3102
+ # Add mode selector
3103
+ mode_layout = QHBoxLayout()
3104
+ mode_label = QLabel("<b>Prompt Mode:</b>")
3105
+ mode_layout.addWidget(mode_label)
3106
+
3107
+ point_mode_button = QPushButton("Points")
3108
+ point_mode_button.setCheckable(True)
3109
+ point_mode_button.setChecked(True)
3110
+ mode_layout.addWidget(point_mode_button)
3111
+
3112
+ box_mode_button = QPushButton("Rectangle")
3113
+ box_mode_button.setCheckable(True)
3114
+ box_mode_button.setChecked(False)
3115
+ mode_layout.addWidget(box_mode_button)
3116
+
3117
+ mode_layout.addStretch()
3118
+ layout.addLayout(mode_layout)
3119
+
3120
+ # Add a button to ensure active layer is correct
3121
+ activate_button = QPushButton("Make Prompt Layer Active")
2338
3122
  activate_button.clicked.connect(
2339
- lambda: processor._ensure_points_layer_active()
3123
+ lambda: processor._ensure_active_prompt_layer()
2340
3124
  )
2341
3125
  layout.addWidget(activate_button)
2342
3126
 
2343
- # Add a "Clear Points" button to reset prompts
2344
- clear_points_button = QPushButton("Clear Points")
2345
- layout.addWidget(clear_points_button)
3127
+ # Add a "Clear Prompts" button to reset prompts
3128
+ clear_prompts_button = QPushButton("Clear Prompts")
3129
+ layout.addWidget(clear_prompts_button)
2346
3130
 
2347
3131
  # Create label table
2348
3132
  label_table = processor.create_label_table(crop_widget)
@@ -2353,7 +3137,7 @@ def create_crop_widget(processor):
2353
3137
  # Selection buttons
2354
3138
  selection_layout = QHBoxLayout()
2355
3139
  select_all_button = QPushButton("Select All")
2356
- clear_selection_button = QPushButton("Clear Selection")
3140
+ clear_selection_button = QPushButton("Clear All Labels")
2357
3141
  selection_layout.addWidget(select_all_button)
2358
3142
  selection_layout.addWidget(clear_selection_button)
2359
3143
  layout.addLayout(selection_layout)
@@ -2391,51 +3175,152 @@ def create_crop_widget(processor):
2391
3175
  # Create new table
2392
3176
  label_table = processor.create_label_table(crop_widget)
2393
3177
  label_table.setMinimumHeight(200)
2394
- layout.insertWidget(3, label_table) # Insert after clear points button
3178
+ layout.insertWidget(
3179
+ 3, label_table
3180
+ ) # Insert after clear prompts button
2395
3181
  return label_table
2396
3182
 
2397
- # Add helper method to ensure points layer is active
2398
- def _ensure_points_layer_active():
2399
- points_layer = None
2400
- for layer in list(processor.viewer.layers):
2401
- if "Points" in layer.name:
2402
- points_layer = layer
2403
- break
3183
+ # Add helper method to ensure active prompt layer is selected based on mode
3184
+ def _ensure_active_prompt_layer():
3185
+ if processor.prompt_mode == "point":
3186
+ points_layer = None
3187
+ for layer in list(processor.viewer.layers):
3188
+ if "Points" in layer.name and "Object" not in layer.name:
3189
+ points_layer = layer
3190
+ break
2404
3191
 
2405
- if points_layer is not None:
2406
- processor.viewer.layers.selection.active = points_layer
2407
- status_label.setText(
2408
- "Points layer is now active - click to add points"
2409
- )
2410
- else:
2411
- status_label.setText(
2412
- "No points layer found. Please load an image first."
2413
- )
3192
+ if points_layer is not None:
3193
+ processor.viewer.layers.selection.active = points_layer
3194
+ if processor.use_3d:
3195
+ status_label.setText(
3196
+ "Points layer active - Navigate to FIRST SLICE of object, ensure 2D view, then click"
3197
+ )
3198
+ else:
3199
+ status_label.setText(
3200
+ "Points layer is now active - click to add points"
3201
+ )
3202
+ else:
3203
+ status_label.setText(
3204
+ "No points layer found. Please load an image first."
3205
+ )
3206
+ else: # box mode
3207
+ shapes_layer = None
3208
+ for layer in list(processor.viewer.layers):
3209
+ if "Rectangles" in layer.name:
3210
+ shapes_layer = layer
3211
+ break
3212
+
3213
+ if shapes_layer is not None:
3214
+ processor.viewer.layers.selection.active = shapes_layer
3215
+ status_label.setText(
3216
+ "Rectangles layer is now active - draw rectangles"
3217
+ )
3218
+ else:
3219
+ status_label.setText(
3220
+ "No rectangles layer found. Please load an image first."
3221
+ )
3222
+
3223
+ processor._ensure_active_prompt_layer = _ensure_active_prompt_layer
3224
+
3225
+ # Keep the old method for backward compatibility
3226
+ processor._ensure_points_layer_active = _ensure_active_prompt_layer
2414
3227
 
2415
- processor._ensure_points_layer_active = _ensure_points_layer_active
3228
+ def on_clear_prompts_clicked():
3229
+ # Find and clear/remove prompt layers based on mode
3230
+ main_points_layer = None
3231
+ object_points_layers = []
3232
+ shapes_layer = None
2416
3233
 
2417
- # Connect button signals
2418
- def on_clear_points_clicked():
2419
- # Remove all point layers
2420
3234
  for layer in list(processor.viewer.layers):
2421
3235
  if "Points" in layer.name:
3236
+ if "Object" in layer.name:
3237
+ object_points_layers.append(layer)
3238
+ else:
3239
+ main_points_layer = layer
3240
+ elif "Rectangles" in layer.name:
3241
+ shapes_layer = layer
3242
+
3243
+ # Remove object-specific point layers (these are created dynamically)
3244
+ for layer in object_points_layers:
3245
+ # Clean up callbacks before removing the layer to prevent cleanup issues
3246
+ if hasattr(layer, "mouse_drag_callbacks"):
3247
+ layer.mouse_drag_callbacks.clear()
3248
+ with contextlib.suppress(ValueError):
2422
3249
  processor.viewer.layers.remove(layer)
2423
3250
 
2424
- # Reset point tracking attributes
2425
- if hasattr(processor, "points_data"):
2426
- processor.points_data = {}
2427
- processor.points_labels = {}
3251
+ # Clear shapes layer
3252
+ if shapes_layer is not None:
3253
+ shapes_layer.data = []
2428
3254
 
2429
- if hasattr(processor, "obj_points"):
2430
- processor.obj_points = {}
2431
- processor.obj_labels = {}
3255
+ # Clear data from main points layer instead of removing it
3256
+ if main_points_layer is not None:
3257
+ # Clear the points data
3258
+ main_points_layer.data = np.zeros(
3259
+ (0, 2 if not processor.use_3d else 3)
3260
+ )
3261
+ main_points_layer.face_color = "green"
2432
3262
 
2433
- # Re-create empty points layer
2434
- processor._update_label_layer()
2435
- processor._ensure_points_layer_active()
3263
+ # Ensure the click callback is still connected
3264
+ if (
3265
+ hasattr(main_points_layer, "mouse_drag_callbacks")
3266
+ and processor._on_points_clicked
3267
+ not in main_points_layer.mouse_drag_callbacks
3268
+ ):
3269
+ main_points_layer.mouse_drag_callbacks.append(
3270
+ processor._on_points_clicked
3271
+ )
3272
+
3273
+ # Reset all tracking attributes for 2D
3274
+ if not processor.use_3d:
3275
+ # Reset current segmentation tracking
3276
+ if hasattr(processor, "current_points"):
3277
+ processor.current_points = []
3278
+ processor.current_labels = []
3279
+
3280
+ # Reset object tracking
3281
+ if hasattr(processor, "obj_points"):
3282
+ processor.obj_points = {}
3283
+ processor.obj_labels = {}
3284
+
3285
+ # Reset box tracking
3286
+ if hasattr(processor, "obj_boxes"):
3287
+ processor.obj_boxes = {}
3288
+
3289
+ # Reset object ID counters
3290
+ if hasattr(processor, "current_obj_id"):
3291
+ # Find the highest existing label ID
3292
+ if processor.segmentation_result is not None:
3293
+ max_label = processor.segmentation_result.max()
3294
+ processor.current_obj_id = max(int(max_label) + 1, 1)
3295
+ processor.next_obj_id = processor.current_obj_id
3296
+ else:
3297
+ processor.current_obj_id = 1
3298
+ processor.next_obj_id = 1
3299
+
3300
+ # Reset SAM2 predictor state
3301
+ processor.reset_sam2_state()
3302
+
3303
+ # For 3D, reset video-specific tracking
3304
+ else:
3305
+ if hasattr(processor, "sam2_points_by_obj"):
3306
+ processor.sam2_points_by_obj = {}
3307
+ processor.sam2_labels_by_obj = {}
3308
+
3309
+ # Reset box tracking
3310
+ if hasattr(processor, "obj_boxes"):
3311
+ processor.obj_boxes = {}
3312
+
3313
+ if hasattr(processor, "points_data"):
3314
+ processor.points_data = {}
3315
+ processor.points_labels = {}
3316
+
3317
+ # Note: We don't reset _sam2_state for 3D as it needs to maintain video state
3318
+
3319
+ # Make the appropriate prompt layer active based on mode
3320
+ _ensure_active_prompt_layer()
2436
3321
 
2437
3322
  status_label.setText(
2438
- "Cleared all points. Click on Points layer to add new points."
3323
+ "Cleared all prompts. Ready to add new segmentation prompts."
2439
3324
  )
2440
3325
 
2441
3326
  def on_select_all_clicked():
@@ -2459,8 +3344,14 @@ def create_crop_widget(processor):
2459
3344
  )
2460
3345
 
2461
3346
  def on_next_clicked():
2462
- # Clear points before moving to next image
2463
- on_clear_points_clicked()
3347
+ # Check if we can move to the next image before clearing prompts
3348
+ if processor.current_index >= len(processor.images) - 1:
3349
+ next_button.setEnabled(False)
3350
+ status_label.setText("No more images. Processing complete.")
3351
+ return
3352
+
3353
+ # Clear prompts before moving to next image
3354
+ on_clear_prompts_clicked()
2464
3355
 
2465
3356
  if not processor.next_image():
2466
3357
  next_button.setEnabled(False)
@@ -2470,11 +3361,17 @@ def create_crop_widget(processor):
2470
3361
  status_label.setText(
2471
3362
  f"Showing image {processor.current_index + 1}/{len(processor.images)}"
2472
3363
  )
2473
- processor._ensure_points_layer_active()
3364
+ processor._ensure_active_prompt_layer()
2474
3365
 
2475
3366
  def on_prev_clicked():
2476
- # Clear points before moving to previous image
2477
- on_clear_points_clicked()
3367
+ # Check if we can move to the previous image before clearing prompts
3368
+ if processor.current_index <= 0:
3369
+ prev_button.setEnabled(False)
3370
+ status_label.setText("Already at the first image.")
3371
+ return
3372
+
3373
+ # Clear prompts before moving to previous image
3374
+ on_clear_prompts_clicked()
2478
3375
 
2479
3376
  if not processor.previous_image():
2480
3377
  prev_button.setEnabled(False)
@@ -2484,15 +3381,33 @@ def create_crop_widget(processor):
2484
3381
  status_label.setText(
2485
3382
  f"Showing image {processor.current_index + 1}/{len(processor.images)}"
2486
3383
  )
2487
- processor._ensure_points_layer_active()
3384
+ processor._ensure_active_prompt_layer()
3385
+
3386
+ def on_point_mode_clicked():
3387
+ processor.prompt_mode = "point"
3388
+ point_mode_button.setChecked(True)
3389
+ box_mode_button.setChecked(False)
3390
+ processor._update_label_layer()
3391
+ status_label.setText("Point mode active - click on objects to segment")
2488
3392
 
2489
- clear_points_button.clicked.connect(on_clear_points_clicked)
3393
+ def on_box_mode_clicked():
3394
+ processor.prompt_mode = "box"
3395
+ point_mode_button.setChecked(False)
3396
+ box_mode_button.setChecked(True)
3397
+ processor._update_label_layer()
3398
+ status_label.setText(
3399
+ "Rectangle mode active - draw rectangles around objects"
3400
+ )
3401
+
3402
+ clear_prompts_button.clicked.connect(on_clear_prompts_clicked)
2490
3403
  select_all_button.clicked.connect(on_select_all_clicked)
2491
3404
  clear_selection_button.clicked.connect(on_clear_selection_clicked)
2492
3405
  crop_button.clicked.connect(on_crop_clicked)
2493
3406
  next_button.clicked.connect(on_next_clicked)
2494
3407
  prev_button.clicked.connect(on_prev_clicked)
2495
- activate_button.clicked.connect(_ensure_points_layer_active)
3408
+ activate_button.clicked.connect(_ensure_active_prompt_layer)
3409
+ point_mode_button.clicked.connect(on_point_mode_clicked)
3410
+ box_mode_button.clicked.connect(on_box_mode_clicked)
2496
3411
 
2497
3412
  return crop_widget
2498
3413
 
@@ -2511,6 +3426,19 @@ def batch_crop_anything(
2511
3426
  viewer: Viewer = None,
2512
3427
  ):
2513
3428
  """MagicGUI widget for starting Batch Crop Anything using SAM2."""
3429
+ # Check if torch is available
3430
+ if not _HAS_TORCH:
3431
+ QMessageBox.critical(
3432
+ None,
3433
+ "Missing Dependency",
3434
+ "PyTorch not found. Batch Crop Anything requires PyTorch and SAM2.\n\n"
3435
+ "To install the required dependencies, run:\n"
3436
+ "pip install 'napari-tmidas[deep-learning]'\n\n"
3437
+ "Then follow SAM2 installation instructions at:\n"
3438
+ "https://github.com/MercaderLabAnatomy/napari-tmidas#installation",
3439
+ )
3440
+ return
3441
+
2514
3442
  # Check if SAM2 is available
2515
3443
  try:
2516
3444
  import importlib.util
@@ -2521,15 +3449,15 @@ def batch_crop_anything(
2521
3449
  None,
2522
3450
  "Missing Dependency",
2523
3451
  "SAM2 not found. Please follow installation instructions at:\n"
2524
- "https://github.com/MercaderLabAnatomy/napari-tmidas?tab=readme-ov-file#dependencies\n",
3452
+ "https://github.com/MercaderLabAnatomy/napari-tmidas#installation\n",
2525
3453
  )
2526
3454
  return
2527
3455
  except ImportError:
2528
3456
  QMessageBox.critical(
2529
3457
  None,
2530
3458
  "Missing Dependency",
2531
- "SAM2 package cannot be imported. Please follow installation instructions at\n"
2532
- "https://github.com/MercaderLabAnatomy/napari-tmidas?tab=readme-ov-file#dependencies",
3459
+ "SAM2 package cannot be imported. Please follow installation instructions at:\n"
3460
+ "https://github.com/MercaderLabAnatomy/napari-tmidas#installation",
2533
3461
  )
2534
3462
  return
2535
3463
 
@@ -2557,24 +3485,7 @@ def batch_crop_anything_widget():
2557
3485
  # Create the magicgui widget
2558
3486
  widget = batch_crop_anything
2559
3487
 
2560
- # Create and add browse button for folder path
2561
- folder_browse_button = QPushButton("Browse...")
2562
-
2563
- def on_folder_browse_clicked():
2564
- folder = QFileDialog.getExistingDirectory(
2565
- None,
2566
- "Select Folder",
2567
- os.path.expanduser("~"),
2568
- QFileDialog.ShowDirsOnly | QFileDialog.DontResolveSymlinks,
2569
- )
2570
- if folder:
2571
- # Update the folder_path field
2572
- widget.folder_path.value = folder
2573
-
2574
- folder_browse_button.clicked.connect(on_folder_browse_clicked)
2575
-
2576
- # Insert the browse button next to the folder_path field
2577
- folder_layout = widget.folder_path.native.parent().layout()
2578
- folder_layout.addWidget(folder_browse_button)
3488
+ # Add browse button using common utility
3489
+ add_browse_button_to_folder_field(widget, "folder_path")
2579
3490
 
2580
3491
  return widget