napari-tmidas 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1113 @@
1
+ """
2
+ Batch Crop Anything - A Napari plugin for interactive image cropping
3
+
4
+ This plugin combines Segment Anything Model (SAM) for automatic object detection with
5
+ an interactive interface for selecting and cropping objects from images.
6
+ """
7
+
8
+ import os
9
+
10
+ import numpy as np
11
+ import torch
12
+ from magicgui import magicgui
13
+ from napari.layers import Labels
14
+ from napari.viewer import Viewer
15
+ from qtpy.QtCore import Qt
16
+ from qtpy.QtWidgets import (
17
+ QCheckBox,
18
+ QFileDialog,
19
+ QHBoxLayout,
20
+ QHeaderView,
21
+ QLabel,
22
+ QMessageBox,
23
+ QPushButton,
24
+ QScrollArea,
25
+ QSlider,
26
+ QTableWidget,
27
+ QTableWidgetItem,
28
+ QVBoxLayout,
29
+ QWidget,
30
+ )
31
+ from skimage.io import imread
32
+ from tifffile import imwrite
33
+
34
+
35
+ class BatchCropAnything:
36
+ """
37
+ Class for processing images with Segment Anything and cropping selected objects.
38
+ """
39
+
40
+ def __init__(self, viewer: Viewer):
41
+ """Initialize the BatchCropAnything processor."""
42
+ # Core components
43
+ self.viewer = viewer
44
+ self.images = []
45
+ self.current_index = 0
46
+
47
+ # Image and segmentation data
48
+ self.original_image = None
49
+ self.segmentation_result = None
50
+ self.current_image_for_segmentation = None
51
+
52
+ # UI references
53
+ self.image_layer = None
54
+ self.label_layer = None
55
+ self.label_table_widget = None
56
+
57
+ # State tracking
58
+ self.selected_labels = set()
59
+ self.label_info = {}
60
+
61
+ # Segmentation parameters
62
+ self.sensitivity = 50 # Default sensitivity (0-100 scale)
63
+
64
+ # Initialize the SAM model
65
+ self._initialize_sam()
66
+
67
+ # --------------------------------------------------
68
+ # Model Initialization
69
+ # --------------------------------------------------
70
+
71
+ def _initialize_sam(self):
72
+ """Initialize the Segment Anything Model."""
73
+ try:
74
+ # Import required modules
75
+ from mobile_sam import (
76
+ SamAutomaticMaskGenerator,
77
+ sam_model_registry,
78
+ )
79
+
80
+ # Setup device
81
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
82
+ model_type = "vit_t"
83
+
84
+ # Find the model weights file
85
+ checkpoint_path = self._find_sam_checkpoint()
86
+ if checkpoint_path is None:
87
+ self.mobile_sam = None
88
+ self.mask_generator = None
89
+ return
90
+
91
+ # Initialize the model
92
+ self.mobile_sam = sam_model_registry[model_type](
93
+ checkpoint=checkpoint_path
94
+ )
95
+ self.mobile_sam.to(device=self.device)
96
+ self.mobile_sam.eval()
97
+
98
+ # Create mask generator with default parameters
99
+ self.mask_generator = SamAutomaticMaskGenerator(self.mobile_sam)
100
+ self.viewer.status = f"Initialized SAM model from {checkpoint_path} on {self.device}"
101
+
102
+ except (ImportError, Exception) as e:
103
+ self.viewer.status = f"Error initializing SAM: {str(e)}"
104
+ self.mobile_sam = None
105
+ self.mask_generator = None
106
+
107
+ def _find_sam_checkpoint(self):
108
+ """Find the SAM model checkpoint file."""
109
+ try:
110
+ import importlib.util
111
+
112
+ # Find the mobile_sam package location
113
+ mobile_sam_spec = importlib.util.find_spec("mobile_sam")
114
+ if mobile_sam_spec is None:
115
+ raise ImportError("mobile_sam package not found")
116
+
117
+ mobile_sam_path = os.path.dirname(mobile_sam_spec.origin)
118
+
119
+ # Check common locations for the model file
120
+ checkpoint_paths = [
121
+ os.path.join(mobile_sam_path, "weights", "mobile_sam.pt"),
122
+ os.path.join(mobile_sam_path, "mobile_sam.pt"),
123
+ os.path.join(
124
+ os.path.dirname(mobile_sam_path),
125
+ "weights",
126
+ "mobile_sam.pt",
127
+ ),
128
+ os.path.join(
129
+ os.path.expanduser("~"), "models", "mobile_sam.pt"
130
+ ),
131
+ "/opt/T-MIDAS/models/mobile_sam.pt",
132
+ os.path.join(os.getcwd(), "mobile_sam.pt"),
133
+ ]
134
+
135
+ for path in checkpoint_paths:
136
+ if os.path.exists(path):
137
+ return path
138
+
139
+ # If model not found, ask user
140
+ QMessageBox.information(
141
+ None,
142
+ "Model Not Found",
143
+ "Mobile-SAM model weights not found. Please select the mobile_sam.pt file.",
144
+ )
145
+
146
+ checkpoint_path, _ = QFileDialog.getOpenFileName(
147
+ None, "Select Mobile-SAM model file", "", "Model Files (*.pt)"
148
+ )
149
+
150
+ return checkpoint_path if checkpoint_path else None
151
+
152
+ except (ImportError, Exception) as e:
153
+ self.viewer.status = f"Error finding SAM checkpoint: {str(e)}"
154
+ return None
155
+
156
+ # --------------------------------------------------
157
+ # Image Loading and Navigation
158
+ # --------------------------------------------------
159
+
160
+ def load_images(self, folder_path: str):
161
+ """Load images from the specified folder path."""
162
+ if not os.path.exists(folder_path):
163
+ self.viewer.status = f"Folder not found: {folder_path}"
164
+ return
165
+
166
+ files = os.listdir(folder_path)
167
+ self.images = [
168
+ os.path.join(folder_path, file)
169
+ for file in files
170
+ if file.lower().endswith(
171
+ (".tif", ".tiff", ".png", ".jpg", ".jpeg")
172
+ )
173
+ and not file.endswith(("_labels.tif", "_cropped.tif", "_cropped_"))
174
+ ]
175
+
176
+ if not self.images:
177
+ self.viewer.status = "No compatible images found in the folder."
178
+ return
179
+
180
+ self.viewer.status = f"Found {len(self.images)} images."
181
+ self.current_index = 0
182
+ self._load_current_image()
183
+
184
+ def next_image(self):
185
+ """Move to the next image."""
186
+ if not self.images:
187
+ self.viewer.status = "No images to process."
188
+ return False
189
+
190
+ # Check if we're already at the last image
191
+ if self.current_index >= len(self.images) - 1:
192
+ self.viewer.status = "No more images. Processing complete."
193
+ return False
194
+
195
+ # Move to the next image
196
+ self.current_index += 1
197
+
198
+ # Clear selected labels
199
+ self.selected_labels = set()
200
+
201
+ # Clear the table reference (will be recreated)
202
+ self.label_table_widget = None
203
+
204
+ # Load the next image
205
+ self._load_current_image()
206
+ return True
207
+
208
+ def previous_image(self):
209
+ """Move to the previous image."""
210
+ if not self.images:
211
+ self.viewer.status = "No images to process."
212
+ return False
213
+
214
+ # Check if we're already at the first image
215
+ if self.current_index <= 0:
216
+ self.viewer.status = "Already at the first image."
217
+ return False
218
+
219
+ # Move to the previous image
220
+ self.current_index -= 1
221
+
222
+ # Clear selected labels
223
+ self.selected_labels = set()
224
+
225
+ # Clear the table reference (will be recreated)
226
+ self.label_table_widget = None
227
+
228
+ # Load the previous image
229
+ self._load_current_image()
230
+ return True
231
+
232
+ def _load_current_image(self):
233
+ """Load the current image and generate segmentation."""
234
+ if not self.images:
235
+ self.viewer.status = "No images to process."
236
+ return
237
+
238
+ if self.mobile_sam is None or self.mask_generator is None:
239
+ self.viewer.status = (
240
+ "SAM model not initialized. Cannot segment images."
241
+ )
242
+ return
243
+
244
+ image_path = self.images[self.current_index]
245
+ self.viewer.status = f"Processing {os.path.basename(image_path)}"
246
+
247
+ try:
248
+ # Clear existing layers
249
+ self.viewer.layers.clear()
250
+
251
+ # Load and process image
252
+ self.original_image = imread(image_path)
253
+
254
+ # Ensure image is 8-bit for SAM display (keeping original for saving)
255
+ if self.original_image.dtype != np.uint8:
256
+ image_for_display = (
257
+ self.original_image / np.amax(self.original_image) * 255
258
+ ).astype(np.uint8)
259
+ else:
260
+ image_for_display = self.original_image
261
+
262
+ # Add image to viewer
263
+ self.image_layer = self.viewer.add_image(
264
+ image_for_display,
265
+ name=f"Image ({os.path.basename(image_path)})",
266
+ )
267
+
268
+ # Generate segmentation
269
+ self._generate_segmentation(image_for_display)
270
+
271
+ except (Exception, ValueError) as e:
272
+ import traceback
273
+
274
+ self.viewer.status = f"Error processing image: {str(e)}"
275
+ traceback.print_exc()
276
+ # Create empty segmentation in case of error
277
+ if (
278
+ hasattr(self, "original_image")
279
+ and self.original_image is not None
280
+ ):
281
+ self.segmentation_result = np.zeros(
282
+ self.original_image.shape[:2], dtype=np.uint32
283
+ )
284
+ self.label_layer = self.viewer.add_labels(
285
+ self.segmentation_result, name="Error: No Segmentation"
286
+ )
287
+
288
+ # --------------------------------------------------
289
+ # Segmentation Generation and Control
290
+ # --------------------------------------------------
291
+
292
+ def _generate_segmentation(self, image):
293
+ """Generate segmentation for the current image."""
294
+ # Prepare for SAM (add color channel if needed)
295
+ if len(image.shape) == 2:
296
+ image_for_sam = image[:, :, np.newaxis].repeat(3, axis=2)
297
+ else:
298
+ image_for_sam = image
299
+
300
+ # Store the current image for later regeneration if sensitivity changes
301
+ self.current_image_for_segmentation = image_for_sam
302
+
303
+ # Generate segmentation with current sensitivity
304
+ self.generate_segmentation_with_sensitivity()
305
+
306
+ def generate_segmentation_with_sensitivity(self, sensitivity=None):
307
+ """Generate segmentation with the specified sensitivity."""
308
+ if sensitivity is not None:
309
+ self.sensitivity = sensitivity
310
+
311
+ if self.mobile_sam is None or self.mask_generator is None:
312
+ self.viewer.status = (
313
+ "SAM model not initialized. Cannot segment images."
314
+ )
315
+ return
316
+
317
+ if self.current_image_for_segmentation is None:
318
+ self.viewer.status = "No image loaded for segmentation."
319
+ return
320
+
321
+ try:
322
+ # Map sensitivity (0-100) to SAM parameters
323
+ # Higher sensitivity (100) = lower thresholds = more objects detected
324
+ # Lower sensitivity (0) = higher thresholds = fewer objects detected
325
+
326
+ # pred_iou_thresh range: 0.92 (low sensitivity) to 0.75 (high sensitivity)
327
+ pred_iou = 0.92 - (self.sensitivity / 100) * 0.17
328
+
329
+ # stability_score_thresh range: 0.97 (low sensitivity) to 0.85 (high sensitivity)
330
+ stability = 0.97 - (self.sensitivity / 100) * 0.12
331
+
332
+ # min_mask_region_area range: 300 (low sensitivity) to 30 (high sensitivity)
333
+ min_area = 300 - (self.sensitivity / 100) * 270
334
+
335
+ # Configure mask generator with sensitivity-adjusted parameters
336
+ self.mask_generator.pred_iou_thresh = pred_iou
337
+ self.mask_generator.stability_score_thresh = stability
338
+ self.mask_generator.min_mask_region_area = min_area
339
+
340
+ # Apply gamma correction based on sensitivity
341
+ # Low sensitivity: gamma > 1 (brighten image)
342
+ # High sensitivity: gamma < 1 (darken image)
343
+ gamma = (
344
+ 1.5 - (self.sensitivity / 100) * 1.0
345
+ ) # Range from 1.5 to 0.5
346
+
347
+ # Apply gamma correction to the input image
348
+ image_for_processing = self.current_image_for_segmentation.copy()
349
+
350
+ # Convert to float for proper gamma correction
351
+ image_float = image_for_processing.astype(np.float32) / 255.0
352
+
353
+ # Apply gamma correction
354
+ image_gamma = np.power(image_float, gamma)
355
+
356
+ # Convert back to uint8
357
+ image_gamma = (image_gamma * 255).astype(np.uint8)
358
+
359
+ self.viewer.status = f"Generating segmentation with sensitivity {self.sensitivity} (gamma={gamma:.2f})..."
360
+
361
+ # Generate masks with gamma-corrected image
362
+ masks = self.mask_generator.generate(image_gamma)
363
+ self.viewer.status = f"Generated {len(masks)} masks"
364
+
365
+ if not masks:
366
+ self.viewer.status = (
367
+ "No segments detected. Try increasing the sensitivity."
368
+ )
369
+ # Create empty label layer
370
+ shape = self.current_image_for_segmentation.shape[:2]
371
+ self.segmentation_result = np.zeros(shape, dtype=np.uint32)
372
+
373
+ # Remove existing label layer if exists
374
+ for layer in list(self.viewer.layers):
375
+ if (
376
+ isinstance(layer, Labels)
377
+ and "Segmentation" in layer.name
378
+ ):
379
+ self.viewer.layers.remove(layer)
380
+
381
+ # Add new empty label layer
382
+ self.label_layer = self.viewer.add_labels(
383
+ self.segmentation_result,
384
+ name=f"Segmentation ({os.path.basename(self.images[self.current_index])})",
385
+ opacity=0.7,
386
+ )
387
+
388
+ # Make the label layer active
389
+ self.viewer.layers.selection.active = self.label_layer
390
+ return
391
+
392
+ # Process segmentation masks
393
+ self._process_segmentation_masks(
394
+ masks, self.current_image_for_segmentation.shape[:2]
395
+ )
396
+
397
+ # Clear selected labels since segmentation has changed
398
+ self.selected_labels = set()
399
+
400
+ # Update table if it exists
401
+ if self.label_table_widget:
402
+ self._populate_label_table(self.label_table_widget)
403
+
404
+ except (Exception, ValueError) as e:
405
+ import traceback
406
+
407
+ self.viewer.status = f"Error generating segmentation: {str(e)}"
408
+ traceback.print_exc()
409
+
410
+ def _process_segmentation_masks(self, masks, shape):
411
+ """Process segmentation masks and create label layer."""
412
+ # Create label image from masks
413
+ labels = np.zeros(shape, dtype=np.uint32)
414
+ self.label_info = {} # Reset label info
415
+
416
+ for i, mask_data in enumerate(masks):
417
+ mask = mask_data["segmentation"]
418
+ label_id = i + 1 # Start label IDs from 1
419
+ labels[mask] = label_id
420
+
421
+ # Calculate label information
422
+ area = np.sum(mask)
423
+ y_indices, x_indices = np.where(mask)
424
+ center_y = np.mean(y_indices) if len(y_indices) > 0 else 0
425
+ center_x = np.mean(x_indices) if len(x_indices) > 0 else 0
426
+
427
+ # Store label info
428
+ self.label_info[label_id] = {
429
+ "area": area,
430
+ "center_y": center_y,
431
+ "center_x": center_x,
432
+ "score": mask_data.get("stability_score", 0),
433
+ }
434
+
435
+ # Sort labels by area (largest first)
436
+ self.label_info = dict(
437
+ sorted(
438
+ self.label_info.items(),
439
+ key=lambda item: item[1]["area"],
440
+ reverse=True,
441
+ )
442
+ )
443
+
444
+ # Save segmentation result
445
+ self.segmentation_result = labels
446
+
447
+ # Remove existing label layer if exists
448
+ for layer in list(self.viewer.layers):
449
+ if isinstance(layer, Labels) and "Segmentation" in layer.name:
450
+ self.viewer.layers.remove(layer)
451
+
452
+ # Add label layer to viewer
453
+ self.label_layer = self.viewer.add_labels(
454
+ labels,
455
+ name=f"Segmentation ({os.path.basename(self.images[self.current_index])})",
456
+ opacity=0.7,
457
+ )
458
+
459
+ # Make the label layer active by default
460
+ self.viewer.layers.selection.active = self.label_layer
461
+
462
+ # Disconnect existing callbacks if any
463
+ if (
464
+ hasattr(self, "label_layer")
465
+ and self.label_layer is not None
466
+ and hasattr(self.label_layer, "mouse_drag_callbacks")
467
+ ):
468
+ # Remove old callbacks
469
+ for callback in list(self.label_layer.mouse_drag_callbacks):
470
+ self.label_layer.mouse_drag_callbacks.remove(callback)
471
+
472
+ # Connect mouse click event to label selection
473
+ self.label_layer.mouse_drag_callbacks.append(self._on_label_clicked)
474
+
475
+ # image_name = os.path.basename(self.images[self.current_index])
476
+ self.viewer.status = f"Loaded image {self.current_index + 1}/{len(self.images)} - Found {len(masks)} segments"
477
+
478
+ # --------------------------------------------------
479
+ # Label Selection and UI Elements
480
+ # --------------------------------------------------
481
+
482
+ def _on_label_clicked(self, layer, event):
483
+ """Handle label selection on mouse click."""
484
+ try:
485
+ # Only process clicks, not drags
486
+ if event.type != "mouse_press":
487
+ return
488
+
489
+ # Get coordinates of mouse click
490
+ coords = np.round(event.position).astype(int)
491
+
492
+ # Make sure coordinates are within bounds
493
+ shape = self.segmentation_result.shape
494
+ if (
495
+ coords[0] < 0
496
+ or coords[1] < 0
497
+ or coords[0] >= shape[0]
498
+ or coords[1] >= shape[1]
499
+ ):
500
+ return
501
+
502
+ # Get the label ID at the clicked position
503
+ label_id = self.segmentation_result[coords[0], coords[1]]
504
+
505
+ # Skip if background (0) is clicked
506
+ if label_id == 0:
507
+ return
508
+
509
+ # Toggle the label selection
510
+ if label_id in self.selected_labels:
511
+ self.selected_labels.remove(label_id)
512
+ self.viewer.status = f"Deselected label ID: {label_id} | Selected labels: {self.selected_labels}"
513
+ else:
514
+ self.selected_labels.add(label_id)
515
+ self.viewer.status = f"Selected label ID: {label_id} | Selected labels: {self.selected_labels}"
516
+
517
+ # Update table if it exists
518
+ self._update_label_table()
519
+
520
+ # Update preview after selection changes
521
+ self.preview_crop()
522
+
523
+ except (Exception, ValueError) as e:
524
+ self.viewer.status = f"Error selecting label: {str(e)}"
525
+
526
+ def create_label_table(self, parent_widget):
527
+ """Create a table widget displaying all detected labels."""
528
+ # Create table widget
529
+ table = QTableWidget()
530
+ table.setColumnCount(2)
531
+ table.setHorizontalHeaderLabels(["Select", "Label ID"])
532
+
533
+ # Set up the table
534
+ table.setEditTriggers(QTableWidget.NoEditTriggers)
535
+ table.setSelectionBehavior(QTableWidget.SelectRows)
536
+
537
+ # Turn off alternating colors to avoid coloring issues
538
+ table.setAlternatingRowColors(False)
539
+
540
+ # Column sizing
541
+ table.horizontalHeader().setSectionResizeMode(0, QHeaderView.Stretch)
542
+ table.horizontalHeader().setSectionResizeMode(
543
+ 1, QHeaderView.ResizeToContents
544
+ )
545
+ table.horizontalHeader().setMinimumSectionSize(80)
546
+
547
+ # Fill the table with label information
548
+ self._populate_label_table(table)
549
+
550
+ # Store reference to the table
551
+ self.label_table_widget = table
552
+
553
+ # Connect signal to make segmentation layer active when table is clicked
554
+ table.clicked.connect(lambda: self._ensure_segmentation_layer_active())
555
+
556
+ return table
557
+
558
+ def _ensure_segmentation_layer_active(self):
559
+ """Ensure the segmentation layer is the active layer."""
560
+ if self.label_layer is not None:
561
+ self.viewer.layers.selection.active = self.label_layer
562
+
563
+ def _populate_label_table(self, table):
564
+ """Populate the table with label information."""
565
+ if not self.label_info:
566
+ table.setRowCount(0)
567
+ return
568
+
569
+ # Set row count
570
+ table.setRowCount(len(self.label_info))
571
+
572
+ # Sort labels by size (largest first)
573
+ sorted_labels = sorted(
574
+ self.label_info.items(),
575
+ key=lambda item: item[1]["area"],
576
+ reverse=True,
577
+ )
578
+
579
+ # Fill table with data
580
+ for row, (label_id, _info) in enumerate(sorted_labels):
581
+ # Checkbox for selection
582
+ checkbox_widget = QWidget()
583
+ checkbox_layout = QHBoxLayout(checkbox_widget)
584
+ checkbox_layout.setContentsMargins(5, 0, 5, 0)
585
+ checkbox_layout.setAlignment(Qt.AlignCenter)
586
+
587
+ checkbox = QCheckBox()
588
+ checkbox.setChecked(label_id in self.selected_labels)
589
+
590
+ # Connect checkbox to label selection
591
+ def make_checkbox_callback(lid):
592
+ def callback(state):
593
+ if state == Qt.Checked:
594
+ self.selected_labels.add(lid)
595
+ else:
596
+ self.selected_labels.discard(lid)
597
+ self.preview_crop()
598
+
599
+ return callback
600
+
601
+ checkbox.stateChanged.connect(make_checkbox_callback(label_id))
602
+
603
+ checkbox_layout.addWidget(checkbox)
604
+ table.setCellWidget(row, 0, checkbox_widget)
605
+
606
+ # Label ID as plain text with transparent background
607
+ item = QTableWidgetItem(str(label_id))
608
+ item.setTextAlignment(Qt.AlignCenter)
609
+
610
+ # Set the background color to transparent
611
+ brush = item.background()
612
+ brush.setStyle(Qt.NoBrush)
613
+ item.setBackground(brush)
614
+
615
+ table.setItem(row, 1, item)
616
+
617
+ def _update_label_table(self):
618
+ """Update the label selection table if it exists."""
619
+ if self.label_table_widget is None:
620
+ return
621
+
622
+ # Block signals during update
623
+ self.label_table_widget.blockSignals(True)
624
+
625
+ # Update checkboxes
626
+ for row in range(self.label_table_widget.rowCount()):
627
+ # Get label ID from the visible column
628
+ label_id_item = self.label_table_widget.item(row, 1)
629
+ if label_id_item is None:
630
+ continue
631
+
632
+ label_id = int(label_id_item.text())
633
+
634
+ # Find checkbox cell
635
+ checkbox_item = self.label_table_widget.cellWidget(row, 0)
636
+ if checkbox_item is None:
637
+ continue
638
+
639
+ # Update checkbox state
640
+ checkbox = checkbox_item.findChild(QCheckBox)
641
+ if checkbox:
642
+ checkbox.setChecked(label_id in self.selected_labels)
643
+
644
+ # Unblock signals
645
+ self.label_table_widget.blockSignals(False)
646
+
647
+ def select_all_labels(self):
648
+ """Select all labels."""
649
+ if not self.label_info:
650
+ return
651
+
652
+ self.selected_labels = set(self.label_info.keys())
653
+ self._update_label_table()
654
+ self.preview_crop()
655
+ self.viewer.status = f"Selected all {len(self.selected_labels)} labels"
656
+
657
+ def clear_selection(self):
658
+ """Clear all selected labels."""
659
+ self.selected_labels = set()
660
+ self._update_label_table()
661
+ self.preview_crop()
662
+ self.viewer.status = "Cleared all selections"
663
+
664
+ # --------------------------------------------------
665
+ # Image Processing and Export
666
+ # --------------------------------------------------
667
+
668
+ def preview_crop(self, label_ids=None):
669
+ """Preview the crop result with the selected label IDs."""
670
+ if self.segmentation_result is None or self.image_layer is None:
671
+ self.viewer.status = (
672
+ "No image or segmentation available for preview."
673
+ )
674
+ return
675
+
676
+ try:
677
+ # Use provided label IDs or default to selected labels
678
+ if label_ids is None:
679
+ label_ids = self.selected_labels
680
+
681
+ # Skip if no labels are selected
682
+ if not label_ids:
683
+ # Remove previous preview if exists
684
+ for layer in list(self.viewer.layers):
685
+ if "Preview" in layer.name:
686
+ self.viewer.layers.remove(layer)
687
+
688
+ # Make sure the segmentation layer is active again
689
+ if self.label_layer is not None:
690
+ self.viewer.layers.selection.active = self.label_layer
691
+ return
692
+
693
+ # Get current image
694
+ image = self.original_image.copy()
695
+
696
+ # Create mask from selected label IDs
697
+ mask = np.zeros_like(self.segmentation_result, dtype=bool)
698
+ for label_id in label_ids:
699
+ mask |= self.segmentation_result == label_id
700
+
701
+ # Apply mask to image for preview (set everything outside mask to 0)
702
+ if len(image.shape) == 2:
703
+ # Grayscale image
704
+ preview_image = image.copy()
705
+ preview_image[~mask] = 0
706
+ else:
707
+ # Color image
708
+ preview_image = image.copy()
709
+ for c in range(preview_image.shape[2]):
710
+ preview_image[:, :, c][~mask] = 0
711
+
712
+ # Remove previous preview if exists
713
+ for layer in list(self.viewer.layers):
714
+ if "Preview" in layer.name:
715
+ self.viewer.layers.remove(layer)
716
+
717
+ # Add preview layer
718
+ if label_ids:
719
+ label_str = ", ".join(str(lid) for lid in sorted(label_ids))
720
+ self.viewer.add_image(
721
+ preview_image,
722
+ name=f"Preview (Labels: {label_str})",
723
+ opacity=0.55,
724
+ )
725
+
726
+ # Make sure the segmentation layer is active again
727
+ if self.label_layer is not None:
728
+ self.viewer.layers.selection.active = self.label_layer
729
+
730
+ except (Exception, ValueError) as e:
731
+ self.viewer.status = f"Error generating preview: {str(e)}"
732
+
733
+ def crop_with_selected_labels(self):
734
+ """Crop the current image using all selected label IDs."""
735
+ if self.segmentation_result is None or self.original_image is None:
736
+ self.viewer.status = (
737
+ "No image or segmentation available for cropping."
738
+ )
739
+ return False
740
+
741
+ if not self.selected_labels:
742
+ self.viewer.status = "No labels selected for cropping."
743
+ return False
744
+
745
+ try:
746
+ # Get current image
747
+ image = self.original_image
748
+
749
+ # Create mask from all selected label IDs
750
+ mask = np.zeros_like(self.segmentation_result, dtype=bool)
751
+ for label_id in self.selected_labels:
752
+ mask |= self.segmentation_result == label_id
753
+
754
+ # Apply mask to image (set everything outside mask to 0)
755
+ if len(image.shape) == 2:
756
+ # Grayscale image
757
+ cropped_image = image.copy()
758
+ cropped_image[~mask] = 0
759
+ else:
760
+ # Color image
761
+ cropped_image = image.copy()
762
+ for c in range(cropped_image.shape[2]):
763
+ cropped_image[:, :, c][~mask] = 0
764
+
765
+ # Save cropped image
766
+ image_path = self.images[self.current_index]
767
+ base_name, ext = os.path.splitext(image_path)
768
+ label_str = "_".join(
769
+ str(lid) for lid in sorted(self.selected_labels)
770
+ )
771
+ output_path = f"{base_name}_cropped_{label_str}{ext}"
772
+
773
+ # Save using appropriate method based on file type
774
+ if output_path.lower().endswith((".tif", ".tiff")):
775
+ imwrite(output_path, cropped_image, compression="zlib")
776
+ else:
777
+ from skimage.io import imsave
778
+
779
+ imsave(output_path, cropped_image)
780
+
781
+ self.viewer.status = f"Saved cropped image to {output_path}"
782
+
783
+ # Make sure the segmentation layer is active again
784
+ if self.label_layer is not None:
785
+ self.viewer.layers.selection.active = self.label_layer
786
+
787
+ return True
788
+
789
+ except (Exception, ValueError) as e:
790
+ self.viewer.status = f"Error cropping image: {str(e)}"
791
+ return False
792
+
793
+
794
+ # --------------------------------------------------
795
+ # UI Creation Functions
796
+ # --------------------------------------------------
797
+
798
+
799
+ def create_crop_widget(processor):
800
+ """Create the crop control widget."""
801
+ crop_widget = QWidget()
802
+ layout = QVBoxLayout()
803
+ layout.setSpacing(10) # Add more space between elements
804
+ layout.setContentsMargins(
805
+ 10, 10, 10, 10
806
+ ) # Add margins around all elements
807
+
808
+ # Instructions
809
+ instructions_label = QLabel(
810
+ "Select objects to keep in the cropped image.\n"
811
+ "You can select labels using the table below or by clicking directly on objects "
812
+ "in the image (make sure the Segmentation layer is active)."
813
+ )
814
+ instructions_label.setWordWrap(True)
815
+ layout.addWidget(instructions_label)
816
+
817
+ # Sensitivity slider
818
+ sensitivity_layout = QVBoxLayout()
819
+
820
+ # Header label
821
+ sensitivity_header_layout = QHBoxLayout()
822
+ sensitivity_label = QLabel("Segmentation Sensitivity:")
823
+ sensitivity_value_label = QLabel(f"{processor.sensitivity}")
824
+ sensitivity_header_layout.addWidget(sensitivity_label)
825
+ sensitivity_header_layout.addStretch()
826
+ sensitivity_header_layout.addWidget(sensitivity_value_label)
827
+ sensitivity_layout.addLayout(sensitivity_header_layout)
828
+
829
+ # Slider
830
+ slider_layout = QHBoxLayout()
831
+ sensitivity_slider = QSlider(Qt.Horizontal)
832
+ sensitivity_slider.setMinimum(0)
833
+ sensitivity_slider.setMaximum(100)
834
+ sensitivity_slider.setValue(processor.sensitivity)
835
+ sensitivity_slider.setTickPosition(QSlider.TicksBelow)
836
+ sensitivity_slider.setTickInterval(10)
837
+ slider_layout.addWidget(sensitivity_slider)
838
+
839
+ apply_sensitivity_button = QPushButton("Apply")
840
+ apply_sensitivity_button.setToolTip(
841
+ "Apply sensitivity changes to regenerate segmentation"
842
+ )
843
+ slider_layout.addWidget(apply_sensitivity_button)
844
+ sensitivity_layout.addLayout(slider_layout)
845
+
846
+ # Description label
847
+ sensitivity_description = QLabel(
848
+ "Medium sensitivity - Balanced detection (γ=1.00)"
849
+ )
850
+ sensitivity_description.setStyleSheet("font-style: italic; color: #666;")
851
+ sensitivity_layout.addWidget(sensitivity_description)
852
+
853
+ layout.addLayout(sensitivity_layout)
854
+
855
+ # Create label table
856
+ label_table = processor.create_label_table(crop_widget)
857
+ label_table.setMinimumHeight(150) # Reduce minimum height to save space
858
+ label_table.setMaximumHeight(
859
+ 300
860
+ ) # Set maximum height to prevent taking too much space
861
+ layout.addWidget(label_table)
862
+
863
+ # Remove "Focus on Segmentation Layer" button as it's now redundant
864
+ selection_layout = QHBoxLayout()
865
+ select_all_button = QPushButton("Select All")
866
+ clear_selection_button = QPushButton("Clear Selection")
867
+ selection_layout.addWidget(select_all_button)
868
+ selection_layout.addWidget(clear_selection_button)
869
+ layout.addLayout(selection_layout)
870
+
871
+ # Crop button
872
+ crop_button = QPushButton("Crop with Selected Objects")
873
+ layout.addWidget(crop_button)
874
+
875
+ # Navigation buttons
876
+ nav_layout = QHBoxLayout()
877
+ prev_button = QPushButton("Previous Image")
878
+ next_button = QPushButton("Next Image")
879
+ nav_layout.addWidget(prev_button)
880
+ nav_layout.addWidget(next_button)
881
+ layout.addLayout(nav_layout)
882
+
883
+ # Status label
884
+ status_label = QLabel(
885
+ "Ready to process images. Select objects using the table or by clicking on them."
886
+ )
887
+ status_label.setWordWrap(True)
888
+ layout.addWidget(status_label)
889
+
890
+ # Set layout
891
+ crop_widget.setLayout(layout)
892
+
893
+ # Function to completely replace the table widget
894
+ def replace_table_widget():
895
+ nonlocal label_table
896
+ # Remove old table
897
+ layout.removeWidget(label_table)
898
+ label_table.setParent(None)
899
+ label_table.deleteLater()
900
+
901
+ # Create new table
902
+ label_table = processor.create_label_table(crop_widget)
903
+ label_table.setMinimumHeight(200)
904
+ layout.insertWidget(3, label_table) # Insert after sensitivity slider
905
+ return label_table
906
+
907
+ # Connect button signals
908
+ def on_sensitivity_changed(value):
909
+ sensitivity_value_label.setText(f"{value}")
910
+ # Update description based on sensitivity
911
+ if value < 25:
912
+ gamma = (
913
+ 1.5 - (value / 100) * 1.0
914
+ ) # Higher gamma for low sensitivity
915
+ description = f"Low sensitivity - Seeks large, distinct objects (γ={gamma:.2f})"
916
+ elif value < 75:
917
+ gamma = 1.5 - (value / 100) * 1.0
918
+ description = (
919
+ f"Medium sensitivity - Balanced detection (γ={gamma:.2f})"
920
+ )
921
+ else:
922
+ gamma = (
923
+ 1.5 - (value / 100) * 1.0
924
+ ) # Lower gamma for high sensitivity
925
+ description = f"High sensitivity - Detects subtle, small objects (γ={gamma:.2f})"
926
+ sensitivity_description.setText(description)
927
+
928
+ def on_apply_sensitivity_clicked():
929
+ new_sensitivity = sensitivity_slider.value()
930
+ processor.generate_segmentation_with_sensitivity(new_sensitivity)
931
+ replace_table_widget()
932
+ status_label.setText(
933
+ f"Regenerated segmentation with sensitivity {new_sensitivity}"
934
+ )
935
+
936
+ def on_select_all_clicked():
937
+ processor.select_all_labels()
938
+ status_label.setText(
939
+ f"Selected all {len(processor.selected_labels)} objects"
940
+ )
941
+
942
+ def on_clear_selection_clicked():
943
+ processor.clear_selection()
944
+ status_label.setText("Selection cleared")
945
+
946
+ def on_crop_clicked():
947
+ success = processor.crop_with_selected_labels()
948
+ if success:
949
+ labels_str = ", ".join(
950
+ str(label) for label in sorted(processor.selected_labels)
951
+ )
952
+ status_label.setText(
953
+ f"Cropped image with {len(processor.selected_labels)} objects (IDs: {labels_str})"
954
+ )
955
+
956
+ def on_next_clicked():
957
+ if not processor.next_image():
958
+ next_button.setEnabled(False)
959
+ else:
960
+ prev_button.setEnabled(True)
961
+ replace_table_widget()
962
+ # Reset sensitivity slider to default
963
+ sensitivity_slider.setValue(processor.sensitivity)
964
+ sensitivity_value_label.setText(f"{processor.sensitivity}")
965
+ status_label.setText(
966
+ f"Showing image {processor.current_index + 1}/{len(processor.images)}"
967
+ )
968
+
969
+ def on_prev_clicked():
970
+ if not processor.previous_image():
971
+ prev_button.setEnabled(False)
972
+ else:
973
+ next_button.setEnabled(True)
974
+ replace_table_widget()
975
+ # Reset sensitivity slider to default
976
+ sensitivity_slider.setValue(processor.sensitivity)
977
+ sensitivity_value_label.setText(f"{processor.sensitivity}")
978
+ status_label.setText(
979
+ f"Showing image {processor.current_index + 1}/{len(processor.images)}"
980
+ )
981
+
982
+ sensitivity_slider.valueChanged.connect(on_sensitivity_changed)
983
+ apply_sensitivity_button.clicked.connect(on_apply_sensitivity_clicked)
984
+ select_all_button.clicked.connect(on_select_all_clicked)
985
+ clear_selection_button.clicked.connect(on_clear_selection_clicked)
986
+ crop_button.clicked.connect(on_crop_clicked)
987
+ next_button.clicked.connect(on_next_clicked)
988
+ prev_button.clicked.connect(on_prev_clicked)
989
+
990
+ return crop_widget
991
+
992
+
993
+ # --------------------------------------------------
994
+ # Napari Plugin Functions
995
+ # --------------------------------------------------
996
+
997
+
998
+ @magicgui(
999
+ call_button="Start Batch Crop Anything",
1000
+ folder_path={"label": "Folder Path", "widget_type": "LineEdit"},
1001
+ )
1002
+ def batch_crop_anything(
1003
+ folder_path: str,
1004
+ viewer: Viewer = None,
1005
+ ):
1006
+ """MagicGUI widget for starting Batch Crop Anything."""
1007
+ # Check if Mobile-SAM is available
1008
+ try:
1009
+ # import torch
1010
+ # from mobile_sam import sam_model_registry
1011
+
1012
+ # Check if the required files are included with the package
1013
+ try:
1014
+ import importlib.util
1015
+ import os
1016
+
1017
+ mobile_sam_spec = importlib.util.find_spec("mobile_sam")
1018
+ if mobile_sam_spec is None:
1019
+ raise ImportError("mobile_sam package not found")
1020
+
1021
+ mobile_sam_path = os.path.dirname(mobile_sam_spec.origin)
1022
+
1023
+ # Check for model file in package
1024
+ model_found = False
1025
+ checkpoint_paths = [
1026
+ os.path.join(mobile_sam_path, "weights", "mobile_sam.pt"),
1027
+ os.path.join(mobile_sam_path, "mobile_sam.pt"),
1028
+ os.path.join(
1029
+ os.path.dirname(mobile_sam_path),
1030
+ "weights",
1031
+ "mobile_sam.pt",
1032
+ ),
1033
+ os.path.join(
1034
+ os.path.expanduser("~"), "models", "mobile_sam.pt"
1035
+ ),
1036
+ "/opt/T-MIDAS/models/mobile_sam.pt",
1037
+ os.path.join(os.getcwd(), "mobile_sam.pt"),
1038
+ ]
1039
+
1040
+ for path in checkpoint_paths:
1041
+ if os.path.exists(path):
1042
+ model_found = True
1043
+ break
1044
+
1045
+ if not model_found:
1046
+ QMessageBox.warning(
1047
+ None,
1048
+ "Model File Missing",
1049
+ "Mobile-SAM model weights (mobile_sam.pt) not found. You'll be prompted to locate it when starting the tool.\n\n"
1050
+ "You can download it from: https://github.com/ChaoningZhang/MobileSAM/tree/master/weights",
1051
+ )
1052
+ except (ImportError, AttributeError) as e:
1053
+ print(f"Warning checking for model file: {str(e)}")
1054
+
1055
+ except ImportError:
1056
+ QMessageBox.critical(
1057
+ None,
1058
+ "Missing Dependency",
1059
+ "Mobile-SAM not found. Please install with:\n"
1060
+ "pip install git+https://github.com/ChaoningZhang/MobileSAM.git\n\n"
1061
+ "You'll also need to download the model weights file (mobile_sam.pt) from:\n"
1062
+ "https://github.com/ChaoningZhang/MobileSAM/tree/master/weights",
1063
+ )
1064
+ return
1065
+
1066
+ # Initialize processor and load images
1067
+ processor = BatchCropAnything(viewer)
1068
+ processor.load_images(folder_path)
1069
+
1070
+ # Create UI
1071
+ crop_widget = create_crop_widget(processor)
1072
+
1073
+ # Wrap the widget in a scroll area
1074
+ scroll_area = QScrollArea()
1075
+ scroll_area.setWidget(crop_widget)
1076
+ scroll_area.setWidgetResizable(
1077
+ True
1078
+ ) # This allows the widget to resize with the scroll area
1079
+ scroll_area.setFrameShape(QScrollArea.NoFrame) # Hide the frame
1080
+ scroll_area.setMinimumHeight(
1081
+ 500
1082
+ ) # Set a minimum height to ensure visibility
1083
+
1084
+ # Add scroll area to viewer
1085
+ viewer.window.add_dock_widget(scroll_area, name="Crop Controls")
1086
+
1087
+
1088
+ def batch_crop_anything_widget():
1089
+ """Provide the batch crop anything widget to Napari."""
1090
+ # Create the magicgui widget
1091
+ widget = batch_crop_anything
1092
+
1093
+ # Create and add browse button for folder path
1094
+ folder_browse_button = QPushButton("Browse...")
1095
+
1096
+ def on_folder_browse_clicked():
1097
+ folder = QFileDialog.getExistingDirectory(
1098
+ None,
1099
+ "Select Folder",
1100
+ os.path.expanduser("~"),
1101
+ QFileDialog.ShowDirsOnly | QFileDialog.DontResolveSymlinks,
1102
+ )
1103
+ if folder:
1104
+ # Update the folder_path field
1105
+ widget.folder_path.value = folder
1106
+
1107
+ folder_browse_button.clicked.connect(on_folder_browse_clicked)
1108
+
1109
+ # Insert the browse button next to the folder_path field
1110
+ folder_layout = widget.folder_path.native.parent().layout()
1111
+ folder_layout.addWidget(folder_browse_button)
1112
+
1113
+ return widget