napari-tmidas 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- napari_tmidas/__init__.py +35 -5
- napari_tmidas/_crop_anything.py +1520 -609
- napari_tmidas/_env_manager.py +76 -0
- napari_tmidas/_file_conversion.py +1646 -1131
- napari_tmidas/_file_selector.py +1455 -216
- napari_tmidas/_label_inspection.py +83 -8
- napari_tmidas/_processing_worker.py +309 -0
- napari_tmidas/_reader.py +6 -10
- napari_tmidas/_registry.py +2 -2
- napari_tmidas/_roi_colocalization.py +1221 -84
- napari_tmidas/_tests/test_crop_anything.py +123 -0
- napari_tmidas/_tests/test_env_manager.py +89 -0
- napari_tmidas/_tests/test_grid_view_overlay.py +193 -0
- napari_tmidas/_tests/test_init.py +98 -0
- napari_tmidas/_tests/test_intensity_label_filter.py +222 -0
- napari_tmidas/_tests/test_label_inspection.py +86 -0
- napari_tmidas/_tests/test_processing_basic.py +500 -0
- napari_tmidas/_tests/test_processing_worker.py +142 -0
- napari_tmidas/_tests/test_regionprops_analysis.py +547 -0
- napari_tmidas/_tests/test_registry.py +70 -2
- napari_tmidas/_tests/test_scipy_filters.py +168 -0
- napari_tmidas/_tests/test_skimage_filters.py +259 -0
- napari_tmidas/_tests/test_split_channels.py +217 -0
- napari_tmidas/_tests/test_spotiflow.py +87 -0
- napari_tmidas/_tests/test_tyx_display_fix.py +142 -0
- napari_tmidas/_tests/test_ui_utils.py +68 -0
- napari_tmidas/_tests/test_widget.py +30 -0
- napari_tmidas/_tests/test_windows_basic.py +66 -0
- napari_tmidas/_ui_utils.py +57 -0
- napari_tmidas/_version.py +16 -3
- napari_tmidas/_widget.py +41 -4
- napari_tmidas/processing_functions/basic.py +557 -20
- napari_tmidas/processing_functions/careamics_env_manager.py +72 -99
- napari_tmidas/processing_functions/cellpose_env_manager.py +415 -112
- napari_tmidas/processing_functions/cellpose_segmentation.py +132 -191
- napari_tmidas/processing_functions/colocalization.py +513 -56
- napari_tmidas/processing_functions/grid_view_overlay.py +703 -0
- napari_tmidas/processing_functions/intensity_label_filter.py +422 -0
- napari_tmidas/processing_functions/regionprops_analysis.py +1280 -0
- napari_tmidas/processing_functions/sam2_env_manager.py +53 -69
- napari_tmidas/processing_functions/sam2_mp4.py +274 -195
- napari_tmidas/processing_functions/scipy_filters.py +403 -8
- napari_tmidas/processing_functions/skimage_filters.py +424 -212
- napari_tmidas/processing_functions/spotiflow_detection.py +949 -0
- napari_tmidas/processing_functions/spotiflow_env_manager.py +591 -0
- napari_tmidas/processing_functions/timepoint_merger.py +334 -86
- {napari_tmidas-0.2.2.dist-info → napari_tmidas-0.2.4.dist-info}/METADATA +70 -30
- napari_tmidas-0.2.4.dist-info/RECORD +63 -0
- napari_tmidas/_tests/__init__.py +0 -0
- napari_tmidas-0.2.2.dist-info/RECORD +0 -40
- {napari_tmidas-0.2.2.dist-info → napari_tmidas-0.2.4.dist-info}/WHEEL +0 -0
- {napari_tmidas-0.2.2.dist-info → napari_tmidas-0.2.4.dist-info}/entry_points.txt +0 -0
- {napari_tmidas-0.2.2.dist-info → napari_tmidas-0.2.4.dist-info}/licenses/LICENSE +0 -0
- {napari_tmidas-0.2.2.dist-info → napari_tmidas-0.2.4.dist-info}/top_level.txt +0 -0
napari_tmidas/_crop_anything.py
CHANGED
|
@@ -9,32 +9,99 @@ The plugin supports both 2D (YX) and 3D (TYX/ZYX) data.
|
|
|
9
9
|
import contextlib
|
|
10
10
|
import os
|
|
11
11
|
import sys
|
|
12
|
+
from pathlib import Path
|
|
12
13
|
|
|
13
14
|
import numpy as np
|
|
14
|
-
import requests
|
|
15
|
-
import torch
|
|
16
|
-
from magicgui import magicgui
|
|
17
|
-
from napari.layers import Labels
|
|
18
|
-
from napari.viewer import Viewer
|
|
19
|
-
from qtpy.QtCore import Qt
|
|
20
|
-
from qtpy.QtWidgets import (
|
|
21
|
-
QCheckBox,
|
|
22
|
-
QFileDialog,
|
|
23
|
-
QHBoxLayout,
|
|
24
|
-
QHeaderView,
|
|
25
|
-
QLabel,
|
|
26
|
-
QMessageBox,
|
|
27
|
-
QPushButton,
|
|
28
|
-
QScrollArea,
|
|
29
|
-
QTableWidget,
|
|
30
|
-
QTableWidgetItem,
|
|
31
|
-
QVBoxLayout,
|
|
32
|
-
QWidget,
|
|
33
|
-
)
|
|
34
|
-
from skimage.io import imread
|
|
35
|
-
from skimage.transform import resize
|
|
36
|
-
from tifffile import imwrite
|
|
37
15
|
|
|
16
|
+
# Lazy imports for optional heavy dependencies
|
|
17
|
+
try:
|
|
18
|
+
import requests
|
|
19
|
+
|
|
20
|
+
_HAS_REQUESTS = True
|
|
21
|
+
except ImportError:
|
|
22
|
+
requests = None
|
|
23
|
+
_HAS_REQUESTS = False
|
|
24
|
+
|
|
25
|
+
try:
|
|
26
|
+
import torch
|
|
27
|
+
|
|
28
|
+
_HAS_TORCH = True
|
|
29
|
+
except ImportError:
|
|
30
|
+
torch = None
|
|
31
|
+
_HAS_TORCH = False
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
from magicgui import magicgui
|
|
35
|
+
|
|
36
|
+
_HAS_MAGICGUI = True
|
|
37
|
+
except ImportError:
|
|
38
|
+
# Create stub decorator
|
|
39
|
+
def magicgui(*args, **kwargs):
|
|
40
|
+
def decorator(func):
|
|
41
|
+
return func
|
|
42
|
+
|
|
43
|
+
if len(args) == 1 and callable(args[0]) and not kwargs:
|
|
44
|
+
return args[0]
|
|
45
|
+
return decorator
|
|
46
|
+
|
|
47
|
+
_HAS_MAGICGUI = False
|
|
48
|
+
|
|
49
|
+
try:
|
|
50
|
+
from napari.layers import Labels
|
|
51
|
+
from napari.viewer import Viewer
|
|
52
|
+
|
|
53
|
+
_HAS_NAPARI = True
|
|
54
|
+
except ImportError:
|
|
55
|
+
Labels = None
|
|
56
|
+
Viewer = None
|
|
57
|
+
_HAS_NAPARI = False
|
|
58
|
+
|
|
59
|
+
try:
|
|
60
|
+
from qtpy.QtCore import Qt
|
|
61
|
+
from qtpy.QtWidgets import (
|
|
62
|
+
QCheckBox,
|
|
63
|
+
QHBoxLayout,
|
|
64
|
+
QHeaderView,
|
|
65
|
+
QLabel,
|
|
66
|
+
QMessageBox,
|
|
67
|
+
QPushButton,
|
|
68
|
+
QScrollArea,
|
|
69
|
+
QTableWidget,
|
|
70
|
+
QTableWidgetItem,
|
|
71
|
+
QVBoxLayout,
|
|
72
|
+
QWidget,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
_HAS_QTPY = True
|
|
76
|
+
except ImportError:
|
|
77
|
+
Qt = None
|
|
78
|
+
QCheckBox = QHBoxLayout = QHeaderView = QLabel = QMessageBox = None
|
|
79
|
+
QPushButton = QScrollArea = QTableWidget = QTableWidgetItem = None
|
|
80
|
+
QVBoxLayout = QWidget = None
|
|
81
|
+
_HAS_QTPY = False
|
|
82
|
+
|
|
83
|
+
try:
|
|
84
|
+
from skimage.io import imread
|
|
85
|
+
from skimage.transform import resize
|
|
86
|
+
|
|
87
|
+
_HAS_SKIMAGE = True
|
|
88
|
+
except ImportError:
|
|
89
|
+
imread = None
|
|
90
|
+
resize = None
|
|
91
|
+
_HAS_SKIMAGE = False
|
|
92
|
+
|
|
93
|
+
try:
|
|
94
|
+
from tifffile import imwrite
|
|
95
|
+
|
|
96
|
+
_HAS_TIFFFILE = True
|
|
97
|
+
except ImportError:
|
|
98
|
+
imwrite = None
|
|
99
|
+
_HAS_TIFFFILE = False
|
|
100
|
+
|
|
101
|
+
from napari_tmidas._file_selector import (
|
|
102
|
+
load_image_file as load_any_image,
|
|
103
|
+
)
|
|
104
|
+
from napari_tmidas._ui_utils import add_browse_button_to_folder_field
|
|
38
105
|
from napari_tmidas.processing_functions.sam2_mp4 import tif_to_mp4
|
|
39
106
|
|
|
40
107
|
sam2_paths = [
|
|
@@ -98,6 +165,7 @@ class BatchCropAnything:
|
|
|
98
165
|
self.image_layer = None
|
|
99
166
|
self.label_layer = None
|
|
100
167
|
self.label_table_widget = None
|
|
168
|
+
self.shapes_layer = None
|
|
101
169
|
|
|
102
170
|
# State tracking
|
|
103
171
|
self.selected_labels = set()
|
|
@@ -106,6 +174,9 @@ class BatchCropAnything:
|
|
|
106
174
|
# Segmentation parameters
|
|
107
175
|
self.sensitivity = 50 # Default sensitivity (0-100 scale)
|
|
108
176
|
|
|
177
|
+
# Prompt mode: 'point' or 'box'
|
|
178
|
+
self.prompt_mode = "point"
|
|
179
|
+
|
|
109
180
|
# Initialize the SAM2 model
|
|
110
181
|
self._initialize_sam2()
|
|
111
182
|
|
|
@@ -131,17 +202,45 @@ class BatchCropAnything:
|
|
|
131
202
|
|
|
132
203
|
try:
|
|
133
204
|
# import torch
|
|
205
|
+
print("DEBUG: Starting SAM2 initialization...")
|
|
134
206
|
|
|
135
207
|
self.device = get_device()
|
|
208
|
+
print(f"DEBUG: Device set to {self.device}")
|
|
136
209
|
|
|
137
210
|
# Download checkpoint if needed
|
|
138
211
|
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt"
|
|
139
212
|
checkpoint_path = download_checkpoint(
|
|
140
213
|
checkpoint_url, "/opt/sam2/checkpoints/"
|
|
141
214
|
)
|
|
215
|
+
print(f"DEBUG: Checkpoint path: {checkpoint_path}")
|
|
216
|
+
|
|
217
|
+
# Use relative config path for SAM2's Hydra config system
|
|
142
218
|
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
|
219
|
+
print(f"DEBUG: Model config: {model_cfg}")
|
|
220
|
+
|
|
221
|
+
# Verify the actual config file exists in the SAM2 installation
|
|
222
|
+
sam2_base_path = None
|
|
223
|
+
for path in sam2_paths:
|
|
224
|
+
if path and os.path.exists(path):
|
|
225
|
+
sam2_base_path = path
|
|
226
|
+
break
|
|
227
|
+
|
|
228
|
+
if sam2_base_path is not None:
|
|
229
|
+
full_config_path = os.path.join(
|
|
230
|
+
sam2_base_path, "sam2", model_cfg
|
|
231
|
+
)
|
|
232
|
+
if not os.path.exists(full_config_path):
|
|
233
|
+
raise FileNotFoundError(
|
|
234
|
+
f"SAM2 config file not found at: {full_config_path}"
|
|
235
|
+
)
|
|
236
|
+
print(f"DEBUG: Verified config exists at: {full_config_path}")
|
|
237
|
+
else:
|
|
238
|
+
print(
|
|
239
|
+
"DEBUG: Warning - could not verify config file exists, but proceeding with relative path"
|
|
240
|
+
)
|
|
143
241
|
|
|
144
242
|
if self.use_3d:
|
|
243
|
+
print("DEBUG: Initializing SAM2 Video Predictor...")
|
|
145
244
|
from sam2.build_sam import build_sam2_video_predictor
|
|
146
245
|
|
|
147
246
|
self.predictor = build_sam2_video_predictor(
|
|
@@ -150,7 +249,9 @@ class BatchCropAnything:
|
|
|
150
249
|
self.viewer.status = (
|
|
151
250
|
f"Initialized SAM2 Video Predictor on {self.device}"
|
|
152
251
|
)
|
|
252
|
+
print("DEBUG: SAM2 Video Predictor initialized successfully")
|
|
153
253
|
else:
|
|
254
|
+
print("DEBUG: Initializing SAM2 Image Predictor...")
|
|
154
255
|
from sam2.build_sam import build_sam2
|
|
155
256
|
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
156
257
|
|
|
@@ -160,6 +261,7 @@ class BatchCropAnything:
|
|
|
160
261
|
self.viewer.status = (
|
|
161
262
|
f"Initialized SAM2 Image Predictor on {self.device}"
|
|
162
263
|
)
|
|
264
|
+
print("DEBUG: SAM2 Image Predictor initialized successfully")
|
|
163
265
|
|
|
164
266
|
except (
|
|
165
267
|
ImportError,
|
|
@@ -167,37 +269,79 @@ class BatchCropAnything:
|
|
|
167
269
|
ValueError,
|
|
168
270
|
FileNotFoundError,
|
|
169
271
|
requests.RequestException,
|
|
272
|
+
AttributeError,
|
|
273
|
+
ModuleNotFoundError,
|
|
170
274
|
) as e:
|
|
171
275
|
import traceback
|
|
172
276
|
|
|
173
|
-
|
|
277
|
+
error_msg = f"SAM2 initialization failed: {str(e)}"
|
|
278
|
+
error_type = type(e).__name__
|
|
279
|
+
self.viewer.status = (
|
|
280
|
+
f"{error_msg} - Images will load without segmentation"
|
|
281
|
+
)
|
|
174
282
|
self.predictor = None
|
|
283
|
+
print(f"DEBUG: SAM2 Error ({error_type}): {error_msg}")
|
|
284
|
+
print("DEBUG: Full traceback:")
|
|
175
285
|
print(traceback.format_exc())
|
|
286
|
+
print(
|
|
287
|
+
"DEBUG: Note: Images will still load, but automatic segmentation will not be available."
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
# Provide specific guidance based on error type
|
|
291
|
+
if isinstance(e, FileNotFoundError):
|
|
292
|
+
print(
|
|
293
|
+
"DEBUG: This appears to be a missing file issue. Check SAM2 installation and config paths."
|
|
294
|
+
)
|
|
295
|
+
elif isinstance(e, (ImportError, ModuleNotFoundError)):
|
|
296
|
+
print(
|
|
297
|
+
"DEBUG: This appears to be a SAM2 import issue. Check SAM2 installation."
|
|
298
|
+
)
|
|
299
|
+
elif isinstance(e, RuntimeError):
|
|
300
|
+
print(
|
|
301
|
+
"DEBUG: This appears to be a runtime issue, possibly GPU/CUDA related."
|
|
302
|
+
)
|
|
303
|
+
else:
|
|
304
|
+
print(f"DEBUG: Unexpected error type: {error_type}")
|
|
176
305
|
|
|
177
306
|
def load_images(self, folder_path: str):
|
|
178
307
|
"""Load images from the specified folder path."""
|
|
308
|
+
print(f"DEBUG: Loading images from folder: {folder_path}")
|
|
179
309
|
if not os.path.exists(folder_path):
|
|
180
310
|
self.viewer.status = f"Folder not found: {folder_path}"
|
|
311
|
+
print(f"DEBUG: Folder does not exist: {folder_path}")
|
|
181
312
|
return
|
|
182
313
|
|
|
183
314
|
files = os.listdir(folder_path)
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
and
|
|
193
|
-
|
|
315
|
+
print(f"DEBUG: Found {len(files)} files in folder")
|
|
316
|
+
self.images = []
|
|
317
|
+
for file in files:
|
|
318
|
+
full = os.path.join(folder_path, file)
|
|
319
|
+
low = file.lower()
|
|
320
|
+
if (
|
|
321
|
+
low.endswith((".tif", ".tiff"))
|
|
322
|
+
or (os.path.isdir(full) and low.endswith(".zarr"))
|
|
323
|
+
) and (
|
|
324
|
+
"label" not in low
|
|
325
|
+
and "_labels_" not in low
|
|
326
|
+
and "sam2"
|
|
327
|
+
not in low # Exclude any SAM2-related files (including output from this tool)
|
|
328
|
+
):
|
|
329
|
+
self.images.append(full)
|
|
330
|
+
print(f"DEBUG: Added image: {file}")
|
|
331
|
+
else:
|
|
332
|
+
print(
|
|
333
|
+
f"DEBUG: Excluded file: {file} (reason: filtering criteria)"
|
|
334
|
+
)
|
|
194
335
|
|
|
195
336
|
if not self.images:
|
|
196
337
|
self.viewer.status = "No compatible images found in the folder."
|
|
338
|
+
print("DEBUG: No compatible images found")
|
|
197
339
|
return
|
|
198
340
|
|
|
341
|
+
print(f"DEBUG: Total compatible images found: {len(self.images)}")
|
|
199
342
|
self.viewer.status = f"Found {len(self.images)} .tif images."
|
|
200
343
|
self.current_index = 0
|
|
344
|
+
print(f"DEBUG: About to load first image: {self.images[0]}")
|
|
201
345
|
self._load_current_image()
|
|
202
346
|
|
|
203
347
|
def next_image(self):
|
|
@@ -250,25 +394,69 @@ class BatchCropAnything:
|
|
|
250
394
|
|
|
251
395
|
def _load_current_image(self):
|
|
252
396
|
"""Load the current image and generate segmentation."""
|
|
397
|
+
print("DEBUG: _load_current_image called")
|
|
253
398
|
if not self.images:
|
|
254
399
|
self.viewer.status = "No images to process."
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
if self.predictor is None:
|
|
258
|
-
self.viewer.status = (
|
|
259
|
-
"SAM2 model not initialized. Cannot segment images."
|
|
260
|
-
)
|
|
400
|
+
print("DEBUG: No images to process")
|
|
261
401
|
return
|
|
262
402
|
|
|
263
403
|
image_path = self.images[self.current_index]
|
|
264
|
-
|
|
404
|
+
print(f"DEBUG: Loading image at path: {image_path}")
|
|
405
|
+
|
|
406
|
+
if self.predictor is None:
|
|
407
|
+
self.viewer.status = f"Loading {os.path.basename(image_path)} (SAM2 model not initialized - no segmentation will be available)"
|
|
408
|
+
print("DEBUG: SAM2 predictor is None")
|
|
409
|
+
else:
|
|
410
|
+
self.viewer.status = f"Processing {os.path.basename(image_path)}"
|
|
411
|
+
print("DEBUG: SAM2 predictor is available")
|
|
265
412
|
|
|
266
413
|
try:
|
|
414
|
+
print("DEBUG: About to clear viewer layers")
|
|
267
415
|
# Clear existing layers
|
|
268
416
|
self.viewer.layers.clear()
|
|
417
|
+
print("DEBUG: Viewer layers cleared")
|
|
269
418
|
|
|
419
|
+
print("DEBUG: About to load image file")
|
|
270
420
|
# Load and process image
|
|
271
|
-
|
|
421
|
+
if image_path.lower().endswith(".zarr") or (
|
|
422
|
+
os.path.isdir(image_path)
|
|
423
|
+
and image_path.lower().endswith(".zarr")
|
|
424
|
+
):
|
|
425
|
+
print("DEBUG: Loading Zarr file")
|
|
426
|
+
data = load_any_image(image_path)
|
|
427
|
+
# If multiple layers returned, take first image layer
|
|
428
|
+
if isinstance(data, list):
|
|
429
|
+
img = None
|
|
430
|
+
for entry in data:
|
|
431
|
+
if isinstance(entry, tuple) and len(entry) == 3:
|
|
432
|
+
d, _kwargs, layer_type = entry
|
|
433
|
+
if layer_type == "image":
|
|
434
|
+
img = d
|
|
435
|
+
break
|
|
436
|
+
elif isinstance(entry, tuple) and len(entry) == 2:
|
|
437
|
+
d, _kwargs = entry
|
|
438
|
+
img = d
|
|
439
|
+
break
|
|
440
|
+
else:
|
|
441
|
+
img = entry
|
|
442
|
+
break
|
|
443
|
+
if img is None:
|
|
444
|
+
raise ValueError("No image layer found in Zarr store")
|
|
445
|
+
else:
|
|
446
|
+
img = data
|
|
447
|
+
|
|
448
|
+
# Compute dask arrays to numpy if needed
|
|
449
|
+
if hasattr(img, "compute"):
|
|
450
|
+
img = img.compute()
|
|
451
|
+
|
|
452
|
+
self.original_image = img
|
|
453
|
+
else:
|
|
454
|
+
print("DEBUG: Loading TIFF file")
|
|
455
|
+
self.original_image = imread(image_path)
|
|
456
|
+
|
|
457
|
+
print(
|
|
458
|
+
f"DEBUG: Image loaded, shape: {self.original_image.shape}, dtype: {self.original_image.dtype}"
|
|
459
|
+
)
|
|
272
460
|
|
|
273
461
|
# For 3D/4D data, determine dimensions
|
|
274
462
|
if self.use_3d and len(self.original_image.shape) >= 3:
|
|
@@ -284,10 +472,12 @@ class BatchCropAnything:
|
|
|
284
472
|
|
|
285
473
|
if time_dim_idx == 0: # TZYX format
|
|
286
474
|
# Keep as is, T is already the first dimension
|
|
475
|
+
print("DEBUG: Adding 4D image (TZYX format) to viewer")
|
|
287
476
|
self.image_layer = self.viewer.add_image(
|
|
288
477
|
self.original_image,
|
|
289
478
|
name=f"Image ({os.path.basename(image_path)})",
|
|
290
479
|
)
|
|
480
|
+
print(f"DEBUG: Added image layer: {self.image_layer}")
|
|
291
481
|
# Store time dimension info
|
|
292
482
|
self.time_dim_size = self.original_image.shape[0]
|
|
293
483
|
self.has_z_dim = True
|
|
@@ -309,19 +499,23 @@ class BatchCropAnything:
|
|
|
309
499
|
transposed_image # Replace with transposed version
|
|
310
500
|
)
|
|
311
501
|
|
|
502
|
+
print("DEBUG: Adding transposed 4D image to viewer")
|
|
312
503
|
self.image_layer = self.viewer.add_image(
|
|
313
504
|
self.original_image,
|
|
314
505
|
name=f"Image ({os.path.basename(image_path)})",
|
|
315
506
|
)
|
|
507
|
+
print(f"DEBUG: Added image layer: {self.image_layer}")
|
|
316
508
|
# Store time dimension info
|
|
317
509
|
self.time_dim_size = self.original_image.shape[0]
|
|
318
510
|
self.has_z_dim = True
|
|
319
511
|
else:
|
|
320
512
|
# No time dimension found, treat as ZYX
|
|
513
|
+
print("DEBUG: Adding 4D image (ZYX format) to viewer")
|
|
321
514
|
self.image_layer = self.viewer.add_image(
|
|
322
515
|
self.original_image,
|
|
323
516
|
name=f"Image ({os.path.basename(image_path)})",
|
|
324
517
|
)
|
|
518
|
+
print(f"DEBUG: Added image layer: {self.image_layer}")
|
|
325
519
|
self.time_dim_size = 1
|
|
326
520
|
self.has_z_dim = True
|
|
327
521
|
elif (
|
|
@@ -330,30 +524,37 @@ class BatchCropAnything:
|
|
|
330
524
|
# Check if first dimension is likely time (> 4, < 400)
|
|
331
525
|
if 4 < self.original_image.shape[0] < 400:
|
|
332
526
|
# Likely TYX format
|
|
527
|
+
print("DEBUG: Adding 3D image (TYX format) to viewer")
|
|
333
528
|
self.image_layer = self.viewer.add_image(
|
|
334
529
|
self.original_image,
|
|
335
530
|
name=f"Image ({os.path.basename(image_path)})",
|
|
336
531
|
)
|
|
532
|
+
print(f"DEBUG: Added image layer: {self.image_layer}")
|
|
337
533
|
self.time_dim_size = self.original_image.shape[0]
|
|
338
534
|
self.has_z_dim = False
|
|
339
535
|
else:
|
|
340
536
|
# Likely ZYX format or another 3D format
|
|
537
|
+
print("DEBUG: Adding 3D image (ZYX format) to viewer")
|
|
341
538
|
self.image_layer = self.viewer.add_image(
|
|
342
539
|
self.original_image,
|
|
343
540
|
name=f"Image ({os.path.basename(image_path)})",
|
|
344
541
|
)
|
|
542
|
+
print(f"DEBUG: Added image layer: {self.image_layer}")
|
|
345
543
|
self.time_dim_size = 1
|
|
346
544
|
self.has_z_dim = True
|
|
347
545
|
else:
|
|
348
546
|
# Should not reach here with use_3d=True, but just in case
|
|
547
|
+
print("DEBUG: Adding 3D image (fallback) to viewer")
|
|
349
548
|
self.image_layer = self.viewer.add_image(
|
|
350
549
|
self.original_image,
|
|
351
550
|
name=f"Image ({os.path.basename(image_path)})",
|
|
352
551
|
)
|
|
552
|
+
print(f"DEBUG: Added image layer: {self.image_layer}")
|
|
353
553
|
self.time_dim_size = 1
|
|
354
554
|
self.has_z_dim = False
|
|
355
555
|
else:
|
|
356
556
|
# Handle 2D data as before
|
|
557
|
+
print("DEBUG: Processing 2D image")
|
|
357
558
|
if self.original_image.dtype != np.uint8:
|
|
358
559
|
image_for_display = (
|
|
359
560
|
self.original_image
|
|
@@ -364,18 +565,42 @@ class BatchCropAnything:
|
|
|
364
565
|
image_for_display = self.original_image
|
|
365
566
|
|
|
366
567
|
# Add image to viewer
|
|
568
|
+
print("DEBUG: Adding 2D image to viewer")
|
|
367
569
|
self.image_layer = self.viewer.add_image(
|
|
368
570
|
image_for_display,
|
|
369
571
|
name=f"Image ({os.path.basename(image_path)})",
|
|
370
572
|
)
|
|
573
|
+
print(f"DEBUG: Added image layer: {self.image_layer}")
|
|
574
|
+
|
|
575
|
+
# Generate segmentation only if predictor is available
|
|
576
|
+
if self.predictor is not None:
|
|
577
|
+
print("DEBUG: About to generate segmentation")
|
|
578
|
+
self._generate_segmentation(self.original_image, image_path)
|
|
579
|
+
print("DEBUG: Segmentation generation completed")
|
|
580
|
+
else:
|
|
581
|
+
print("DEBUG: Creating empty segmentation (no predictor)")
|
|
582
|
+
# Create empty segmentation when predictor is not available
|
|
583
|
+
if self.use_3d:
|
|
584
|
+
shape = self.original_image.shape
|
|
585
|
+
else:
|
|
586
|
+
shape = self.original_image.shape[:2]
|
|
587
|
+
|
|
588
|
+
self.segmentation_result = np.zeros(shape, dtype=np.uint32)
|
|
589
|
+
self.label_layer = self.viewer.add_labels(
|
|
590
|
+
self.segmentation_result,
|
|
591
|
+
name="No Segmentation (SAM2 not available)",
|
|
592
|
+
)
|
|
593
|
+
print(f"DEBUG: Added empty label layer: {self.label_layer}")
|
|
371
594
|
|
|
372
|
-
|
|
373
|
-
self._generate_segmentation(self.original_image, image_path)
|
|
595
|
+
print("DEBUG: _load_current_image completed successfully")
|
|
374
596
|
|
|
375
597
|
except (FileNotFoundError, ValueError, TypeError, OSError) as e:
|
|
376
598
|
import traceback
|
|
377
599
|
|
|
378
|
-
|
|
600
|
+
error_msg = f"Error processing image: {str(e)}"
|
|
601
|
+
self.viewer.status = error_msg
|
|
602
|
+
print(f"DEBUG: Exception in _load_current_image: {error_msg}")
|
|
603
|
+
print("DEBUG: Full traceback:")
|
|
379
604
|
traceback.print_exc()
|
|
380
605
|
|
|
381
606
|
# Create empty segmentation in case of error
|
|
@@ -392,6 +617,7 @@ class BatchCropAnything:
|
|
|
392
617
|
self.label_layer = self.viewer.add_labels(
|
|
393
618
|
self.segmentation_result, name="Error: No Segmentation"
|
|
394
619
|
)
|
|
620
|
+
print(f"DEBUG: Added error label layer: {self.label_layer}")
|
|
395
621
|
|
|
396
622
|
def _generate_segmentation(self, image, image_path: str):
|
|
397
623
|
"""Generate segmentation for the current image using SAM2."""
|
|
@@ -447,7 +673,8 @@ class BatchCropAnything:
|
|
|
447
673
|
traceback.print_exc()
|
|
448
674
|
|
|
449
675
|
def _generate_2d_segmentation(self, confidence_threshold):
|
|
450
|
-
"""Generate 2D segmentation
|
|
676
|
+
"""Generate initial 2D segmentation - start with empty labels for interactive mode."""
|
|
677
|
+
device_type = "cuda" if self.device.type == "cuda" else "cpu"
|
|
451
678
|
# Ensure image is in the correct format for SAM2
|
|
452
679
|
image = self.current_image_for_segmentation
|
|
453
680
|
|
|
@@ -469,9 +696,7 @@ class BatchCropAnything:
|
|
|
469
696
|
(new_height, new_width),
|
|
470
697
|
anti_aliasing=True,
|
|
471
698
|
preserve_range=True,
|
|
472
|
-
).astype(
|
|
473
|
-
np.float32
|
|
474
|
-
) # Convert to float32
|
|
699
|
+
).astype(np.float32)
|
|
475
700
|
|
|
476
701
|
self.current_scale_factor = scale_factor
|
|
477
702
|
else:
|
|
@@ -497,73 +722,54 @@ class BatchCropAnything:
|
|
|
497
722
|
if resized_image.max() > 1.0:
|
|
498
723
|
resized_image = resized_image / 255.0
|
|
499
724
|
|
|
500
|
-
#
|
|
501
|
-
|
|
502
|
-
"cuda", dtype=torch.float32
|
|
503
|
-
):
|
|
504
|
-
# Set the image in the predictor
|
|
505
|
-
self.predictor.set_image(resized_image)
|
|
725
|
+
# Store the prepared image for later use
|
|
726
|
+
self.prepared_sam2_image = resized_image
|
|
506
727
|
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
point_labels=None,
|
|
511
|
-
box=None,
|
|
512
|
-
multimask_output=True,
|
|
513
|
-
)
|
|
728
|
+
# Initialize empty segmentation result
|
|
729
|
+
self.segmentation_result = np.zeros(orig_shape, dtype=np.uint32)
|
|
730
|
+
self.label_info = {}
|
|
514
731
|
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
self.
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
center_y = np.mean(y_indices) if len(y_indices) > 0 else 0
|
|
532
|
-
center_x = np.mean(x_indices) if len(x_indices) > 0 else 0
|
|
533
|
-
|
|
534
|
-
# Store label info
|
|
535
|
-
self.label_info[label_id] = {
|
|
536
|
-
"area": area,
|
|
537
|
-
"center_y": center_y,
|
|
538
|
-
"center_x": center_x,
|
|
539
|
-
"score": float(scores[i]),
|
|
540
|
-
}
|
|
541
|
-
|
|
542
|
-
# Handle upscaling if needed
|
|
543
|
-
if self.current_scale_factor < 1.0:
|
|
544
|
-
labels = resize(
|
|
545
|
-
labels,
|
|
546
|
-
orig_shape,
|
|
547
|
-
order=0, # Nearest neighbor interpolation
|
|
548
|
-
preserve_range=True,
|
|
549
|
-
anti_aliasing=False,
|
|
550
|
-
).astype(np.uint32)
|
|
551
|
-
|
|
552
|
-
# Sort labels by area (largest first)
|
|
553
|
-
self.label_info = dict(
|
|
554
|
-
sorted(
|
|
555
|
-
self.label_info.items(),
|
|
556
|
-
key=lambda item: item[1]["area"],
|
|
557
|
-
reverse=True,
|
|
558
|
-
)
|
|
732
|
+
# Initialize tracking for interactive segmentation
|
|
733
|
+
self.current_points = []
|
|
734
|
+
self.current_labels = []
|
|
735
|
+
self.current_obj_id = 1
|
|
736
|
+
self.next_obj_id = 1
|
|
737
|
+
|
|
738
|
+
# Initialize object tracking dictionaries
|
|
739
|
+
self.obj_points = {}
|
|
740
|
+
self.obj_labels = {}
|
|
741
|
+
|
|
742
|
+
# Reset SAM2-specific tracking dictionaries for 2D mode
|
|
743
|
+
self.sam2_points_by_obj = {}
|
|
744
|
+
self.sam2_labels_by_obj = {}
|
|
745
|
+
self._sam2_next_obj_id = 1
|
|
746
|
+
print(
|
|
747
|
+
"DEBUG: Reset _sam2_next_obj_id to 1 in _generate_2d_segmentation"
|
|
559
748
|
)
|
|
560
749
|
|
|
561
|
-
#
|
|
562
|
-
self.
|
|
750
|
+
# Set the image in the predictor for later use (2D mode only)
|
|
751
|
+
device_type = "cuda" if self.device.type == "cuda" else "cpu"
|
|
752
|
+
if hasattr(self.predictor, "set_image"):
|
|
753
|
+
with (
|
|
754
|
+
torch.inference_mode(),
|
|
755
|
+
torch.autocast(device_type, dtype=torch.float32),
|
|
756
|
+
):
|
|
757
|
+
self.predictor.set_image(resized_image)
|
|
758
|
+
else:
|
|
759
|
+
print(
|
|
760
|
+
"DEBUG: Skipping set_image - predictor doesn't support it (likely VideoPredictor)"
|
|
761
|
+
)
|
|
563
762
|
|
|
564
763
|
# Update the label layer
|
|
565
764
|
self._update_label_layer()
|
|
566
765
|
|
|
766
|
+
# Show instructions
|
|
767
|
+
self.viewer.status = (
|
|
768
|
+
"2D Mode: Click on the image to add objects. Use Shift+click for negative points to refine. "
|
|
769
|
+
"Click existing objects to select them for cropping. "
|
|
770
|
+
"Note: For stacks, interactive segmentation only works in 2D view mode."
|
|
771
|
+
)
|
|
772
|
+
|
|
567
773
|
def _generate_3d_segmentation(self, confidence_threshold, image_path):
|
|
568
774
|
"""
|
|
569
775
|
Initialize 3D segmentation using SAM2 Video Predictor.
|
|
@@ -584,9 +790,7 @@ class BatchCropAnything:
|
|
|
584
790
|
import tempfile
|
|
585
791
|
|
|
586
792
|
temp_dir = tempfile.gettempdir()
|
|
587
|
-
mp4_path =
|
|
588
|
-
temp_dir, f"temp_volume_{os.path.basename(image_path)}.mp4"
|
|
589
|
-
)
|
|
793
|
+
mp4_path = None
|
|
590
794
|
|
|
591
795
|
# If we need to save a modified version for MP4 conversion
|
|
592
796
|
need_temp_tif = False
|
|
@@ -616,31 +820,72 @@ class BatchCropAnything:
|
|
|
616
820
|
imwrite(temp_tif_path, projected_volume)
|
|
617
821
|
need_temp_tif = True
|
|
618
822
|
|
|
619
|
-
#
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
823
|
+
# Check if MP4 already exists
|
|
824
|
+
expected_mp4 = str(Path(temp_tif_path).with_suffix(".mp4"))
|
|
825
|
+
if os.path.exists(expected_mp4):
|
|
826
|
+
self.viewer.status = (
|
|
827
|
+
f"Using existing MP4: {os.path.basename(expected_mp4)}"
|
|
828
|
+
)
|
|
829
|
+
print(
|
|
830
|
+
f"DEBUG: MP4 already exists, skipping conversion: {expected_mp4}"
|
|
831
|
+
)
|
|
832
|
+
mp4_path = expected_mp4
|
|
833
|
+
else:
|
|
834
|
+
# Convert the projected TIF to MP4
|
|
835
|
+
self.viewer.status = "Converting projected 3D volume to MP4 format for SAM2..."
|
|
836
|
+
mp4_path = tif_to_mp4(temp_tif_path)
|
|
624
837
|
else:
|
|
625
|
-
#
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
838
|
+
# Check if MP4 already exists for the original image
|
|
839
|
+
expected_mp4 = str(Path(image_path).with_suffix(".mp4"))
|
|
840
|
+
if os.path.exists(expected_mp4):
|
|
841
|
+
self.viewer.status = (
|
|
842
|
+
f"Using existing MP4: {os.path.basename(expected_mp4)}"
|
|
843
|
+
)
|
|
844
|
+
print(
|
|
845
|
+
f"DEBUG: MP4 already exists, skipping conversion: {expected_mp4}"
|
|
846
|
+
)
|
|
847
|
+
mp4_path = expected_mp4
|
|
848
|
+
else:
|
|
849
|
+
# Convert original volume to video format for SAM2
|
|
850
|
+
self.viewer.status = (
|
|
851
|
+
"Converting 3D volume to MP4 format for SAM2..."
|
|
852
|
+
)
|
|
853
|
+
mp4_path = tif_to_mp4(image_path)
|
|
630
854
|
|
|
631
855
|
# Initialize SAM2 state with the video
|
|
632
856
|
self.viewer.status = "Initializing SAM2 Video Predictor..."
|
|
633
|
-
|
|
634
|
-
"cuda"
|
|
635
|
-
|
|
636
|
-
|
|
857
|
+
try:
|
|
858
|
+
device_type = "cuda" if self.device.type == "cuda" else "cpu"
|
|
859
|
+
with (
|
|
860
|
+
torch.inference_mode(),
|
|
861
|
+
torch.autocast(device_type, dtype=torch.float32),
|
|
862
|
+
):
|
|
863
|
+
self._sam2_state = self.predictor.init_state(mp4_path)
|
|
864
|
+
except (
|
|
865
|
+
RuntimeError,
|
|
866
|
+
ValueError,
|
|
867
|
+
TypeError,
|
|
868
|
+
torch.cuda.OutOfMemoryError,
|
|
869
|
+
) as e:
|
|
870
|
+
self.viewer.status = (
|
|
871
|
+
f"Error initializing SAM2 video predictor: {str(e)}"
|
|
872
|
+
)
|
|
873
|
+
print(f"SAM2 video predictor initialization failed: {e}")
|
|
874
|
+
return
|
|
637
875
|
|
|
638
876
|
# Store needed state for 3D processing
|
|
639
877
|
self._sam2_next_obj_id = 1
|
|
878
|
+
print(
|
|
879
|
+
"DEBUG: Reset _sam2_next_obj_id to 1 in _generate_3d_segmentation"
|
|
880
|
+
)
|
|
640
881
|
self._sam2_prompts = (
|
|
641
882
|
{}
|
|
642
883
|
) # Store prompts for each object (points, labels, box)
|
|
643
884
|
|
|
885
|
+
# Reset SAM2-specific tracking dictionaries for 3D mode
|
|
886
|
+
self.sam2_points_by_obj = {}
|
|
887
|
+
self.sam2_labels_by_obj = {}
|
|
888
|
+
|
|
644
889
|
# Update the label layer with empty segmentation
|
|
645
890
|
self._update_label_layer()
|
|
646
891
|
|
|
@@ -648,8 +893,10 @@ class BatchCropAnything:
|
|
|
648
893
|
if self.label_layer is not None and hasattr(
|
|
649
894
|
self.label_layer, "mouse_drag_callbacks"
|
|
650
895
|
):
|
|
896
|
+
# Safely remove all existing callbacks
|
|
651
897
|
for callback in list(self.label_layer.mouse_drag_callbacks):
|
|
652
|
-
|
|
898
|
+
with contextlib.suppress(ValueError):
|
|
899
|
+
self.label_layer.mouse_drag_callbacks.remove(callback)
|
|
653
900
|
|
|
654
901
|
# Add 3D-specific click handler
|
|
655
902
|
self.label_layer.mouse_drag_callbacks.append(
|
|
@@ -673,8 +920,8 @@ class BatchCropAnything:
|
|
|
673
920
|
|
|
674
921
|
# Show instructions
|
|
675
922
|
self.viewer.status = (
|
|
676
|
-
"3D Mode active: Navigate to the
|
|
677
|
-
"Use Shift+click for negative points
|
|
923
|
+
"3D Mode active: IMPORTANT - Navigate to the FIRST SLICE where object appears (using slider), "
|
|
924
|
+
"then click on object in 2D view (not 3D view). Use Shift+click for negative points. "
|
|
678
925
|
"Segmentation will be propagated to all frames automatically."
|
|
679
926
|
)
|
|
680
927
|
|
|
@@ -728,6 +975,9 @@ class BatchCropAnything:
|
|
|
728
975
|
# Create new object for positive points on background
|
|
729
976
|
ann_obj_id = self._sam2_next_obj_id
|
|
730
977
|
if point_label > 0 and label_id == 0:
|
|
978
|
+
print(
|
|
979
|
+
f"DEBUG: Incrementing _sam2_next_obj_id from {self._sam2_next_obj_id} to {self._sam2_next_obj_id + 1}"
|
|
980
|
+
)
|
|
731
981
|
self._sam2_next_obj_id += 1
|
|
732
982
|
|
|
733
983
|
# Find or create points layer for this object
|
|
@@ -915,8 +1165,10 @@ class BatchCropAnything:
|
|
|
915
1165
|
# Try to perform SAM2 propagation with error handling
|
|
916
1166
|
try:
|
|
917
1167
|
# Use torch.inference_mode() and torch.autocast to ensure consistent dtypes
|
|
918
|
-
|
|
919
|
-
|
|
1168
|
+
device_type = "cuda" if self.device.type == "cuda" else "cpu"
|
|
1169
|
+
with (
|
|
1170
|
+
torch.inference_mode(),
|
|
1171
|
+
torch.autocast(device_type, dtype=torch.float32),
|
|
920
1172
|
):
|
|
921
1173
|
# Attempt to run SAM2 propagation - this will iterate through all frames
|
|
922
1174
|
for (
|
|
@@ -1012,7 +1264,11 @@ class BatchCropAnything:
|
|
|
1012
1264
|
time.sleep(2)
|
|
1013
1265
|
for layer in list(self.viewer.layers):
|
|
1014
1266
|
if "Propagation Progress" in layer.name:
|
|
1015
|
-
|
|
1267
|
+
# Clean up callbacks before removing the layer to prevent cleanup issues
|
|
1268
|
+
if hasattr(layer, "mouse_drag_callbacks"):
|
|
1269
|
+
layer.mouse_drag_callbacks.clear()
|
|
1270
|
+
with contextlib.suppress(ValueError):
|
|
1271
|
+
self.viewer.layers.remove(layer)
|
|
1016
1272
|
|
|
1017
1273
|
threading.Thread(target=remove_progress).start()
|
|
1018
1274
|
|
|
@@ -1035,6 +1291,7 @@ class BatchCropAnything:
|
|
|
1035
1291
|
Given a 3D coordinate (x, y, z), run SAM2 video predictor to segment the object at that point,
|
|
1036
1292
|
update the segmentation result and label layer.
|
|
1037
1293
|
"""
|
|
1294
|
+
device_type = "cuda" if self.device.type == "cuda" else "cpu"
|
|
1038
1295
|
if not hasattr(self, "_sam2_state") or self._sam2_state is None:
|
|
1039
1296
|
self.viewer.status = "SAM2 3D state not initialized."
|
|
1040
1297
|
return
|
|
@@ -1048,8 +1305,9 @@ class BatchCropAnything:
|
|
|
1048
1305
|
point_coords = np.array([[x, y, z]])
|
|
1049
1306
|
point_labels = np.array([1]) # 1 = foreground
|
|
1050
1307
|
|
|
1051
|
-
with
|
|
1052
|
-
|
|
1308
|
+
with (
|
|
1309
|
+
torch.inference_mode(),
|
|
1310
|
+
torch.autocast(device_type, dtype=torch.float32),
|
|
1053
1311
|
):
|
|
1054
1312
|
masks, scores, _ = self.predictor.predict(
|
|
1055
1313
|
state=self._sam2_state,
|
|
@@ -1103,7 +1361,11 @@ class BatchCropAnything:
|
|
|
1103
1361
|
# Remove existing label layer if it exists
|
|
1104
1362
|
for layer in list(self.viewer.layers):
|
|
1105
1363
|
if isinstance(layer, Labels) and "Segmentation" in layer.name:
|
|
1106
|
-
|
|
1364
|
+
# Clean up callbacks before removing the layer to prevent cleanup issues
|
|
1365
|
+
if hasattr(layer, "mouse_drag_callbacks"):
|
|
1366
|
+
layer.mouse_drag_callbacks.clear()
|
|
1367
|
+
with contextlib.suppress(ValueError):
|
|
1368
|
+
self.viewer.layers.remove(layer)
|
|
1107
1369
|
|
|
1108
1370
|
# Add label layer to viewer
|
|
1109
1371
|
self.label_layer = self.viewer.add_labels(
|
|
@@ -1112,10 +1374,36 @@ class BatchCropAnything:
|
|
|
1112
1374
|
opacity=0.7,
|
|
1113
1375
|
)
|
|
1114
1376
|
|
|
1115
|
-
#
|
|
1377
|
+
# Connect click handler to the label layer for selection and deletion
|
|
1378
|
+
if hasattr(self.label_layer, "mouse_drag_callbacks"):
|
|
1379
|
+
# Clear existing callbacks to avoid duplicates
|
|
1380
|
+
self.label_layer.mouse_drag_callbacks.clear()
|
|
1381
|
+
# Add our click handler
|
|
1382
|
+
self.label_layer.mouse_drag_callbacks.append(
|
|
1383
|
+
self._on_label_clicked
|
|
1384
|
+
)
|
|
1385
|
+
|
|
1386
|
+
# Create or update interaction layers based on mode
|
|
1387
|
+
if self.prompt_mode == "point":
|
|
1388
|
+
self._ensure_points_layer()
|
|
1389
|
+
self._remove_shapes_layer()
|
|
1390
|
+
else: # box mode
|
|
1391
|
+
self._ensure_shapes_layer()
|
|
1392
|
+
self._remove_points_layer()
|
|
1393
|
+
|
|
1394
|
+
# Update status
|
|
1395
|
+
n_labels = len(np.unique(self.segmentation_result)) - (
|
|
1396
|
+
1 if 0 in np.unique(self.segmentation_result) else 0
|
|
1397
|
+
)
|
|
1398
|
+
self.viewer.status = f"Loaded image {self.current_index + 1}/{len(self.images)} - Found {n_labels} segments"
|
|
1399
|
+
|
|
1400
|
+
def _ensure_points_layer(self):
|
|
1401
|
+
"""Ensure points layer exists and is properly configured."""
|
|
1116
1402
|
points_layer = None
|
|
1117
1403
|
for layer in list(self.viewer.layers):
|
|
1118
|
-
if
|
|
1404
|
+
if (
|
|
1405
|
+
"Points" in layer.name and "Object" not in layer.name
|
|
1406
|
+
): # Main points layer
|
|
1119
1407
|
points_layer = layer
|
|
1120
1408
|
break
|
|
1121
1409
|
|
|
@@ -1131,141 +1419,193 @@ class BatchCropAnything:
|
|
|
1131
1419
|
opacity=0.8,
|
|
1132
1420
|
)
|
|
1133
1421
|
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
)
|
|
1422
|
+
# Connect points layer mouse click event
|
|
1423
|
+
if hasattr(points_layer, "mouse_drag_callbacks"):
|
|
1424
|
+
points_layer.mouse_drag_callbacks.clear()
|
|
1138
1425
|
points_layer.mouse_drag_callbacks.append(
|
|
1139
1426
|
self._on_points_clicked
|
|
1140
1427
|
)
|
|
1141
1428
|
|
|
1142
|
-
# Connect points layer mouse click event
|
|
1143
|
-
points_layer.mouse_drag_callbacks.append(self._on_points_clicked)
|
|
1144
|
-
|
|
1145
1429
|
# Make the points layer active to encourage interaction with it
|
|
1146
1430
|
self.viewer.layers.selection.active = points_layer
|
|
1147
1431
|
|
|
1148
|
-
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
)
|
|
1152
|
-
|
|
1432
|
+
def _ensure_shapes_layer(self):
|
|
1433
|
+
"""Ensure shapes layer exists and is properly configured."""
|
|
1434
|
+
shapes_layer = None
|
|
1435
|
+
for layer in list(self.viewer.layers):
|
|
1436
|
+
if "Rectangles" in layer.name:
|
|
1437
|
+
shapes_layer = layer
|
|
1438
|
+
break
|
|
1153
1439
|
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
|
|
1158
|
-
|
|
1440
|
+
if shapes_layer is None:
|
|
1441
|
+
# Initialize an empty shapes layer
|
|
1442
|
+
shapes_layer = self.viewer.add_shapes(
|
|
1443
|
+
None,
|
|
1444
|
+
shape_type="rectangle",
|
|
1445
|
+
edge_width=3,
|
|
1446
|
+
edge_color="green",
|
|
1447
|
+
face_color="transparent",
|
|
1448
|
+
name="Rectangles (Draw to Segment)",
|
|
1449
|
+
)
|
|
1450
|
+
|
|
1451
|
+
# Store reference
|
|
1452
|
+
self.shapes_layer = shapes_layer
|
|
1453
|
+
|
|
1454
|
+
# Initialize processing flag to prevent re-entry
|
|
1455
|
+
if not hasattr(self, "_processing_rectangle"):
|
|
1456
|
+
self._processing_rectangle = False
|
|
1457
|
+
|
|
1458
|
+
# Always ensure the event is connected (disconnect old ones first to avoid duplicates)
|
|
1459
|
+
# Remove any existing callbacks
|
|
1460
|
+
with contextlib.suppress(Exception):
|
|
1461
|
+
shapes_layer.events.data.disconnect()
|
|
1462
|
+
|
|
1463
|
+
# Connect shape added event
|
|
1464
|
+
@shapes_layer.events.data.connect
|
|
1465
|
+
def on_shape_added(event):
|
|
1466
|
+
print(
|
|
1467
|
+
f"DEBUG: Shape event triggered! Shapes: {len(shapes_layer.data)}, Processing: {self._processing_rectangle}"
|
|
1468
|
+
)
|
|
1469
|
+
|
|
1470
|
+
# Ignore if we're already processing or if there are no shapes
|
|
1471
|
+
if self._processing_rectangle:
|
|
1472
|
+
print("DEBUG: Already processing a rectangle, ignoring event")
|
|
1159
1473
|
return
|
|
1160
1474
|
|
|
1161
|
-
|
|
1162
|
-
|
|
1475
|
+
if len(shapes_layer.data) == 0:
|
|
1476
|
+
print("DEBUG: No shapes present, ignoring event")
|
|
1477
|
+
return
|
|
1163
1478
|
|
|
1164
|
-
#
|
|
1165
|
-
|
|
1166
|
-
|
|
1479
|
+
# Only process if we have exactly 1 shape (newly drawn)
|
|
1480
|
+
if len(shapes_layer.data) == 1:
|
|
1481
|
+
print("DEBUG: New shape detected, processing...")
|
|
1482
|
+
# Set flag to prevent re-entry
|
|
1483
|
+
self._processing_rectangle = True
|
|
1484
|
+
try:
|
|
1485
|
+
# Get the shape
|
|
1486
|
+
self._on_rectangle_added(shapes_layer.data[-1])
|
|
1487
|
+
finally:
|
|
1488
|
+
# Always reset flag
|
|
1489
|
+
self._processing_rectangle = False
|
|
1490
|
+
else:
|
|
1491
|
+
print(
|
|
1492
|
+
f"DEBUG: Multiple shapes present ({len(shapes_layer.data)}), skipping"
|
|
1493
|
+
)
|
|
1167
1494
|
|
|
1168
|
-
|
|
1169
|
-
|
|
1170
|
-
if len(coords) == 3:
|
|
1171
|
-
t, y, x = map(int, coords)
|
|
1172
|
-
elif len(coords) == 2:
|
|
1173
|
-
t = int(self.viewer.dims.current_step[0])
|
|
1174
|
-
y, x = map(int, coords)
|
|
1175
|
-
else:
|
|
1176
|
-
self.viewer.status = (
|
|
1177
|
-
f"Unexpected coordinate dimensions: {coords}"
|
|
1178
|
-
)
|
|
1179
|
-
return
|
|
1495
|
+
# Make the shapes layer active
|
|
1496
|
+
self.viewer.layers.selection.active = shapes_layer
|
|
1180
1497
|
|
|
1181
|
-
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
layer.
|
|
1498
|
+
def _remove_points_layer(self):
|
|
1499
|
+
"""Remove points layer when not in point mode."""
|
|
1500
|
+
for layer in list(self.viewer.layers):
|
|
1501
|
+
if "Points" in layer.name and "Object" not in layer.name:
|
|
1502
|
+
if hasattr(layer, "mouse_drag_callbacks"):
|
|
1503
|
+
layer.mouse_drag_callbacks.clear()
|
|
1504
|
+
with contextlib.suppress(ValueError):
|
|
1505
|
+
self.viewer.layers.remove(layer)
|
|
1187
1506
|
|
|
1188
|
-
|
|
1189
|
-
|
|
1190
|
-
|
|
1191
|
-
|
|
1192
|
-
|
|
1193
|
-
|
|
1194
|
-
|
|
1195
|
-
colors.append("red" if is_negative else "green")
|
|
1196
|
-
layer.face_color = colors
|
|
1507
|
+
def _remove_shapes_layer(self):
|
|
1508
|
+
"""Remove shapes layer when not in box mode."""
|
|
1509
|
+
for layer in list(self.viewer.layers):
|
|
1510
|
+
if "Rectangles" in layer.name:
|
|
1511
|
+
with contextlib.suppress(ValueError):
|
|
1512
|
+
self.viewer.layers.remove(layer)
|
|
1513
|
+
self.shapes_layer = None
|
|
1197
1514
|
|
|
1198
|
-
|
|
1199
|
-
|
|
1200
|
-
|
|
1201
|
-
|
|
1202
|
-
|
|
1515
|
+
def _on_rectangle_added(self, rectangle_coords):
|
|
1516
|
+
"""Handle rectangle selection for segmentation."""
|
|
1517
|
+
print("DEBUG: _on_rectangle_added called!")
|
|
1518
|
+
device_type = "cuda" if self.device.type == "cuda" else "cpu"
|
|
1519
|
+
try:
|
|
1520
|
+
# Rectangle coords are in the form of a 4x2 or 4x3 array (corners)
|
|
1521
|
+
# Convert to bounding box format [x_min, y_min, x_max, y_max]
|
|
1522
|
+
|
|
1523
|
+
# Debug info
|
|
1524
|
+
print(f"DEBUG: Rectangle coords: {rectangle_coords}")
|
|
1525
|
+
print(f"DEBUG: Rectangle coords shape: {rectangle_coords.shape}")
|
|
1526
|
+
print(f"DEBUG: use_3d flag: {self.use_3d}")
|
|
1527
|
+
print(
|
|
1528
|
+
f"DEBUG: Has predictor: {hasattr(self, 'predictor') and self.predictor is not None}"
|
|
1529
|
+
)
|
|
1530
|
+
if hasattr(self, "predictor") and self.predictor is not None:
|
|
1531
|
+
print(
|
|
1532
|
+
f"DEBUG: Predictor type: {type(self.predictor).__name__}"
|
|
1533
|
+
)
|
|
1534
|
+
else:
|
|
1535
|
+
print("DEBUG: No predictor available!")
|
|
1536
|
+
self.viewer.status = "Error: Predictor not initialized"
|
|
1537
|
+
return
|
|
1538
|
+
|
|
1539
|
+
# Check if we're in 3D mode (use the flag, not coordinate shape)
|
|
1540
|
+
# In 3D mode, even when drawing on a 2D slice, we get (4, 2) coords
|
|
1541
|
+
# but we need to treat it as 3D with propagation
|
|
1542
|
+
if (
|
|
1543
|
+
self.use_3d
|
|
1544
|
+
and len(rectangle_coords.shape) == 2
|
|
1545
|
+
and rectangle_coords.shape[0] == 4
|
|
1546
|
+
):
|
|
1547
|
+
print("DEBUG: Processing as 3D rectangle (will propagate)")
|
|
1548
|
+
|
|
1549
|
+
# Get current frame/slice
|
|
1550
|
+
t = int(self.viewer.dims.current_step[0])
|
|
1551
|
+
print(f"DEBUG: Current frame/slice: {t}")
|
|
1552
|
+
|
|
1553
|
+
# Get Y and X bounds from 2D coordinates
|
|
1554
|
+
if rectangle_coords.shape[1] == 3:
|
|
1555
|
+
# If we somehow got 3D coords (T/Z, Y, X)
|
|
1556
|
+
y_coords = rectangle_coords[:, 1]
|
|
1557
|
+
x_coords = rectangle_coords[:, 2]
|
|
1558
|
+
elif rectangle_coords.shape[1] == 2:
|
|
1559
|
+
# More common: 2D coords (Y, X) when drawing on a slice
|
|
1560
|
+
y_coords = rectangle_coords[:, 0]
|
|
1561
|
+
x_coords = rectangle_coords[:, 1]
|
|
1203
1562
|
else:
|
|
1204
|
-
|
|
1205
|
-
|
|
1206
|
-
|
|
1207
|
-
|
|
1208
|
-
|
|
1209
|
-
self._sam2_next_obj_id += 1
|
|
1563
|
+
print(
|
|
1564
|
+
f"DEBUG: Unexpected coordinate dimensions: {rectangle_coords.shape[1]}"
|
|
1565
|
+
)
|
|
1566
|
+
self.viewer.status = "Error: Unexpected rectangle format"
|
|
1567
|
+
return
|
|
1210
1568
|
|
|
1211
|
-
|
|
1212
|
-
|
|
1213
|
-
self.points_data = {}
|
|
1214
|
-
self.points_labels = {}
|
|
1569
|
+
y_min, y_max = int(min(y_coords)), int(max(y_coords))
|
|
1570
|
+
x_min, x_max = int(min(x_coords)), int(max(x_coords))
|
|
1215
1571
|
|
|
1216
|
-
|
|
1217
|
-
|
|
1218
|
-
self.points_labels[obj_id] = []
|
|
1572
|
+
box = np.array([x_min, y_min, x_max, y_max], dtype=np.float32)
|
|
1573
|
+
print(f"DEBUG: Box coordinates: {box}")
|
|
1219
1574
|
|
|
1220
|
-
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
self.
|
|
1575
|
+
# Use SAM2 with box prompt - use _sam2_next_obj_id for 3D mode
|
|
1576
|
+
if not hasattr(self, "_sam2_next_obj_id"):
|
|
1577
|
+
self._sam2_next_obj_id = 1
|
|
1578
|
+
obj_id = self._sam2_next_obj_id
|
|
1579
|
+
self._sam2_next_obj_id += 1
|
|
1580
|
+
print(
|
|
1581
|
+
f"DEBUG: Box mode - using object ID {obj_id}, next will be {self._sam2_next_obj_id}"
|
|
1582
|
+
)
|
|
1224
1583
|
|
|
1225
|
-
#
|
|
1584
|
+
# Store box for this object
|
|
1585
|
+
if not hasattr(self, "obj_boxes"):
|
|
1586
|
+
self.obj_boxes = {}
|
|
1587
|
+
self.obj_boxes[obj_id] = box
|
|
1588
|
+
|
|
1589
|
+
# Perform segmentation with 3D propagation
|
|
1226
1590
|
if (
|
|
1227
1591
|
hasattr(self, "_sam2_state")
|
|
1228
1592
|
and self._sam2_state is not None
|
|
1229
1593
|
):
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
self.points_data[obj_id], dtype=np.float32
|
|
1233
|
-
)
|
|
1234
|
-
labels = np.array(
|
|
1235
|
-
self.points_labels[obj_id], dtype=np.int32
|
|
1594
|
+
self.viewer.status = (
|
|
1595
|
+
f"Segmenting object {obj_id} with box at frame {t}..."
|
|
1236
1596
|
)
|
|
1597
|
+
print(f"DEBUG: Starting segmentation for object {obj_id}")
|
|
1237
1598
|
|
|
1238
|
-
# Create progress layer for visual feedback
|
|
1239
|
-
progress_layer = None
|
|
1240
|
-
for existing_layer in self.viewer.layers:
|
|
1241
|
-
if "Propagation Progress" in existing_layer.name:
|
|
1242
|
-
progress_layer = existing_layer
|
|
1243
|
-
break
|
|
1244
|
-
|
|
1245
|
-
if progress_layer is None:
|
|
1246
|
-
progress_data = np.zeros_like(self.segmentation_result)
|
|
1247
|
-
progress_layer = self.viewer.add_image(
|
|
1248
|
-
progress_data,
|
|
1249
|
-
name="Propagation Progress",
|
|
1250
|
-
colormap="magma",
|
|
1251
|
-
opacity=0.5,
|
|
1252
|
-
visible=True,
|
|
1253
|
-
)
|
|
1254
|
-
|
|
1255
|
-
# First update the current frame immediately
|
|
1256
|
-
self.viewer.status = f"Processing object at frame {t}..."
|
|
1257
|
-
|
|
1258
|
-
# Run SAM2 on current frame
|
|
1259
1599
|
_, out_obj_ids, out_mask_logits = (
|
|
1260
1600
|
self.predictor.add_new_points_or_box(
|
|
1261
1601
|
inference_state=self._sam2_state,
|
|
1262
1602
|
frame_idx=t,
|
|
1263
1603
|
obj_id=obj_id,
|
|
1264
|
-
|
|
1265
|
-
labels=labels,
|
|
1604
|
+
box=box,
|
|
1266
1605
|
)
|
|
1267
1606
|
)
|
|
1268
1607
|
|
|
1608
|
+
print("DEBUG: Segmentation complete, processing mask")
|
|
1269
1609
|
# Update current frame
|
|
1270
1610
|
mask = (out_mask_logits[0] > 0.0).cpu().numpy()
|
|
1271
1611
|
if mask.ndim > 2:
|
|
@@ -1283,21 +1623,380 @@ class BatchCropAnything:
|
|
|
1283
1623
|
anti_aliasing=False,
|
|
1284
1624
|
).astype(bool)
|
|
1285
1625
|
|
|
1286
|
-
# Update segmentation
|
|
1287
|
-
|
|
1288
|
-
|
|
1289
|
-
|
|
1290
|
-
(self.segmentation_result[t] == obj_id) & mask
|
|
1291
|
-
] = 0
|
|
1292
|
-
else:
|
|
1293
|
-
# For positive points, only replace background
|
|
1294
|
-
self.segmentation_result[t][
|
|
1295
|
-
mask & (self.segmentation_result[t] == 0)
|
|
1296
|
-
] = obj_id
|
|
1626
|
+
# Update segmentation
|
|
1627
|
+
self.segmentation_result[t][
|
|
1628
|
+
mask & (self.segmentation_result[t] == 0)
|
|
1629
|
+
] = obj_id
|
|
1297
1630
|
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
|
|
1631
|
+
print(f"DEBUG: Starting propagation for object {obj_id}")
|
|
1632
|
+
# Propagate to all frames
|
|
1633
|
+
self._propagate_mask_for_current_object(obj_id, t)
|
|
1634
|
+
|
|
1635
|
+
# Update UI
|
|
1636
|
+
print("DEBUG: Updating label layer")
|
|
1637
|
+
self._update_label_layer()
|
|
1638
|
+
if (
|
|
1639
|
+
hasattr(self, "label_table_widget")
|
|
1640
|
+
and self.label_table_widget is not None
|
|
1641
|
+
):
|
|
1642
|
+
self._populate_label_table(self.label_table_widget)
|
|
1643
|
+
|
|
1644
|
+
self.viewer.status = (
|
|
1645
|
+
f"Segmented and propagated object {obj_id} from box"
|
|
1646
|
+
)
|
|
1647
|
+
print("DEBUG: Rectangle processing complete!")
|
|
1648
|
+
|
|
1649
|
+
# Keep the rectangle visible after processing
|
|
1650
|
+
# Users can manually delete it if needed
|
|
1651
|
+
# if self.shapes_layer is not None:
|
|
1652
|
+
# self.shapes_layer.data = []
|
|
1653
|
+
else:
|
|
1654
|
+
print("DEBUG: _sam2_state not available")
|
|
1655
|
+
self.viewer.status = (
|
|
1656
|
+
"Error: 3D segmentation state not initialized"
|
|
1657
|
+
)
|
|
1658
|
+
|
|
1659
|
+
elif (
|
|
1660
|
+
not self.use_3d
|
|
1661
|
+
and len(rectangle_coords.shape) == 2
|
|
1662
|
+
and rectangle_coords.shape[1] == 2
|
|
1663
|
+
):
|
|
1664
|
+
# 2D case: rectangle_coords shape is (4, 2) for Y, X
|
|
1665
|
+
if rectangle_coords.shape[0] == 4:
|
|
1666
|
+
# Get Y and X bounds
|
|
1667
|
+
y_coords = rectangle_coords[:, 0]
|
|
1668
|
+
x_coords = rectangle_coords[:, 1]
|
|
1669
|
+
y_min, y_max = int(min(y_coords)), int(max(y_coords))
|
|
1670
|
+
x_min, x_max = int(min(x_coords)), int(max(x_coords))
|
|
1671
|
+
|
|
1672
|
+
box = np.array(
|
|
1673
|
+
[x_min, y_min, x_max, y_max], dtype=np.float32
|
|
1674
|
+
)
|
|
1675
|
+
|
|
1676
|
+
# Use SAM2 with box prompt - use next_obj_id for 2D mode
|
|
1677
|
+
if not hasattr(self, "next_obj_id"):
|
|
1678
|
+
self.next_obj_id = 1
|
|
1679
|
+
obj_id = self.next_obj_id
|
|
1680
|
+
self.next_obj_id += 1
|
|
1681
|
+
print(
|
|
1682
|
+
f"DEBUG: 2D Box mode - using object ID {obj_id}, next will be {self.next_obj_id}"
|
|
1683
|
+
)
|
|
1684
|
+
|
|
1685
|
+
# Store box for this object
|
|
1686
|
+
if not hasattr(self, "obj_boxes"):
|
|
1687
|
+
self.obj_boxes = {}
|
|
1688
|
+
self.obj_boxes[obj_id] = box
|
|
1689
|
+
|
|
1690
|
+
# Perform segmentation
|
|
1691
|
+
if (
|
|
1692
|
+
hasattr(self, "predictor")
|
|
1693
|
+
and self.predictor is not None
|
|
1694
|
+
):
|
|
1695
|
+
# Make sure image is loaded
|
|
1696
|
+
if self.current_image_for_segmentation is None:
|
|
1697
|
+
self.viewer.status = (
|
|
1698
|
+
"No image loaded for segmentation"
|
|
1699
|
+
)
|
|
1700
|
+
return
|
|
1701
|
+
|
|
1702
|
+
# Prepare image for SAM2
|
|
1703
|
+
image = self.current_image_for_segmentation
|
|
1704
|
+
if len(image.shape) == 2:
|
|
1705
|
+
image = np.stack([image] * 3, axis=-1)
|
|
1706
|
+
elif len(image.shape) == 3 and image.shape[2] == 1:
|
|
1707
|
+
image = np.concatenate([image] * 3, axis=2)
|
|
1708
|
+
elif len(image.shape) == 3 and image.shape[2] > 3:
|
|
1709
|
+
image = image[:, :, :3]
|
|
1710
|
+
|
|
1711
|
+
if image.dtype != np.uint8:
|
|
1712
|
+
image = (image / np.max(image) * 255).astype(
|
|
1713
|
+
np.uint8
|
|
1714
|
+
)
|
|
1715
|
+
|
|
1716
|
+
# Set the image in the predictor (only for ImagePredictor, not VideoPredictor)
|
|
1717
|
+
if hasattr(self.predictor, "set_image"):
|
|
1718
|
+
self.predictor.set_image(image)
|
|
1719
|
+
else:
|
|
1720
|
+
self.viewer.status = "Error: Rectangle mode requires Image Predictor (2D mode)"
|
|
1721
|
+
return
|
|
1722
|
+
|
|
1723
|
+
self.viewer.status = (
|
|
1724
|
+
f"Segmenting object {obj_id} with box..."
|
|
1725
|
+
)
|
|
1726
|
+
|
|
1727
|
+
with (
|
|
1728
|
+
torch.inference_mode(),
|
|
1729
|
+
torch.autocast(device_type),
|
|
1730
|
+
):
|
|
1731
|
+
masks, scores, _ = self.predictor.predict(
|
|
1732
|
+
box=box,
|
|
1733
|
+
multimask_output=False,
|
|
1734
|
+
)
|
|
1735
|
+
|
|
1736
|
+
# Get the mask
|
|
1737
|
+
if len(masks) > 0:
|
|
1738
|
+
best_mask = masks[0]
|
|
1739
|
+
|
|
1740
|
+
# Resize if needed
|
|
1741
|
+
if (
|
|
1742
|
+
best_mask.shape
|
|
1743
|
+
!= self.segmentation_result.shape
|
|
1744
|
+
):
|
|
1745
|
+
from skimage.transform import resize
|
|
1746
|
+
|
|
1747
|
+
best_mask = resize(
|
|
1748
|
+
best_mask.astype(float),
|
|
1749
|
+
self.segmentation_result.shape,
|
|
1750
|
+
order=0,
|
|
1751
|
+
preserve_range=True,
|
|
1752
|
+
anti_aliasing=False,
|
|
1753
|
+
).astype(bool)
|
|
1754
|
+
|
|
1755
|
+
# Apply mask (only overwrite background)
|
|
1756
|
+
mask_condition = np.logical_and(
|
|
1757
|
+
best_mask, (self.segmentation_result == 0)
|
|
1758
|
+
)
|
|
1759
|
+
self.segmentation_result[mask_condition] = (
|
|
1760
|
+
obj_id
|
|
1761
|
+
)
|
|
1762
|
+
|
|
1763
|
+
# Update label info
|
|
1764
|
+
area = np.sum(
|
|
1765
|
+
self.segmentation_result == obj_id
|
|
1766
|
+
)
|
|
1767
|
+
y_indices, x_indices = np.where(
|
|
1768
|
+
self.segmentation_result == obj_id
|
|
1769
|
+
)
|
|
1770
|
+
center_y = (
|
|
1771
|
+
np.mean(y_indices)
|
|
1772
|
+
if len(y_indices) > 0
|
|
1773
|
+
else 0
|
|
1774
|
+
)
|
|
1775
|
+
center_x = (
|
|
1776
|
+
np.mean(x_indices)
|
|
1777
|
+
if len(x_indices) > 0
|
|
1778
|
+
else 0
|
|
1779
|
+
)
|
|
1780
|
+
|
|
1781
|
+
self.label_info[obj_id] = {
|
|
1782
|
+
"area": area,
|
|
1783
|
+
"center_y": center_y,
|
|
1784
|
+
"center_x": center_x,
|
|
1785
|
+
"score": float(scores[0]),
|
|
1786
|
+
}
|
|
1787
|
+
|
|
1788
|
+
self.viewer.status = (
|
|
1789
|
+
f"Segmented object {obj_id} from box"
|
|
1790
|
+
)
|
|
1791
|
+
else:
|
|
1792
|
+
self.viewer.status = "No valid mask produced"
|
|
1793
|
+
|
|
1794
|
+
# Update the UI
|
|
1795
|
+
self._update_label_layer()
|
|
1796
|
+
if (
|
|
1797
|
+
hasattr(self, "label_table_widget")
|
|
1798
|
+
and self.label_table_widget is not None
|
|
1799
|
+
):
|
|
1800
|
+
self._populate_label_table(self.label_table_widget)
|
|
1801
|
+
|
|
1802
|
+
# Keep the rectangle visible after processing
|
|
1803
|
+
# Users can manually delete it if needed
|
|
1804
|
+
# if self.shapes_layer is not None:
|
|
1805
|
+
# self.shapes_layer.data = []
|
|
1806
|
+
else:
|
|
1807
|
+
# Unexpected shape dimensions
|
|
1808
|
+
print(
|
|
1809
|
+
f"DEBUG: Unexpected rectangle shape: {rectangle_coords.shape}"
|
|
1810
|
+
)
|
|
1811
|
+
self.viewer.status = f"Error: Unexpected rectangle dimensions {rectangle_coords.shape}. Expected (4,2) for 2D or (4,3) for 3D."
|
|
1812
|
+
|
|
1813
|
+
except (
|
|
1814
|
+
IndexError,
|
|
1815
|
+
KeyError,
|
|
1816
|
+
ValueError,
|
|
1817
|
+
RuntimeError,
|
|
1818
|
+
TypeError,
|
|
1819
|
+
) as e:
|
|
1820
|
+
import traceback
|
|
1821
|
+
|
|
1822
|
+
self.viewer.status = f"Error in rectangle handling: {str(e)}"
|
|
1823
|
+
print("DEBUG: Exception in _on_rectangle_added:")
|
|
1824
|
+
traceback.print_exc()
|
|
1825
|
+
|
|
1826
|
+
def _on_points_clicked(self, layer, event):
|
|
1827
|
+
"""Handle clicks on the points layer for adding/removing points."""
|
|
1828
|
+
device_type = "cuda" if self.device.type == "cuda" else "cpu"
|
|
1829
|
+
try:
|
|
1830
|
+
# Only process clicks, not drags
|
|
1831
|
+
if event.type != "mouse_press":
|
|
1832
|
+
return
|
|
1833
|
+
|
|
1834
|
+
# Check if segmentation result exists
|
|
1835
|
+
if self.segmentation_result is None:
|
|
1836
|
+
self.viewer.status = (
|
|
1837
|
+
"Segmentation not ready. Please wait for image to load."
|
|
1838
|
+
)
|
|
1839
|
+
return
|
|
1840
|
+
|
|
1841
|
+
# Get coordinates of mouse click
|
|
1842
|
+
coords = np.round(event.position).astype(int)
|
|
1843
|
+
|
|
1844
|
+
# Check if Shift is pressed for negative points
|
|
1845
|
+
is_negative = "Shift" in event.modifiers
|
|
1846
|
+
point_label = -1 if is_negative else 1
|
|
1847
|
+
|
|
1848
|
+
# Handle 2D vs 3D coordinates
|
|
1849
|
+
if self.use_3d:
|
|
1850
|
+
if len(coords) == 3:
|
|
1851
|
+
t, y, x = map(int, coords)
|
|
1852
|
+
elif len(coords) == 2:
|
|
1853
|
+
t = int(self.viewer.dims.current_step[0])
|
|
1854
|
+
y, x = map(int, coords)
|
|
1855
|
+
else:
|
|
1856
|
+
self.viewer.status = (
|
|
1857
|
+
f"Unexpected coordinate dimensions: {coords}"
|
|
1858
|
+
)
|
|
1859
|
+
return
|
|
1860
|
+
|
|
1861
|
+
# Add point to the layer immediately for visual feedback
|
|
1862
|
+
new_point = np.array([[t, y, x]])
|
|
1863
|
+
if len(layer.data) == 0:
|
|
1864
|
+
layer.data = new_point
|
|
1865
|
+
else:
|
|
1866
|
+
layer.data = np.vstack([layer.data, new_point])
|
|
1867
|
+
|
|
1868
|
+
# Update point colors
|
|
1869
|
+
colors = layer.face_color
|
|
1870
|
+
if isinstance(colors, list):
|
|
1871
|
+
colors.append("red" if is_negative else "green")
|
|
1872
|
+
else:
|
|
1873
|
+
n_points = len(layer.data)
|
|
1874
|
+
colors = ["green"] * (n_points - 1)
|
|
1875
|
+
colors.append("red" if is_negative else "green")
|
|
1876
|
+
layer.face_color = colors
|
|
1877
|
+
|
|
1878
|
+
# Validate coordinates are within segmentation bounds
|
|
1879
|
+
if (
|
|
1880
|
+
t < 0
|
|
1881
|
+
or t >= self.segmentation_result.shape[0]
|
|
1882
|
+
or y < 0
|
|
1883
|
+
or y >= self.segmentation_result.shape[1]
|
|
1884
|
+
or x < 0
|
|
1885
|
+
or x >= self.segmentation_result.shape[2]
|
|
1886
|
+
):
|
|
1887
|
+
self.viewer.status = (
|
|
1888
|
+
f"Click at ({t}, {y}, {x}) is out of bounds for "
|
|
1889
|
+
f"segmentation shape {self.segmentation_result.shape}. "
|
|
1890
|
+
f"Please click within the image bounds."
|
|
1891
|
+
)
|
|
1892
|
+
# Remove the invalid point that was just added
|
|
1893
|
+
if len(layer.data) > 0:
|
|
1894
|
+
layer.data = layer.data[:-1]
|
|
1895
|
+
return
|
|
1896
|
+
|
|
1897
|
+
# Get the object ID
|
|
1898
|
+
# If clicking on existing segmentation with negative point
|
|
1899
|
+
label_id = self.segmentation_result[t, y, x]
|
|
1900
|
+
if is_negative and label_id > 0:
|
|
1901
|
+
obj_id = label_id
|
|
1902
|
+
else:
|
|
1903
|
+
# For new objects or negative on background
|
|
1904
|
+
if not hasattr(self, "_sam2_next_obj_id"):
|
|
1905
|
+
self._sam2_next_obj_id = 1
|
|
1906
|
+
obj_id = self._sam2_next_obj_id
|
|
1907
|
+
if point_label > 0 and label_id == 0:
|
|
1908
|
+
self._sam2_next_obj_id += 1
|
|
1909
|
+
|
|
1910
|
+
# Store point information
|
|
1911
|
+
if not hasattr(self, "points_data"):
|
|
1912
|
+
self.points_data = {}
|
|
1913
|
+
self.points_labels = {}
|
|
1914
|
+
|
|
1915
|
+
if obj_id not in self.points_data:
|
|
1916
|
+
self.points_data[obj_id] = []
|
|
1917
|
+
self.points_labels[obj_id] = []
|
|
1918
|
+
|
|
1919
|
+
self.points_data[obj_id].append(
|
|
1920
|
+
[x, y]
|
|
1921
|
+
) # Note: SAM2 expects [x,y] format
|
|
1922
|
+
self.points_labels[obj_id].append(point_label)
|
|
1923
|
+
|
|
1924
|
+
# Perform segmentation
|
|
1925
|
+
if (
|
|
1926
|
+
hasattr(self, "_sam2_state")
|
|
1927
|
+
and self._sam2_state is not None
|
|
1928
|
+
):
|
|
1929
|
+
# Prepare points
|
|
1930
|
+
points = np.array(
|
|
1931
|
+
self.points_data[obj_id], dtype=np.float32
|
|
1932
|
+
)
|
|
1933
|
+
labels = np.array(
|
|
1934
|
+
self.points_labels[obj_id], dtype=np.int32
|
|
1935
|
+
)
|
|
1936
|
+
|
|
1937
|
+
# Create progress layer for visual feedback
|
|
1938
|
+
progress_layer = None
|
|
1939
|
+
for existing_layer in self.viewer.layers:
|
|
1940
|
+
if "Propagation Progress" in existing_layer.name:
|
|
1941
|
+
progress_layer = existing_layer
|
|
1942
|
+
break
|
|
1943
|
+
|
|
1944
|
+
if progress_layer is None:
|
|
1945
|
+
progress_data = np.zeros_like(self.segmentation_result)
|
|
1946
|
+
progress_layer = self.viewer.add_image(
|
|
1947
|
+
progress_data,
|
|
1948
|
+
name="Propagation Progress",
|
|
1949
|
+
colormap="magma",
|
|
1950
|
+
opacity=0.5,
|
|
1951
|
+
visible=True,
|
|
1952
|
+
)
|
|
1953
|
+
|
|
1954
|
+
# First update the current frame immediately
|
|
1955
|
+
self.viewer.status = f"Processing object at frame {t}..."
|
|
1956
|
+
|
|
1957
|
+
# Run SAM2 on current frame
|
|
1958
|
+
_, out_obj_ids, out_mask_logits = (
|
|
1959
|
+
self.predictor.add_new_points_or_box(
|
|
1960
|
+
inference_state=self._sam2_state,
|
|
1961
|
+
frame_idx=t,
|
|
1962
|
+
obj_id=obj_id,
|
|
1963
|
+
points=points,
|
|
1964
|
+
labels=labels,
|
|
1965
|
+
)
|
|
1966
|
+
)
|
|
1967
|
+
|
|
1968
|
+
# Update current frame
|
|
1969
|
+
mask = (out_mask_logits[0] > 0.0).cpu().numpy()
|
|
1970
|
+
if mask.ndim > 2:
|
|
1971
|
+
mask = mask.squeeze()
|
|
1972
|
+
|
|
1973
|
+
# Resize if needed
|
|
1974
|
+
if mask.shape != self.segmentation_result[t].shape:
|
|
1975
|
+
from skimage.transform import resize
|
|
1976
|
+
|
|
1977
|
+
mask = resize(
|
|
1978
|
+
mask.astype(float),
|
|
1979
|
+
self.segmentation_result[t].shape,
|
|
1980
|
+
order=0,
|
|
1981
|
+
preserve_range=True,
|
|
1982
|
+
anti_aliasing=False,
|
|
1983
|
+
).astype(bool)
|
|
1984
|
+
|
|
1985
|
+
# Update segmentation for this frame
|
|
1986
|
+
if point_label < 0:
|
|
1987
|
+
# For negative points, only remove from this object
|
|
1988
|
+
self.segmentation_result[t][
|
|
1989
|
+
(self.segmentation_result[t] == obj_id) & mask
|
|
1990
|
+
] = 0
|
|
1991
|
+
else:
|
|
1992
|
+
# For positive points, only replace background
|
|
1993
|
+
self.segmentation_result[t][
|
|
1994
|
+
mask & (self.segmentation_result[t] == 0)
|
|
1995
|
+
] = obj_id
|
|
1996
|
+
|
|
1997
|
+
# Update progress layer for this frame
|
|
1998
|
+
progress_data = progress_layer.data
|
|
1999
|
+
progress_data[t] = (
|
|
1301
2000
|
mask.astype(float) * 0.5
|
|
1302
2001
|
) # Highlight current frame
|
|
1303
2002
|
progress_layer.data = progress_data
|
|
@@ -1398,7 +2097,11 @@ class BatchCropAnything:
|
|
|
1398
2097
|
time.sleep(2)
|
|
1399
2098
|
for layer in list(self.viewer.layers):
|
|
1400
2099
|
if "Propagation Progress" in layer.name:
|
|
1401
|
-
|
|
2100
|
+
# Clean up callbacks before removing the layer to prevent cleanup issues
|
|
2101
|
+
if hasattr(layer, "mouse_drag_callbacks"):
|
|
2102
|
+
layer.mouse_drag_callbacks.clear()
|
|
2103
|
+
with contextlib.suppress(ValueError):
|
|
2104
|
+
self.viewer.layers.remove(layer)
|
|
1402
2105
|
|
|
1403
2106
|
threading.Thread(target=remove_progress).start()
|
|
1404
2107
|
|
|
@@ -1439,6 +2142,23 @@ class BatchCropAnything:
|
|
|
1439
2142
|
colors.append("red" if is_negative else "green")
|
|
1440
2143
|
layer.face_color = colors
|
|
1441
2144
|
|
|
2145
|
+
# Validate coordinates are within segmentation bounds
|
|
2146
|
+
if (
|
|
2147
|
+
y < 0
|
|
2148
|
+
or y >= self.segmentation_result.shape[0]
|
|
2149
|
+
or x < 0
|
|
2150
|
+
or x >= self.segmentation_result.shape[1]
|
|
2151
|
+
):
|
|
2152
|
+
self.viewer.status = (
|
|
2153
|
+
f"Click at ({y}, {x}) is out of bounds for "
|
|
2154
|
+
f"segmentation shape {self.segmentation_result.shape}. "
|
|
2155
|
+
f"Please click within the image bounds."
|
|
2156
|
+
)
|
|
2157
|
+
# Remove the invalid point that was just added
|
|
2158
|
+
if len(layer.data) > 0:
|
|
2159
|
+
layer.data = layer.data[:-1]
|
|
2160
|
+
return
|
|
2161
|
+
|
|
1442
2162
|
# Get object ID
|
|
1443
2163
|
label_id = self.segmentation_result[y, x]
|
|
1444
2164
|
if is_negative and label_id > 0:
|
|
@@ -1483,8 +2203,14 @@ class BatchCropAnything:
|
|
|
1483
2203
|
if image.dtype != np.uint8:
|
|
1484
2204
|
image = (image / np.max(image) * 255).astype(np.uint8)
|
|
1485
2205
|
|
|
1486
|
-
# Set the image in the predictor
|
|
1487
|
-
self.predictor
|
|
2206
|
+
# Set the image in the predictor (only for ImagePredictor, not VideoPredictor)
|
|
2207
|
+
if hasattr(self.predictor, "set_image"):
|
|
2208
|
+
self.predictor.set_image(image)
|
|
2209
|
+
else:
|
|
2210
|
+
self.viewer.status = (
|
|
2211
|
+
"Error: Point mode in 2D requires Image Predictor"
|
|
2212
|
+
)
|
|
2213
|
+
return
|
|
1488
2214
|
|
|
1489
2215
|
# Use only points for current object
|
|
1490
2216
|
points = np.array(
|
|
@@ -1494,7 +2220,7 @@ class BatchCropAnything:
|
|
|
1494
2220
|
|
|
1495
2221
|
self.viewer.status = f"Segmenting object {obj_id} with {len(points)} points..."
|
|
1496
2222
|
|
|
1497
|
-
with torch.inference_mode(), torch.autocast(
|
|
2223
|
+
with torch.inference_mode(), torch.autocast(device_type):
|
|
1498
2224
|
masks, scores, _ = self.predictor.predict(
|
|
1499
2225
|
point_coords=points,
|
|
1500
2226
|
point_labels=labels,
|
|
@@ -1583,16 +2309,23 @@ class BatchCropAnything:
|
|
|
1583
2309
|
def _on_label_clicked(self, layer, event):
|
|
1584
2310
|
"""Handle label selection and user prompts on mouse click."""
|
|
1585
2311
|
try:
|
|
1586
|
-
# Only process
|
|
2312
|
+
# Only process mouse press events
|
|
1587
2313
|
if event.type != "mouse_press":
|
|
1588
2314
|
return
|
|
1589
2315
|
|
|
2316
|
+
# Only handle left mouse button
|
|
2317
|
+
if event.button != 1:
|
|
2318
|
+
return
|
|
2319
|
+
|
|
1590
2320
|
# Get coordinates of mouse click
|
|
1591
2321
|
coords = np.round(event.position).astype(int)
|
|
1592
2322
|
|
|
1593
|
-
# Check
|
|
2323
|
+
# Check modifiers
|
|
1594
2324
|
is_negative = "Shift" in event.modifiers
|
|
1595
|
-
|
|
2325
|
+
is_control = (
|
|
2326
|
+
"Control" in event.modifiers or "Ctrl" in event.modifiers
|
|
2327
|
+
)
|
|
2328
|
+
# point_label = -1 if is_negative else 1
|
|
1596
2329
|
|
|
1597
2330
|
# For 2D data
|
|
1598
2331
|
if not self.use_3d:
|
|
@@ -1613,262 +2346,13 @@ class BatchCropAnything:
|
|
|
1613
2346
|
# Get the label ID at the clicked position
|
|
1614
2347
|
label_id = self.segmentation_result[y, x]
|
|
1615
2348
|
|
|
1616
|
-
#
|
|
1617
|
-
if
|
|
1618
|
-
|
|
1619
|
-
|
|
1620
|
-
self.next_obj_id = (
|
|
1621
|
-
int(self.segmentation_result.max()) + 1
|
|
1622
|
-
)
|
|
1623
|
-
else:
|
|
1624
|
-
self.next_obj_id = 1
|
|
1625
|
-
|
|
1626
|
-
# If clicking on background or using negative click, handle segmentation
|
|
1627
|
-
if label_id == 0 or is_negative:
|
|
1628
|
-
# Find or create points layer for the current object we're working on
|
|
1629
|
-
current_obj_id = None
|
|
1630
|
-
|
|
1631
|
-
# If negative point on existing label, use that label's ID
|
|
1632
|
-
if is_negative and label_id > 0:
|
|
1633
|
-
current_obj_id = label_id
|
|
1634
|
-
# For positive clicks on background, create a new object
|
|
1635
|
-
elif point_label > 0 and label_id == 0:
|
|
1636
|
-
current_obj_id = self.next_obj_id
|
|
1637
|
-
self.next_obj_id += 1
|
|
1638
|
-
# For negative on background, try to find most recent object
|
|
1639
|
-
elif point_label < 0 and label_id == 0:
|
|
1640
|
-
# Use most recently created object if available
|
|
1641
|
-
if hasattr(self, "obj_points") and self.obj_points:
|
|
1642
|
-
current_obj_id = max(self.obj_points.keys())
|
|
1643
|
-
else:
|
|
1644
|
-
self.viewer.status = "No existing object to modify with negative point"
|
|
1645
|
-
return
|
|
1646
|
-
|
|
1647
|
-
if current_obj_id is None:
|
|
1648
|
-
self.viewer.status = (
|
|
1649
|
-
"Could not determine which object to modify"
|
|
1650
|
-
)
|
|
1651
|
-
return
|
|
1652
|
-
|
|
1653
|
-
# Find or create points layer for this object
|
|
1654
|
-
points_layer = None
|
|
1655
|
-
for layer in list(self.viewer.layers):
|
|
1656
|
-
if f"Points for Object {current_obj_id}" in layer.name:
|
|
1657
|
-
points_layer = layer
|
|
1658
|
-
break
|
|
1659
|
-
|
|
1660
|
-
# Initialize object tracking if needed
|
|
1661
|
-
if not hasattr(self, "obj_points"):
|
|
1662
|
-
self.obj_points = {}
|
|
1663
|
-
self.obj_labels = {}
|
|
1664
|
-
|
|
1665
|
-
if current_obj_id not in self.obj_points:
|
|
1666
|
-
self.obj_points[current_obj_id] = []
|
|
1667
|
-
self.obj_labels[current_obj_id] = []
|
|
1668
|
-
|
|
1669
|
-
# Create or update points layer for this object
|
|
1670
|
-
if points_layer is None:
|
|
1671
|
-
# First point for this object
|
|
1672
|
-
points_layer = self.viewer.add_points(
|
|
1673
|
-
np.array([[y, x]]),
|
|
1674
|
-
name=f"Points for Object {current_obj_id}",
|
|
1675
|
-
size=10,
|
|
1676
|
-
face_color=["green" if point_label > 0 else "red"],
|
|
1677
|
-
border_color="white",
|
|
1678
|
-
border_width=1,
|
|
1679
|
-
opacity=0.8,
|
|
1680
|
-
)
|
|
1681
|
-
with contextlib.suppress(AttributeError, ValueError):
|
|
1682
|
-
points_layer.mouse_drag_callbacks.remove(
|
|
1683
|
-
self._on_points_clicked
|
|
1684
|
-
)
|
|
1685
|
-
points_layer.mouse_drag_callbacks.append(
|
|
1686
|
-
self._on_points_clicked
|
|
1687
|
-
)
|
|
1688
|
-
|
|
1689
|
-
self.obj_points[current_obj_id] = [[x, y]]
|
|
1690
|
-
self.obj_labels[current_obj_id] = [point_label]
|
|
1691
|
-
else:
|
|
1692
|
-
# Add point to existing layer
|
|
1693
|
-
current_points = points_layer.data
|
|
1694
|
-
current_colors = points_layer.face_color
|
|
1695
|
-
|
|
1696
|
-
# Add new point
|
|
1697
|
-
new_points = np.vstack([current_points, [y, x]])
|
|
1698
|
-
new_color = "green" if point_label > 0 else "red"
|
|
1699
|
-
|
|
1700
|
-
# Update points layer
|
|
1701
|
-
points_layer.data = new_points
|
|
1702
|
-
|
|
1703
|
-
# Update colors
|
|
1704
|
-
if isinstance(current_colors, list):
|
|
1705
|
-
current_colors.append(new_color)
|
|
1706
|
-
points_layer.face_color = current_colors
|
|
1707
|
-
else:
|
|
1708
|
-
# If it's an array, create a list of colors
|
|
1709
|
-
colors = []
|
|
1710
|
-
for i in range(len(new_points)):
|
|
1711
|
-
if i < len(current_points):
|
|
1712
|
-
colors.append(
|
|
1713
|
-
"green" if point_label > 0 else "red"
|
|
1714
|
-
)
|
|
1715
|
-
else:
|
|
1716
|
-
colors.append(new_color)
|
|
1717
|
-
points_layer.face_color = colors
|
|
1718
|
-
|
|
1719
|
-
# Update object tracking
|
|
1720
|
-
self.obj_points[current_obj_id].append([x, y])
|
|
1721
|
-
self.obj_labels[current_obj_id].append(point_label)
|
|
1722
|
-
|
|
1723
|
-
# Now do the actual segmentation using SAM2
|
|
1724
|
-
if (
|
|
1725
|
-
hasattr(self, "predictor")
|
|
1726
|
-
and self.predictor is not None
|
|
1727
|
-
):
|
|
1728
|
-
try:
|
|
1729
|
-
# Make sure image is loaded
|
|
1730
|
-
if self.current_image_for_segmentation is None:
|
|
1731
|
-
self.viewer.status = (
|
|
1732
|
-
"No image loaded for segmentation"
|
|
1733
|
-
)
|
|
1734
|
-
return
|
|
1735
|
-
|
|
1736
|
-
# Prepare image for SAM2
|
|
1737
|
-
image = self.current_image_for_segmentation
|
|
1738
|
-
if len(image.shape) == 2:
|
|
1739
|
-
image = np.stack([image] * 3, axis=-1)
|
|
1740
|
-
elif len(image.shape) == 3 and image.shape[2] == 1:
|
|
1741
|
-
image = np.concatenate([image] * 3, axis=2)
|
|
1742
|
-
elif len(image.shape) == 3 and image.shape[2] > 3:
|
|
1743
|
-
image = image[:, :, :3]
|
|
1744
|
-
|
|
1745
|
-
if image.dtype != np.uint8:
|
|
1746
|
-
image = (image / np.max(image) * 255).astype(
|
|
1747
|
-
np.uint8
|
|
1748
|
-
)
|
|
1749
|
-
|
|
1750
|
-
# Set the image in the predictor
|
|
1751
|
-
self.predictor.set_image(image)
|
|
1752
|
-
|
|
1753
|
-
# Only use the points for the current object being segmented
|
|
1754
|
-
points = np.array(
|
|
1755
|
-
self.obj_points[current_obj_id],
|
|
1756
|
-
dtype=np.float32,
|
|
1757
|
-
)
|
|
1758
|
-
labels = np.array(
|
|
1759
|
-
self.obj_labels[current_obj_id], dtype=np.int32
|
|
1760
|
-
)
|
|
1761
|
-
|
|
1762
|
-
self.viewer.status = f"Segmenting object {current_obj_id} with {len(points)} points..."
|
|
1763
|
-
|
|
1764
|
-
with torch.inference_mode(), torch.autocast(
|
|
1765
|
-
"cuda"
|
|
1766
|
-
):
|
|
1767
|
-
masks, scores, _ = self.predictor.predict(
|
|
1768
|
-
point_coords=points,
|
|
1769
|
-
point_labels=labels,
|
|
1770
|
-
multimask_output=True,
|
|
1771
|
-
)
|
|
1772
|
-
|
|
1773
|
-
# Get best mask
|
|
1774
|
-
if len(masks) > 0:
|
|
1775
|
-
best_mask = masks[0]
|
|
1776
|
-
|
|
1777
|
-
# Update segmentation result
|
|
1778
|
-
if (
|
|
1779
|
-
best_mask.shape
|
|
1780
|
-
!= self.segmentation_result.shape
|
|
1781
|
-
):
|
|
1782
|
-
from skimage.transform import resize
|
|
1783
|
-
|
|
1784
|
-
best_mask = resize(
|
|
1785
|
-
best_mask.astype(float),
|
|
1786
|
-
self.segmentation_result.shape,
|
|
1787
|
-
order=0,
|
|
1788
|
-
preserve_range=True,
|
|
1789
|
-
anti_aliasing=False,
|
|
1790
|
-
).astype(bool)
|
|
1791
|
-
|
|
1792
|
-
# CRITICAL FIX: For negative points, only remove from this object's mask
|
|
1793
|
-
# For positive points, add to this object's mask without removing other objects
|
|
1794
|
-
if point_label < 0:
|
|
1795
|
-
# Remove only from current object's mask
|
|
1796
|
-
self.segmentation_result[
|
|
1797
|
-
(
|
|
1798
|
-
self.segmentation_result
|
|
1799
|
-
== current_obj_id
|
|
1800
|
-
)
|
|
1801
|
-
& best_mask
|
|
1802
|
-
] = 0
|
|
1803
|
-
else:
|
|
1804
|
-
# Add to current object's mask without affecting other objects
|
|
1805
|
-
# Only overwrite background (value 0)
|
|
1806
|
-
self.segmentation_result[
|
|
1807
|
-
best_mask
|
|
1808
|
-
& (self.segmentation_result == 0)
|
|
1809
|
-
] = current_obj_id
|
|
1810
|
-
|
|
1811
|
-
# Update label info
|
|
1812
|
-
area = np.sum(
|
|
1813
|
-
self.segmentation_result
|
|
1814
|
-
== current_obj_id
|
|
1815
|
-
)
|
|
1816
|
-
y_indices, x_indices = np.where(
|
|
1817
|
-
self.segmentation_result
|
|
1818
|
-
== current_obj_id
|
|
1819
|
-
)
|
|
1820
|
-
center_y = (
|
|
1821
|
-
np.mean(y_indices)
|
|
1822
|
-
if len(y_indices) > 0
|
|
1823
|
-
else 0
|
|
1824
|
-
)
|
|
1825
|
-
center_x = (
|
|
1826
|
-
np.mean(x_indices)
|
|
1827
|
-
if len(x_indices) > 0
|
|
1828
|
-
else 0
|
|
1829
|
-
)
|
|
1830
|
-
|
|
1831
|
-
self.label_info[current_obj_id] = {
|
|
1832
|
-
"area": area,
|
|
1833
|
-
"center_y": center_y,
|
|
1834
|
-
"center_x": center_x,
|
|
1835
|
-
"score": float(scores[0]),
|
|
1836
|
-
}
|
|
1837
|
-
|
|
1838
|
-
self.viewer.status = (
|
|
1839
|
-
f"Updated object {current_obj_id}"
|
|
1840
|
-
)
|
|
1841
|
-
else:
|
|
1842
|
-
self.viewer.status = (
|
|
1843
|
-
"No valid mask produced"
|
|
1844
|
-
)
|
|
1845
|
-
|
|
1846
|
-
# Update the UI
|
|
1847
|
-
self._update_label_layer()
|
|
1848
|
-
if (
|
|
1849
|
-
hasattr(self, "label_table_widget")
|
|
1850
|
-
and self.label_table_widget is not None
|
|
1851
|
-
):
|
|
1852
|
-
self._populate_label_table(
|
|
1853
|
-
self.label_table_widget
|
|
1854
|
-
)
|
|
1855
|
-
|
|
1856
|
-
except (
|
|
1857
|
-
IndexError,
|
|
1858
|
-
KeyError,
|
|
1859
|
-
ValueError,
|
|
1860
|
-
AttributeError,
|
|
1861
|
-
TypeError,
|
|
1862
|
-
) as e:
|
|
1863
|
-
import traceback
|
|
1864
|
-
|
|
1865
|
-
self.viewer.status = (
|
|
1866
|
-
f"Error in SAM2 processing: {str(e)}"
|
|
1867
|
-
)
|
|
1868
|
-
traceback.print_exc()
|
|
2349
|
+
# Handle Ctrl+Click to clear a single label
|
|
2350
|
+
if is_control and label_id > 0:
|
|
2351
|
+
self.clear_label_at_position(y, x)
|
|
2352
|
+
return
|
|
1869
2353
|
|
|
1870
|
-
# If clicking on an existing label, toggle selection
|
|
1871
|
-
|
|
2354
|
+
# If clicking on an existing label (and not using modifiers), toggle selection
|
|
2355
|
+
if label_id > 0 and not is_negative and not is_control:
|
|
1872
2356
|
# Toggle the label selection
|
|
1873
2357
|
if label_id in self.selected_labels:
|
|
1874
2358
|
self.selected_labels.remove(label_id)
|
|
@@ -1880,8 +2364,14 @@ class BatchCropAnything:
|
|
|
1880
2364
|
# Update table and preview
|
|
1881
2365
|
self._update_label_table()
|
|
1882
2366
|
self.preview_crop()
|
|
2367
|
+
return
|
|
2368
|
+
|
|
2369
|
+
# If clicking on background or using Shift (negative points), this should be handled by points layer
|
|
2370
|
+
# Don't process these clicks here to avoid conflicts
|
|
2371
|
+
if label_id == 0 or is_negative:
|
|
2372
|
+
return
|
|
1883
2373
|
|
|
1884
|
-
# 3D case
|
|
2374
|
+
# 3D case
|
|
1885
2375
|
else:
|
|
1886
2376
|
if len(coords) == 3:
|
|
1887
2377
|
t, y, x = map(int, coords)
|
|
@@ -1910,12 +2400,13 @@ class BatchCropAnything:
|
|
|
1910
2400
|
# Get the label ID at the clicked position
|
|
1911
2401
|
label_id = self.segmentation_result[t, y, x]
|
|
1912
2402
|
|
|
1913
|
-
#
|
|
1914
|
-
if label_id
|
|
1915
|
-
|
|
1916
|
-
|
|
1917
|
-
|
|
1918
|
-
|
|
2403
|
+
# Handle Ctrl+Click to clear a single label
|
|
2404
|
+
if is_control and label_id > 0:
|
|
2405
|
+
self.clear_label_at_position_3d(t, y, x)
|
|
2406
|
+
return
|
|
2407
|
+
|
|
2408
|
+
# If clicking on an existing label and not using negative points, handle selection
|
|
2409
|
+
if label_id > 0 and not is_negative and not is_control:
|
|
1919
2410
|
# Toggle the label selection
|
|
1920
2411
|
if label_id in self.selected_labels:
|
|
1921
2412
|
self.selected_labels.remove(label_id)
|
|
@@ -1926,9 +2417,12 @@ class BatchCropAnything:
|
|
|
1926
2417
|
|
|
1927
2418
|
# Update table if it exists
|
|
1928
2419
|
self._update_label_table()
|
|
1929
|
-
|
|
1930
|
-
# Update preview after selection changes
|
|
1931
2420
|
self.preview_crop()
|
|
2421
|
+
return
|
|
2422
|
+
|
|
2423
|
+
# For background clicks or negative points, let the 3D handler deal with it
|
|
2424
|
+
if label_id == 0 or is_negative:
|
|
2425
|
+
return
|
|
1932
2426
|
|
|
1933
2427
|
except (
|
|
1934
2428
|
IndexError,
|
|
@@ -1942,12 +2436,74 @@ class BatchCropAnything:
|
|
|
1942
2436
|
self.viewer.status = f"Error in click handling: {str(e)}"
|
|
1943
2437
|
traceback.print_exc()
|
|
1944
2438
|
|
|
2439
|
+
def _add_segmentation_point(self, x, y, event):
|
|
2440
|
+
"""Add a point for segmentation."""
|
|
2441
|
+
is_negative = "Shift" in event.modifiers
|
|
2442
|
+
|
|
2443
|
+
# Initialize tracking if needed
|
|
2444
|
+
if not hasattr(self, "current_points"):
|
|
2445
|
+
self.current_points = []
|
|
2446
|
+
self.current_labels = []
|
|
2447
|
+
self.current_obj_id = 1
|
|
2448
|
+
|
|
2449
|
+
# Add point
|
|
2450
|
+
self.current_points.append([x, y])
|
|
2451
|
+
self.current_labels.append(0 if is_negative else 1)
|
|
2452
|
+
|
|
2453
|
+
# Run SAM2 prediction
|
|
2454
|
+
if self.predictor is not None:
|
|
2455
|
+
# Prepare image
|
|
2456
|
+
image = self._prepare_image_for_sam2()
|
|
2457
|
+
|
|
2458
|
+
# Set the image in the predictor (only for ImagePredictor, not VideoPredictor)
|
|
2459
|
+
if hasattr(self.predictor, "set_image"):
|
|
2460
|
+
self.predictor.set_image(image)
|
|
2461
|
+
else:
|
|
2462
|
+
self.viewer.status = (
|
|
2463
|
+
"Error: This operation requires Image Predictor (2D mode)"
|
|
2464
|
+
)
|
|
2465
|
+
return
|
|
2466
|
+
|
|
2467
|
+
# Predict
|
|
2468
|
+
device_type = "cuda" if self.device.type == "cuda" else "cpu"
|
|
2469
|
+
with torch.inference_mode(), torch.autocast(device_type):
|
|
2470
|
+
masks, scores, _ = self.predictor.predict(
|
|
2471
|
+
point_coords=np.array(
|
|
2472
|
+
self.current_points, dtype=np.float32
|
|
2473
|
+
),
|
|
2474
|
+
point_labels=np.array(self.current_labels, dtype=np.int32),
|
|
2475
|
+
multimask_output=False,
|
|
2476
|
+
)
|
|
2477
|
+
|
|
2478
|
+
# Update segmentation
|
|
2479
|
+
if len(masks) > 0:
|
|
2480
|
+
mask = masks[0] > 0.5
|
|
2481
|
+
if self.current_scale_factor < 1.0:
|
|
2482
|
+
mask = resize(
|
|
2483
|
+
mask, self.segmentation_result.shape, order=0
|
|
2484
|
+
).astype(bool)
|
|
2485
|
+
|
|
2486
|
+
# Update segmentation result
|
|
2487
|
+
self.segmentation_result[mask] = self.current_obj_id
|
|
2488
|
+
|
|
2489
|
+
# Move to next object if adding positive point
|
|
2490
|
+
if not is_negative:
|
|
2491
|
+
self.current_obj_id += 1
|
|
2492
|
+
self.current_points = []
|
|
2493
|
+
self.current_labels = []
|
|
2494
|
+
|
|
2495
|
+
self._update_label_layer()
|
|
2496
|
+
|
|
1945
2497
|
def _add_point_marker(self, coords, label_type):
|
|
1946
2498
|
"""Add a visible marker for where the user clicked."""
|
|
1947
2499
|
# Remove previous point markers
|
|
1948
2500
|
for layer in list(self.viewer.layers):
|
|
1949
2501
|
if "Point Prompt" in layer.name:
|
|
1950
|
-
|
|
2502
|
+
# Clean up callbacks before removing the layer to prevent cleanup issues
|
|
2503
|
+
if hasattr(layer, "mouse_drag_callbacks"):
|
|
2504
|
+
layer.mouse_drag_callbacks.clear()
|
|
2505
|
+
with contextlib.suppress(ValueError):
|
|
2506
|
+
self.viewer.layers.remove(layer)
|
|
1951
2507
|
|
|
1952
2508
|
# Create points layer
|
|
1953
2509
|
color = (
|
|
@@ -2135,11 +2691,170 @@ class BatchCropAnything:
|
|
|
2135
2691
|
self.viewer.status = f"Selected all {len(self.selected_labels)} labels"
|
|
2136
2692
|
|
|
2137
2693
|
def clear_selection(self):
|
|
2138
|
-
"""Clear all
|
|
2694
|
+
"""Clear all labels from the segmentation.
|
|
2695
|
+
|
|
2696
|
+
This removes all segmented objects from the label layer, resets all tracking data,
|
|
2697
|
+
and prepares the interface for new segmentations. Note: The method name is kept as
|
|
2698
|
+
'clear_selection' for backwards compatibility, but it clears all labels, not just
|
|
2699
|
+
the selection.
|
|
2700
|
+
"""
|
|
2701
|
+
if self.segmentation_result is None:
|
|
2702
|
+
self.viewer.status = "No segmentation available"
|
|
2703
|
+
return
|
|
2704
|
+
|
|
2705
|
+
# Get all unique label IDs (excluding background 0)
|
|
2706
|
+
unique_labels = np.unique(self.segmentation_result)
|
|
2707
|
+
label_ids = [label for label in unique_labels if label > 0]
|
|
2708
|
+
|
|
2709
|
+
if len(label_ids) == 0:
|
|
2710
|
+
self.viewer.status = "No labels to clear"
|
|
2711
|
+
return
|
|
2712
|
+
|
|
2713
|
+
# Clear the entire segmentation result
|
|
2714
|
+
self.segmentation_result[:] = 0
|
|
2715
|
+
|
|
2716
|
+
# Clear selected labels
|
|
2139
2717
|
self.selected_labels = set()
|
|
2718
|
+
|
|
2719
|
+
# Clear label info
|
|
2720
|
+
self.label_info = {}
|
|
2721
|
+
|
|
2722
|
+
# Remove any object-specific point layers
|
|
2723
|
+
for layer in list(self.viewer.layers):
|
|
2724
|
+
if "Points for Object" in layer.name:
|
|
2725
|
+
# Clean up callbacks before removing the layer to prevent cleanup issues
|
|
2726
|
+
if hasattr(layer, "mouse_drag_callbacks"):
|
|
2727
|
+
layer.mouse_drag_callbacks.clear()
|
|
2728
|
+
with contextlib.suppress(ValueError):
|
|
2729
|
+
self.viewer.layers.remove(layer)
|
|
2730
|
+
|
|
2731
|
+
# Clean up object tracking data
|
|
2732
|
+
if hasattr(self, "obj_points"):
|
|
2733
|
+
self.obj_points = {}
|
|
2734
|
+
if hasattr(self, "obj_labels"):
|
|
2735
|
+
self.obj_labels = {}
|
|
2736
|
+
if hasattr(self, "points_data"):
|
|
2737
|
+
self.points_data = {}
|
|
2738
|
+
if hasattr(self, "points_labels"):
|
|
2739
|
+
self.points_labels = {}
|
|
2740
|
+
|
|
2741
|
+
# Reset object ID counters
|
|
2742
|
+
if hasattr(self, "next_obj_id"):
|
|
2743
|
+
self.next_obj_id = 1
|
|
2744
|
+
if hasattr(self, "_sam2_next_obj_id"):
|
|
2745
|
+
self._sam2_next_obj_id = 1
|
|
2746
|
+
|
|
2747
|
+
# Update UI
|
|
2748
|
+
self._update_label_layer()
|
|
2140
2749
|
self._update_label_table()
|
|
2141
2750
|
self.preview_crop()
|
|
2142
|
-
|
|
2751
|
+
|
|
2752
|
+
self.viewer.status = (
|
|
2753
|
+
f"Cleared all {len(label_ids)} labels from segmentation"
|
|
2754
|
+
)
|
|
2755
|
+
|
|
2756
|
+
def clear_label_at_position(self, y, x):
|
|
2757
|
+
"""Clear a single label at the specified 2D position."""
|
|
2758
|
+
if self.segmentation_result is None:
|
|
2759
|
+
self.viewer.status = "No segmentation available"
|
|
2760
|
+
return
|
|
2761
|
+
|
|
2762
|
+
label_id = self.segmentation_result[y, x]
|
|
2763
|
+
if label_id > 0:
|
|
2764
|
+
# Remove all pixels with this label ID
|
|
2765
|
+
self.segmentation_result[self.segmentation_result == label_id] = 0
|
|
2766
|
+
|
|
2767
|
+
# Remove from selected labels if it was selected
|
|
2768
|
+
self.selected_labels.discard(label_id)
|
|
2769
|
+
|
|
2770
|
+
# Remove from label info
|
|
2771
|
+
if label_id in self.label_info:
|
|
2772
|
+
del self.label_info[label_id]
|
|
2773
|
+
|
|
2774
|
+
# Remove any object-specific point layers for this label
|
|
2775
|
+
for layer in list(self.viewer.layers):
|
|
2776
|
+
if f"Points for Object {label_id}" in layer.name:
|
|
2777
|
+
# Clean up callbacks before removing the layer to prevent cleanup issues
|
|
2778
|
+
if hasattr(layer, "mouse_drag_callbacks"):
|
|
2779
|
+
layer.mouse_drag_callbacks.clear()
|
|
2780
|
+
with contextlib.suppress(ValueError):
|
|
2781
|
+
self.viewer.layers.remove(layer)
|
|
2782
|
+
|
|
2783
|
+
# Clean up object tracking data
|
|
2784
|
+
if hasattr(self, "obj_points") and label_id in self.obj_points:
|
|
2785
|
+
del self.obj_points[label_id]
|
|
2786
|
+
if hasattr(self, "obj_labels") and label_id in self.obj_labels:
|
|
2787
|
+
del self.obj_labels[label_id]
|
|
2788
|
+
|
|
2789
|
+
# Update UI
|
|
2790
|
+
self._update_label_layer()
|
|
2791
|
+
self._update_label_table()
|
|
2792
|
+
self.preview_crop()
|
|
2793
|
+
|
|
2794
|
+
self.viewer.status = f"Deleted label ID: {label_id}"
|
|
2795
|
+
else:
|
|
2796
|
+
self.viewer.status = "No label to delete at this position"
|
|
2797
|
+
|
|
2798
|
+
def clear_label_at_position_3d(self, t, y, x):
|
|
2799
|
+
"""Clear a single label at the specified 3D position."""
|
|
2800
|
+
if self.segmentation_result is None:
|
|
2801
|
+
self.viewer.status = "No segmentation available"
|
|
2802
|
+
return
|
|
2803
|
+
|
|
2804
|
+
label_id = self.segmentation_result[t, y, x]
|
|
2805
|
+
if label_id > 0:
|
|
2806
|
+
# Remove all pixels with this label ID across all timeframes
|
|
2807
|
+
self.segmentation_result[self.segmentation_result == label_id] = 0
|
|
2808
|
+
|
|
2809
|
+
# Remove from selected labels if it was selected
|
|
2810
|
+
self.selected_labels.discard(label_id)
|
|
2811
|
+
|
|
2812
|
+
# Remove from label info
|
|
2813
|
+
if label_id in self.label_info:
|
|
2814
|
+
del self.label_info[label_id]
|
|
2815
|
+
|
|
2816
|
+
# Remove any object-specific point layers for this label
|
|
2817
|
+
for layer in list(self.viewer.layers):
|
|
2818
|
+
if f"Points for Object {label_id}" in layer.name:
|
|
2819
|
+
# Clean up callbacks before removing the layer to prevent cleanup issues
|
|
2820
|
+
if hasattr(layer, "mouse_drag_callbacks"):
|
|
2821
|
+
layer.mouse_drag_callbacks.clear()
|
|
2822
|
+
with contextlib.suppress(ValueError):
|
|
2823
|
+
self.viewer.layers.remove(layer)
|
|
2824
|
+
|
|
2825
|
+
# Clean up 3D object tracking data
|
|
2826
|
+
if (
|
|
2827
|
+
hasattr(self, "sam2_points_by_obj")
|
|
2828
|
+
and label_id in self.sam2_points_by_obj
|
|
2829
|
+
):
|
|
2830
|
+
del self.sam2_points_by_obj[label_id]
|
|
2831
|
+
if (
|
|
2832
|
+
hasattr(self, "sam2_labels_by_obj")
|
|
2833
|
+
and label_id in self.sam2_labels_by_obj
|
|
2834
|
+
):
|
|
2835
|
+
del self.sam2_labels_by_obj[label_id]
|
|
2836
|
+
if hasattr(self, "points_data") and label_id in self.points_data:
|
|
2837
|
+
del self.points_data[label_id]
|
|
2838
|
+
if (
|
|
2839
|
+
hasattr(self, "points_labels")
|
|
2840
|
+
and label_id in self.points_labels
|
|
2841
|
+
):
|
|
2842
|
+
del self.points_labels[label_id]
|
|
2843
|
+
|
|
2844
|
+
# Update UI
|
|
2845
|
+
self._update_label_layer()
|
|
2846
|
+
if (
|
|
2847
|
+
hasattr(self, "label_table_widget")
|
|
2848
|
+
and self.label_table_widget is not None
|
|
2849
|
+
):
|
|
2850
|
+
self._populate_label_table(self.label_table_widget)
|
|
2851
|
+
self.preview_crop()
|
|
2852
|
+
|
|
2853
|
+
self.viewer.status = (
|
|
2854
|
+
f"Deleted label ID: {label_id} from all timeframes"
|
|
2855
|
+
)
|
|
2856
|
+
else:
|
|
2857
|
+
self.viewer.status = "No label to delete at this position"
|
|
2143
2858
|
|
|
2144
2859
|
def preview_crop(self, label_ids=None):
|
|
2145
2860
|
"""Preview the crop result with the selected label IDs."""
|
|
@@ -2159,7 +2874,11 @@ class BatchCropAnything:
|
|
|
2159
2874
|
# Remove previous preview if exists
|
|
2160
2875
|
for layer in list(self.viewer.layers):
|
|
2161
2876
|
if "Preview" in layer.name:
|
|
2162
|
-
|
|
2877
|
+
# Clean up callbacks before removing the layer to prevent cleanup issues
|
|
2878
|
+
if hasattr(layer, "mouse_drag_callbacks"):
|
|
2879
|
+
layer.mouse_drag_callbacks.clear()
|
|
2880
|
+
with contextlib.suppress(ValueError):
|
|
2881
|
+
self.viewer.layers.remove(layer)
|
|
2163
2882
|
|
|
2164
2883
|
# Make sure the segmentation layer is active again
|
|
2165
2884
|
if self.label_layer is not None:
|
|
@@ -2197,7 +2916,11 @@ class BatchCropAnything:
|
|
|
2197
2916
|
# Remove previous preview if exists
|
|
2198
2917
|
for layer in list(self.viewer.layers):
|
|
2199
2918
|
if "Preview" in layer.name:
|
|
2200
|
-
|
|
2919
|
+
# Clean up callbacks before removing the layer to prevent cleanup issues
|
|
2920
|
+
if hasattr(layer, "mouse_drag_callbacks"):
|
|
2921
|
+
layer.mouse_drag_callbacks.clear()
|
|
2922
|
+
with contextlib.suppress(ValueError):
|
|
2923
|
+
self.viewer.layers.remove(layer)
|
|
2201
2924
|
|
|
2202
2925
|
# Add preview layer
|
|
2203
2926
|
if label_ids:
|
|
@@ -2288,17 +3011,14 @@ class BatchCropAnything:
|
|
|
2288
3011
|
# Save cropped image
|
|
2289
3012
|
image_path = self.images[self.current_index]
|
|
2290
3013
|
base_name, ext = os.path.splitext(image_path)
|
|
2291
|
-
|
|
2292
|
-
str(lid) for lid in sorted(self.selected_labels)
|
|
2293
|
-
)
|
|
2294
|
-
output_path = f"{base_name}_cropped_{label_str}.tif"
|
|
3014
|
+
output_path = f"{base_name}_sam2_cropped.tif"
|
|
2295
3015
|
|
|
2296
3016
|
# Save using tifffile with explicit parameters for best compatibility
|
|
2297
3017
|
imwrite(output_path, cropped_image, compression="zlib")
|
|
2298
3018
|
self.viewer.status = f"Saved cropped image to {output_path}"
|
|
2299
3019
|
|
|
2300
3020
|
# Save the label image with exact same dimensions as original
|
|
2301
|
-
label_output_path = f"{base_name}
|
|
3021
|
+
label_output_path = f"{base_name}_sam2_labels.tif"
|
|
2302
3022
|
imwrite(label_output_path, label_image, compression="zlib")
|
|
2303
3023
|
self.viewer.status += f"\nSaved label mask to {label_output_path}"
|
|
2304
3024
|
|
|
@@ -2312,6 +3032,27 @@ class BatchCropAnything:
|
|
|
2312
3032
|
self.viewer.status = f"Error cropping image: {str(e)}"
|
|
2313
3033
|
return False
|
|
2314
3034
|
|
|
3035
|
+
def reset_sam2_state(self):
|
|
3036
|
+
"""Reset SAM2 predictor state for 2D segmentation."""
|
|
3037
|
+
if not self.use_3d and hasattr(self, "prepared_sam2_image"):
|
|
3038
|
+
# Re-set the image in the predictor (only for ImagePredictor)
|
|
3039
|
+
device_type = "cuda" if self.device.type == "cuda" else "cpu"
|
|
3040
|
+
try:
|
|
3041
|
+
if hasattr(self.predictor, "set_image"):
|
|
3042
|
+
with (
|
|
3043
|
+
torch.inference_mode(),
|
|
3044
|
+
torch.autocast(device_type, dtype=torch.float32),
|
|
3045
|
+
):
|
|
3046
|
+
self.predictor.set_image(self.prepared_sam2_image)
|
|
3047
|
+
else:
|
|
3048
|
+
print(
|
|
3049
|
+
"DEBUG: reset_sam2_state - predictor doesn't have set_image method"
|
|
3050
|
+
)
|
|
3051
|
+
except (RuntimeError, AssertionError, TypeError, ValueError) as e:
|
|
3052
|
+
print(f"Error resetting SAM2 state: {e}")
|
|
3053
|
+
# If there's an error, try to reinitialize
|
|
3054
|
+
self._initialize_sam2()
|
|
3055
|
+
|
|
2315
3056
|
|
|
2316
3057
|
def create_crop_widget(processor):
|
|
2317
3058
|
"""Create the crop control widget."""
|
|
@@ -2322,27 +3063,70 @@ def create_crop_widget(processor):
|
|
|
2322
3063
|
|
|
2323
3064
|
# Instructions
|
|
2324
3065
|
dimension_type = "3D (TYX/ZYX)" if processor.use_3d else "2D (YX)"
|
|
2325
|
-
|
|
2326
|
-
|
|
2327
|
-
|
|
2328
|
-
|
|
2329
|
-
|
|
2330
|
-
|
|
2331
|
-
|
|
2332
|
-
|
|
3066
|
+
|
|
3067
|
+
if processor.use_3d:
|
|
3068
|
+
instructions_text = (
|
|
3069
|
+
f"<b>Processing {dimension_type} data</b><br><br>"
|
|
3070
|
+
"<b>⚠️ IMPORTANT for 3D stacks:</b><br>"
|
|
3071
|
+
"<ul>"
|
|
3072
|
+
"<li><b>Navigate to the FIRST SLICE</b> where your object appears (use the time/Z slider)</li>"
|
|
3073
|
+
"<li><b>Switch to 2D view</b> (click 2D icon in napari, NOT 3D view)</li>"
|
|
3074
|
+
"<li><b>Point Mode:</b> Select Points layer and click on objects to segment them</li>"
|
|
3075
|
+
"<li><b>Rectangle Mode:</b> Draw rectangles around objects to segment them</li>"
|
|
3076
|
+
"<li>Segmentation will automatically propagate to all slices</li>"
|
|
3077
|
+
"</ul><br>"
|
|
3078
|
+
"<b>General Controls:</b><br>"
|
|
3079
|
+
"<ul>"
|
|
3080
|
+
"<li>Use <b>Shift+click</b> for negative points (remove areas from segmentation)</li>"
|
|
3081
|
+
"<li>Click on existing objects in <b>Segmentation layer</b> to select for cropping</li>"
|
|
3082
|
+
"<li>Press <b>CTRL+click</b> on labels in <b>Segmentation layer</b> to delete them</li>"
|
|
3083
|
+
"<li>Press <b>'Crop'</b> to save selected objects to disk</li>"
|
|
3084
|
+
"</ul>"
|
|
3085
|
+
)
|
|
3086
|
+
else:
|
|
3087
|
+
instructions_text = (
|
|
3088
|
+
f"<b>Processing {dimension_type} data</b><br><br>"
|
|
3089
|
+
"<b>Point Mode:</b> Click on objects to segment them. Use Shift+click for negative points.<br>"
|
|
3090
|
+
"<b>Rectangle Mode:</b> Draw rectangles around objects to segment them.<br><br>"
|
|
3091
|
+
"<ul>"
|
|
3092
|
+
"<li>Click on existing objects in <b>Segmentation layer</b> to select them for cropping</li>"
|
|
3093
|
+
"<li>Press <b>CTRL+click</b> on labels in <b>Segmentation layer</b> to delete them</li>"
|
|
3094
|
+
"<li>Press <b>'Crop'</b> to save selected objects to disk</li>"
|
|
3095
|
+
"</ul>"
|
|
3096
|
+
)
|
|
3097
|
+
|
|
3098
|
+
instructions_label = QLabel(instructions_text)
|
|
2333
3099
|
instructions_label.setWordWrap(True)
|
|
2334
3100
|
layout.addWidget(instructions_label)
|
|
2335
3101
|
|
|
2336
|
-
# Add
|
|
2337
|
-
|
|
3102
|
+
# Add mode selector
|
|
3103
|
+
mode_layout = QHBoxLayout()
|
|
3104
|
+
mode_label = QLabel("<b>Prompt Mode:</b>")
|
|
3105
|
+
mode_layout.addWidget(mode_label)
|
|
3106
|
+
|
|
3107
|
+
point_mode_button = QPushButton("Points")
|
|
3108
|
+
point_mode_button.setCheckable(True)
|
|
3109
|
+
point_mode_button.setChecked(True)
|
|
3110
|
+
mode_layout.addWidget(point_mode_button)
|
|
3111
|
+
|
|
3112
|
+
box_mode_button = QPushButton("Rectangle")
|
|
3113
|
+
box_mode_button.setCheckable(True)
|
|
3114
|
+
box_mode_button.setChecked(False)
|
|
3115
|
+
mode_layout.addWidget(box_mode_button)
|
|
3116
|
+
|
|
3117
|
+
mode_layout.addStretch()
|
|
3118
|
+
layout.addLayout(mode_layout)
|
|
3119
|
+
|
|
3120
|
+
# Add a button to ensure active layer is correct
|
|
3121
|
+
activate_button = QPushButton("Make Prompt Layer Active")
|
|
2338
3122
|
activate_button.clicked.connect(
|
|
2339
|
-
lambda: processor.
|
|
3123
|
+
lambda: processor._ensure_active_prompt_layer()
|
|
2340
3124
|
)
|
|
2341
3125
|
layout.addWidget(activate_button)
|
|
2342
3126
|
|
|
2343
|
-
# Add a "Clear
|
|
2344
|
-
|
|
2345
|
-
layout.addWidget(
|
|
3127
|
+
# Add a "Clear Prompts" button to reset prompts
|
|
3128
|
+
clear_prompts_button = QPushButton("Clear Prompts")
|
|
3129
|
+
layout.addWidget(clear_prompts_button)
|
|
2346
3130
|
|
|
2347
3131
|
# Create label table
|
|
2348
3132
|
label_table = processor.create_label_table(crop_widget)
|
|
@@ -2353,7 +3137,7 @@ def create_crop_widget(processor):
|
|
|
2353
3137
|
# Selection buttons
|
|
2354
3138
|
selection_layout = QHBoxLayout()
|
|
2355
3139
|
select_all_button = QPushButton("Select All")
|
|
2356
|
-
clear_selection_button = QPushButton("Clear
|
|
3140
|
+
clear_selection_button = QPushButton("Clear All Labels")
|
|
2357
3141
|
selection_layout.addWidget(select_all_button)
|
|
2358
3142
|
selection_layout.addWidget(clear_selection_button)
|
|
2359
3143
|
layout.addLayout(selection_layout)
|
|
@@ -2391,51 +3175,152 @@ def create_crop_widget(processor):
|
|
|
2391
3175
|
# Create new table
|
|
2392
3176
|
label_table = processor.create_label_table(crop_widget)
|
|
2393
3177
|
label_table.setMinimumHeight(200)
|
|
2394
|
-
layout.insertWidget(
|
|
3178
|
+
layout.insertWidget(
|
|
3179
|
+
3, label_table
|
|
3180
|
+
) # Insert after clear prompts button
|
|
2395
3181
|
return label_table
|
|
2396
3182
|
|
|
2397
|
-
# Add helper method to ensure
|
|
2398
|
-
def
|
|
2399
|
-
|
|
2400
|
-
|
|
2401
|
-
|
|
2402
|
-
|
|
2403
|
-
|
|
3183
|
+
# Add helper method to ensure active prompt layer is selected based on mode
|
|
3184
|
+
def _ensure_active_prompt_layer():
|
|
3185
|
+
if processor.prompt_mode == "point":
|
|
3186
|
+
points_layer = None
|
|
3187
|
+
for layer in list(processor.viewer.layers):
|
|
3188
|
+
if "Points" in layer.name and "Object" not in layer.name:
|
|
3189
|
+
points_layer = layer
|
|
3190
|
+
break
|
|
2404
3191
|
|
|
2405
|
-
|
|
2406
|
-
|
|
2407
|
-
|
|
2408
|
-
|
|
2409
|
-
|
|
2410
|
-
|
|
2411
|
-
|
|
2412
|
-
|
|
2413
|
-
|
|
3192
|
+
if points_layer is not None:
|
|
3193
|
+
processor.viewer.layers.selection.active = points_layer
|
|
3194
|
+
if processor.use_3d:
|
|
3195
|
+
status_label.setText(
|
|
3196
|
+
"Points layer active - Navigate to FIRST SLICE of object, ensure 2D view, then click"
|
|
3197
|
+
)
|
|
3198
|
+
else:
|
|
3199
|
+
status_label.setText(
|
|
3200
|
+
"Points layer is now active - click to add points"
|
|
3201
|
+
)
|
|
3202
|
+
else:
|
|
3203
|
+
status_label.setText(
|
|
3204
|
+
"No points layer found. Please load an image first."
|
|
3205
|
+
)
|
|
3206
|
+
else: # box mode
|
|
3207
|
+
shapes_layer = None
|
|
3208
|
+
for layer in list(processor.viewer.layers):
|
|
3209
|
+
if "Rectangles" in layer.name:
|
|
3210
|
+
shapes_layer = layer
|
|
3211
|
+
break
|
|
3212
|
+
|
|
3213
|
+
if shapes_layer is not None:
|
|
3214
|
+
processor.viewer.layers.selection.active = shapes_layer
|
|
3215
|
+
status_label.setText(
|
|
3216
|
+
"Rectangles layer is now active - draw rectangles"
|
|
3217
|
+
)
|
|
3218
|
+
else:
|
|
3219
|
+
status_label.setText(
|
|
3220
|
+
"No rectangles layer found. Please load an image first."
|
|
3221
|
+
)
|
|
3222
|
+
|
|
3223
|
+
processor._ensure_active_prompt_layer = _ensure_active_prompt_layer
|
|
3224
|
+
|
|
3225
|
+
# Keep the old method for backward compatibility
|
|
3226
|
+
processor._ensure_points_layer_active = _ensure_active_prompt_layer
|
|
2414
3227
|
|
|
2415
|
-
|
|
3228
|
+
def on_clear_prompts_clicked():
|
|
3229
|
+
# Find and clear/remove prompt layers based on mode
|
|
3230
|
+
main_points_layer = None
|
|
3231
|
+
object_points_layers = []
|
|
3232
|
+
shapes_layer = None
|
|
2416
3233
|
|
|
2417
|
-
# Connect button signals
|
|
2418
|
-
def on_clear_points_clicked():
|
|
2419
|
-
# Remove all point layers
|
|
2420
3234
|
for layer in list(processor.viewer.layers):
|
|
2421
3235
|
if "Points" in layer.name:
|
|
3236
|
+
if "Object" in layer.name:
|
|
3237
|
+
object_points_layers.append(layer)
|
|
3238
|
+
else:
|
|
3239
|
+
main_points_layer = layer
|
|
3240
|
+
elif "Rectangles" in layer.name:
|
|
3241
|
+
shapes_layer = layer
|
|
3242
|
+
|
|
3243
|
+
# Remove object-specific point layers (these are created dynamically)
|
|
3244
|
+
for layer in object_points_layers:
|
|
3245
|
+
# Clean up callbacks before removing the layer to prevent cleanup issues
|
|
3246
|
+
if hasattr(layer, "mouse_drag_callbacks"):
|
|
3247
|
+
layer.mouse_drag_callbacks.clear()
|
|
3248
|
+
with contextlib.suppress(ValueError):
|
|
2422
3249
|
processor.viewer.layers.remove(layer)
|
|
2423
3250
|
|
|
2424
|
-
#
|
|
2425
|
-
if
|
|
2426
|
-
|
|
2427
|
-
processor.points_labels = {}
|
|
3251
|
+
# Clear shapes layer
|
|
3252
|
+
if shapes_layer is not None:
|
|
3253
|
+
shapes_layer.data = []
|
|
2428
3254
|
|
|
2429
|
-
|
|
2430
|
-
|
|
2431
|
-
|
|
3255
|
+
# Clear data from main points layer instead of removing it
|
|
3256
|
+
if main_points_layer is not None:
|
|
3257
|
+
# Clear the points data
|
|
3258
|
+
main_points_layer.data = np.zeros(
|
|
3259
|
+
(0, 2 if not processor.use_3d else 3)
|
|
3260
|
+
)
|
|
3261
|
+
main_points_layer.face_color = "green"
|
|
2432
3262
|
|
|
2433
|
-
|
|
2434
|
-
|
|
2435
|
-
|
|
3263
|
+
# Ensure the click callback is still connected
|
|
3264
|
+
if (
|
|
3265
|
+
hasattr(main_points_layer, "mouse_drag_callbacks")
|
|
3266
|
+
and processor._on_points_clicked
|
|
3267
|
+
not in main_points_layer.mouse_drag_callbacks
|
|
3268
|
+
):
|
|
3269
|
+
main_points_layer.mouse_drag_callbacks.append(
|
|
3270
|
+
processor._on_points_clicked
|
|
3271
|
+
)
|
|
3272
|
+
|
|
3273
|
+
# Reset all tracking attributes for 2D
|
|
3274
|
+
if not processor.use_3d:
|
|
3275
|
+
# Reset current segmentation tracking
|
|
3276
|
+
if hasattr(processor, "current_points"):
|
|
3277
|
+
processor.current_points = []
|
|
3278
|
+
processor.current_labels = []
|
|
3279
|
+
|
|
3280
|
+
# Reset object tracking
|
|
3281
|
+
if hasattr(processor, "obj_points"):
|
|
3282
|
+
processor.obj_points = {}
|
|
3283
|
+
processor.obj_labels = {}
|
|
3284
|
+
|
|
3285
|
+
# Reset box tracking
|
|
3286
|
+
if hasattr(processor, "obj_boxes"):
|
|
3287
|
+
processor.obj_boxes = {}
|
|
3288
|
+
|
|
3289
|
+
# Reset object ID counters
|
|
3290
|
+
if hasattr(processor, "current_obj_id"):
|
|
3291
|
+
# Find the highest existing label ID
|
|
3292
|
+
if processor.segmentation_result is not None:
|
|
3293
|
+
max_label = processor.segmentation_result.max()
|
|
3294
|
+
processor.current_obj_id = max(int(max_label) + 1, 1)
|
|
3295
|
+
processor.next_obj_id = processor.current_obj_id
|
|
3296
|
+
else:
|
|
3297
|
+
processor.current_obj_id = 1
|
|
3298
|
+
processor.next_obj_id = 1
|
|
3299
|
+
|
|
3300
|
+
# Reset SAM2 predictor state
|
|
3301
|
+
processor.reset_sam2_state()
|
|
3302
|
+
|
|
3303
|
+
# For 3D, reset video-specific tracking
|
|
3304
|
+
else:
|
|
3305
|
+
if hasattr(processor, "sam2_points_by_obj"):
|
|
3306
|
+
processor.sam2_points_by_obj = {}
|
|
3307
|
+
processor.sam2_labels_by_obj = {}
|
|
3308
|
+
|
|
3309
|
+
# Reset box tracking
|
|
3310
|
+
if hasattr(processor, "obj_boxes"):
|
|
3311
|
+
processor.obj_boxes = {}
|
|
3312
|
+
|
|
3313
|
+
if hasattr(processor, "points_data"):
|
|
3314
|
+
processor.points_data = {}
|
|
3315
|
+
processor.points_labels = {}
|
|
3316
|
+
|
|
3317
|
+
# Note: We don't reset _sam2_state for 3D as it needs to maintain video state
|
|
3318
|
+
|
|
3319
|
+
# Make the appropriate prompt layer active based on mode
|
|
3320
|
+
_ensure_active_prompt_layer()
|
|
2436
3321
|
|
|
2437
3322
|
status_label.setText(
|
|
2438
|
-
"Cleared all
|
|
3323
|
+
"Cleared all prompts. Ready to add new segmentation prompts."
|
|
2439
3324
|
)
|
|
2440
3325
|
|
|
2441
3326
|
def on_select_all_clicked():
|
|
@@ -2459,8 +3344,14 @@ def create_crop_widget(processor):
|
|
|
2459
3344
|
)
|
|
2460
3345
|
|
|
2461
3346
|
def on_next_clicked():
|
|
2462
|
-
#
|
|
2463
|
-
|
|
3347
|
+
# Check if we can move to the next image before clearing prompts
|
|
3348
|
+
if processor.current_index >= len(processor.images) - 1:
|
|
3349
|
+
next_button.setEnabled(False)
|
|
3350
|
+
status_label.setText("No more images. Processing complete.")
|
|
3351
|
+
return
|
|
3352
|
+
|
|
3353
|
+
# Clear prompts before moving to next image
|
|
3354
|
+
on_clear_prompts_clicked()
|
|
2464
3355
|
|
|
2465
3356
|
if not processor.next_image():
|
|
2466
3357
|
next_button.setEnabled(False)
|
|
@@ -2470,11 +3361,17 @@ def create_crop_widget(processor):
|
|
|
2470
3361
|
status_label.setText(
|
|
2471
3362
|
f"Showing image {processor.current_index + 1}/{len(processor.images)}"
|
|
2472
3363
|
)
|
|
2473
|
-
processor.
|
|
3364
|
+
processor._ensure_active_prompt_layer()
|
|
2474
3365
|
|
|
2475
3366
|
def on_prev_clicked():
|
|
2476
|
-
#
|
|
2477
|
-
|
|
3367
|
+
# Check if we can move to the previous image before clearing prompts
|
|
3368
|
+
if processor.current_index <= 0:
|
|
3369
|
+
prev_button.setEnabled(False)
|
|
3370
|
+
status_label.setText("Already at the first image.")
|
|
3371
|
+
return
|
|
3372
|
+
|
|
3373
|
+
# Clear prompts before moving to previous image
|
|
3374
|
+
on_clear_prompts_clicked()
|
|
2478
3375
|
|
|
2479
3376
|
if not processor.previous_image():
|
|
2480
3377
|
prev_button.setEnabled(False)
|
|
@@ -2484,15 +3381,33 @@ def create_crop_widget(processor):
|
|
|
2484
3381
|
status_label.setText(
|
|
2485
3382
|
f"Showing image {processor.current_index + 1}/{len(processor.images)}"
|
|
2486
3383
|
)
|
|
2487
|
-
processor.
|
|
3384
|
+
processor._ensure_active_prompt_layer()
|
|
3385
|
+
|
|
3386
|
+
def on_point_mode_clicked():
|
|
3387
|
+
processor.prompt_mode = "point"
|
|
3388
|
+
point_mode_button.setChecked(True)
|
|
3389
|
+
box_mode_button.setChecked(False)
|
|
3390
|
+
processor._update_label_layer()
|
|
3391
|
+
status_label.setText("Point mode active - click on objects to segment")
|
|
2488
3392
|
|
|
2489
|
-
|
|
3393
|
+
def on_box_mode_clicked():
|
|
3394
|
+
processor.prompt_mode = "box"
|
|
3395
|
+
point_mode_button.setChecked(False)
|
|
3396
|
+
box_mode_button.setChecked(True)
|
|
3397
|
+
processor._update_label_layer()
|
|
3398
|
+
status_label.setText(
|
|
3399
|
+
"Rectangle mode active - draw rectangles around objects"
|
|
3400
|
+
)
|
|
3401
|
+
|
|
3402
|
+
clear_prompts_button.clicked.connect(on_clear_prompts_clicked)
|
|
2490
3403
|
select_all_button.clicked.connect(on_select_all_clicked)
|
|
2491
3404
|
clear_selection_button.clicked.connect(on_clear_selection_clicked)
|
|
2492
3405
|
crop_button.clicked.connect(on_crop_clicked)
|
|
2493
3406
|
next_button.clicked.connect(on_next_clicked)
|
|
2494
3407
|
prev_button.clicked.connect(on_prev_clicked)
|
|
2495
|
-
activate_button.clicked.connect(
|
|
3408
|
+
activate_button.clicked.connect(_ensure_active_prompt_layer)
|
|
3409
|
+
point_mode_button.clicked.connect(on_point_mode_clicked)
|
|
3410
|
+
box_mode_button.clicked.connect(on_box_mode_clicked)
|
|
2496
3411
|
|
|
2497
3412
|
return crop_widget
|
|
2498
3413
|
|
|
@@ -2511,6 +3426,19 @@ def batch_crop_anything(
|
|
|
2511
3426
|
viewer: Viewer = None,
|
|
2512
3427
|
):
|
|
2513
3428
|
"""MagicGUI widget for starting Batch Crop Anything using SAM2."""
|
|
3429
|
+
# Check if torch is available
|
|
3430
|
+
if not _HAS_TORCH:
|
|
3431
|
+
QMessageBox.critical(
|
|
3432
|
+
None,
|
|
3433
|
+
"Missing Dependency",
|
|
3434
|
+
"PyTorch not found. Batch Crop Anything requires PyTorch and SAM2.\n\n"
|
|
3435
|
+
"To install the required dependencies, run:\n"
|
|
3436
|
+
"pip install 'napari-tmidas[deep-learning]'\n\n"
|
|
3437
|
+
"Then follow SAM2 installation instructions at:\n"
|
|
3438
|
+
"https://github.com/MercaderLabAnatomy/napari-tmidas#installation",
|
|
3439
|
+
)
|
|
3440
|
+
return
|
|
3441
|
+
|
|
2514
3442
|
# Check if SAM2 is available
|
|
2515
3443
|
try:
|
|
2516
3444
|
import importlib.util
|
|
@@ -2521,15 +3449,15 @@ def batch_crop_anything(
|
|
|
2521
3449
|
None,
|
|
2522
3450
|
"Missing Dependency",
|
|
2523
3451
|
"SAM2 not found. Please follow installation instructions at:\n"
|
|
2524
|
-
"https://github.com/MercaderLabAnatomy/napari-tmidas
|
|
3452
|
+
"https://github.com/MercaderLabAnatomy/napari-tmidas#installation\n",
|
|
2525
3453
|
)
|
|
2526
3454
|
return
|
|
2527
3455
|
except ImportError:
|
|
2528
3456
|
QMessageBox.critical(
|
|
2529
3457
|
None,
|
|
2530
3458
|
"Missing Dependency",
|
|
2531
|
-
"SAM2 package cannot be imported. Please follow installation instructions at
|
|
2532
|
-
"https://github.com/MercaderLabAnatomy/napari-tmidas
|
|
3459
|
+
"SAM2 package cannot be imported. Please follow installation instructions at:\n"
|
|
3460
|
+
"https://github.com/MercaderLabAnatomy/napari-tmidas#installation",
|
|
2533
3461
|
)
|
|
2534
3462
|
return
|
|
2535
3463
|
|
|
@@ -2557,24 +3485,7 @@ def batch_crop_anything_widget():
|
|
|
2557
3485
|
# Create the magicgui widget
|
|
2558
3486
|
widget = batch_crop_anything
|
|
2559
3487
|
|
|
2560
|
-
#
|
|
2561
|
-
|
|
2562
|
-
|
|
2563
|
-
def on_folder_browse_clicked():
|
|
2564
|
-
folder = QFileDialog.getExistingDirectory(
|
|
2565
|
-
None,
|
|
2566
|
-
"Select Folder",
|
|
2567
|
-
os.path.expanduser("~"),
|
|
2568
|
-
QFileDialog.ShowDirsOnly | QFileDialog.DontResolveSymlinks,
|
|
2569
|
-
)
|
|
2570
|
-
if folder:
|
|
2571
|
-
# Update the folder_path field
|
|
2572
|
-
widget.folder_path.value = folder
|
|
2573
|
-
|
|
2574
|
-
folder_browse_button.clicked.connect(on_folder_browse_clicked)
|
|
2575
|
-
|
|
2576
|
-
# Insert the browse button next to the folder_path field
|
|
2577
|
-
folder_layout = widget.folder_path.native.parent().layout()
|
|
2578
|
-
folder_layout.addWidget(folder_browse_button)
|
|
3488
|
+
# Add browse button using common utility
|
|
3489
|
+
add_browse_button_to_folder_field(widget, "folder_path")
|
|
2579
3490
|
|
|
2580
3491
|
return widget
|