napari-tmidas 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- napari_tmidas/__init__.py +3 -0
- napari_tmidas/_crop_anything.py +1113 -0
- napari_tmidas/_file_conversion.py +488 -256
- napari_tmidas/_file_selector.py +267 -101
- napari_tmidas/_label_inspection.py +10 -0
- napari_tmidas/_roi_colocalization.py +1175 -0
- napari_tmidas/_version.py +2 -2
- napari_tmidas/napari.yaml +10 -0
- napari_tmidas/processing_functions/basic.py +83 -0
- napari_tmidas/processing_functions/colocalization.py +242 -0
- napari_tmidas/processing_functions/skimage_filters.py +17 -32
- {napari_tmidas-0.1.5.dist-info → napari_tmidas-0.1.7.dist-info}/METADATA +44 -14
- napari_tmidas-0.1.7.dist-info/RECORD +29 -0
- napari_tmidas-0.1.5.dist-info/RECORD +0 -26
- {napari_tmidas-0.1.5.dist-info → napari_tmidas-0.1.7.dist-info}/WHEEL +0 -0
- {napari_tmidas-0.1.5.dist-info → napari_tmidas-0.1.7.dist-info}/entry_points.txt +0 -0
- {napari_tmidas-0.1.5.dist-info → napari_tmidas-0.1.7.dist-info}/licenses/LICENSE +0 -0
- {napari_tmidas-0.1.5.dist-info → napari_tmidas-0.1.7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1113 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Batch Crop Anything - A Napari plugin for interactive image cropping
|
|
3
|
+
|
|
4
|
+
This plugin combines Segment Anything Model (SAM) for automatic object detection with
|
|
5
|
+
an interactive interface for selecting and cropping objects from images.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import torch
|
|
12
|
+
from magicgui import magicgui
|
|
13
|
+
from napari.layers import Labels
|
|
14
|
+
from napari.viewer import Viewer
|
|
15
|
+
from qtpy.QtCore import Qt
|
|
16
|
+
from qtpy.QtWidgets import (
|
|
17
|
+
QCheckBox,
|
|
18
|
+
QFileDialog,
|
|
19
|
+
QHBoxLayout,
|
|
20
|
+
QHeaderView,
|
|
21
|
+
QLabel,
|
|
22
|
+
QMessageBox,
|
|
23
|
+
QPushButton,
|
|
24
|
+
QScrollArea,
|
|
25
|
+
QSlider,
|
|
26
|
+
QTableWidget,
|
|
27
|
+
QTableWidgetItem,
|
|
28
|
+
QVBoxLayout,
|
|
29
|
+
QWidget,
|
|
30
|
+
)
|
|
31
|
+
from skimage.io import imread
|
|
32
|
+
from tifffile import imwrite
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class BatchCropAnything:
|
|
36
|
+
"""
|
|
37
|
+
Class for processing images with Segment Anything and cropping selected objects.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self, viewer: Viewer):
|
|
41
|
+
"""Initialize the BatchCropAnything processor."""
|
|
42
|
+
# Core components
|
|
43
|
+
self.viewer = viewer
|
|
44
|
+
self.images = []
|
|
45
|
+
self.current_index = 0
|
|
46
|
+
|
|
47
|
+
# Image and segmentation data
|
|
48
|
+
self.original_image = None
|
|
49
|
+
self.segmentation_result = None
|
|
50
|
+
self.current_image_for_segmentation = None
|
|
51
|
+
|
|
52
|
+
# UI references
|
|
53
|
+
self.image_layer = None
|
|
54
|
+
self.label_layer = None
|
|
55
|
+
self.label_table_widget = None
|
|
56
|
+
|
|
57
|
+
# State tracking
|
|
58
|
+
self.selected_labels = set()
|
|
59
|
+
self.label_info = {}
|
|
60
|
+
|
|
61
|
+
# Segmentation parameters
|
|
62
|
+
self.sensitivity = 50 # Default sensitivity (0-100 scale)
|
|
63
|
+
|
|
64
|
+
# Initialize the SAM model
|
|
65
|
+
self._initialize_sam()
|
|
66
|
+
|
|
67
|
+
# --------------------------------------------------
|
|
68
|
+
# Model Initialization
|
|
69
|
+
# --------------------------------------------------
|
|
70
|
+
|
|
71
|
+
def _initialize_sam(self):
|
|
72
|
+
"""Initialize the Segment Anything Model."""
|
|
73
|
+
try:
|
|
74
|
+
# Import required modules
|
|
75
|
+
from mobile_sam import (
|
|
76
|
+
SamAutomaticMaskGenerator,
|
|
77
|
+
sam_model_registry,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Setup device
|
|
81
|
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
82
|
+
model_type = "vit_t"
|
|
83
|
+
|
|
84
|
+
# Find the model weights file
|
|
85
|
+
checkpoint_path = self._find_sam_checkpoint()
|
|
86
|
+
if checkpoint_path is None:
|
|
87
|
+
self.mobile_sam = None
|
|
88
|
+
self.mask_generator = None
|
|
89
|
+
return
|
|
90
|
+
|
|
91
|
+
# Initialize the model
|
|
92
|
+
self.mobile_sam = sam_model_registry[model_type](
|
|
93
|
+
checkpoint=checkpoint_path
|
|
94
|
+
)
|
|
95
|
+
self.mobile_sam.to(device=self.device)
|
|
96
|
+
self.mobile_sam.eval()
|
|
97
|
+
|
|
98
|
+
# Create mask generator with default parameters
|
|
99
|
+
self.mask_generator = SamAutomaticMaskGenerator(self.mobile_sam)
|
|
100
|
+
self.viewer.status = f"Initialized SAM model from {checkpoint_path} on {self.device}"
|
|
101
|
+
|
|
102
|
+
except (ImportError, Exception) as e:
|
|
103
|
+
self.viewer.status = f"Error initializing SAM: {str(e)}"
|
|
104
|
+
self.mobile_sam = None
|
|
105
|
+
self.mask_generator = None
|
|
106
|
+
|
|
107
|
+
def _find_sam_checkpoint(self):
|
|
108
|
+
"""Find the SAM model checkpoint file."""
|
|
109
|
+
try:
|
|
110
|
+
import importlib.util
|
|
111
|
+
|
|
112
|
+
# Find the mobile_sam package location
|
|
113
|
+
mobile_sam_spec = importlib.util.find_spec("mobile_sam")
|
|
114
|
+
if mobile_sam_spec is None:
|
|
115
|
+
raise ImportError("mobile_sam package not found")
|
|
116
|
+
|
|
117
|
+
mobile_sam_path = os.path.dirname(mobile_sam_spec.origin)
|
|
118
|
+
|
|
119
|
+
# Check common locations for the model file
|
|
120
|
+
checkpoint_paths = [
|
|
121
|
+
os.path.join(mobile_sam_path, "weights", "mobile_sam.pt"),
|
|
122
|
+
os.path.join(mobile_sam_path, "mobile_sam.pt"),
|
|
123
|
+
os.path.join(
|
|
124
|
+
os.path.dirname(mobile_sam_path),
|
|
125
|
+
"weights",
|
|
126
|
+
"mobile_sam.pt",
|
|
127
|
+
),
|
|
128
|
+
os.path.join(
|
|
129
|
+
os.path.expanduser("~"), "models", "mobile_sam.pt"
|
|
130
|
+
),
|
|
131
|
+
"/opt/T-MIDAS/models/mobile_sam.pt",
|
|
132
|
+
os.path.join(os.getcwd(), "mobile_sam.pt"),
|
|
133
|
+
]
|
|
134
|
+
|
|
135
|
+
for path in checkpoint_paths:
|
|
136
|
+
if os.path.exists(path):
|
|
137
|
+
return path
|
|
138
|
+
|
|
139
|
+
# If model not found, ask user
|
|
140
|
+
QMessageBox.information(
|
|
141
|
+
None,
|
|
142
|
+
"Model Not Found",
|
|
143
|
+
"Mobile-SAM model weights not found. Please select the mobile_sam.pt file.",
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
checkpoint_path, _ = QFileDialog.getOpenFileName(
|
|
147
|
+
None, "Select Mobile-SAM model file", "", "Model Files (*.pt)"
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
return checkpoint_path if checkpoint_path else None
|
|
151
|
+
|
|
152
|
+
except (ImportError, Exception) as e:
|
|
153
|
+
self.viewer.status = f"Error finding SAM checkpoint: {str(e)}"
|
|
154
|
+
return None
|
|
155
|
+
|
|
156
|
+
# --------------------------------------------------
|
|
157
|
+
# Image Loading and Navigation
|
|
158
|
+
# --------------------------------------------------
|
|
159
|
+
|
|
160
|
+
def load_images(self, folder_path: str):
|
|
161
|
+
"""Load images from the specified folder path."""
|
|
162
|
+
if not os.path.exists(folder_path):
|
|
163
|
+
self.viewer.status = f"Folder not found: {folder_path}"
|
|
164
|
+
return
|
|
165
|
+
|
|
166
|
+
files = os.listdir(folder_path)
|
|
167
|
+
self.images = [
|
|
168
|
+
os.path.join(folder_path, file)
|
|
169
|
+
for file in files
|
|
170
|
+
if file.lower().endswith(
|
|
171
|
+
(".tif", ".tiff", ".png", ".jpg", ".jpeg")
|
|
172
|
+
)
|
|
173
|
+
and not file.endswith(("_labels.tif", "_cropped.tif", "_cropped_"))
|
|
174
|
+
]
|
|
175
|
+
|
|
176
|
+
if not self.images:
|
|
177
|
+
self.viewer.status = "No compatible images found in the folder."
|
|
178
|
+
return
|
|
179
|
+
|
|
180
|
+
self.viewer.status = f"Found {len(self.images)} images."
|
|
181
|
+
self.current_index = 0
|
|
182
|
+
self._load_current_image()
|
|
183
|
+
|
|
184
|
+
def next_image(self):
|
|
185
|
+
"""Move to the next image."""
|
|
186
|
+
if not self.images:
|
|
187
|
+
self.viewer.status = "No images to process."
|
|
188
|
+
return False
|
|
189
|
+
|
|
190
|
+
# Check if we're already at the last image
|
|
191
|
+
if self.current_index >= len(self.images) - 1:
|
|
192
|
+
self.viewer.status = "No more images. Processing complete."
|
|
193
|
+
return False
|
|
194
|
+
|
|
195
|
+
# Move to the next image
|
|
196
|
+
self.current_index += 1
|
|
197
|
+
|
|
198
|
+
# Clear selected labels
|
|
199
|
+
self.selected_labels = set()
|
|
200
|
+
|
|
201
|
+
# Clear the table reference (will be recreated)
|
|
202
|
+
self.label_table_widget = None
|
|
203
|
+
|
|
204
|
+
# Load the next image
|
|
205
|
+
self._load_current_image()
|
|
206
|
+
return True
|
|
207
|
+
|
|
208
|
+
def previous_image(self):
|
|
209
|
+
"""Move to the previous image."""
|
|
210
|
+
if not self.images:
|
|
211
|
+
self.viewer.status = "No images to process."
|
|
212
|
+
return False
|
|
213
|
+
|
|
214
|
+
# Check if we're already at the first image
|
|
215
|
+
if self.current_index <= 0:
|
|
216
|
+
self.viewer.status = "Already at the first image."
|
|
217
|
+
return False
|
|
218
|
+
|
|
219
|
+
# Move to the previous image
|
|
220
|
+
self.current_index -= 1
|
|
221
|
+
|
|
222
|
+
# Clear selected labels
|
|
223
|
+
self.selected_labels = set()
|
|
224
|
+
|
|
225
|
+
# Clear the table reference (will be recreated)
|
|
226
|
+
self.label_table_widget = None
|
|
227
|
+
|
|
228
|
+
# Load the previous image
|
|
229
|
+
self._load_current_image()
|
|
230
|
+
return True
|
|
231
|
+
|
|
232
|
+
def _load_current_image(self):
|
|
233
|
+
"""Load the current image and generate segmentation."""
|
|
234
|
+
if not self.images:
|
|
235
|
+
self.viewer.status = "No images to process."
|
|
236
|
+
return
|
|
237
|
+
|
|
238
|
+
if self.mobile_sam is None or self.mask_generator is None:
|
|
239
|
+
self.viewer.status = (
|
|
240
|
+
"SAM model not initialized. Cannot segment images."
|
|
241
|
+
)
|
|
242
|
+
return
|
|
243
|
+
|
|
244
|
+
image_path = self.images[self.current_index]
|
|
245
|
+
self.viewer.status = f"Processing {os.path.basename(image_path)}"
|
|
246
|
+
|
|
247
|
+
try:
|
|
248
|
+
# Clear existing layers
|
|
249
|
+
self.viewer.layers.clear()
|
|
250
|
+
|
|
251
|
+
# Load and process image
|
|
252
|
+
self.original_image = imread(image_path)
|
|
253
|
+
|
|
254
|
+
# Ensure image is 8-bit for SAM display (keeping original for saving)
|
|
255
|
+
if self.original_image.dtype != np.uint8:
|
|
256
|
+
image_for_display = (
|
|
257
|
+
self.original_image / np.amax(self.original_image) * 255
|
|
258
|
+
).astype(np.uint8)
|
|
259
|
+
else:
|
|
260
|
+
image_for_display = self.original_image
|
|
261
|
+
|
|
262
|
+
# Add image to viewer
|
|
263
|
+
self.image_layer = self.viewer.add_image(
|
|
264
|
+
image_for_display,
|
|
265
|
+
name=f"Image ({os.path.basename(image_path)})",
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
# Generate segmentation
|
|
269
|
+
self._generate_segmentation(image_for_display)
|
|
270
|
+
|
|
271
|
+
except (Exception, ValueError) as e:
|
|
272
|
+
import traceback
|
|
273
|
+
|
|
274
|
+
self.viewer.status = f"Error processing image: {str(e)}"
|
|
275
|
+
traceback.print_exc()
|
|
276
|
+
# Create empty segmentation in case of error
|
|
277
|
+
if (
|
|
278
|
+
hasattr(self, "original_image")
|
|
279
|
+
and self.original_image is not None
|
|
280
|
+
):
|
|
281
|
+
self.segmentation_result = np.zeros(
|
|
282
|
+
self.original_image.shape[:2], dtype=np.uint32
|
|
283
|
+
)
|
|
284
|
+
self.label_layer = self.viewer.add_labels(
|
|
285
|
+
self.segmentation_result, name="Error: No Segmentation"
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
# --------------------------------------------------
|
|
289
|
+
# Segmentation Generation and Control
|
|
290
|
+
# --------------------------------------------------
|
|
291
|
+
|
|
292
|
+
def _generate_segmentation(self, image):
|
|
293
|
+
"""Generate segmentation for the current image."""
|
|
294
|
+
# Prepare for SAM (add color channel if needed)
|
|
295
|
+
if len(image.shape) == 2:
|
|
296
|
+
image_for_sam = image[:, :, np.newaxis].repeat(3, axis=2)
|
|
297
|
+
else:
|
|
298
|
+
image_for_sam = image
|
|
299
|
+
|
|
300
|
+
# Store the current image for later regeneration if sensitivity changes
|
|
301
|
+
self.current_image_for_segmentation = image_for_sam
|
|
302
|
+
|
|
303
|
+
# Generate segmentation with current sensitivity
|
|
304
|
+
self.generate_segmentation_with_sensitivity()
|
|
305
|
+
|
|
306
|
+
def generate_segmentation_with_sensitivity(self, sensitivity=None):
|
|
307
|
+
"""Generate segmentation with the specified sensitivity."""
|
|
308
|
+
if sensitivity is not None:
|
|
309
|
+
self.sensitivity = sensitivity
|
|
310
|
+
|
|
311
|
+
if self.mobile_sam is None or self.mask_generator is None:
|
|
312
|
+
self.viewer.status = (
|
|
313
|
+
"SAM model not initialized. Cannot segment images."
|
|
314
|
+
)
|
|
315
|
+
return
|
|
316
|
+
|
|
317
|
+
if self.current_image_for_segmentation is None:
|
|
318
|
+
self.viewer.status = "No image loaded for segmentation."
|
|
319
|
+
return
|
|
320
|
+
|
|
321
|
+
try:
|
|
322
|
+
# Map sensitivity (0-100) to SAM parameters
|
|
323
|
+
# Higher sensitivity (100) = lower thresholds = more objects detected
|
|
324
|
+
# Lower sensitivity (0) = higher thresholds = fewer objects detected
|
|
325
|
+
|
|
326
|
+
# pred_iou_thresh range: 0.92 (low sensitivity) to 0.75 (high sensitivity)
|
|
327
|
+
pred_iou = 0.92 - (self.sensitivity / 100) * 0.17
|
|
328
|
+
|
|
329
|
+
# stability_score_thresh range: 0.97 (low sensitivity) to 0.85 (high sensitivity)
|
|
330
|
+
stability = 0.97 - (self.sensitivity / 100) * 0.12
|
|
331
|
+
|
|
332
|
+
# min_mask_region_area range: 300 (low sensitivity) to 30 (high sensitivity)
|
|
333
|
+
min_area = 300 - (self.sensitivity / 100) * 270
|
|
334
|
+
|
|
335
|
+
# Configure mask generator with sensitivity-adjusted parameters
|
|
336
|
+
self.mask_generator.pred_iou_thresh = pred_iou
|
|
337
|
+
self.mask_generator.stability_score_thresh = stability
|
|
338
|
+
self.mask_generator.min_mask_region_area = min_area
|
|
339
|
+
|
|
340
|
+
# Apply gamma correction based on sensitivity
|
|
341
|
+
# Low sensitivity: gamma > 1 (brighten image)
|
|
342
|
+
# High sensitivity: gamma < 1 (darken image)
|
|
343
|
+
gamma = (
|
|
344
|
+
1.5 - (self.sensitivity / 100) * 1.0
|
|
345
|
+
) # Range from 1.5 to 0.5
|
|
346
|
+
|
|
347
|
+
# Apply gamma correction to the input image
|
|
348
|
+
image_for_processing = self.current_image_for_segmentation.copy()
|
|
349
|
+
|
|
350
|
+
# Convert to float for proper gamma correction
|
|
351
|
+
image_float = image_for_processing.astype(np.float32) / 255.0
|
|
352
|
+
|
|
353
|
+
# Apply gamma correction
|
|
354
|
+
image_gamma = np.power(image_float, gamma)
|
|
355
|
+
|
|
356
|
+
# Convert back to uint8
|
|
357
|
+
image_gamma = (image_gamma * 255).astype(np.uint8)
|
|
358
|
+
|
|
359
|
+
self.viewer.status = f"Generating segmentation with sensitivity {self.sensitivity} (gamma={gamma:.2f})..."
|
|
360
|
+
|
|
361
|
+
# Generate masks with gamma-corrected image
|
|
362
|
+
masks = self.mask_generator.generate(image_gamma)
|
|
363
|
+
self.viewer.status = f"Generated {len(masks)} masks"
|
|
364
|
+
|
|
365
|
+
if not masks:
|
|
366
|
+
self.viewer.status = (
|
|
367
|
+
"No segments detected. Try increasing the sensitivity."
|
|
368
|
+
)
|
|
369
|
+
# Create empty label layer
|
|
370
|
+
shape = self.current_image_for_segmentation.shape[:2]
|
|
371
|
+
self.segmentation_result = np.zeros(shape, dtype=np.uint32)
|
|
372
|
+
|
|
373
|
+
# Remove existing label layer if exists
|
|
374
|
+
for layer in list(self.viewer.layers):
|
|
375
|
+
if (
|
|
376
|
+
isinstance(layer, Labels)
|
|
377
|
+
and "Segmentation" in layer.name
|
|
378
|
+
):
|
|
379
|
+
self.viewer.layers.remove(layer)
|
|
380
|
+
|
|
381
|
+
# Add new empty label layer
|
|
382
|
+
self.label_layer = self.viewer.add_labels(
|
|
383
|
+
self.segmentation_result,
|
|
384
|
+
name=f"Segmentation ({os.path.basename(self.images[self.current_index])})",
|
|
385
|
+
opacity=0.7,
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
# Make the label layer active
|
|
389
|
+
self.viewer.layers.selection.active = self.label_layer
|
|
390
|
+
return
|
|
391
|
+
|
|
392
|
+
# Process segmentation masks
|
|
393
|
+
self._process_segmentation_masks(
|
|
394
|
+
masks, self.current_image_for_segmentation.shape[:2]
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
# Clear selected labels since segmentation has changed
|
|
398
|
+
self.selected_labels = set()
|
|
399
|
+
|
|
400
|
+
# Update table if it exists
|
|
401
|
+
if self.label_table_widget:
|
|
402
|
+
self._populate_label_table(self.label_table_widget)
|
|
403
|
+
|
|
404
|
+
except (Exception, ValueError) as e:
|
|
405
|
+
import traceback
|
|
406
|
+
|
|
407
|
+
self.viewer.status = f"Error generating segmentation: {str(e)}"
|
|
408
|
+
traceback.print_exc()
|
|
409
|
+
|
|
410
|
+
def _process_segmentation_masks(self, masks, shape):
|
|
411
|
+
"""Process segmentation masks and create label layer."""
|
|
412
|
+
# Create label image from masks
|
|
413
|
+
labels = np.zeros(shape, dtype=np.uint32)
|
|
414
|
+
self.label_info = {} # Reset label info
|
|
415
|
+
|
|
416
|
+
for i, mask_data in enumerate(masks):
|
|
417
|
+
mask = mask_data["segmentation"]
|
|
418
|
+
label_id = i + 1 # Start label IDs from 1
|
|
419
|
+
labels[mask] = label_id
|
|
420
|
+
|
|
421
|
+
# Calculate label information
|
|
422
|
+
area = np.sum(mask)
|
|
423
|
+
y_indices, x_indices = np.where(mask)
|
|
424
|
+
center_y = np.mean(y_indices) if len(y_indices) > 0 else 0
|
|
425
|
+
center_x = np.mean(x_indices) if len(x_indices) > 0 else 0
|
|
426
|
+
|
|
427
|
+
# Store label info
|
|
428
|
+
self.label_info[label_id] = {
|
|
429
|
+
"area": area,
|
|
430
|
+
"center_y": center_y,
|
|
431
|
+
"center_x": center_x,
|
|
432
|
+
"score": mask_data.get("stability_score", 0),
|
|
433
|
+
}
|
|
434
|
+
|
|
435
|
+
# Sort labels by area (largest first)
|
|
436
|
+
self.label_info = dict(
|
|
437
|
+
sorted(
|
|
438
|
+
self.label_info.items(),
|
|
439
|
+
key=lambda item: item[1]["area"],
|
|
440
|
+
reverse=True,
|
|
441
|
+
)
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
# Save segmentation result
|
|
445
|
+
self.segmentation_result = labels
|
|
446
|
+
|
|
447
|
+
# Remove existing label layer if exists
|
|
448
|
+
for layer in list(self.viewer.layers):
|
|
449
|
+
if isinstance(layer, Labels) and "Segmentation" in layer.name:
|
|
450
|
+
self.viewer.layers.remove(layer)
|
|
451
|
+
|
|
452
|
+
# Add label layer to viewer
|
|
453
|
+
self.label_layer = self.viewer.add_labels(
|
|
454
|
+
labels,
|
|
455
|
+
name=f"Segmentation ({os.path.basename(self.images[self.current_index])})",
|
|
456
|
+
opacity=0.7,
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
# Make the label layer active by default
|
|
460
|
+
self.viewer.layers.selection.active = self.label_layer
|
|
461
|
+
|
|
462
|
+
# Disconnect existing callbacks if any
|
|
463
|
+
if (
|
|
464
|
+
hasattr(self, "label_layer")
|
|
465
|
+
and self.label_layer is not None
|
|
466
|
+
and hasattr(self.label_layer, "mouse_drag_callbacks")
|
|
467
|
+
):
|
|
468
|
+
# Remove old callbacks
|
|
469
|
+
for callback in list(self.label_layer.mouse_drag_callbacks):
|
|
470
|
+
self.label_layer.mouse_drag_callbacks.remove(callback)
|
|
471
|
+
|
|
472
|
+
# Connect mouse click event to label selection
|
|
473
|
+
self.label_layer.mouse_drag_callbacks.append(self._on_label_clicked)
|
|
474
|
+
|
|
475
|
+
# image_name = os.path.basename(self.images[self.current_index])
|
|
476
|
+
self.viewer.status = f"Loaded image {self.current_index + 1}/{len(self.images)} - Found {len(masks)} segments"
|
|
477
|
+
|
|
478
|
+
# --------------------------------------------------
|
|
479
|
+
# Label Selection and UI Elements
|
|
480
|
+
# --------------------------------------------------
|
|
481
|
+
|
|
482
|
+
def _on_label_clicked(self, layer, event):
|
|
483
|
+
"""Handle label selection on mouse click."""
|
|
484
|
+
try:
|
|
485
|
+
# Only process clicks, not drags
|
|
486
|
+
if event.type != "mouse_press":
|
|
487
|
+
return
|
|
488
|
+
|
|
489
|
+
# Get coordinates of mouse click
|
|
490
|
+
coords = np.round(event.position).astype(int)
|
|
491
|
+
|
|
492
|
+
# Make sure coordinates are within bounds
|
|
493
|
+
shape = self.segmentation_result.shape
|
|
494
|
+
if (
|
|
495
|
+
coords[0] < 0
|
|
496
|
+
or coords[1] < 0
|
|
497
|
+
or coords[0] >= shape[0]
|
|
498
|
+
or coords[1] >= shape[1]
|
|
499
|
+
):
|
|
500
|
+
return
|
|
501
|
+
|
|
502
|
+
# Get the label ID at the clicked position
|
|
503
|
+
label_id = self.segmentation_result[coords[0], coords[1]]
|
|
504
|
+
|
|
505
|
+
# Skip if background (0) is clicked
|
|
506
|
+
if label_id == 0:
|
|
507
|
+
return
|
|
508
|
+
|
|
509
|
+
# Toggle the label selection
|
|
510
|
+
if label_id in self.selected_labels:
|
|
511
|
+
self.selected_labels.remove(label_id)
|
|
512
|
+
self.viewer.status = f"Deselected label ID: {label_id} | Selected labels: {self.selected_labels}"
|
|
513
|
+
else:
|
|
514
|
+
self.selected_labels.add(label_id)
|
|
515
|
+
self.viewer.status = f"Selected label ID: {label_id} | Selected labels: {self.selected_labels}"
|
|
516
|
+
|
|
517
|
+
# Update table if it exists
|
|
518
|
+
self._update_label_table()
|
|
519
|
+
|
|
520
|
+
# Update preview after selection changes
|
|
521
|
+
self.preview_crop()
|
|
522
|
+
|
|
523
|
+
except (Exception, ValueError) as e:
|
|
524
|
+
self.viewer.status = f"Error selecting label: {str(e)}"
|
|
525
|
+
|
|
526
|
+
def create_label_table(self, parent_widget):
|
|
527
|
+
"""Create a table widget displaying all detected labels."""
|
|
528
|
+
# Create table widget
|
|
529
|
+
table = QTableWidget()
|
|
530
|
+
table.setColumnCount(2)
|
|
531
|
+
table.setHorizontalHeaderLabels(["Select", "Label ID"])
|
|
532
|
+
|
|
533
|
+
# Set up the table
|
|
534
|
+
table.setEditTriggers(QTableWidget.NoEditTriggers)
|
|
535
|
+
table.setSelectionBehavior(QTableWidget.SelectRows)
|
|
536
|
+
|
|
537
|
+
# Turn off alternating colors to avoid coloring issues
|
|
538
|
+
table.setAlternatingRowColors(False)
|
|
539
|
+
|
|
540
|
+
# Column sizing
|
|
541
|
+
table.horizontalHeader().setSectionResizeMode(0, QHeaderView.Stretch)
|
|
542
|
+
table.horizontalHeader().setSectionResizeMode(
|
|
543
|
+
1, QHeaderView.ResizeToContents
|
|
544
|
+
)
|
|
545
|
+
table.horizontalHeader().setMinimumSectionSize(80)
|
|
546
|
+
|
|
547
|
+
# Fill the table with label information
|
|
548
|
+
self._populate_label_table(table)
|
|
549
|
+
|
|
550
|
+
# Store reference to the table
|
|
551
|
+
self.label_table_widget = table
|
|
552
|
+
|
|
553
|
+
# Connect signal to make segmentation layer active when table is clicked
|
|
554
|
+
table.clicked.connect(lambda: self._ensure_segmentation_layer_active())
|
|
555
|
+
|
|
556
|
+
return table
|
|
557
|
+
|
|
558
|
+
def _ensure_segmentation_layer_active(self):
|
|
559
|
+
"""Ensure the segmentation layer is the active layer."""
|
|
560
|
+
if self.label_layer is not None:
|
|
561
|
+
self.viewer.layers.selection.active = self.label_layer
|
|
562
|
+
|
|
563
|
+
def _populate_label_table(self, table):
|
|
564
|
+
"""Populate the table with label information."""
|
|
565
|
+
if not self.label_info:
|
|
566
|
+
table.setRowCount(0)
|
|
567
|
+
return
|
|
568
|
+
|
|
569
|
+
# Set row count
|
|
570
|
+
table.setRowCount(len(self.label_info))
|
|
571
|
+
|
|
572
|
+
# Sort labels by size (largest first)
|
|
573
|
+
sorted_labels = sorted(
|
|
574
|
+
self.label_info.items(),
|
|
575
|
+
key=lambda item: item[1]["area"],
|
|
576
|
+
reverse=True,
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
# Fill table with data
|
|
580
|
+
for row, (label_id, _info) in enumerate(sorted_labels):
|
|
581
|
+
# Checkbox for selection
|
|
582
|
+
checkbox_widget = QWidget()
|
|
583
|
+
checkbox_layout = QHBoxLayout(checkbox_widget)
|
|
584
|
+
checkbox_layout.setContentsMargins(5, 0, 5, 0)
|
|
585
|
+
checkbox_layout.setAlignment(Qt.AlignCenter)
|
|
586
|
+
|
|
587
|
+
checkbox = QCheckBox()
|
|
588
|
+
checkbox.setChecked(label_id in self.selected_labels)
|
|
589
|
+
|
|
590
|
+
# Connect checkbox to label selection
|
|
591
|
+
def make_checkbox_callback(lid):
|
|
592
|
+
def callback(state):
|
|
593
|
+
if state == Qt.Checked:
|
|
594
|
+
self.selected_labels.add(lid)
|
|
595
|
+
else:
|
|
596
|
+
self.selected_labels.discard(lid)
|
|
597
|
+
self.preview_crop()
|
|
598
|
+
|
|
599
|
+
return callback
|
|
600
|
+
|
|
601
|
+
checkbox.stateChanged.connect(make_checkbox_callback(label_id))
|
|
602
|
+
|
|
603
|
+
checkbox_layout.addWidget(checkbox)
|
|
604
|
+
table.setCellWidget(row, 0, checkbox_widget)
|
|
605
|
+
|
|
606
|
+
# Label ID as plain text with transparent background
|
|
607
|
+
item = QTableWidgetItem(str(label_id))
|
|
608
|
+
item.setTextAlignment(Qt.AlignCenter)
|
|
609
|
+
|
|
610
|
+
# Set the background color to transparent
|
|
611
|
+
brush = item.background()
|
|
612
|
+
brush.setStyle(Qt.NoBrush)
|
|
613
|
+
item.setBackground(brush)
|
|
614
|
+
|
|
615
|
+
table.setItem(row, 1, item)
|
|
616
|
+
|
|
617
|
+
def _update_label_table(self):
|
|
618
|
+
"""Update the label selection table if it exists."""
|
|
619
|
+
if self.label_table_widget is None:
|
|
620
|
+
return
|
|
621
|
+
|
|
622
|
+
# Block signals during update
|
|
623
|
+
self.label_table_widget.blockSignals(True)
|
|
624
|
+
|
|
625
|
+
# Update checkboxes
|
|
626
|
+
for row in range(self.label_table_widget.rowCount()):
|
|
627
|
+
# Get label ID from the visible column
|
|
628
|
+
label_id_item = self.label_table_widget.item(row, 1)
|
|
629
|
+
if label_id_item is None:
|
|
630
|
+
continue
|
|
631
|
+
|
|
632
|
+
label_id = int(label_id_item.text())
|
|
633
|
+
|
|
634
|
+
# Find checkbox cell
|
|
635
|
+
checkbox_item = self.label_table_widget.cellWidget(row, 0)
|
|
636
|
+
if checkbox_item is None:
|
|
637
|
+
continue
|
|
638
|
+
|
|
639
|
+
# Update checkbox state
|
|
640
|
+
checkbox = checkbox_item.findChild(QCheckBox)
|
|
641
|
+
if checkbox:
|
|
642
|
+
checkbox.setChecked(label_id in self.selected_labels)
|
|
643
|
+
|
|
644
|
+
# Unblock signals
|
|
645
|
+
self.label_table_widget.blockSignals(False)
|
|
646
|
+
|
|
647
|
+
def select_all_labels(self):
|
|
648
|
+
"""Select all labels."""
|
|
649
|
+
if not self.label_info:
|
|
650
|
+
return
|
|
651
|
+
|
|
652
|
+
self.selected_labels = set(self.label_info.keys())
|
|
653
|
+
self._update_label_table()
|
|
654
|
+
self.preview_crop()
|
|
655
|
+
self.viewer.status = f"Selected all {len(self.selected_labels)} labels"
|
|
656
|
+
|
|
657
|
+
def clear_selection(self):
|
|
658
|
+
"""Clear all selected labels."""
|
|
659
|
+
self.selected_labels = set()
|
|
660
|
+
self._update_label_table()
|
|
661
|
+
self.preview_crop()
|
|
662
|
+
self.viewer.status = "Cleared all selections"
|
|
663
|
+
|
|
664
|
+
# --------------------------------------------------
|
|
665
|
+
# Image Processing and Export
|
|
666
|
+
# --------------------------------------------------
|
|
667
|
+
|
|
668
|
+
def preview_crop(self, label_ids=None):
|
|
669
|
+
"""Preview the crop result with the selected label IDs."""
|
|
670
|
+
if self.segmentation_result is None or self.image_layer is None:
|
|
671
|
+
self.viewer.status = (
|
|
672
|
+
"No image or segmentation available for preview."
|
|
673
|
+
)
|
|
674
|
+
return
|
|
675
|
+
|
|
676
|
+
try:
|
|
677
|
+
# Use provided label IDs or default to selected labels
|
|
678
|
+
if label_ids is None:
|
|
679
|
+
label_ids = self.selected_labels
|
|
680
|
+
|
|
681
|
+
# Skip if no labels are selected
|
|
682
|
+
if not label_ids:
|
|
683
|
+
# Remove previous preview if exists
|
|
684
|
+
for layer in list(self.viewer.layers):
|
|
685
|
+
if "Preview" in layer.name:
|
|
686
|
+
self.viewer.layers.remove(layer)
|
|
687
|
+
|
|
688
|
+
# Make sure the segmentation layer is active again
|
|
689
|
+
if self.label_layer is not None:
|
|
690
|
+
self.viewer.layers.selection.active = self.label_layer
|
|
691
|
+
return
|
|
692
|
+
|
|
693
|
+
# Get current image
|
|
694
|
+
image = self.original_image.copy()
|
|
695
|
+
|
|
696
|
+
# Create mask from selected label IDs
|
|
697
|
+
mask = np.zeros_like(self.segmentation_result, dtype=bool)
|
|
698
|
+
for label_id in label_ids:
|
|
699
|
+
mask |= self.segmentation_result == label_id
|
|
700
|
+
|
|
701
|
+
# Apply mask to image for preview (set everything outside mask to 0)
|
|
702
|
+
if len(image.shape) == 2:
|
|
703
|
+
# Grayscale image
|
|
704
|
+
preview_image = image.copy()
|
|
705
|
+
preview_image[~mask] = 0
|
|
706
|
+
else:
|
|
707
|
+
# Color image
|
|
708
|
+
preview_image = image.copy()
|
|
709
|
+
for c in range(preview_image.shape[2]):
|
|
710
|
+
preview_image[:, :, c][~mask] = 0
|
|
711
|
+
|
|
712
|
+
# Remove previous preview if exists
|
|
713
|
+
for layer in list(self.viewer.layers):
|
|
714
|
+
if "Preview" in layer.name:
|
|
715
|
+
self.viewer.layers.remove(layer)
|
|
716
|
+
|
|
717
|
+
# Add preview layer
|
|
718
|
+
if label_ids:
|
|
719
|
+
label_str = ", ".join(str(lid) for lid in sorted(label_ids))
|
|
720
|
+
self.viewer.add_image(
|
|
721
|
+
preview_image,
|
|
722
|
+
name=f"Preview (Labels: {label_str})",
|
|
723
|
+
opacity=0.55,
|
|
724
|
+
)
|
|
725
|
+
|
|
726
|
+
# Make sure the segmentation layer is active again
|
|
727
|
+
if self.label_layer is not None:
|
|
728
|
+
self.viewer.layers.selection.active = self.label_layer
|
|
729
|
+
|
|
730
|
+
except (Exception, ValueError) as e:
|
|
731
|
+
self.viewer.status = f"Error generating preview: {str(e)}"
|
|
732
|
+
|
|
733
|
+
def crop_with_selected_labels(self):
|
|
734
|
+
"""Crop the current image using all selected label IDs."""
|
|
735
|
+
if self.segmentation_result is None or self.original_image is None:
|
|
736
|
+
self.viewer.status = (
|
|
737
|
+
"No image or segmentation available for cropping."
|
|
738
|
+
)
|
|
739
|
+
return False
|
|
740
|
+
|
|
741
|
+
if not self.selected_labels:
|
|
742
|
+
self.viewer.status = "No labels selected for cropping."
|
|
743
|
+
return False
|
|
744
|
+
|
|
745
|
+
try:
|
|
746
|
+
# Get current image
|
|
747
|
+
image = self.original_image
|
|
748
|
+
|
|
749
|
+
# Create mask from all selected label IDs
|
|
750
|
+
mask = np.zeros_like(self.segmentation_result, dtype=bool)
|
|
751
|
+
for label_id in self.selected_labels:
|
|
752
|
+
mask |= self.segmentation_result == label_id
|
|
753
|
+
|
|
754
|
+
# Apply mask to image (set everything outside mask to 0)
|
|
755
|
+
if len(image.shape) == 2:
|
|
756
|
+
# Grayscale image
|
|
757
|
+
cropped_image = image.copy()
|
|
758
|
+
cropped_image[~mask] = 0
|
|
759
|
+
else:
|
|
760
|
+
# Color image
|
|
761
|
+
cropped_image = image.copy()
|
|
762
|
+
for c in range(cropped_image.shape[2]):
|
|
763
|
+
cropped_image[:, :, c][~mask] = 0
|
|
764
|
+
|
|
765
|
+
# Save cropped image
|
|
766
|
+
image_path = self.images[self.current_index]
|
|
767
|
+
base_name, ext = os.path.splitext(image_path)
|
|
768
|
+
label_str = "_".join(
|
|
769
|
+
str(lid) for lid in sorted(self.selected_labels)
|
|
770
|
+
)
|
|
771
|
+
output_path = f"{base_name}_cropped_{label_str}{ext}"
|
|
772
|
+
|
|
773
|
+
# Save using appropriate method based on file type
|
|
774
|
+
if output_path.lower().endswith((".tif", ".tiff")):
|
|
775
|
+
imwrite(output_path, cropped_image, compression="zlib")
|
|
776
|
+
else:
|
|
777
|
+
from skimage.io import imsave
|
|
778
|
+
|
|
779
|
+
imsave(output_path, cropped_image)
|
|
780
|
+
|
|
781
|
+
self.viewer.status = f"Saved cropped image to {output_path}"
|
|
782
|
+
|
|
783
|
+
# Make sure the segmentation layer is active again
|
|
784
|
+
if self.label_layer is not None:
|
|
785
|
+
self.viewer.layers.selection.active = self.label_layer
|
|
786
|
+
|
|
787
|
+
return True
|
|
788
|
+
|
|
789
|
+
except (Exception, ValueError) as e:
|
|
790
|
+
self.viewer.status = f"Error cropping image: {str(e)}"
|
|
791
|
+
return False
|
|
792
|
+
|
|
793
|
+
|
|
794
|
+
# --------------------------------------------------
|
|
795
|
+
# UI Creation Functions
|
|
796
|
+
# --------------------------------------------------
|
|
797
|
+
|
|
798
|
+
|
|
799
|
+
def create_crop_widget(processor):
|
|
800
|
+
"""Create the crop control widget."""
|
|
801
|
+
crop_widget = QWidget()
|
|
802
|
+
layout = QVBoxLayout()
|
|
803
|
+
layout.setSpacing(10) # Add more space between elements
|
|
804
|
+
layout.setContentsMargins(
|
|
805
|
+
10, 10, 10, 10
|
|
806
|
+
) # Add margins around all elements
|
|
807
|
+
|
|
808
|
+
# Instructions
|
|
809
|
+
instructions_label = QLabel(
|
|
810
|
+
"Select objects to keep in the cropped image.\n"
|
|
811
|
+
"You can select labels using the table below or by clicking directly on objects "
|
|
812
|
+
"in the image (make sure the Segmentation layer is active)."
|
|
813
|
+
)
|
|
814
|
+
instructions_label.setWordWrap(True)
|
|
815
|
+
layout.addWidget(instructions_label)
|
|
816
|
+
|
|
817
|
+
# Sensitivity slider
|
|
818
|
+
sensitivity_layout = QVBoxLayout()
|
|
819
|
+
|
|
820
|
+
# Header label
|
|
821
|
+
sensitivity_header_layout = QHBoxLayout()
|
|
822
|
+
sensitivity_label = QLabel("Segmentation Sensitivity:")
|
|
823
|
+
sensitivity_value_label = QLabel(f"{processor.sensitivity}")
|
|
824
|
+
sensitivity_header_layout.addWidget(sensitivity_label)
|
|
825
|
+
sensitivity_header_layout.addStretch()
|
|
826
|
+
sensitivity_header_layout.addWidget(sensitivity_value_label)
|
|
827
|
+
sensitivity_layout.addLayout(sensitivity_header_layout)
|
|
828
|
+
|
|
829
|
+
# Slider
|
|
830
|
+
slider_layout = QHBoxLayout()
|
|
831
|
+
sensitivity_slider = QSlider(Qt.Horizontal)
|
|
832
|
+
sensitivity_slider.setMinimum(0)
|
|
833
|
+
sensitivity_slider.setMaximum(100)
|
|
834
|
+
sensitivity_slider.setValue(processor.sensitivity)
|
|
835
|
+
sensitivity_slider.setTickPosition(QSlider.TicksBelow)
|
|
836
|
+
sensitivity_slider.setTickInterval(10)
|
|
837
|
+
slider_layout.addWidget(sensitivity_slider)
|
|
838
|
+
|
|
839
|
+
apply_sensitivity_button = QPushButton("Apply")
|
|
840
|
+
apply_sensitivity_button.setToolTip(
|
|
841
|
+
"Apply sensitivity changes to regenerate segmentation"
|
|
842
|
+
)
|
|
843
|
+
slider_layout.addWidget(apply_sensitivity_button)
|
|
844
|
+
sensitivity_layout.addLayout(slider_layout)
|
|
845
|
+
|
|
846
|
+
# Description label
|
|
847
|
+
sensitivity_description = QLabel(
|
|
848
|
+
"Medium sensitivity - Balanced detection (γ=1.00)"
|
|
849
|
+
)
|
|
850
|
+
sensitivity_description.setStyleSheet("font-style: italic; color: #666;")
|
|
851
|
+
sensitivity_layout.addWidget(sensitivity_description)
|
|
852
|
+
|
|
853
|
+
layout.addLayout(sensitivity_layout)
|
|
854
|
+
|
|
855
|
+
# Create label table
|
|
856
|
+
label_table = processor.create_label_table(crop_widget)
|
|
857
|
+
label_table.setMinimumHeight(150) # Reduce minimum height to save space
|
|
858
|
+
label_table.setMaximumHeight(
|
|
859
|
+
300
|
|
860
|
+
) # Set maximum height to prevent taking too much space
|
|
861
|
+
layout.addWidget(label_table)
|
|
862
|
+
|
|
863
|
+
# Remove "Focus on Segmentation Layer" button as it's now redundant
|
|
864
|
+
selection_layout = QHBoxLayout()
|
|
865
|
+
select_all_button = QPushButton("Select All")
|
|
866
|
+
clear_selection_button = QPushButton("Clear Selection")
|
|
867
|
+
selection_layout.addWidget(select_all_button)
|
|
868
|
+
selection_layout.addWidget(clear_selection_button)
|
|
869
|
+
layout.addLayout(selection_layout)
|
|
870
|
+
|
|
871
|
+
# Crop button
|
|
872
|
+
crop_button = QPushButton("Crop with Selected Objects")
|
|
873
|
+
layout.addWidget(crop_button)
|
|
874
|
+
|
|
875
|
+
# Navigation buttons
|
|
876
|
+
nav_layout = QHBoxLayout()
|
|
877
|
+
prev_button = QPushButton("Previous Image")
|
|
878
|
+
next_button = QPushButton("Next Image")
|
|
879
|
+
nav_layout.addWidget(prev_button)
|
|
880
|
+
nav_layout.addWidget(next_button)
|
|
881
|
+
layout.addLayout(nav_layout)
|
|
882
|
+
|
|
883
|
+
# Status label
|
|
884
|
+
status_label = QLabel(
|
|
885
|
+
"Ready to process images. Select objects using the table or by clicking on them."
|
|
886
|
+
)
|
|
887
|
+
status_label.setWordWrap(True)
|
|
888
|
+
layout.addWidget(status_label)
|
|
889
|
+
|
|
890
|
+
# Set layout
|
|
891
|
+
crop_widget.setLayout(layout)
|
|
892
|
+
|
|
893
|
+
# Function to completely replace the table widget
|
|
894
|
+
def replace_table_widget():
|
|
895
|
+
nonlocal label_table
|
|
896
|
+
# Remove old table
|
|
897
|
+
layout.removeWidget(label_table)
|
|
898
|
+
label_table.setParent(None)
|
|
899
|
+
label_table.deleteLater()
|
|
900
|
+
|
|
901
|
+
# Create new table
|
|
902
|
+
label_table = processor.create_label_table(crop_widget)
|
|
903
|
+
label_table.setMinimumHeight(200)
|
|
904
|
+
layout.insertWidget(3, label_table) # Insert after sensitivity slider
|
|
905
|
+
return label_table
|
|
906
|
+
|
|
907
|
+
# Connect button signals
|
|
908
|
+
def on_sensitivity_changed(value):
|
|
909
|
+
sensitivity_value_label.setText(f"{value}")
|
|
910
|
+
# Update description based on sensitivity
|
|
911
|
+
if value < 25:
|
|
912
|
+
gamma = (
|
|
913
|
+
1.5 - (value / 100) * 1.0
|
|
914
|
+
) # Higher gamma for low sensitivity
|
|
915
|
+
description = f"Low sensitivity - Seeks large, distinct objects (γ={gamma:.2f})"
|
|
916
|
+
elif value < 75:
|
|
917
|
+
gamma = 1.5 - (value / 100) * 1.0
|
|
918
|
+
description = (
|
|
919
|
+
f"Medium sensitivity - Balanced detection (γ={gamma:.2f})"
|
|
920
|
+
)
|
|
921
|
+
else:
|
|
922
|
+
gamma = (
|
|
923
|
+
1.5 - (value / 100) * 1.0
|
|
924
|
+
) # Lower gamma for high sensitivity
|
|
925
|
+
description = f"High sensitivity - Detects subtle, small objects (γ={gamma:.2f})"
|
|
926
|
+
sensitivity_description.setText(description)
|
|
927
|
+
|
|
928
|
+
def on_apply_sensitivity_clicked():
|
|
929
|
+
new_sensitivity = sensitivity_slider.value()
|
|
930
|
+
processor.generate_segmentation_with_sensitivity(new_sensitivity)
|
|
931
|
+
replace_table_widget()
|
|
932
|
+
status_label.setText(
|
|
933
|
+
f"Regenerated segmentation with sensitivity {new_sensitivity}"
|
|
934
|
+
)
|
|
935
|
+
|
|
936
|
+
def on_select_all_clicked():
|
|
937
|
+
processor.select_all_labels()
|
|
938
|
+
status_label.setText(
|
|
939
|
+
f"Selected all {len(processor.selected_labels)} objects"
|
|
940
|
+
)
|
|
941
|
+
|
|
942
|
+
def on_clear_selection_clicked():
|
|
943
|
+
processor.clear_selection()
|
|
944
|
+
status_label.setText("Selection cleared")
|
|
945
|
+
|
|
946
|
+
def on_crop_clicked():
|
|
947
|
+
success = processor.crop_with_selected_labels()
|
|
948
|
+
if success:
|
|
949
|
+
labels_str = ", ".join(
|
|
950
|
+
str(label) for label in sorted(processor.selected_labels)
|
|
951
|
+
)
|
|
952
|
+
status_label.setText(
|
|
953
|
+
f"Cropped image with {len(processor.selected_labels)} objects (IDs: {labels_str})"
|
|
954
|
+
)
|
|
955
|
+
|
|
956
|
+
def on_next_clicked():
|
|
957
|
+
if not processor.next_image():
|
|
958
|
+
next_button.setEnabled(False)
|
|
959
|
+
else:
|
|
960
|
+
prev_button.setEnabled(True)
|
|
961
|
+
replace_table_widget()
|
|
962
|
+
# Reset sensitivity slider to default
|
|
963
|
+
sensitivity_slider.setValue(processor.sensitivity)
|
|
964
|
+
sensitivity_value_label.setText(f"{processor.sensitivity}")
|
|
965
|
+
status_label.setText(
|
|
966
|
+
f"Showing image {processor.current_index + 1}/{len(processor.images)}"
|
|
967
|
+
)
|
|
968
|
+
|
|
969
|
+
def on_prev_clicked():
|
|
970
|
+
if not processor.previous_image():
|
|
971
|
+
prev_button.setEnabled(False)
|
|
972
|
+
else:
|
|
973
|
+
next_button.setEnabled(True)
|
|
974
|
+
replace_table_widget()
|
|
975
|
+
# Reset sensitivity slider to default
|
|
976
|
+
sensitivity_slider.setValue(processor.sensitivity)
|
|
977
|
+
sensitivity_value_label.setText(f"{processor.sensitivity}")
|
|
978
|
+
status_label.setText(
|
|
979
|
+
f"Showing image {processor.current_index + 1}/{len(processor.images)}"
|
|
980
|
+
)
|
|
981
|
+
|
|
982
|
+
sensitivity_slider.valueChanged.connect(on_sensitivity_changed)
|
|
983
|
+
apply_sensitivity_button.clicked.connect(on_apply_sensitivity_clicked)
|
|
984
|
+
select_all_button.clicked.connect(on_select_all_clicked)
|
|
985
|
+
clear_selection_button.clicked.connect(on_clear_selection_clicked)
|
|
986
|
+
crop_button.clicked.connect(on_crop_clicked)
|
|
987
|
+
next_button.clicked.connect(on_next_clicked)
|
|
988
|
+
prev_button.clicked.connect(on_prev_clicked)
|
|
989
|
+
|
|
990
|
+
return crop_widget
|
|
991
|
+
|
|
992
|
+
|
|
993
|
+
# --------------------------------------------------
|
|
994
|
+
# Napari Plugin Functions
|
|
995
|
+
# --------------------------------------------------
|
|
996
|
+
|
|
997
|
+
|
|
998
|
+
@magicgui(
|
|
999
|
+
call_button="Start Batch Crop Anything",
|
|
1000
|
+
folder_path={"label": "Folder Path", "widget_type": "LineEdit"},
|
|
1001
|
+
)
|
|
1002
|
+
def batch_crop_anything(
|
|
1003
|
+
folder_path: str,
|
|
1004
|
+
viewer: Viewer = None,
|
|
1005
|
+
):
|
|
1006
|
+
"""MagicGUI widget for starting Batch Crop Anything."""
|
|
1007
|
+
# Check if Mobile-SAM is available
|
|
1008
|
+
try:
|
|
1009
|
+
# import torch
|
|
1010
|
+
# from mobile_sam import sam_model_registry
|
|
1011
|
+
|
|
1012
|
+
# Check if the required files are included with the package
|
|
1013
|
+
try:
|
|
1014
|
+
import importlib.util
|
|
1015
|
+
import os
|
|
1016
|
+
|
|
1017
|
+
mobile_sam_spec = importlib.util.find_spec("mobile_sam")
|
|
1018
|
+
if mobile_sam_spec is None:
|
|
1019
|
+
raise ImportError("mobile_sam package not found")
|
|
1020
|
+
|
|
1021
|
+
mobile_sam_path = os.path.dirname(mobile_sam_spec.origin)
|
|
1022
|
+
|
|
1023
|
+
# Check for model file in package
|
|
1024
|
+
model_found = False
|
|
1025
|
+
checkpoint_paths = [
|
|
1026
|
+
os.path.join(mobile_sam_path, "weights", "mobile_sam.pt"),
|
|
1027
|
+
os.path.join(mobile_sam_path, "mobile_sam.pt"),
|
|
1028
|
+
os.path.join(
|
|
1029
|
+
os.path.dirname(mobile_sam_path),
|
|
1030
|
+
"weights",
|
|
1031
|
+
"mobile_sam.pt",
|
|
1032
|
+
),
|
|
1033
|
+
os.path.join(
|
|
1034
|
+
os.path.expanduser("~"), "models", "mobile_sam.pt"
|
|
1035
|
+
),
|
|
1036
|
+
"/opt/T-MIDAS/models/mobile_sam.pt",
|
|
1037
|
+
os.path.join(os.getcwd(), "mobile_sam.pt"),
|
|
1038
|
+
]
|
|
1039
|
+
|
|
1040
|
+
for path in checkpoint_paths:
|
|
1041
|
+
if os.path.exists(path):
|
|
1042
|
+
model_found = True
|
|
1043
|
+
break
|
|
1044
|
+
|
|
1045
|
+
if not model_found:
|
|
1046
|
+
QMessageBox.warning(
|
|
1047
|
+
None,
|
|
1048
|
+
"Model File Missing",
|
|
1049
|
+
"Mobile-SAM model weights (mobile_sam.pt) not found. You'll be prompted to locate it when starting the tool.\n\n"
|
|
1050
|
+
"You can download it from: https://github.com/ChaoningZhang/MobileSAM/tree/master/weights",
|
|
1051
|
+
)
|
|
1052
|
+
except (ImportError, AttributeError) as e:
|
|
1053
|
+
print(f"Warning checking for model file: {str(e)}")
|
|
1054
|
+
|
|
1055
|
+
except ImportError:
|
|
1056
|
+
QMessageBox.critical(
|
|
1057
|
+
None,
|
|
1058
|
+
"Missing Dependency",
|
|
1059
|
+
"Mobile-SAM not found. Please install with:\n"
|
|
1060
|
+
"pip install git+https://github.com/ChaoningZhang/MobileSAM.git\n\n"
|
|
1061
|
+
"You'll also need to download the model weights file (mobile_sam.pt) from:\n"
|
|
1062
|
+
"https://github.com/ChaoningZhang/MobileSAM/tree/master/weights",
|
|
1063
|
+
)
|
|
1064
|
+
return
|
|
1065
|
+
|
|
1066
|
+
# Initialize processor and load images
|
|
1067
|
+
processor = BatchCropAnything(viewer)
|
|
1068
|
+
processor.load_images(folder_path)
|
|
1069
|
+
|
|
1070
|
+
# Create UI
|
|
1071
|
+
crop_widget = create_crop_widget(processor)
|
|
1072
|
+
|
|
1073
|
+
# Wrap the widget in a scroll area
|
|
1074
|
+
scroll_area = QScrollArea()
|
|
1075
|
+
scroll_area.setWidget(crop_widget)
|
|
1076
|
+
scroll_area.setWidgetResizable(
|
|
1077
|
+
True
|
|
1078
|
+
) # This allows the widget to resize with the scroll area
|
|
1079
|
+
scroll_area.setFrameShape(QScrollArea.NoFrame) # Hide the frame
|
|
1080
|
+
scroll_area.setMinimumHeight(
|
|
1081
|
+
500
|
|
1082
|
+
) # Set a minimum height to ensure visibility
|
|
1083
|
+
|
|
1084
|
+
# Add scroll area to viewer
|
|
1085
|
+
viewer.window.add_dock_widget(scroll_area, name="Crop Controls")
|
|
1086
|
+
|
|
1087
|
+
|
|
1088
|
+
def batch_crop_anything_widget():
|
|
1089
|
+
"""Provide the batch crop anything widget to Napari."""
|
|
1090
|
+
# Create the magicgui widget
|
|
1091
|
+
widget = batch_crop_anything
|
|
1092
|
+
|
|
1093
|
+
# Create and add browse button for folder path
|
|
1094
|
+
folder_browse_button = QPushButton("Browse...")
|
|
1095
|
+
|
|
1096
|
+
def on_folder_browse_clicked():
|
|
1097
|
+
folder = QFileDialog.getExistingDirectory(
|
|
1098
|
+
None,
|
|
1099
|
+
"Select Folder",
|
|
1100
|
+
os.path.expanduser("~"),
|
|
1101
|
+
QFileDialog.ShowDirsOnly | QFileDialog.DontResolveSymlinks,
|
|
1102
|
+
)
|
|
1103
|
+
if folder:
|
|
1104
|
+
# Update the folder_path field
|
|
1105
|
+
widget.folder_path.value = folder
|
|
1106
|
+
|
|
1107
|
+
folder_browse_button.clicked.connect(on_folder_browse_clicked)
|
|
1108
|
+
|
|
1109
|
+
# Insert the browse button next to the folder_path field
|
|
1110
|
+
folder_layout = widget.folder_path.native.parent().layout()
|
|
1111
|
+
folder_layout.addWidget(folder_browse_button)
|
|
1112
|
+
|
|
1113
|
+
return widget
|