napari-tmidas 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- napari_tmidas/_crop_anything.py +1942 -607
- napari_tmidas/_file_selector.py +99 -16
- napari_tmidas/_registry.py +15 -14
- napari_tmidas/_tests/test_file_selector.py +90 -0
- napari_tmidas/_tests/test_registry.py +67 -0
- napari_tmidas/_version.py +2 -2
- napari_tmidas/processing_functions/basic.py +494 -23
- napari_tmidas/processing_functions/careamics_denoising.py +324 -0
- napari_tmidas/processing_functions/careamics_env_manager.py +339 -0
- napari_tmidas/processing_functions/cellpose_env_manager.py +55 -20
- napari_tmidas/processing_functions/cellpose_segmentation.py +105 -218
- napari_tmidas/processing_functions/sam2_mp4.py +283 -0
- napari_tmidas/processing_functions/skimage_filters.py +31 -1
- napari_tmidas/processing_functions/timepoint_merger.py +490 -0
- napari_tmidas/processing_functions/trackastra_tracking.py +322 -0
- {napari_tmidas-0.2.0.dist-info → napari_tmidas-0.2.2.dist-info}/METADATA +37 -17
- {napari_tmidas-0.2.0.dist-info → napari_tmidas-0.2.2.dist-info}/RECORD +21 -14
- {napari_tmidas-0.2.0.dist-info → napari_tmidas-0.2.2.dist-info}/WHEEL +1 -1
- {napari_tmidas-0.2.0.dist-info → napari_tmidas-0.2.2.dist-info}/entry_points.txt +0 -0
- {napari_tmidas-0.2.0.dist-info → napari_tmidas-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {napari_tmidas-0.2.0.dist-info → napari_tmidas-0.2.2.dist-info}/top_level.txt +0 -0
napari_tmidas/_crop_anything.py
CHANGED
|
@@ -1,13 +1,17 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Batch Crop Anything - A Napari plugin for interactive image cropping
|
|
3
3
|
|
|
4
|
-
This plugin combines
|
|
4
|
+
This plugin combines SAM2 for automatic object detection with
|
|
5
5
|
an interactive interface for selecting and cropping objects from images.
|
|
6
|
+
The plugin supports both 2D (YX) and 3D (TYX/ZYX) data.
|
|
6
7
|
"""
|
|
7
8
|
|
|
9
|
+
import contextlib
|
|
8
10
|
import os
|
|
11
|
+
import sys
|
|
9
12
|
|
|
10
13
|
import numpy as np
|
|
14
|
+
import requests
|
|
11
15
|
import torch
|
|
12
16
|
from magicgui import magicgui
|
|
13
17
|
from napari.layers import Labels
|
|
@@ -22,34 +26,73 @@ from qtpy.QtWidgets import (
|
|
|
22
26
|
QMessageBox,
|
|
23
27
|
QPushButton,
|
|
24
28
|
QScrollArea,
|
|
25
|
-
QSlider,
|
|
26
29
|
QTableWidget,
|
|
27
30
|
QTableWidgetItem,
|
|
28
31
|
QVBoxLayout,
|
|
29
32
|
QWidget,
|
|
30
33
|
)
|
|
31
34
|
from skimage.io import imread
|
|
32
|
-
from skimage.transform import resize
|
|
35
|
+
from skimage.transform import resize
|
|
33
36
|
from tifffile import imwrite
|
|
34
37
|
|
|
38
|
+
from napari_tmidas.processing_functions.sam2_mp4 import tif_to_mp4
|
|
39
|
+
|
|
40
|
+
sam2_paths = [
|
|
41
|
+
os.environ.get("SAM2_PATH"),
|
|
42
|
+
"/opt/sam2",
|
|
43
|
+
os.path.expanduser("~/sam2"),
|
|
44
|
+
"./sam2",
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
for path in sam2_paths:
|
|
48
|
+
if path and os.path.exists(path):
|
|
49
|
+
sys.path.append(path)
|
|
50
|
+
break
|
|
51
|
+
else:
|
|
52
|
+
print(
|
|
53
|
+
"Warning: SAM2 not found in common locations. Please set SAM2_PATH environment variable."
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def get_device():
|
|
58
|
+
if sys.platform == "darwin":
|
|
59
|
+
# MacOS: Only check for MPS
|
|
60
|
+
if (
|
|
61
|
+
hasattr(torch.backends, "mps")
|
|
62
|
+
and torch.backends.mps.is_available()
|
|
63
|
+
):
|
|
64
|
+
device = torch.device("mps")
|
|
65
|
+
print("Using Apple Silicon GPU (MPS)")
|
|
66
|
+
else:
|
|
67
|
+
device = torch.device("cpu")
|
|
68
|
+
print("Using CPU")
|
|
69
|
+
else:
|
|
70
|
+
# Other platforms: check for CUDA, then CPU
|
|
71
|
+
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
|
72
|
+
device = torch.device("cuda")
|
|
73
|
+
print(f"Using CUDA GPU: {torch.cuda.get_device_name()}")
|
|
74
|
+
else:
|
|
75
|
+
device = torch.device("cpu")
|
|
76
|
+
print("Using CPU")
|
|
77
|
+
return device
|
|
78
|
+
|
|
35
79
|
|
|
36
80
|
class BatchCropAnything:
|
|
37
|
-
"""
|
|
38
|
-
Class for processing images with Segment Anything and cropping selected objects.
|
|
39
|
-
"""
|
|
81
|
+
"""Class for processing images with SAM2 and cropping selected objects."""
|
|
40
82
|
|
|
41
|
-
def __init__(self, viewer: Viewer):
|
|
83
|
+
def __init__(self, viewer: Viewer, use_3d=False):
|
|
42
84
|
"""Initialize the BatchCropAnything processor."""
|
|
43
85
|
# Core components
|
|
44
86
|
self.viewer = viewer
|
|
45
87
|
self.images = []
|
|
46
88
|
self.current_index = 0
|
|
89
|
+
self.use_3d = use_3d
|
|
47
90
|
|
|
48
91
|
# Image and segmentation data
|
|
49
92
|
self.original_image = None
|
|
50
93
|
self.segmentation_result = None
|
|
51
94
|
self.current_image_for_segmentation = None
|
|
52
|
-
self.current_scale_factor = 1.0
|
|
95
|
+
self.current_scale_factor = 1.0
|
|
53
96
|
|
|
54
97
|
# UI references
|
|
55
98
|
self.image_layer = None
|
|
@@ -63,101 +106,73 @@ class BatchCropAnything:
|
|
|
63
106
|
# Segmentation parameters
|
|
64
107
|
self.sensitivity = 50 # Default sensitivity (0-100 scale)
|
|
65
108
|
|
|
66
|
-
# Initialize the
|
|
67
|
-
self.
|
|
68
|
-
|
|
69
|
-
# --------------------------------------------------
|
|
70
|
-
# Model Initialization
|
|
71
|
-
# --------------------------------------------------
|
|
72
|
-
|
|
73
|
-
def _initialize_sam(self):
|
|
74
|
-
"""Initialize the Segment Anything Model."""
|
|
75
|
-
try:
|
|
76
|
-
# Import required modules
|
|
77
|
-
from mobile_sam import (
|
|
78
|
-
SamAutomaticMaskGenerator,
|
|
79
|
-
sam_model_registry,
|
|
80
|
-
)
|
|
109
|
+
# Initialize the SAM2 model
|
|
110
|
+
self._initialize_sam2()
|
|
81
111
|
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
model_type = "vit_t"
|
|
112
|
+
def _initialize_sam2(self):
|
|
113
|
+
"""Initialize the SAM2 model based on dimension mode."""
|
|
85
114
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
if checkpoint_path is None:
|
|
89
|
-
self.mobile_sam = None
|
|
90
|
-
self.mask_generator = None
|
|
91
|
-
return
|
|
115
|
+
def download_checkpoint(url, dest_folder):
|
|
116
|
+
import os
|
|
92
117
|
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
118
|
+
os.makedirs(dest_folder, exist_ok=True)
|
|
119
|
+
filename = os.path.join(dest_folder, url.split("/")[-1])
|
|
120
|
+
if not os.path.exists(filename):
|
|
121
|
+
print(f"Downloading checkpoint to {filename}...")
|
|
122
|
+
response = requests.get(url, stream=True, timeout=30)
|
|
123
|
+
response.raise_for_status()
|
|
124
|
+
with open(filename, "wb") as f:
|
|
125
|
+
for chunk in response.iter_content(chunk_size=8192):
|
|
126
|
+
f.write(chunk)
|
|
127
|
+
print("Download complete.")
|
|
128
|
+
else:
|
|
129
|
+
print(f"Checkpoint already exists at {filename}.")
|
|
130
|
+
return filename
|
|
99
131
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
self.viewer.status = f"Initialized SAM model from {checkpoint_path} on {self.device}"
|
|
132
|
+
try:
|
|
133
|
+
# import torch
|
|
103
134
|
|
|
104
|
-
|
|
105
|
-
self.viewer.status = f"Error initializing SAM: {str(e)}"
|
|
106
|
-
self.mobile_sam = None
|
|
107
|
-
self.mask_generator = None
|
|
135
|
+
self.device = get_device()
|
|
108
136
|
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
# Find the mobile_sam package location
|
|
115
|
-
mobile_sam_spec = importlib.util.find_spec("mobile_sam")
|
|
116
|
-
if mobile_sam_spec is None:
|
|
117
|
-
raise ImportError("mobile_sam package not found")
|
|
118
|
-
|
|
119
|
-
mobile_sam_path = os.path.dirname(mobile_sam_spec.origin)
|
|
120
|
-
|
|
121
|
-
# Check common locations for the model file
|
|
122
|
-
checkpoint_paths = [
|
|
123
|
-
os.path.join(mobile_sam_path, "weights", "mobile_sam.pt"),
|
|
124
|
-
os.path.join(mobile_sam_path, "mobile_sam.pt"),
|
|
125
|
-
os.path.join(
|
|
126
|
-
os.path.dirname(mobile_sam_path),
|
|
127
|
-
"weights",
|
|
128
|
-
"mobile_sam.pt",
|
|
129
|
-
),
|
|
130
|
-
os.path.join(
|
|
131
|
-
os.path.expanduser("~"), "models", "mobile_sam.pt"
|
|
132
|
-
),
|
|
133
|
-
"/opt/T-MIDAS/models/mobile_sam.pt",
|
|
134
|
-
os.path.join(os.getcwd(), "mobile_sam.pt"),
|
|
135
|
-
]
|
|
136
|
-
|
|
137
|
-
for path in checkpoint_paths:
|
|
138
|
-
if os.path.exists(path):
|
|
139
|
-
return path
|
|
140
|
-
|
|
141
|
-
# If model not found, ask user
|
|
142
|
-
QMessageBox.information(
|
|
143
|
-
None,
|
|
144
|
-
"Model Not Found",
|
|
145
|
-
"Mobile-SAM model weights not found. Please select the mobile_sam.pt file.",
|
|
137
|
+
# Download checkpoint if needed
|
|
138
|
+
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt"
|
|
139
|
+
checkpoint_path = download_checkpoint(
|
|
140
|
+
checkpoint_url, "/opt/sam2/checkpoints/"
|
|
146
141
|
)
|
|
142
|
+
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
|
147
143
|
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
)
|
|
144
|
+
if self.use_3d:
|
|
145
|
+
from sam2.build_sam import build_sam2_video_predictor
|
|
151
146
|
|
|
152
|
-
|
|
147
|
+
self.predictor = build_sam2_video_predictor(
|
|
148
|
+
model_cfg, checkpoint_path, device=self.device
|
|
149
|
+
)
|
|
150
|
+
self.viewer.status = (
|
|
151
|
+
f"Initialized SAM2 Video Predictor on {self.device}"
|
|
152
|
+
)
|
|
153
|
+
else:
|
|
154
|
+
from sam2.build_sam import build_sam2
|
|
155
|
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
156
|
+
|
|
157
|
+
self.predictor = SAM2ImagePredictor(
|
|
158
|
+
build_sam2(model_cfg, checkpoint_path)
|
|
159
|
+
)
|
|
160
|
+
self.viewer.status = (
|
|
161
|
+
f"Initialized SAM2 Image Predictor on {self.device}"
|
|
162
|
+
)
|
|
153
163
|
|
|
154
|
-
except (
|
|
155
|
-
|
|
156
|
-
|
|
164
|
+
except (
|
|
165
|
+
ImportError,
|
|
166
|
+
RuntimeError,
|
|
167
|
+
ValueError,
|
|
168
|
+
FileNotFoundError,
|
|
169
|
+
requests.RequestException,
|
|
170
|
+
) as e:
|
|
171
|
+
import traceback
|
|
157
172
|
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
173
|
+
self.viewer.status = f"Error initializing SAM2: {str(e)}"
|
|
174
|
+
self.predictor = None
|
|
175
|
+
print(traceback.format_exc())
|
|
161
176
|
|
|
162
177
|
def load_images(self, folder_path: str):
|
|
163
178
|
"""Load images from the specified folder path."""
|
|
@@ -169,17 +184,19 @@ class BatchCropAnything:
|
|
|
169
184
|
self.images = [
|
|
170
185
|
os.path.join(folder_path, file)
|
|
171
186
|
for file in files
|
|
172
|
-
if file.lower().endswith(
|
|
173
|
-
|
|
174
|
-
)
|
|
175
|
-
and not file.
|
|
187
|
+
if file.lower().endswith(".tif")
|
|
188
|
+
or file.lower().endswith(".tiff")
|
|
189
|
+
and "label" not in file.lower()
|
|
190
|
+
and "cropped" not in file.lower()
|
|
191
|
+
and "_labels_" not in file.lower()
|
|
192
|
+
and "_cropped_" not in file.lower()
|
|
176
193
|
]
|
|
177
194
|
|
|
178
195
|
if not self.images:
|
|
179
196
|
self.viewer.status = "No compatible images found in the folder."
|
|
180
197
|
return
|
|
181
198
|
|
|
182
|
-
self.viewer.status = f"Found {len(self.images)} images."
|
|
199
|
+
self.viewer.status = f"Found {len(self.images)} .tif images."
|
|
183
200
|
self.current_index = 0
|
|
184
201
|
self._load_current_image()
|
|
185
202
|
|
|
@@ -237,9 +254,9 @@ class BatchCropAnything:
|
|
|
237
254
|
self.viewer.status = "No images to process."
|
|
238
255
|
return
|
|
239
256
|
|
|
240
|
-
if self.
|
|
257
|
+
if self.predictor is None:
|
|
241
258
|
self.viewer.status = (
|
|
242
|
-
"
|
|
259
|
+
"SAM2 model not initialized. Cannot segment images."
|
|
243
260
|
)
|
|
244
261
|
return
|
|
245
262
|
|
|
@@ -253,66 +270,147 @@ class BatchCropAnything:
|
|
|
253
270
|
# Load and process image
|
|
254
271
|
self.original_image = imread(image_path)
|
|
255
272
|
|
|
256
|
-
#
|
|
257
|
-
if self.original_image.
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
273
|
+
# For 3D/4D data, determine dimensions
|
|
274
|
+
if self.use_3d and len(self.original_image.shape) >= 3:
|
|
275
|
+
# Check shape to identify dimensions
|
|
276
|
+
if len(self.original_image.shape) == 4: # TZYX or similar
|
|
277
|
+
# Identify time dimension as first dim with size > 4 and < 400
|
|
278
|
+
# This is a heuristic to differentiate time from channels/small Z stacks
|
|
279
|
+
time_dim_idx = -1
|
|
280
|
+
for i, dim_size in enumerate(self.original_image.shape):
|
|
281
|
+
if 4 < dim_size < 400:
|
|
282
|
+
time_dim_idx = i
|
|
283
|
+
break
|
|
284
|
+
|
|
285
|
+
if time_dim_idx == 0: # TZYX format
|
|
286
|
+
# Keep as is, T is already the first dimension
|
|
287
|
+
self.image_layer = self.viewer.add_image(
|
|
288
|
+
self.original_image,
|
|
289
|
+
name=f"Image ({os.path.basename(image_path)})",
|
|
290
|
+
)
|
|
291
|
+
# Store time dimension info
|
|
292
|
+
self.time_dim_size = self.original_image.shape[0]
|
|
293
|
+
self.has_z_dim = True
|
|
294
|
+
elif (
|
|
295
|
+
time_dim_idx > 0
|
|
296
|
+
): # Unusual format, we need to transpose
|
|
297
|
+
# Transpose to move T to first dimension
|
|
298
|
+
# Create permutation order that puts time_dim_idx first
|
|
299
|
+
perm_order = list(
|
|
300
|
+
range(len(self.original_image.shape))
|
|
301
|
+
)
|
|
302
|
+
perm_order.remove(time_dim_idx)
|
|
303
|
+
perm_order.insert(0, time_dim_idx)
|
|
304
|
+
|
|
305
|
+
transposed_image = np.transpose(
|
|
306
|
+
self.original_image, perm_order
|
|
307
|
+
)
|
|
308
|
+
self.original_image = (
|
|
309
|
+
transposed_image # Replace with transposed version
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
self.image_layer = self.viewer.add_image(
|
|
313
|
+
self.original_image,
|
|
314
|
+
name=f"Image ({os.path.basename(image_path)})",
|
|
315
|
+
)
|
|
316
|
+
# Store time dimension info
|
|
317
|
+
self.time_dim_size = self.original_image.shape[0]
|
|
318
|
+
self.has_z_dim = True
|
|
319
|
+
else:
|
|
320
|
+
# No time dimension found, treat as ZYX
|
|
321
|
+
self.image_layer = self.viewer.add_image(
|
|
322
|
+
self.original_image,
|
|
323
|
+
name=f"Image ({os.path.basename(image_path)})",
|
|
324
|
+
)
|
|
325
|
+
self.time_dim_size = 1
|
|
326
|
+
self.has_z_dim = True
|
|
327
|
+
elif (
|
|
328
|
+
len(self.original_image.shape) == 3
|
|
329
|
+
): # Could be TYX or ZYX
|
|
330
|
+
# Check if first dimension is likely time (> 4, < 400)
|
|
331
|
+
if 4 < self.original_image.shape[0] < 400:
|
|
332
|
+
# Likely TYX format
|
|
333
|
+
self.image_layer = self.viewer.add_image(
|
|
334
|
+
self.original_image,
|
|
335
|
+
name=f"Image ({os.path.basename(image_path)})",
|
|
336
|
+
)
|
|
337
|
+
self.time_dim_size = self.original_image.shape[0]
|
|
338
|
+
self.has_z_dim = False
|
|
339
|
+
else:
|
|
340
|
+
# Likely ZYX format or another 3D format
|
|
341
|
+
self.image_layer = self.viewer.add_image(
|
|
342
|
+
self.original_image,
|
|
343
|
+
name=f"Image ({os.path.basename(image_path)})",
|
|
344
|
+
)
|
|
345
|
+
self.time_dim_size = 1
|
|
346
|
+
self.has_z_dim = True
|
|
347
|
+
else:
|
|
348
|
+
# Should not reach here with use_3d=True, but just in case
|
|
349
|
+
self.image_layer = self.viewer.add_image(
|
|
350
|
+
self.original_image,
|
|
351
|
+
name=f"Image ({os.path.basename(image_path)})",
|
|
352
|
+
)
|
|
353
|
+
self.time_dim_size = 1
|
|
354
|
+
self.has_z_dim = False
|
|
261
355
|
else:
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
356
|
+
# Handle 2D data as before
|
|
357
|
+
if self.original_image.dtype != np.uint8:
|
|
358
|
+
image_for_display = (
|
|
359
|
+
self.original_image
|
|
360
|
+
/ np.amax(self.original_image)
|
|
361
|
+
* 255
|
|
362
|
+
).astype(np.uint8)
|
|
363
|
+
else:
|
|
364
|
+
image_for_display = self.original_image
|
|
365
|
+
|
|
366
|
+
# Add image to viewer
|
|
367
|
+
self.image_layer = self.viewer.add_image(
|
|
368
|
+
image_for_display,
|
|
369
|
+
name=f"Image ({os.path.basename(image_path)})",
|
|
370
|
+
)
|
|
269
371
|
|
|
270
372
|
# Generate segmentation
|
|
271
|
-
self._generate_segmentation(
|
|
373
|
+
self._generate_segmentation(self.original_image, image_path)
|
|
272
374
|
|
|
273
|
-
except (
|
|
375
|
+
except (FileNotFoundError, ValueError, TypeError, OSError) as e:
|
|
274
376
|
import traceback
|
|
275
377
|
|
|
276
378
|
self.viewer.status = f"Error processing image: {str(e)}"
|
|
277
379
|
traceback.print_exc()
|
|
380
|
+
|
|
278
381
|
# Create empty segmentation in case of error
|
|
279
382
|
if (
|
|
280
383
|
hasattr(self, "original_image")
|
|
281
384
|
and self.original_image is not None
|
|
282
385
|
):
|
|
283
|
-
self.
|
|
284
|
-
self.original_image.shape
|
|
285
|
-
|
|
386
|
+
if self.use_3d:
|
|
387
|
+
shape = self.original_image.shape
|
|
388
|
+
else:
|
|
389
|
+
shape = self.original_image.shape[:2]
|
|
390
|
+
|
|
391
|
+
self.segmentation_result = np.zeros(shape, dtype=np.uint32)
|
|
286
392
|
self.label_layer = self.viewer.add_labels(
|
|
287
393
|
self.segmentation_result, name="Error: No Segmentation"
|
|
288
394
|
)
|
|
289
395
|
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
def _generate_segmentation(self, image):
|
|
295
|
-
"""Generate segmentation for the current image."""
|
|
296
|
-
# Prepare for SAM (add color channel if needed)
|
|
297
|
-
if len(image.shape) == 2:
|
|
298
|
-
image_for_sam = image[:, :, np.newaxis].repeat(3, axis=2)
|
|
299
|
-
else:
|
|
300
|
-
image_for_sam = image
|
|
301
|
-
|
|
302
|
-
# Store the current image for later regeneration if sensitivity changes
|
|
303
|
-
self.current_image_for_segmentation = image_for_sam
|
|
396
|
+
def _generate_segmentation(self, image, image_path: str):
|
|
397
|
+
"""Generate segmentation for the current image using SAM2."""
|
|
398
|
+
# Store the current image for later processing
|
|
399
|
+
self.current_image_for_segmentation = image
|
|
304
400
|
|
|
305
401
|
# Generate segmentation with current sensitivity
|
|
306
|
-
self.generate_segmentation_with_sensitivity()
|
|
402
|
+
self.generate_segmentation_with_sensitivity(image_path)
|
|
307
403
|
|
|
308
|
-
def generate_segmentation_with_sensitivity(
|
|
404
|
+
def generate_segmentation_with_sensitivity(
|
|
405
|
+
self, image_path: str, sensitivity=None
|
|
406
|
+
):
|
|
309
407
|
"""Generate segmentation with the specified sensitivity."""
|
|
310
408
|
if sensitivity is not None:
|
|
311
409
|
self.sensitivity = sensitivity
|
|
312
410
|
|
|
313
|
-
if self.
|
|
411
|
+
if self.predictor is None:
|
|
314
412
|
self.viewer.status = (
|
|
315
|
-
"
|
|
413
|
+
"SAM2 model not initialized. Cannot segment images."
|
|
316
414
|
)
|
|
317
415
|
return
|
|
318
416
|
|
|
@@ -321,298 +419,740 @@ class BatchCropAnything:
|
|
|
321
419
|
return
|
|
322
420
|
|
|
323
421
|
try:
|
|
324
|
-
# Map sensitivity (0-100) to
|
|
325
|
-
#
|
|
326
|
-
|
|
422
|
+
# Map sensitivity (0-100) to SAM2 parameters
|
|
423
|
+
# For SAM2, adjust confidence threshold based on sensitivity
|
|
424
|
+
confidence_threshold = (
|
|
425
|
+
0.9 - (self.sensitivity / 100) * 0.4
|
|
426
|
+
) # Range from 0.9 to 0.5
|
|
427
|
+
|
|
428
|
+
# Process based on dimension mode
|
|
429
|
+
if self.use_3d:
|
|
430
|
+
# Process 3D data
|
|
431
|
+
self._generate_3d_segmentation(
|
|
432
|
+
confidence_threshold, image_path
|
|
433
|
+
)
|
|
434
|
+
else:
|
|
435
|
+
# Process 2D data
|
|
436
|
+
self._generate_2d_segmentation(confidence_threshold)
|
|
437
|
+
|
|
438
|
+
except (
|
|
439
|
+
ValueError,
|
|
440
|
+
RuntimeError,
|
|
441
|
+
torch.cuda.OutOfMemoryError,
|
|
442
|
+
TypeError,
|
|
443
|
+
) as e:
|
|
444
|
+
import traceback
|
|
327
445
|
|
|
328
|
-
|
|
329
|
-
|
|
446
|
+
self.viewer.status = f"Error generating segmentation: {str(e)}"
|
|
447
|
+
traceback.print_exc()
|
|
330
448
|
|
|
331
|
-
|
|
332
|
-
|
|
449
|
+
def _generate_2d_segmentation(self, confidence_threshold):
|
|
450
|
+
"""Generate 2D segmentation using SAM2 Image Predictor."""
|
|
451
|
+
# Ensure image is in the correct format for SAM2
|
|
452
|
+
image = self.current_image_for_segmentation
|
|
453
|
+
|
|
454
|
+
# Handle resizing for very large images
|
|
455
|
+
orig_shape = image.shape[:2]
|
|
456
|
+
image_mp = (orig_shape[0] * orig_shape[1]) / 1e6
|
|
457
|
+
max_mp = 2.0 # Maximum image size in megapixels
|
|
458
|
+
|
|
459
|
+
if image_mp > max_mp:
|
|
460
|
+
scale_factor = np.sqrt(max_mp / image_mp)
|
|
461
|
+
new_height = int(orig_shape[0] * scale_factor)
|
|
462
|
+
new_width = int(orig_shape[1] * scale_factor)
|
|
463
|
+
|
|
464
|
+
self.viewer.status = f"Downscaling image from {orig_shape} to {(new_height, new_width)} for processing"
|
|
465
|
+
|
|
466
|
+
# Resize image
|
|
467
|
+
resized_image = resize(
|
|
468
|
+
image,
|
|
469
|
+
(new_height, new_width),
|
|
470
|
+
anti_aliasing=True,
|
|
471
|
+
preserve_range=True,
|
|
472
|
+
).astype(
|
|
473
|
+
np.float32
|
|
474
|
+
) # Convert to float32
|
|
475
|
+
|
|
476
|
+
self.current_scale_factor = scale_factor
|
|
477
|
+
else:
|
|
478
|
+
# Convert to float32 format
|
|
479
|
+
if image.dtype != np.float32:
|
|
480
|
+
resized_image = image.astype(np.float32)
|
|
481
|
+
else:
|
|
482
|
+
resized_image = image
|
|
483
|
+
self.current_scale_factor = 1.0
|
|
484
|
+
|
|
485
|
+
# Ensure image is in RGB format for SAM2
|
|
486
|
+
if len(resized_image.shape) == 2:
|
|
487
|
+
# Convert grayscale to RGB
|
|
488
|
+
resized_image = np.stack([resized_image] * 3, axis=-1)
|
|
489
|
+
elif len(resized_image.shape) == 3 and resized_image.shape[2] == 1:
|
|
490
|
+
# Convert single channel to RGB
|
|
491
|
+
resized_image = np.concatenate([resized_image] * 3, axis=2)
|
|
492
|
+
elif len(resized_image.shape) == 3 and resized_image.shape[2] > 3:
|
|
493
|
+
# Use first 3 channels
|
|
494
|
+
resized_image = resized_image[:, :, :3]
|
|
495
|
+
|
|
496
|
+
# Normalize the image to [0,1] range if it's not already
|
|
497
|
+
if resized_image.max() > 1.0:
|
|
498
|
+
resized_image = resized_image / 255.0
|
|
499
|
+
|
|
500
|
+
# Set SAM2 prediction parameters based on sensitivity
|
|
501
|
+
with torch.inference_mode(), torch.autocast(
|
|
502
|
+
"cuda", dtype=torch.float32
|
|
503
|
+
):
|
|
504
|
+
# Set the image in the predictor
|
|
505
|
+
self.predictor.set_image(resized_image)
|
|
506
|
+
|
|
507
|
+
# Use automatic points generation with confidence threshold
|
|
508
|
+
masks, scores, _ = self.predictor.predict(
|
|
509
|
+
point_coords=None,
|
|
510
|
+
point_labels=None,
|
|
511
|
+
box=None,
|
|
512
|
+
multimask_output=True,
|
|
513
|
+
)
|
|
333
514
|
|
|
334
|
-
#
|
|
335
|
-
|
|
515
|
+
# Filter masks by confidence threshold
|
|
516
|
+
valid_masks = scores > confidence_threshold
|
|
517
|
+
masks = masks[valid_masks]
|
|
518
|
+
scores = scores[valid_masks]
|
|
336
519
|
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
self.mask_generator.min_mask_region_area = min_area
|
|
520
|
+
# Convert masks to label image
|
|
521
|
+
labels = np.zeros(resized_image.shape[:2], dtype=np.uint32)
|
|
522
|
+
self.label_info = {} # Reset label info
|
|
341
523
|
|
|
342
|
-
|
|
343
|
-
#
|
|
344
|
-
|
|
345
|
-
gamma = (
|
|
346
|
-
1.5 - (self.sensitivity / 100) * 1.0
|
|
347
|
-
) # Range from 1.5 to 0.5
|
|
524
|
+
for i, mask in enumerate(masks):
|
|
525
|
+
label_id = i + 1 # Start label IDs from 1
|
|
526
|
+
labels[mask] = label_id
|
|
348
527
|
|
|
349
|
-
#
|
|
350
|
-
|
|
528
|
+
# Calculate label information
|
|
529
|
+
area = np.sum(mask)
|
|
530
|
+
y_indices, x_indices = np.where(mask)
|
|
531
|
+
center_y = np.mean(y_indices) if len(y_indices) > 0 else 0
|
|
532
|
+
center_x = np.mean(x_indices) if len(x_indices) > 0 else 0
|
|
351
533
|
|
|
352
|
-
#
|
|
353
|
-
|
|
534
|
+
# Store label info
|
|
535
|
+
self.label_info[label_id] = {
|
|
536
|
+
"area": area,
|
|
537
|
+
"center_y": center_y,
|
|
538
|
+
"center_x": center_x,
|
|
539
|
+
"score": float(scores[i]),
|
|
540
|
+
}
|
|
354
541
|
|
|
355
|
-
|
|
356
|
-
|
|
542
|
+
# Handle upscaling if needed
|
|
543
|
+
if self.current_scale_factor < 1.0:
|
|
544
|
+
labels = resize(
|
|
545
|
+
labels,
|
|
546
|
+
orig_shape,
|
|
547
|
+
order=0, # Nearest neighbor interpolation
|
|
548
|
+
preserve_range=True,
|
|
549
|
+
anti_aliasing=False,
|
|
550
|
+
).astype(np.uint32)
|
|
357
551
|
|
|
358
|
-
|
|
359
|
-
|
|
552
|
+
# Sort labels by area (largest first)
|
|
553
|
+
self.label_info = dict(
|
|
554
|
+
sorted(
|
|
555
|
+
self.label_info.items(),
|
|
556
|
+
key=lambda item: item[1]["area"],
|
|
557
|
+
reverse=True,
|
|
558
|
+
)
|
|
559
|
+
)
|
|
360
560
|
|
|
361
|
-
|
|
362
|
-
|
|
561
|
+
# Save segmentation result
|
|
562
|
+
self.segmentation_result = labels
|
|
363
563
|
|
|
364
|
-
|
|
365
|
-
|
|
564
|
+
# Update the label layer
|
|
565
|
+
self._update_label_layer()
|
|
366
566
|
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
567
|
+
def _generate_3d_segmentation(self, confidence_threshold, image_path):
|
|
568
|
+
"""
|
|
569
|
+
Initialize 3D segmentation using SAM2 Video Predictor.
|
|
570
|
+
This correctly sets up interactive segmentation following SAM2's video approach.
|
|
571
|
+
"""
|
|
572
|
+
try:
|
|
573
|
+
# Handle image_path - make sure it's a string
|
|
574
|
+
if not isinstance(image_path, str):
|
|
575
|
+
image_path = self.images[self.current_index]
|
|
370
576
|
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
577
|
+
# Initialize empty segmentation
|
|
578
|
+
volume_shape = self.current_image_for_segmentation.shape
|
|
579
|
+
labels = np.zeros(volume_shape, dtype=np.uint32)
|
|
580
|
+
self.segmentation_result = labels
|
|
375
581
|
|
|
376
|
-
|
|
582
|
+
# Create a temp directory for the MP4 conversion if needed
|
|
583
|
+
import os
|
|
584
|
+
import tempfile
|
|
377
585
|
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
anti_aliasing=True,
|
|
383
|
-
preserve_range=True,
|
|
384
|
-
).astype(np.uint8)
|
|
586
|
+
temp_dir = tempfile.gettempdir()
|
|
587
|
+
mp4_path = os.path.join(
|
|
588
|
+
temp_dir, f"temp_volume_{os.path.basename(image_path)}.mp4"
|
|
589
|
+
)
|
|
385
590
|
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
image_gamma_resized = image_gamma
|
|
390
|
-
self.current_scale_factor = 1.0
|
|
591
|
+
# If we need to save a modified version for MP4 conversion
|
|
592
|
+
need_temp_tif = False
|
|
593
|
+
temp_tif_path = None
|
|
391
594
|
|
|
392
|
-
|
|
595
|
+
# Check if we have a 4D volume with Z dimension
|
|
596
|
+
if (
|
|
597
|
+
hasattr(self, "has_z_dim")
|
|
598
|
+
and self.has_z_dim
|
|
599
|
+
and len(self.current_image_for_segmentation.shape) == 4
|
|
600
|
+
):
|
|
601
|
+
# We need to convert the 4D TZYX to a 3D TYX for proper video conversion
|
|
602
|
+
# by taking maximum intensity projection of Z for each time point
|
|
603
|
+
self.viewer.status = (
|
|
604
|
+
"Converting 4D TZYX volume to 3D TYX for SAM2..."
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
# Create maximum intensity projection along Z axis (axis 1 in TZYX)
|
|
608
|
+
projected_volume = np.max(
|
|
609
|
+
self.current_image_for_segmentation, axis=1
|
|
610
|
+
)
|
|
393
611
|
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
612
|
+
# Save this as a temporary TIF for MP4 conversion
|
|
613
|
+
temp_tif_path = os.path.join(
|
|
614
|
+
temp_dir, f"temp_projected_{os.path.basename(image_path)}"
|
|
615
|
+
)
|
|
616
|
+
imwrite(temp_tif_path, projected_volume)
|
|
617
|
+
need_temp_tif = True
|
|
397
618
|
|
|
398
|
-
|
|
619
|
+
# Convert the projected TIF to MP4
|
|
399
620
|
self.viewer.status = (
|
|
400
|
-
"
|
|
621
|
+
"Converting projected 3D volume to MP4 format for SAM2..."
|
|
401
622
|
)
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
623
|
+
mp4_path = tif_to_mp4(temp_tif_path)
|
|
624
|
+
else:
|
|
625
|
+
# Convert original volume to video format for SAM2
|
|
626
|
+
self.viewer.status = (
|
|
627
|
+
"Converting 3D volume to MP4 format for SAM2..."
|
|
628
|
+
)
|
|
629
|
+
mp4_path = tif_to_mp4(image_path)
|
|
405
630
|
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
self.viewer.layers.remove(layer)
|
|
631
|
+
# Initialize SAM2 state with the video
|
|
632
|
+
self.viewer.status = "Initializing SAM2 Video Predictor..."
|
|
633
|
+
with torch.inference_mode(), torch.autocast(
|
|
634
|
+
"cuda", dtype=torch.bfloat16
|
|
635
|
+
):
|
|
636
|
+
self._sam2_state = self.predictor.init_state(mp4_path)
|
|
413
637
|
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
638
|
+
# Store needed state for 3D processing
|
|
639
|
+
self._sam2_next_obj_id = 1
|
|
640
|
+
self._sam2_prompts = (
|
|
641
|
+
{}
|
|
642
|
+
) # Store prompts for each object (points, labels, box)
|
|
643
|
+
|
|
644
|
+
# Update the label layer with empty segmentation
|
|
645
|
+
self._update_label_layer()
|
|
646
|
+
|
|
647
|
+
# Replace the click handler for interactive 3D segmentation
|
|
648
|
+
if self.label_layer is not None and hasattr(
|
|
649
|
+
self.label_layer, "mouse_drag_callbacks"
|
|
650
|
+
):
|
|
651
|
+
for callback in list(self.label_layer.mouse_drag_callbacks):
|
|
652
|
+
self.label_layer.mouse_drag_callbacks.remove(callback)
|
|
653
|
+
|
|
654
|
+
# Add 3D-specific click handler
|
|
655
|
+
self.label_layer.mouse_drag_callbacks.append(
|
|
656
|
+
self._on_3d_label_clicked
|
|
419
657
|
)
|
|
420
658
|
|
|
421
|
-
|
|
422
|
-
|
|
659
|
+
# Set the viewer to show the first frame
|
|
660
|
+
if hasattr(self.viewer, "dims") and self.viewer.dims.ndim > 2:
|
|
661
|
+
self.viewer.dims.set_point(
|
|
662
|
+
0, 0
|
|
663
|
+
) # Set the first dimension (typically time/z) to 0
|
|
664
|
+
|
|
665
|
+
# Clean up temporary file if we created one
|
|
666
|
+
if (
|
|
667
|
+
need_temp_tif
|
|
668
|
+
and temp_tif_path
|
|
669
|
+
and os.path.exists(temp_tif_path)
|
|
670
|
+
):
|
|
671
|
+
with contextlib.suppress(Exception):
|
|
672
|
+
os.remove(temp_tif_path)
|
|
673
|
+
|
|
674
|
+
# Show instructions
|
|
675
|
+
self.viewer.status = (
|
|
676
|
+
"3D Mode active: Navigate to the first frame where object appears, then click. "
|
|
677
|
+
"Use Shift+click for negative points (to remove areas). "
|
|
678
|
+
"Segmentation will be propagated to all frames automatically."
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
return True
|
|
682
|
+
|
|
683
|
+
except (
|
|
684
|
+
FileNotFoundError,
|
|
685
|
+
RuntimeError,
|
|
686
|
+
torch.cuda.OutOfMemoryError,
|
|
687
|
+
ValueError,
|
|
688
|
+
OSError,
|
|
689
|
+
) as e:
|
|
690
|
+
import traceback
|
|
691
|
+
|
|
692
|
+
self.viewer.status = f"Error in 3D segmentation setup: {str(e)}"
|
|
693
|
+
traceback.print_exc()
|
|
694
|
+
return False
|
|
695
|
+
|
|
696
|
+
def _on_3d_label_clicked(self, layer, event):
|
|
697
|
+
"""Handle click on 3D label layer to add a prompt for segmentation."""
|
|
698
|
+
try:
|
|
699
|
+
if event.button != 1:
|
|
423
700
|
return
|
|
424
701
|
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
self.
|
|
430
|
-
|
|
702
|
+
coords = layer.world_to_data(event.position)
|
|
703
|
+
if len(coords) == 3:
|
|
704
|
+
z, y, x = map(int, coords)
|
|
705
|
+
elif len(coords) == 2:
|
|
706
|
+
z = int(self.viewer.dims.current_step[0])
|
|
707
|
+
y, x = map(int, coords)
|
|
708
|
+
else:
|
|
709
|
+
self.viewer.status = (
|
|
710
|
+
f"Unexpected coordinate dimensions: {coords}"
|
|
431
711
|
)
|
|
712
|
+
return
|
|
713
|
+
|
|
714
|
+
# Check if Shift key is pressed
|
|
715
|
+
is_negative = "Shift" in event.modifiers
|
|
716
|
+
point_label = -1 if is_negative else 1
|
|
717
|
+
|
|
718
|
+
# Initialize a unique object ID for this click
|
|
719
|
+
if not hasattr(self, "_sam2_next_obj_id"):
|
|
720
|
+
self._sam2_next_obj_id = 1
|
|
721
|
+
|
|
722
|
+
# Get current object ID (or create new one)
|
|
723
|
+
label_id = self.segmentation_result[z, y, x]
|
|
724
|
+
if is_negative and label_id > 0:
|
|
725
|
+
# Use existing object ID for negative points
|
|
726
|
+
ann_obj_id = label_id
|
|
727
|
+
else:
|
|
728
|
+
# Create new object for positive points on background
|
|
729
|
+
ann_obj_id = self._sam2_next_obj_id
|
|
730
|
+
if point_label > 0 and label_id == 0:
|
|
731
|
+
self._sam2_next_obj_id += 1
|
|
732
|
+
|
|
733
|
+
# Find or create points layer for this object
|
|
734
|
+
points_layer = None
|
|
735
|
+
for layer in list(self.viewer.layers):
|
|
736
|
+
if f"Points for Object {ann_obj_id}" in layer.name:
|
|
737
|
+
points_layer = layer
|
|
738
|
+
break
|
|
739
|
+
|
|
740
|
+
if points_layer is None:
|
|
741
|
+
# Create new points layer for this object
|
|
742
|
+
points_layer = self.viewer.add_points(
|
|
743
|
+
np.array([[z, y, x]]),
|
|
744
|
+
name=f"Points for Object {ann_obj_id}",
|
|
745
|
+
size=10,
|
|
746
|
+
face_color="green" if point_label > 0 else "red",
|
|
747
|
+
border_color="white",
|
|
748
|
+
border_width=1,
|
|
749
|
+
opacity=0.8,
|
|
750
|
+
)
|
|
751
|
+
|
|
752
|
+
with contextlib.suppress(AttributeError, ValueError):
|
|
753
|
+
points_layer.mouse_drag_callbacks.remove(
|
|
754
|
+
self._on_points_clicked
|
|
755
|
+
)
|
|
756
|
+
points_layer.mouse_drag_callbacks.append(
|
|
757
|
+
self._on_points_clicked
|
|
758
|
+
)
|
|
759
|
+
|
|
760
|
+
# Initialize points for this object
|
|
761
|
+
if not hasattr(self, "sam2_points_by_obj"):
|
|
762
|
+
self.sam2_points_by_obj = {}
|
|
763
|
+
self.sam2_labels_by_obj = {}
|
|
764
|
+
|
|
765
|
+
self.sam2_points_by_obj[ann_obj_id] = [[x, y]]
|
|
766
|
+
self.sam2_labels_by_obj[ann_obj_id] = [point_label]
|
|
432
767
|
else:
|
|
433
|
-
|
|
434
|
-
|
|
768
|
+
# Add to existing points layer
|
|
769
|
+
current_points = points_layer.data
|
|
770
|
+
new_points = np.vstack([current_points, [z, y, x]])
|
|
771
|
+
points_layer.data = new_points
|
|
772
|
+
|
|
773
|
+
# Add to existing point lists
|
|
774
|
+
if not hasattr(self, "sam2_points_by_obj"):
|
|
775
|
+
self.sam2_points_by_obj = {}
|
|
776
|
+
self.sam2_labels_by_obj = {}
|
|
777
|
+
|
|
778
|
+
if ann_obj_id not in self.sam2_points_by_obj:
|
|
779
|
+
self.sam2_points_by_obj[ann_obj_id] = []
|
|
780
|
+
self.sam2_labels_by_obj[ann_obj_id] = []
|
|
781
|
+
|
|
782
|
+
self.sam2_points_by_obj[ann_obj_id].append([x, y])
|
|
783
|
+
self.sam2_labels_by_obj[ann_obj_id].append(point_label)
|
|
784
|
+
|
|
785
|
+
# Perform SAM2 segmentation
|
|
786
|
+
if hasattr(self, "_sam2_state") and self._sam2_state is not None:
|
|
787
|
+
points = np.array(
|
|
788
|
+
self.sam2_points_by_obj[ann_obj_id], dtype=np.float32
|
|
789
|
+
)
|
|
790
|
+
labels = np.array(
|
|
791
|
+
self.sam2_labels_by_obj[ann_obj_id], dtype=np.int32
|
|
435
792
|
)
|
|
436
793
|
|
|
437
|
-
|
|
438
|
-
self.selected_labels = set()
|
|
794
|
+
self.viewer.status = f"Processing object at frame {z}..."
|
|
439
795
|
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
796
|
+
_, out_obj_ids, out_mask_logits = (
|
|
797
|
+
self.predictor.add_new_points_or_box(
|
|
798
|
+
inference_state=self._sam2_state,
|
|
799
|
+
frame_idx=z,
|
|
800
|
+
obj_id=ann_obj_id,
|
|
801
|
+
points=points,
|
|
802
|
+
labels=labels,
|
|
803
|
+
)
|
|
804
|
+
)
|
|
443
805
|
|
|
444
|
-
|
|
806
|
+
# Convert logits to mask and update segmentation
|
|
807
|
+
mask = (out_mask_logits[0] > 0.0).cpu().numpy()
|
|
808
|
+
|
|
809
|
+
# Fix mask dimensions if needed
|
|
810
|
+
if mask.ndim > 2:
|
|
811
|
+
mask = mask.squeeze()
|
|
812
|
+
|
|
813
|
+
# Check mask dimensions and resize if needed
|
|
814
|
+
if mask.shape != self.segmentation_result[z].shape:
|
|
815
|
+
from skimage.transform import resize
|
|
816
|
+
|
|
817
|
+
mask = resize(
|
|
818
|
+
mask.astype(float),
|
|
819
|
+
self.segmentation_result[z].shape,
|
|
820
|
+
order=0,
|
|
821
|
+
preserve_range=True,
|
|
822
|
+
anti_aliasing=False,
|
|
823
|
+
).astype(bool)
|
|
824
|
+
|
|
825
|
+
# Apply the mask to current frame
|
|
826
|
+
# For negative points, only remove from the current object
|
|
827
|
+
if point_label < 0:
|
|
828
|
+
# Remove only from current object
|
|
829
|
+
self.segmentation_result[z][
|
|
830
|
+
(self.segmentation_result[z] == ann_obj_id) & mask
|
|
831
|
+
] = 0
|
|
832
|
+
else:
|
|
833
|
+
# Add to current object (only overwrite background)
|
|
834
|
+
self.segmentation_result[z][
|
|
835
|
+
mask & (self.segmentation_result[z] == 0)
|
|
836
|
+
] = ann_obj_id
|
|
837
|
+
|
|
838
|
+
# Automatically propagate to other frames
|
|
839
|
+
self._propagate_mask_for_current_object(ann_obj_id, z)
|
|
840
|
+
|
|
841
|
+
# Update label layer
|
|
842
|
+
self._update_label_layer()
|
|
843
|
+
|
|
844
|
+
# Update label table if needed
|
|
845
|
+
if (
|
|
846
|
+
hasattr(self, "label_table_widget")
|
|
847
|
+
and self.label_table_widget is not None
|
|
848
|
+
):
|
|
849
|
+
self._populate_label_table(self.label_table_widget)
|
|
850
|
+
|
|
851
|
+
self.viewer.status = (
|
|
852
|
+
f"Updated 3D object {ann_obj_id} across all frames"
|
|
853
|
+
)
|
|
854
|
+
else:
|
|
855
|
+
self.viewer.status = "SAM2 3D state not initialized"
|
|
856
|
+
|
|
857
|
+
except (
|
|
858
|
+
IndexError,
|
|
859
|
+
KeyError,
|
|
860
|
+
ValueError,
|
|
861
|
+
RuntimeError,
|
|
862
|
+
torch.cuda.OutOfMemoryError,
|
|
863
|
+
) as e:
|
|
445
864
|
import traceback
|
|
446
865
|
|
|
447
|
-
self.viewer.status = f"Error
|
|
866
|
+
self.viewer.status = f"Error in 3D click handler: {str(e)}"
|
|
448
867
|
traceback.print_exc()
|
|
449
868
|
|
|
450
|
-
def
|
|
451
|
-
"""
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
self.label_info = {} # Reset label info
|
|
869
|
+
def _propagate_mask_for_current_object(self, obj_id, current_frame_idx):
|
|
870
|
+
"""
|
|
871
|
+
Propagate the mask for the current object from the given frame to all other frames.
|
|
872
|
+
Uses SAM2's video propagation with proper error handling.
|
|
455
873
|
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
874
|
+
Parameters:
|
|
875
|
+
obj_id: The ID of the object to propagate
|
|
876
|
+
current_frame_idx: The frame index where the object was identified
|
|
877
|
+
"""
|
|
878
|
+
try:
|
|
879
|
+
if not hasattr(self, "_sam2_state") or self._sam2_state is None:
|
|
880
|
+
self.viewer.status = (
|
|
881
|
+
"SAM2 3D state not initialized for propagation"
|
|
882
|
+
)
|
|
883
|
+
return
|
|
460
884
|
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
y_indices, x_indices = np.where(mask)
|
|
464
|
-
center_y = np.mean(y_indices) if len(y_indices) > 0 else 0
|
|
465
|
-
center_x = np.mean(x_indices) if len(x_indices) > 0 else 0
|
|
885
|
+
total_frames = self.segmentation_result.shape[0]
|
|
886
|
+
self.viewer.status = f"Propagating object {obj_id} through all {total_frames} frames..."
|
|
466
887
|
|
|
467
|
-
#
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
"
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
}
|
|
888
|
+
# Create a progress layer for visualization
|
|
889
|
+
progress_layer = None
|
|
890
|
+
for layer in list(self.viewer.layers):
|
|
891
|
+
if "Propagation Progress" in layer.name:
|
|
892
|
+
progress_layer = layer
|
|
893
|
+
break
|
|
474
894
|
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
895
|
+
if progress_layer is None:
|
|
896
|
+
progress_data = np.zeros_like(
|
|
897
|
+
self.segmentation_result, dtype=float
|
|
898
|
+
)
|
|
899
|
+
progress_layer = self.viewer.add_image(
|
|
900
|
+
progress_data,
|
|
901
|
+
name="Propagation Progress",
|
|
902
|
+
colormap="magma",
|
|
903
|
+
opacity=0.3,
|
|
904
|
+
visible=True,
|
|
905
|
+
)
|
|
906
|
+
|
|
907
|
+
# Update current frame in the progress layer
|
|
908
|
+
progress_data = progress_layer.data
|
|
909
|
+
current_mask = (
|
|
910
|
+
self.segmentation_result[current_frame_idx] == obj_id
|
|
481
911
|
)
|
|
482
|
-
|
|
912
|
+
progress_data[current_frame_idx] = current_mask.astype(float) * 0.8
|
|
913
|
+
progress_layer.data = progress_data
|
|
914
|
+
|
|
915
|
+
# Try to perform SAM2 propagation with error handling
|
|
916
|
+
try:
|
|
917
|
+
# Use torch.inference_mode() and torch.autocast to ensure consistent dtypes
|
|
918
|
+
with torch.inference_mode(), torch.autocast(
|
|
919
|
+
"cuda", dtype=torch.float32
|
|
920
|
+
):
|
|
921
|
+
# Attempt to run SAM2 propagation - this will iterate through all frames
|
|
922
|
+
for (
|
|
923
|
+
frame_idx,
|
|
924
|
+
object_ids,
|
|
925
|
+
mask_logits,
|
|
926
|
+
) in self.predictor.propagate_in_video(self._sam2_state):
|
|
927
|
+
if frame_idx >= total_frames:
|
|
928
|
+
continue
|
|
929
|
+
|
|
930
|
+
# Find our object ID in the results
|
|
931
|
+
# obj_mask = None
|
|
932
|
+
for i, prop_obj_id in enumerate(object_ids):
|
|
933
|
+
if prop_obj_id == obj_id:
|
|
934
|
+
# Get the mask for our object
|
|
935
|
+
mask = (mask_logits[i] > 0.0).cpu().numpy()
|
|
936
|
+
|
|
937
|
+
# Fix dimensions if needed
|
|
938
|
+
if mask.ndim > 2:
|
|
939
|
+
mask = mask.squeeze()
|
|
940
|
+
|
|
941
|
+
# Resize if needed
|
|
942
|
+
if (
|
|
943
|
+
mask.shape
|
|
944
|
+
!= self.segmentation_result[
|
|
945
|
+
frame_idx
|
|
946
|
+
].shape
|
|
947
|
+
):
|
|
948
|
+
from skimage.transform import resize
|
|
949
|
+
|
|
950
|
+
mask = resize(
|
|
951
|
+
mask.astype(float),
|
|
952
|
+
self.segmentation_result[
|
|
953
|
+
frame_idx
|
|
954
|
+
].shape,
|
|
955
|
+
order=0,
|
|
956
|
+
preserve_range=True,
|
|
957
|
+
anti_aliasing=False,
|
|
958
|
+
).astype(bool)
|
|
959
|
+
|
|
960
|
+
# Update segmentation - only replacing background pixels
|
|
961
|
+
self.segmentation_result[frame_idx][
|
|
962
|
+
mask
|
|
963
|
+
& (
|
|
964
|
+
self.segmentation_result[frame_idx]
|
|
965
|
+
== 0
|
|
966
|
+
)
|
|
967
|
+
] = obj_id
|
|
968
|
+
|
|
969
|
+
# Update progress visualization
|
|
970
|
+
progress_data = progress_layer.data
|
|
971
|
+
progress_data[frame_idx] = (
|
|
972
|
+
mask.astype(float) * 0.8
|
|
973
|
+
)
|
|
974
|
+
progress_layer.data = progress_data
|
|
975
|
+
|
|
976
|
+
# Update status occasionally
|
|
977
|
+
if frame_idx % 10 == 0:
|
|
978
|
+
self.viewer.status = f"Propagating: frame {frame_idx+1}/{total_frames}"
|
|
979
|
+
|
|
980
|
+
except RuntimeError as e:
|
|
981
|
+
# If we get a dtype mismatch or other error, the current frame's mask to other frames
|
|
982
|
+
self.viewer.status = f"SAM2 propagation failed with error: {str(e)}. Falling back to alternative method."
|
|
983
|
+
|
|
984
|
+
# Use the current frame's mask for propagation
|
|
985
|
+
for frame_idx in range(total_frames):
|
|
986
|
+
if (
|
|
987
|
+
frame_idx != current_frame_idx
|
|
988
|
+
): # Skip current frame as it's already done
|
|
989
|
+
# Only replace background pixels with the current frame's object
|
|
990
|
+
self.segmentation_result[frame_idx][
|
|
991
|
+
current_mask
|
|
992
|
+
& (self.segmentation_result[frame_idx] == 0)
|
|
993
|
+
] = obj_id
|
|
994
|
+
|
|
995
|
+
# Update progress layer
|
|
996
|
+
progress_data = progress_layer.data
|
|
997
|
+
progress_data[frame_idx] = (
|
|
998
|
+
current_mask.astype(float) * 0.5
|
|
999
|
+
) # Different intensity to indicate fallback
|
|
1000
|
+
progress_layer.data = progress_data
|
|
1001
|
+
|
|
1002
|
+
# Update status occasionally
|
|
1003
|
+
if frame_idx % 10 == 0:
|
|
1004
|
+
self.viewer.status = f"Fallback propagation: frame {frame_idx+1}/{total_frames}"
|
|
1005
|
+
|
|
1006
|
+
# Remove progress layer after 2 seconds
|
|
1007
|
+
import threading
|
|
1008
|
+
|
|
1009
|
+
def remove_progress():
|
|
1010
|
+
import time
|
|
1011
|
+
|
|
1012
|
+
time.sleep(2)
|
|
1013
|
+
for layer in list(self.viewer.layers):
|
|
1014
|
+
if "Propagation Progress" in layer.name:
|
|
1015
|
+
self.viewer.layers.remove(layer)
|
|
483
1016
|
|
|
484
|
-
|
|
485
|
-
self.segmentation_result = labels
|
|
1017
|
+
threading.Thread(target=remove_progress).start()
|
|
486
1018
|
|
|
487
|
-
|
|
488
|
-
for layer in list(self.viewer.layers):
|
|
489
|
-
if isinstance(layer, Labels) and "Segmentation" in layer.name:
|
|
490
|
-
self.viewer.layers.remove(layer)
|
|
1019
|
+
self.viewer.status = f"Propagation of object {obj_id} complete"
|
|
491
1020
|
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
1021
|
+
except (
|
|
1022
|
+
IndexError,
|
|
1023
|
+
ValueError,
|
|
1024
|
+
RuntimeError,
|
|
1025
|
+
torch.cuda.OutOfMemoryError,
|
|
1026
|
+
TypeError,
|
|
1027
|
+
) as e:
|
|
1028
|
+
import traceback
|
|
498
1029
|
|
|
499
|
-
|
|
500
|
-
|
|
1030
|
+
self.viewer.status = f"Error in propagation: {str(e)}"
|
|
1031
|
+
traceback.print_exc()
|
|
501
1032
|
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
):
|
|
508
|
-
|
|
509
|
-
for callback in list(self.label_layer.mouse_drag_callbacks):
|
|
510
|
-
self.label_layer.mouse_drag_callbacks.remove(callback)
|
|
511
|
-
|
|
512
|
-
# Connect mouse click event to label selection
|
|
513
|
-
self.label_layer.mouse_drag_callbacks.append(self._on_label_clicked)
|
|
514
|
-
|
|
515
|
-
# image_name = os.path.basename(self.images[self.current_index])
|
|
516
|
-
self.viewer.status = f"Loaded image {self.current_index + 1}/{len(self.images)} - Found {len(masks)} segments"
|
|
517
|
-
|
|
518
|
-
# New method for handling scaled segmentation masks
|
|
519
|
-
def _process_segmentation_masks_with_scaling(self, masks, original_shape):
|
|
520
|
-
"""Process segmentation masks with scaling to match the original image size."""
|
|
521
|
-
# Create label image from masks
|
|
522
|
-
# First determine the size of the mask predictions (which are at the downscaled resolution)
|
|
523
|
-
if not masks:
|
|
1033
|
+
def _add_3d_prompt(self, prompt_coords):
|
|
1034
|
+
"""
|
|
1035
|
+
Given a 3D coordinate (x, y, z), run SAM2 video predictor to segment the object at that point,
|
|
1036
|
+
update the segmentation result and label layer.
|
|
1037
|
+
"""
|
|
1038
|
+
if not hasattr(self, "_sam2_state") or self._sam2_state is None:
|
|
1039
|
+
self.viewer.status = "SAM2 3D state not initialized."
|
|
524
1040
|
return
|
|
525
1041
|
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
downscaled_labels = np.zeros(mask_shape, dtype=np.uint32)
|
|
530
|
-
self.label_info = {} # Reset label info
|
|
531
|
-
|
|
532
|
-
# Fill in the downscaled labels
|
|
533
|
-
for i, mask_data in enumerate(masks):
|
|
534
|
-
mask = mask_data["segmentation"]
|
|
535
|
-
label_id = i + 1 # Start label IDs from 1
|
|
536
|
-
downscaled_labels[mask] = label_id
|
|
537
|
-
|
|
538
|
-
# Store basic label info
|
|
539
|
-
area = np.sum(mask)
|
|
540
|
-
y_indices, x_indices = np.where(mask)
|
|
541
|
-
center_y = np.mean(y_indices) if len(y_indices) > 0 else 0
|
|
542
|
-
center_x = np.mean(x_indices) if len(x_indices) > 0 else 0
|
|
543
|
-
|
|
544
|
-
# Scale centers to original image coordinates
|
|
545
|
-
center_y_orig = center_y / self.current_scale_factor
|
|
546
|
-
center_x_orig = center_x / self.current_scale_factor
|
|
1042
|
+
if self.predictor is None:
|
|
1043
|
+
self.viewer.status = "SAM2 predictor not initialized."
|
|
1044
|
+
return
|
|
547
1045
|
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
self.current_scale_factor**2
|
|
553
|
-
), # Approximate area in original scale
|
|
554
|
-
"center_y": center_y_orig,
|
|
555
|
-
"center_x": center_x_orig,
|
|
556
|
-
"score": mask_data.get("stability_score", 0),
|
|
557
|
-
}
|
|
1046
|
+
# Prepare prompt for SAM2: point_coords is [[x, y, t]], point_labels is [1]
|
|
1047
|
+
x, y, z = prompt_coords
|
|
1048
|
+
point_coords = np.array([[x, y, z]])
|
|
1049
|
+
point_labels = np.array([1]) # 1 = foreground
|
|
558
1050
|
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
1051
|
+
with torch.inference_mode(), torch.autocast(
|
|
1052
|
+
"cuda", dtype=torch.bfloat16
|
|
1053
|
+
):
|
|
1054
|
+
masks, scores, _ = self.predictor.predict(
|
|
1055
|
+
state=self._sam2_state,
|
|
1056
|
+
point_coords=point_coords,
|
|
1057
|
+
point_labels=point_labels,
|
|
1058
|
+
multimask_output=True,
|
|
1059
|
+
)
|
|
567
1060
|
|
|
568
|
-
#
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
1061
|
+
# Pick the best mask (highest score)
|
|
1062
|
+
if masks is not None and len(masks) > 0:
|
|
1063
|
+
best_idx = np.argmax(scores)
|
|
1064
|
+
mask = masks[best_idx]
|
|
1065
|
+
obj_id = self._sam2_next_obj_id
|
|
1066
|
+
self.segmentation_result[mask] = obj_id
|
|
1067
|
+
self._sam2_next_obj_id += 1
|
|
1068
|
+
self.viewer.status = (
|
|
1069
|
+
f"Added object {obj_id} at (x={x}, y={y}, z={z})"
|
|
574
1070
|
)
|
|
575
|
-
|
|
1071
|
+
self._update_label_layer()
|
|
1072
|
+
else:
|
|
1073
|
+
self.viewer.status = "No mask found for this prompt."
|
|
1074
|
+
|
|
1075
|
+
def on_apply_propagate(self):
|
|
1076
|
+
"""Propagate masks across the video and update the segmentation layer."""
|
|
1077
|
+
self.viewer.status = "Propagating masks across all frames..."
|
|
1078
|
+
self.viewer.window._qt_window.setCursor(Qt.WaitCursor)
|
|
1079
|
+
|
|
1080
|
+
self.segmentation_result[:] = 0
|
|
1081
|
+
|
|
1082
|
+
for (
|
|
1083
|
+
frame_idx,
|
|
1084
|
+
object_ids,
|
|
1085
|
+
mask_logits,
|
|
1086
|
+
) in self.predictor.propagate_in_video(self._sam2_state):
|
|
1087
|
+
masks = (mask_logits > 0.0).cpu().numpy()
|
|
1088
|
+
if frame_idx >= self.segmentation_result.shape[0]:
|
|
1089
|
+
print(
|
|
1090
|
+
f"Warning: frame_idx {frame_idx} out of bounds for segmentation_result with shape {self.segmentation_result.shape}"
|
|
1091
|
+
)
|
|
1092
|
+
continue
|
|
1093
|
+
for i, obj_id in enumerate(object_ids):
|
|
1094
|
+
self.segmentation_result[frame_idx][masks[i]] = obj_id
|
|
1095
|
+
self.viewer.status = f"Propagating: frame {frame_idx+1}"
|
|
576
1096
|
|
|
577
|
-
|
|
578
|
-
self.
|
|
1097
|
+
self._update_label_layer()
|
|
1098
|
+
self.viewer.status = "Propagation complete!"
|
|
1099
|
+
self.viewer.window._qt_window.setCursor(Qt.ArrowCursor)
|
|
579
1100
|
|
|
580
|
-
|
|
1101
|
+
def _update_label_layer(self):
|
|
1102
|
+
"""Update the label layer in the viewer."""
|
|
1103
|
+
# Remove existing label layer if it exists
|
|
581
1104
|
for layer in list(self.viewer.layers):
|
|
582
1105
|
if isinstance(layer, Labels) and "Segmentation" in layer.name:
|
|
583
1106
|
self.viewer.layers.remove(layer)
|
|
584
1107
|
|
|
585
1108
|
# Add label layer to viewer
|
|
586
1109
|
self.label_layer = self.viewer.add_labels(
|
|
587
|
-
|
|
1110
|
+
self.segmentation_result,
|
|
588
1111
|
name=f"Segmentation ({os.path.basename(self.images[self.current_index])})",
|
|
589
1112
|
opacity=0.7,
|
|
590
1113
|
)
|
|
591
1114
|
|
|
592
|
-
#
|
|
593
|
-
|
|
1115
|
+
# Create points layer for interaction if it doesn't exist
|
|
1116
|
+
points_layer = None
|
|
1117
|
+
for layer in list(self.viewer.layers):
|
|
1118
|
+
if "Points" in layer.name:
|
|
1119
|
+
points_layer = layer
|
|
1120
|
+
break
|
|
1121
|
+
|
|
1122
|
+
if points_layer is None:
|
|
1123
|
+
# Initialize an empty points layer
|
|
1124
|
+
points_layer = self.viewer.add_points(
|
|
1125
|
+
np.zeros((0, 2 if not self.use_3d else 3)),
|
|
1126
|
+
name="Points (Click to Add)",
|
|
1127
|
+
size=10,
|
|
1128
|
+
face_color="green",
|
|
1129
|
+
border_color="white",
|
|
1130
|
+
border_width=1,
|
|
1131
|
+
opacity=0.8,
|
|
1132
|
+
)
|
|
594
1133
|
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
for callback in list(self.label_layer.mouse_drag_callbacks):
|
|
603
|
-
self.label_layer.mouse_drag_callbacks.remove(callback)
|
|
1134
|
+
with contextlib.suppress(AttributeError, ValueError):
|
|
1135
|
+
points_layer.mouse_drag_callbacks.remove(
|
|
1136
|
+
self._on_points_clicked
|
|
1137
|
+
)
|
|
1138
|
+
points_layer.mouse_drag_callbacks.append(
|
|
1139
|
+
self._on_points_clicked
|
|
1140
|
+
)
|
|
604
1141
|
|
|
605
|
-
|
|
606
|
-
|
|
1142
|
+
# Connect points layer mouse click event
|
|
1143
|
+
points_layer.mouse_drag_callbacks.append(self._on_points_clicked)
|
|
607
1144
|
|
|
608
|
-
|
|
1145
|
+
# Make the points layer active to encourage interaction with it
|
|
1146
|
+
self.viewer.layers.selection.active = points_layer
|
|
609
1147
|
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
1148
|
+
# Update status
|
|
1149
|
+
n_labels = len(np.unique(self.segmentation_result)) - (
|
|
1150
|
+
1 if 0 in np.unique(self.segmentation_result) else 0
|
|
1151
|
+
)
|
|
1152
|
+
self.viewer.status = f"Loaded image {self.current_index + 1}/{len(self.images)} - Found {n_labels} segments"
|
|
613
1153
|
|
|
614
|
-
def
|
|
615
|
-
"""Handle
|
|
1154
|
+
def _on_points_clicked(self, layer, event):
|
|
1155
|
+
"""Handle clicks on the points layer for adding/removing points."""
|
|
616
1156
|
try:
|
|
617
1157
|
# Only process clicks, not drags
|
|
618
1158
|
if event.type != "mouse_press":
|
|
@@ -621,39 +1161,815 @@ class BatchCropAnything:
|
|
|
621
1161
|
# Get coordinates of mouse click
|
|
622
1162
|
coords = np.round(event.position).astype(int)
|
|
623
1163
|
|
|
624
|
-
#
|
|
625
|
-
|
|
626
|
-
if
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
1164
|
+
# Check if Shift is pressed for negative points
|
|
1165
|
+
is_negative = "Shift" in event.modifiers
|
|
1166
|
+
point_label = -1 if is_negative else 1
|
|
1167
|
+
|
|
1168
|
+
# Handle 2D vs 3D coordinates
|
|
1169
|
+
if self.use_3d:
|
|
1170
|
+
if len(coords) == 3:
|
|
1171
|
+
t, y, x = map(int, coords)
|
|
1172
|
+
elif len(coords) == 2:
|
|
1173
|
+
t = int(self.viewer.dims.current_step[0])
|
|
1174
|
+
y, x = map(int, coords)
|
|
1175
|
+
else:
|
|
1176
|
+
self.viewer.status = (
|
|
1177
|
+
f"Unexpected coordinate dimensions: {coords}"
|
|
1178
|
+
)
|
|
1179
|
+
return
|
|
1180
|
+
|
|
1181
|
+
# Add point to the layer immediately for visual feedback
|
|
1182
|
+
new_point = np.array([[t, y, x]])
|
|
1183
|
+
if len(layer.data) == 0:
|
|
1184
|
+
layer.data = new_point
|
|
1185
|
+
else:
|
|
1186
|
+
layer.data = np.vstack([layer.data, new_point])
|
|
1187
|
+
|
|
1188
|
+
# Update point colors
|
|
1189
|
+
colors = layer.face_color
|
|
1190
|
+
if isinstance(colors, list):
|
|
1191
|
+
colors.append("red" if is_negative else "green")
|
|
1192
|
+
else:
|
|
1193
|
+
n_points = len(layer.data)
|
|
1194
|
+
colors = ["green"] * (n_points - 1)
|
|
1195
|
+
colors.append("red" if is_negative else "green")
|
|
1196
|
+
layer.face_color = colors
|
|
1197
|
+
|
|
1198
|
+
# Get the object ID
|
|
1199
|
+
# If clicking on existing segmentation with negative point
|
|
1200
|
+
label_id = self.segmentation_result[t, y, x]
|
|
1201
|
+
if is_negative and label_id > 0:
|
|
1202
|
+
obj_id = label_id
|
|
1203
|
+
else:
|
|
1204
|
+
# For new objects or negative on background
|
|
1205
|
+
if not hasattr(self, "_sam2_next_obj_id"):
|
|
1206
|
+
self._sam2_next_obj_id = 1
|
|
1207
|
+
obj_id = self._sam2_next_obj_id
|
|
1208
|
+
if point_label > 0 and label_id == 0:
|
|
1209
|
+
self._sam2_next_obj_id += 1
|
|
1210
|
+
|
|
1211
|
+
# Store point information
|
|
1212
|
+
if not hasattr(self, "points_data"):
|
|
1213
|
+
self.points_data = {}
|
|
1214
|
+
self.points_labels = {}
|
|
1215
|
+
|
|
1216
|
+
if obj_id not in self.points_data:
|
|
1217
|
+
self.points_data[obj_id] = []
|
|
1218
|
+
self.points_labels[obj_id] = []
|
|
1219
|
+
|
|
1220
|
+
self.points_data[obj_id].append(
|
|
1221
|
+
[x, y]
|
|
1222
|
+
) # Note: SAM2 expects [x,y] format
|
|
1223
|
+
self.points_labels[obj_id].append(point_label)
|
|
1224
|
+
|
|
1225
|
+
# Perform segmentation
|
|
1226
|
+
if (
|
|
1227
|
+
hasattr(self, "_sam2_state")
|
|
1228
|
+
and self._sam2_state is not None
|
|
1229
|
+
):
|
|
1230
|
+
# Prepare points
|
|
1231
|
+
points = np.array(
|
|
1232
|
+
self.points_data[obj_id], dtype=np.float32
|
|
1233
|
+
)
|
|
1234
|
+
labels = np.array(
|
|
1235
|
+
self.points_labels[obj_id], dtype=np.int32
|
|
1236
|
+
)
|
|
1237
|
+
|
|
1238
|
+
# Create progress layer for visual feedback
|
|
1239
|
+
progress_layer = None
|
|
1240
|
+
for existing_layer in self.viewer.layers:
|
|
1241
|
+
if "Propagation Progress" in existing_layer.name:
|
|
1242
|
+
progress_layer = existing_layer
|
|
1243
|
+
break
|
|
1244
|
+
|
|
1245
|
+
if progress_layer is None:
|
|
1246
|
+
progress_data = np.zeros_like(self.segmentation_result)
|
|
1247
|
+
progress_layer = self.viewer.add_image(
|
|
1248
|
+
progress_data,
|
|
1249
|
+
name="Propagation Progress",
|
|
1250
|
+
colormap="magma",
|
|
1251
|
+
opacity=0.5,
|
|
1252
|
+
visible=True,
|
|
1253
|
+
)
|
|
1254
|
+
|
|
1255
|
+
# First update the current frame immediately
|
|
1256
|
+
self.viewer.status = f"Processing object at frame {t}..."
|
|
1257
|
+
|
|
1258
|
+
# Run SAM2 on current frame
|
|
1259
|
+
_, out_obj_ids, out_mask_logits = (
|
|
1260
|
+
self.predictor.add_new_points_or_box(
|
|
1261
|
+
inference_state=self._sam2_state,
|
|
1262
|
+
frame_idx=t,
|
|
1263
|
+
obj_id=obj_id,
|
|
1264
|
+
points=points,
|
|
1265
|
+
labels=labels,
|
|
1266
|
+
)
|
|
1267
|
+
)
|
|
1268
|
+
|
|
1269
|
+
# Update current frame
|
|
1270
|
+
mask = (out_mask_logits[0] > 0.0).cpu().numpy()
|
|
1271
|
+
if mask.ndim > 2:
|
|
1272
|
+
mask = mask.squeeze()
|
|
1273
|
+
|
|
1274
|
+
# Resize if needed
|
|
1275
|
+
if mask.shape != self.segmentation_result[t].shape:
|
|
1276
|
+
from skimage.transform import resize
|
|
1277
|
+
|
|
1278
|
+
mask = resize(
|
|
1279
|
+
mask.astype(float),
|
|
1280
|
+
self.segmentation_result[t].shape,
|
|
1281
|
+
order=0,
|
|
1282
|
+
preserve_range=True,
|
|
1283
|
+
anti_aliasing=False,
|
|
1284
|
+
).astype(bool)
|
|
1285
|
+
|
|
1286
|
+
# Update segmentation for this frame
|
|
1287
|
+
if point_label < 0:
|
|
1288
|
+
# For negative points, only remove from this object
|
|
1289
|
+
self.segmentation_result[t][
|
|
1290
|
+
(self.segmentation_result[t] == obj_id) & mask
|
|
1291
|
+
] = 0
|
|
1292
|
+
else:
|
|
1293
|
+
# For positive points, only replace background
|
|
1294
|
+
self.segmentation_result[t][
|
|
1295
|
+
mask & (self.segmentation_result[t] == 0)
|
|
1296
|
+
] = obj_id
|
|
1297
|
+
|
|
1298
|
+
# Update progress layer for this frame
|
|
1299
|
+
progress_data = progress_layer.data
|
|
1300
|
+
progress_data[t] = (
|
|
1301
|
+
mask.astype(float) * 0.5
|
|
1302
|
+
) # Highlight current frame
|
|
1303
|
+
progress_layer.data = progress_data
|
|
1304
|
+
|
|
1305
|
+
# Now propagate to all frames with visual feedback
|
|
1306
|
+
self.viewer.status = "Propagating to all frames..."
|
|
1307
|
+
|
|
1308
|
+
# Run propagation
|
|
1309
|
+
frame_count = self.segmentation_result.shape[0]
|
|
1310
|
+
for (
|
|
1311
|
+
frame_idx,
|
|
1312
|
+
prop_obj_ids,
|
|
1313
|
+
mask_logits,
|
|
1314
|
+
) in self.predictor.propagate_in_video(self._sam2_state):
|
|
1315
|
+
if frame_idx >= frame_count:
|
|
1316
|
+
continue
|
|
1317
|
+
|
|
1318
|
+
# Find our object
|
|
1319
|
+
obj_mask = None
|
|
1320
|
+
for i, prop_obj_id in enumerate(prop_obj_ids):
|
|
1321
|
+
if prop_obj_id == obj_id:
|
|
1322
|
+
obj_mask = (mask_logits[i] > 0.0).cpu().numpy()
|
|
1323
|
+
if obj_mask.ndim > 2:
|
|
1324
|
+
obj_mask = obj_mask.squeeze()
|
|
1325
|
+
|
|
1326
|
+
# Resize if needed
|
|
1327
|
+
if (
|
|
1328
|
+
obj_mask.shape
|
|
1329
|
+
!= self.segmentation_result[
|
|
1330
|
+
frame_idx
|
|
1331
|
+
].shape
|
|
1332
|
+
):
|
|
1333
|
+
obj_mask = resize(
|
|
1334
|
+
obj_mask.astype(float),
|
|
1335
|
+
self.segmentation_result[
|
|
1336
|
+
frame_idx
|
|
1337
|
+
].shape,
|
|
1338
|
+
order=0,
|
|
1339
|
+
preserve_range=True,
|
|
1340
|
+
anti_aliasing=False,
|
|
1341
|
+
).astype(bool)
|
|
1342
|
+
|
|
1343
|
+
# Update segmentation
|
|
1344
|
+
self.segmentation_result[frame_idx][
|
|
1345
|
+
obj_mask
|
|
1346
|
+
& (
|
|
1347
|
+
self.segmentation_result[frame_idx]
|
|
1348
|
+
== 0
|
|
1349
|
+
)
|
|
1350
|
+
] = obj_id
|
|
1351
|
+
|
|
1352
|
+
# Update progress visualization
|
|
1353
|
+
progress_data = progress_layer.data
|
|
1354
|
+
progress_data[frame_idx] = (
|
|
1355
|
+
obj_mask.astype(float) * 0.8
|
|
1356
|
+
) # Show as processed
|
|
1357
|
+
progress_layer.data = progress_data
|
|
1358
|
+
|
|
1359
|
+
# Update status
|
|
1360
|
+
if frame_idx % 5 == 0:
|
|
1361
|
+
self.viewer.status = f"Propagating: frame {frame_idx+1}/{frame_count}"
|
|
1362
|
+
# Remove the viewer.update() call as it's causing errors
|
|
1363
|
+
|
|
1364
|
+
# Process any missing frames
|
|
1365
|
+
processed_frames = set(range(frame_count))
|
|
1366
|
+
for frame_idx in range(frame_count):
|
|
1367
|
+
if (
|
|
1368
|
+
progress_data[frame_idx].max() == 0
|
|
1369
|
+
): # Frame not processed yet
|
|
1370
|
+
# Use nearest processed frame's mask
|
|
1371
|
+
nearest_idx = min(
|
|
1372
|
+
processed_frames,
|
|
1373
|
+
key=lambda x: abs(x - frame_idx),
|
|
1374
|
+
)
|
|
1375
|
+
if progress_data[nearest_idx].max() > 0:
|
|
1376
|
+
self.segmentation_result[frame_idx][
|
|
1377
|
+
(self.segmentation_result[frame_idx] == 0)
|
|
1378
|
+
& (
|
|
1379
|
+
self.segmentation_result[nearest_idx]
|
|
1380
|
+
== obj_id
|
|
1381
|
+
)
|
|
1382
|
+
] = obj_id
|
|
1383
|
+
|
|
1384
|
+
# Update progress visualization
|
|
1385
|
+
progress_data[frame_idx] = (
|
|
1386
|
+
progress_data[nearest_idx] * 0.6
|
|
1387
|
+
) # Mark as copied
|
|
1388
|
+
|
|
1389
|
+
# Final update of progress layer
|
|
1390
|
+
progress_layer.data = progress_data
|
|
1391
|
+
|
|
1392
|
+
# Remove progress layer after 2 seconds
|
|
1393
|
+
import threading
|
|
1394
|
+
|
|
1395
|
+
def remove_progress():
|
|
1396
|
+
import time
|
|
1397
|
+
|
|
1398
|
+
time.sleep(2)
|
|
1399
|
+
for layer in list(self.viewer.layers):
|
|
1400
|
+
if "Propagation Progress" in layer.name:
|
|
1401
|
+
self.viewer.layers.remove(layer)
|
|
1402
|
+
|
|
1403
|
+
threading.Thread(target=remove_progress).start()
|
|
1404
|
+
|
|
1405
|
+
# Update UI
|
|
1406
|
+
self._update_label_layer()
|
|
1407
|
+
if (
|
|
1408
|
+
hasattr(self, "label_table_widget")
|
|
1409
|
+
and self.label_table_widget is not None
|
|
1410
|
+
):
|
|
1411
|
+
self._populate_label_table(self.label_table_widget)
|
|
1412
|
+
|
|
1413
|
+
self.viewer.status = f"Object {obj_id} segmented and propagated to all frames"
|
|
1414
|
+
|
|
1415
|
+
else:
|
|
1416
|
+
# 2D case
|
|
1417
|
+
if len(coords) == 2:
|
|
1418
|
+
y, x = map(int, coords)
|
|
1419
|
+
else:
|
|
1420
|
+
self.viewer.status = (
|
|
1421
|
+
f"Unexpected coordinate dimensions: {coords}"
|
|
1422
|
+
)
|
|
1423
|
+
return
|
|
1424
|
+
|
|
1425
|
+
# Add point to the layer immediately for visual feedback
|
|
1426
|
+
new_point = np.array([[y, x]])
|
|
1427
|
+
if len(layer.data) == 0:
|
|
1428
|
+
layer.data = new_point
|
|
1429
|
+
else:
|
|
1430
|
+
layer.data = np.vstack([layer.data, new_point])
|
|
1431
|
+
|
|
1432
|
+
# Update point colors
|
|
1433
|
+
colors = layer.face_color
|
|
1434
|
+
if isinstance(colors, list):
|
|
1435
|
+
colors.append("red" if is_negative else "green")
|
|
1436
|
+
else:
|
|
1437
|
+
n_points = len(layer.data)
|
|
1438
|
+
colors = ["green"] * (n_points - 1)
|
|
1439
|
+
colors.append("red" if is_negative else "green")
|
|
1440
|
+
layer.face_color = colors
|
|
1441
|
+
|
|
1442
|
+
# Get object ID
|
|
1443
|
+
label_id = self.segmentation_result[y, x]
|
|
1444
|
+
if is_negative and label_id > 0:
|
|
1445
|
+
obj_id = label_id
|
|
1446
|
+
else:
|
|
1447
|
+
if not hasattr(self, "next_obj_id"):
|
|
1448
|
+
self.next_obj_id = 1
|
|
1449
|
+
obj_id = self.next_obj_id
|
|
1450
|
+
if point_label > 0 and label_id == 0:
|
|
1451
|
+
self.next_obj_id += 1
|
|
1452
|
+
|
|
1453
|
+
# Store point information
|
|
1454
|
+
if not hasattr(self, "obj_points"):
|
|
1455
|
+
self.obj_points = {}
|
|
1456
|
+
self.obj_labels = {}
|
|
1457
|
+
|
|
1458
|
+
if obj_id not in self.obj_points:
|
|
1459
|
+
self.obj_points[obj_id] = []
|
|
1460
|
+
self.obj_labels[obj_id] = []
|
|
1461
|
+
|
|
1462
|
+
self.obj_points[obj_id].append(
|
|
1463
|
+
[x, y]
|
|
1464
|
+
) # SAM2 expects [x,y] format
|
|
1465
|
+
self.obj_labels[obj_id].append(point_label)
|
|
1466
|
+
|
|
1467
|
+
# Perform segmentation
|
|
1468
|
+
if hasattr(self, "predictor") and self.predictor is not None:
|
|
1469
|
+
# Make sure image is loaded
|
|
1470
|
+
if self.current_image_for_segmentation is None:
|
|
1471
|
+
self.viewer.status = "No image loaded for segmentation"
|
|
1472
|
+
return
|
|
1473
|
+
|
|
1474
|
+
# Prepare image for SAM2
|
|
1475
|
+
image = self.current_image_for_segmentation
|
|
1476
|
+
if len(image.shape) == 2:
|
|
1477
|
+
image = np.stack([image] * 3, axis=-1)
|
|
1478
|
+
elif len(image.shape) == 3 and image.shape[2] == 1:
|
|
1479
|
+
image = np.concatenate([image] * 3, axis=2)
|
|
1480
|
+
elif len(image.shape) == 3 and image.shape[2] > 3:
|
|
1481
|
+
image = image[:, :, :3]
|
|
1482
|
+
|
|
1483
|
+
if image.dtype != np.uint8:
|
|
1484
|
+
image = (image / np.max(image) * 255).astype(np.uint8)
|
|
1485
|
+
|
|
1486
|
+
# Set the image in the predictor
|
|
1487
|
+
self.predictor.set_image(image)
|
|
1488
|
+
|
|
1489
|
+
# Use only points for current object
|
|
1490
|
+
points = np.array(
|
|
1491
|
+
self.obj_points[obj_id], dtype=np.float32
|
|
1492
|
+
)
|
|
1493
|
+
labels = np.array(self.obj_labels[obj_id], dtype=np.int32)
|
|
1494
|
+
|
|
1495
|
+
self.viewer.status = f"Segmenting object {obj_id} with {len(points)} points..."
|
|
1496
|
+
|
|
1497
|
+
with torch.inference_mode(), torch.autocast("cuda"):
|
|
1498
|
+
masks, scores, _ = self.predictor.predict(
|
|
1499
|
+
point_coords=points,
|
|
1500
|
+
point_labels=labels,
|
|
1501
|
+
multimask_output=True,
|
|
1502
|
+
)
|
|
1503
|
+
|
|
1504
|
+
# Get best mask
|
|
1505
|
+
if len(masks) > 0:
|
|
1506
|
+
best_mask = masks[0]
|
|
1507
|
+
|
|
1508
|
+
# Update segmentation result
|
|
1509
|
+
if (
|
|
1510
|
+
best_mask.shape
|
|
1511
|
+
!= self.segmentation_result.shape
|
|
1512
|
+
):
|
|
1513
|
+
from skimage.transform import resize
|
|
1514
|
+
|
|
1515
|
+
best_mask = resize(
|
|
1516
|
+
best_mask.astype(float),
|
|
1517
|
+
self.segmentation_result.shape,
|
|
1518
|
+
order=0,
|
|
1519
|
+
preserve_range=True,
|
|
1520
|
+
anti_aliasing=False,
|
|
1521
|
+
).astype(bool)
|
|
1522
|
+
|
|
1523
|
+
# Apply mask based on point type
|
|
1524
|
+
if point_label < 0:
|
|
1525
|
+
# Remove only from current object
|
|
1526
|
+
mask_condition = np.logical_and(
|
|
1527
|
+
self.segmentation_result == obj_id,
|
|
1528
|
+
best_mask,
|
|
1529
|
+
)
|
|
1530
|
+
self.segmentation_result[mask_condition] = 0
|
|
1531
|
+
else:
|
|
1532
|
+
# Add to current object (only overwrite background)
|
|
1533
|
+
mask_condition = np.logical_and(
|
|
1534
|
+
best_mask, (self.segmentation_result == 0)
|
|
1535
|
+
)
|
|
1536
|
+
self.segmentation_result[mask_condition] = (
|
|
1537
|
+
obj_id
|
|
1538
|
+
)
|
|
1539
|
+
|
|
1540
|
+
# Update label info
|
|
1541
|
+
area = np.sum(self.segmentation_result == obj_id)
|
|
1542
|
+
y_indices, x_indices = np.where(
|
|
1543
|
+
self.segmentation_result == obj_id
|
|
1544
|
+
)
|
|
1545
|
+
center_y = (
|
|
1546
|
+
np.mean(y_indices) if len(y_indices) > 0 else 0
|
|
1547
|
+
)
|
|
1548
|
+
center_x = (
|
|
1549
|
+
np.mean(x_indices) if len(x_indices) > 0 else 0
|
|
1550
|
+
)
|
|
1551
|
+
|
|
1552
|
+
self.label_info[obj_id] = {
|
|
1553
|
+
"area": area,
|
|
1554
|
+
"center_y": center_y,
|
|
1555
|
+
"center_x": center_x,
|
|
1556
|
+
"score": float(scores[0]),
|
|
1557
|
+
}
|
|
1558
|
+
|
|
1559
|
+
self.viewer.status = f"Updated object {obj_id}"
|
|
1560
|
+
else:
|
|
1561
|
+
self.viewer.status = "No valid mask produced"
|
|
1562
|
+
|
|
1563
|
+
# Update the UI
|
|
1564
|
+
self._update_label_layer()
|
|
1565
|
+
if (
|
|
1566
|
+
hasattr(self, "label_table_widget")
|
|
1567
|
+
and self.label_table_widget is not None
|
|
1568
|
+
):
|
|
1569
|
+
self._populate_label_table(self.label_table_widget)
|
|
1570
|
+
|
|
1571
|
+
except (
|
|
1572
|
+
IndexError,
|
|
1573
|
+
KeyError,
|
|
1574
|
+
ValueError,
|
|
1575
|
+
RuntimeError,
|
|
1576
|
+
TypeError,
|
|
1577
|
+
) as e:
|
|
1578
|
+
import traceback
|
|
633
1579
|
|
|
634
|
-
|
|
635
|
-
|
|
1580
|
+
self.viewer.status = f"Error in points handling: {str(e)}"
|
|
1581
|
+
traceback.print_exc()
|
|
636
1582
|
|
|
637
|
-
|
|
638
|
-
|
|
1583
|
+
def _on_label_clicked(self, layer, event):
|
|
1584
|
+
"""Handle label selection and user prompts on mouse click."""
|
|
1585
|
+
try:
|
|
1586
|
+
# Only process clicks, not drags
|
|
1587
|
+
if event.type != "mouse_press":
|
|
639
1588
|
return
|
|
640
1589
|
|
|
641
|
-
#
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
1590
|
+
# Get coordinates of mouse click
|
|
1591
|
+
coords = np.round(event.position).astype(int)
|
|
1592
|
+
|
|
1593
|
+
# Check if Shift is pressed (negative point)
|
|
1594
|
+
is_negative = "Shift" in event.modifiers
|
|
1595
|
+
point_label = -1 if is_negative else 1
|
|
1596
|
+
|
|
1597
|
+
# For 2D data
|
|
1598
|
+
if not self.use_3d:
|
|
1599
|
+
if len(coords) == 2:
|
|
1600
|
+
y, x = map(int, coords)
|
|
1601
|
+
else:
|
|
1602
|
+
self.viewer.status = (
|
|
1603
|
+
f"Unexpected coordinate dimensions: {coords}"
|
|
1604
|
+
)
|
|
1605
|
+
return
|
|
1606
|
+
|
|
1607
|
+
# Check if within image bounds
|
|
1608
|
+
shape = self.segmentation_result.shape
|
|
1609
|
+
if y < 0 or x < 0 or y >= shape[0] or x >= shape[1]:
|
|
1610
|
+
self.viewer.status = "Click is outside image bounds"
|
|
1611
|
+
return
|
|
1612
|
+
|
|
1613
|
+
# Get the label ID at the clicked position
|
|
1614
|
+
label_id = self.segmentation_result[y, x]
|
|
1615
|
+
|
|
1616
|
+
# Initialize a unique object ID for this click (if needed)
|
|
1617
|
+
if not hasattr(self, "next_obj_id"):
|
|
1618
|
+
# Start with highest existing ID + 1
|
|
1619
|
+
if self.segmentation_result.max() > 0:
|
|
1620
|
+
self.next_obj_id = (
|
|
1621
|
+
int(self.segmentation_result.max()) + 1
|
|
1622
|
+
)
|
|
1623
|
+
else:
|
|
1624
|
+
self.next_obj_id = 1
|
|
1625
|
+
|
|
1626
|
+
# If clicking on background or using negative click, handle segmentation
|
|
1627
|
+
if label_id == 0 or is_negative:
|
|
1628
|
+
# Find or create points layer for the current object we're working on
|
|
1629
|
+
current_obj_id = None
|
|
1630
|
+
|
|
1631
|
+
# If negative point on existing label, use that label's ID
|
|
1632
|
+
if is_negative and label_id > 0:
|
|
1633
|
+
current_obj_id = label_id
|
|
1634
|
+
# For positive clicks on background, create a new object
|
|
1635
|
+
elif point_label > 0 and label_id == 0:
|
|
1636
|
+
current_obj_id = self.next_obj_id
|
|
1637
|
+
self.next_obj_id += 1
|
|
1638
|
+
# For negative on background, try to find most recent object
|
|
1639
|
+
elif point_label < 0 and label_id == 0:
|
|
1640
|
+
# Use most recently created object if available
|
|
1641
|
+
if hasattr(self, "obj_points") and self.obj_points:
|
|
1642
|
+
current_obj_id = max(self.obj_points.keys())
|
|
1643
|
+
else:
|
|
1644
|
+
self.viewer.status = "No existing object to modify with negative point"
|
|
1645
|
+
return
|
|
1646
|
+
|
|
1647
|
+
if current_obj_id is None:
|
|
1648
|
+
self.viewer.status = (
|
|
1649
|
+
"Could not determine which object to modify"
|
|
1650
|
+
)
|
|
1651
|
+
return
|
|
1652
|
+
|
|
1653
|
+
# Find or create points layer for this object
|
|
1654
|
+
points_layer = None
|
|
1655
|
+
for layer in list(self.viewer.layers):
|
|
1656
|
+
if f"Points for Object {current_obj_id}" in layer.name:
|
|
1657
|
+
points_layer = layer
|
|
1658
|
+
break
|
|
1659
|
+
|
|
1660
|
+
# Initialize object tracking if needed
|
|
1661
|
+
if not hasattr(self, "obj_points"):
|
|
1662
|
+
self.obj_points = {}
|
|
1663
|
+
self.obj_labels = {}
|
|
1664
|
+
|
|
1665
|
+
if current_obj_id not in self.obj_points:
|
|
1666
|
+
self.obj_points[current_obj_id] = []
|
|
1667
|
+
self.obj_labels[current_obj_id] = []
|
|
1668
|
+
|
|
1669
|
+
# Create or update points layer for this object
|
|
1670
|
+
if points_layer is None:
|
|
1671
|
+
# First point for this object
|
|
1672
|
+
points_layer = self.viewer.add_points(
|
|
1673
|
+
np.array([[y, x]]),
|
|
1674
|
+
name=f"Points for Object {current_obj_id}",
|
|
1675
|
+
size=10,
|
|
1676
|
+
face_color=["green" if point_label > 0 else "red"],
|
|
1677
|
+
border_color="white",
|
|
1678
|
+
border_width=1,
|
|
1679
|
+
opacity=0.8,
|
|
1680
|
+
)
|
|
1681
|
+
with contextlib.suppress(AttributeError, ValueError):
|
|
1682
|
+
points_layer.mouse_drag_callbacks.remove(
|
|
1683
|
+
self._on_points_clicked
|
|
1684
|
+
)
|
|
1685
|
+
points_layer.mouse_drag_callbacks.append(
|
|
1686
|
+
self._on_points_clicked
|
|
1687
|
+
)
|
|
1688
|
+
|
|
1689
|
+
self.obj_points[current_obj_id] = [[x, y]]
|
|
1690
|
+
self.obj_labels[current_obj_id] = [point_label]
|
|
1691
|
+
else:
|
|
1692
|
+
# Add point to existing layer
|
|
1693
|
+
current_points = points_layer.data
|
|
1694
|
+
current_colors = points_layer.face_color
|
|
1695
|
+
|
|
1696
|
+
# Add new point
|
|
1697
|
+
new_points = np.vstack([current_points, [y, x]])
|
|
1698
|
+
new_color = "green" if point_label > 0 else "red"
|
|
1699
|
+
|
|
1700
|
+
# Update points layer
|
|
1701
|
+
points_layer.data = new_points
|
|
1702
|
+
|
|
1703
|
+
# Update colors
|
|
1704
|
+
if isinstance(current_colors, list):
|
|
1705
|
+
current_colors.append(new_color)
|
|
1706
|
+
points_layer.face_color = current_colors
|
|
1707
|
+
else:
|
|
1708
|
+
# If it's an array, create a list of colors
|
|
1709
|
+
colors = []
|
|
1710
|
+
for i in range(len(new_points)):
|
|
1711
|
+
if i < len(current_points):
|
|
1712
|
+
colors.append(
|
|
1713
|
+
"green" if point_label > 0 else "red"
|
|
1714
|
+
)
|
|
1715
|
+
else:
|
|
1716
|
+
colors.append(new_color)
|
|
1717
|
+
points_layer.face_color = colors
|
|
1718
|
+
|
|
1719
|
+
# Update object tracking
|
|
1720
|
+
self.obj_points[current_obj_id].append([x, y])
|
|
1721
|
+
self.obj_labels[current_obj_id].append(point_label)
|
|
1722
|
+
|
|
1723
|
+
# Now do the actual segmentation using SAM2
|
|
1724
|
+
if (
|
|
1725
|
+
hasattr(self, "predictor")
|
|
1726
|
+
and self.predictor is not None
|
|
1727
|
+
):
|
|
1728
|
+
try:
|
|
1729
|
+
# Make sure image is loaded
|
|
1730
|
+
if self.current_image_for_segmentation is None:
|
|
1731
|
+
self.viewer.status = (
|
|
1732
|
+
"No image loaded for segmentation"
|
|
1733
|
+
)
|
|
1734
|
+
return
|
|
1735
|
+
|
|
1736
|
+
# Prepare image for SAM2
|
|
1737
|
+
image = self.current_image_for_segmentation
|
|
1738
|
+
if len(image.shape) == 2:
|
|
1739
|
+
image = np.stack([image] * 3, axis=-1)
|
|
1740
|
+
elif len(image.shape) == 3 and image.shape[2] == 1:
|
|
1741
|
+
image = np.concatenate([image] * 3, axis=2)
|
|
1742
|
+
elif len(image.shape) == 3 and image.shape[2] > 3:
|
|
1743
|
+
image = image[:, :, :3]
|
|
1744
|
+
|
|
1745
|
+
if image.dtype != np.uint8:
|
|
1746
|
+
image = (image / np.max(image) * 255).astype(
|
|
1747
|
+
np.uint8
|
|
1748
|
+
)
|
|
1749
|
+
|
|
1750
|
+
# Set the image in the predictor
|
|
1751
|
+
self.predictor.set_image(image)
|
|
1752
|
+
|
|
1753
|
+
# Only use the points for the current object being segmented
|
|
1754
|
+
points = np.array(
|
|
1755
|
+
self.obj_points[current_obj_id],
|
|
1756
|
+
dtype=np.float32,
|
|
1757
|
+
)
|
|
1758
|
+
labels = np.array(
|
|
1759
|
+
self.obj_labels[current_obj_id], dtype=np.int32
|
|
1760
|
+
)
|
|
1761
|
+
|
|
1762
|
+
self.viewer.status = f"Segmenting object {current_obj_id} with {len(points)} points..."
|
|
1763
|
+
|
|
1764
|
+
with torch.inference_mode(), torch.autocast(
|
|
1765
|
+
"cuda"
|
|
1766
|
+
):
|
|
1767
|
+
masks, scores, _ = self.predictor.predict(
|
|
1768
|
+
point_coords=points,
|
|
1769
|
+
point_labels=labels,
|
|
1770
|
+
multimask_output=True,
|
|
1771
|
+
)
|
|
1772
|
+
|
|
1773
|
+
# Get best mask
|
|
1774
|
+
if len(masks) > 0:
|
|
1775
|
+
best_mask = masks[0]
|
|
1776
|
+
|
|
1777
|
+
# Update segmentation result
|
|
1778
|
+
if (
|
|
1779
|
+
best_mask.shape
|
|
1780
|
+
!= self.segmentation_result.shape
|
|
1781
|
+
):
|
|
1782
|
+
from skimage.transform import resize
|
|
1783
|
+
|
|
1784
|
+
best_mask = resize(
|
|
1785
|
+
best_mask.astype(float),
|
|
1786
|
+
self.segmentation_result.shape,
|
|
1787
|
+
order=0,
|
|
1788
|
+
preserve_range=True,
|
|
1789
|
+
anti_aliasing=False,
|
|
1790
|
+
).astype(bool)
|
|
1791
|
+
|
|
1792
|
+
# CRITICAL FIX: For negative points, only remove from this object's mask
|
|
1793
|
+
# For positive points, add to this object's mask without removing other objects
|
|
1794
|
+
if point_label < 0:
|
|
1795
|
+
# Remove only from current object's mask
|
|
1796
|
+
self.segmentation_result[
|
|
1797
|
+
(
|
|
1798
|
+
self.segmentation_result
|
|
1799
|
+
== current_obj_id
|
|
1800
|
+
)
|
|
1801
|
+
& best_mask
|
|
1802
|
+
] = 0
|
|
1803
|
+
else:
|
|
1804
|
+
# Add to current object's mask without affecting other objects
|
|
1805
|
+
# Only overwrite background (value 0)
|
|
1806
|
+
self.segmentation_result[
|
|
1807
|
+
best_mask
|
|
1808
|
+
& (self.segmentation_result == 0)
|
|
1809
|
+
] = current_obj_id
|
|
1810
|
+
|
|
1811
|
+
# Update label info
|
|
1812
|
+
area = np.sum(
|
|
1813
|
+
self.segmentation_result
|
|
1814
|
+
== current_obj_id
|
|
1815
|
+
)
|
|
1816
|
+
y_indices, x_indices = np.where(
|
|
1817
|
+
self.segmentation_result
|
|
1818
|
+
== current_obj_id
|
|
1819
|
+
)
|
|
1820
|
+
center_y = (
|
|
1821
|
+
np.mean(y_indices)
|
|
1822
|
+
if len(y_indices) > 0
|
|
1823
|
+
else 0
|
|
1824
|
+
)
|
|
1825
|
+
center_x = (
|
|
1826
|
+
np.mean(x_indices)
|
|
1827
|
+
if len(x_indices) > 0
|
|
1828
|
+
else 0
|
|
1829
|
+
)
|
|
1830
|
+
|
|
1831
|
+
self.label_info[current_obj_id] = {
|
|
1832
|
+
"area": area,
|
|
1833
|
+
"center_y": center_y,
|
|
1834
|
+
"center_x": center_x,
|
|
1835
|
+
"score": float(scores[0]),
|
|
1836
|
+
}
|
|
1837
|
+
|
|
1838
|
+
self.viewer.status = (
|
|
1839
|
+
f"Updated object {current_obj_id}"
|
|
1840
|
+
)
|
|
1841
|
+
else:
|
|
1842
|
+
self.viewer.status = (
|
|
1843
|
+
"No valid mask produced"
|
|
1844
|
+
)
|
|
1845
|
+
|
|
1846
|
+
# Update the UI
|
|
1847
|
+
self._update_label_layer()
|
|
1848
|
+
if (
|
|
1849
|
+
hasattr(self, "label_table_widget")
|
|
1850
|
+
and self.label_table_widget is not None
|
|
1851
|
+
):
|
|
1852
|
+
self._populate_label_table(
|
|
1853
|
+
self.label_table_widget
|
|
1854
|
+
)
|
|
1855
|
+
|
|
1856
|
+
except (
|
|
1857
|
+
IndexError,
|
|
1858
|
+
KeyError,
|
|
1859
|
+
ValueError,
|
|
1860
|
+
AttributeError,
|
|
1861
|
+
TypeError,
|
|
1862
|
+
) as e:
|
|
1863
|
+
import traceback
|
|
1864
|
+
|
|
1865
|
+
self.viewer.status = (
|
|
1866
|
+
f"Error in SAM2 processing: {str(e)}"
|
|
1867
|
+
)
|
|
1868
|
+
traceback.print_exc()
|
|
1869
|
+
|
|
1870
|
+
# If clicking on an existing label, toggle selection
|
|
1871
|
+
elif label_id > 0:
|
|
1872
|
+
# Toggle the label selection
|
|
1873
|
+
if label_id in self.selected_labels:
|
|
1874
|
+
self.selected_labels.remove(label_id)
|
|
1875
|
+
self.viewer.status = f"Deselected label ID: {label_id} | Selected labels: {self.selected_labels}"
|
|
1876
|
+
else:
|
|
1877
|
+
self.selected_labels.add(label_id)
|
|
1878
|
+
self.viewer.status = f"Selected label ID: {label_id} | Selected labels: {self.selected_labels}"
|
|
1879
|
+
|
|
1880
|
+
# Update table and preview
|
|
1881
|
+
self._update_label_table()
|
|
1882
|
+
self.preview_crop()
|
|
1883
|
+
|
|
1884
|
+
# 3D case (handle differently)
|
|
645
1885
|
else:
|
|
646
|
-
|
|
647
|
-
|
|
1886
|
+
if len(coords) == 3:
|
|
1887
|
+
t, y, x = map(int, coords)
|
|
1888
|
+
elif len(coords) == 2:
|
|
1889
|
+
t = int(self.viewer.dims.current_step[0])
|
|
1890
|
+
y, x = map(int, coords)
|
|
1891
|
+
else:
|
|
1892
|
+
self.viewer.status = (
|
|
1893
|
+
f"Unexpected coordinate dimensions: {coords}"
|
|
1894
|
+
)
|
|
1895
|
+
return
|
|
1896
|
+
|
|
1897
|
+
# Check if within bounds
|
|
1898
|
+
shape = self.segmentation_result.shape
|
|
1899
|
+
if (
|
|
1900
|
+
t < 0
|
|
1901
|
+
or t >= shape[0]
|
|
1902
|
+
or y < 0
|
|
1903
|
+
or y >= shape[1]
|
|
1904
|
+
or x < 0
|
|
1905
|
+
or x >= shape[2]
|
|
1906
|
+
):
|
|
1907
|
+
self.viewer.status = "Click is outside volume bounds"
|
|
1908
|
+
return
|
|
1909
|
+
|
|
1910
|
+
# Get the label ID at the clicked position
|
|
1911
|
+
label_id = self.segmentation_result[t, y, x]
|
|
1912
|
+
|
|
1913
|
+
# If background or shift is pressed, handle in _on_3d_label_clicked
|
|
1914
|
+
if label_id == 0 or is_negative:
|
|
1915
|
+
# This will be handled by _on_3d_label_clicked already attached
|
|
1916
|
+
pass
|
|
1917
|
+
# If clicking on an existing label, handle selection
|
|
1918
|
+
elif label_id > 0:
|
|
1919
|
+
# Toggle the label selection
|
|
1920
|
+
if label_id in self.selected_labels:
|
|
1921
|
+
self.selected_labels.remove(label_id)
|
|
1922
|
+
self.viewer.status = f"Deselected label ID: {label_id} | Selected labels: {self.selected_labels}"
|
|
1923
|
+
else:
|
|
1924
|
+
self.selected_labels.add(label_id)
|
|
1925
|
+
self.viewer.status = f"Selected label ID: {label_id} | Selected labels: {self.selected_labels}"
|
|
648
1926
|
|
|
649
|
-
|
|
650
|
-
|
|
1927
|
+
# Update table if it exists
|
|
1928
|
+
self._update_label_table()
|
|
651
1929
|
|
|
652
|
-
|
|
653
|
-
|
|
1930
|
+
# Update preview after selection changes
|
|
1931
|
+
self.preview_crop()
|
|
654
1932
|
|
|
655
|
-
except (
|
|
656
|
-
|
|
1933
|
+
except (
|
|
1934
|
+
IndexError,
|
|
1935
|
+
KeyError,
|
|
1936
|
+
ValueError,
|
|
1937
|
+
AttributeError,
|
|
1938
|
+
TypeError,
|
|
1939
|
+
) as e:
|
|
1940
|
+
import traceback
|
|
1941
|
+
|
|
1942
|
+
self.viewer.status = f"Error in click handling: {str(e)}"
|
|
1943
|
+
traceback.print_exc()
|
|
1944
|
+
|
|
1945
|
+
def _add_point_marker(self, coords, label_type):
|
|
1946
|
+
"""Add a visible marker for where the user clicked."""
|
|
1947
|
+
# Remove previous point markers
|
|
1948
|
+
for layer in list(self.viewer.layers):
|
|
1949
|
+
if "Point Prompt" in layer.name:
|
|
1950
|
+
self.viewer.layers.remove(layer)
|
|
1951
|
+
|
|
1952
|
+
# Create points layer
|
|
1953
|
+
color = (
|
|
1954
|
+
"red" if label_type < 0 else "green"
|
|
1955
|
+
) # Red for negative, green for positive
|
|
1956
|
+
self.viewer.add_points(
|
|
1957
|
+
[coords],
|
|
1958
|
+
name="Point Prompt",
|
|
1959
|
+
size=10,
|
|
1960
|
+
face_color=color,
|
|
1961
|
+
edge_color="white",
|
|
1962
|
+
edge_width=2,
|
|
1963
|
+
opacity=0.8,
|
|
1964
|
+
)
|
|
1965
|
+
|
|
1966
|
+
with contextlib.suppress(AttributeError, ValueError):
|
|
1967
|
+
self.points_layer.mouse_drag_callbacks.remove(
|
|
1968
|
+
self._on_points_clicked
|
|
1969
|
+
)
|
|
1970
|
+
self.points_layer.mouse_drag_callbacks.append(
|
|
1971
|
+
self._on_points_clicked
|
|
1972
|
+
)
|
|
657
1973
|
|
|
658
1974
|
def create_label_table(self, parent_widget):
|
|
659
1975
|
"""Create a table widget displaying all detected labels."""
|
|
@@ -694,57 +2010,86 @@ class BatchCropAnything:
|
|
|
694
2010
|
|
|
695
2011
|
def _populate_label_table(self, table):
|
|
696
2012
|
"""Populate the table with label information."""
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
2013
|
+
try:
|
|
2014
|
+
# Get all unique non-zero labels from the segmentation result safely
|
|
2015
|
+
if self.segmentation_result is None:
|
|
2016
|
+
# No segmentation yet
|
|
2017
|
+
table.setRowCount(0)
|
|
2018
|
+
self.viewer.status = "No segmentation available"
|
|
2019
|
+
return
|
|
700
2020
|
|
|
701
|
-
|
|
702
|
-
|
|
2021
|
+
# Get unique labels, safely handling None values
|
|
2022
|
+
unique_labels = []
|
|
2023
|
+
for val in np.unique(self.segmentation_result):
|
|
2024
|
+
if val is not None and val > 0:
|
|
2025
|
+
unique_labels.append(val)
|
|
703
2026
|
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
reverse=True,
|
|
709
|
-
)
|
|
2027
|
+
if len(unique_labels) == 0:
|
|
2028
|
+
table.setRowCount(0)
|
|
2029
|
+
self.viewer.status = "No labeled objects found"
|
|
2030
|
+
return
|
|
710
2031
|
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
2032
|
+
# Set row count
|
|
2033
|
+
table.setRowCount(len(unique_labels))
|
|
2034
|
+
|
|
2035
|
+
# Fill in label info for any missing labels
|
|
2036
|
+
for label_id in unique_labels:
|
|
2037
|
+
if label_id not in self.label_info:
|
|
2038
|
+
# Calculate basic info for this label
|
|
2039
|
+
mask = self.segmentation_result == label_id
|
|
2040
|
+
area = np.sum(mask)
|
|
2041
|
+
|
|
2042
|
+
# Add info to label_info dictionary
|
|
2043
|
+
self.label_info[label_id] = {
|
|
2044
|
+
"area": area,
|
|
2045
|
+
"score": 1.0, # Default score
|
|
2046
|
+
}
|
|
2047
|
+
|
|
2048
|
+
# Fill table with data
|
|
2049
|
+
for row, label_id in enumerate(unique_labels):
|
|
2050
|
+
# Checkbox for selection
|
|
2051
|
+
checkbox_widget = QWidget()
|
|
2052
|
+
checkbox_layout = QHBoxLayout(checkbox_widget)
|
|
2053
|
+
checkbox_layout.setContentsMargins(5, 0, 5, 0)
|
|
2054
|
+
checkbox_layout.setAlignment(Qt.AlignCenter)
|
|
2055
|
+
|
|
2056
|
+
checkbox = QCheckBox()
|
|
2057
|
+
checkbox.setChecked(label_id in self.selected_labels)
|
|
2058
|
+
|
|
2059
|
+
# Connect checkbox to label selection
|
|
2060
|
+
def make_checkbox_callback(lid):
|
|
2061
|
+
def callback(state):
|
|
2062
|
+
if state == Qt.Checked:
|
|
2063
|
+
self.selected_labels.add(lid)
|
|
2064
|
+
else:
|
|
2065
|
+
self.selected_labels.discard(lid)
|
|
2066
|
+
self.preview_crop()
|
|
2067
|
+
|
|
2068
|
+
return callback
|
|
730
2069
|
|
|
731
|
-
|
|
2070
|
+
checkbox.stateChanged.connect(make_checkbox_callback(label_id))
|
|
732
2071
|
|
|
733
|
-
|
|
2072
|
+
checkbox_layout.addWidget(checkbox)
|
|
2073
|
+
table.setCellWidget(row, 0, checkbox_widget)
|
|
734
2074
|
|
|
735
|
-
|
|
736
|
-
|
|
2075
|
+
# Label ID as plain text with transparent background
|
|
2076
|
+
item = QTableWidgetItem(str(label_id))
|
|
2077
|
+
item.setTextAlignment(Qt.AlignCenter)
|
|
737
2078
|
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
2079
|
+
# Set the background color to transparent
|
|
2080
|
+
brush = item.background()
|
|
2081
|
+
brush.setStyle(Qt.NoBrush)
|
|
2082
|
+
item.setBackground(brush)
|
|
741
2083
|
|
|
742
|
-
|
|
743
|
-
brush = item.background()
|
|
744
|
-
brush.setStyle(Qt.NoBrush)
|
|
745
|
-
item.setBackground(brush)
|
|
2084
|
+
table.setItem(row, 1, item)
|
|
746
2085
|
|
|
747
|
-
|
|
2086
|
+
except (KeyError, TypeError, ValueError, AttributeError) as e:
|
|
2087
|
+
import traceback
|
|
2088
|
+
|
|
2089
|
+
self.viewer.status = f"Error populating table: {str(e)}"
|
|
2090
|
+
traceback.print_exc()
|
|
2091
|
+
# Set empty table as fallback
|
|
2092
|
+
table.setRowCount(0)
|
|
748
2093
|
|
|
749
2094
|
def _update_label_table(self):
|
|
750
2095
|
"""Update the label selection table if it exists."""
|
|
@@ -754,6 +2099,9 @@ class BatchCropAnything:
|
|
|
754
2099
|
# Block signals during update
|
|
755
2100
|
self.label_table_widget.blockSignals(True)
|
|
756
2101
|
|
|
2102
|
+
# Completely repopulate the table to ensure it's up to date
|
|
2103
|
+
self._populate_label_table(self.label_table_widget)
|
|
2104
|
+
|
|
757
2105
|
# Update checkboxes
|
|
758
2106
|
for row in range(self.label_table_widget.rowCount()):
|
|
759
2107
|
# Get label ID from the visible column
|
|
@@ -793,10 +2141,6 @@ class BatchCropAnything:
|
|
|
793
2141
|
self.preview_crop()
|
|
794
2142
|
self.viewer.status = "Cleared all selections"
|
|
795
2143
|
|
|
796
|
-
# --------------------------------------------------
|
|
797
|
-
# Image Processing and Export
|
|
798
|
-
# --------------------------------------------------
|
|
799
|
-
|
|
800
2144
|
def preview_crop(self, label_ids=None):
|
|
801
2145
|
"""Preview the crop result with the selected label IDs."""
|
|
802
2146
|
if self.segmentation_result is None or self.image_layer is None:
|
|
@@ -826,20 +2170,29 @@ class BatchCropAnything:
|
|
|
826
2170
|
image = self.original_image.copy()
|
|
827
2171
|
|
|
828
2172
|
# Create mask from selected label IDs
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
mask
|
|
2173
|
+
if self.use_3d:
|
|
2174
|
+
# For 3D data
|
|
2175
|
+
mask = np.zeros_like(self.segmentation_result, dtype=bool)
|
|
2176
|
+
for label_id in label_ids:
|
|
2177
|
+
mask |= self.segmentation_result == label_id
|
|
832
2178
|
|
|
833
|
-
|
|
834
|
-
if len(image.shape) == 2:
|
|
835
|
-
# Grayscale image
|
|
2179
|
+
# Apply mask
|
|
836
2180
|
preview_image = image.copy()
|
|
837
2181
|
preview_image[~mask] = 0
|
|
838
2182
|
else:
|
|
839
|
-
#
|
|
840
|
-
|
|
841
|
-
for
|
|
842
|
-
|
|
2183
|
+
# For 2D data
|
|
2184
|
+
mask = np.zeros_like(self.segmentation_result, dtype=bool)
|
|
2185
|
+
for label_id in label_ids:
|
|
2186
|
+
mask |= self.segmentation_result == label_id
|
|
2187
|
+
|
|
2188
|
+
# Apply mask
|
|
2189
|
+
if len(image.shape) == 2:
|
|
2190
|
+
preview_image = image.copy()
|
|
2191
|
+
preview_image[~mask] = 0
|
|
2192
|
+
else:
|
|
2193
|
+
preview_image = image.copy()
|
|
2194
|
+
for c in range(preview_image.shape[2]):
|
|
2195
|
+
preview_image[:, :, c][~mask] = 0
|
|
843
2196
|
|
|
844
2197
|
# Remove previous preview if exists
|
|
845
2198
|
for layer in list(self.viewer.layers):
|
|
@@ -879,20 +2232,58 @@ class BatchCropAnything:
|
|
|
879
2232
|
image = self.original_image
|
|
880
2233
|
|
|
881
2234
|
# Create mask from all selected label IDs
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
mask
|
|
2235
|
+
if self.use_3d:
|
|
2236
|
+
# For 3D data, create a 3D mask
|
|
2237
|
+
mask = np.zeros_like(self.segmentation_result, dtype=bool)
|
|
2238
|
+
for label_id in self.selected_labels:
|
|
2239
|
+
mask |= self.segmentation_result == label_id
|
|
885
2240
|
|
|
886
|
-
|
|
887
|
-
if len(image.shape) == 2:
|
|
888
|
-
# Grayscale image
|
|
2241
|
+
# Apply mask to image (set everything outside mask to 0)
|
|
889
2242
|
cropped_image = image.copy()
|
|
890
2243
|
cropped_image[~mask] = 0
|
|
2244
|
+
|
|
2245
|
+
# Save label image with same dimensions as original
|
|
2246
|
+
label_image = np.zeros_like(
|
|
2247
|
+
self.segmentation_result, dtype=np.uint32
|
|
2248
|
+
)
|
|
2249
|
+
for label_id in self.selected_labels:
|
|
2250
|
+
label_image[self.segmentation_result == label_id] = (
|
|
2251
|
+
label_id
|
|
2252
|
+
)
|
|
891
2253
|
else:
|
|
892
|
-
#
|
|
893
|
-
|
|
894
|
-
for
|
|
895
|
-
|
|
2254
|
+
# For 2D data, handle as before
|
|
2255
|
+
mask = np.zeros_like(self.segmentation_result, dtype=bool)
|
|
2256
|
+
for label_id in self.selected_labels:
|
|
2257
|
+
mask |= self.segmentation_result == label_id
|
|
2258
|
+
|
|
2259
|
+
# Apply mask to image (set everything outside mask to 0)
|
|
2260
|
+
if len(image.shape) == 2:
|
|
2261
|
+
# Grayscale image
|
|
2262
|
+
cropped_image = image.copy()
|
|
2263
|
+
cropped_image[~mask] = 0
|
|
2264
|
+
|
|
2265
|
+
# Create label image with same dimensions
|
|
2266
|
+
label_image = np.zeros_like(
|
|
2267
|
+
self.segmentation_result, dtype=np.uint32
|
|
2268
|
+
)
|
|
2269
|
+
for label_id in self.selected_labels:
|
|
2270
|
+
label_image[self.segmentation_result == label_id] = (
|
|
2271
|
+
label_id
|
|
2272
|
+
)
|
|
2273
|
+
else:
|
|
2274
|
+
# Color image - mask must be expanded to match channel dimension
|
|
2275
|
+
cropped_image = image.copy()
|
|
2276
|
+
for c in range(cropped_image.shape[2]):
|
|
2277
|
+
cropped_image[:, :, c][~mask] = 0
|
|
2278
|
+
|
|
2279
|
+
# Create label image with 2D dimensions (without channels)
|
|
2280
|
+
label_image = np.zeros_like(
|
|
2281
|
+
self.segmentation_result, dtype=np.uint32
|
|
2282
|
+
)
|
|
2283
|
+
for label_id in self.selected_labels:
|
|
2284
|
+
label_image[self.segmentation_result == label_id] = (
|
|
2285
|
+
label_id
|
|
2286
|
+
)
|
|
896
2287
|
|
|
897
2288
|
# Save cropped image
|
|
898
2289
|
image_path = self.images[self.current_index]
|
|
@@ -900,18 +2291,17 @@ class BatchCropAnything:
|
|
|
900
2291
|
label_str = "_".join(
|
|
901
2292
|
str(lid) for lid in sorted(self.selected_labels)
|
|
902
2293
|
)
|
|
903
|
-
output_path = f"{base_name}_cropped_{label_str}
|
|
904
|
-
|
|
905
|
-
# Save using appropriate method based on file type
|
|
906
|
-
if output_path.lower().endswith((".tif", ".tiff")):
|
|
907
|
-
imwrite(output_path, cropped_image, compression="zlib")
|
|
908
|
-
else:
|
|
909
|
-
from skimage.io import imsave
|
|
910
|
-
|
|
911
|
-
imsave(output_path, cropped_image)
|
|
2294
|
+
output_path = f"{base_name}_cropped_{label_str}.tif"
|
|
912
2295
|
|
|
2296
|
+
# Save using tifffile with explicit parameters for best compatibility
|
|
2297
|
+
imwrite(output_path, cropped_image, compression="zlib")
|
|
913
2298
|
self.viewer.status = f"Saved cropped image to {output_path}"
|
|
914
2299
|
|
|
2300
|
+
# Save the label image with exact same dimensions as original
|
|
2301
|
+
label_output_path = f"{base_name}_labels_{label_str}.tif"
|
|
2302
|
+
imwrite(label_output_path, label_image, compression="zlib")
|
|
2303
|
+
self.viewer.status += f"\nSaved label mask to {label_output_path}"
|
|
2304
|
+
|
|
915
2305
|
# Make sure the segmentation layer is active again
|
|
916
2306
|
if self.label_layer is not None:
|
|
917
2307
|
self.viewer.layers.selection.active = self.label_layer
|
|
@@ -923,76 +2313,44 @@ class BatchCropAnything:
|
|
|
923
2313
|
return False
|
|
924
2314
|
|
|
925
2315
|
|
|
926
|
-
# --------------------------------------------------
|
|
927
|
-
# UI Creation Functions
|
|
928
|
-
# --------------------------------------------------
|
|
929
|
-
|
|
930
|
-
|
|
931
2316
|
def create_crop_widget(processor):
|
|
932
2317
|
"""Create the crop control widget."""
|
|
933
2318
|
crop_widget = QWidget()
|
|
934
2319
|
layout = QVBoxLayout()
|
|
935
|
-
layout.setSpacing(10)
|
|
936
|
-
layout.setContentsMargins(
|
|
937
|
-
10, 10, 10, 10
|
|
938
|
-
) # Add margins around all elements
|
|
2320
|
+
layout.setSpacing(10)
|
|
2321
|
+
layout.setContentsMargins(10, 10, 10, 10)
|
|
939
2322
|
|
|
940
2323
|
# Instructions
|
|
2324
|
+
dimension_type = "3D (TYX/ZYX)" if processor.use_3d else "2D (YX)"
|
|
941
2325
|
instructions_label = QLabel(
|
|
942
|
-
"
|
|
943
|
-
"
|
|
944
|
-
"
|
|
2326
|
+
f"<b>Processing {dimension_type} data</b><br><br>"
|
|
2327
|
+
"To create/edit objects:<br>"
|
|
2328
|
+
"1. <b>Click on the POINTS layer</b> to add positive points<br>"
|
|
2329
|
+
"2. Use Shift+click for negative points to refine segmentation<br>"
|
|
2330
|
+
"3. Click on existing objects in the Segmentation layer to select them<br>"
|
|
2331
|
+
"4. Press 'Crop' to save the selected objects to disk"
|
|
945
2332
|
)
|
|
946
2333
|
instructions_label.setWordWrap(True)
|
|
947
2334
|
layout.addWidget(instructions_label)
|
|
948
2335
|
|
|
949
|
-
#
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
sensitivity_header_layout = QHBoxLayout()
|
|
954
|
-
sensitivity_label = QLabel("Segmentation Sensitivity:")
|
|
955
|
-
sensitivity_value_label = QLabel(f"{processor.sensitivity}")
|
|
956
|
-
sensitivity_header_layout.addWidget(sensitivity_label)
|
|
957
|
-
sensitivity_header_layout.addStretch()
|
|
958
|
-
sensitivity_header_layout.addWidget(sensitivity_value_label)
|
|
959
|
-
sensitivity_layout.addLayout(sensitivity_header_layout)
|
|
960
|
-
|
|
961
|
-
# Slider
|
|
962
|
-
slider_layout = QHBoxLayout()
|
|
963
|
-
sensitivity_slider = QSlider(Qt.Horizontal)
|
|
964
|
-
sensitivity_slider.setMinimum(0)
|
|
965
|
-
sensitivity_slider.setMaximum(100)
|
|
966
|
-
sensitivity_slider.setValue(processor.sensitivity)
|
|
967
|
-
sensitivity_slider.setTickPosition(QSlider.TicksBelow)
|
|
968
|
-
sensitivity_slider.setTickInterval(10)
|
|
969
|
-
slider_layout.addWidget(sensitivity_slider)
|
|
970
|
-
|
|
971
|
-
apply_sensitivity_button = QPushButton("Apply")
|
|
972
|
-
apply_sensitivity_button.setToolTip(
|
|
973
|
-
"Apply sensitivity changes to regenerate segmentation"
|
|
974
|
-
)
|
|
975
|
-
slider_layout.addWidget(apply_sensitivity_button)
|
|
976
|
-
sensitivity_layout.addLayout(slider_layout)
|
|
977
|
-
|
|
978
|
-
# Description label
|
|
979
|
-
sensitivity_description = QLabel(
|
|
980
|
-
"Medium sensitivity - Balanced detection (γ=1.00)"
|
|
2336
|
+
# Add a button to ensure points layer is active
|
|
2337
|
+
activate_button = QPushButton("Make Points Layer Active")
|
|
2338
|
+
activate_button.clicked.connect(
|
|
2339
|
+
lambda: processor._ensure_points_layer_active()
|
|
981
2340
|
)
|
|
982
|
-
|
|
983
|
-
sensitivity_layout.addWidget(sensitivity_description)
|
|
2341
|
+
layout.addWidget(activate_button)
|
|
984
2342
|
|
|
985
|
-
|
|
2343
|
+
# Add a "Clear Points" button to reset prompts
|
|
2344
|
+
clear_points_button = QPushButton("Clear Points")
|
|
2345
|
+
layout.addWidget(clear_points_button)
|
|
986
2346
|
|
|
987
2347
|
# Create label table
|
|
988
2348
|
label_table = processor.create_label_table(crop_widget)
|
|
989
|
-
label_table.setMinimumHeight(150)
|
|
990
|
-
label_table.setMaximumHeight(
|
|
991
|
-
300
|
|
992
|
-
) # Set maximum height to prevent taking too much space
|
|
2349
|
+
label_table.setMinimumHeight(150)
|
|
2350
|
+
label_table.setMaximumHeight(300)
|
|
993
2351
|
layout.addWidget(label_table)
|
|
994
2352
|
|
|
995
|
-
#
|
|
2353
|
+
# Selection buttons
|
|
996
2354
|
selection_layout = QHBoxLayout()
|
|
997
2355
|
select_all_button = QPushButton("Select All")
|
|
998
2356
|
clear_selection_button = QPushButton("Clear Selection")
|
|
@@ -1014,7 +2372,7 @@ def create_crop_widget(processor):
|
|
|
1014
2372
|
|
|
1015
2373
|
# Status label
|
|
1016
2374
|
status_label = QLabel(
|
|
1017
|
-
"Ready to process images.
|
|
2375
|
+
"Ready to process images. Click on POINTS layer to add segmentation points."
|
|
1018
2376
|
)
|
|
1019
2377
|
status_label.setWordWrap(True)
|
|
1020
2378
|
layout.addWidget(status_label)
|
|
@@ -1033,36 +2391,51 @@ def create_crop_widget(processor):
|
|
|
1033
2391
|
# Create new table
|
|
1034
2392
|
label_table = processor.create_label_table(crop_widget)
|
|
1035
2393
|
label_table.setMinimumHeight(200)
|
|
1036
|
-
layout.insertWidget(3, label_table) # Insert after
|
|
2394
|
+
layout.insertWidget(3, label_table) # Insert after clear points button
|
|
1037
2395
|
return label_table
|
|
1038
2396
|
|
|
1039
|
-
#
|
|
1040
|
-
def
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
f"Medium sensitivity - Balanced detection (γ={gamma:.2f})"
|
|
2397
|
+
# Add helper method to ensure points layer is active
|
|
2398
|
+
def _ensure_points_layer_active():
|
|
2399
|
+
points_layer = None
|
|
2400
|
+
for layer in list(processor.viewer.layers):
|
|
2401
|
+
if "Points" in layer.name:
|
|
2402
|
+
points_layer = layer
|
|
2403
|
+
break
|
|
2404
|
+
|
|
2405
|
+
if points_layer is not None:
|
|
2406
|
+
processor.viewer.layers.selection.active = points_layer
|
|
2407
|
+
status_label.setText(
|
|
2408
|
+
"Points layer is now active - click to add points"
|
|
1052
2409
|
)
|
|
1053
2410
|
else:
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
)
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
2411
|
+
status_label.setText(
|
|
2412
|
+
"No points layer found. Please load an image first."
|
|
2413
|
+
)
|
|
2414
|
+
|
|
2415
|
+
processor._ensure_points_layer_active = _ensure_points_layer_active
|
|
2416
|
+
|
|
2417
|
+
# Connect button signals
|
|
2418
|
+
def on_clear_points_clicked():
|
|
2419
|
+
# Remove all point layers
|
|
2420
|
+
for layer in list(processor.viewer.layers):
|
|
2421
|
+
if "Points" in layer.name:
|
|
2422
|
+
processor.viewer.layers.remove(layer)
|
|
2423
|
+
|
|
2424
|
+
# Reset point tracking attributes
|
|
2425
|
+
if hasattr(processor, "points_data"):
|
|
2426
|
+
processor.points_data = {}
|
|
2427
|
+
processor.points_labels = {}
|
|
2428
|
+
|
|
2429
|
+
if hasattr(processor, "obj_points"):
|
|
2430
|
+
processor.obj_points = {}
|
|
2431
|
+
processor.obj_labels = {}
|
|
2432
|
+
|
|
2433
|
+
# Re-create empty points layer
|
|
2434
|
+
processor._update_label_layer()
|
|
2435
|
+
processor._ensure_points_layer_active()
|
|
2436
|
+
|
|
1064
2437
|
status_label.setText(
|
|
1065
|
-
|
|
2438
|
+
"Cleared all points. Click on Points layer to add new points."
|
|
1066
2439
|
)
|
|
1067
2440
|
|
|
1068
2441
|
def on_select_all_clicked():
|
|
@@ -1086,117 +2459,83 @@ def create_crop_widget(processor):
|
|
|
1086
2459
|
)
|
|
1087
2460
|
|
|
1088
2461
|
def on_next_clicked():
|
|
2462
|
+
# Clear points before moving to next image
|
|
2463
|
+
on_clear_points_clicked()
|
|
2464
|
+
|
|
1089
2465
|
if not processor.next_image():
|
|
1090
2466
|
next_button.setEnabled(False)
|
|
1091
2467
|
else:
|
|
1092
2468
|
prev_button.setEnabled(True)
|
|
1093
2469
|
replace_table_widget()
|
|
1094
|
-
# Reset sensitivity slider to default
|
|
1095
|
-
sensitivity_slider.setValue(processor.sensitivity)
|
|
1096
|
-
sensitivity_value_label.setText(f"{processor.sensitivity}")
|
|
1097
2470
|
status_label.setText(
|
|
1098
2471
|
f"Showing image {processor.current_index + 1}/{len(processor.images)}"
|
|
1099
2472
|
)
|
|
2473
|
+
processor._ensure_points_layer_active()
|
|
1100
2474
|
|
|
1101
2475
|
def on_prev_clicked():
|
|
2476
|
+
# Clear points before moving to previous image
|
|
2477
|
+
on_clear_points_clicked()
|
|
2478
|
+
|
|
1102
2479
|
if not processor.previous_image():
|
|
1103
2480
|
prev_button.setEnabled(False)
|
|
1104
2481
|
else:
|
|
1105
2482
|
next_button.setEnabled(True)
|
|
1106
2483
|
replace_table_widget()
|
|
1107
|
-
# Reset sensitivity slider to default
|
|
1108
|
-
sensitivity_slider.setValue(processor.sensitivity)
|
|
1109
|
-
sensitivity_value_label.setText(f"{processor.sensitivity}")
|
|
1110
2484
|
status_label.setText(
|
|
1111
2485
|
f"Showing image {processor.current_index + 1}/{len(processor.images)}"
|
|
1112
2486
|
)
|
|
2487
|
+
processor._ensure_points_layer_active()
|
|
1113
2488
|
|
|
1114
|
-
|
|
1115
|
-
apply_sensitivity_button.clicked.connect(on_apply_sensitivity_clicked)
|
|
2489
|
+
clear_points_button.clicked.connect(on_clear_points_clicked)
|
|
1116
2490
|
select_all_button.clicked.connect(on_select_all_clicked)
|
|
1117
2491
|
clear_selection_button.clicked.connect(on_clear_selection_clicked)
|
|
1118
2492
|
crop_button.clicked.connect(on_crop_clicked)
|
|
1119
2493
|
next_button.clicked.connect(on_next_clicked)
|
|
1120
2494
|
prev_button.clicked.connect(on_prev_clicked)
|
|
2495
|
+
activate_button.clicked.connect(_ensure_points_layer_active)
|
|
1121
2496
|
|
|
1122
2497
|
return crop_widget
|
|
1123
2498
|
|
|
1124
2499
|
|
|
1125
|
-
# --------------------------------------------------
|
|
1126
|
-
# Napari Plugin Functions
|
|
1127
|
-
# --------------------------------------------------
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
2500
|
@magicgui(
|
|
1131
2501
|
call_button="Start Batch Crop Anything",
|
|
1132
2502
|
folder_path={"label": "Folder Path", "widget_type": "LineEdit"},
|
|
2503
|
+
data_dimensions={
|
|
2504
|
+
"label": "Data Dimensions",
|
|
2505
|
+
"choices": ["YX (2D)", "TYX/ZYX (3D)"],
|
|
2506
|
+
},
|
|
1133
2507
|
)
|
|
1134
2508
|
def batch_crop_anything(
|
|
1135
2509
|
folder_path: str,
|
|
2510
|
+
data_dimensions: str,
|
|
1136
2511
|
viewer: Viewer = None,
|
|
1137
2512
|
):
|
|
1138
|
-
"""MagicGUI widget for starting Batch Crop Anything."""
|
|
1139
|
-
# Check if
|
|
2513
|
+
"""MagicGUI widget for starting Batch Crop Anything using SAM2."""
|
|
2514
|
+
# Check if SAM2 is available
|
|
1140
2515
|
try:
|
|
1141
|
-
|
|
1142
|
-
# from mobile_sam import sam_model_registry
|
|
1143
|
-
|
|
1144
|
-
# Check if the required files are included with the package
|
|
1145
|
-
try:
|
|
1146
|
-
import importlib.util
|
|
1147
|
-
import os
|
|
1148
|
-
|
|
1149
|
-
mobile_sam_spec = importlib.util.find_spec("mobile_sam")
|
|
1150
|
-
if mobile_sam_spec is None:
|
|
1151
|
-
raise ImportError("mobile_sam package not found")
|
|
1152
|
-
|
|
1153
|
-
mobile_sam_path = os.path.dirname(mobile_sam_spec.origin)
|
|
1154
|
-
|
|
1155
|
-
# Check for model file in package
|
|
1156
|
-
model_found = False
|
|
1157
|
-
checkpoint_paths = [
|
|
1158
|
-
os.path.join(mobile_sam_path, "weights", "mobile_sam.pt"),
|
|
1159
|
-
os.path.join(mobile_sam_path, "mobile_sam.pt"),
|
|
1160
|
-
os.path.join(
|
|
1161
|
-
os.path.dirname(mobile_sam_path),
|
|
1162
|
-
"weights",
|
|
1163
|
-
"mobile_sam.pt",
|
|
1164
|
-
),
|
|
1165
|
-
os.path.join(
|
|
1166
|
-
os.path.expanduser("~"), "models", "mobile_sam.pt"
|
|
1167
|
-
),
|
|
1168
|
-
"/opt/T-MIDAS/models/mobile_sam.pt",
|
|
1169
|
-
os.path.join(os.getcwd(), "mobile_sam.pt"),
|
|
1170
|
-
]
|
|
1171
|
-
|
|
1172
|
-
for path in checkpoint_paths:
|
|
1173
|
-
if os.path.exists(path):
|
|
1174
|
-
model_found = True
|
|
1175
|
-
break
|
|
1176
|
-
|
|
1177
|
-
if not model_found:
|
|
1178
|
-
QMessageBox.warning(
|
|
1179
|
-
None,
|
|
1180
|
-
"Model File Missing",
|
|
1181
|
-
"Mobile-SAM model weights (mobile_sam.pt) not found. You'll be prompted to locate it when starting the tool.\n\n"
|
|
1182
|
-
"You can download it from: https://github.com/ChaoningZhang/MobileSAM/tree/master/weights",
|
|
1183
|
-
)
|
|
1184
|
-
except (ImportError, AttributeError) as e:
|
|
1185
|
-
print(f"Warning checking for model file: {str(e)}")
|
|
2516
|
+
import importlib.util
|
|
1186
2517
|
|
|
2518
|
+
sam2_spec = importlib.util.find_spec("sam2")
|
|
2519
|
+
if sam2_spec is None:
|
|
2520
|
+
QMessageBox.critical(
|
|
2521
|
+
None,
|
|
2522
|
+
"Missing Dependency",
|
|
2523
|
+
"SAM2 not found. Please follow installation instructions at:\n"
|
|
2524
|
+
"https://github.com/MercaderLabAnatomy/napari-tmidas?tab=readme-ov-file#dependencies\n",
|
|
2525
|
+
)
|
|
2526
|
+
return
|
|
1187
2527
|
except ImportError:
|
|
1188
2528
|
QMessageBox.critical(
|
|
1189
2529
|
None,
|
|
1190
2530
|
"Missing Dependency",
|
|
1191
|
-
"
|
|
1192
|
-
"
|
|
1193
|
-
"You'll also need to download the model weights file (mobile_sam.pt) from:\n"
|
|
1194
|
-
"https://github.com/ChaoningZhang/MobileSAM/tree/master/weights",
|
|
2531
|
+
"SAM2 package cannot be imported. Please follow installation instructions at\n"
|
|
2532
|
+
"https://github.com/MercaderLabAnatomy/napari-tmidas?tab=readme-ov-file#dependencies",
|
|
1195
2533
|
)
|
|
1196
2534
|
return
|
|
1197
2535
|
|
|
1198
|
-
# Initialize processor
|
|
1199
|
-
|
|
2536
|
+
# Initialize processor with the selected dimensions mode
|
|
2537
|
+
use_3d = "TYX/ZYX" in data_dimensions
|
|
2538
|
+
processor = BatchCropAnything(viewer, use_3d=use_3d)
|
|
1200
2539
|
processor.load_images(folder_path)
|
|
1201
2540
|
|
|
1202
2541
|
# Create UI
|
|
@@ -1205,13 +2544,9 @@ def batch_crop_anything(
|
|
|
1205
2544
|
# Wrap the widget in a scroll area
|
|
1206
2545
|
scroll_area = QScrollArea()
|
|
1207
2546
|
scroll_area.setWidget(crop_widget)
|
|
1208
|
-
scroll_area.setWidgetResizable(
|
|
1209
|
-
|
|
1210
|
-
)
|
|
1211
|
-
scroll_area.setFrameShape(QScrollArea.NoFrame) # Hide the frame
|
|
1212
|
-
scroll_area.setMinimumHeight(
|
|
1213
|
-
500
|
|
1214
|
-
) # Set a minimum height to ensure visibility
|
|
2547
|
+
scroll_area.setWidgetResizable(True)
|
|
2548
|
+
scroll_area.setFrameShape(QScrollArea.NoFrame)
|
|
2549
|
+
scroll_area.setMinimumHeight(500)
|
|
1215
2550
|
|
|
1216
2551
|
# Add scroll area to viewer
|
|
1217
2552
|
viewer.window.add_dock_widget(scroll_area, name="Crop Controls")
|