napari-tmidas 0.2.1__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- napari_tmidas/__init__.py +35 -5
- napari_tmidas/_crop_anything.py +1458 -499
- napari_tmidas/_env_manager.py +76 -0
- napari_tmidas/_file_conversion.py +1646 -1131
- napari_tmidas/_file_selector.py +1464 -223
- napari_tmidas/_label_inspection.py +83 -8
- napari_tmidas/_processing_worker.py +309 -0
- napari_tmidas/_reader.py +6 -10
- napari_tmidas/_registry.py +15 -14
- napari_tmidas/_roi_colocalization.py +1221 -84
- napari_tmidas/_tests/test_crop_anything.py +123 -0
- napari_tmidas/_tests/test_env_manager.py +89 -0
- napari_tmidas/_tests/test_file_selector.py +90 -0
- napari_tmidas/_tests/test_grid_view_overlay.py +193 -0
- napari_tmidas/_tests/test_init.py +98 -0
- napari_tmidas/_tests/test_intensity_label_filter.py +222 -0
- napari_tmidas/_tests/test_label_inspection.py +86 -0
- napari_tmidas/_tests/test_processing_basic.py +500 -0
- napari_tmidas/_tests/test_processing_worker.py +142 -0
- napari_tmidas/_tests/test_regionprops_analysis.py +547 -0
- napari_tmidas/_tests/test_registry.py +135 -0
- napari_tmidas/_tests/test_scipy_filters.py +168 -0
- napari_tmidas/_tests/test_skimage_filters.py +259 -0
- napari_tmidas/_tests/test_split_channels.py +217 -0
- napari_tmidas/_tests/test_spotiflow.py +87 -0
- napari_tmidas/_tests/test_tyx_display_fix.py +142 -0
- napari_tmidas/_tests/test_ui_utils.py +68 -0
- napari_tmidas/_tests/test_widget.py +30 -0
- napari_tmidas/_tests/test_windows_basic.py +66 -0
- napari_tmidas/_ui_utils.py +57 -0
- napari_tmidas/_version.py +16 -3
- napari_tmidas/_widget.py +41 -4
- napari_tmidas/processing_functions/basic.py +557 -20
- napari_tmidas/processing_functions/careamics_env_manager.py +72 -99
- napari_tmidas/processing_functions/cellpose_env_manager.py +415 -112
- napari_tmidas/processing_functions/cellpose_segmentation.py +132 -191
- napari_tmidas/processing_functions/colocalization.py +513 -56
- napari_tmidas/processing_functions/grid_view_overlay.py +703 -0
- napari_tmidas/processing_functions/intensity_label_filter.py +422 -0
- napari_tmidas/processing_functions/regionprops_analysis.py +1280 -0
- napari_tmidas/processing_functions/sam2_env_manager.py +53 -69
- napari_tmidas/processing_functions/sam2_mp4.py +274 -195
- napari_tmidas/processing_functions/scipy_filters.py +403 -8
- napari_tmidas/processing_functions/skimage_filters.py +424 -212
- napari_tmidas/processing_functions/spotiflow_detection.py +949 -0
- napari_tmidas/processing_functions/spotiflow_env_manager.py +591 -0
- napari_tmidas/processing_functions/timepoint_merger.py +334 -86
- napari_tmidas/processing_functions/trackastra_tracking.py +24 -5
- {napari_tmidas-0.2.1.dist-info → napari_tmidas-0.2.4.dist-info}/METADATA +92 -39
- napari_tmidas-0.2.4.dist-info/RECORD +63 -0
- napari_tmidas/_tests/__init__.py +0 -0
- napari_tmidas-0.2.1.dist-info/RECORD +0 -38
- {napari_tmidas-0.2.1.dist-info → napari_tmidas-0.2.4.dist-info}/WHEEL +0 -0
- {napari_tmidas-0.2.1.dist-info → napari_tmidas-0.2.4.dist-info}/entry_points.txt +0 -0
- {napari_tmidas-0.2.1.dist-info → napari_tmidas-0.2.4.dist-info}/licenses/LICENSE +0 -0
- {napari_tmidas-0.2.1.dist-info → napari_tmidas-0.2.4.dist-info}/top_level.txt +0 -0
napari_tmidas/_crop_anything.py
CHANGED
|
@@ -8,42 +8,126 @@ The plugin supports both 2D (YX) and 3D (TYX/ZYX) data.
|
|
|
8
8
|
|
|
9
9
|
import contextlib
|
|
10
10
|
import os
|
|
11
|
-
|
|
12
|
-
# Add this at the beginning of your plugin file
|
|
13
11
|
import sys
|
|
12
|
+
from pathlib import Path
|
|
14
13
|
|
|
15
|
-
sys.path.append("/opt/sam2")
|
|
16
14
|
import numpy as np
|
|
17
|
-
import requests
|
|
18
|
-
import torch
|
|
19
|
-
from magicgui import magicgui
|
|
20
|
-
from napari.layers import Labels
|
|
21
|
-
from napari.viewer import Viewer
|
|
22
|
-
from qtpy.QtCore import Qt
|
|
23
|
-
from qtpy.QtWidgets import (
|
|
24
|
-
QCheckBox,
|
|
25
|
-
QFileDialog,
|
|
26
|
-
QHBoxLayout,
|
|
27
|
-
QHeaderView,
|
|
28
|
-
QLabel,
|
|
29
|
-
QMessageBox,
|
|
30
|
-
QPushButton,
|
|
31
|
-
QScrollArea,
|
|
32
|
-
QTableWidget,
|
|
33
|
-
QTableWidgetItem,
|
|
34
|
-
QVBoxLayout,
|
|
35
|
-
QWidget,
|
|
36
|
-
)
|
|
37
|
-
from skimage.io import imread
|
|
38
|
-
from skimage.transform import resize
|
|
39
|
-
from tifffile import imwrite
|
|
40
15
|
|
|
16
|
+
# Lazy imports for optional heavy dependencies
|
|
17
|
+
try:
|
|
18
|
+
import requests
|
|
19
|
+
|
|
20
|
+
_HAS_REQUESTS = True
|
|
21
|
+
except ImportError:
|
|
22
|
+
requests = None
|
|
23
|
+
_HAS_REQUESTS = False
|
|
24
|
+
|
|
25
|
+
try:
|
|
26
|
+
import torch
|
|
27
|
+
|
|
28
|
+
_HAS_TORCH = True
|
|
29
|
+
except ImportError:
|
|
30
|
+
torch = None
|
|
31
|
+
_HAS_TORCH = False
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
from magicgui import magicgui
|
|
35
|
+
|
|
36
|
+
_HAS_MAGICGUI = True
|
|
37
|
+
except ImportError:
|
|
38
|
+
# Create stub decorator
|
|
39
|
+
def magicgui(*args, **kwargs):
|
|
40
|
+
def decorator(func):
|
|
41
|
+
return func
|
|
42
|
+
|
|
43
|
+
if len(args) == 1 and callable(args[0]) and not kwargs:
|
|
44
|
+
return args[0]
|
|
45
|
+
return decorator
|
|
46
|
+
|
|
47
|
+
_HAS_MAGICGUI = False
|
|
48
|
+
|
|
49
|
+
try:
|
|
50
|
+
from napari.layers import Labels
|
|
51
|
+
from napari.viewer import Viewer
|
|
52
|
+
|
|
53
|
+
_HAS_NAPARI = True
|
|
54
|
+
except ImportError:
|
|
55
|
+
Labels = None
|
|
56
|
+
Viewer = None
|
|
57
|
+
_HAS_NAPARI = False
|
|
58
|
+
|
|
59
|
+
try:
|
|
60
|
+
from qtpy.QtCore import Qt
|
|
61
|
+
from qtpy.QtWidgets import (
|
|
62
|
+
QCheckBox,
|
|
63
|
+
QHBoxLayout,
|
|
64
|
+
QHeaderView,
|
|
65
|
+
QLabel,
|
|
66
|
+
QMessageBox,
|
|
67
|
+
QPushButton,
|
|
68
|
+
QScrollArea,
|
|
69
|
+
QTableWidget,
|
|
70
|
+
QTableWidgetItem,
|
|
71
|
+
QVBoxLayout,
|
|
72
|
+
QWidget,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
_HAS_QTPY = True
|
|
76
|
+
except ImportError:
|
|
77
|
+
Qt = None
|
|
78
|
+
QCheckBox = QHBoxLayout = QHeaderView = QLabel = QMessageBox = None
|
|
79
|
+
QPushButton = QScrollArea = QTableWidget = QTableWidgetItem = None
|
|
80
|
+
QVBoxLayout = QWidget = None
|
|
81
|
+
_HAS_QTPY = False
|
|
82
|
+
|
|
83
|
+
try:
|
|
84
|
+
from skimage.io import imread
|
|
85
|
+
from skimage.transform import resize
|
|
86
|
+
|
|
87
|
+
_HAS_SKIMAGE = True
|
|
88
|
+
except ImportError:
|
|
89
|
+
imread = None
|
|
90
|
+
resize = None
|
|
91
|
+
_HAS_SKIMAGE = False
|
|
92
|
+
|
|
93
|
+
try:
|
|
94
|
+
from tifffile import imwrite
|
|
95
|
+
|
|
96
|
+
_HAS_TIFFFILE = True
|
|
97
|
+
except ImportError:
|
|
98
|
+
imwrite = None
|
|
99
|
+
_HAS_TIFFFILE = False
|
|
100
|
+
|
|
101
|
+
from napari_tmidas._file_selector import (
|
|
102
|
+
load_image_file as load_any_image,
|
|
103
|
+
)
|
|
104
|
+
from napari_tmidas._ui_utils import add_browse_button_to_folder_field
|
|
41
105
|
from napari_tmidas.processing_functions.sam2_mp4 import tif_to_mp4
|
|
42
106
|
|
|
107
|
+
sam2_paths = [
|
|
108
|
+
os.environ.get("SAM2_PATH"),
|
|
109
|
+
"/opt/sam2",
|
|
110
|
+
os.path.expanduser("~/sam2"),
|
|
111
|
+
"./sam2",
|
|
112
|
+
]
|
|
113
|
+
|
|
114
|
+
for path in sam2_paths:
|
|
115
|
+
if path and os.path.exists(path):
|
|
116
|
+
sys.path.append(path)
|
|
117
|
+
break
|
|
118
|
+
else:
|
|
119
|
+
print(
|
|
120
|
+
"Warning: SAM2 not found in common locations. Please set SAM2_PATH environment variable."
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
|
|
43
124
|
def get_device():
|
|
44
125
|
if sys.platform == "darwin":
|
|
45
126
|
# MacOS: Only check for MPS
|
|
46
|
-
if
|
|
127
|
+
if (
|
|
128
|
+
hasattr(torch.backends, "mps")
|
|
129
|
+
and torch.backends.mps.is_available()
|
|
130
|
+
):
|
|
47
131
|
device = torch.device("mps")
|
|
48
132
|
print("Using Apple Silicon GPU (MPS)")
|
|
49
133
|
else:
|
|
@@ -60,8 +144,6 @@ def get_device():
|
|
|
60
144
|
return device
|
|
61
145
|
|
|
62
146
|
|
|
63
|
-
|
|
64
|
-
|
|
65
147
|
class BatchCropAnything:
|
|
66
148
|
"""Class for processing images with SAM2 and cropping selected objects."""
|
|
67
149
|
|
|
@@ -83,6 +165,7 @@ class BatchCropAnything:
|
|
|
83
165
|
self.image_layer = None
|
|
84
166
|
self.label_layer = None
|
|
85
167
|
self.label_table_widget = None
|
|
168
|
+
self.shapes_layer = None
|
|
86
169
|
|
|
87
170
|
# State tracking
|
|
88
171
|
self.selected_labels = set()
|
|
@@ -91,6 +174,9 @@ class BatchCropAnything:
|
|
|
91
174
|
# Segmentation parameters
|
|
92
175
|
self.sensitivity = 50 # Default sensitivity (0-100 scale)
|
|
93
176
|
|
|
177
|
+
# Prompt mode: 'point' or 'box'
|
|
178
|
+
self.prompt_mode = "point"
|
|
179
|
+
|
|
94
180
|
# Initialize the SAM2 model
|
|
95
181
|
self._initialize_sam2()
|
|
96
182
|
|
|
@@ -104,7 +190,7 @@ class BatchCropAnything:
|
|
|
104
190
|
filename = os.path.join(dest_folder, url.split("/")[-1])
|
|
105
191
|
if not os.path.exists(filename):
|
|
106
192
|
print(f"Downloading checkpoint to {filename}...")
|
|
107
|
-
response = requests.get(url, stream=True)
|
|
193
|
+
response = requests.get(url, stream=True, timeout=30)
|
|
108
194
|
response.raise_for_status()
|
|
109
195
|
with open(filename, "wb") as f:
|
|
110
196
|
for chunk in response.iter_content(chunk_size=8192):
|
|
@@ -116,17 +202,45 @@ class BatchCropAnything:
|
|
|
116
202
|
|
|
117
203
|
try:
|
|
118
204
|
# import torch
|
|
205
|
+
print("DEBUG: Starting SAM2 initialization...")
|
|
119
206
|
|
|
120
207
|
self.device = get_device()
|
|
208
|
+
print(f"DEBUG: Device set to {self.device}")
|
|
121
209
|
|
|
122
210
|
# Download checkpoint if needed
|
|
123
211
|
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt"
|
|
124
212
|
checkpoint_path = download_checkpoint(
|
|
125
213
|
checkpoint_url, "/opt/sam2/checkpoints/"
|
|
126
214
|
)
|
|
215
|
+
print(f"DEBUG: Checkpoint path: {checkpoint_path}")
|
|
216
|
+
|
|
217
|
+
# Use relative config path for SAM2's Hydra config system
|
|
127
218
|
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
|
219
|
+
print(f"DEBUG: Model config: {model_cfg}")
|
|
220
|
+
|
|
221
|
+
# Verify the actual config file exists in the SAM2 installation
|
|
222
|
+
sam2_base_path = None
|
|
223
|
+
for path in sam2_paths:
|
|
224
|
+
if path and os.path.exists(path):
|
|
225
|
+
sam2_base_path = path
|
|
226
|
+
break
|
|
227
|
+
|
|
228
|
+
if sam2_base_path is not None:
|
|
229
|
+
full_config_path = os.path.join(
|
|
230
|
+
sam2_base_path, "sam2", model_cfg
|
|
231
|
+
)
|
|
232
|
+
if not os.path.exists(full_config_path):
|
|
233
|
+
raise FileNotFoundError(
|
|
234
|
+
f"SAM2 config file not found at: {full_config_path}"
|
|
235
|
+
)
|
|
236
|
+
print(f"DEBUG: Verified config exists at: {full_config_path}")
|
|
237
|
+
else:
|
|
238
|
+
print(
|
|
239
|
+
"DEBUG: Warning - could not verify config file exists, but proceeding with relative path"
|
|
240
|
+
)
|
|
128
241
|
|
|
129
242
|
if self.use_3d:
|
|
243
|
+
print("DEBUG: Initializing SAM2 Video Predictor...")
|
|
130
244
|
from sam2.build_sam import build_sam2_video_predictor
|
|
131
245
|
|
|
132
246
|
self.predictor = build_sam2_video_predictor(
|
|
@@ -135,7 +249,9 @@ class BatchCropAnything:
|
|
|
135
249
|
self.viewer.status = (
|
|
136
250
|
f"Initialized SAM2 Video Predictor on {self.device}"
|
|
137
251
|
)
|
|
252
|
+
print("DEBUG: SAM2 Video Predictor initialized successfully")
|
|
138
253
|
else:
|
|
254
|
+
print("DEBUG: Initializing SAM2 Image Predictor...")
|
|
139
255
|
from sam2.build_sam import build_sam2
|
|
140
256
|
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
141
257
|
|
|
@@ -145,6 +261,7 @@ class BatchCropAnything:
|
|
|
145
261
|
self.viewer.status = (
|
|
146
262
|
f"Initialized SAM2 Image Predictor on {self.device}"
|
|
147
263
|
)
|
|
264
|
+
print("DEBUG: SAM2 Image Predictor initialized successfully")
|
|
148
265
|
|
|
149
266
|
except (
|
|
150
267
|
ImportError,
|
|
@@ -152,37 +269,79 @@ class BatchCropAnything:
|
|
|
152
269
|
ValueError,
|
|
153
270
|
FileNotFoundError,
|
|
154
271
|
requests.RequestException,
|
|
272
|
+
AttributeError,
|
|
273
|
+
ModuleNotFoundError,
|
|
155
274
|
) as e:
|
|
156
275
|
import traceback
|
|
157
276
|
|
|
158
|
-
|
|
277
|
+
error_msg = f"SAM2 initialization failed: {str(e)}"
|
|
278
|
+
error_type = type(e).__name__
|
|
279
|
+
self.viewer.status = (
|
|
280
|
+
f"{error_msg} - Images will load without segmentation"
|
|
281
|
+
)
|
|
159
282
|
self.predictor = None
|
|
283
|
+
print(f"DEBUG: SAM2 Error ({error_type}): {error_msg}")
|
|
284
|
+
print("DEBUG: Full traceback:")
|
|
160
285
|
print(traceback.format_exc())
|
|
286
|
+
print(
|
|
287
|
+
"DEBUG: Note: Images will still load, but automatic segmentation will not be available."
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
# Provide specific guidance based on error type
|
|
291
|
+
if isinstance(e, FileNotFoundError):
|
|
292
|
+
print(
|
|
293
|
+
"DEBUG: This appears to be a missing file issue. Check SAM2 installation and config paths."
|
|
294
|
+
)
|
|
295
|
+
elif isinstance(e, (ImportError, ModuleNotFoundError)):
|
|
296
|
+
print(
|
|
297
|
+
"DEBUG: This appears to be a SAM2 import issue. Check SAM2 installation."
|
|
298
|
+
)
|
|
299
|
+
elif isinstance(e, RuntimeError):
|
|
300
|
+
print(
|
|
301
|
+
"DEBUG: This appears to be a runtime issue, possibly GPU/CUDA related."
|
|
302
|
+
)
|
|
303
|
+
else:
|
|
304
|
+
print(f"DEBUG: Unexpected error type: {error_type}")
|
|
161
305
|
|
|
162
306
|
def load_images(self, folder_path: str):
|
|
163
307
|
"""Load images from the specified folder path."""
|
|
308
|
+
print(f"DEBUG: Loading images from folder: {folder_path}")
|
|
164
309
|
if not os.path.exists(folder_path):
|
|
165
310
|
self.viewer.status = f"Folder not found: {folder_path}"
|
|
311
|
+
print(f"DEBUG: Folder does not exist: {folder_path}")
|
|
166
312
|
return
|
|
167
313
|
|
|
168
314
|
files = os.listdir(folder_path)
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
and
|
|
178
|
-
|
|
315
|
+
print(f"DEBUG: Found {len(files)} files in folder")
|
|
316
|
+
self.images = []
|
|
317
|
+
for file in files:
|
|
318
|
+
full = os.path.join(folder_path, file)
|
|
319
|
+
low = file.lower()
|
|
320
|
+
if (
|
|
321
|
+
low.endswith((".tif", ".tiff"))
|
|
322
|
+
or (os.path.isdir(full) and low.endswith(".zarr"))
|
|
323
|
+
) and (
|
|
324
|
+
"label" not in low
|
|
325
|
+
and "_labels_" not in low
|
|
326
|
+
and "sam2"
|
|
327
|
+
not in low # Exclude any SAM2-related files (including output from this tool)
|
|
328
|
+
):
|
|
329
|
+
self.images.append(full)
|
|
330
|
+
print(f"DEBUG: Added image: {file}")
|
|
331
|
+
else:
|
|
332
|
+
print(
|
|
333
|
+
f"DEBUG: Excluded file: {file} (reason: filtering criteria)"
|
|
334
|
+
)
|
|
179
335
|
|
|
180
336
|
if not self.images:
|
|
181
337
|
self.viewer.status = "No compatible images found in the folder."
|
|
338
|
+
print("DEBUG: No compatible images found")
|
|
182
339
|
return
|
|
183
340
|
|
|
341
|
+
print(f"DEBUG: Total compatible images found: {len(self.images)}")
|
|
184
342
|
self.viewer.status = f"Found {len(self.images)} .tif images."
|
|
185
343
|
self.current_index = 0
|
|
344
|
+
print(f"DEBUG: About to load first image: {self.images[0]}")
|
|
186
345
|
self._load_current_image()
|
|
187
346
|
|
|
188
347
|
def next_image(self):
|
|
@@ -235,25 +394,69 @@ class BatchCropAnything:
|
|
|
235
394
|
|
|
236
395
|
def _load_current_image(self):
|
|
237
396
|
"""Load the current image and generate segmentation."""
|
|
397
|
+
print("DEBUG: _load_current_image called")
|
|
238
398
|
if not self.images:
|
|
239
399
|
self.viewer.status = "No images to process."
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
if self.predictor is None:
|
|
243
|
-
self.viewer.status = (
|
|
244
|
-
"SAM2 model not initialized. Cannot segment images."
|
|
245
|
-
)
|
|
400
|
+
print("DEBUG: No images to process")
|
|
246
401
|
return
|
|
247
402
|
|
|
248
403
|
image_path = self.images[self.current_index]
|
|
249
|
-
|
|
404
|
+
print(f"DEBUG: Loading image at path: {image_path}")
|
|
405
|
+
|
|
406
|
+
if self.predictor is None:
|
|
407
|
+
self.viewer.status = f"Loading {os.path.basename(image_path)} (SAM2 model not initialized - no segmentation will be available)"
|
|
408
|
+
print("DEBUG: SAM2 predictor is None")
|
|
409
|
+
else:
|
|
410
|
+
self.viewer.status = f"Processing {os.path.basename(image_path)}"
|
|
411
|
+
print("DEBUG: SAM2 predictor is available")
|
|
250
412
|
|
|
251
413
|
try:
|
|
414
|
+
print("DEBUG: About to clear viewer layers")
|
|
252
415
|
# Clear existing layers
|
|
253
416
|
self.viewer.layers.clear()
|
|
417
|
+
print("DEBUG: Viewer layers cleared")
|
|
254
418
|
|
|
419
|
+
print("DEBUG: About to load image file")
|
|
255
420
|
# Load and process image
|
|
256
|
-
|
|
421
|
+
if image_path.lower().endswith(".zarr") or (
|
|
422
|
+
os.path.isdir(image_path)
|
|
423
|
+
and image_path.lower().endswith(".zarr")
|
|
424
|
+
):
|
|
425
|
+
print("DEBUG: Loading Zarr file")
|
|
426
|
+
data = load_any_image(image_path)
|
|
427
|
+
# If multiple layers returned, take first image layer
|
|
428
|
+
if isinstance(data, list):
|
|
429
|
+
img = None
|
|
430
|
+
for entry in data:
|
|
431
|
+
if isinstance(entry, tuple) and len(entry) == 3:
|
|
432
|
+
d, _kwargs, layer_type = entry
|
|
433
|
+
if layer_type == "image":
|
|
434
|
+
img = d
|
|
435
|
+
break
|
|
436
|
+
elif isinstance(entry, tuple) and len(entry) == 2:
|
|
437
|
+
d, _kwargs = entry
|
|
438
|
+
img = d
|
|
439
|
+
break
|
|
440
|
+
else:
|
|
441
|
+
img = entry
|
|
442
|
+
break
|
|
443
|
+
if img is None:
|
|
444
|
+
raise ValueError("No image layer found in Zarr store")
|
|
445
|
+
else:
|
|
446
|
+
img = data
|
|
447
|
+
|
|
448
|
+
# Compute dask arrays to numpy if needed
|
|
449
|
+
if hasattr(img, "compute"):
|
|
450
|
+
img = img.compute()
|
|
451
|
+
|
|
452
|
+
self.original_image = img
|
|
453
|
+
else:
|
|
454
|
+
print("DEBUG: Loading TIFF file")
|
|
455
|
+
self.original_image = imread(image_path)
|
|
456
|
+
|
|
457
|
+
print(
|
|
458
|
+
f"DEBUG: Image loaded, shape: {self.original_image.shape}, dtype: {self.original_image.dtype}"
|
|
459
|
+
)
|
|
257
460
|
|
|
258
461
|
# For 3D/4D data, determine dimensions
|
|
259
462
|
if self.use_3d and len(self.original_image.shape) >= 3:
|
|
@@ -269,10 +472,12 @@ class BatchCropAnything:
|
|
|
269
472
|
|
|
270
473
|
if time_dim_idx == 0: # TZYX format
|
|
271
474
|
# Keep as is, T is already the first dimension
|
|
475
|
+
print("DEBUG: Adding 4D image (TZYX format) to viewer")
|
|
272
476
|
self.image_layer = self.viewer.add_image(
|
|
273
477
|
self.original_image,
|
|
274
478
|
name=f"Image ({os.path.basename(image_path)})",
|
|
275
479
|
)
|
|
480
|
+
print(f"DEBUG: Added image layer: {self.image_layer}")
|
|
276
481
|
# Store time dimension info
|
|
277
482
|
self.time_dim_size = self.original_image.shape[0]
|
|
278
483
|
self.has_z_dim = True
|
|
@@ -294,19 +499,23 @@ class BatchCropAnything:
|
|
|
294
499
|
transposed_image # Replace with transposed version
|
|
295
500
|
)
|
|
296
501
|
|
|
502
|
+
print("DEBUG: Adding transposed 4D image to viewer")
|
|
297
503
|
self.image_layer = self.viewer.add_image(
|
|
298
504
|
self.original_image,
|
|
299
505
|
name=f"Image ({os.path.basename(image_path)})",
|
|
300
506
|
)
|
|
507
|
+
print(f"DEBUG: Added image layer: {self.image_layer}")
|
|
301
508
|
# Store time dimension info
|
|
302
509
|
self.time_dim_size = self.original_image.shape[0]
|
|
303
510
|
self.has_z_dim = True
|
|
304
511
|
else:
|
|
305
512
|
# No time dimension found, treat as ZYX
|
|
513
|
+
print("DEBUG: Adding 4D image (ZYX format) to viewer")
|
|
306
514
|
self.image_layer = self.viewer.add_image(
|
|
307
515
|
self.original_image,
|
|
308
516
|
name=f"Image ({os.path.basename(image_path)})",
|
|
309
517
|
)
|
|
518
|
+
print(f"DEBUG: Added image layer: {self.image_layer}")
|
|
310
519
|
self.time_dim_size = 1
|
|
311
520
|
self.has_z_dim = True
|
|
312
521
|
elif (
|
|
@@ -315,30 +524,37 @@ class BatchCropAnything:
|
|
|
315
524
|
# Check if first dimension is likely time (> 4, < 400)
|
|
316
525
|
if 4 < self.original_image.shape[0] < 400:
|
|
317
526
|
# Likely TYX format
|
|
527
|
+
print("DEBUG: Adding 3D image (TYX format) to viewer")
|
|
318
528
|
self.image_layer = self.viewer.add_image(
|
|
319
529
|
self.original_image,
|
|
320
530
|
name=f"Image ({os.path.basename(image_path)})",
|
|
321
531
|
)
|
|
532
|
+
print(f"DEBUG: Added image layer: {self.image_layer}")
|
|
322
533
|
self.time_dim_size = self.original_image.shape[0]
|
|
323
534
|
self.has_z_dim = False
|
|
324
535
|
else:
|
|
325
536
|
# Likely ZYX format or another 3D format
|
|
537
|
+
print("DEBUG: Adding 3D image (ZYX format) to viewer")
|
|
326
538
|
self.image_layer = self.viewer.add_image(
|
|
327
539
|
self.original_image,
|
|
328
540
|
name=f"Image ({os.path.basename(image_path)})",
|
|
329
541
|
)
|
|
542
|
+
print(f"DEBUG: Added image layer: {self.image_layer}")
|
|
330
543
|
self.time_dim_size = 1
|
|
331
544
|
self.has_z_dim = True
|
|
332
545
|
else:
|
|
333
546
|
# Should not reach here with use_3d=True, but just in case
|
|
547
|
+
print("DEBUG: Adding 3D image (fallback) to viewer")
|
|
334
548
|
self.image_layer = self.viewer.add_image(
|
|
335
549
|
self.original_image,
|
|
336
550
|
name=f"Image ({os.path.basename(image_path)})",
|
|
337
551
|
)
|
|
552
|
+
print(f"DEBUG: Added image layer: {self.image_layer}")
|
|
338
553
|
self.time_dim_size = 1
|
|
339
554
|
self.has_z_dim = False
|
|
340
555
|
else:
|
|
341
556
|
# Handle 2D data as before
|
|
557
|
+
print("DEBUG: Processing 2D image")
|
|
342
558
|
if self.original_image.dtype != np.uint8:
|
|
343
559
|
image_for_display = (
|
|
344
560
|
self.original_image
|
|
@@ -349,18 +565,42 @@ class BatchCropAnything:
|
|
|
349
565
|
image_for_display = self.original_image
|
|
350
566
|
|
|
351
567
|
# Add image to viewer
|
|
568
|
+
print("DEBUG: Adding 2D image to viewer")
|
|
352
569
|
self.image_layer = self.viewer.add_image(
|
|
353
570
|
image_for_display,
|
|
354
571
|
name=f"Image ({os.path.basename(image_path)})",
|
|
355
572
|
)
|
|
573
|
+
print(f"DEBUG: Added image layer: {self.image_layer}")
|
|
574
|
+
|
|
575
|
+
# Generate segmentation only if predictor is available
|
|
576
|
+
if self.predictor is not None:
|
|
577
|
+
print("DEBUG: About to generate segmentation")
|
|
578
|
+
self._generate_segmentation(self.original_image, image_path)
|
|
579
|
+
print("DEBUG: Segmentation generation completed")
|
|
580
|
+
else:
|
|
581
|
+
print("DEBUG: Creating empty segmentation (no predictor)")
|
|
582
|
+
# Create empty segmentation when predictor is not available
|
|
583
|
+
if self.use_3d:
|
|
584
|
+
shape = self.original_image.shape
|
|
585
|
+
else:
|
|
586
|
+
shape = self.original_image.shape[:2]
|
|
587
|
+
|
|
588
|
+
self.segmentation_result = np.zeros(shape, dtype=np.uint32)
|
|
589
|
+
self.label_layer = self.viewer.add_labels(
|
|
590
|
+
self.segmentation_result,
|
|
591
|
+
name="No Segmentation (SAM2 not available)",
|
|
592
|
+
)
|
|
593
|
+
print(f"DEBUG: Added empty label layer: {self.label_layer}")
|
|
356
594
|
|
|
357
|
-
|
|
358
|
-
self._generate_segmentation(self.original_image, image_path)
|
|
595
|
+
print("DEBUG: _load_current_image completed successfully")
|
|
359
596
|
|
|
360
597
|
except (FileNotFoundError, ValueError, TypeError, OSError) as e:
|
|
361
598
|
import traceback
|
|
362
599
|
|
|
363
|
-
|
|
600
|
+
error_msg = f"Error processing image: {str(e)}"
|
|
601
|
+
self.viewer.status = error_msg
|
|
602
|
+
print(f"DEBUG: Exception in _load_current_image: {error_msg}")
|
|
603
|
+
print("DEBUG: Full traceback:")
|
|
364
604
|
traceback.print_exc()
|
|
365
605
|
|
|
366
606
|
# Create empty segmentation in case of error
|
|
@@ -377,6 +617,7 @@ class BatchCropAnything:
|
|
|
377
617
|
self.label_layer = self.viewer.add_labels(
|
|
378
618
|
self.segmentation_result, name="Error: No Segmentation"
|
|
379
619
|
)
|
|
620
|
+
print(f"DEBUG: Added error label layer: {self.label_layer}")
|
|
380
621
|
|
|
381
622
|
def _generate_segmentation(self, image, image_path: str):
|
|
382
623
|
"""Generate segmentation for the current image using SAM2."""
|
|
@@ -432,7 +673,8 @@ class BatchCropAnything:
|
|
|
432
673
|
traceback.print_exc()
|
|
433
674
|
|
|
434
675
|
def _generate_2d_segmentation(self, confidence_threshold):
|
|
435
|
-
"""Generate 2D segmentation
|
|
676
|
+
"""Generate initial 2D segmentation - start with empty labels for interactive mode."""
|
|
677
|
+
device_type = "cuda" if self.device.type == "cuda" else "cpu"
|
|
436
678
|
# Ensure image is in the correct format for SAM2
|
|
437
679
|
image = self.current_image_for_segmentation
|
|
438
680
|
|
|
@@ -454,9 +696,7 @@ class BatchCropAnything:
|
|
|
454
696
|
(new_height, new_width),
|
|
455
697
|
anti_aliasing=True,
|
|
456
698
|
preserve_range=True,
|
|
457
|
-
).astype(
|
|
458
|
-
np.float32
|
|
459
|
-
) # Convert to float32
|
|
699
|
+
).astype(np.float32)
|
|
460
700
|
|
|
461
701
|
self.current_scale_factor = scale_factor
|
|
462
702
|
else:
|
|
@@ -482,73 +722,54 @@ class BatchCropAnything:
|
|
|
482
722
|
if resized_image.max() > 1.0:
|
|
483
723
|
resized_image = resized_image / 255.0
|
|
484
724
|
|
|
485
|
-
#
|
|
486
|
-
|
|
487
|
-
"cuda", dtype=torch.float32
|
|
488
|
-
):
|
|
489
|
-
# Set the image in the predictor
|
|
490
|
-
self.predictor.set_image(resized_image)
|
|
725
|
+
# Store the prepared image for later use
|
|
726
|
+
self.prepared_sam2_image = resized_image
|
|
491
727
|
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
point_labels=None,
|
|
496
|
-
box=None,
|
|
497
|
-
multimask_output=True,
|
|
498
|
-
)
|
|
728
|
+
# Initialize empty segmentation result
|
|
729
|
+
self.segmentation_result = np.zeros(orig_shape, dtype=np.uint32)
|
|
730
|
+
self.label_info = {}
|
|
499
731
|
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
self.
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
center_y = np.mean(y_indices) if len(y_indices) > 0 else 0
|
|
517
|
-
center_x = np.mean(x_indices) if len(x_indices) > 0 else 0
|
|
518
|
-
|
|
519
|
-
# Store label info
|
|
520
|
-
self.label_info[label_id] = {
|
|
521
|
-
"area": area,
|
|
522
|
-
"center_y": center_y,
|
|
523
|
-
"center_x": center_x,
|
|
524
|
-
"score": float(scores[i]),
|
|
525
|
-
}
|
|
526
|
-
|
|
527
|
-
# Handle upscaling if needed
|
|
528
|
-
if self.current_scale_factor < 1.0:
|
|
529
|
-
labels = resize(
|
|
530
|
-
labels,
|
|
531
|
-
orig_shape,
|
|
532
|
-
order=0, # Nearest neighbor interpolation
|
|
533
|
-
preserve_range=True,
|
|
534
|
-
anti_aliasing=False,
|
|
535
|
-
).astype(np.uint32)
|
|
536
|
-
|
|
537
|
-
# Sort labels by area (largest first)
|
|
538
|
-
self.label_info = dict(
|
|
539
|
-
sorted(
|
|
540
|
-
self.label_info.items(),
|
|
541
|
-
key=lambda item: item[1]["area"],
|
|
542
|
-
reverse=True,
|
|
543
|
-
)
|
|
732
|
+
# Initialize tracking for interactive segmentation
|
|
733
|
+
self.current_points = []
|
|
734
|
+
self.current_labels = []
|
|
735
|
+
self.current_obj_id = 1
|
|
736
|
+
self.next_obj_id = 1
|
|
737
|
+
|
|
738
|
+
# Initialize object tracking dictionaries
|
|
739
|
+
self.obj_points = {}
|
|
740
|
+
self.obj_labels = {}
|
|
741
|
+
|
|
742
|
+
# Reset SAM2-specific tracking dictionaries for 2D mode
|
|
743
|
+
self.sam2_points_by_obj = {}
|
|
744
|
+
self.sam2_labels_by_obj = {}
|
|
745
|
+
self._sam2_next_obj_id = 1
|
|
746
|
+
print(
|
|
747
|
+
"DEBUG: Reset _sam2_next_obj_id to 1 in _generate_2d_segmentation"
|
|
544
748
|
)
|
|
545
749
|
|
|
546
|
-
#
|
|
547
|
-
self.
|
|
750
|
+
# Set the image in the predictor for later use (2D mode only)
|
|
751
|
+
device_type = "cuda" if self.device.type == "cuda" else "cpu"
|
|
752
|
+
if hasattr(self.predictor, "set_image"):
|
|
753
|
+
with (
|
|
754
|
+
torch.inference_mode(),
|
|
755
|
+
torch.autocast(device_type, dtype=torch.float32),
|
|
756
|
+
):
|
|
757
|
+
self.predictor.set_image(resized_image)
|
|
758
|
+
else:
|
|
759
|
+
print(
|
|
760
|
+
"DEBUG: Skipping set_image - predictor doesn't support it (likely VideoPredictor)"
|
|
761
|
+
)
|
|
548
762
|
|
|
549
763
|
# Update the label layer
|
|
550
764
|
self._update_label_layer()
|
|
551
765
|
|
|
766
|
+
# Show instructions
|
|
767
|
+
self.viewer.status = (
|
|
768
|
+
"2D Mode: Click on the image to add objects. Use Shift+click for negative points to refine. "
|
|
769
|
+
"Click existing objects to select them for cropping. "
|
|
770
|
+
"Note: For stacks, interactive segmentation only works in 2D view mode."
|
|
771
|
+
)
|
|
772
|
+
|
|
552
773
|
def _generate_3d_segmentation(self, confidence_threshold, image_path):
|
|
553
774
|
"""
|
|
554
775
|
Initialize 3D segmentation using SAM2 Video Predictor.
|
|
@@ -569,9 +790,7 @@ class BatchCropAnything:
|
|
|
569
790
|
import tempfile
|
|
570
791
|
|
|
571
792
|
temp_dir = tempfile.gettempdir()
|
|
572
|
-
mp4_path =
|
|
573
|
-
temp_dir, f"temp_volume_{os.path.basename(image_path)}.mp4"
|
|
574
|
-
)
|
|
793
|
+
mp4_path = None
|
|
575
794
|
|
|
576
795
|
# If we need to save a modified version for MP4 conversion
|
|
577
796
|
need_temp_tif = False
|
|
@@ -601,31 +820,72 @@ class BatchCropAnything:
|
|
|
601
820
|
imwrite(temp_tif_path, projected_volume)
|
|
602
821
|
need_temp_tif = True
|
|
603
822
|
|
|
604
|
-
#
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
823
|
+
# Check if MP4 already exists
|
|
824
|
+
expected_mp4 = str(Path(temp_tif_path).with_suffix(".mp4"))
|
|
825
|
+
if os.path.exists(expected_mp4):
|
|
826
|
+
self.viewer.status = (
|
|
827
|
+
f"Using existing MP4: {os.path.basename(expected_mp4)}"
|
|
828
|
+
)
|
|
829
|
+
print(
|
|
830
|
+
f"DEBUG: MP4 already exists, skipping conversion: {expected_mp4}"
|
|
831
|
+
)
|
|
832
|
+
mp4_path = expected_mp4
|
|
833
|
+
else:
|
|
834
|
+
# Convert the projected TIF to MP4
|
|
835
|
+
self.viewer.status = "Converting projected 3D volume to MP4 format for SAM2..."
|
|
836
|
+
mp4_path = tif_to_mp4(temp_tif_path)
|
|
609
837
|
else:
|
|
610
|
-
#
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
838
|
+
# Check if MP4 already exists for the original image
|
|
839
|
+
expected_mp4 = str(Path(image_path).with_suffix(".mp4"))
|
|
840
|
+
if os.path.exists(expected_mp4):
|
|
841
|
+
self.viewer.status = (
|
|
842
|
+
f"Using existing MP4: {os.path.basename(expected_mp4)}"
|
|
843
|
+
)
|
|
844
|
+
print(
|
|
845
|
+
f"DEBUG: MP4 already exists, skipping conversion: {expected_mp4}"
|
|
846
|
+
)
|
|
847
|
+
mp4_path = expected_mp4
|
|
848
|
+
else:
|
|
849
|
+
# Convert original volume to video format for SAM2
|
|
850
|
+
self.viewer.status = (
|
|
851
|
+
"Converting 3D volume to MP4 format for SAM2..."
|
|
852
|
+
)
|
|
853
|
+
mp4_path = tif_to_mp4(image_path)
|
|
615
854
|
|
|
616
855
|
# Initialize SAM2 state with the video
|
|
617
856
|
self.viewer.status = "Initializing SAM2 Video Predictor..."
|
|
618
|
-
|
|
619
|
-
"cuda"
|
|
620
|
-
|
|
621
|
-
|
|
857
|
+
try:
|
|
858
|
+
device_type = "cuda" if self.device.type == "cuda" else "cpu"
|
|
859
|
+
with (
|
|
860
|
+
torch.inference_mode(),
|
|
861
|
+
torch.autocast(device_type, dtype=torch.float32),
|
|
862
|
+
):
|
|
863
|
+
self._sam2_state = self.predictor.init_state(mp4_path)
|
|
864
|
+
except (
|
|
865
|
+
RuntimeError,
|
|
866
|
+
ValueError,
|
|
867
|
+
TypeError,
|
|
868
|
+
torch.cuda.OutOfMemoryError,
|
|
869
|
+
) as e:
|
|
870
|
+
self.viewer.status = (
|
|
871
|
+
f"Error initializing SAM2 video predictor: {str(e)}"
|
|
872
|
+
)
|
|
873
|
+
print(f"SAM2 video predictor initialization failed: {e}")
|
|
874
|
+
return
|
|
622
875
|
|
|
623
876
|
# Store needed state for 3D processing
|
|
624
877
|
self._sam2_next_obj_id = 1
|
|
878
|
+
print(
|
|
879
|
+
"DEBUG: Reset _sam2_next_obj_id to 1 in _generate_3d_segmentation"
|
|
880
|
+
)
|
|
625
881
|
self._sam2_prompts = (
|
|
626
882
|
{}
|
|
627
883
|
) # Store prompts for each object (points, labels, box)
|
|
628
884
|
|
|
885
|
+
# Reset SAM2-specific tracking dictionaries for 3D mode
|
|
886
|
+
self.sam2_points_by_obj = {}
|
|
887
|
+
self.sam2_labels_by_obj = {}
|
|
888
|
+
|
|
629
889
|
# Update the label layer with empty segmentation
|
|
630
890
|
self._update_label_layer()
|
|
631
891
|
|
|
@@ -633,8 +893,10 @@ class BatchCropAnything:
|
|
|
633
893
|
if self.label_layer is not None and hasattr(
|
|
634
894
|
self.label_layer, "mouse_drag_callbacks"
|
|
635
895
|
):
|
|
896
|
+
# Safely remove all existing callbacks
|
|
636
897
|
for callback in list(self.label_layer.mouse_drag_callbacks):
|
|
637
|
-
|
|
898
|
+
with contextlib.suppress(ValueError):
|
|
899
|
+
self.label_layer.mouse_drag_callbacks.remove(callback)
|
|
638
900
|
|
|
639
901
|
# Add 3D-specific click handler
|
|
640
902
|
self.label_layer.mouse_drag_callbacks.append(
|
|
@@ -658,8 +920,8 @@ class BatchCropAnything:
|
|
|
658
920
|
|
|
659
921
|
# Show instructions
|
|
660
922
|
self.viewer.status = (
|
|
661
|
-
"3D Mode active: Navigate to the
|
|
662
|
-
"Use Shift+click for negative points
|
|
923
|
+
"3D Mode active: IMPORTANT - Navigate to the FIRST SLICE where object appears (using slider), "
|
|
924
|
+
"then click on object in 2D view (not 3D view). Use Shift+click for negative points. "
|
|
663
925
|
"Segmentation will be propagated to all frames automatically."
|
|
664
926
|
)
|
|
665
927
|
|
|
@@ -713,6 +975,9 @@ class BatchCropAnything:
|
|
|
713
975
|
# Create new object for positive points on background
|
|
714
976
|
ann_obj_id = self._sam2_next_obj_id
|
|
715
977
|
if point_label > 0 and label_id == 0:
|
|
978
|
+
print(
|
|
979
|
+
f"DEBUG: Incrementing _sam2_next_obj_id from {self._sam2_next_obj_id} to {self._sam2_next_obj_id + 1}"
|
|
980
|
+
)
|
|
716
981
|
self._sam2_next_obj_id += 1
|
|
717
982
|
|
|
718
983
|
# Find or create points layer for this object
|
|
@@ -733,6 +998,15 @@ class BatchCropAnything:
|
|
|
733
998
|
border_width=1,
|
|
734
999
|
opacity=0.8,
|
|
735
1000
|
)
|
|
1001
|
+
|
|
1002
|
+
with contextlib.suppress(AttributeError, ValueError):
|
|
1003
|
+
points_layer.mouse_drag_callbacks.remove(
|
|
1004
|
+
self._on_points_clicked
|
|
1005
|
+
)
|
|
1006
|
+
points_layer.mouse_drag_callbacks.append(
|
|
1007
|
+
self._on_points_clicked
|
|
1008
|
+
)
|
|
1009
|
+
|
|
736
1010
|
# Initialize points for this object
|
|
737
1011
|
if not hasattr(self, "sam2_points_by_obj"):
|
|
738
1012
|
self.sam2_points_by_obj = {}
|
|
@@ -891,8 +1165,10 @@ class BatchCropAnything:
|
|
|
891
1165
|
# Try to perform SAM2 propagation with error handling
|
|
892
1166
|
try:
|
|
893
1167
|
# Use torch.inference_mode() and torch.autocast to ensure consistent dtypes
|
|
894
|
-
|
|
895
|
-
|
|
1168
|
+
device_type = "cuda" if self.device.type == "cuda" else "cpu"
|
|
1169
|
+
with (
|
|
1170
|
+
torch.inference_mode(),
|
|
1171
|
+
torch.autocast(device_type, dtype=torch.float32),
|
|
896
1172
|
):
|
|
897
1173
|
# Attempt to run SAM2 propagation - this will iterate through all frames
|
|
898
1174
|
for (
|
|
@@ -988,7 +1264,11 @@ class BatchCropAnything:
|
|
|
988
1264
|
time.sleep(2)
|
|
989
1265
|
for layer in list(self.viewer.layers):
|
|
990
1266
|
if "Propagation Progress" in layer.name:
|
|
991
|
-
|
|
1267
|
+
# Clean up callbacks before removing the layer to prevent cleanup issues
|
|
1268
|
+
if hasattr(layer, "mouse_drag_callbacks"):
|
|
1269
|
+
layer.mouse_drag_callbacks.clear()
|
|
1270
|
+
with contextlib.suppress(ValueError):
|
|
1271
|
+
self.viewer.layers.remove(layer)
|
|
992
1272
|
|
|
993
1273
|
threading.Thread(target=remove_progress).start()
|
|
994
1274
|
|
|
@@ -1011,6 +1291,7 @@ class BatchCropAnything:
|
|
|
1011
1291
|
Given a 3D coordinate (x, y, z), run SAM2 video predictor to segment the object at that point,
|
|
1012
1292
|
update the segmentation result and label layer.
|
|
1013
1293
|
"""
|
|
1294
|
+
device_type = "cuda" if self.device.type == "cuda" else "cpu"
|
|
1014
1295
|
if not hasattr(self, "_sam2_state") or self._sam2_state is None:
|
|
1015
1296
|
self.viewer.status = "SAM2 3D state not initialized."
|
|
1016
1297
|
return
|
|
@@ -1024,8 +1305,9 @@ class BatchCropAnything:
|
|
|
1024
1305
|
point_coords = np.array([[x, y, z]])
|
|
1025
1306
|
point_labels = np.array([1]) # 1 = foreground
|
|
1026
1307
|
|
|
1027
|
-
with
|
|
1028
|
-
|
|
1308
|
+
with (
|
|
1309
|
+
torch.inference_mode(),
|
|
1310
|
+
torch.autocast(device_type, dtype=torch.float32),
|
|
1029
1311
|
):
|
|
1030
1312
|
masks, scores, _ = self.predictor.predict(
|
|
1031
1313
|
state=self._sam2_state,
|
|
@@ -1079,7 +1361,11 @@ class BatchCropAnything:
|
|
|
1079
1361
|
# Remove existing label layer if it exists
|
|
1080
1362
|
for layer in list(self.viewer.layers):
|
|
1081
1363
|
if isinstance(layer, Labels) and "Segmentation" in layer.name:
|
|
1082
|
-
|
|
1364
|
+
# Clean up callbacks before removing the layer to prevent cleanup issues
|
|
1365
|
+
if hasattr(layer, "mouse_drag_callbacks"):
|
|
1366
|
+
layer.mouse_drag_callbacks.clear()
|
|
1367
|
+
with contextlib.suppress(ValueError):
|
|
1368
|
+
self.viewer.layers.remove(layer)
|
|
1083
1369
|
|
|
1084
1370
|
# Add label layer to viewer
|
|
1085
1371
|
self.label_layer = self.viewer.add_labels(
|
|
@@ -1088,10 +1374,36 @@ class BatchCropAnything:
|
|
|
1088
1374
|
opacity=0.7,
|
|
1089
1375
|
)
|
|
1090
1376
|
|
|
1091
|
-
#
|
|
1377
|
+
# Connect click handler to the label layer for selection and deletion
|
|
1378
|
+
if hasattr(self.label_layer, "mouse_drag_callbacks"):
|
|
1379
|
+
# Clear existing callbacks to avoid duplicates
|
|
1380
|
+
self.label_layer.mouse_drag_callbacks.clear()
|
|
1381
|
+
# Add our click handler
|
|
1382
|
+
self.label_layer.mouse_drag_callbacks.append(
|
|
1383
|
+
self._on_label_clicked
|
|
1384
|
+
)
|
|
1385
|
+
|
|
1386
|
+
# Create or update interaction layers based on mode
|
|
1387
|
+
if self.prompt_mode == "point":
|
|
1388
|
+
self._ensure_points_layer()
|
|
1389
|
+
self._remove_shapes_layer()
|
|
1390
|
+
else: # box mode
|
|
1391
|
+
self._ensure_shapes_layer()
|
|
1392
|
+
self._remove_points_layer()
|
|
1393
|
+
|
|
1394
|
+
# Update status
|
|
1395
|
+
n_labels = len(np.unique(self.segmentation_result)) - (
|
|
1396
|
+
1 if 0 in np.unique(self.segmentation_result) else 0
|
|
1397
|
+
)
|
|
1398
|
+
self.viewer.status = f"Loaded image {self.current_index + 1}/{len(self.images)} - Found {n_labels} segments"
|
|
1399
|
+
|
|
1400
|
+
def _ensure_points_layer(self):
|
|
1401
|
+
"""Ensure points layer exists and is properly configured."""
|
|
1092
1402
|
points_layer = None
|
|
1093
1403
|
for layer in list(self.viewer.layers):
|
|
1094
|
-
if
|
|
1404
|
+
if (
|
|
1405
|
+
"Points" in layer.name and "Object" not in layer.name
|
|
1406
|
+
): # Main points layer
|
|
1095
1407
|
points_layer = layer
|
|
1096
1408
|
break
|
|
1097
1409
|
|
|
@@ -1108,24 +1420,424 @@ class BatchCropAnything:
|
|
|
1108
1420
|
)
|
|
1109
1421
|
|
|
1110
1422
|
# Connect points layer mouse click event
|
|
1111
|
-
points_layer
|
|
1423
|
+
if hasattr(points_layer, "mouse_drag_callbacks"):
|
|
1424
|
+
points_layer.mouse_drag_callbacks.clear()
|
|
1425
|
+
points_layer.mouse_drag_callbacks.append(
|
|
1426
|
+
self._on_points_clicked
|
|
1427
|
+
)
|
|
1112
1428
|
|
|
1113
1429
|
# Make the points layer active to encourage interaction with it
|
|
1114
1430
|
self.viewer.layers.selection.active = points_layer
|
|
1115
1431
|
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
)
|
|
1120
|
-
|
|
1432
|
+
def _ensure_shapes_layer(self):
|
|
1433
|
+
"""Ensure shapes layer exists and is properly configured."""
|
|
1434
|
+
shapes_layer = None
|
|
1435
|
+
for layer in list(self.viewer.layers):
|
|
1436
|
+
if "Rectangles" in layer.name:
|
|
1437
|
+
shapes_layer = layer
|
|
1438
|
+
break
|
|
1439
|
+
|
|
1440
|
+
if shapes_layer is None:
|
|
1441
|
+
# Initialize an empty shapes layer
|
|
1442
|
+
shapes_layer = self.viewer.add_shapes(
|
|
1443
|
+
None,
|
|
1444
|
+
shape_type="rectangle",
|
|
1445
|
+
edge_width=3,
|
|
1446
|
+
edge_color="green",
|
|
1447
|
+
face_color="transparent",
|
|
1448
|
+
name="Rectangles (Draw to Segment)",
|
|
1449
|
+
)
|
|
1450
|
+
|
|
1451
|
+
# Store reference
|
|
1452
|
+
self.shapes_layer = shapes_layer
|
|
1453
|
+
|
|
1454
|
+
# Initialize processing flag to prevent re-entry
|
|
1455
|
+
if not hasattr(self, "_processing_rectangle"):
|
|
1456
|
+
self._processing_rectangle = False
|
|
1457
|
+
|
|
1458
|
+
# Always ensure the event is connected (disconnect old ones first to avoid duplicates)
|
|
1459
|
+
# Remove any existing callbacks
|
|
1460
|
+
with contextlib.suppress(Exception):
|
|
1461
|
+
shapes_layer.events.data.disconnect()
|
|
1462
|
+
|
|
1463
|
+
# Connect shape added event
|
|
1464
|
+
@shapes_layer.events.data.connect
|
|
1465
|
+
def on_shape_added(event):
|
|
1466
|
+
print(
|
|
1467
|
+
f"DEBUG: Shape event triggered! Shapes: {len(shapes_layer.data)}, Processing: {self._processing_rectangle}"
|
|
1468
|
+
)
|
|
1469
|
+
|
|
1470
|
+
# Ignore if we're already processing or if there are no shapes
|
|
1471
|
+
if self._processing_rectangle:
|
|
1472
|
+
print("DEBUG: Already processing a rectangle, ignoring event")
|
|
1473
|
+
return
|
|
1474
|
+
|
|
1475
|
+
if len(shapes_layer.data) == 0:
|
|
1476
|
+
print("DEBUG: No shapes present, ignoring event")
|
|
1477
|
+
return
|
|
1478
|
+
|
|
1479
|
+
# Only process if we have exactly 1 shape (newly drawn)
|
|
1480
|
+
if len(shapes_layer.data) == 1:
|
|
1481
|
+
print("DEBUG: New shape detected, processing...")
|
|
1482
|
+
# Set flag to prevent re-entry
|
|
1483
|
+
self._processing_rectangle = True
|
|
1484
|
+
try:
|
|
1485
|
+
# Get the shape
|
|
1486
|
+
self._on_rectangle_added(shapes_layer.data[-1])
|
|
1487
|
+
finally:
|
|
1488
|
+
# Always reset flag
|
|
1489
|
+
self._processing_rectangle = False
|
|
1490
|
+
else:
|
|
1491
|
+
print(
|
|
1492
|
+
f"DEBUG: Multiple shapes present ({len(shapes_layer.data)}), skipping"
|
|
1493
|
+
)
|
|
1494
|
+
|
|
1495
|
+
# Make the shapes layer active
|
|
1496
|
+
self.viewer.layers.selection.active = shapes_layer
|
|
1497
|
+
|
|
1498
|
+
def _remove_points_layer(self):
|
|
1499
|
+
"""Remove points layer when not in point mode."""
|
|
1500
|
+
for layer in list(self.viewer.layers):
|
|
1501
|
+
if "Points" in layer.name and "Object" not in layer.name:
|
|
1502
|
+
if hasattr(layer, "mouse_drag_callbacks"):
|
|
1503
|
+
layer.mouse_drag_callbacks.clear()
|
|
1504
|
+
with contextlib.suppress(ValueError):
|
|
1505
|
+
self.viewer.layers.remove(layer)
|
|
1506
|
+
|
|
1507
|
+
def _remove_shapes_layer(self):
|
|
1508
|
+
"""Remove shapes layer when not in box mode."""
|
|
1509
|
+
for layer in list(self.viewer.layers):
|
|
1510
|
+
if "Rectangles" in layer.name:
|
|
1511
|
+
with contextlib.suppress(ValueError):
|
|
1512
|
+
self.viewer.layers.remove(layer)
|
|
1513
|
+
self.shapes_layer = None
|
|
1514
|
+
|
|
1515
|
+
def _on_rectangle_added(self, rectangle_coords):
|
|
1516
|
+
"""Handle rectangle selection for segmentation."""
|
|
1517
|
+
print("DEBUG: _on_rectangle_added called!")
|
|
1518
|
+
device_type = "cuda" if self.device.type == "cuda" else "cpu"
|
|
1519
|
+
try:
|
|
1520
|
+
# Rectangle coords are in the form of a 4x2 or 4x3 array (corners)
|
|
1521
|
+
# Convert to bounding box format [x_min, y_min, x_max, y_max]
|
|
1522
|
+
|
|
1523
|
+
# Debug info
|
|
1524
|
+
print(f"DEBUG: Rectangle coords: {rectangle_coords}")
|
|
1525
|
+
print(f"DEBUG: Rectangle coords shape: {rectangle_coords.shape}")
|
|
1526
|
+
print(f"DEBUG: use_3d flag: {self.use_3d}")
|
|
1527
|
+
print(
|
|
1528
|
+
f"DEBUG: Has predictor: {hasattr(self, 'predictor') and self.predictor is not None}"
|
|
1529
|
+
)
|
|
1530
|
+
if hasattr(self, "predictor") and self.predictor is not None:
|
|
1531
|
+
print(
|
|
1532
|
+
f"DEBUG: Predictor type: {type(self.predictor).__name__}"
|
|
1533
|
+
)
|
|
1534
|
+
else:
|
|
1535
|
+
print("DEBUG: No predictor available!")
|
|
1536
|
+
self.viewer.status = "Error: Predictor not initialized"
|
|
1537
|
+
return
|
|
1538
|
+
|
|
1539
|
+
# Check if we're in 3D mode (use the flag, not coordinate shape)
|
|
1540
|
+
# In 3D mode, even when drawing on a 2D slice, we get (4, 2) coords
|
|
1541
|
+
# but we need to treat it as 3D with propagation
|
|
1542
|
+
if (
|
|
1543
|
+
self.use_3d
|
|
1544
|
+
and len(rectangle_coords.shape) == 2
|
|
1545
|
+
and rectangle_coords.shape[0] == 4
|
|
1546
|
+
):
|
|
1547
|
+
print("DEBUG: Processing as 3D rectangle (will propagate)")
|
|
1548
|
+
|
|
1549
|
+
# Get current frame/slice
|
|
1550
|
+
t = int(self.viewer.dims.current_step[0])
|
|
1551
|
+
print(f"DEBUG: Current frame/slice: {t}")
|
|
1552
|
+
|
|
1553
|
+
# Get Y and X bounds from 2D coordinates
|
|
1554
|
+
if rectangle_coords.shape[1] == 3:
|
|
1555
|
+
# If we somehow got 3D coords (T/Z, Y, X)
|
|
1556
|
+
y_coords = rectangle_coords[:, 1]
|
|
1557
|
+
x_coords = rectangle_coords[:, 2]
|
|
1558
|
+
elif rectangle_coords.shape[1] == 2:
|
|
1559
|
+
# More common: 2D coords (Y, X) when drawing on a slice
|
|
1560
|
+
y_coords = rectangle_coords[:, 0]
|
|
1561
|
+
x_coords = rectangle_coords[:, 1]
|
|
1562
|
+
else:
|
|
1563
|
+
print(
|
|
1564
|
+
f"DEBUG: Unexpected coordinate dimensions: {rectangle_coords.shape[1]}"
|
|
1565
|
+
)
|
|
1566
|
+
self.viewer.status = "Error: Unexpected rectangle format"
|
|
1567
|
+
return
|
|
1568
|
+
|
|
1569
|
+
y_min, y_max = int(min(y_coords)), int(max(y_coords))
|
|
1570
|
+
x_min, x_max = int(min(x_coords)), int(max(x_coords))
|
|
1571
|
+
|
|
1572
|
+
box = np.array([x_min, y_min, x_max, y_max], dtype=np.float32)
|
|
1573
|
+
print(f"DEBUG: Box coordinates: {box}")
|
|
1574
|
+
|
|
1575
|
+
# Use SAM2 with box prompt - use _sam2_next_obj_id for 3D mode
|
|
1576
|
+
if not hasattr(self, "_sam2_next_obj_id"):
|
|
1577
|
+
self._sam2_next_obj_id = 1
|
|
1578
|
+
obj_id = self._sam2_next_obj_id
|
|
1579
|
+
self._sam2_next_obj_id += 1
|
|
1580
|
+
print(
|
|
1581
|
+
f"DEBUG: Box mode - using object ID {obj_id}, next will be {self._sam2_next_obj_id}"
|
|
1582
|
+
)
|
|
1583
|
+
|
|
1584
|
+
# Store box for this object
|
|
1585
|
+
if not hasattr(self, "obj_boxes"):
|
|
1586
|
+
self.obj_boxes = {}
|
|
1587
|
+
self.obj_boxes[obj_id] = box
|
|
1588
|
+
|
|
1589
|
+
# Perform segmentation with 3D propagation
|
|
1590
|
+
if (
|
|
1591
|
+
hasattr(self, "_sam2_state")
|
|
1592
|
+
and self._sam2_state is not None
|
|
1593
|
+
):
|
|
1594
|
+
self.viewer.status = (
|
|
1595
|
+
f"Segmenting object {obj_id} with box at frame {t}..."
|
|
1596
|
+
)
|
|
1597
|
+
print(f"DEBUG: Starting segmentation for object {obj_id}")
|
|
1598
|
+
|
|
1599
|
+
_, out_obj_ids, out_mask_logits = (
|
|
1600
|
+
self.predictor.add_new_points_or_box(
|
|
1601
|
+
inference_state=self._sam2_state,
|
|
1602
|
+
frame_idx=t,
|
|
1603
|
+
obj_id=obj_id,
|
|
1604
|
+
box=box,
|
|
1605
|
+
)
|
|
1606
|
+
)
|
|
1607
|
+
|
|
1608
|
+
print("DEBUG: Segmentation complete, processing mask")
|
|
1609
|
+
# Update current frame
|
|
1610
|
+
mask = (out_mask_logits[0] > 0.0).cpu().numpy()
|
|
1611
|
+
if mask.ndim > 2:
|
|
1612
|
+
mask = mask.squeeze()
|
|
1613
|
+
|
|
1614
|
+
# Resize if needed
|
|
1615
|
+
if mask.shape != self.segmentation_result[t].shape:
|
|
1616
|
+
from skimage.transform import resize
|
|
1617
|
+
|
|
1618
|
+
mask = resize(
|
|
1619
|
+
mask.astype(float),
|
|
1620
|
+
self.segmentation_result[t].shape,
|
|
1621
|
+
order=0,
|
|
1622
|
+
preserve_range=True,
|
|
1623
|
+
anti_aliasing=False,
|
|
1624
|
+
).astype(bool)
|
|
1625
|
+
|
|
1626
|
+
# Update segmentation
|
|
1627
|
+
self.segmentation_result[t][
|
|
1628
|
+
mask & (self.segmentation_result[t] == 0)
|
|
1629
|
+
] = obj_id
|
|
1630
|
+
|
|
1631
|
+
print(f"DEBUG: Starting propagation for object {obj_id}")
|
|
1632
|
+
# Propagate to all frames
|
|
1633
|
+
self._propagate_mask_for_current_object(obj_id, t)
|
|
1634
|
+
|
|
1635
|
+
# Update UI
|
|
1636
|
+
print("DEBUG: Updating label layer")
|
|
1637
|
+
self._update_label_layer()
|
|
1638
|
+
if (
|
|
1639
|
+
hasattr(self, "label_table_widget")
|
|
1640
|
+
and self.label_table_widget is not None
|
|
1641
|
+
):
|
|
1642
|
+
self._populate_label_table(self.label_table_widget)
|
|
1643
|
+
|
|
1644
|
+
self.viewer.status = (
|
|
1645
|
+
f"Segmented and propagated object {obj_id} from box"
|
|
1646
|
+
)
|
|
1647
|
+
print("DEBUG: Rectangle processing complete!")
|
|
1648
|
+
|
|
1649
|
+
# Keep the rectangle visible after processing
|
|
1650
|
+
# Users can manually delete it if needed
|
|
1651
|
+
# if self.shapes_layer is not None:
|
|
1652
|
+
# self.shapes_layer.data = []
|
|
1653
|
+
else:
|
|
1654
|
+
print("DEBUG: _sam2_state not available")
|
|
1655
|
+
self.viewer.status = (
|
|
1656
|
+
"Error: 3D segmentation state not initialized"
|
|
1657
|
+
)
|
|
1658
|
+
|
|
1659
|
+
elif (
|
|
1660
|
+
not self.use_3d
|
|
1661
|
+
and len(rectangle_coords.shape) == 2
|
|
1662
|
+
and rectangle_coords.shape[1] == 2
|
|
1663
|
+
):
|
|
1664
|
+
# 2D case: rectangle_coords shape is (4, 2) for Y, X
|
|
1665
|
+
if rectangle_coords.shape[0] == 4:
|
|
1666
|
+
# Get Y and X bounds
|
|
1667
|
+
y_coords = rectangle_coords[:, 0]
|
|
1668
|
+
x_coords = rectangle_coords[:, 1]
|
|
1669
|
+
y_min, y_max = int(min(y_coords)), int(max(y_coords))
|
|
1670
|
+
x_min, x_max = int(min(x_coords)), int(max(x_coords))
|
|
1671
|
+
|
|
1672
|
+
box = np.array(
|
|
1673
|
+
[x_min, y_min, x_max, y_max], dtype=np.float32
|
|
1674
|
+
)
|
|
1675
|
+
|
|
1676
|
+
# Use SAM2 with box prompt - use next_obj_id for 2D mode
|
|
1677
|
+
if not hasattr(self, "next_obj_id"):
|
|
1678
|
+
self.next_obj_id = 1
|
|
1679
|
+
obj_id = self.next_obj_id
|
|
1680
|
+
self.next_obj_id += 1
|
|
1681
|
+
print(
|
|
1682
|
+
f"DEBUG: 2D Box mode - using object ID {obj_id}, next will be {self.next_obj_id}"
|
|
1683
|
+
)
|
|
1684
|
+
|
|
1685
|
+
# Store box for this object
|
|
1686
|
+
if not hasattr(self, "obj_boxes"):
|
|
1687
|
+
self.obj_boxes = {}
|
|
1688
|
+
self.obj_boxes[obj_id] = box
|
|
1689
|
+
|
|
1690
|
+
# Perform segmentation
|
|
1691
|
+
if (
|
|
1692
|
+
hasattr(self, "predictor")
|
|
1693
|
+
and self.predictor is not None
|
|
1694
|
+
):
|
|
1695
|
+
# Make sure image is loaded
|
|
1696
|
+
if self.current_image_for_segmentation is None:
|
|
1697
|
+
self.viewer.status = (
|
|
1698
|
+
"No image loaded for segmentation"
|
|
1699
|
+
)
|
|
1700
|
+
return
|
|
1701
|
+
|
|
1702
|
+
# Prepare image for SAM2
|
|
1703
|
+
image = self.current_image_for_segmentation
|
|
1704
|
+
if len(image.shape) == 2:
|
|
1705
|
+
image = np.stack([image] * 3, axis=-1)
|
|
1706
|
+
elif len(image.shape) == 3 and image.shape[2] == 1:
|
|
1707
|
+
image = np.concatenate([image] * 3, axis=2)
|
|
1708
|
+
elif len(image.shape) == 3 and image.shape[2] > 3:
|
|
1709
|
+
image = image[:, :, :3]
|
|
1710
|
+
|
|
1711
|
+
if image.dtype != np.uint8:
|
|
1712
|
+
image = (image / np.max(image) * 255).astype(
|
|
1713
|
+
np.uint8
|
|
1714
|
+
)
|
|
1715
|
+
|
|
1716
|
+
# Set the image in the predictor (only for ImagePredictor, not VideoPredictor)
|
|
1717
|
+
if hasattr(self.predictor, "set_image"):
|
|
1718
|
+
self.predictor.set_image(image)
|
|
1719
|
+
else:
|
|
1720
|
+
self.viewer.status = "Error: Rectangle mode requires Image Predictor (2D mode)"
|
|
1721
|
+
return
|
|
1722
|
+
|
|
1723
|
+
self.viewer.status = (
|
|
1724
|
+
f"Segmenting object {obj_id} with box..."
|
|
1725
|
+
)
|
|
1726
|
+
|
|
1727
|
+
with (
|
|
1728
|
+
torch.inference_mode(),
|
|
1729
|
+
torch.autocast(device_type),
|
|
1730
|
+
):
|
|
1731
|
+
masks, scores, _ = self.predictor.predict(
|
|
1732
|
+
box=box,
|
|
1733
|
+
multimask_output=False,
|
|
1734
|
+
)
|
|
1735
|
+
|
|
1736
|
+
# Get the mask
|
|
1737
|
+
if len(masks) > 0:
|
|
1738
|
+
best_mask = masks[0]
|
|
1739
|
+
|
|
1740
|
+
# Resize if needed
|
|
1741
|
+
if (
|
|
1742
|
+
best_mask.shape
|
|
1743
|
+
!= self.segmentation_result.shape
|
|
1744
|
+
):
|
|
1745
|
+
from skimage.transform import resize
|
|
1746
|
+
|
|
1747
|
+
best_mask = resize(
|
|
1748
|
+
best_mask.astype(float),
|
|
1749
|
+
self.segmentation_result.shape,
|
|
1750
|
+
order=0,
|
|
1751
|
+
preserve_range=True,
|
|
1752
|
+
anti_aliasing=False,
|
|
1753
|
+
).astype(bool)
|
|
1754
|
+
|
|
1755
|
+
# Apply mask (only overwrite background)
|
|
1756
|
+
mask_condition = np.logical_and(
|
|
1757
|
+
best_mask, (self.segmentation_result == 0)
|
|
1758
|
+
)
|
|
1759
|
+
self.segmentation_result[mask_condition] = (
|
|
1760
|
+
obj_id
|
|
1761
|
+
)
|
|
1762
|
+
|
|
1763
|
+
# Update label info
|
|
1764
|
+
area = np.sum(
|
|
1765
|
+
self.segmentation_result == obj_id
|
|
1766
|
+
)
|
|
1767
|
+
y_indices, x_indices = np.where(
|
|
1768
|
+
self.segmentation_result == obj_id
|
|
1769
|
+
)
|
|
1770
|
+
center_y = (
|
|
1771
|
+
np.mean(y_indices)
|
|
1772
|
+
if len(y_indices) > 0
|
|
1773
|
+
else 0
|
|
1774
|
+
)
|
|
1775
|
+
center_x = (
|
|
1776
|
+
np.mean(x_indices)
|
|
1777
|
+
if len(x_indices) > 0
|
|
1778
|
+
else 0
|
|
1779
|
+
)
|
|
1780
|
+
|
|
1781
|
+
self.label_info[obj_id] = {
|
|
1782
|
+
"area": area,
|
|
1783
|
+
"center_y": center_y,
|
|
1784
|
+
"center_x": center_x,
|
|
1785
|
+
"score": float(scores[0]),
|
|
1786
|
+
}
|
|
1787
|
+
|
|
1788
|
+
self.viewer.status = (
|
|
1789
|
+
f"Segmented object {obj_id} from box"
|
|
1790
|
+
)
|
|
1791
|
+
else:
|
|
1792
|
+
self.viewer.status = "No valid mask produced"
|
|
1793
|
+
|
|
1794
|
+
# Update the UI
|
|
1795
|
+
self._update_label_layer()
|
|
1796
|
+
if (
|
|
1797
|
+
hasattr(self, "label_table_widget")
|
|
1798
|
+
and self.label_table_widget is not None
|
|
1799
|
+
):
|
|
1800
|
+
self._populate_label_table(self.label_table_widget)
|
|
1801
|
+
|
|
1802
|
+
# Keep the rectangle visible after processing
|
|
1803
|
+
# Users can manually delete it if needed
|
|
1804
|
+
# if self.shapes_layer is not None:
|
|
1805
|
+
# self.shapes_layer.data = []
|
|
1806
|
+
else:
|
|
1807
|
+
# Unexpected shape dimensions
|
|
1808
|
+
print(
|
|
1809
|
+
f"DEBUG: Unexpected rectangle shape: {rectangle_coords.shape}"
|
|
1810
|
+
)
|
|
1811
|
+
self.viewer.status = f"Error: Unexpected rectangle dimensions {rectangle_coords.shape}. Expected (4,2) for 2D or (4,3) for 3D."
|
|
1812
|
+
|
|
1813
|
+
except (
|
|
1814
|
+
IndexError,
|
|
1815
|
+
KeyError,
|
|
1816
|
+
ValueError,
|
|
1817
|
+
RuntimeError,
|
|
1818
|
+
TypeError,
|
|
1819
|
+
) as e:
|
|
1820
|
+
import traceback
|
|
1821
|
+
|
|
1822
|
+
self.viewer.status = f"Error in rectangle handling: {str(e)}"
|
|
1823
|
+
print("DEBUG: Exception in _on_rectangle_added:")
|
|
1824
|
+
traceback.print_exc()
|
|
1121
1825
|
|
|
1122
1826
|
def _on_points_clicked(self, layer, event):
|
|
1123
1827
|
"""Handle clicks on the points layer for adding/removing points."""
|
|
1828
|
+
device_type = "cuda" if self.device.type == "cuda" else "cpu"
|
|
1124
1829
|
try:
|
|
1125
1830
|
# Only process clicks, not drags
|
|
1126
1831
|
if event.type != "mouse_press":
|
|
1127
1832
|
return
|
|
1128
1833
|
|
|
1834
|
+
# Check if segmentation result exists
|
|
1835
|
+
if self.segmentation_result is None:
|
|
1836
|
+
self.viewer.status = (
|
|
1837
|
+
"Segmentation not ready. Please wait for image to load."
|
|
1838
|
+
)
|
|
1839
|
+
return
|
|
1840
|
+
|
|
1129
1841
|
# Get coordinates of mouse click
|
|
1130
1842
|
coords = np.round(event.position).astype(int)
|
|
1131
1843
|
|
|
@@ -1163,6 +1875,25 @@ class BatchCropAnything:
|
|
|
1163
1875
|
colors.append("red" if is_negative else "green")
|
|
1164
1876
|
layer.face_color = colors
|
|
1165
1877
|
|
|
1878
|
+
# Validate coordinates are within segmentation bounds
|
|
1879
|
+
if (
|
|
1880
|
+
t < 0
|
|
1881
|
+
or t >= self.segmentation_result.shape[0]
|
|
1882
|
+
or y < 0
|
|
1883
|
+
or y >= self.segmentation_result.shape[1]
|
|
1884
|
+
or x < 0
|
|
1885
|
+
or x >= self.segmentation_result.shape[2]
|
|
1886
|
+
):
|
|
1887
|
+
self.viewer.status = (
|
|
1888
|
+
f"Click at ({t}, {y}, {x}) is out of bounds for "
|
|
1889
|
+
f"segmentation shape {self.segmentation_result.shape}. "
|
|
1890
|
+
f"Please click within the image bounds."
|
|
1891
|
+
)
|
|
1892
|
+
# Remove the invalid point that was just added
|
|
1893
|
+
if len(layer.data) > 0:
|
|
1894
|
+
layer.data = layer.data[:-1]
|
|
1895
|
+
return
|
|
1896
|
+
|
|
1166
1897
|
# Get the object ID
|
|
1167
1898
|
# If clicking on existing segmentation with negative point
|
|
1168
1899
|
label_id = self.segmentation_result[t, y, x]
|
|
@@ -1366,7 +2097,11 @@ class BatchCropAnything:
|
|
|
1366
2097
|
time.sleep(2)
|
|
1367
2098
|
for layer in list(self.viewer.layers):
|
|
1368
2099
|
if "Propagation Progress" in layer.name:
|
|
1369
|
-
|
|
2100
|
+
# Clean up callbacks before removing the layer to prevent cleanup issues
|
|
2101
|
+
if hasattr(layer, "mouse_drag_callbacks"):
|
|
2102
|
+
layer.mouse_drag_callbacks.clear()
|
|
2103
|
+
with contextlib.suppress(ValueError):
|
|
2104
|
+
self.viewer.layers.remove(layer)
|
|
1370
2105
|
|
|
1371
2106
|
threading.Thread(target=remove_progress).start()
|
|
1372
2107
|
|
|
@@ -1407,6 +2142,23 @@ class BatchCropAnything:
|
|
|
1407
2142
|
colors.append("red" if is_negative else "green")
|
|
1408
2143
|
layer.face_color = colors
|
|
1409
2144
|
|
|
2145
|
+
# Validate coordinates are within segmentation bounds
|
|
2146
|
+
if (
|
|
2147
|
+
y < 0
|
|
2148
|
+
or y >= self.segmentation_result.shape[0]
|
|
2149
|
+
or x < 0
|
|
2150
|
+
or x >= self.segmentation_result.shape[1]
|
|
2151
|
+
):
|
|
2152
|
+
self.viewer.status = (
|
|
2153
|
+
f"Click at ({y}, {x}) is out of bounds for "
|
|
2154
|
+
f"segmentation shape {self.segmentation_result.shape}. "
|
|
2155
|
+
f"Please click within the image bounds."
|
|
2156
|
+
)
|
|
2157
|
+
# Remove the invalid point that was just added
|
|
2158
|
+
if len(layer.data) > 0:
|
|
2159
|
+
layer.data = layer.data[:-1]
|
|
2160
|
+
return
|
|
2161
|
+
|
|
1410
2162
|
# Get object ID
|
|
1411
2163
|
label_id = self.segmentation_result[y, x]
|
|
1412
2164
|
if is_negative and label_id > 0:
|
|
@@ -1451,8 +2203,14 @@ class BatchCropAnything:
|
|
|
1451
2203
|
if image.dtype != np.uint8:
|
|
1452
2204
|
image = (image / np.max(image) * 255).astype(np.uint8)
|
|
1453
2205
|
|
|
1454
|
-
# Set the image in the predictor
|
|
1455
|
-
self.predictor
|
|
2206
|
+
# Set the image in the predictor (only for ImagePredictor, not VideoPredictor)
|
|
2207
|
+
if hasattr(self.predictor, "set_image"):
|
|
2208
|
+
self.predictor.set_image(image)
|
|
2209
|
+
else:
|
|
2210
|
+
self.viewer.status = (
|
|
2211
|
+
"Error: Point mode in 2D requires Image Predictor"
|
|
2212
|
+
)
|
|
2213
|
+
return
|
|
1456
2214
|
|
|
1457
2215
|
# Use only points for current object
|
|
1458
2216
|
points = np.array(
|
|
@@ -1462,7 +2220,7 @@ class BatchCropAnything:
|
|
|
1462
2220
|
|
|
1463
2221
|
self.viewer.status = f"Segmenting object {obj_id} with {len(points)} points..."
|
|
1464
2222
|
|
|
1465
|
-
with torch.inference_mode(), torch.autocast(
|
|
2223
|
+
with torch.inference_mode(), torch.autocast(device_type):
|
|
1466
2224
|
masks, scores, _ = self.predictor.predict(
|
|
1467
2225
|
point_coords=points,
|
|
1468
2226
|
point_labels=labels,
|
|
@@ -1551,16 +2309,23 @@ class BatchCropAnything:
|
|
|
1551
2309
|
def _on_label_clicked(self, layer, event):
|
|
1552
2310
|
"""Handle label selection and user prompts on mouse click."""
|
|
1553
2311
|
try:
|
|
1554
|
-
# Only process
|
|
2312
|
+
# Only process mouse press events
|
|
1555
2313
|
if event.type != "mouse_press":
|
|
1556
2314
|
return
|
|
1557
2315
|
|
|
2316
|
+
# Only handle left mouse button
|
|
2317
|
+
if event.button != 1:
|
|
2318
|
+
return
|
|
2319
|
+
|
|
1558
2320
|
# Get coordinates of mouse click
|
|
1559
2321
|
coords = np.round(event.position).astype(int)
|
|
1560
2322
|
|
|
1561
|
-
# Check
|
|
2323
|
+
# Check modifiers
|
|
1562
2324
|
is_negative = "Shift" in event.modifiers
|
|
1563
|
-
|
|
2325
|
+
is_control = (
|
|
2326
|
+
"Control" in event.modifiers or "Ctrl" in event.modifiers
|
|
2327
|
+
)
|
|
2328
|
+
# point_label = -1 if is_negative else 1
|
|
1564
2329
|
|
|
1565
2330
|
# For 2D data
|
|
1566
2331
|
if not self.use_3d:
|
|
@@ -1581,254 +2346,13 @@ class BatchCropAnything:
|
|
|
1581
2346
|
# Get the label ID at the clicked position
|
|
1582
2347
|
label_id = self.segmentation_result[y, x]
|
|
1583
2348
|
|
|
1584
|
-
#
|
|
1585
|
-
if
|
|
1586
|
-
|
|
1587
|
-
|
|
1588
|
-
self.next_obj_id = (
|
|
1589
|
-
int(self.segmentation_result.max()) + 1
|
|
1590
|
-
)
|
|
1591
|
-
else:
|
|
1592
|
-
self.next_obj_id = 1
|
|
1593
|
-
|
|
1594
|
-
# If clicking on background or using negative click, handle segmentation
|
|
1595
|
-
if label_id == 0 or is_negative:
|
|
1596
|
-
# Find or create points layer for the current object we're working on
|
|
1597
|
-
current_obj_id = None
|
|
1598
|
-
|
|
1599
|
-
# If negative point on existing label, use that label's ID
|
|
1600
|
-
if is_negative and label_id > 0:
|
|
1601
|
-
current_obj_id = label_id
|
|
1602
|
-
# For positive clicks on background, create a new object
|
|
1603
|
-
elif point_label > 0 and label_id == 0:
|
|
1604
|
-
current_obj_id = self.next_obj_id
|
|
1605
|
-
self.next_obj_id += 1
|
|
1606
|
-
# For negative on background, try to find most recent object
|
|
1607
|
-
elif point_label < 0 and label_id == 0:
|
|
1608
|
-
# Use most recently created object if available
|
|
1609
|
-
if hasattr(self, "obj_points") and self.obj_points:
|
|
1610
|
-
current_obj_id = max(self.obj_points.keys())
|
|
1611
|
-
else:
|
|
1612
|
-
self.viewer.status = "No existing object to modify with negative point"
|
|
1613
|
-
return
|
|
1614
|
-
|
|
1615
|
-
if current_obj_id is None:
|
|
1616
|
-
self.viewer.status = (
|
|
1617
|
-
"Could not determine which object to modify"
|
|
1618
|
-
)
|
|
1619
|
-
return
|
|
1620
|
-
|
|
1621
|
-
# Find or create points layer for this object
|
|
1622
|
-
points_layer = None
|
|
1623
|
-
for layer in list(self.viewer.layers):
|
|
1624
|
-
if f"Points for Object {current_obj_id}" in layer.name:
|
|
1625
|
-
points_layer = layer
|
|
1626
|
-
break
|
|
1627
|
-
|
|
1628
|
-
# Initialize object tracking if needed
|
|
1629
|
-
if not hasattr(self, "obj_points"):
|
|
1630
|
-
self.obj_points = {}
|
|
1631
|
-
self.obj_labels = {}
|
|
1632
|
-
|
|
1633
|
-
if current_obj_id not in self.obj_points:
|
|
1634
|
-
self.obj_points[current_obj_id] = []
|
|
1635
|
-
self.obj_labels[current_obj_id] = []
|
|
1636
|
-
|
|
1637
|
-
# Create or update points layer for this object
|
|
1638
|
-
if points_layer is None:
|
|
1639
|
-
# First point for this object
|
|
1640
|
-
points_layer = self.viewer.add_points(
|
|
1641
|
-
np.array([[y, x]]),
|
|
1642
|
-
name=f"Points for Object {current_obj_id}",
|
|
1643
|
-
size=10,
|
|
1644
|
-
face_color=["green" if point_label > 0 else "red"],
|
|
1645
|
-
border_color="white",
|
|
1646
|
-
border_width=1,
|
|
1647
|
-
opacity=0.8,
|
|
1648
|
-
)
|
|
1649
|
-
self.obj_points[current_obj_id] = [[x, y]]
|
|
1650
|
-
self.obj_labels[current_obj_id] = [point_label]
|
|
1651
|
-
else:
|
|
1652
|
-
# Add point to existing layer
|
|
1653
|
-
current_points = points_layer.data
|
|
1654
|
-
current_colors = points_layer.face_color
|
|
1655
|
-
|
|
1656
|
-
# Add new point
|
|
1657
|
-
new_points = np.vstack([current_points, [y, x]])
|
|
1658
|
-
new_color = "green" if point_label > 0 else "red"
|
|
1659
|
-
|
|
1660
|
-
# Update points layer
|
|
1661
|
-
points_layer.data = new_points
|
|
1662
|
-
|
|
1663
|
-
# Update colors
|
|
1664
|
-
if isinstance(current_colors, list):
|
|
1665
|
-
current_colors.append(new_color)
|
|
1666
|
-
points_layer.face_color = current_colors
|
|
1667
|
-
else:
|
|
1668
|
-
# If it's an array, create a list of colors
|
|
1669
|
-
colors = []
|
|
1670
|
-
for i in range(len(new_points)):
|
|
1671
|
-
if i < len(current_points):
|
|
1672
|
-
colors.append(
|
|
1673
|
-
"green" if point_label > 0 else "red"
|
|
1674
|
-
)
|
|
1675
|
-
else:
|
|
1676
|
-
colors.append(new_color)
|
|
1677
|
-
points_layer.face_color = colors
|
|
1678
|
-
|
|
1679
|
-
# Update object tracking
|
|
1680
|
-
self.obj_points[current_obj_id].append([x, y])
|
|
1681
|
-
self.obj_labels[current_obj_id].append(point_label)
|
|
1682
|
-
|
|
1683
|
-
# Now do the actual segmentation using SAM2
|
|
1684
|
-
if (
|
|
1685
|
-
hasattr(self, "predictor")
|
|
1686
|
-
and self.predictor is not None
|
|
1687
|
-
):
|
|
1688
|
-
try:
|
|
1689
|
-
# Make sure image is loaded
|
|
1690
|
-
if self.current_image_for_segmentation is None:
|
|
1691
|
-
self.viewer.status = (
|
|
1692
|
-
"No image loaded for segmentation"
|
|
1693
|
-
)
|
|
1694
|
-
return
|
|
1695
|
-
|
|
1696
|
-
# Prepare image for SAM2
|
|
1697
|
-
image = self.current_image_for_segmentation
|
|
1698
|
-
if len(image.shape) == 2:
|
|
1699
|
-
image = np.stack([image] * 3, axis=-1)
|
|
1700
|
-
elif len(image.shape) == 3 and image.shape[2] == 1:
|
|
1701
|
-
image = np.concatenate([image] * 3, axis=2)
|
|
1702
|
-
elif len(image.shape) == 3 and image.shape[2] > 3:
|
|
1703
|
-
image = image[:, :, :3]
|
|
1704
|
-
|
|
1705
|
-
if image.dtype != np.uint8:
|
|
1706
|
-
image = (image / np.max(image) * 255).astype(
|
|
1707
|
-
np.uint8
|
|
1708
|
-
)
|
|
1709
|
-
|
|
1710
|
-
# Set the image in the predictor
|
|
1711
|
-
self.predictor.set_image(image)
|
|
1712
|
-
|
|
1713
|
-
# Only use the points for the current object being segmented
|
|
1714
|
-
points = np.array(
|
|
1715
|
-
self.obj_points[current_obj_id],
|
|
1716
|
-
dtype=np.float32,
|
|
1717
|
-
)
|
|
1718
|
-
labels = np.array(
|
|
1719
|
-
self.obj_labels[current_obj_id], dtype=np.int32
|
|
1720
|
-
)
|
|
1721
|
-
|
|
1722
|
-
self.viewer.status = f"Segmenting object {current_obj_id} with {len(points)} points..."
|
|
1723
|
-
|
|
1724
|
-
with torch.inference_mode(), torch.autocast(
|
|
1725
|
-
"cuda"
|
|
1726
|
-
):
|
|
1727
|
-
masks, scores, _ = self.predictor.predict(
|
|
1728
|
-
point_coords=points,
|
|
1729
|
-
point_labels=labels,
|
|
1730
|
-
multimask_output=True,
|
|
1731
|
-
)
|
|
1732
|
-
|
|
1733
|
-
# Get best mask
|
|
1734
|
-
if len(masks) > 0:
|
|
1735
|
-
best_mask = masks[0]
|
|
1736
|
-
|
|
1737
|
-
# Update segmentation result
|
|
1738
|
-
if (
|
|
1739
|
-
best_mask.shape
|
|
1740
|
-
!= self.segmentation_result.shape
|
|
1741
|
-
):
|
|
1742
|
-
from skimage.transform import resize
|
|
1743
|
-
|
|
1744
|
-
best_mask = resize(
|
|
1745
|
-
best_mask.astype(float),
|
|
1746
|
-
self.segmentation_result.shape,
|
|
1747
|
-
order=0,
|
|
1748
|
-
preserve_range=True,
|
|
1749
|
-
anti_aliasing=False,
|
|
1750
|
-
).astype(bool)
|
|
1751
|
-
|
|
1752
|
-
# CRITICAL FIX: For negative points, only remove from this object's mask
|
|
1753
|
-
# For positive points, add to this object's mask without removing other objects
|
|
1754
|
-
if point_label < 0:
|
|
1755
|
-
# Remove only from current object's mask
|
|
1756
|
-
self.segmentation_result[
|
|
1757
|
-
(
|
|
1758
|
-
self.segmentation_result
|
|
1759
|
-
== current_obj_id
|
|
1760
|
-
)
|
|
1761
|
-
& best_mask
|
|
1762
|
-
] = 0
|
|
1763
|
-
else:
|
|
1764
|
-
# Add to current object's mask without affecting other objects
|
|
1765
|
-
# Only overwrite background (value 0)
|
|
1766
|
-
self.segmentation_result[
|
|
1767
|
-
best_mask
|
|
1768
|
-
& (self.segmentation_result == 0)
|
|
1769
|
-
] = current_obj_id
|
|
1770
|
-
|
|
1771
|
-
# Update label info
|
|
1772
|
-
area = np.sum(
|
|
1773
|
-
self.segmentation_result
|
|
1774
|
-
== current_obj_id
|
|
1775
|
-
)
|
|
1776
|
-
y_indices, x_indices = np.where(
|
|
1777
|
-
self.segmentation_result
|
|
1778
|
-
== current_obj_id
|
|
1779
|
-
)
|
|
1780
|
-
center_y = (
|
|
1781
|
-
np.mean(y_indices)
|
|
1782
|
-
if len(y_indices) > 0
|
|
1783
|
-
else 0
|
|
1784
|
-
)
|
|
1785
|
-
center_x = (
|
|
1786
|
-
np.mean(x_indices)
|
|
1787
|
-
if len(x_indices) > 0
|
|
1788
|
-
else 0
|
|
1789
|
-
)
|
|
1790
|
-
|
|
1791
|
-
self.label_info[current_obj_id] = {
|
|
1792
|
-
"area": area,
|
|
1793
|
-
"center_y": center_y,
|
|
1794
|
-
"center_x": center_x,
|
|
1795
|
-
"score": float(scores[0]),
|
|
1796
|
-
}
|
|
1797
|
-
|
|
1798
|
-
self.viewer.status = (
|
|
1799
|
-
f"Updated object {current_obj_id}"
|
|
1800
|
-
)
|
|
1801
|
-
else:
|
|
1802
|
-
self.viewer.status = (
|
|
1803
|
-
"No valid mask produced"
|
|
1804
|
-
)
|
|
1805
|
-
|
|
1806
|
-
# Update the UI
|
|
1807
|
-
self._update_label_layer()
|
|
1808
|
-
if (
|
|
1809
|
-
hasattr(self, "label_table_widget")
|
|
1810
|
-
and self.label_table_widget is not None
|
|
1811
|
-
):
|
|
1812
|
-
self._populate_label_table(
|
|
1813
|
-
self.label_table_widget
|
|
1814
|
-
)
|
|
1815
|
-
|
|
1816
|
-
except (
|
|
1817
|
-
IndexError,
|
|
1818
|
-
KeyError,
|
|
1819
|
-
ValueError,
|
|
1820
|
-
AttributeError,
|
|
1821
|
-
TypeError,
|
|
1822
|
-
) as e:
|
|
1823
|
-
import traceback
|
|
1824
|
-
|
|
1825
|
-
self.viewer.status = (
|
|
1826
|
-
f"Error in SAM2 processing: {str(e)}"
|
|
1827
|
-
)
|
|
1828
|
-
traceback.print_exc()
|
|
2349
|
+
# Handle Ctrl+Click to clear a single label
|
|
2350
|
+
if is_control and label_id > 0:
|
|
2351
|
+
self.clear_label_at_position(y, x)
|
|
2352
|
+
return
|
|
1829
2353
|
|
|
1830
|
-
# If clicking on an existing label, toggle selection
|
|
1831
|
-
|
|
2354
|
+
# If clicking on an existing label (and not using modifiers), toggle selection
|
|
2355
|
+
if label_id > 0 and not is_negative and not is_control:
|
|
1832
2356
|
# Toggle the label selection
|
|
1833
2357
|
if label_id in self.selected_labels:
|
|
1834
2358
|
self.selected_labels.remove(label_id)
|
|
@@ -1840,8 +2364,14 @@ class BatchCropAnything:
|
|
|
1840
2364
|
# Update table and preview
|
|
1841
2365
|
self._update_label_table()
|
|
1842
2366
|
self.preview_crop()
|
|
2367
|
+
return
|
|
1843
2368
|
|
|
1844
|
-
|
|
2369
|
+
# If clicking on background or using Shift (negative points), this should be handled by points layer
|
|
2370
|
+
# Don't process these clicks here to avoid conflicts
|
|
2371
|
+
if label_id == 0 or is_negative:
|
|
2372
|
+
return
|
|
2373
|
+
|
|
2374
|
+
# 3D case
|
|
1845
2375
|
else:
|
|
1846
2376
|
if len(coords) == 3:
|
|
1847
2377
|
t, y, x = map(int, coords)
|
|
@@ -1870,12 +2400,13 @@ class BatchCropAnything:
|
|
|
1870
2400
|
# Get the label ID at the clicked position
|
|
1871
2401
|
label_id = self.segmentation_result[t, y, x]
|
|
1872
2402
|
|
|
1873
|
-
#
|
|
1874
|
-
if label_id
|
|
1875
|
-
|
|
1876
|
-
|
|
1877
|
-
|
|
1878
|
-
|
|
2403
|
+
# Handle Ctrl+Click to clear a single label
|
|
2404
|
+
if is_control and label_id > 0:
|
|
2405
|
+
self.clear_label_at_position_3d(t, y, x)
|
|
2406
|
+
return
|
|
2407
|
+
|
|
2408
|
+
# If clicking on an existing label and not using negative points, handle selection
|
|
2409
|
+
if label_id > 0 and not is_negative and not is_control:
|
|
1879
2410
|
# Toggle the label selection
|
|
1880
2411
|
if label_id in self.selected_labels:
|
|
1881
2412
|
self.selected_labels.remove(label_id)
|
|
@@ -1886,9 +2417,12 @@ class BatchCropAnything:
|
|
|
1886
2417
|
|
|
1887
2418
|
# Update table if it exists
|
|
1888
2419
|
self._update_label_table()
|
|
1889
|
-
|
|
1890
|
-
# Update preview after selection changes
|
|
1891
2420
|
self.preview_crop()
|
|
2421
|
+
return
|
|
2422
|
+
|
|
2423
|
+
# For background clicks or negative points, let the 3D handler deal with it
|
|
2424
|
+
if label_id == 0 or is_negative:
|
|
2425
|
+
return
|
|
1892
2426
|
|
|
1893
2427
|
except (
|
|
1894
2428
|
IndexError,
|
|
@@ -1902,12 +2436,74 @@ class BatchCropAnything:
|
|
|
1902
2436
|
self.viewer.status = f"Error in click handling: {str(e)}"
|
|
1903
2437
|
traceback.print_exc()
|
|
1904
2438
|
|
|
2439
|
+
def _add_segmentation_point(self, x, y, event):
|
|
2440
|
+
"""Add a point for segmentation."""
|
|
2441
|
+
is_negative = "Shift" in event.modifiers
|
|
2442
|
+
|
|
2443
|
+
# Initialize tracking if needed
|
|
2444
|
+
if not hasattr(self, "current_points"):
|
|
2445
|
+
self.current_points = []
|
|
2446
|
+
self.current_labels = []
|
|
2447
|
+
self.current_obj_id = 1
|
|
2448
|
+
|
|
2449
|
+
# Add point
|
|
2450
|
+
self.current_points.append([x, y])
|
|
2451
|
+
self.current_labels.append(0 if is_negative else 1)
|
|
2452
|
+
|
|
2453
|
+
# Run SAM2 prediction
|
|
2454
|
+
if self.predictor is not None:
|
|
2455
|
+
# Prepare image
|
|
2456
|
+
image = self._prepare_image_for_sam2()
|
|
2457
|
+
|
|
2458
|
+
# Set the image in the predictor (only for ImagePredictor, not VideoPredictor)
|
|
2459
|
+
if hasattr(self.predictor, "set_image"):
|
|
2460
|
+
self.predictor.set_image(image)
|
|
2461
|
+
else:
|
|
2462
|
+
self.viewer.status = (
|
|
2463
|
+
"Error: This operation requires Image Predictor (2D mode)"
|
|
2464
|
+
)
|
|
2465
|
+
return
|
|
2466
|
+
|
|
2467
|
+
# Predict
|
|
2468
|
+
device_type = "cuda" if self.device.type == "cuda" else "cpu"
|
|
2469
|
+
with torch.inference_mode(), torch.autocast(device_type):
|
|
2470
|
+
masks, scores, _ = self.predictor.predict(
|
|
2471
|
+
point_coords=np.array(
|
|
2472
|
+
self.current_points, dtype=np.float32
|
|
2473
|
+
),
|
|
2474
|
+
point_labels=np.array(self.current_labels, dtype=np.int32),
|
|
2475
|
+
multimask_output=False,
|
|
2476
|
+
)
|
|
2477
|
+
|
|
2478
|
+
# Update segmentation
|
|
2479
|
+
if len(masks) > 0:
|
|
2480
|
+
mask = masks[0] > 0.5
|
|
2481
|
+
if self.current_scale_factor < 1.0:
|
|
2482
|
+
mask = resize(
|
|
2483
|
+
mask, self.segmentation_result.shape, order=0
|
|
2484
|
+
).astype(bool)
|
|
2485
|
+
|
|
2486
|
+
# Update segmentation result
|
|
2487
|
+
self.segmentation_result[mask] = self.current_obj_id
|
|
2488
|
+
|
|
2489
|
+
# Move to next object if adding positive point
|
|
2490
|
+
if not is_negative:
|
|
2491
|
+
self.current_obj_id += 1
|
|
2492
|
+
self.current_points = []
|
|
2493
|
+
self.current_labels = []
|
|
2494
|
+
|
|
2495
|
+
self._update_label_layer()
|
|
2496
|
+
|
|
1905
2497
|
def _add_point_marker(self, coords, label_type):
|
|
1906
2498
|
"""Add a visible marker for where the user clicked."""
|
|
1907
2499
|
# Remove previous point markers
|
|
1908
2500
|
for layer in list(self.viewer.layers):
|
|
1909
2501
|
if "Point Prompt" in layer.name:
|
|
1910
|
-
|
|
2502
|
+
# Clean up callbacks before removing the layer to prevent cleanup issues
|
|
2503
|
+
if hasattr(layer, "mouse_drag_callbacks"):
|
|
2504
|
+
layer.mouse_drag_callbacks.clear()
|
|
2505
|
+
with contextlib.suppress(ValueError):
|
|
2506
|
+
self.viewer.layers.remove(layer)
|
|
1911
2507
|
|
|
1912
2508
|
# Create points layer
|
|
1913
2509
|
color = (
|
|
@@ -1923,6 +2519,14 @@ class BatchCropAnything:
|
|
|
1923
2519
|
opacity=0.8,
|
|
1924
2520
|
)
|
|
1925
2521
|
|
|
2522
|
+
with contextlib.suppress(AttributeError, ValueError):
|
|
2523
|
+
self.points_layer.mouse_drag_callbacks.remove(
|
|
2524
|
+
self._on_points_clicked
|
|
2525
|
+
)
|
|
2526
|
+
self.points_layer.mouse_drag_callbacks.append(
|
|
2527
|
+
self._on_points_clicked
|
|
2528
|
+
)
|
|
2529
|
+
|
|
1926
2530
|
def create_label_table(self, parent_widget):
|
|
1927
2531
|
"""Create a table widget displaying all detected labels."""
|
|
1928
2532
|
# Create table widget
|
|
@@ -2087,11 +2691,170 @@ class BatchCropAnything:
|
|
|
2087
2691
|
self.viewer.status = f"Selected all {len(self.selected_labels)} labels"
|
|
2088
2692
|
|
|
2089
2693
|
def clear_selection(self):
|
|
2090
|
-
"""Clear all
|
|
2694
|
+
"""Clear all labels from the segmentation.
|
|
2695
|
+
|
|
2696
|
+
This removes all segmented objects from the label layer, resets all tracking data,
|
|
2697
|
+
and prepares the interface for new segmentations. Note: The method name is kept as
|
|
2698
|
+
'clear_selection' for backwards compatibility, but it clears all labels, not just
|
|
2699
|
+
the selection.
|
|
2700
|
+
"""
|
|
2701
|
+
if self.segmentation_result is None:
|
|
2702
|
+
self.viewer.status = "No segmentation available"
|
|
2703
|
+
return
|
|
2704
|
+
|
|
2705
|
+
# Get all unique label IDs (excluding background 0)
|
|
2706
|
+
unique_labels = np.unique(self.segmentation_result)
|
|
2707
|
+
label_ids = [label for label in unique_labels if label > 0]
|
|
2708
|
+
|
|
2709
|
+
if len(label_ids) == 0:
|
|
2710
|
+
self.viewer.status = "No labels to clear"
|
|
2711
|
+
return
|
|
2712
|
+
|
|
2713
|
+
# Clear the entire segmentation result
|
|
2714
|
+
self.segmentation_result[:] = 0
|
|
2715
|
+
|
|
2716
|
+
# Clear selected labels
|
|
2091
2717
|
self.selected_labels = set()
|
|
2718
|
+
|
|
2719
|
+
# Clear label info
|
|
2720
|
+
self.label_info = {}
|
|
2721
|
+
|
|
2722
|
+
# Remove any object-specific point layers
|
|
2723
|
+
for layer in list(self.viewer.layers):
|
|
2724
|
+
if "Points for Object" in layer.name:
|
|
2725
|
+
# Clean up callbacks before removing the layer to prevent cleanup issues
|
|
2726
|
+
if hasattr(layer, "mouse_drag_callbacks"):
|
|
2727
|
+
layer.mouse_drag_callbacks.clear()
|
|
2728
|
+
with contextlib.suppress(ValueError):
|
|
2729
|
+
self.viewer.layers.remove(layer)
|
|
2730
|
+
|
|
2731
|
+
# Clean up object tracking data
|
|
2732
|
+
if hasattr(self, "obj_points"):
|
|
2733
|
+
self.obj_points = {}
|
|
2734
|
+
if hasattr(self, "obj_labels"):
|
|
2735
|
+
self.obj_labels = {}
|
|
2736
|
+
if hasattr(self, "points_data"):
|
|
2737
|
+
self.points_data = {}
|
|
2738
|
+
if hasattr(self, "points_labels"):
|
|
2739
|
+
self.points_labels = {}
|
|
2740
|
+
|
|
2741
|
+
# Reset object ID counters
|
|
2742
|
+
if hasattr(self, "next_obj_id"):
|
|
2743
|
+
self.next_obj_id = 1
|
|
2744
|
+
if hasattr(self, "_sam2_next_obj_id"):
|
|
2745
|
+
self._sam2_next_obj_id = 1
|
|
2746
|
+
|
|
2747
|
+
# Update UI
|
|
2748
|
+
self._update_label_layer()
|
|
2092
2749
|
self._update_label_table()
|
|
2093
2750
|
self.preview_crop()
|
|
2094
|
-
|
|
2751
|
+
|
|
2752
|
+
self.viewer.status = (
|
|
2753
|
+
f"Cleared all {len(label_ids)} labels from segmentation"
|
|
2754
|
+
)
|
|
2755
|
+
|
|
2756
|
+
def clear_label_at_position(self, y, x):
|
|
2757
|
+
"""Clear a single label at the specified 2D position."""
|
|
2758
|
+
if self.segmentation_result is None:
|
|
2759
|
+
self.viewer.status = "No segmentation available"
|
|
2760
|
+
return
|
|
2761
|
+
|
|
2762
|
+
label_id = self.segmentation_result[y, x]
|
|
2763
|
+
if label_id > 0:
|
|
2764
|
+
# Remove all pixels with this label ID
|
|
2765
|
+
self.segmentation_result[self.segmentation_result == label_id] = 0
|
|
2766
|
+
|
|
2767
|
+
# Remove from selected labels if it was selected
|
|
2768
|
+
self.selected_labels.discard(label_id)
|
|
2769
|
+
|
|
2770
|
+
# Remove from label info
|
|
2771
|
+
if label_id in self.label_info:
|
|
2772
|
+
del self.label_info[label_id]
|
|
2773
|
+
|
|
2774
|
+
# Remove any object-specific point layers for this label
|
|
2775
|
+
for layer in list(self.viewer.layers):
|
|
2776
|
+
if f"Points for Object {label_id}" in layer.name:
|
|
2777
|
+
# Clean up callbacks before removing the layer to prevent cleanup issues
|
|
2778
|
+
if hasattr(layer, "mouse_drag_callbacks"):
|
|
2779
|
+
layer.mouse_drag_callbacks.clear()
|
|
2780
|
+
with contextlib.suppress(ValueError):
|
|
2781
|
+
self.viewer.layers.remove(layer)
|
|
2782
|
+
|
|
2783
|
+
# Clean up object tracking data
|
|
2784
|
+
if hasattr(self, "obj_points") and label_id in self.obj_points:
|
|
2785
|
+
del self.obj_points[label_id]
|
|
2786
|
+
if hasattr(self, "obj_labels") and label_id in self.obj_labels:
|
|
2787
|
+
del self.obj_labels[label_id]
|
|
2788
|
+
|
|
2789
|
+
# Update UI
|
|
2790
|
+
self._update_label_layer()
|
|
2791
|
+
self._update_label_table()
|
|
2792
|
+
self.preview_crop()
|
|
2793
|
+
|
|
2794
|
+
self.viewer.status = f"Deleted label ID: {label_id}"
|
|
2795
|
+
else:
|
|
2796
|
+
self.viewer.status = "No label to delete at this position"
|
|
2797
|
+
|
|
2798
|
+
def clear_label_at_position_3d(self, t, y, x):
|
|
2799
|
+
"""Clear a single label at the specified 3D position."""
|
|
2800
|
+
if self.segmentation_result is None:
|
|
2801
|
+
self.viewer.status = "No segmentation available"
|
|
2802
|
+
return
|
|
2803
|
+
|
|
2804
|
+
label_id = self.segmentation_result[t, y, x]
|
|
2805
|
+
if label_id > 0:
|
|
2806
|
+
# Remove all pixels with this label ID across all timeframes
|
|
2807
|
+
self.segmentation_result[self.segmentation_result == label_id] = 0
|
|
2808
|
+
|
|
2809
|
+
# Remove from selected labels if it was selected
|
|
2810
|
+
self.selected_labels.discard(label_id)
|
|
2811
|
+
|
|
2812
|
+
# Remove from label info
|
|
2813
|
+
if label_id in self.label_info:
|
|
2814
|
+
del self.label_info[label_id]
|
|
2815
|
+
|
|
2816
|
+
# Remove any object-specific point layers for this label
|
|
2817
|
+
for layer in list(self.viewer.layers):
|
|
2818
|
+
if f"Points for Object {label_id}" in layer.name:
|
|
2819
|
+
# Clean up callbacks before removing the layer to prevent cleanup issues
|
|
2820
|
+
if hasattr(layer, "mouse_drag_callbacks"):
|
|
2821
|
+
layer.mouse_drag_callbacks.clear()
|
|
2822
|
+
with contextlib.suppress(ValueError):
|
|
2823
|
+
self.viewer.layers.remove(layer)
|
|
2824
|
+
|
|
2825
|
+
# Clean up 3D object tracking data
|
|
2826
|
+
if (
|
|
2827
|
+
hasattr(self, "sam2_points_by_obj")
|
|
2828
|
+
and label_id in self.sam2_points_by_obj
|
|
2829
|
+
):
|
|
2830
|
+
del self.sam2_points_by_obj[label_id]
|
|
2831
|
+
if (
|
|
2832
|
+
hasattr(self, "sam2_labels_by_obj")
|
|
2833
|
+
and label_id in self.sam2_labels_by_obj
|
|
2834
|
+
):
|
|
2835
|
+
del self.sam2_labels_by_obj[label_id]
|
|
2836
|
+
if hasattr(self, "points_data") and label_id in self.points_data:
|
|
2837
|
+
del self.points_data[label_id]
|
|
2838
|
+
if (
|
|
2839
|
+
hasattr(self, "points_labels")
|
|
2840
|
+
and label_id in self.points_labels
|
|
2841
|
+
):
|
|
2842
|
+
del self.points_labels[label_id]
|
|
2843
|
+
|
|
2844
|
+
# Update UI
|
|
2845
|
+
self._update_label_layer()
|
|
2846
|
+
if (
|
|
2847
|
+
hasattr(self, "label_table_widget")
|
|
2848
|
+
and self.label_table_widget is not None
|
|
2849
|
+
):
|
|
2850
|
+
self._populate_label_table(self.label_table_widget)
|
|
2851
|
+
self.preview_crop()
|
|
2852
|
+
|
|
2853
|
+
self.viewer.status = (
|
|
2854
|
+
f"Deleted label ID: {label_id} from all timeframes"
|
|
2855
|
+
)
|
|
2856
|
+
else:
|
|
2857
|
+
self.viewer.status = "No label to delete at this position"
|
|
2095
2858
|
|
|
2096
2859
|
def preview_crop(self, label_ids=None):
|
|
2097
2860
|
"""Preview the crop result with the selected label IDs."""
|
|
@@ -2111,7 +2874,11 @@ class BatchCropAnything:
|
|
|
2111
2874
|
# Remove previous preview if exists
|
|
2112
2875
|
for layer in list(self.viewer.layers):
|
|
2113
2876
|
if "Preview" in layer.name:
|
|
2114
|
-
|
|
2877
|
+
# Clean up callbacks before removing the layer to prevent cleanup issues
|
|
2878
|
+
if hasattr(layer, "mouse_drag_callbacks"):
|
|
2879
|
+
layer.mouse_drag_callbacks.clear()
|
|
2880
|
+
with contextlib.suppress(ValueError):
|
|
2881
|
+
self.viewer.layers.remove(layer)
|
|
2115
2882
|
|
|
2116
2883
|
# Make sure the segmentation layer is active again
|
|
2117
2884
|
if self.label_layer is not None:
|
|
@@ -2149,7 +2916,11 @@ class BatchCropAnything:
|
|
|
2149
2916
|
# Remove previous preview if exists
|
|
2150
2917
|
for layer in list(self.viewer.layers):
|
|
2151
2918
|
if "Preview" in layer.name:
|
|
2152
|
-
|
|
2919
|
+
# Clean up callbacks before removing the layer to prevent cleanup issues
|
|
2920
|
+
if hasattr(layer, "mouse_drag_callbacks"):
|
|
2921
|
+
layer.mouse_drag_callbacks.clear()
|
|
2922
|
+
with contextlib.suppress(ValueError):
|
|
2923
|
+
self.viewer.layers.remove(layer)
|
|
2153
2924
|
|
|
2154
2925
|
# Add preview layer
|
|
2155
2926
|
if label_ids:
|
|
@@ -2240,17 +3011,14 @@ class BatchCropAnything:
|
|
|
2240
3011
|
# Save cropped image
|
|
2241
3012
|
image_path = self.images[self.current_index]
|
|
2242
3013
|
base_name, ext = os.path.splitext(image_path)
|
|
2243
|
-
|
|
2244
|
-
str(lid) for lid in sorted(self.selected_labels)
|
|
2245
|
-
)
|
|
2246
|
-
output_path = f"{base_name}_cropped_{label_str}.tif"
|
|
3014
|
+
output_path = f"{base_name}_sam2_cropped.tif"
|
|
2247
3015
|
|
|
2248
3016
|
# Save using tifffile with explicit parameters for best compatibility
|
|
2249
3017
|
imwrite(output_path, cropped_image, compression="zlib")
|
|
2250
3018
|
self.viewer.status = f"Saved cropped image to {output_path}"
|
|
2251
3019
|
|
|
2252
3020
|
# Save the label image with exact same dimensions as original
|
|
2253
|
-
label_output_path = f"{base_name}
|
|
3021
|
+
label_output_path = f"{base_name}_sam2_labels.tif"
|
|
2254
3022
|
imwrite(label_output_path, label_image, compression="zlib")
|
|
2255
3023
|
self.viewer.status += f"\nSaved label mask to {label_output_path}"
|
|
2256
3024
|
|
|
@@ -2264,6 +3032,27 @@ class BatchCropAnything:
|
|
|
2264
3032
|
self.viewer.status = f"Error cropping image: {str(e)}"
|
|
2265
3033
|
return False
|
|
2266
3034
|
|
|
3035
|
+
def reset_sam2_state(self):
|
|
3036
|
+
"""Reset SAM2 predictor state for 2D segmentation."""
|
|
3037
|
+
if not self.use_3d and hasattr(self, "prepared_sam2_image"):
|
|
3038
|
+
# Re-set the image in the predictor (only for ImagePredictor)
|
|
3039
|
+
device_type = "cuda" if self.device.type == "cuda" else "cpu"
|
|
3040
|
+
try:
|
|
3041
|
+
if hasattr(self.predictor, "set_image"):
|
|
3042
|
+
with (
|
|
3043
|
+
torch.inference_mode(),
|
|
3044
|
+
torch.autocast(device_type, dtype=torch.float32),
|
|
3045
|
+
):
|
|
3046
|
+
self.predictor.set_image(self.prepared_sam2_image)
|
|
3047
|
+
else:
|
|
3048
|
+
print(
|
|
3049
|
+
"DEBUG: reset_sam2_state - predictor doesn't have set_image method"
|
|
3050
|
+
)
|
|
3051
|
+
except (RuntimeError, AssertionError, TypeError, ValueError) as e:
|
|
3052
|
+
print(f"Error resetting SAM2 state: {e}")
|
|
3053
|
+
# If there's an error, try to reinitialize
|
|
3054
|
+
self._initialize_sam2()
|
|
3055
|
+
|
|
2267
3056
|
|
|
2268
3057
|
def create_crop_widget(processor):
|
|
2269
3058
|
"""Create the crop control widget."""
|
|
@@ -2274,27 +3063,70 @@ def create_crop_widget(processor):
|
|
|
2274
3063
|
|
|
2275
3064
|
# Instructions
|
|
2276
3065
|
dimension_type = "3D (TYX/ZYX)" if processor.use_3d else "2D (YX)"
|
|
2277
|
-
|
|
2278
|
-
|
|
2279
|
-
|
|
2280
|
-
|
|
2281
|
-
|
|
2282
|
-
|
|
2283
|
-
|
|
2284
|
-
|
|
3066
|
+
|
|
3067
|
+
if processor.use_3d:
|
|
3068
|
+
instructions_text = (
|
|
3069
|
+
f"<b>Processing {dimension_type} data</b><br><br>"
|
|
3070
|
+
"<b>⚠️ IMPORTANT for 3D stacks:</b><br>"
|
|
3071
|
+
"<ul>"
|
|
3072
|
+
"<li><b>Navigate to the FIRST SLICE</b> where your object appears (use the time/Z slider)</li>"
|
|
3073
|
+
"<li><b>Switch to 2D view</b> (click 2D icon in napari, NOT 3D view)</li>"
|
|
3074
|
+
"<li><b>Point Mode:</b> Select Points layer and click on objects to segment them</li>"
|
|
3075
|
+
"<li><b>Rectangle Mode:</b> Draw rectangles around objects to segment them</li>"
|
|
3076
|
+
"<li>Segmentation will automatically propagate to all slices</li>"
|
|
3077
|
+
"</ul><br>"
|
|
3078
|
+
"<b>General Controls:</b><br>"
|
|
3079
|
+
"<ul>"
|
|
3080
|
+
"<li>Use <b>Shift+click</b> for negative points (remove areas from segmentation)</li>"
|
|
3081
|
+
"<li>Click on existing objects in <b>Segmentation layer</b> to select for cropping</li>"
|
|
3082
|
+
"<li>Press <b>CTRL+click</b> on labels in <b>Segmentation layer</b> to delete them</li>"
|
|
3083
|
+
"<li>Press <b>'Crop'</b> to save selected objects to disk</li>"
|
|
3084
|
+
"</ul>"
|
|
3085
|
+
)
|
|
3086
|
+
else:
|
|
3087
|
+
instructions_text = (
|
|
3088
|
+
f"<b>Processing {dimension_type} data</b><br><br>"
|
|
3089
|
+
"<b>Point Mode:</b> Click on objects to segment them. Use Shift+click for negative points.<br>"
|
|
3090
|
+
"<b>Rectangle Mode:</b> Draw rectangles around objects to segment them.<br><br>"
|
|
3091
|
+
"<ul>"
|
|
3092
|
+
"<li>Click on existing objects in <b>Segmentation layer</b> to select them for cropping</li>"
|
|
3093
|
+
"<li>Press <b>CTRL+click</b> on labels in <b>Segmentation layer</b> to delete them</li>"
|
|
3094
|
+
"<li>Press <b>'Crop'</b> to save selected objects to disk</li>"
|
|
3095
|
+
"</ul>"
|
|
3096
|
+
)
|
|
3097
|
+
|
|
3098
|
+
instructions_label = QLabel(instructions_text)
|
|
2285
3099
|
instructions_label.setWordWrap(True)
|
|
2286
3100
|
layout.addWidget(instructions_label)
|
|
2287
3101
|
|
|
2288
|
-
# Add
|
|
2289
|
-
|
|
3102
|
+
# Add mode selector
|
|
3103
|
+
mode_layout = QHBoxLayout()
|
|
3104
|
+
mode_label = QLabel("<b>Prompt Mode:</b>")
|
|
3105
|
+
mode_layout.addWidget(mode_label)
|
|
3106
|
+
|
|
3107
|
+
point_mode_button = QPushButton("Points")
|
|
3108
|
+
point_mode_button.setCheckable(True)
|
|
3109
|
+
point_mode_button.setChecked(True)
|
|
3110
|
+
mode_layout.addWidget(point_mode_button)
|
|
3111
|
+
|
|
3112
|
+
box_mode_button = QPushButton("Rectangle")
|
|
3113
|
+
box_mode_button.setCheckable(True)
|
|
3114
|
+
box_mode_button.setChecked(False)
|
|
3115
|
+
mode_layout.addWidget(box_mode_button)
|
|
3116
|
+
|
|
3117
|
+
mode_layout.addStretch()
|
|
3118
|
+
layout.addLayout(mode_layout)
|
|
3119
|
+
|
|
3120
|
+
# Add a button to ensure active layer is correct
|
|
3121
|
+
activate_button = QPushButton("Make Prompt Layer Active")
|
|
2290
3122
|
activate_button.clicked.connect(
|
|
2291
|
-
lambda: processor.
|
|
3123
|
+
lambda: processor._ensure_active_prompt_layer()
|
|
2292
3124
|
)
|
|
2293
3125
|
layout.addWidget(activate_button)
|
|
2294
3126
|
|
|
2295
|
-
# Add a "Clear
|
|
2296
|
-
|
|
2297
|
-
layout.addWidget(
|
|
3127
|
+
# Add a "Clear Prompts" button to reset prompts
|
|
3128
|
+
clear_prompts_button = QPushButton("Clear Prompts")
|
|
3129
|
+
layout.addWidget(clear_prompts_button)
|
|
2298
3130
|
|
|
2299
3131
|
# Create label table
|
|
2300
3132
|
label_table = processor.create_label_table(crop_widget)
|
|
@@ -2305,7 +3137,7 @@ def create_crop_widget(processor):
|
|
|
2305
3137
|
# Selection buttons
|
|
2306
3138
|
selection_layout = QHBoxLayout()
|
|
2307
3139
|
select_all_button = QPushButton("Select All")
|
|
2308
|
-
clear_selection_button = QPushButton("Clear
|
|
3140
|
+
clear_selection_button = QPushButton("Clear All Labels")
|
|
2309
3141
|
selection_layout.addWidget(select_all_button)
|
|
2310
3142
|
selection_layout.addWidget(clear_selection_button)
|
|
2311
3143
|
layout.addLayout(selection_layout)
|
|
@@ -2343,51 +3175,152 @@ def create_crop_widget(processor):
|
|
|
2343
3175
|
# Create new table
|
|
2344
3176
|
label_table = processor.create_label_table(crop_widget)
|
|
2345
3177
|
label_table.setMinimumHeight(200)
|
|
2346
|
-
layout.insertWidget(
|
|
3178
|
+
layout.insertWidget(
|
|
3179
|
+
3, label_table
|
|
3180
|
+
) # Insert after clear prompts button
|
|
2347
3181
|
return label_table
|
|
2348
3182
|
|
|
2349
|
-
# Add helper method to ensure
|
|
2350
|
-
def
|
|
2351
|
-
|
|
2352
|
-
|
|
2353
|
-
|
|
2354
|
-
|
|
2355
|
-
|
|
3183
|
+
# Add helper method to ensure active prompt layer is selected based on mode
|
|
3184
|
+
def _ensure_active_prompt_layer():
|
|
3185
|
+
if processor.prompt_mode == "point":
|
|
3186
|
+
points_layer = None
|
|
3187
|
+
for layer in list(processor.viewer.layers):
|
|
3188
|
+
if "Points" in layer.name and "Object" not in layer.name:
|
|
3189
|
+
points_layer = layer
|
|
3190
|
+
break
|
|
2356
3191
|
|
|
2357
|
-
|
|
2358
|
-
|
|
2359
|
-
|
|
2360
|
-
|
|
2361
|
-
|
|
2362
|
-
|
|
2363
|
-
|
|
2364
|
-
|
|
2365
|
-
|
|
3192
|
+
if points_layer is not None:
|
|
3193
|
+
processor.viewer.layers.selection.active = points_layer
|
|
3194
|
+
if processor.use_3d:
|
|
3195
|
+
status_label.setText(
|
|
3196
|
+
"Points layer active - Navigate to FIRST SLICE of object, ensure 2D view, then click"
|
|
3197
|
+
)
|
|
3198
|
+
else:
|
|
3199
|
+
status_label.setText(
|
|
3200
|
+
"Points layer is now active - click to add points"
|
|
3201
|
+
)
|
|
3202
|
+
else:
|
|
3203
|
+
status_label.setText(
|
|
3204
|
+
"No points layer found. Please load an image first."
|
|
3205
|
+
)
|
|
3206
|
+
else: # box mode
|
|
3207
|
+
shapes_layer = None
|
|
3208
|
+
for layer in list(processor.viewer.layers):
|
|
3209
|
+
if "Rectangles" in layer.name:
|
|
3210
|
+
shapes_layer = layer
|
|
3211
|
+
break
|
|
3212
|
+
|
|
3213
|
+
if shapes_layer is not None:
|
|
3214
|
+
processor.viewer.layers.selection.active = shapes_layer
|
|
3215
|
+
status_label.setText(
|
|
3216
|
+
"Rectangles layer is now active - draw rectangles"
|
|
3217
|
+
)
|
|
3218
|
+
else:
|
|
3219
|
+
status_label.setText(
|
|
3220
|
+
"No rectangles layer found. Please load an image first."
|
|
3221
|
+
)
|
|
2366
3222
|
|
|
2367
|
-
processor.
|
|
3223
|
+
processor._ensure_active_prompt_layer = _ensure_active_prompt_layer
|
|
3224
|
+
|
|
3225
|
+
# Keep the old method for backward compatibility
|
|
3226
|
+
processor._ensure_points_layer_active = _ensure_active_prompt_layer
|
|
3227
|
+
|
|
3228
|
+
def on_clear_prompts_clicked():
|
|
3229
|
+
# Find and clear/remove prompt layers based on mode
|
|
3230
|
+
main_points_layer = None
|
|
3231
|
+
object_points_layers = []
|
|
3232
|
+
shapes_layer = None
|
|
2368
3233
|
|
|
2369
|
-
# Connect button signals
|
|
2370
|
-
def on_clear_points_clicked():
|
|
2371
|
-
# Remove all point layers
|
|
2372
3234
|
for layer in list(processor.viewer.layers):
|
|
2373
3235
|
if "Points" in layer.name:
|
|
3236
|
+
if "Object" in layer.name:
|
|
3237
|
+
object_points_layers.append(layer)
|
|
3238
|
+
else:
|
|
3239
|
+
main_points_layer = layer
|
|
3240
|
+
elif "Rectangles" in layer.name:
|
|
3241
|
+
shapes_layer = layer
|
|
3242
|
+
|
|
3243
|
+
# Remove object-specific point layers (these are created dynamically)
|
|
3244
|
+
for layer in object_points_layers:
|
|
3245
|
+
# Clean up callbacks before removing the layer to prevent cleanup issues
|
|
3246
|
+
if hasattr(layer, "mouse_drag_callbacks"):
|
|
3247
|
+
layer.mouse_drag_callbacks.clear()
|
|
3248
|
+
with contextlib.suppress(ValueError):
|
|
2374
3249
|
processor.viewer.layers.remove(layer)
|
|
2375
3250
|
|
|
2376
|
-
#
|
|
2377
|
-
if
|
|
2378
|
-
|
|
2379
|
-
processor.points_labels = {}
|
|
3251
|
+
# Clear shapes layer
|
|
3252
|
+
if shapes_layer is not None:
|
|
3253
|
+
shapes_layer.data = []
|
|
2380
3254
|
|
|
2381
|
-
|
|
2382
|
-
|
|
2383
|
-
|
|
3255
|
+
# Clear data from main points layer instead of removing it
|
|
3256
|
+
if main_points_layer is not None:
|
|
3257
|
+
# Clear the points data
|
|
3258
|
+
main_points_layer.data = np.zeros(
|
|
3259
|
+
(0, 2 if not processor.use_3d else 3)
|
|
3260
|
+
)
|
|
3261
|
+
main_points_layer.face_color = "green"
|
|
2384
3262
|
|
|
2385
|
-
|
|
2386
|
-
|
|
2387
|
-
|
|
3263
|
+
# Ensure the click callback is still connected
|
|
3264
|
+
if (
|
|
3265
|
+
hasattr(main_points_layer, "mouse_drag_callbacks")
|
|
3266
|
+
and processor._on_points_clicked
|
|
3267
|
+
not in main_points_layer.mouse_drag_callbacks
|
|
3268
|
+
):
|
|
3269
|
+
main_points_layer.mouse_drag_callbacks.append(
|
|
3270
|
+
processor._on_points_clicked
|
|
3271
|
+
)
|
|
3272
|
+
|
|
3273
|
+
# Reset all tracking attributes for 2D
|
|
3274
|
+
if not processor.use_3d:
|
|
3275
|
+
# Reset current segmentation tracking
|
|
3276
|
+
if hasattr(processor, "current_points"):
|
|
3277
|
+
processor.current_points = []
|
|
3278
|
+
processor.current_labels = []
|
|
3279
|
+
|
|
3280
|
+
# Reset object tracking
|
|
3281
|
+
if hasattr(processor, "obj_points"):
|
|
3282
|
+
processor.obj_points = {}
|
|
3283
|
+
processor.obj_labels = {}
|
|
3284
|
+
|
|
3285
|
+
# Reset box tracking
|
|
3286
|
+
if hasattr(processor, "obj_boxes"):
|
|
3287
|
+
processor.obj_boxes = {}
|
|
3288
|
+
|
|
3289
|
+
# Reset object ID counters
|
|
3290
|
+
if hasattr(processor, "current_obj_id"):
|
|
3291
|
+
# Find the highest existing label ID
|
|
3292
|
+
if processor.segmentation_result is not None:
|
|
3293
|
+
max_label = processor.segmentation_result.max()
|
|
3294
|
+
processor.current_obj_id = max(int(max_label) + 1, 1)
|
|
3295
|
+
processor.next_obj_id = processor.current_obj_id
|
|
3296
|
+
else:
|
|
3297
|
+
processor.current_obj_id = 1
|
|
3298
|
+
processor.next_obj_id = 1
|
|
3299
|
+
|
|
3300
|
+
# Reset SAM2 predictor state
|
|
3301
|
+
processor.reset_sam2_state()
|
|
3302
|
+
|
|
3303
|
+
# For 3D, reset video-specific tracking
|
|
3304
|
+
else:
|
|
3305
|
+
if hasattr(processor, "sam2_points_by_obj"):
|
|
3306
|
+
processor.sam2_points_by_obj = {}
|
|
3307
|
+
processor.sam2_labels_by_obj = {}
|
|
3308
|
+
|
|
3309
|
+
# Reset box tracking
|
|
3310
|
+
if hasattr(processor, "obj_boxes"):
|
|
3311
|
+
processor.obj_boxes = {}
|
|
3312
|
+
|
|
3313
|
+
if hasattr(processor, "points_data"):
|
|
3314
|
+
processor.points_data = {}
|
|
3315
|
+
processor.points_labels = {}
|
|
3316
|
+
|
|
3317
|
+
# Note: We don't reset _sam2_state for 3D as it needs to maintain video state
|
|
3318
|
+
|
|
3319
|
+
# Make the appropriate prompt layer active based on mode
|
|
3320
|
+
_ensure_active_prompt_layer()
|
|
2388
3321
|
|
|
2389
3322
|
status_label.setText(
|
|
2390
|
-
"Cleared all
|
|
3323
|
+
"Cleared all prompts. Ready to add new segmentation prompts."
|
|
2391
3324
|
)
|
|
2392
3325
|
|
|
2393
3326
|
def on_select_all_clicked():
|
|
@@ -2411,8 +3344,14 @@ def create_crop_widget(processor):
|
|
|
2411
3344
|
)
|
|
2412
3345
|
|
|
2413
3346
|
def on_next_clicked():
|
|
2414
|
-
#
|
|
2415
|
-
|
|
3347
|
+
# Check if we can move to the next image before clearing prompts
|
|
3348
|
+
if processor.current_index >= len(processor.images) - 1:
|
|
3349
|
+
next_button.setEnabled(False)
|
|
3350
|
+
status_label.setText("No more images. Processing complete.")
|
|
3351
|
+
return
|
|
3352
|
+
|
|
3353
|
+
# Clear prompts before moving to next image
|
|
3354
|
+
on_clear_prompts_clicked()
|
|
2416
3355
|
|
|
2417
3356
|
if not processor.next_image():
|
|
2418
3357
|
next_button.setEnabled(False)
|
|
@@ -2422,11 +3361,17 @@ def create_crop_widget(processor):
|
|
|
2422
3361
|
status_label.setText(
|
|
2423
3362
|
f"Showing image {processor.current_index + 1}/{len(processor.images)}"
|
|
2424
3363
|
)
|
|
2425
|
-
processor.
|
|
3364
|
+
processor._ensure_active_prompt_layer()
|
|
2426
3365
|
|
|
2427
3366
|
def on_prev_clicked():
|
|
2428
|
-
#
|
|
2429
|
-
|
|
3367
|
+
# Check if we can move to the previous image before clearing prompts
|
|
3368
|
+
if processor.current_index <= 0:
|
|
3369
|
+
prev_button.setEnabled(False)
|
|
3370
|
+
status_label.setText("Already at the first image.")
|
|
3371
|
+
return
|
|
3372
|
+
|
|
3373
|
+
# Clear prompts before moving to previous image
|
|
3374
|
+
on_clear_prompts_clicked()
|
|
2430
3375
|
|
|
2431
3376
|
if not processor.previous_image():
|
|
2432
3377
|
prev_button.setEnabled(False)
|
|
@@ -2436,15 +3381,33 @@ def create_crop_widget(processor):
|
|
|
2436
3381
|
status_label.setText(
|
|
2437
3382
|
f"Showing image {processor.current_index + 1}/{len(processor.images)}"
|
|
2438
3383
|
)
|
|
2439
|
-
processor.
|
|
3384
|
+
processor._ensure_active_prompt_layer()
|
|
3385
|
+
|
|
3386
|
+
def on_point_mode_clicked():
|
|
3387
|
+
processor.prompt_mode = "point"
|
|
3388
|
+
point_mode_button.setChecked(True)
|
|
3389
|
+
box_mode_button.setChecked(False)
|
|
3390
|
+
processor._update_label_layer()
|
|
3391
|
+
status_label.setText("Point mode active - click on objects to segment")
|
|
3392
|
+
|
|
3393
|
+
def on_box_mode_clicked():
|
|
3394
|
+
processor.prompt_mode = "box"
|
|
3395
|
+
point_mode_button.setChecked(False)
|
|
3396
|
+
box_mode_button.setChecked(True)
|
|
3397
|
+
processor._update_label_layer()
|
|
3398
|
+
status_label.setText(
|
|
3399
|
+
"Rectangle mode active - draw rectangles around objects"
|
|
3400
|
+
)
|
|
2440
3401
|
|
|
2441
|
-
|
|
3402
|
+
clear_prompts_button.clicked.connect(on_clear_prompts_clicked)
|
|
2442
3403
|
select_all_button.clicked.connect(on_select_all_clicked)
|
|
2443
3404
|
clear_selection_button.clicked.connect(on_clear_selection_clicked)
|
|
2444
3405
|
crop_button.clicked.connect(on_crop_clicked)
|
|
2445
3406
|
next_button.clicked.connect(on_next_clicked)
|
|
2446
3407
|
prev_button.clicked.connect(on_prev_clicked)
|
|
2447
|
-
activate_button.clicked.connect(
|
|
3408
|
+
activate_button.clicked.connect(_ensure_active_prompt_layer)
|
|
3409
|
+
point_mode_button.clicked.connect(on_point_mode_clicked)
|
|
3410
|
+
box_mode_button.clicked.connect(on_box_mode_clicked)
|
|
2448
3411
|
|
|
2449
3412
|
return crop_widget
|
|
2450
3413
|
|
|
@@ -2463,6 +3426,19 @@ def batch_crop_anything(
|
|
|
2463
3426
|
viewer: Viewer = None,
|
|
2464
3427
|
):
|
|
2465
3428
|
"""MagicGUI widget for starting Batch Crop Anything using SAM2."""
|
|
3429
|
+
# Check if torch is available
|
|
3430
|
+
if not _HAS_TORCH:
|
|
3431
|
+
QMessageBox.critical(
|
|
3432
|
+
None,
|
|
3433
|
+
"Missing Dependency",
|
|
3434
|
+
"PyTorch not found. Batch Crop Anything requires PyTorch and SAM2.\n\n"
|
|
3435
|
+
"To install the required dependencies, run:\n"
|
|
3436
|
+
"pip install 'napari-tmidas[deep-learning]'\n\n"
|
|
3437
|
+
"Then follow SAM2 installation instructions at:\n"
|
|
3438
|
+
"https://github.com/MercaderLabAnatomy/napari-tmidas#installation",
|
|
3439
|
+
)
|
|
3440
|
+
return
|
|
3441
|
+
|
|
2466
3442
|
# Check if SAM2 is available
|
|
2467
3443
|
try:
|
|
2468
3444
|
import importlib.util
|
|
@@ -2473,15 +3449,15 @@ def batch_crop_anything(
|
|
|
2473
3449
|
None,
|
|
2474
3450
|
"Missing Dependency",
|
|
2475
3451
|
"SAM2 not found. Please follow installation instructions at:\n"
|
|
2476
|
-
"https://github.com/MercaderLabAnatomy/napari-tmidas
|
|
3452
|
+
"https://github.com/MercaderLabAnatomy/napari-tmidas#installation\n",
|
|
2477
3453
|
)
|
|
2478
3454
|
return
|
|
2479
3455
|
except ImportError:
|
|
2480
3456
|
QMessageBox.critical(
|
|
2481
3457
|
None,
|
|
2482
3458
|
"Missing Dependency",
|
|
2483
|
-
"SAM2 package cannot be imported. Please follow installation instructions at
|
|
2484
|
-
"https://github.com/MercaderLabAnatomy/napari-tmidas
|
|
3459
|
+
"SAM2 package cannot be imported. Please follow installation instructions at:\n"
|
|
3460
|
+
"https://github.com/MercaderLabAnatomy/napari-tmidas#installation",
|
|
2485
3461
|
)
|
|
2486
3462
|
return
|
|
2487
3463
|
|
|
@@ -2509,24 +3485,7 @@ def batch_crop_anything_widget():
|
|
|
2509
3485
|
# Create the magicgui widget
|
|
2510
3486
|
widget = batch_crop_anything
|
|
2511
3487
|
|
|
2512
|
-
#
|
|
2513
|
-
|
|
2514
|
-
|
|
2515
|
-
def on_folder_browse_clicked():
|
|
2516
|
-
folder = QFileDialog.getExistingDirectory(
|
|
2517
|
-
None,
|
|
2518
|
-
"Select Folder",
|
|
2519
|
-
os.path.expanduser("~"),
|
|
2520
|
-
QFileDialog.ShowDirsOnly | QFileDialog.DontResolveSymlinks,
|
|
2521
|
-
)
|
|
2522
|
-
if folder:
|
|
2523
|
-
# Update the folder_path field
|
|
2524
|
-
widget.folder_path.value = folder
|
|
2525
|
-
|
|
2526
|
-
folder_browse_button.clicked.connect(on_folder_browse_clicked)
|
|
2527
|
-
|
|
2528
|
-
# Insert the browse button next to the folder_path field
|
|
2529
|
-
folder_layout = widget.folder_path.native.parent().layout()
|
|
2530
|
-
folder_layout.addWidget(folder_browse_button)
|
|
3488
|
+
# Add browse button using common utility
|
|
3489
|
+
add_browse_button_to_folder_field(widget, "folder_path")
|
|
2531
3490
|
|
|
2532
3491
|
return widget
|