napari-tmidas 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- napari_tmidas/_crop_anything.py +1895 -608
- napari_tmidas/_file_selector.py +87 -6
- napari_tmidas/_version.py +2 -2
- napari_tmidas/processing_functions/basic.py +494 -23
- napari_tmidas/processing_functions/careamics_denoising.py +324 -0
- napari_tmidas/processing_functions/careamics_env_manager.py +339 -0
- napari_tmidas/processing_functions/cellpose_env_manager.py +55 -20
- napari_tmidas/processing_functions/cellpose_segmentation.py +105 -218
- napari_tmidas/processing_functions/sam2_mp4.py +283 -0
- napari_tmidas/processing_functions/skimage_filters.py +31 -1
- napari_tmidas/processing_functions/timepoint_merger.py +490 -0
- napari_tmidas/processing_functions/trackastra_tracking.py +303 -0
- {napari_tmidas-0.2.0.dist-info → napari_tmidas-0.2.1.dist-info}/METADATA +15 -8
- {napari_tmidas-0.2.0.dist-info → napari_tmidas-0.2.1.dist-info}/RECORD +18 -13
- {napari_tmidas-0.2.0.dist-info → napari_tmidas-0.2.1.dist-info}/WHEEL +1 -1
- {napari_tmidas-0.2.0.dist-info → napari_tmidas-0.2.1.dist-info}/entry_points.txt +0 -0
- {napari_tmidas-0.2.0.dist-info → napari_tmidas-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {napari_tmidas-0.2.0.dist-info → napari_tmidas-0.2.1.dist-info}/top_level.txt +0 -0
napari_tmidas/_crop_anything.py
CHANGED
|
@@ -1,13 +1,20 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Batch Crop Anything - A Napari plugin for interactive image cropping
|
|
3
3
|
|
|
4
|
-
This plugin combines
|
|
4
|
+
This plugin combines SAM2 for automatic object detection with
|
|
5
5
|
an interactive interface for selecting and cropping objects from images.
|
|
6
|
+
The plugin supports both 2D (YX) and 3D (TYX/ZYX) data.
|
|
6
7
|
"""
|
|
7
8
|
|
|
9
|
+
import contextlib
|
|
8
10
|
import os
|
|
9
11
|
|
|
12
|
+
# Add this at the beginning of your plugin file
|
|
13
|
+
import sys
|
|
14
|
+
|
|
15
|
+
sys.path.append("/opt/sam2")
|
|
10
16
|
import numpy as np
|
|
17
|
+
import requests
|
|
11
18
|
import torch
|
|
12
19
|
from magicgui import magicgui
|
|
13
20
|
from napari.layers import Labels
|
|
@@ -22,34 +29,55 @@ from qtpy.QtWidgets import (
|
|
|
22
29
|
QMessageBox,
|
|
23
30
|
QPushButton,
|
|
24
31
|
QScrollArea,
|
|
25
|
-
QSlider,
|
|
26
32
|
QTableWidget,
|
|
27
33
|
QTableWidgetItem,
|
|
28
34
|
QVBoxLayout,
|
|
29
35
|
QWidget,
|
|
30
36
|
)
|
|
31
37
|
from skimage.io import imread
|
|
32
|
-
from skimage.transform import resize
|
|
38
|
+
from skimage.transform import resize
|
|
33
39
|
from tifffile import imwrite
|
|
34
40
|
|
|
41
|
+
from napari_tmidas.processing_functions.sam2_mp4 import tif_to_mp4
|
|
42
|
+
|
|
43
|
+
def get_device():
|
|
44
|
+
if sys.platform == "darwin":
|
|
45
|
+
# MacOS: Only check for MPS
|
|
46
|
+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
47
|
+
device = torch.device("mps")
|
|
48
|
+
print("Using Apple Silicon GPU (MPS)")
|
|
49
|
+
else:
|
|
50
|
+
device = torch.device("cpu")
|
|
51
|
+
print("Using CPU")
|
|
52
|
+
else:
|
|
53
|
+
# Other platforms: check for CUDA, then CPU
|
|
54
|
+
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
|
55
|
+
device = torch.device("cuda")
|
|
56
|
+
print(f"Using CUDA GPU: {torch.cuda.get_device_name()}")
|
|
57
|
+
else:
|
|
58
|
+
device = torch.device("cpu")
|
|
59
|
+
print("Using CPU")
|
|
60
|
+
return device
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
|
|
35
64
|
|
|
36
65
|
class BatchCropAnything:
|
|
37
|
-
"""
|
|
38
|
-
Class for processing images with Segment Anything and cropping selected objects.
|
|
39
|
-
"""
|
|
66
|
+
"""Class for processing images with SAM2 and cropping selected objects."""
|
|
40
67
|
|
|
41
|
-
def __init__(self, viewer: Viewer):
|
|
68
|
+
def __init__(self, viewer: Viewer, use_3d=False):
|
|
42
69
|
"""Initialize the BatchCropAnything processor."""
|
|
43
70
|
# Core components
|
|
44
71
|
self.viewer = viewer
|
|
45
72
|
self.images = []
|
|
46
73
|
self.current_index = 0
|
|
74
|
+
self.use_3d = use_3d
|
|
47
75
|
|
|
48
76
|
# Image and segmentation data
|
|
49
77
|
self.original_image = None
|
|
50
78
|
self.segmentation_result = None
|
|
51
79
|
self.current_image_for_segmentation = None
|
|
52
|
-
self.current_scale_factor = 1.0
|
|
80
|
+
self.current_scale_factor = 1.0
|
|
53
81
|
|
|
54
82
|
# UI references
|
|
55
83
|
self.image_layer = None
|
|
@@ -63,101 +91,73 @@ class BatchCropAnything:
|
|
|
63
91
|
# Segmentation parameters
|
|
64
92
|
self.sensitivity = 50 # Default sensitivity (0-100 scale)
|
|
65
93
|
|
|
66
|
-
# Initialize the
|
|
67
|
-
self.
|
|
68
|
-
|
|
69
|
-
# --------------------------------------------------
|
|
70
|
-
# Model Initialization
|
|
71
|
-
# --------------------------------------------------
|
|
72
|
-
|
|
73
|
-
def _initialize_sam(self):
|
|
74
|
-
"""Initialize the Segment Anything Model."""
|
|
75
|
-
try:
|
|
76
|
-
# Import required modules
|
|
77
|
-
from mobile_sam import (
|
|
78
|
-
SamAutomaticMaskGenerator,
|
|
79
|
-
sam_model_registry,
|
|
80
|
-
)
|
|
94
|
+
# Initialize the SAM2 model
|
|
95
|
+
self._initialize_sam2()
|
|
81
96
|
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
model_type = "vit_t"
|
|
97
|
+
def _initialize_sam2(self):
|
|
98
|
+
"""Initialize the SAM2 model based on dimension mode."""
|
|
85
99
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
if checkpoint_path is None:
|
|
89
|
-
self.mobile_sam = None
|
|
90
|
-
self.mask_generator = None
|
|
91
|
-
return
|
|
100
|
+
def download_checkpoint(url, dest_folder):
|
|
101
|
+
import os
|
|
92
102
|
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
103
|
+
os.makedirs(dest_folder, exist_ok=True)
|
|
104
|
+
filename = os.path.join(dest_folder, url.split("/")[-1])
|
|
105
|
+
if not os.path.exists(filename):
|
|
106
|
+
print(f"Downloading checkpoint to {filename}...")
|
|
107
|
+
response = requests.get(url, stream=True)
|
|
108
|
+
response.raise_for_status()
|
|
109
|
+
with open(filename, "wb") as f:
|
|
110
|
+
for chunk in response.iter_content(chunk_size=8192):
|
|
111
|
+
f.write(chunk)
|
|
112
|
+
print("Download complete.")
|
|
113
|
+
else:
|
|
114
|
+
print(f"Checkpoint already exists at {filename}.")
|
|
115
|
+
return filename
|
|
99
116
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
self.viewer.status = f"Initialized SAM model from {checkpoint_path} on {self.device}"
|
|
117
|
+
try:
|
|
118
|
+
# import torch
|
|
103
119
|
|
|
104
|
-
|
|
105
|
-
self.viewer.status = f"Error initializing SAM: {str(e)}"
|
|
106
|
-
self.mobile_sam = None
|
|
107
|
-
self.mask_generator = None
|
|
120
|
+
self.device = get_device()
|
|
108
121
|
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
# Find the mobile_sam package location
|
|
115
|
-
mobile_sam_spec = importlib.util.find_spec("mobile_sam")
|
|
116
|
-
if mobile_sam_spec is None:
|
|
117
|
-
raise ImportError("mobile_sam package not found")
|
|
118
|
-
|
|
119
|
-
mobile_sam_path = os.path.dirname(mobile_sam_spec.origin)
|
|
120
|
-
|
|
121
|
-
# Check common locations for the model file
|
|
122
|
-
checkpoint_paths = [
|
|
123
|
-
os.path.join(mobile_sam_path, "weights", "mobile_sam.pt"),
|
|
124
|
-
os.path.join(mobile_sam_path, "mobile_sam.pt"),
|
|
125
|
-
os.path.join(
|
|
126
|
-
os.path.dirname(mobile_sam_path),
|
|
127
|
-
"weights",
|
|
128
|
-
"mobile_sam.pt",
|
|
129
|
-
),
|
|
130
|
-
os.path.join(
|
|
131
|
-
os.path.expanduser("~"), "models", "mobile_sam.pt"
|
|
132
|
-
),
|
|
133
|
-
"/opt/T-MIDAS/models/mobile_sam.pt",
|
|
134
|
-
os.path.join(os.getcwd(), "mobile_sam.pt"),
|
|
135
|
-
]
|
|
136
|
-
|
|
137
|
-
for path in checkpoint_paths:
|
|
138
|
-
if os.path.exists(path):
|
|
139
|
-
return path
|
|
140
|
-
|
|
141
|
-
# If model not found, ask user
|
|
142
|
-
QMessageBox.information(
|
|
143
|
-
None,
|
|
144
|
-
"Model Not Found",
|
|
145
|
-
"Mobile-SAM model weights not found. Please select the mobile_sam.pt file.",
|
|
122
|
+
# Download checkpoint if needed
|
|
123
|
+
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt"
|
|
124
|
+
checkpoint_path = download_checkpoint(
|
|
125
|
+
checkpoint_url, "/opt/sam2/checkpoints/"
|
|
146
126
|
)
|
|
127
|
+
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
|
147
128
|
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
)
|
|
129
|
+
if self.use_3d:
|
|
130
|
+
from sam2.build_sam import build_sam2_video_predictor
|
|
151
131
|
|
|
152
|
-
|
|
132
|
+
self.predictor = build_sam2_video_predictor(
|
|
133
|
+
model_cfg, checkpoint_path, device=self.device
|
|
134
|
+
)
|
|
135
|
+
self.viewer.status = (
|
|
136
|
+
f"Initialized SAM2 Video Predictor on {self.device}"
|
|
137
|
+
)
|
|
138
|
+
else:
|
|
139
|
+
from sam2.build_sam import build_sam2
|
|
140
|
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
141
|
+
|
|
142
|
+
self.predictor = SAM2ImagePredictor(
|
|
143
|
+
build_sam2(model_cfg, checkpoint_path)
|
|
144
|
+
)
|
|
145
|
+
self.viewer.status = (
|
|
146
|
+
f"Initialized SAM2 Image Predictor on {self.device}"
|
|
147
|
+
)
|
|
153
148
|
|
|
154
|
-
except (
|
|
155
|
-
|
|
156
|
-
|
|
149
|
+
except (
|
|
150
|
+
ImportError,
|
|
151
|
+
RuntimeError,
|
|
152
|
+
ValueError,
|
|
153
|
+
FileNotFoundError,
|
|
154
|
+
requests.RequestException,
|
|
155
|
+
) as e:
|
|
156
|
+
import traceback
|
|
157
157
|
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
158
|
+
self.viewer.status = f"Error initializing SAM2: {str(e)}"
|
|
159
|
+
self.predictor = None
|
|
160
|
+
print(traceback.format_exc())
|
|
161
161
|
|
|
162
162
|
def load_images(self, folder_path: str):
|
|
163
163
|
"""Load images from the specified folder path."""
|
|
@@ -169,17 +169,19 @@ class BatchCropAnything:
|
|
|
169
169
|
self.images = [
|
|
170
170
|
os.path.join(folder_path, file)
|
|
171
171
|
for file in files
|
|
172
|
-
if file.lower().endswith(
|
|
173
|
-
|
|
174
|
-
)
|
|
175
|
-
and not file.
|
|
172
|
+
if file.lower().endswith(".tif")
|
|
173
|
+
or file.lower().endswith(".tiff")
|
|
174
|
+
and "label" not in file.lower()
|
|
175
|
+
and "cropped" not in file.lower()
|
|
176
|
+
and "_labels_" not in file.lower()
|
|
177
|
+
and "_cropped_" not in file.lower()
|
|
176
178
|
]
|
|
177
179
|
|
|
178
180
|
if not self.images:
|
|
179
181
|
self.viewer.status = "No compatible images found in the folder."
|
|
180
182
|
return
|
|
181
183
|
|
|
182
|
-
self.viewer.status = f"Found {len(self.images)} images."
|
|
184
|
+
self.viewer.status = f"Found {len(self.images)} .tif images."
|
|
183
185
|
self.current_index = 0
|
|
184
186
|
self._load_current_image()
|
|
185
187
|
|
|
@@ -237,9 +239,9 @@ class BatchCropAnything:
|
|
|
237
239
|
self.viewer.status = "No images to process."
|
|
238
240
|
return
|
|
239
241
|
|
|
240
|
-
if self.
|
|
242
|
+
if self.predictor is None:
|
|
241
243
|
self.viewer.status = (
|
|
242
|
-
"
|
|
244
|
+
"SAM2 model not initialized. Cannot segment images."
|
|
243
245
|
)
|
|
244
246
|
return
|
|
245
247
|
|
|
@@ -253,66 +255,147 @@ class BatchCropAnything:
|
|
|
253
255
|
# Load and process image
|
|
254
256
|
self.original_image = imread(image_path)
|
|
255
257
|
|
|
256
|
-
#
|
|
257
|
-
if self.original_image.
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
258
|
+
# For 3D/4D data, determine dimensions
|
|
259
|
+
if self.use_3d and len(self.original_image.shape) >= 3:
|
|
260
|
+
# Check shape to identify dimensions
|
|
261
|
+
if len(self.original_image.shape) == 4: # TZYX or similar
|
|
262
|
+
# Identify time dimension as first dim with size > 4 and < 400
|
|
263
|
+
# This is a heuristic to differentiate time from channels/small Z stacks
|
|
264
|
+
time_dim_idx = -1
|
|
265
|
+
for i, dim_size in enumerate(self.original_image.shape):
|
|
266
|
+
if 4 < dim_size < 400:
|
|
267
|
+
time_dim_idx = i
|
|
268
|
+
break
|
|
269
|
+
|
|
270
|
+
if time_dim_idx == 0: # TZYX format
|
|
271
|
+
# Keep as is, T is already the first dimension
|
|
272
|
+
self.image_layer = self.viewer.add_image(
|
|
273
|
+
self.original_image,
|
|
274
|
+
name=f"Image ({os.path.basename(image_path)})",
|
|
275
|
+
)
|
|
276
|
+
# Store time dimension info
|
|
277
|
+
self.time_dim_size = self.original_image.shape[0]
|
|
278
|
+
self.has_z_dim = True
|
|
279
|
+
elif (
|
|
280
|
+
time_dim_idx > 0
|
|
281
|
+
): # Unusual format, we need to transpose
|
|
282
|
+
# Transpose to move T to first dimension
|
|
283
|
+
# Create permutation order that puts time_dim_idx first
|
|
284
|
+
perm_order = list(
|
|
285
|
+
range(len(self.original_image.shape))
|
|
286
|
+
)
|
|
287
|
+
perm_order.remove(time_dim_idx)
|
|
288
|
+
perm_order.insert(0, time_dim_idx)
|
|
289
|
+
|
|
290
|
+
transposed_image = np.transpose(
|
|
291
|
+
self.original_image, perm_order
|
|
292
|
+
)
|
|
293
|
+
self.original_image = (
|
|
294
|
+
transposed_image # Replace with transposed version
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
self.image_layer = self.viewer.add_image(
|
|
298
|
+
self.original_image,
|
|
299
|
+
name=f"Image ({os.path.basename(image_path)})",
|
|
300
|
+
)
|
|
301
|
+
# Store time dimension info
|
|
302
|
+
self.time_dim_size = self.original_image.shape[0]
|
|
303
|
+
self.has_z_dim = True
|
|
304
|
+
else:
|
|
305
|
+
# No time dimension found, treat as ZYX
|
|
306
|
+
self.image_layer = self.viewer.add_image(
|
|
307
|
+
self.original_image,
|
|
308
|
+
name=f"Image ({os.path.basename(image_path)})",
|
|
309
|
+
)
|
|
310
|
+
self.time_dim_size = 1
|
|
311
|
+
self.has_z_dim = True
|
|
312
|
+
elif (
|
|
313
|
+
len(self.original_image.shape) == 3
|
|
314
|
+
): # Could be TYX or ZYX
|
|
315
|
+
# Check if first dimension is likely time (> 4, < 400)
|
|
316
|
+
if 4 < self.original_image.shape[0] < 400:
|
|
317
|
+
# Likely TYX format
|
|
318
|
+
self.image_layer = self.viewer.add_image(
|
|
319
|
+
self.original_image,
|
|
320
|
+
name=f"Image ({os.path.basename(image_path)})",
|
|
321
|
+
)
|
|
322
|
+
self.time_dim_size = self.original_image.shape[0]
|
|
323
|
+
self.has_z_dim = False
|
|
324
|
+
else:
|
|
325
|
+
# Likely ZYX format or another 3D format
|
|
326
|
+
self.image_layer = self.viewer.add_image(
|
|
327
|
+
self.original_image,
|
|
328
|
+
name=f"Image ({os.path.basename(image_path)})",
|
|
329
|
+
)
|
|
330
|
+
self.time_dim_size = 1
|
|
331
|
+
self.has_z_dim = True
|
|
332
|
+
else:
|
|
333
|
+
# Should not reach here with use_3d=True, but just in case
|
|
334
|
+
self.image_layer = self.viewer.add_image(
|
|
335
|
+
self.original_image,
|
|
336
|
+
name=f"Image ({os.path.basename(image_path)})",
|
|
337
|
+
)
|
|
338
|
+
self.time_dim_size = 1
|
|
339
|
+
self.has_z_dim = False
|
|
261
340
|
else:
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
341
|
+
# Handle 2D data as before
|
|
342
|
+
if self.original_image.dtype != np.uint8:
|
|
343
|
+
image_for_display = (
|
|
344
|
+
self.original_image
|
|
345
|
+
/ np.amax(self.original_image)
|
|
346
|
+
* 255
|
|
347
|
+
).astype(np.uint8)
|
|
348
|
+
else:
|
|
349
|
+
image_for_display = self.original_image
|
|
350
|
+
|
|
351
|
+
# Add image to viewer
|
|
352
|
+
self.image_layer = self.viewer.add_image(
|
|
353
|
+
image_for_display,
|
|
354
|
+
name=f"Image ({os.path.basename(image_path)})",
|
|
355
|
+
)
|
|
269
356
|
|
|
270
357
|
# Generate segmentation
|
|
271
|
-
self._generate_segmentation(
|
|
358
|
+
self._generate_segmentation(self.original_image, image_path)
|
|
272
359
|
|
|
273
|
-
except (
|
|
360
|
+
except (FileNotFoundError, ValueError, TypeError, OSError) as e:
|
|
274
361
|
import traceback
|
|
275
362
|
|
|
276
363
|
self.viewer.status = f"Error processing image: {str(e)}"
|
|
277
364
|
traceback.print_exc()
|
|
365
|
+
|
|
278
366
|
# Create empty segmentation in case of error
|
|
279
367
|
if (
|
|
280
368
|
hasattr(self, "original_image")
|
|
281
369
|
and self.original_image is not None
|
|
282
370
|
):
|
|
283
|
-
self.
|
|
284
|
-
self.original_image.shape
|
|
285
|
-
|
|
371
|
+
if self.use_3d:
|
|
372
|
+
shape = self.original_image.shape
|
|
373
|
+
else:
|
|
374
|
+
shape = self.original_image.shape[:2]
|
|
375
|
+
|
|
376
|
+
self.segmentation_result = np.zeros(shape, dtype=np.uint32)
|
|
286
377
|
self.label_layer = self.viewer.add_labels(
|
|
287
378
|
self.segmentation_result, name="Error: No Segmentation"
|
|
288
379
|
)
|
|
289
380
|
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
def _generate_segmentation(self, image):
|
|
295
|
-
"""Generate segmentation for the current image."""
|
|
296
|
-
# Prepare for SAM (add color channel if needed)
|
|
297
|
-
if len(image.shape) == 2:
|
|
298
|
-
image_for_sam = image[:, :, np.newaxis].repeat(3, axis=2)
|
|
299
|
-
else:
|
|
300
|
-
image_for_sam = image
|
|
301
|
-
|
|
302
|
-
# Store the current image for later regeneration if sensitivity changes
|
|
303
|
-
self.current_image_for_segmentation = image_for_sam
|
|
381
|
+
def _generate_segmentation(self, image, image_path: str):
|
|
382
|
+
"""Generate segmentation for the current image using SAM2."""
|
|
383
|
+
# Store the current image for later processing
|
|
384
|
+
self.current_image_for_segmentation = image
|
|
304
385
|
|
|
305
386
|
# Generate segmentation with current sensitivity
|
|
306
|
-
self.generate_segmentation_with_sensitivity()
|
|
387
|
+
self.generate_segmentation_with_sensitivity(image_path)
|
|
307
388
|
|
|
308
|
-
def generate_segmentation_with_sensitivity(
|
|
389
|
+
def generate_segmentation_with_sensitivity(
|
|
390
|
+
self, image_path: str, sensitivity=None
|
|
391
|
+
):
|
|
309
392
|
"""Generate segmentation with the specified sensitivity."""
|
|
310
393
|
if sensitivity is not None:
|
|
311
394
|
self.sensitivity = sensitivity
|
|
312
395
|
|
|
313
|
-
if self.
|
|
396
|
+
if self.predictor is None:
|
|
314
397
|
self.viewer.status = (
|
|
315
|
-
"
|
|
398
|
+
"SAM2 model not initialized. Cannot segment images."
|
|
316
399
|
)
|
|
317
400
|
return
|
|
318
401
|
|
|
@@ -321,298 +404,723 @@ class BatchCropAnything:
|
|
|
321
404
|
return
|
|
322
405
|
|
|
323
406
|
try:
|
|
324
|
-
# Map sensitivity (0-100) to
|
|
325
|
-
#
|
|
326
|
-
|
|
407
|
+
# Map sensitivity (0-100) to SAM2 parameters
|
|
408
|
+
# For SAM2, adjust confidence threshold based on sensitivity
|
|
409
|
+
confidence_threshold = (
|
|
410
|
+
0.9 - (self.sensitivity / 100) * 0.4
|
|
411
|
+
) # Range from 0.9 to 0.5
|
|
412
|
+
|
|
413
|
+
# Process based on dimension mode
|
|
414
|
+
if self.use_3d:
|
|
415
|
+
# Process 3D data
|
|
416
|
+
self._generate_3d_segmentation(
|
|
417
|
+
confidence_threshold, image_path
|
|
418
|
+
)
|
|
419
|
+
else:
|
|
420
|
+
# Process 2D data
|
|
421
|
+
self._generate_2d_segmentation(confidence_threshold)
|
|
422
|
+
|
|
423
|
+
except (
|
|
424
|
+
ValueError,
|
|
425
|
+
RuntimeError,
|
|
426
|
+
torch.cuda.OutOfMemoryError,
|
|
427
|
+
TypeError,
|
|
428
|
+
) as e:
|
|
429
|
+
import traceback
|
|
327
430
|
|
|
328
|
-
|
|
329
|
-
|
|
431
|
+
self.viewer.status = f"Error generating segmentation: {str(e)}"
|
|
432
|
+
traceback.print_exc()
|
|
330
433
|
|
|
331
|
-
|
|
332
|
-
|
|
434
|
+
def _generate_2d_segmentation(self, confidence_threshold):
|
|
435
|
+
"""Generate 2D segmentation using SAM2 Image Predictor."""
|
|
436
|
+
# Ensure image is in the correct format for SAM2
|
|
437
|
+
image = self.current_image_for_segmentation
|
|
438
|
+
|
|
439
|
+
# Handle resizing for very large images
|
|
440
|
+
orig_shape = image.shape[:2]
|
|
441
|
+
image_mp = (orig_shape[0] * orig_shape[1]) / 1e6
|
|
442
|
+
max_mp = 2.0 # Maximum image size in megapixels
|
|
443
|
+
|
|
444
|
+
if image_mp > max_mp:
|
|
445
|
+
scale_factor = np.sqrt(max_mp / image_mp)
|
|
446
|
+
new_height = int(orig_shape[0] * scale_factor)
|
|
447
|
+
new_width = int(orig_shape[1] * scale_factor)
|
|
448
|
+
|
|
449
|
+
self.viewer.status = f"Downscaling image from {orig_shape} to {(new_height, new_width)} for processing"
|
|
450
|
+
|
|
451
|
+
# Resize image
|
|
452
|
+
resized_image = resize(
|
|
453
|
+
image,
|
|
454
|
+
(new_height, new_width),
|
|
455
|
+
anti_aliasing=True,
|
|
456
|
+
preserve_range=True,
|
|
457
|
+
).astype(
|
|
458
|
+
np.float32
|
|
459
|
+
) # Convert to float32
|
|
460
|
+
|
|
461
|
+
self.current_scale_factor = scale_factor
|
|
462
|
+
else:
|
|
463
|
+
# Convert to float32 format
|
|
464
|
+
if image.dtype != np.float32:
|
|
465
|
+
resized_image = image.astype(np.float32)
|
|
466
|
+
else:
|
|
467
|
+
resized_image = image
|
|
468
|
+
self.current_scale_factor = 1.0
|
|
469
|
+
|
|
470
|
+
# Ensure image is in RGB format for SAM2
|
|
471
|
+
if len(resized_image.shape) == 2:
|
|
472
|
+
# Convert grayscale to RGB
|
|
473
|
+
resized_image = np.stack([resized_image] * 3, axis=-1)
|
|
474
|
+
elif len(resized_image.shape) == 3 and resized_image.shape[2] == 1:
|
|
475
|
+
# Convert single channel to RGB
|
|
476
|
+
resized_image = np.concatenate([resized_image] * 3, axis=2)
|
|
477
|
+
elif len(resized_image.shape) == 3 and resized_image.shape[2] > 3:
|
|
478
|
+
# Use first 3 channels
|
|
479
|
+
resized_image = resized_image[:, :, :3]
|
|
480
|
+
|
|
481
|
+
# Normalize the image to [0,1] range if it's not already
|
|
482
|
+
if resized_image.max() > 1.0:
|
|
483
|
+
resized_image = resized_image / 255.0
|
|
484
|
+
|
|
485
|
+
# Set SAM2 prediction parameters based on sensitivity
|
|
486
|
+
with torch.inference_mode(), torch.autocast(
|
|
487
|
+
"cuda", dtype=torch.float32
|
|
488
|
+
):
|
|
489
|
+
# Set the image in the predictor
|
|
490
|
+
self.predictor.set_image(resized_image)
|
|
491
|
+
|
|
492
|
+
# Use automatic points generation with confidence threshold
|
|
493
|
+
masks, scores, _ = self.predictor.predict(
|
|
494
|
+
point_coords=None,
|
|
495
|
+
point_labels=None,
|
|
496
|
+
box=None,
|
|
497
|
+
multimask_output=True,
|
|
498
|
+
)
|
|
333
499
|
|
|
334
|
-
#
|
|
335
|
-
|
|
500
|
+
# Filter masks by confidence threshold
|
|
501
|
+
valid_masks = scores > confidence_threshold
|
|
502
|
+
masks = masks[valid_masks]
|
|
503
|
+
scores = scores[valid_masks]
|
|
336
504
|
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
self.mask_generator.min_mask_region_area = min_area
|
|
505
|
+
# Convert masks to label image
|
|
506
|
+
labels = np.zeros(resized_image.shape[:2], dtype=np.uint32)
|
|
507
|
+
self.label_info = {} # Reset label info
|
|
341
508
|
|
|
342
|
-
|
|
343
|
-
#
|
|
344
|
-
|
|
345
|
-
gamma = (
|
|
346
|
-
1.5 - (self.sensitivity / 100) * 1.0
|
|
347
|
-
) # Range from 1.5 to 0.5
|
|
509
|
+
for i, mask in enumerate(masks):
|
|
510
|
+
label_id = i + 1 # Start label IDs from 1
|
|
511
|
+
labels[mask] = label_id
|
|
348
512
|
|
|
349
|
-
#
|
|
350
|
-
|
|
513
|
+
# Calculate label information
|
|
514
|
+
area = np.sum(mask)
|
|
515
|
+
y_indices, x_indices = np.where(mask)
|
|
516
|
+
center_y = np.mean(y_indices) if len(y_indices) > 0 else 0
|
|
517
|
+
center_x = np.mean(x_indices) if len(x_indices) > 0 else 0
|
|
351
518
|
|
|
352
|
-
#
|
|
353
|
-
|
|
519
|
+
# Store label info
|
|
520
|
+
self.label_info[label_id] = {
|
|
521
|
+
"area": area,
|
|
522
|
+
"center_y": center_y,
|
|
523
|
+
"center_x": center_x,
|
|
524
|
+
"score": float(scores[i]),
|
|
525
|
+
}
|
|
354
526
|
|
|
355
|
-
|
|
356
|
-
|
|
527
|
+
# Handle upscaling if needed
|
|
528
|
+
if self.current_scale_factor < 1.0:
|
|
529
|
+
labels = resize(
|
|
530
|
+
labels,
|
|
531
|
+
orig_shape,
|
|
532
|
+
order=0, # Nearest neighbor interpolation
|
|
533
|
+
preserve_range=True,
|
|
534
|
+
anti_aliasing=False,
|
|
535
|
+
).astype(np.uint32)
|
|
357
536
|
|
|
358
|
-
|
|
359
|
-
|
|
537
|
+
# Sort labels by area (largest first)
|
|
538
|
+
self.label_info = dict(
|
|
539
|
+
sorted(
|
|
540
|
+
self.label_info.items(),
|
|
541
|
+
key=lambda item: item[1]["area"],
|
|
542
|
+
reverse=True,
|
|
543
|
+
)
|
|
544
|
+
)
|
|
360
545
|
|
|
361
|
-
|
|
362
|
-
|
|
546
|
+
# Save segmentation result
|
|
547
|
+
self.segmentation_result = labels
|
|
548
|
+
|
|
549
|
+
# Update the label layer
|
|
550
|
+
self._update_label_layer()
|
|
363
551
|
|
|
364
|
-
|
|
365
|
-
|
|
552
|
+
def _generate_3d_segmentation(self, confidence_threshold, image_path):
|
|
553
|
+
"""
|
|
554
|
+
Initialize 3D segmentation using SAM2 Video Predictor.
|
|
555
|
+
This correctly sets up interactive segmentation following SAM2's video approach.
|
|
556
|
+
"""
|
|
557
|
+
try:
|
|
558
|
+
# Handle image_path - make sure it's a string
|
|
559
|
+
if not isinstance(image_path, str):
|
|
560
|
+
image_path = self.images[self.current_index]
|
|
366
561
|
|
|
367
|
-
#
|
|
368
|
-
|
|
369
|
-
|
|
562
|
+
# Initialize empty segmentation
|
|
563
|
+
volume_shape = self.current_image_for_segmentation.shape
|
|
564
|
+
labels = np.zeros(volume_shape, dtype=np.uint32)
|
|
565
|
+
self.segmentation_result = labels
|
|
370
566
|
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
new_width = int(orig_shape[1] * scale_factor)
|
|
567
|
+
# Create a temp directory for the MP4 conversion if needed
|
|
568
|
+
import os
|
|
569
|
+
import tempfile
|
|
375
570
|
|
|
376
|
-
|
|
571
|
+
temp_dir = tempfile.gettempdir()
|
|
572
|
+
mp4_path = os.path.join(
|
|
573
|
+
temp_dir, f"temp_volume_{os.path.basename(image_path)}.mp4"
|
|
574
|
+
)
|
|
377
575
|
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
(new_height, new_width),
|
|
382
|
-
anti_aliasing=True,
|
|
383
|
-
preserve_range=True,
|
|
384
|
-
).astype(np.uint8)
|
|
576
|
+
# If we need to save a modified version for MP4 conversion
|
|
577
|
+
need_temp_tif = False
|
|
578
|
+
temp_tif_path = None
|
|
385
579
|
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
self.
|
|
580
|
+
# Check if we have a 4D volume with Z dimension
|
|
581
|
+
if (
|
|
582
|
+
hasattr(self, "has_z_dim")
|
|
583
|
+
and self.has_z_dim
|
|
584
|
+
and len(self.current_image_for_segmentation.shape) == 4
|
|
585
|
+
):
|
|
586
|
+
# We need to convert the 4D TZYX to a 3D TYX for proper video conversion
|
|
587
|
+
# by taking maximum intensity projection of Z for each time point
|
|
588
|
+
self.viewer.status = (
|
|
589
|
+
"Converting 4D TZYX volume to 3D TYX for SAM2..."
|
|
590
|
+
)
|
|
391
591
|
|
|
392
|
-
|
|
592
|
+
# Create maximum intensity projection along Z axis (axis 1 in TZYX)
|
|
593
|
+
projected_volume = np.max(
|
|
594
|
+
self.current_image_for_segmentation, axis=1
|
|
595
|
+
)
|
|
393
596
|
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
597
|
+
# Save this as a temporary TIF for MP4 conversion
|
|
598
|
+
temp_tif_path = os.path.join(
|
|
599
|
+
temp_dir, f"temp_projected_{os.path.basename(image_path)}"
|
|
600
|
+
)
|
|
601
|
+
imwrite(temp_tif_path, projected_volume)
|
|
602
|
+
need_temp_tif = True
|
|
397
603
|
|
|
398
|
-
|
|
604
|
+
# Convert the projected TIF to MP4
|
|
399
605
|
self.viewer.status = (
|
|
400
|
-
"
|
|
606
|
+
"Converting projected 3D volume to MP4 format for SAM2..."
|
|
401
607
|
)
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
608
|
+
mp4_path = tif_to_mp4(temp_tif_path)
|
|
609
|
+
else:
|
|
610
|
+
# Convert original volume to video format for SAM2
|
|
611
|
+
self.viewer.status = (
|
|
612
|
+
"Converting 3D volume to MP4 format for SAM2..."
|
|
613
|
+
)
|
|
614
|
+
mp4_path = tif_to_mp4(image_path)
|
|
405
615
|
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
self.viewer.layers.remove(layer)
|
|
616
|
+
# Initialize SAM2 state with the video
|
|
617
|
+
self.viewer.status = "Initializing SAM2 Video Predictor..."
|
|
618
|
+
with torch.inference_mode(), torch.autocast(
|
|
619
|
+
"cuda", dtype=torch.bfloat16
|
|
620
|
+
):
|
|
621
|
+
self._sam2_state = self.predictor.init_state(mp4_path)
|
|
413
622
|
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
623
|
+
# Store needed state for 3D processing
|
|
624
|
+
self._sam2_next_obj_id = 1
|
|
625
|
+
self._sam2_prompts = (
|
|
626
|
+
{}
|
|
627
|
+
) # Store prompts for each object (points, labels, box)
|
|
628
|
+
|
|
629
|
+
# Update the label layer with empty segmentation
|
|
630
|
+
self._update_label_layer()
|
|
631
|
+
|
|
632
|
+
# Replace the click handler for interactive 3D segmentation
|
|
633
|
+
if self.label_layer is not None and hasattr(
|
|
634
|
+
self.label_layer, "mouse_drag_callbacks"
|
|
635
|
+
):
|
|
636
|
+
for callback in list(self.label_layer.mouse_drag_callbacks):
|
|
637
|
+
self.label_layer.mouse_drag_callbacks.remove(callback)
|
|
638
|
+
|
|
639
|
+
# Add 3D-specific click handler
|
|
640
|
+
self.label_layer.mouse_drag_callbacks.append(
|
|
641
|
+
self._on_3d_label_clicked
|
|
419
642
|
)
|
|
420
643
|
|
|
421
|
-
|
|
422
|
-
|
|
644
|
+
# Set the viewer to show the first frame
|
|
645
|
+
if hasattr(self.viewer, "dims") and self.viewer.dims.ndim > 2:
|
|
646
|
+
self.viewer.dims.set_point(
|
|
647
|
+
0, 0
|
|
648
|
+
) # Set the first dimension (typically time/z) to 0
|
|
649
|
+
|
|
650
|
+
# Clean up temporary file if we created one
|
|
651
|
+
if (
|
|
652
|
+
need_temp_tif
|
|
653
|
+
and temp_tif_path
|
|
654
|
+
and os.path.exists(temp_tif_path)
|
|
655
|
+
):
|
|
656
|
+
with contextlib.suppress(Exception):
|
|
657
|
+
os.remove(temp_tif_path)
|
|
658
|
+
|
|
659
|
+
# Show instructions
|
|
660
|
+
self.viewer.status = (
|
|
661
|
+
"3D Mode active: Navigate to the first frame where object appears, then click. "
|
|
662
|
+
"Use Shift+click for negative points (to remove areas). "
|
|
663
|
+
"Segmentation will be propagated to all frames automatically."
|
|
664
|
+
)
|
|
665
|
+
|
|
666
|
+
return True
|
|
667
|
+
|
|
668
|
+
except (
|
|
669
|
+
FileNotFoundError,
|
|
670
|
+
RuntimeError,
|
|
671
|
+
torch.cuda.OutOfMemoryError,
|
|
672
|
+
ValueError,
|
|
673
|
+
OSError,
|
|
674
|
+
) as e:
|
|
675
|
+
import traceback
|
|
676
|
+
|
|
677
|
+
self.viewer.status = f"Error in 3D segmentation setup: {str(e)}"
|
|
678
|
+
traceback.print_exc()
|
|
679
|
+
return False
|
|
680
|
+
|
|
681
|
+
def _on_3d_label_clicked(self, layer, event):
|
|
682
|
+
"""Handle click on 3D label layer to add a prompt for segmentation."""
|
|
683
|
+
try:
|
|
684
|
+
if event.button != 1:
|
|
685
|
+
return
|
|
686
|
+
|
|
687
|
+
coords = layer.world_to_data(event.position)
|
|
688
|
+
if len(coords) == 3:
|
|
689
|
+
z, y, x = map(int, coords)
|
|
690
|
+
elif len(coords) == 2:
|
|
691
|
+
z = int(self.viewer.dims.current_step[0])
|
|
692
|
+
y, x = map(int, coords)
|
|
693
|
+
else:
|
|
694
|
+
self.viewer.status = (
|
|
695
|
+
f"Unexpected coordinate dimensions: {coords}"
|
|
696
|
+
)
|
|
423
697
|
return
|
|
424
698
|
|
|
425
|
-
#
|
|
426
|
-
|
|
427
|
-
if
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
699
|
+
# Check if Shift key is pressed
|
|
700
|
+
is_negative = "Shift" in event.modifiers
|
|
701
|
+
point_label = -1 if is_negative else 1
|
|
702
|
+
|
|
703
|
+
# Initialize a unique object ID for this click
|
|
704
|
+
if not hasattr(self, "_sam2_next_obj_id"):
|
|
705
|
+
self._sam2_next_obj_id = 1
|
|
706
|
+
|
|
707
|
+
# Get current object ID (or create new one)
|
|
708
|
+
label_id = self.segmentation_result[z, y, x]
|
|
709
|
+
if is_negative and label_id > 0:
|
|
710
|
+
# Use existing object ID for negative points
|
|
711
|
+
ann_obj_id = label_id
|
|
712
|
+
else:
|
|
713
|
+
# Create new object for positive points on background
|
|
714
|
+
ann_obj_id = self._sam2_next_obj_id
|
|
715
|
+
if point_label > 0 and label_id == 0:
|
|
716
|
+
self._sam2_next_obj_id += 1
|
|
717
|
+
|
|
718
|
+
# Find or create points layer for this object
|
|
719
|
+
points_layer = None
|
|
720
|
+
for layer in list(self.viewer.layers):
|
|
721
|
+
if f"Points for Object {ann_obj_id}" in layer.name:
|
|
722
|
+
points_layer = layer
|
|
723
|
+
break
|
|
724
|
+
|
|
725
|
+
if points_layer is None:
|
|
726
|
+
# Create new points layer for this object
|
|
727
|
+
points_layer = self.viewer.add_points(
|
|
728
|
+
np.array([[z, y, x]]),
|
|
729
|
+
name=f"Points for Object {ann_obj_id}",
|
|
730
|
+
size=10,
|
|
731
|
+
face_color="green" if point_label > 0 else "red",
|
|
732
|
+
border_color="white",
|
|
733
|
+
border_width=1,
|
|
734
|
+
opacity=0.8,
|
|
431
735
|
)
|
|
736
|
+
# Initialize points for this object
|
|
737
|
+
if not hasattr(self, "sam2_points_by_obj"):
|
|
738
|
+
self.sam2_points_by_obj = {}
|
|
739
|
+
self.sam2_labels_by_obj = {}
|
|
740
|
+
|
|
741
|
+
self.sam2_points_by_obj[ann_obj_id] = [[x, y]]
|
|
742
|
+
self.sam2_labels_by_obj[ann_obj_id] = [point_label]
|
|
432
743
|
else:
|
|
433
|
-
|
|
434
|
-
|
|
744
|
+
# Add to existing points layer
|
|
745
|
+
current_points = points_layer.data
|
|
746
|
+
new_points = np.vstack([current_points, [z, y, x]])
|
|
747
|
+
points_layer.data = new_points
|
|
748
|
+
|
|
749
|
+
# Add to existing point lists
|
|
750
|
+
if not hasattr(self, "sam2_points_by_obj"):
|
|
751
|
+
self.sam2_points_by_obj = {}
|
|
752
|
+
self.sam2_labels_by_obj = {}
|
|
753
|
+
|
|
754
|
+
if ann_obj_id not in self.sam2_points_by_obj:
|
|
755
|
+
self.sam2_points_by_obj[ann_obj_id] = []
|
|
756
|
+
self.sam2_labels_by_obj[ann_obj_id] = []
|
|
757
|
+
|
|
758
|
+
self.sam2_points_by_obj[ann_obj_id].append([x, y])
|
|
759
|
+
self.sam2_labels_by_obj[ann_obj_id].append(point_label)
|
|
760
|
+
|
|
761
|
+
# Perform SAM2 segmentation
|
|
762
|
+
if hasattr(self, "_sam2_state") and self._sam2_state is not None:
|
|
763
|
+
points = np.array(
|
|
764
|
+
self.sam2_points_by_obj[ann_obj_id], dtype=np.float32
|
|
765
|
+
)
|
|
766
|
+
labels = np.array(
|
|
767
|
+
self.sam2_labels_by_obj[ann_obj_id], dtype=np.int32
|
|
435
768
|
)
|
|
436
769
|
|
|
437
|
-
|
|
438
|
-
self.selected_labels = set()
|
|
770
|
+
self.viewer.status = f"Processing object at frame {z}..."
|
|
439
771
|
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
772
|
+
_, out_obj_ids, out_mask_logits = (
|
|
773
|
+
self.predictor.add_new_points_or_box(
|
|
774
|
+
inference_state=self._sam2_state,
|
|
775
|
+
frame_idx=z,
|
|
776
|
+
obj_id=ann_obj_id,
|
|
777
|
+
points=points,
|
|
778
|
+
labels=labels,
|
|
779
|
+
)
|
|
780
|
+
)
|
|
443
781
|
|
|
444
|
-
|
|
782
|
+
# Convert logits to mask and update segmentation
|
|
783
|
+
mask = (out_mask_logits[0] > 0.0).cpu().numpy()
|
|
784
|
+
|
|
785
|
+
# Fix mask dimensions if needed
|
|
786
|
+
if mask.ndim > 2:
|
|
787
|
+
mask = mask.squeeze()
|
|
788
|
+
|
|
789
|
+
# Check mask dimensions and resize if needed
|
|
790
|
+
if mask.shape != self.segmentation_result[z].shape:
|
|
791
|
+
from skimage.transform import resize
|
|
792
|
+
|
|
793
|
+
mask = resize(
|
|
794
|
+
mask.astype(float),
|
|
795
|
+
self.segmentation_result[z].shape,
|
|
796
|
+
order=0,
|
|
797
|
+
preserve_range=True,
|
|
798
|
+
anti_aliasing=False,
|
|
799
|
+
).astype(bool)
|
|
800
|
+
|
|
801
|
+
# Apply the mask to current frame
|
|
802
|
+
# For negative points, only remove from the current object
|
|
803
|
+
if point_label < 0:
|
|
804
|
+
# Remove only from current object
|
|
805
|
+
self.segmentation_result[z][
|
|
806
|
+
(self.segmentation_result[z] == ann_obj_id) & mask
|
|
807
|
+
] = 0
|
|
808
|
+
else:
|
|
809
|
+
# Add to current object (only overwrite background)
|
|
810
|
+
self.segmentation_result[z][
|
|
811
|
+
mask & (self.segmentation_result[z] == 0)
|
|
812
|
+
] = ann_obj_id
|
|
813
|
+
|
|
814
|
+
# Automatically propagate to other frames
|
|
815
|
+
self._propagate_mask_for_current_object(ann_obj_id, z)
|
|
816
|
+
|
|
817
|
+
# Update label layer
|
|
818
|
+
self._update_label_layer()
|
|
819
|
+
|
|
820
|
+
# Update label table if needed
|
|
821
|
+
if (
|
|
822
|
+
hasattr(self, "label_table_widget")
|
|
823
|
+
and self.label_table_widget is not None
|
|
824
|
+
):
|
|
825
|
+
self._populate_label_table(self.label_table_widget)
|
|
826
|
+
|
|
827
|
+
self.viewer.status = (
|
|
828
|
+
f"Updated 3D object {ann_obj_id} across all frames"
|
|
829
|
+
)
|
|
830
|
+
else:
|
|
831
|
+
self.viewer.status = "SAM2 3D state not initialized"
|
|
832
|
+
|
|
833
|
+
except (
|
|
834
|
+
IndexError,
|
|
835
|
+
KeyError,
|
|
836
|
+
ValueError,
|
|
837
|
+
RuntimeError,
|
|
838
|
+
torch.cuda.OutOfMemoryError,
|
|
839
|
+
) as e:
|
|
445
840
|
import traceback
|
|
446
841
|
|
|
447
|
-
self.viewer.status = f"Error
|
|
842
|
+
self.viewer.status = f"Error in 3D click handler: {str(e)}"
|
|
448
843
|
traceback.print_exc()
|
|
449
844
|
|
|
450
|
-
def
|
|
451
|
-
"""
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
self.label_info = {} # Reset label info
|
|
845
|
+
def _propagate_mask_for_current_object(self, obj_id, current_frame_idx):
|
|
846
|
+
"""
|
|
847
|
+
Propagate the mask for the current object from the given frame to all other frames.
|
|
848
|
+
Uses SAM2's video propagation with proper error handling.
|
|
455
849
|
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
850
|
+
Parameters:
|
|
851
|
+
obj_id: The ID of the object to propagate
|
|
852
|
+
current_frame_idx: The frame index where the object was identified
|
|
853
|
+
"""
|
|
854
|
+
try:
|
|
855
|
+
if not hasattr(self, "_sam2_state") or self._sam2_state is None:
|
|
856
|
+
self.viewer.status = (
|
|
857
|
+
"SAM2 3D state not initialized for propagation"
|
|
858
|
+
)
|
|
859
|
+
return
|
|
460
860
|
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
y_indices, x_indices = np.where(mask)
|
|
464
|
-
center_y = np.mean(y_indices) if len(y_indices) > 0 else 0
|
|
465
|
-
center_x = np.mean(x_indices) if len(x_indices) > 0 else 0
|
|
861
|
+
total_frames = self.segmentation_result.shape[0]
|
|
862
|
+
self.viewer.status = f"Propagating object {obj_id} through all {total_frames} frames..."
|
|
466
863
|
|
|
467
|
-
#
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
"
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
}
|
|
864
|
+
# Create a progress layer for visualization
|
|
865
|
+
progress_layer = None
|
|
866
|
+
for layer in list(self.viewer.layers):
|
|
867
|
+
if "Propagation Progress" in layer.name:
|
|
868
|
+
progress_layer = layer
|
|
869
|
+
break
|
|
474
870
|
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
871
|
+
if progress_layer is None:
|
|
872
|
+
progress_data = np.zeros_like(
|
|
873
|
+
self.segmentation_result, dtype=float
|
|
874
|
+
)
|
|
875
|
+
progress_layer = self.viewer.add_image(
|
|
876
|
+
progress_data,
|
|
877
|
+
name="Propagation Progress",
|
|
878
|
+
colormap="magma",
|
|
879
|
+
opacity=0.3,
|
|
880
|
+
visible=True,
|
|
881
|
+
)
|
|
882
|
+
|
|
883
|
+
# Update current frame in the progress layer
|
|
884
|
+
progress_data = progress_layer.data
|
|
885
|
+
current_mask = (
|
|
886
|
+
self.segmentation_result[current_frame_idx] == obj_id
|
|
481
887
|
)
|
|
482
|
-
|
|
888
|
+
progress_data[current_frame_idx] = current_mask.astype(float) * 0.8
|
|
889
|
+
progress_layer.data = progress_data
|
|
890
|
+
|
|
891
|
+
# Try to perform SAM2 propagation with error handling
|
|
892
|
+
try:
|
|
893
|
+
# Use torch.inference_mode() and torch.autocast to ensure consistent dtypes
|
|
894
|
+
with torch.inference_mode(), torch.autocast(
|
|
895
|
+
"cuda", dtype=torch.float32
|
|
896
|
+
):
|
|
897
|
+
# Attempt to run SAM2 propagation - this will iterate through all frames
|
|
898
|
+
for (
|
|
899
|
+
frame_idx,
|
|
900
|
+
object_ids,
|
|
901
|
+
mask_logits,
|
|
902
|
+
) in self.predictor.propagate_in_video(self._sam2_state):
|
|
903
|
+
if frame_idx >= total_frames:
|
|
904
|
+
continue
|
|
905
|
+
|
|
906
|
+
# Find our object ID in the results
|
|
907
|
+
# obj_mask = None
|
|
908
|
+
for i, prop_obj_id in enumerate(object_ids):
|
|
909
|
+
if prop_obj_id == obj_id:
|
|
910
|
+
# Get the mask for our object
|
|
911
|
+
mask = (mask_logits[i] > 0.0).cpu().numpy()
|
|
912
|
+
|
|
913
|
+
# Fix dimensions if needed
|
|
914
|
+
if mask.ndim > 2:
|
|
915
|
+
mask = mask.squeeze()
|
|
916
|
+
|
|
917
|
+
# Resize if needed
|
|
918
|
+
if (
|
|
919
|
+
mask.shape
|
|
920
|
+
!= self.segmentation_result[
|
|
921
|
+
frame_idx
|
|
922
|
+
].shape
|
|
923
|
+
):
|
|
924
|
+
from skimage.transform import resize
|
|
925
|
+
|
|
926
|
+
mask = resize(
|
|
927
|
+
mask.astype(float),
|
|
928
|
+
self.segmentation_result[
|
|
929
|
+
frame_idx
|
|
930
|
+
].shape,
|
|
931
|
+
order=0,
|
|
932
|
+
preserve_range=True,
|
|
933
|
+
anti_aliasing=False,
|
|
934
|
+
).astype(bool)
|
|
935
|
+
|
|
936
|
+
# Update segmentation - only replacing background pixels
|
|
937
|
+
self.segmentation_result[frame_idx][
|
|
938
|
+
mask
|
|
939
|
+
& (
|
|
940
|
+
self.segmentation_result[frame_idx]
|
|
941
|
+
== 0
|
|
942
|
+
)
|
|
943
|
+
] = obj_id
|
|
944
|
+
|
|
945
|
+
# Update progress visualization
|
|
946
|
+
progress_data = progress_layer.data
|
|
947
|
+
progress_data[frame_idx] = (
|
|
948
|
+
mask.astype(float) * 0.8
|
|
949
|
+
)
|
|
950
|
+
progress_layer.data = progress_data
|
|
951
|
+
|
|
952
|
+
# Update status occasionally
|
|
953
|
+
if frame_idx % 10 == 0:
|
|
954
|
+
self.viewer.status = f"Propagating: frame {frame_idx+1}/{total_frames}"
|
|
955
|
+
|
|
956
|
+
except RuntimeError as e:
|
|
957
|
+
# If we get a dtype mismatch or other error, the current frame's mask to other frames
|
|
958
|
+
self.viewer.status = f"SAM2 propagation failed with error: {str(e)}. Falling back to alternative method."
|
|
959
|
+
|
|
960
|
+
# Use the current frame's mask for propagation
|
|
961
|
+
for frame_idx in range(total_frames):
|
|
962
|
+
if (
|
|
963
|
+
frame_idx != current_frame_idx
|
|
964
|
+
): # Skip current frame as it's already done
|
|
965
|
+
# Only replace background pixels with the current frame's object
|
|
966
|
+
self.segmentation_result[frame_idx][
|
|
967
|
+
current_mask
|
|
968
|
+
& (self.segmentation_result[frame_idx] == 0)
|
|
969
|
+
] = obj_id
|
|
970
|
+
|
|
971
|
+
# Update progress layer
|
|
972
|
+
progress_data = progress_layer.data
|
|
973
|
+
progress_data[frame_idx] = (
|
|
974
|
+
current_mask.astype(float) * 0.5
|
|
975
|
+
) # Different intensity to indicate fallback
|
|
976
|
+
progress_layer.data = progress_data
|
|
977
|
+
|
|
978
|
+
# Update status occasionally
|
|
979
|
+
if frame_idx % 10 == 0:
|
|
980
|
+
self.viewer.status = f"Fallback propagation: frame {frame_idx+1}/{total_frames}"
|
|
981
|
+
|
|
982
|
+
# Remove progress layer after 2 seconds
|
|
983
|
+
import threading
|
|
984
|
+
|
|
985
|
+
def remove_progress():
|
|
986
|
+
import time
|
|
987
|
+
|
|
988
|
+
time.sleep(2)
|
|
989
|
+
for layer in list(self.viewer.layers):
|
|
990
|
+
if "Propagation Progress" in layer.name:
|
|
991
|
+
self.viewer.layers.remove(layer)
|
|
483
992
|
|
|
484
|
-
|
|
485
|
-
self.segmentation_result = labels
|
|
993
|
+
threading.Thread(target=remove_progress).start()
|
|
486
994
|
|
|
487
|
-
|
|
488
|
-
for layer in list(self.viewer.layers):
|
|
489
|
-
if isinstance(layer, Labels) and "Segmentation" in layer.name:
|
|
490
|
-
self.viewer.layers.remove(layer)
|
|
995
|
+
self.viewer.status = f"Propagation of object {obj_id} complete"
|
|
491
996
|
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
997
|
+
except (
|
|
998
|
+
IndexError,
|
|
999
|
+
ValueError,
|
|
1000
|
+
RuntimeError,
|
|
1001
|
+
torch.cuda.OutOfMemoryError,
|
|
1002
|
+
TypeError,
|
|
1003
|
+
) as e:
|
|
1004
|
+
import traceback
|
|
498
1005
|
|
|
499
|
-
|
|
500
|
-
|
|
1006
|
+
self.viewer.status = f"Error in propagation: {str(e)}"
|
|
1007
|
+
traceback.print_exc()
|
|
501
1008
|
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
):
|
|
508
|
-
|
|
509
|
-
for callback in list(self.label_layer.mouse_drag_callbacks):
|
|
510
|
-
self.label_layer.mouse_drag_callbacks.remove(callback)
|
|
511
|
-
|
|
512
|
-
# Connect mouse click event to label selection
|
|
513
|
-
self.label_layer.mouse_drag_callbacks.append(self._on_label_clicked)
|
|
514
|
-
|
|
515
|
-
# image_name = os.path.basename(self.images[self.current_index])
|
|
516
|
-
self.viewer.status = f"Loaded image {self.current_index + 1}/{len(self.images)} - Found {len(masks)} segments"
|
|
517
|
-
|
|
518
|
-
# New method for handling scaled segmentation masks
|
|
519
|
-
def _process_segmentation_masks_with_scaling(self, masks, original_shape):
|
|
520
|
-
"""Process segmentation masks with scaling to match the original image size."""
|
|
521
|
-
# Create label image from masks
|
|
522
|
-
# First determine the size of the mask predictions (which are at the downscaled resolution)
|
|
523
|
-
if not masks:
|
|
1009
|
+
def _add_3d_prompt(self, prompt_coords):
|
|
1010
|
+
"""
|
|
1011
|
+
Given a 3D coordinate (x, y, z), run SAM2 video predictor to segment the object at that point,
|
|
1012
|
+
update the segmentation result and label layer.
|
|
1013
|
+
"""
|
|
1014
|
+
if not hasattr(self, "_sam2_state") or self._sam2_state is None:
|
|
1015
|
+
self.viewer.status = "SAM2 3D state not initialized."
|
|
524
1016
|
return
|
|
525
1017
|
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
downscaled_labels = np.zeros(mask_shape, dtype=np.uint32)
|
|
530
|
-
self.label_info = {} # Reset label info
|
|
531
|
-
|
|
532
|
-
# Fill in the downscaled labels
|
|
533
|
-
for i, mask_data in enumerate(masks):
|
|
534
|
-
mask = mask_data["segmentation"]
|
|
535
|
-
label_id = i + 1 # Start label IDs from 1
|
|
536
|
-
downscaled_labels[mask] = label_id
|
|
537
|
-
|
|
538
|
-
# Store basic label info
|
|
539
|
-
area = np.sum(mask)
|
|
540
|
-
y_indices, x_indices = np.where(mask)
|
|
541
|
-
center_y = np.mean(y_indices) if len(y_indices) > 0 else 0
|
|
542
|
-
center_x = np.mean(x_indices) if len(x_indices) > 0 else 0
|
|
543
|
-
|
|
544
|
-
# Scale centers to original image coordinates
|
|
545
|
-
center_y_orig = center_y / self.current_scale_factor
|
|
546
|
-
center_x_orig = center_x / self.current_scale_factor
|
|
1018
|
+
if self.predictor is None:
|
|
1019
|
+
self.viewer.status = "SAM2 predictor not initialized."
|
|
1020
|
+
return
|
|
547
1021
|
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
self.current_scale_factor**2
|
|
553
|
-
), # Approximate area in original scale
|
|
554
|
-
"center_y": center_y_orig,
|
|
555
|
-
"center_x": center_x_orig,
|
|
556
|
-
"score": mask_data.get("stability_score", 0),
|
|
557
|
-
}
|
|
1022
|
+
# Prepare prompt for SAM2: point_coords is [[x, y, t]], point_labels is [1]
|
|
1023
|
+
x, y, z = prompt_coords
|
|
1024
|
+
point_coords = np.array([[x, y, z]])
|
|
1025
|
+
point_labels = np.array([1]) # 1 = foreground
|
|
558
1026
|
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
1027
|
+
with torch.inference_mode(), torch.autocast(
|
|
1028
|
+
"cuda", dtype=torch.bfloat16
|
|
1029
|
+
):
|
|
1030
|
+
masks, scores, _ = self.predictor.predict(
|
|
1031
|
+
state=self._sam2_state,
|
|
1032
|
+
point_coords=point_coords,
|
|
1033
|
+
point_labels=point_labels,
|
|
1034
|
+
multimask_output=True,
|
|
1035
|
+
)
|
|
567
1036
|
|
|
568
|
-
#
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
1037
|
+
# Pick the best mask (highest score)
|
|
1038
|
+
if masks is not None and len(masks) > 0:
|
|
1039
|
+
best_idx = np.argmax(scores)
|
|
1040
|
+
mask = masks[best_idx]
|
|
1041
|
+
obj_id = self._sam2_next_obj_id
|
|
1042
|
+
self.segmentation_result[mask] = obj_id
|
|
1043
|
+
self._sam2_next_obj_id += 1
|
|
1044
|
+
self.viewer.status = (
|
|
1045
|
+
f"Added object {obj_id} at (x={x}, y={y}, z={z})"
|
|
574
1046
|
)
|
|
575
|
-
|
|
1047
|
+
self._update_label_layer()
|
|
1048
|
+
else:
|
|
1049
|
+
self.viewer.status = "No mask found for this prompt."
|
|
1050
|
+
|
|
1051
|
+
def on_apply_propagate(self):
|
|
1052
|
+
"""Propagate masks across the video and update the segmentation layer."""
|
|
1053
|
+
self.viewer.status = "Propagating masks across all frames..."
|
|
1054
|
+
self.viewer.window._qt_window.setCursor(Qt.WaitCursor)
|
|
1055
|
+
|
|
1056
|
+
self.segmentation_result[:] = 0
|
|
1057
|
+
|
|
1058
|
+
for (
|
|
1059
|
+
frame_idx,
|
|
1060
|
+
object_ids,
|
|
1061
|
+
mask_logits,
|
|
1062
|
+
) in self.predictor.propagate_in_video(self._sam2_state):
|
|
1063
|
+
masks = (mask_logits > 0.0).cpu().numpy()
|
|
1064
|
+
if frame_idx >= self.segmentation_result.shape[0]:
|
|
1065
|
+
print(
|
|
1066
|
+
f"Warning: frame_idx {frame_idx} out of bounds for segmentation_result with shape {self.segmentation_result.shape}"
|
|
1067
|
+
)
|
|
1068
|
+
continue
|
|
1069
|
+
for i, obj_id in enumerate(object_ids):
|
|
1070
|
+
self.segmentation_result[frame_idx][masks[i]] = obj_id
|
|
1071
|
+
self.viewer.status = f"Propagating: frame {frame_idx+1}"
|
|
576
1072
|
|
|
577
|
-
|
|
578
|
-
self.
|
|
1073
|
+
self._update_label_layer()
|
|
1074
|
+
self.viewer.status = "Propagation complete!"
|
|
1075
|
+
self.viewer.window._qt_window.setCursor(Qt.ArrowCursor)
|
|
579
1076
|
|
|
580
|
-
|
|
1077
|
+
def _update_label_layer(self):
|
|
1078
|
+
"""Update the label layer in the viewer."""
|
|
1079
|
+
# Remove existing label layer if it exists
|
|
581
1080
|
for layer in list(self.viewer.layers):
|
|
582
1081
|
if isinstance(layer, Labels) and "Segmentation" in layer.name:
|
|
583
1082
|
self.viewer.layers.remove(layer)
|
|
584
1083
|
|
|
585
1084
|
# Add label layer to viewer
|
|
586
1085
|
self.label_layer = self.viewer.add_labels(
|
|
587
|
-
|
|
1086
|
+
self.segmentation_result,
|
|
588
1087
|
name=f"Segmentation ({os.path.basename(self.images[self.current_index])})",
|
|
589
1088
|
opacity=0.7,
|
|
590
1089
|
)
|
|
591
1090
|
|
|
592
|
-
#
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
1091
|
+
# Create points layer for interaction if it doesn't exist
|
|
1092
|
+
points_layer = None
|
|
1093
|
+
for layer in list(self.viewer.layers):
|
|
1094
|
+
if "Points" in layer.name:
|
|
1095
|
+
points_layer = layer
|
|
1096
|
+
break
|
|
1097
|
+
|
|
1098
|
+
if points_layer is None:
|
|
1099
|
+
# Initialize an empty points layer
|
|
1100
|
+
points_layer = self.viewer.add_points(
|
|
1101
|
+
np.zeros((0, 2 if not self.use_3d else 3)),
|
|
1102
|
+
name="Points (Click to Add)",
|
|
1103
|
+
size=10,
|
|
1104
|
+
face_color="green",
|
|
1105
|
+
border_color="white",
|
|
1106
|
+
border_width=1,
|
|
1107
|
+
opacity=0.8,
|
|
1108
|
+
)
|
|
604
1109
|
|
|
605
|
-
|
|
606
|
-
|
|
1110
|
+
# Connect points layer mouse click event
|
|
1111
|
+
points_layer.mouse_drag_callbacks.append(self._on_points_clicked)
|
|
607
1112
|
|
|
608
|
-
|
|
1113
|
+
# Make the points layer active to encourage interaction with it
|
|
1114
|
+
self.viewer.layers.selection.active = points_layer
|
|
609
1115
|
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
1116
|
+
# Update status
|
|
1117
|
+
n_labels = len(np.unique(self.segmentation_result)) - (
|
|
1118
|
+
1 if 0 in np.unique(self.segmentation_result) else 0
|
|
1119
|
+
)
|
|
1120
|
+
self.viewer.status = f"Loaded image {self.current_index + 1}/{len(self.images)} - Found {n_labels} segments"
|
|
613
1121
|
|
|
614
|
-
def
|
|
615
|
-
"""Handle
|
|
1122
|
+
def _on_points_clicked(self, layer, event):
|
|
1123
|
+
"""Handle clicks on the points layer for adding/removing points."""
|
|
616
1124
|
try:
|
|
617
1125
|
# Only process clicks, not drags
|
|
618
1126
|
if event.type != "mouse_press":
|
|
@@ -621,39 +1129,799 @@ class BatchCropAnything:
|
|
|
621
1129
|
# Get coordinates of mouse click
|
|
622
1130
|
coords = np.round(event.position).astype(int)
|
|
623
1131
|
|
|
624
|
-
#
|
|
625
|
-
|
|
626
|
-
if
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
1132
|
+
# Check if Shift is pressed for negative points
|
|
1133
|
+
is_negative = "Shift" in event.modifiers
|
|
1134
|
+
point_label = -1 if is_negative else 1
|
|
1135
|
+
|
|
1136
|
+
# Handle 2D vs 3D coordinates
|
|
1137
|
+
if self.use_3d:
|
|
1138
|
+
if len(coords) == 3:
|
|
1139
|
+
t, y, x = map(int, coords)
|
|
1140
|
+
elif len(coords) == 2:
|
|
1141
|
+
t = int(self.viewer.dims.current_step[0])
|
|
1142
|
+
y, x = map(int, coords)
|
|
1143
|
+
else:
|
|
1144
|
+
self.viewer.status = (
|
|
1145
|
+
f"Unexpected coordinate dimensions: {coords}"
|
|
1146
|
+
)
|
|
1147
|
+
return
|
|
1148
|
+
|
|
1149
|
+
# Add point to the layer immediately for visual feedback
|
|
1150
|
+
new_point = np.array([[t, y, x]])
|
|
1151
|
+
if len(layer.data) == 0:
|
|
1152
|
+
layer.data = new_point
|
|
1153
|
+
else:
|
|
1154
|
+
layer.data = np.vstack([layer.data, new_point])
|
|
1155
|
+
|
|
1156
|
+
# Update point colors
|
|
1157
|
+
colors = layer.face_color
|
|
1158
|
+
if isinstance(colors, list):
|
|
1159
|
+
colors.append("red" if is_negative else "green")
|
|
1160
|
+
else:
|
|
1161
|
+
n_points = len(layer.data)
|
|
1162
|
+
colors = ["green"] * (n_points - 1)
|
|
1163
|
+
colors.append("red" if is_negative else "green")
|
|
1164
|
+
layer.face_color = colors
|
|
1165
|
+
|
|
1166
|
+
# Get the object ID
|
|
1167
|
+
# If clicking on existing segmentation with negative point
|
|
1168
|
+
label_id = self.segmentation_result[t, y, x]
|
|
1169
|
+
if is_negative and label_id > 0:
|
|
1170
|
+
obj_id = label_id
|
|
1171
|
+
else:
|
|
1172
|
+
# For new objects or negative on background
|
|
1173
|
+
if not hasattr(self, "_sam2_next_obj_id"):
|
|
1174
|
+
self._sam2_next_obj_id = 1
|
|
1175
|
+
obj_id = self._sam2_next_obj_id
|
|
1176
|
+
if point_label > 0 and label_id == 0:
|
|
1177
|
+
self._sam2_next_obj_id += 1
|
|
1178
|
+
|
|
1179
|
+
# Store point information
|
|
1180
|
+
if not hasattr(self, "points_data"):
|
|
1181
|
+
self.points_data = {}
|
|
1182
|
+
self.points_labels = {}
|
|
1183
|
+
|
|
1184
|
+
if obj_id not in self.points_data:
|
|
1185
|
+
self.points_data[obj_id] = []
|
|
1186
|
+
self.points_labels[obj_id] = []
|
|
1187
|
+
|
|
1188
|
+
self.points_data[obj_id].append(
|
|
1189
|
+
[x, y]
|
|
1190
|
+
) # Note: SAM2 expects [x,y] format
|
|
1191
|
+
self.points_labels[obj_id].append(point_label)
|
|
1192
|
+
|
|
1193
|
+
# Perform segmentation
|
|
1194
|
+
if (
|
|
1195
|
+
hasattr(self, "_sam2_state")
|
|
1196
|
+
and self._sam2_state is not None
|
|
1197
|
+
):
|
|
1198
|
+
# Prepare points
|
|
1199
|
+
points = np.array(
|
|
1200
|
+
self.points_data[obj_id], dtype=np.float32
|
|
1201
|
+
)
|
|
1202
|
+
labels = np.array(
|
|
1203
|
+
self.points_labels[obj_id], dtype=np.int32
|
|
1204
|
+
)
|
|
1205
|
+
|
|
1206
|
+
# Create progress layer for visual feedback
|
|
1207
|
+
progress_layer = None
|
|
1208
|
+
for existing_layer in self.viewer.layers:
|
|
1209
|
+
if "Propagation Progress" in existing_layer.name:
|
|
1210
|
+
progress_layer = existing_layer
|
|
1211
|
+
break
|
|
1212
|
+
|
|
1213
|
+
if progress_layer is None:
|
|
1214
|
+
progress_data = np.zeros_like(self.segmentation_result)
|
|
1215
|
+
progress_layer = self.viewer.add_image(
|
|
1216
|
+
progress_data,
|
|
1217
|
+
name="Propagation Progress",
|
|
1218
|
+
colormap="magma",
|
|
1219
|
+
opacity=0.5,
|
|
1220
|
+
visible=True,
|
|
1221
|
+
)
|
|
1222
|
+
|
|
1223
|
+
# First update the current frame immediately
|
|
1224
|
+
self.viewer.status = f"Processing object at frame {t}..."
|
|
1225
|
+
|
|
1226
|
+
# Run SAM2 on current frame
|
|
1227
|
+
_, out_obj_ids, out_mask_logits = (
|
|
1228
|
+
self.predictor.add_new_points_or_box(
|
|
1229
|
+
inference_state=self._sam2_state,
|
|
1230
|
+
frame_idx=t,
|
|
1231
|
+
obj_id=obj_id,
|
|
1232
|
+
points=points,
|
|
1233
|
+
labels=labels,
|
|
1234
|
+
)
|
|
1235
|
+
)
|
|
1236
|
+
|
|
1237
|
+
# Update current frame
|
|
1238
|
+
mask = (out_mask_logits[0] > 0.0).cpu().numpy()
|
|
1239
|
+
if mask.ndim > 2:
|
|
1240
|
+
mask = mask.squeeze()
|
|
1241
|
+
|
|
1242
|
+
# Resize if needed
|
|
1243
|
+
if mask.shape != self.segmentation_result[t].shape:
|
|
1244
|
+
from skimage.transform import resize
|
|
1245
|
+
|
|
1246
|
+
mask = resize(
|
|
1247
|
+
mask.astype(float),
|
|
1248
|
+
self.segmentation_result[t].shape,
|
|
1249
|
+
order=0,
|
|
1250
|
+
preserve_range=True,
|
|
1251
|
+
anti_aliasing=False,
|
|
1252
|
+
).astype(bool)
|
|
1253
|
+
|
|
1254
|
+
# Update segmentation for this frame
|
|
1255
|
+
if point_label < 0:
|
|
1256
|
+
# For negative points, only remove from this object
|
|
1257
|
+
self.segmentation_result[t][
|
|
1258
|
+
(self.segmentation_result[t] == obj_id) & mask
|
|
1259
|
+
] = 0
|
|
1260
|
+
else:
|
|
1261
|
+
# For positive points, only replace background
|
|
1262
|
+
self.segmentation_result[t][
|
|
1263
|
+
mask & (self.segmentation_result[t] == 0)
|
|
1264
|
+
] = obj_id
|
|
1265
|
+
|
|
1266
|
+
# Update progress layer for this frame
|
|
1267
|
+
progress_data = progress_layer.data
|
|
1268
|
+
progress_data[t] = (
|
|
1269
|
+
mask.astype(float) * 0.5
|
|
1270
|
+
) # Highlight current frame
|
|
1271
|
+
progress_layer.data = progress_data
|
|
1272
|
+
|
|
1273
|
+
# Now propagate to all frames with visual feedback
|
|
1274
|
+
self.viewer.status = "Propagating to all frames..."
|
|
1275
|
+
|
|
1276
|
+
# Run propagation
|
|
1277
|
+
frame_count = self.segmentation_result.shape[0]
|
|
1278
|
+
for (
|
|
1279
|
+
frame_idx,
|
|
1280
|
+
prop_obj_ids,
|
|
1281
|
+
mask_logits,
|
|
1282
|
+
) in self.predictor.propagate_in_video(self._sam2_state):
|
|
1283
|
+
if frame_idx >= frame_count:
|
|
1284
|
+
continue
|
|
1285
|
+
|
|
1286
|
+
# Find our object
|
|
1287
|
+
obj_mask = None
|
|
1288
|
+
for i, prop_obj_id in enumerate(prop_obj_ids):
|
|
1289
|
+
if prop_obj_id == obj_id:
|
|
1290
|
+
obj_mask = (mask_logits[i] > 0.0).cpu().numpy()
|
|
1291
|
+
if obj_mask.ndim > 2:
|
|
1292
|
+
obj_mask = obj_mask.squeeze()
|
|
1293
|
+
|
|
1294
|
+
# Resize if needed
|
|
1295
|
+
if (
|
|
1296
|
+
obj_mask.shape
|
|
1297
|
+
!= self.segmentation_result[
|
|
1298
|
+
frame_idx
|
|
1299
|
+
].shape
|
|
1300
|
+
):
|
|
1301
|
+
obj_mask = resize(
|
|
1302
|
+
obj_mask.astype(float),
|
|
1303
|
+
self.segmentation_result[
|
|
1304
|
+
frame_idx
|
|
1305
|
+
].shape,
|
|
1306
|
+
order=0,
|
|
1307
|
+
preserve_range=True,
|
|
1308
|
+
anti_aliasing=False,
|
|
1309
|
+
).astype(bool)
|
|
1310
|
+
|
|
1311
|
+
# Update segmentation
|
|
1312
|
+
self.segmentation_result[frame_idx][
|
|
1313
|
+
obj_mask
|
|
1314
|
+
& (
|
|
1315
|
+
self.segmentation_result[frame_idx]
|
|
1316
|
+
== 0
|
|
1317
|
+
)
|
|
1318
|
+
] = obj_id
|
|
1319
|
+
|
|
1320
|
+
# Update progress visualization
|
|
1321
|
+
progress_data = progress_layer.data
|
|
1322
|
+
progress_data[frame_idx] = (
|
|
1323
|
+
obj_mask.astype(float) * 0.8
|
|
1324
|
+
) # Show as processed
|
|
1325
|
+
progress_layer.data = progress_data
|
|
1326
|
+
|
|
1327
|
+
# Update status
|
|
1328
|
+
if frame_idx % 5 == 0:
|
|
1329
|
+
self.viewer.status = f"Propagating: frame {frame_idx+1}/{frame_count}"
|
|
1330
|
+
# Remove the viewer.update() call as it's causing errors
|
|
1331
|
+
|
|
1332
|
+
# Process any missing frames
|
|
1333
|
+
processed_frames = set(range(frame_count))
|
|
1334
|
+
for frame_idx in range(frame_count):
|
|
1335
|
+
if (
|
|
1336
|
+
progress_data[frame_idx].max() == 0
|
|
1337
|
+
): # Frame not processed yet
|
|
1338
|
+
# Use nearest processed frame's mask
|
|
1339
|
+
nearest_idx = min(
|
|
1340
|
+
processed_frames,
|
|
1341
|
+
key=lambda x: abs(x - frame_idx),
|
|
1342
|
+
)
|
|
1343
|
+
if progress_data[nearest_idx].max() > 0:
|
|
1344
|
+
self.segmentation_result[frame_idx][
|
|
1345
|
+
(self.segmentation_result[frame_idx] == 0)
|
|
1346
|
+
& (
|
|
1347
|
+
self.segmentation_result[nearest_idx]
|
|
1348
|
+
== obj_id
|
|
1349
|
+
)
|
|
1350
|
+
] = obj_id
|
|
1351
|
+
|
|
1352
|
+
# Update progress visualization
|
|
1353
|
+
progress_data[frame_idx] = (
|
|
1354
|
+
progress_data[nearest_idx] * 0.6
|
|
1355
|
+
) # Mark as copied
|
|
1356
|
+
|
|
1357
|
+
# Final update of progress layer
|
|
1358
|
+
progress_layer.data = progress_data
|
|
1359
|
+
|
|
1360
|
+
# Remove progress layer after 2 seconds
|
|
1361
|
+
import threading
|
|
1362
|
+
|
|
1363
|
+
def remove_progress():
|
|
1364
|
+
import time
|
|
1365
|
+
|
|
1366
|
+
time.sleep(2)
|
|
1367
|
+
for layer in list(self.viewer.layers):
|
|
1368
|
+
if "Propagation Progress" in layer.name:
|
|
1369
|
+
self.viewer.layers.remove(layer)
|
|
1370
|
+
|
|
1371
|
+
threading.Thread(target=remove_progress).start()
|
|
1372
|
+
|
|
1373
|
+
# Update UI
|
|
1374
|
+
self._update_label_layer()
|
|
1375
|
+
if (
|
|
1376
|
+
hasattr(self, "label_table_widget")
|
|
1377
|
+
and self.label_table_widget is not None
|
|
1378
|
+
):
|
|
1379
|
+
self._populate_label_table(self.label_table_widget)
|
|
633
1380
|
|
|
634
|
-
|
|
635
|
-
label_id = self.segmentation_result[coords[0], coords[1]]
|
|
1381
|
+
self.viewer.status = f"Object {obj_id} segmented and propagated to all frames"
|
|
636
1382
|
|
|
637
|
-
|
|
638
|
-
|
|
1383
|
+
else:
|
|
1384
|
+
# 2D case
|
|
1385
|
+
if len(coords) == 2:
|
|
1386
|
+
y, x = map(int, coords)
|
|
1387
|
+
else:
|
|
1388
|
+
self.viewer.status = (
|
|
1389
|
+
f"Unexpected coordinate dimensions: {coords}"
|
|
1390
|
+
)
|
|
1391
|
+
return
|
|
1392
|
+
|
|
1393
|
+
# Add point to the layer immediately for visual feedback
|
|
1394
|
+
new_point = np.array([[y, x]])
|
|
1395
|
+
if len(layer.data) == 0:
|
|
1396
|
+
layer.data = new_point
|
|
1397
|
+
else:
|
|
1398
|
+
layer.data = np.vstack([layer.data, new_point])
|
|
1399
|
+
|
|
1400
|
+
# Update point colors
|
|
1401
|
+
colors = layer.face_color
|
|
1402
|
+
if isinstance(colors, list):
|
|
1403
|
+
colors.append("red" if is_negative else "green")
|
|
1404
|
+
else:
|
|
1405
|
+
n_points = len(layer.data)
|
|
1406
|
+
colors = ["green"] * (n_points - 1)
|
|
1407
|
+
colors.append("red" if is_negative else "green")
|
|
1408
|
+
layer.face_color = colors
|
|
1409
|
+
|
|
1410
|
+
# Get object ID
|
|
1411
|
+
label_id = self.segmentation_result[y, x]
|
|
1412
|
+
if is_negative and label_id > 0:
|
|
1413
|
+
obj_id = label_id
|
|
1414
|
+
else:
|
|
1415
|
+
if not hasattr(self, "next_obj_id"):
|
|
1416
|
+
self.next_obj_id = 1
|
|
1417
|
+
obj_id = self.next_obj_id
|
|
1418
|
+
if point_label > 0 and label_id == 0:
|
|
1419
|
+
self.next_obj_id += 1
|
|
1420
|
+
|
|
1421
|
+
# Store point information
|
|
1422
|
+
if not hasattr(self, "obj_points"):
|
|
1423
|
+
self.obj_points = {}
|
|
1424
|
+
self.obj_labels = {}
|
|
1425
|
+
|
|
1426
|
+
if obj_id not in self.obj_points:
|
|
1427
|
+
self.obj_points[obj_id] = []
|
|
1428
|
+
self.obj_labels[obj_id] = []
|
|
1429
|
+
|
|
1430
|
+
self.obj_points[obj_id].append(
|
|
1431
|
+
[x, y]
|
|
1432
|
+
) # SAM2 expects [x,y] format
|
|
1433
|
+
self.obj_labels[obj_id].append(point_label)
|
|
1434
|
+
|
|
1435
|
+
# Perform segmentation
|
|
1436
|
+
if hasattr(self, "predictor") and self.predictor is not None:
|
|
1437
|
+
# Make sure image is loaded
|
|
1438
|
+
if self.current_image_for_segmentation is None:
|
|
1439
|
+
self.viewer.status = "No image loaded for segmentation"
|
|
1440
|
+
return
|
|
1441
|
+
|
|
1442
|
+
# Prepare image for SAM2
|
|
1443
|
+
image = self.current_image_for_segmentation
|
|
1444
|
+
if len(image.shape) == 2:
|
|
1445
|
+
image = np.stack([image] * 3, axis=-1)
|
|
1446
|
+
elif len(image.shape) == 3 and image.shape[2] == 1:
|
|
1447
|
+
image = np.concatenate([image] * 3, axis=2)
|
|
1448
|
+
elif len(image.shape) == 3 and image.shape[2] > 3:
|
|
1449
|
+
image = image[:, :, :3]
|
|
1450
|
+
|
|
1451
|
+
if image.dtype != np.uint8:
|
|
1452
|
+
image = (image / np.max(image) * 255).astype(np.uint8)
|
|
1453
|
+
|
|
1454
|
+
# Set the image in the predictor
|
|
1455
|
+
self.predictor.set_image(image)
|
|
1456
|
+
|
|
1457
|
+
# Use only points for current object
|
|
1458
|
+
points = np.array(
|
|
1459
|
+
self.obj_points[obj_id], dtype=np.float32
|
|
1460
|
+
)
|
|
1461
|
+
labels = np.array(self.obj_labels[obj_id], dtype=np.int32)
|
|
1462
|
+
|
|
1463
|
+
self.viewer.status = f"Segmenting object {obj_id} with {len(points)} points..."
|
|
1464
|
+
|
|
1465
|
+
with torch.inference_mode(), torch.autocast("cuda"):
|
|
1466
|
+
masks, scores, _ = self.predictor.predict(
|
|
1467
|
+
point_coords=points,
|
|
1468
|
+
point_labels=labels,
|
|
1469
|
+
multimask_output=True,
|
|
1470
|
+
)
|
|
1471
|
+
|
|
1472
|
+
# Get best mask
|
|
1473
|
+
if len(masks) > 0:
|
|
1474
|
+
best_mask = masks[0]
|
|
1475
|
+
|
|
1476
|
+
# Update segmentation result
|
|
1477
|
+
if (
|
|
1478
|
+
best_mask.shape
|
|
1479
|
+
!= self.segmentation_result.shape
|
|
1480
|
+
):
|
|
1481
|
+
from skimage.transform import resize
|
|
1482
|
+
|
|
1483
|
+
best_mask = resize(
|
|
1484
|
+
best_mask.astype(float),
|
|
1485
|
+
self.segmentation_result.shape,
|
|
1486
|
+
order=0,
|
|
1487
|
+
preserve_range=True,
|
|
1488
|
+
anti_aliasing=False,
|
|
1489
|
+
).astype(bool)
|
|
1490
|
+
|
|
1491
|
+
# Apply mask based on point type
|
|
1492
|
+
if point_label < 0:
|
|
1493
|
+
# Remove only from current object
|
|
1494
|
+
mask_condition = np.logical_and(
|
|
1495
|
+
self.segmentation_result == obj_id,
|
|
1496
|
+
best_mask,
|
|
1497
|
+
)
|
|
1498
|
+
self.segmentation_result[mask_condition] = 0
|
|
1499
|
+
else:
|
|
1500
|
+
# Add to current object (only overwrite background)
|
|
1501
|
+
mask_condition = np.logical_and(
|
|
1502
|
+
best_mask, (self.segmentation_result == 0)
|
|
1503
|
+
)
|
|
1504
|
+
self.segmentation_result[mask_condition] = (
|
|
1505
|
+
obj_id
|
|
1506
|
+
)
|
|
1507
|
+
|
|
1508
|
+
# Update label info
|
|
1509
|
+
area = np.sum(self.segmentation_result == obj_id)
|
|
1510
|
+
y_indices, x_indices = np.where(
|
|
1511
|
+
self.segmentation_result == obj_id
|
|
1512
|
+
)
|
|
1513
|
+
center_y = (
|
|
1514
|
+
np.mean(y_indices) if len(y_indices) > 0 else 0
|
|
1515
|
+
)
|
|
1516
|
+
center_x = (
|
|
1517
|
+
np.mean(x_indices) if len(x_indices) > 0 else 0
|
|
1518
|
+
)
|
|
1519
|
+
|
|
1520
|
+
self.label_info[obj_id] = {
|
|
1521
|
+
"area": area,
|
|
1522
|
+
"center_y": center_y,
|
|
1523
|
+
"center_x": center_x,
|
|
1524
|
+
"score": float(scores[0]),
|
|
1525
|
+
}
|
|
1526
|
+
|
|
1527
|
+
self.viewer.status = f"Updated object {obj_id}"
|
|
1528
|
+
else:
|
|
1529
|
+
self.viewer.status = "No valid mask produced"
|
|
1530
|
+
|
|
1531
|
+
# Update the UI
|
|
1532
|
+
self._update_label_layer()
|
|
1533
|
+
if (
|
|
1534
|
+
hasattr(self, "label_table_widget")
|
|
1535
|
+
and self.label_table_widget is not None
|
|
1536
|
+
):
|
|
1537
|
+
self._populate_label_table(self.label_table_widget)
|
|
1538
|
+
|
|
1539
|
+
except (
|
|
1540
|
+
IndexError,
|
|
1541
|
+
KeyError,
|
|
1542
|
+
ValueError,
|
|
1543
|
+
RuntimeError,
|
|
1544
|
+
TypeError,
|
|
1545
|
+
) as e:
|
|
1546
|
+
import traceback
|
|
1547
|
+
|
|
1548
|
+
self.viewer.status = f"Error in points handling: {str(e)}"
|
|
1549
|
+
traceback.print_exc()
|
|
1550
|
+
|
|
1551
|
+
def _on_label_clicked(self, layer, event):
|
|
1552
|
+
"""Handle label selection and user prompts on mouse click."""
|
|
1553
|
+
try:
|
|
1554
|
+
# Only process clicks, not drags
|
|
1555
|
+
if event.type != "mouse_press":
|
|
639
1556
|
return
|
|
640
1557
|
|
|
641
|
-
#
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
1558
|
+
# Get coordinates of mouse click
|
|
1559
|
+
coords = np.round(event.position).astype(int)
|
|
1560
|
+
|
|
1561
|
+
# Check if Shift is pressed (negative point)
|
|
1562
|
+
is_negative = "Shift" in event.modifiers
|
|
1563
|
+
point_label = -1 if is_negative else 1
|
|
1564
|
+
|
|
1565
|
+
# For 2D data
|
|
1566
|
+
if not self.use_3d:
|
|
1567
|
+
if len(coords) == 2:
|
|
1568
|
+
y, x = map(int, coords)
|
|
1569
|
+
else:
|
|
1570
|
+
self.viewer.status = (
|
|
1571
|
+
f"Unexpected coordinate dimensions: {coords}"
|
|
1572
|
+
)
|
|
1573
|
+
return
|
|
1574
|
+
|
|
1575
|
+
# Check if within image bounds
|
|
1576
|
+
shape = self.segmentation_result.shape
|
|
1577
|
+
if y < 0 or x < 0 or y >= shape[0] or x >= shape[1]:
|
|
1578
|
+
self.viewer.status = "Click is outside image bounds"
|
|
1579
|
+
return
|
|
1580
|
+
|
|
1581
|
+
# Get the label ID at the clicked position
|
|
1582
|
+
label_id = self.segmentation_result[y, x]
|
|
1583
|
+
|
|
1584
|
+
# Initialize a unique object ID for this click (if needed)
|
|
1585
|
+
if not hasattr(self, "next_obj_id"):
|
|
1586
|
+
# Start with highest existing ID + 1
|
|
1587
|
+
if self.segmentation_result.max() > 0:
|
|
1588
|
+
self.next_obj_id = (
|
|
1589
|
+
int(self.segmentation_result.max()) + 1
|
|
1590
|
+
)
|
|
1591
|
+
else:
|
|
1592
|
+
self.next_obj_id = 1
|
|
1593
|
+
|
|
1594
|
+
# If clicking on background or using negative click, handle segmentation
|
|
1595
|
+
if label_id == 0 or is_negative:
|
|
1596
|
+
# Find or create points layer for the current object we're working on
|
|
1597
|
+
current_obj_id = None
|
|
1598
|
+
|
|
1599
|
+
# If negative point on existing label, use that label's ID
|
|
1600
|
+
if is_negative and label_id > 0:
|
|
1601
|
+
current_obj_id = label_id
|
|
1602
|
+
# For positive clicks on background, create a new object
|
|
1603
|
+
elif point_label > 0 and label_id == 0:
|
|
1604
|
+
current_obj_id = self.next_obj_id
|
|
1605
|
+
self.next_obj_id += 1
|
|
1606
|
+
# For negative on background, try to find most recent object
|
|
1607
|
+
elif point_label < 0 and label_id == 0:
|
|
1608
|
+
# Use most recently created object if available
|
|
1609
|
+
if hasattr(self, "obj_points") and self.obj_points:
|
|
1610
|
+
current_obj_id = max(self.obj_points.keys())
|
|
1611
|
+
else:
|
|
1612
|
+
self.viewer.status = "No existing object to modify with negative point"
|
|
1613
|
+
return
|
|
1614
|
+
|
|
1615
|
+
if current_obj_id is None:
|
|
1616
|
+
self.viewer.status = (
|
|
1617
|
+
"Could not determine which object to modify"
|
|
1618
|
+
)
|
|
1619
|
+
return
|
|
1620
|
+
|
|
1621
|
+
# Find or create points layer for this object
|
|
1622
|
+
points_layer = None
|
|
1623
|
+
for layer in list(self.viewer.layers):
|
|
1624
|
+
if f"Points for Object {current_obj_id}" in layer.name:
|
|
1625
|
+
points_layer = layer
|
|
1626
|
+
break
|
|
1627
|
+
|
|
1628
|
+
# Initialize object tracking if needed
|
|
1629
|
+
if not hasattr(self, "obj_points"):
|
|
1630
|
+
self.obj_points = {}
|
|
1631
|
+
self.obj_labels = {}
|
|
1632
|
+
|
|
1633
|
+
if current_obj_id not in self.obj_points:
|
|
1634
|
+
self.obj_points[current_obj_id] = []
|
|
1635
|
+
self.obj_labels[current_obj_id] = []
|
|
1636
|
+
|
|
1637
|
+
# Create or update points layer for this object
|
|
1638
|
+
if points_layer is None:
|
|
1639
|
+
# First point for this object
|
|
1640
|
+
points_layer = self.viewer.add_points(
|
|
1641
|
+
np.array([[y, x]]),
|
|
1642
|
+
name=f"Points for Object {current_obj_id}",
|
|
1643
|
+
size=10,
|
|
1644
|
+
face_color=["green" if point_label > 0 else "red"],
|
|
1645
|
+
border_color="white",
|
|
1646
|
+
border_width=1,
|
|
1647
|
+
opacity=0.8,
|
|
1648
|
+
)
|
|
1649
|
+
self.obj_points[current_obj_id] = [[x, y]]
|
|
1650
|
+
self.obj_labels[current_obj_id] = [point_label]
|
|
1651
|
+
else:
|
|
1652
|
+
# Add point to existing layer
|
|
1653
|
+
current_points = points_layer.data
|
|
1654
|
+
current_colors = points_layer.face_color
|
|
1655
|
+
|
|
1656
|
+
# Add new point
|
|
1657
|
+
new_points = np.vstack([current_points, [y, x]])
|
|
1658
|
+
new_color = "green" if point_label > 0 else "red"
|
|
1659
|
+
|
|
1660
|
+
# Update points layer
|
|
1661
|
+
points_layer.data = new_points
|
|
1662
|
+
|
|
1663
|
+
# Update colors
|
|
1664
|
+
if isinstance(current_colors, list):
|
|
1665
|
+
current_colors.append(new_color)
|
|
1666
|
+
points_layer.face_color = current_colors
|
|
1667
|
+
else:
|
|
1668
|
+
# If it's an array, create a list of colors
|
|
1669
|
+
colors = []
|
|
1670
|
+
for i in range(len(new_points)):
|
|
1671
|
+
if i < len(current_points):
|
|
1672
|
+
colors.append(
|
|
1673
|
+
"green" if point_label > 0 else "red"
|
|
1674
|
+
)
|
|
1675
|
+
else:
|
|
1676
|
+
colors.append(new_color)
|
|
1677
|
+
points_layer.face_color = colors
|
|
1678
|
+
|
|
1679
|
+
# Update object tracking
|
|
1680
|
+
self.obj_points[current_obj_id].append([x, y])
|
|
1681
|
+
self.obj_labels[current_obj_id].append(point_label)
|
|
1682
|
+
|
|
1683
|
+
# Now do the actual segmentation using SAM2
|
|
1684
|
+
if (
|
|
1685
|
+
hasattr(self, "predictor")
|
|
1686
|
+
and self.predictor is not None
|
|
1687
|
+
):
|
|
1688
|
+
try:
|
|
1689
|
+
# Make sure image is loaded
|
|
1690
|
+
if self.current_image_for_segmentation is None:
|
|
1691
|
+
self.viewer.status = (
|
|
1692
|
+
"No image loaded for segmentation"
|
|
1693
|
+
)
|
|
1694
|
+
return
|
|
1695
|
+
|
|
1696
|
+
# Prepare image for SAM2
|
|
1697
|
+
image = self.current_image_for_segmentation
|
|
1698
|
+
if len(image.shape) == 2:
|
|
1699
|
+
image = np.stack([image] * 3, axis=-1)
|
|
1700
|
+
elif len(image.shape) == 3 and image.shape[2] == 1:
|
|
1701
|
+
image = np.concatenate([image] * 3, axis=2)
|
|
1702
|
+
elif len(image.shape) == 3 and image.shape[2] > 3:
|
|
1703
|
+
image = image[:, :, :3]
|
|
1704
|
+
|
|
1705
|
+
if image.dtype != np.uint8:
|
|
1706
|
+
image = (image / np.max(image) * 255).astype(
|
|
1707
|
+
np.uint8
|
|
1708
|
+
)
|
|
1709
|
+
|
|
1710
|
+
# Set the image in the predictor
|
|
1711
|
+
self.predictor.set_image(image)
|
|
1712
|
+
|
|
1713
|
+
# Only use the points for the current object being segmented
|
|
1714
|
+
points = np.array(
|
|
1715
|
+
self.obj_points[current_obj_id],
|
|
1716
|
+
dtype=np.float32,
|
|
1717
|
+
)
|
|
1718
|
+
labels = np.array(
|
|
1719
|
+
self.obj_labels[current_obj_id], dtype=np.int32
|
|
1720
|
+
)
|
|
1721
|
+
|
|
1722
|
+
self.viewer.status = f"Segmenting object {current_obj_id} with {len(points)} points..."
|
|
1723
|
+
|
|
1724
|
+
with torch.inference_mode(), torch.autocast(
|
|
1725
|
+
"cuda"
|
|
1726
|
+
):
|
|
1727
|
+
masks, scores, _ = self.predictor.predict(
|
|
1728
|
+
point_coords=points,
|
|
1729
|
+
point_labels=labels,
|
|
1730
|
+
multimask_output=True,
|
|
1731
|
+
)
|
|
1732
|
+
|
|
1733
|
+
# Get best mask
|
|
1734
|
+
if len(masks) > 0:
|
|
1735
|
+
best_mask = masks[0]
|
|
1736
|
+
|
|
1737
|
+
# Update segmentation result
|
|
1738
|
+
if (
|
|
1739
|
+
best_mask.shape
|
|
1740
|
+
!= self.segmentation_result.shape
|
|
1741
|
+
):
|
|
1742
|
+
from skimage.transform import resize
|
|
1743
|
+
|
|
1744
|
+
best_mask = resize(
|
|
1745
|
+
best_mask.astype(float),
|
|
1746
|
+
self.segmentation_result.shape,
|
|
1747
|
+
order=0,
|
|
1748
|
+
preserve_range=True,
|
|
1749
|
+
anti_aliasing=False,
|
|
1750
|
+
).astype(bool)
|
|
1751
|
+
|
|
1752
|
+
# CRITICAL FIX: For negative points, only remove from this object's mask
|
|
1753
|
+
# For positive points, add to this object's mask without removing other objects
|
|
1754
|
+
if point_label < 0:
|
|
1755
|
+
# Remove only from current object's mask
|
|
1756
|
+
self.segmentation_result[
|
|
1757
|
+
(
|
|
1758
|
+
self.segmentation_result
|
|
1759
|
+
== current_obj_id
|
|
1760
|
+
)
|
|
1761
|
+
& best_mask
|
|
1762
|
+
] = 0
|
|
1763
|
+
else:
|
|
1764
|
+
# Add to current object's mask without affecting other objects
|
|
1765
|
+
# Only overwrite background (value 0)
|
|
1766
|
+
self.segmentation_result[
|
|
1767
|
+
best_mask
|
|
1768
|
+
& (self.segmentation_result == 0)
|
|
1769
|
+
] = current_obj_id
|
|
1770
|
+
|
|
1771
|
+
# Update label info
|
|
1772
|
+
area = np.sum(
|
|
1773
|
+
self.segmentation_result
|
|
1774
|
+
== current_obj_id
|
|
1775
|
+
)
|
|
1776
|
+
y_indices, x_indices = np.where(
|
|
1777
|
+
self.segmentation_result
|
|
1778
|
+
== current_obj_id
|
|
1779
|
+
)
|
|
1780
|
+
center_y = (
|
|
1781
|
+
np.mean(y_indices)
|
|
1782
|
+
if len(y_indices) > 0
|
|
1783
|
+
else 0
|
|
1784
|
+
)
|
|
1785
|
+
center_x = (
|
|
1786
|
+
np.mean(x_indices)
|
|
1787
|
+
if len(x_indices) > 0
|
|
1788
|
+
else 0
|
|
1789
|
+
)
|
|
1790
|
+
|
|
1791
|
+
self.label_info[current_obj_id] = {
|
|
1792
|
+
"area": area,
|
|
1793
|
+
"center_y": center_y,
|
|
1794
|
+
"center_x": center_x,
|
|
1795
|
+
"score": float(scores[0]),
|
|
1796
|
+
}
|
|
1797
|
+
|
|
1798
|
+
self.viewer.status = (
|
|
1799
|
+
f"Updated object {current_obj_id}"
|
|
1800
|
+
)
|
|
1801
|
+
else:
|
|
1802
|
+
self.viewer.status = (
|
|
1803
|
+
"No valid mask produced"
|
|
1804
|
+
)
|
|
1805
|
+
|
|
1806
|
+
# Update the UI
|
|
1807
|
+
self._update_label_layer()
|
|
1808
|
+
if (
|
|
1809
|
+
hasattr(self, "label_table_widget")
|
|
1810
|
+
and self.label_table_widget is not None
|
|
1811
|
+
):
|
|
1812
|
+
self._populate_label_table(
|
|
1813
|
+
self.label_table_widget
|
|
1814
|
+
)
|
|
1815
|
+
|
|
1816
|
+
except (
|
|
1817
|
+
IndexError,
|
|
1818
|
+
KeyError,
|
|
1819
|
+
ValueError,
|
|
1820
|
+
AttributeError,
|
|
1821
|
+
TypeError,
|
|
1822
|
+
) as e:
|
|
1823
|
+
import traceback
|
|
1824
|
+
|
|
1825
|
+
self.viewer.status = (
|
|
1826
|
+
f"Error in SAM2 processing: {str(e)}"
|
|
1827
|
+
)
|
|
1828
|
+
traceback.print_exc()
|
|
1829
|
+
|
|
1830
|
+
# If clicking on an existing label, toggle selection
|
|
1831
|
+
elif label_id > 0:
|
|
1832
|
+
# Toggle the label selection
|
|
1833
|
+
if label_id in self.selected_labels:
|
|
1834
|
+
self.selected_labels.remove(label_id)
|
|
1835
|
+
self.viewer.status = f"Deselected label ID: {label_id} | Selected labels: {self.selected_labels}"
|
|
1836
|
+
else:
|
|
1837
|
+
self.selected_labels.add(label_id)
|
|
1838
|
+
self.viewer.status = f"Selected label ID: {label_id} | Selected labels: {self.selected_labels}"
|
|
1839
|
+
|
|
1840
|
+
# Update table and preview
|
|
1841
|
+
self._update_label_table()
|
|
1842
|
+
self.preview_crop()
|
|
1843
|
+
|
|
1844
|
+
# 3D case (handle differently)
|
|
645
1845
|
else:
|
|
646
|
-
|
|
647
|
-
|
|
1846
|
+
if len(coords) == 3:
|
|
1847
|
+
t, y, x = map(int, coords)
|
|
1848
|
+
elif len(coords) == 2:
|
|
1849
|
+
t = int(self.viewer.dims.current_step[0])
|
|
1850
|
+
y, x = map(int, coords)
|
|
1851
|
+
else:
|
|
1852
|
+
self.viewer.status = (
|
|
1853
|
+
f"Unexpected coordinate dimensions: {coords}"
|
|
1854
|
+
)
|
|
1855
|
+
return
|
|
1856
|
+
|
|
1857
|
+
# Check if within bounds
|
|
1858
|
+
shape = self.segmentation_result.shape
|
|
1859
|
+
if (
|
|
1860
|
+
t < 0
|
|
1861
|
+
or t >= shape[0]
|
|
1862
|
+
or y < 0
|
|
1863
|
+
or y >= shape[1]
|
|
1864
|
+
or x < 0
|
|
1865
|
+
or x >= shape[2]
|
|
1866
|
+
):
|
|
1867
|
+
self.viewer.status = "Click is outside volume bounds"
|
|
1868
|
+
return
|
|
1869
|
+
|
|
1870
|
+
# Get the label ID at the clicked position
|
|
1871
|
+
label_id = self.segmentation_result[t, y, x]
|
|
1872
|
+
|
|
1873
|
+
# If background or shift is pressed, handle in _on_3d_label_clicked
|
|
1874
|
+
if label_id == 0 or is_negative:
|
|
1875
|
+
# This will be handled by _on_3d_label_clicked already attached
|
|
1876
|
+
pass
|
|
1877
|
+
# If clicking on an existing label, handle selection
|
|
1878
|
+
elif label_id > 0:
|
|
1879
|
+
# Toggle the label selection
|
|
1880
|
+
if label_id in self.selected_labels:
|
|
1881
|
+
self.selected_labels.remove(label_id)
|
|
1882
|
+
self.viewer.status = f"Deselected label ID: {label_id} | Selected labels: {self.selected_labels}"
|
|
1883
|
+
else:
|
|
1884
|
+
self.selected_labels.add(label_id)
|
|
1885
|
+
self.viewer.status = f"Selected label ID: {label_id} | Selected labels: {self.selected_labels}"
|
|
648
1886
|
|
|
649
|
-
|
|
650
|
-
|
|
1887
|
+
# Update table if it exists
|
|
1888
|
+
self._update_label_table()
|
|
651
1889
|
|
|
652
|
-
|
|
653
|
-
|
|
1890
|
+
# Update preview after selection changes
|
|
1891
|
+
self.preview_crop()
|
|
654
1892
|
|
|
655
|
-
except (
|
|
656
|
-
|
|
1893
|
+
except (
|
|
1894
|
+
IndexError,
|
|
1895
|
+
KeyError,
|
|
1896
|
+
ValueError,
|
|
1897
|
+
AttributeError,
|
|
1898
|
+
TypeError,
|
|
1899
|
+
) as e:
|
|
1900
|
+
import traceback
|
|
1901
|
+
|
|
1902
|
+
self.viewer.status = f"Error in click handling: {str(e)}"
|
|
1903
|
+
traceback.print_exc()
|
|
1904
|
+
|
|
1905
|
+
def _add_point_marker(self, coords, label_type):
|
|
1906
|
+
"""Add a visible marker for where the user clicked."""
|
|
1907
|
+
# Remove previous point markers
|
|
1908
|
+
for layer in list(self.viewer.layers):
|
|
1909
|
+
if "Point Prompt" in layer.name:
|
|
1910
|
+
self.viewer.layers.remove(layer)
|
|
1911
|
+
|
|
1912
|
+
# Create points layer
|
|
1913
|
+
color = (
|
|
1914
|
+
"red" if label_type < 0 else "green"
|
|
1915
|
+
) # Red for negative, green for positive
|
|
1916
|
+
self.viewer.add_points(
|
|
1917
|
+
[coords],
|
|
1918
|
+
name="Point Prompt",
|
|
1919
|
+
size=10,
|
|
1920
|
+
face_color=color,
|
|
1921
|
+
edge_color="white",
|
|
1922
|
+
edge_width=2,
|
|
1923
|
+
opacity=0.8,
|
|
1924
|
+
)
|
|
657
1925
|
|
|
658
1926
|
def create_label_table(self, parent_widget):
|
|
659
1927
|
"""Create a table widget displaying all detected labels."""
|
|
@@ -694,57 +1962,86 @@ class BatchCropAnything:
|
|
|
694
1962
|
|
|
695
1963
|
def _populate_label_table(self, table):
|
|
696
1964
|
"""Populate the table with label information."""
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
1965
|
+
try:
|
|
1966
|
+
# Get all unique non-zero labels from the segmentation result safely
|
|
1967
|
+
if self.segmentation_result is None:
|
|
1968
|
+
# No segmentation yet
|
|
1969
|
+
table.setRowCount(0)
|
|
1970
|
+
self.viewer.status = "No segmentation available"
|
|
1971
|
+
return
|
|
700
1972
|
|
|
701
|
-
|
|
702
|
-
|
|
1973
|
+
# Get unique labels, safely handling None values
|
|
1974
|
+
unique_labels = []
|
|
1975
|
+
for val in np.unique(self.segmentation_result):
|
|
1976
|
+
if val is not None and val > 0:
|
|
1977
|
+
unique_labels.append(val)
|
|
703
1978
|
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
reverse=True,
|
|
709
|
-
)
|
|
1979
|
+
if len(unique_labels) == 0:
|
|
1980
|
+
table.setRowCount(0)
|
|
1981
|
+
self.viewer.status = "No labeled objects found"
|
|
1982
|
+
return
|
|
710
1983
|
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
1984
|
+
# Set row count
|
|
1985
|
+
table.setRowCount(len(unique_labels))
|
|
1986
|
+
|
|
1987
|
+
# Fill in label info for any missing labels
|
|
1988
|
+
for label_id in unique_labels:
|
|
1989
|
+
if label_id not in self.label_info:
|
|
1990
|
+
# Calculate basic info for this label
|
|
1991
|
+
mask = self.segmentation_result == label_id
|
|
1992
|
+
area = np.sum(mask)
|
|
1993
|
+
|
|
1994
|
+
# Add info to label_info dictionary
|
|
1995
|
+
self.label_info[label_id] = {
|
|
1996
|
+
"area": area,
|
|
1997
|
+
"score": 1.0, # Default score
|
|
1998
|
+
}
|
|
1999
|
+
|
|
2000
|
+
# Fill table with data
|
|
2001
|
+
for row, label_id in enumerate(unique_labels):
|
|
2002
|
+
# Checkbox for selection
|
|
2003
|
+
checkbox_widget = QWidget()
|
|
2004
|
+
checkbox_layout = QHBoxLayout(checkbox_widget)
|
|
2005
|
+
checkbox_layout.setContentsMargins(5, 0, 5, 0)
|
|
2006
|
+
checkbox_layout.setAlignment(Qt.AlignCenter)
|
|
2007
|
+
|
|
2008
|
+
checkbox = QCheckBox()
|
|
2009
|
+
checkbox.setChecked(label_id in self.selected_labels)
|
|
2010
|
+
|
|
2011
|
+
# Connect checkbox to label selection
|
|
2012
|
+
def make_checkbox_callback(lid):
|
|
2013
|
+
def callback(state):
|
|
2014
|
+
if state == Qt.Checked:
|
|
2015
|
+
self.selected_labels.add(lid)
|
|
2016
|
+
else:
|
|
2017
|
+
self.selected_labels.discard(lid)
|
|
2018
|
+
self.preview_crop()
|
|
730
2019
|
|
|
731
|
-
|
|
2020
|
+
return callback
|
|
732
2021
|
|
|
733
|
-
|
|
2022
|
+
checkbox.stateChanged.connect(make_checkbox_callback(label_id))
|
|
734
2023
|
|
|
735
|
-
|
|
736
|
-
|
|
2024
|
+
checkbox_layout.addWidget(checkbox)
|
|
2025
|
+
table.setCellWidget(row, 0, checkbox_widget)
|
|
737
2026
|
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
2027
|
+
# Label ID as plain text with transparent background
|
|
2028
|
+
item = QTableWidgetItem(str(label_id))
|
|
2029
|
+
item.setTextAlignment(Qt.AlignCenter)
|
|
741
2030
|
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
2031
|
+
# Set the background color to transparent
|
|
2032
|
+
brush = item.background()
|
|
2033
|
+
brush.setStyle(Qt.NoBrush)
|
|
2034
|
+
item.setBackground(brush)
|
|
746
2035
|
|
|
747
|
-
|
|
2036
|
+
table.setItem(row, 1, item)
|
|
2037
|
+
|
|
2038
|
+
except (KeyError, TypeError, ValueError, AttributeError) as e:
|
|
2039
|
+
import traceback
|
|
2040
|
+
|
|
2041
|
+
self.viewer.status = f"Error populating table: {str(e)}"
|
|
2042
|
+
traceback.print_exc()
|
|
2043
|
+
# Set empty table as fallback
|
|
2044
|
+
table.setRowCount(0)
|
|
748
2045
|
|
|
749
2046
|
def _update_label_table(self):
|
|
750
2047
|
"""Update the label selection table if it exists."""
|
|
@@ -754,6 +2051,9 @@ class BatchCropAnything:
|
|
|
754
2051
|
# Block signals during update
|
|
755
2052
|
self.label_table_widget.blockSignals(True)
|
|
756
2053
|
|
|
2054
|
+
# Completely repopulate the table to ensure it's up to date
|
|
2055
|
+
self._populate_label_table(self.label_table_widget)
|
|
2056
|
+
|
|
757
2057
|
# Update checkboxes
|
|
758
2058
|
for row in range(self.label_table_widget.rowCount()):
|
|
759
2059
|
# Get label ID from the visible column
|
|
@@ -793,10 +2093,6 @@ class BatchCropAnything:
|
|
|
793
2093
|
self.preview_crop()
|
|
794
2094
|
self.viewer.status = "Cleared all selections"
|
|
795
2095
|
|
|
796
|
-
# --------------------------------------------------
|
|
797
|
-
# Image Processing and Export
|
|
798
|
-
# --------------------------------------------------
|
|
799
|
-
|
|
800
2096
|
def preview_crop(self, label_ids=None):
|
|
801
2097
|
"""Preview the crop result with the selected label IDs."""
|
|
802
2098
|
if self.segmentation_result is None or self.image_layer is None:
|
|
@@ -826,20 +2122,29 @@ class BatchCropAnything:
|
|
|
826
2122
|
image = self.original_image.copy()
|
|
827
2123
|
|
|
828
2124
|
# Create mask from selected label IDs
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
mask
|
|
2125
|
+
if self.use_3d:
|
|
2126
|
+
# For 3D data
|
|
2127
|
+
mask = np.zeros_like(self.segmentation_result, dtype=bool)
|
|
2128
|
+
for label_id in label_ids:
|
|
2129
|
+
mask |= self.segmentation_result == label_id
|
|
832
2130
|
|
|
833
|
-
|
|
834
|
-
if len(image.shape) == 2:
|
|
835
|
-
# Grayscale image
|
|
2131
|
+
# Apply mask
|
|
836
2132
|
preview_image = image.copy()
|
|
837
2133
|
preview_image[~mask] = 0
|
|
838
2134
|
else:
|
|
839
|
-
#
|
|
840
|
-
|
|
841
|
-
for
|
|
842
|
-
|
|
2135
|
+
# For 2D data
|
|
2136
|
+
mask = np.zeros_like(self.segmentation_result, dtype=bool)
|
|
2137
|
+
for label_id in label_ids:
|
|
2138
|
+
mask |= self.segmentation_result == label_id
|
|
2139
|
+
|
|
2140
|
+
# Apply mask
|
|
2141
|
+
if len(image.shape) == 2:
|
|
2142
|
+
preview_image = image.copy()
|
|
2143
|
+
preview_image[~mask] = 0
|
|
2144
|
+
else:
|
|
2145
|
+
preview_image = image.copy()
|
|
2146
|
+
for c in range(preview_image.shape[2]):
|
|
2147
|
+
preview_image[:, :, c][~mask] = 0
|
|
843
2148
|
|
|
844
2149
|
# Remove previous preview if exists
|
|
845
2150
|
for layer in list(self.viewer.layers):
|
|
@@ -879,20 +2184,58 @@ class BatchCropAnything:
|
|
|
879
2184
|
image = self.original_image
|
|
880
2185
|
|
|
881
2186
|
# Create mask from all selected label IDs
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
mask
|
|
2187
|
+
if self.use_3d:
|
|
2188
|
+
# For 3D data, create a 3D mask
|
|
2189
|
+
mask = np.zeros_like(self.segmentation_result, dtype=bool)
|
|
2190
|
+
for label_id in self.selected_labels:
|
|
2191
|
+
mask |= self.segmentation_result == label_id
|
|
885
2192
|
|
|
886
|
-
|
|
887
|
-
if len(image.shape) == 2:
|
|
888
|
-
# Grayscale image
|
|
2193
|
+
# Apply mask to image (set everything outside mask to 0)
|
|
889
2194
|
cropped_image = image.copy()
|
|
890
2195
|
cropped_image[~mask] = 0
|
|
2196
|
+
|
|
2197
|
+
# Save label image with same dimensions as original
|
|
2198
|
+
label_image = np.zeros_like(
|
|
2199
|
+
self.segmentation_result, dtype=np.uint32
|
|
2200
|
+
)
|
|
2201
|
+
for label_id in self.selected_labels:
|
|
2202
|
+
label_image[self.segmentation_result == label_id] = (
|
|
2203
|
+
label_id
|
|
2204
|
+
)
|
|
891
2205
|
else:
|
|
892
|
-
#
|
|
893
|
-
|
|
894
|
-
for
|
|
895
|
-
|
|
2206
|
+
# For 2D data, handle as before
|
|
2207
|
+
mask = np.zeros_like(self.segmentation_result, dtype=bool)
|
|
2208
|
+
for label_id in self.selected_labels:
|
|
2209
|
+
mask |= self.segmentation_result == label_id
|
|
2210
|
+
|
|
2211
|
+
# Apply mask to image (set everything outside mask to 0)
|
|
2212
|
+
if len(image.shape) == 2:
|
|
2213
|
+
# Grayscale image
|
|
2214
|
+
cropped_image = image.copy()
|
|
2215
|
+
cropped_image[~mask] = 0
|
|
2216
|
+
|
|
2217
|
+
# Create label image with same dimensions
|
|
2218
|
+
label_image = np.zeros_like(
|
|
2219
|
+
self.segmentation_result, dtype=np.uint32
|
|
2220
|
+
)
|
|
2221
|
+
for label_id in self.selected_labels:
|
|
2222
|
+
label_image[self.segmentation_result == label_id] = (
|
|
2223
|
+
label_id
|
|
2224
|
+
)
|
|
2225
|
+
else:
|
|
2226
|
+
# Color image - mask must be expanded to match channel dimension
|
|
2227
|
+
cropped_image = image.copy()
|
|
2228
|
+
for c in range(cropped_image.shape[2]):
|
|
2229
|
+
cropped_image[:, :, c][~mask] = 0
|
|
2230
|
+
|
|
2231
|
+
# Create label image with 2D dimensions (without channels)
|
|
2232
|
+
label_image = np.zeros_like(
|
|
2233
|
+
self.segmentation_result, dtype=np.uint32
|
|
2234
|
+
)
|
|
2235
|
+
for label_id in self.selected_labels:
|
|
2236
|
+
label_image[self.segmentation_result == label_id] = (
|
|
2237
|
+
label_id
|
|
2238
|
+
)
|
|
896
2239
|
|
|
897
2240
|
# Save cropped image
|
|
898
2241
|
image_path = self.images[self.current_index]
|
|
@@ -900,18 +2243,17 @@ class BatchCropAnything:
|
|
|
900
2243
|
label_str = "_".join(
|
|
901
2244
|
str(lid) for lid in sorted(self.selected_labels)
|
|
902
2245
|
)
|
|
903
|
-
output_path = f"{base_name}_cropped_{label_str}
|
|
904
|
-
|
|
905
|
-
# Save using appropriate method based on file type
|
|
906
|
-
if output_path.lower().endswith((".tif", ".tiff")):
|
|
907
|
-
imwrite(output_path, cropped_image, compression="zlib")
|
|
908
|
-
else:
|
|
909
|
-
from skimage.io import imsave
|
|
910
|
-
|
|
911
|
-
imsave(output_path, cropped_image)
|
|
2246
|
+
output_path = f"{base_name}_cropped_{label_str}.tif"
|
|
912
2247
|
|
|
2248
|
+
# Save using tifffile with explicit parameters for best compatibility
|
|
2249
|
+
imwrite(output_path, cropped_image, compression="zlib")
|
|
913
2250
|
self.viewer.status = f"Saved cropped image to {output_path}"
|
|
914
2251
|
|
|
2252
|
+
# Save the label image with exact same dimensions as original
|
|
2253
|
+
label_output_path = f"{base_name}_labels_{label_str}.tif"
|
|
2254
|
+
imwrite(label_output_path, label_image, compression="zlib")
|
|
2255
|
+
self.viewer.status += f"\nSaved label mask to {label_output_path}"
|
|
2256
|
+
|
|
915
2257
|
# Make sure the segmentation layer is active again
|
|
916
2258
|
if self.label_layer is not None:
|
|
917
2259
|
self.viewer.layers.selection.active = self.label_layer
|
|
@@ -923,76 +2265,44 @@ class BatchCropAnything:
|
|
|
923
2265
|
return False
|
|
924
2266
|
|
|
925
2267
|
|
|
926
|
-
# --------------------------------------------------
|
|
927
|
-
# UI Creation Functions
|
|
928
|
-
# --------------------------------------------------
|
|
929
|
-
|
|
930
|
-
|
|
931
2268
|
def create_crop_widget(processor):
|
|
932
2269
|
"""Create the crop control widget."""
|
|
933
2270
|
crop_widget = QWidget()
|
|
934
2271
|
layout = QVBoxLayout()
|
|
935
|
-
layout.setSpacing(10)
|
|
936
|
-
layout.setContentsMargins(
|
|
937
|
-
10, 10, 10, 10
|
|
938
|
-
) # Add margins around all elements
|
|
2272
|
+
layout.setSpacing(10)
|
|
2273
|
+
layout.setContentsMargins(10, 10, 10, 10)
|
|
939
2274
|
|
|
940
2275
|
# Instructions
|
|
2276
|
+
dimension_type = "3D (TYX/ZYX)" if processor.use_3d else "2D (YX)"
|
|
941
2277
|
instructions_label = QLabel(
|
|
942
|
-
"
|
|
943
|
-
"
|
|
944
|
-
"
|
|
2278
|
+
f"<b>Processing {dimension_type} data</b><br><br>"
|
|
2279
|
+
"To create/edit objects:<br>"
|
|
2280
|
+
"1. <b>Click on the POINTS layer</b> to add positive points<br>"
|
|
2281
|
+
"2. Use Shift+click for negative points to refine segmentation<br>"
|
|
2282
|
+
"3. Click on existing objects in the Segmentation layer to select them<br>"
|
|
2283
|
+
"4. Press 'Crop' to save the selected objects to disk"
|
|
945
2284
|
)
|
|
946
2285
|
instructions_label.setWordWrap(True)
|
|
947
2286
|
layout.addWidget(instructions_label)
|
|
948
2287
|
|
|
949
|
-
#
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
sensitivity_header_layout = QHBoxLayout()
|
|
954
|
-
sensitivity_label = QLabel("Segmentation Sensitivity:")
|
|
955
|
-
sensitivity_value_label = QLabel(f"{processor.sensitivity}")
|
|
956
|
-
sensitivity_header_layout.addWidget(sensitivity_label)
|
|
957
|
-
sensitivity_header_layout.addStretch()
|
|
958
|
-
sensitivity_header_layout.addWidget(sensitivity_value_label)
|
|
959
|
-
sensitivity_layout.addLayout(sensitivity_header_layout)
|
|
960
|
-
|
|
961
|
-
# Slider
|
|
962
|
-
slider_layout = QHBoxLayout()
|
|
963
|
-
sensitivity_slider = QSlider(Qt.Horizontal)
|
|
964
|
-
sensitivity_slider.setMinimum(0)
|
|
965
|
-
sensitivity_slider.setMaximum(100)
|
|
966
|
-
sensitivity_slider.setValue(processor.sensitivity)
|
|
967
|
-
sensitivity_slider.setTickPosition(QSlider.TicksBelow)
|
|
968
|
-
sensitivity_slider.setTickInterval(10)
|
|
969
|
-
slider_layout.addWidget(sensitivity_slider)
|
|
970
|
-
|
|
971
|
-
apply_sensitivity_button = QPushButton("Apply")
|
|
972
|
-
apply_sensitivity_button.setToolTip(
|
|
973
|
-
"Apply sensitivity changes to regenerate segmentation"
|
|
2288
|
+
# Add a button to ensure points layer is active
|
|
2289
|
+
activate_button = QPushButton("Make Points Layer Active")
|
|
2290
|
+
activate_button.clicked.connect(
|
|
2291
|
+
lambda: processor._ensure_points_layer_active()
|
|
974
2292
|
)
|
|
975
|
-
|
|
976
|
-
sensitivity_layout.addLayout(slider_layout)
|
|
2293
|
+
layout.addWidget(activate_button)
|
|
977
2294
|
|
|
978
|
-
#
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
)
|
|
982
|
-
sensitivity_description.setStyleSheet("font-style: italic; color: #666;")
|
|
983
|
-
sensitivity_layout.addWidget(sensitivity_description)
|
|
984
|
-
|
|
985
|
-
layout.addLayout(sensitivity_layout)
|
|
2295
|
+
# Add a "Clear Points" button to reset prompts
|
|
2296
|
+
clear_points_button = QPushButton("Clear Points")
|
|
2297
|
+
layout.addWidget(clear_points_button)
|
|
986
2298
|
|
|
987
2299
|
# Create label table
|
|
988
2300
|
label_table = processor.create_label_table(crop_widget)
|
|
989
|
-
label_table.setMinimumHeight(150)
|
|
990
|
-
label_table.setMaximumHeight(
|
|
991
|
-
300
|
|
992
|
-
) # Set maximum height to prevent taking too much space
|
|
2301
|
+
label_table.setMinimumHeight(150)
|
|
2302
|
+
label_table.setMaximumHeight(300)
|
|
993
2303
|
layout.addWidget(label_table)
|
|
994
2304
|
|
|
995
|
-
#
|
|
2305
|
+
# Selection buttons
|
|
996
2306
|
selection_layout = QHBoxLayout()
|
|
997
2307
|
select_all_button = QPushButton("Select All")
|
|
998
2308
|
clear_selection_button = QPushButton("Clear Selection")
|
|
@@ -1014,7 +2324,7 @@ def create_crop_widget(processor):
|
|
|
1014
2324
|
|
|
1015
2325
|
# Status label
|
|
1016
2326
|
status_label = QLabel(
|
|
1017
|
-
"Ready to process images.
|
|
2327
|
+
"Ready to process images. Click on POINTS layer to add segmentation points."
|
|
1018
2328
|
)
|
|
1019
2329
|
status_label.setWordWrap(True)
|
|
1020
2330
|
layout.addWidget(status_label)
|
|
@@ -1033,36 +2343,51 @@ def create_crop_widget(processor):
|
|
|
1033
2343
|
# Create new table
|
|
1034
2344
|
label_table = processor.create_label_table(crop_widget)
|
|
1035
2345
|
label_table.setMinimumHeight(200)
|
|
1036
|
-
layout.insertWidget(3, label_table) # Insert after
|
|
2346
|
+
layout.insertWidget(3, label_table) # Insert after clear points button
|
|
1037
2347
|
return label_table
|
|
1038
2348
|
|
|
1039
|
-
#
|
|
1040
|
-
def
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
f"Medium sensitivity - Balanced detection (γ={gamma:.2f})"
|
|
2349
|
+
# Add helper method to ensure points layer is active
|
|
2350
|
+
def _ensure_points_layer_active():
|
|
2351
|
+
points_layer = None
|
|
2352
|
+
for layer in list(processor.viewer.layers):
|
|
2353
|
+
if "Points" in layer.name:
|
|
2354
|
+
points_layer = layer
|
|
2355
|
+
break
|
|
2356
|
+
|
|
2357
|
+
if points_layer is not None:
|
|
2358
|
+
processor.viewer.layers.selection.active = points_layer
|
|
2359
|
+
status_label.setText(
|
|
2360
|
+
"Points layer is now active - click to add points"
|
|
1052
2361
|
)
|
|
1053
2362
|
else:
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
)
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
2363
|
+
status_label.setText(
|
|
2364
|
+
"No points layer found. Please load an image first."
|
|
2365
|
+
)
|
|
2366
|
+
|
|
2367
|
+
processor._ensure_points_layer_active = _ensure_points_layer_active
|
|
2368
|
+
|
|
2369
|
+
# Connect button signals
|
|
2370
|
+
def on_clear_points_clicked():
|
|
2371
|
+
# Remove all point layers
|
|
2372
|
+
for layer in list(processor.viewer.layers):
|
|
2373
|
+
if "Points" in layer.name:
|
|
2374
|
+
processor.viewer.layers.remove(layer)
|
|
2375
|
+
|
|
2376
|
+
# Reset point tracking attributes
|
|
2377
|
+
if hasattr(processor, "points_data"):
|
|
2378
|
+
processor.points_data = {}
|
|
2379
|
+
processor.points_labels = {}
|
|
2380
|
+
|
|
2381
|
+
if hasattr(processor, "obj_points"):
|
|
2382
|
+
processor.obj_points = {}
|
|
2383
|
+
processor.obj_labels = {}
|
|
2384
|
+
|
|
2385
|
+
# Re-create empty points layer
|
|
2386
|
+
processor._update_label_layer()
|
|
2387
|
+
processor._ensure_points_layer_active()
|
|
2388
|
+
|
|
1064
2389
|
status_label.setText(
|
|
1065
|
-
|
|
2390
|
+
"Cleared all points. Click on Points layer to add new points."
|
|
1066
2391
|
)
|
|
1067
2392
|
|
|
1068
2393
|
def on_select_all_clicked():
|
|
@@ -1086,117 +2411,83 @@ def create_crop_widget(processor):
|
|
|
1086
2411
|
)
|
|
1087
2412
|
|
|
1088
2413
|
def on_next_clicked():
|
|
2414
|
+
# Clear points before moving to next image
|
|
2415
|
+
on_clear_points_clicked()
|
|
2416
|
+
|
|
1089
2417
|
if not processor.next_image():
|
|
1090
2418
|
next_button.setEnabled(False)
|
|
1091
2419
|
else:
|
|
1092
2420
|
prev_button.setEnabled(True)
|
|
1093
2421
|
replace_table_widget()
|
|
1094
|
-
# Reset sensitivity slider to default
|
|
1095
|
-
sensitivity_slider.setValue(processor.sensitivity)
|
|
1096
|
-
sensitivity_value_label.setText(f"{processor.sensitivity}")
|
|
1097
2422
|
status_label.setText(
|
|
1098
2423
|
f"Showing image {processor.current_index + 1}/{len(processor.images)}"
|
|
1099
2424
|
)
|
|
2425
|
+
processor._ensure_points_layer_active()
|
|
1100
2426
|
|
|
1101
2427
|
def on_prev_clicked():
|
|
2428
|
+
# Clear points before moving to previous image
|
|
2429
|
+
on_clear_points_clicked()
|
|
2430
|
+
|
|
1102
2431
|
if not processor.previous_image():
|
|
1103
2432
|
prev_button.setEnabled(False)
|
|
1104
2433
|
else:
|
|
1105
2434
|
next_button.setEnabled(True)
|
|
1106
2435
|
replace_table_widget()
|
|
1107
|
-
# Reset sensitivity slider to default
|
|
1108
|
-
sensitivity_slider.setValue(processor.sensitivity)
|
|
1109
|
-
sensitivity_value_label.setText(f"{processor.sensitivity}")
|
|
1110
2436
|
status_label.setText(
|
|
1111
2437
|
f"Showing image {processor.current_index + 1}/{len(processor.images)}"
|
|
1112
2438
|
)
|
|
2439
|
+
processor._ensure_points_layer_active()
|
|
1113
2440
|
|
|
1114
|
-
|
|
1115
|
-
apply_sensitivity_button.clicked.connect(on_apply_sensitivity_clicked)
|
|
2441
|
+
clear_points_button.clicked.connect(on_clear_points_clicked)
|
|
1116
2442
|
select_all_button.clicked.connect(on_select_all_clicked)
|
|
1117
2443
|
clear_selection_button.clicked.connect(on_clear_selection_clicked)
|
|
1118
2444
|
crop_button.clicked.connect(on_crop_clicked)
|
|
1119
2445
|
next_button.clicked.connect(on_next_clicked)
|
|
1120
2446
|
prev_button.clicked.connect(on_prev_clicked)
|
|
2447
|
+
activate_button.clicked.connect(_ensure_points_layer_active)
|
|
1121
2448
|
|
|
1122
2449
|
return crop_widget
|
|
1123
2450
|
|
|
1124
2451
|
|
|
1125
|
-
# --------------------------------------------------
|
|
1126
|
-
# Napari Plugin Functions
|
|
1127
|
-
# --------------------------------------------------
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
2452
|
@magicgui(
|
|
1131
2453
|
call_button="Start Batch Crop Anything",
|
|
1132
2454
|
folder_path={"label": "Folder Path", "widget_type": "LineEdit"},
|
|
2455
|
+
data_dimensions={
|
|
2456
|
+
"label": "Data Dimensions",
|
|
2457
|
+
"choices": ["YX (2D)", "TYX/ZYX (3D)"],
|
|
2458
|
+
},
|
|
1133
2459
|
)
|
|
1134
2460
|
def batch_crop_anything(
|
|
1135
2461
|
folder_path: str,
|
|
2462
|
+
data_dimensions: str,
|
|
1136
2463
|
viewer: Viewer = None,
|
|
1137
2464
|
):
|
|
1138
|
-
"""MagicGUI widget for starting Batch Crop Anything."""
|
|
1139
|
-
# Check if
|
|
2465
|
+
"""MagicGUI widget for starting Batch Crop Anything using SAM2."""
|
|
2466
|
+
# Check if SAM2 is available
|
|
1140
2467
|
try:
|
|
1141
|
-
|
|
1142
|
-
# from mobile_sam import sam_model_registry
|
|
1143
|
-
|
|
1144
|
-
# Check if the required files are included with the package
|
|
1145
|
-
try:
|
|
1146
|
-
import importlib.util
|
|
1147
|
-
import os
|
|
1148
|
-
|
|
1149
|
-
mobile_sam_spec = importlib.util.find_spec("mobile_sam")
|
|
1150
|
-
if mobile_sam_spec is None:
|
|
1151
|
-
raise ImportError("mobile_sam package not found")
|
|
1152
|
-
|
|
1153
|
-
mobile_sam_path = os.path.dirname(mobile_sam_spec.origin)
|
|
1154
|
-
|
|
1155
|
-
# Check for model file in package
|
|
1156
|
-
model_found = False
|
|
1157
|
-
checkpoint_paths = [
|
|
1158
|
-
os.path.join(mobile_sam_path, "weights", "mobile_sam.pt"),
|
|
1159
|
-
os.path.join(mobile_sam_path, "mobile_sam.pt"),
|
|
1160
|
-
os.path.join(
|
|
1161
|
-
os.path.dirname(mobile_sam_path),
|
|
1162
|
-
"weights",
|
|
1163
|
-
"mobile_sam.pt",
|
|
1164
|
-
),
|
|
1165
|
-
os.path.join(
|
|
1166
|
-
os.path.expanduser("~"), "models", "mobile_sam.pt"
|
|
1167
|
-
),
|
|
1168
|
-
"/opt/T-MIDAS/models/mobile_sam.pt",
|
|
1169
|
-
os.path.join(os.getcwd(), "mobile_sam.pt"),
|
|
1170
|
-
]
|
|
1171
|
-
|
|
1172
|
-
for path in checkpoint_paths:
|
|
1173
|
-
if os.path.exists(path):
|
|
1174
|
-
model_found = True
|
|
1175
|
-
break
|
|
1176
|
-
|
|
1177
|
-
if not model_found:
|
|
1178
|
-
QMessageBox.warning(
|
|
1179
|
-
None,
|
|
1180
|
-
"Model File Missing",
|
|
1181
|
-
"Mobile-SAM model weights (mobile_sam.pt) not found. You'll be prompted to locate it when starting the tool.\n\n"
|
|
1182
|
-
"You can download it from: https://github.com/ChaoningZhang/MobileSAM/tree/master/weights",
|
|
1183
|
-
)
|
|
1184
|
-
except (ImportError, AttributeError) as e:
|
|
1185
|
-
print(f"Warning checking for model file: {str(e)}")
|
|
2468
|
+
import importlib.util
|
|
1186
2469
|
|
|
2470
|
+
sam2_spec = importlib.util.find_spec("sam2")
|
|
2471
|
+
if sam2_spec is None:
|
|
2472
|
+
QMessageBox.critical(
|
|
2473
|
+
None,
|
|
2474
|
+
"Missing Dependency",
|
|
2475
|
+
"SAM2 not found. Please follow installation instructions at:\n"
|
|
2476
|
+
"https://github.com/MercaderLabAnatomy/napari-tmidas?tab=readme-ov-file#dependencies\n",
|
|
2477
|
+
)
|
|
2478
|
+
return
|
|
1187
2479
|
except ImportError:
|
|
1188
2480
|
QMessageBox.critical(
|
|
1189
2481
|
None,
|
|
1190
2482
|
"Missing Dependency",
|
|
1191
|
-
"
|
|
1192
|
-
"
|
|
1193
|
-
"You'll also need to download the model weights file (mobile_sam.pt) from:\n"
|
|
1194
|
-
"https://github.com/ChaoningZhang/MobileSAM/tree/master/weights",
|
|
2483
|
+
"SAM2 package cannot be imported. Please follow installation instructions at\n"
|
|
2484
|
+
"https://github.com/MercaderLabAnatomy/napari-tmidas?tab=readme-ov-file#dependencies",
|
|
1195
2485
|
)
|
|
1196
2486
|
return
|
|
1197
2487
|
|
|
1198
|
-
# Initialize processor
|
|
1199
|
-
|
|
2488
|
+
# Initialize processor with the selected dimensions mode
|
|
2489
|
+
use_3d = "TYX/ZYX" in data_dimensions
|
|
2490
|
+
processor = BatchCropAnything(viewer, use_3d=use_3d)
|
|
1200
2491
|
processor.load_images(folder_path)
|
|
1201
2492
|
|
|
1202
2493
|
# Create UI
|
|
@@ -1205,13 +2496,9 @@ def batch_crop_anything(
|
|
|
1205
2496
|
# Wrap the widget in a scroll area
|
|
1206
2497
|
scroll_area = QScrollArea()
|
|
1207
2498
|
scroll_area.setWidget(crop_widget)
|
|
1208
|
-
scroll_area.setWidgetResizable(
|
|
1209
|
-
|
|
1210
|
-
)
|
|
1211
|
-
scroll_area.setFrameShape(QScrollArea.NoFrame) # Hide the frame
|
|
1212
|
-
scroll_area.setMinimumHeight(
|
|
1213
|
-
500
|
|
1214
|
-
) # Set a minimum height to ensure visibility
|
|
2499
|
+
scroll_area.setWidgetResizable(True)
|
|
2500
|
+
scroll_area.setFrameShape(QScrollArea.NoFrame)
|
|
2501
|
+
scroll_area.setMinimumHeight(500)
|
|
1215
2502
|
|
|
1216
2503
|
# Add scroll area to viewer
|
|
1217
2504
|
viewer.window.add_dock_widget(scroll_area, name="Crop Controls")
|