lazylabel-gui 1.3.3__py3-none-any.whl → 1.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lazylabel/core/file_manager.py +1 -1
- lazylabel/models/sam2_model.py +253 -134
- lazylabel/ui/control_panel.py +7 -2
- lazylabel/ui/main_window.py +264 -593
- lazylabel/ui/photo_viewer.py +35 -11
- lazylabel/ui/widgets/channel_threshold_widget.py +8 -9
- lazylabel/ui/widgets/fft_threshold_widget.py +4 -0
- lazylabel/ui/widgets/model_selection_widget.py +9 -0
- lazylabel/ui/workers/__init__.py +15 -0
- lazylabel/ui/workers/image_discovery_worker.py +66 -0
- lazylabel/ui/workers/multi_view_sam_init_worker.py +135 -0
- lazylabel/ui/workers/multi_view_sam_update_worker.py +158 -0
- lazylabel/ui/workers/sam_update_worker.py +129 -0
- lazylabel/ui/workers/single_view_sam_init_worker.py +61 -0
- lazylabel/utils/fast_file_manager.py +422 -78
- {lazylabel_gui-1.3.3.dist-info → lazylabel_gui-1.3.5.dist-info}/METADATA +1 -1
- {lazylabel_gui-1.3.3.dist-info → lazylabel_gui-1.3.5.dist-info}/RECORD +21 -15
- {lazylabel_gui-1.3.3.dist-info → lazylabel_gui-1.3.5.dist-info}/WHEEL +0 -0
- {lazylabel_gui-1.3.3.dist-info → lazylabel_gui-1.3.5.dist-info}/entry_points.txt +0 -0
- {lazylabel_gui-1.3.3.dist-info → lazylabel_gui-1.3.5.dist-info}/licenses/LICENSE +0 -0
- {lazylabel_gui-1.3.3.dist-info → lazylabel_gui-1.3.5.dist-info}/top_level.txt +0 -0
lazylabel/ui/photo_viewer.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import cv2
|
2
2
|
import numpy as np
|
3
3
|
from PyQt6.QtCore import QRectF, Qt, pyqtSignal
|
4
|
-
from PyQt6.QtGui import QCursor, QImage, QPixmap
|
4
|
+
from PyQt6.QtGui import QCursor, QImage, QPainter, QPixmap
|
5
5
|
from PyQt6.QtWidgets import QGraphicsPixmapItem, QGraphicsScene, QGraphicsView
|
6
6
|
|
7
7
|
|
@@ -17,6 +17,14 @@ class PhotoViewer(QGraphicsView):
|
|
17
17
|
self._scene.addItem(self._pixmap_item)
|
18
18
|
self.setScene(self._scene)
|
19
19
|
|
20
|
+
# Enable proper alpha blending for transparency
|
21
|
+
self.setRenderHint(QPainter.RenderHint.Antialiasing)
|
22
|
+
self.setRenderHint(QPainter.RenderHint.SmoothPixmapTransform)
|
23
|
+
|
24
|
+
# Ensure viewport supports transparency
|
25
|
+
self.viewport().setAttribute(Qt.WidgetAttribute.WA_TranslucentBackground, True)
|
26
|
+
self.setStyleSheet("background: transparent;")
|
27
|
+
|
20
28
|
self.setTransformationAnchor(QGraphicsView.ViewportAnchor.AnchorUnderMouse)
|
21
29
|
self.setResizeAnchor(QGraphicsView.ViewportAnchor.AnchorViewCenter)
|
22
30
|
self.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
|
@@ -26,7 +34,7 @@ class PhotoViewer(QGraphicsView):
|
|
26
34
|
|
27
35
|
self._original_image = None
|
28
36
|
self._adjusted_pixmap = None
|
29
|
-
self.
|
37
|
+
self._original_image_bgra = None
|
30
38
|
|
31
39
|
def fitInView(self, scale=True):
|
32
40
|
rect = QRectF(self._pixmap_item.pixmap().rect())
|
@@ -50,6 +58,9 @@ class PhotoViewer(QGraphicsView):
|
|
50
58
|
if self._pixmap_item not in self._scene.items():
|
51
59
|
self._pixmap_item = QGraphicsPixmapItem()
|
52
60
|
self._scene.addItem(self._pixmap_item)
|
61
|
+
|
62
|
+
# PNG files are now loaded with proper alpha format at source
|
63
|
+
|
53
64
|
self._pixmap_item.setPixmap(pixmap)
|
54
65
|
|
55
66
|
# Convert QImage to ARGB32 for consistent processing
|
@@ -62,14 +73,16 @@ class PhotoViewer(QGraphicsView):
|
|
62
73
|
img_np = np.array(ptr).reshape(
|
63
74
|
converted_image.height(), converted_image.width(), 4
|
64
75
|
)
|
65
|
-
#
|
66
|
-
self.
|
76
|
+
# QImage ARGB32 is stored as BGRA in memory, keep this format
|
77
|
+
self._original_image_bgra = (
|
78
|
+
img_np.copy()
|
79
|
+
) # Make a copy to avoid memory issues
|
67
80
|
|
68
81
|
self.fitInView()
|
69
82
|
else:
|
70
83
|
self._original_image = None
|
71
84
|
self._adjusted_pixmap = None
|
72
|
-
self.
|
85
|
+
self._original_image_bgra = None
|
73
86
|
# Check if _pixmap_item still exists, recreate if deleted
|
74
87
|
if self._pixmap_item not in self._scene.items():
|
75
88
|
self._pixmap_item = QGraphicsPixmapItem()
|
@@ -77,7 +90,7 @@ class PhotoViewer(QGraphicsView):
|
|
77
90
|
self._pixmap_item.setPixmap(QPixmap())
|
78
91
|
|
79
92
|
def set_image_adjustments(self, brightness: float, contrast: float, gamma: float):
|
80
|
-
if self.
|
93
|
+
if self._original_image_bgra is None or self._original_image is None:
|
81
94
|
return
|
82
95
|
|
83
96
|
# Ensure _pixmap_item exists and is valid
|
@@ -85,9 +98,13 @@ class PhotoViewer(QGraphicsView):
|
|
85
98
|
self._pixmap_item = QGraphicsPixmapItem()
|
86
99
|
self._scene.addItem(self._pixmap_item)
|
87
100
|
|
88
|
-
|
101
|
+
img_bgra = self._original_image_bgra.copy()
|
89
102
|
|
90
|
-
#
|
103
|
+
# Separate alpha channel for transparency preservation
|
104
|
+
alpha_channel = img_bgra[:, :, 3:4] # Keep dimensions
|
105
|
+
img_bgr = img_bgra[:, :, :3] # RGB channels only
|
106
|
+
|
107
|
+
# Apply brightness and contrast to RGB channels only
|
91
108
|
# new_image = alpha * old_image + beta
|
92
109
|
adjusted_img = cv2.convertScaleAbs(
|
93
110
|
img_bgr, alpha=1 + contrast / 100.0, beta=brightness
|
@@ -101,11 +118,18 @@ class PhotoViewer(QGraphicsView):
|
|
101
118
|
).astype("uint8")
|
102
119
|
adjusted_img = cv2.LUT(adjusted_img, table)
|
103
120
|
|
104
|
-
#
|
105
|
-
|
121
|
+
# Recombine with alpha channel to preserve transparency
|
122
|
+
adjusted_bgra = np.concatenate([adjusted_img, alpha_channel], axis=2)
|
123
|
+
|
124
|
+
# Convert back to QImage with alpha support
|
125
|
+
h, w, ch = adjusted_bgra.shape
|
106
126
|
bytes_per_line = ch * w
|
107
127
|
adjusted_qimage = QImage(
|
108
|
-
|
128
|
+
adjusted_bgra.data, w, h, bytes_per_line, QImage.Format.Format_ARGB32
|
129
|
+
)
|
130
|
+
# Convert to premultiplied alpha for proper blending
|
131
|
+
adjusted_qimage = adjusted_qimage.convertToFormat(
|
132
|
+
QImage.Format.Format_ARGB32_Premultiplied
|
109
133
|
)
|
110
134
|
self._adjusted_pixmap = QPixmap.fromImage(adjusted_qimage)
|
111
135
|
|
@@ -270,10 +270,6 @@ class MultiIndicatorSlider(QWidget):
|
|
270
270
|
"""Handle right-click to remove indicator."""
|
271
271
|
slider_rect = self.get_slider_rect()
|
272
272
|
|
273
|
-
# Only allow removal if more than 1 indicator
|
274
|
-
if len(self.indicators) <= 1:
|
275
|
-
return
|
276
|
-
|
277
273
|
# Check if right-clicking on an indicator
|
278
274
|
for i, value in enumerate(self.indicators):
|
279
275
|
x = self.value_to_x(value)
|
@@ -473,7 +469,7 @@ class ChannelThresholdWidget(QWidget):
|
|
473
469
|
|
474
470
|
if self.current_image_channels == 1:
|
475
471
|
# Grayscale image
|
476
|
-
if "Gray" in self.sliders:
|
472
|
+
if "Gray" in self.sliders and self.sliders["Gray"].is_enabled():
|
477
473
|
result = self._apply_channel_thresholding(
|
478
474
|
result, self.sliders["Gray"].get_indicators()
|
479
475
|
)
|
@@ -481,7 +477,10 @@ class ChannelThresholdWidget(QWidget):
|
|
481
477
|
# RGB image
|
482
478
|
channel_names = ["Red", "Green", "Blue"]
|
483
479
|
for i, channel_name in enumerate(channel_names):
|
484
|
-
if
|
480
|
+
if (
|
481
|
+
channel_name in self.sliders
|
482
|
+
and self.sliders[channel_name].is_enabled()
|
483
|
+
):
|
485
484
|
result[:, :, i] = self._apply_channel_thresholding(
|
486
485
|
result[:, :, i], self.sliders[channel_name].get_indicators()
|
487
486
|
)
|
@@ -494,7 +493,7 @@ class ChannelThresholdWidget(QWidget):
|
|
494
493
|
if not indicators:
|
495
494
|
return channel_data
|
496
495
|
|
497
|
-
|
496
|
+
# Sort indicators
|
498
497
|
sorted_indicators = sorted(indicators)
|
499
498
|
|
500
499
|
# Create output array
|
@@ -527,9 +526,9 @@ class ChannelThresholdWidget(QWidget):
|
|
527
526
|
return result
|
528
527
|
|
529
528
|
def has_active_thresholding(self):
|
530
|
-
"""Check if any channel has active thresholding (indicators present)."""
|
529
|
+
"""Check if any channel has active thresholding (enabled and indicators present)."""
|
531
530
|
for slider_widget in self.sliders.values():
|
532
|
-
if slider_widget.get_indicators():
|
531
|
+
if slider_widget.is_enabled() and slider_widget.get_indicators():
|
533
532
|
return True
|
534
533
|
return False
|
535
534
|
|
@@ -308,6 +308,8 @@ class FFTThresholdWidget(QWidget):
|
|
308
308
|
self.status_label.setStyleSheet(
|
309
309
|
"color: #F44336; font-size: 9px; font-style: italic;"
|
310
310
|
)
|
311
|
+
# Disable FFT processing for color images
|
312
|
+
self.enable_checkbox.setChecked(False)
|
311
313
|
else:
|
312
314
|
# Unknown format
|
313
315
|
self.current_image_channels = 0
|
@@ -315,6 +317,8 @@ class FFTThresholdWidget(QWidget):
|
|
315
317
|
self.status_label.setStyleSheet(
|
316
318
|
"color: #F44336; font-size: 9px; font-style: italic;"
|
317
319
|
)
|
320
|
+
# Disable FFT processing for unsupported formats
|
321
|
+
self.enable_checkbox.setChecked(False)
|
318
322
|
|
319
323
|
def is_active(self):
|
320
324
|
"""Check if FFT processing is active (checkbox enabled and image is grayscale)."""
|
@@ -111,6 +111,15 @@ class CustomDropdown(QToolButton):
|
|
111
111
|
text, _ = self.items[index]
|
112
112
|
self.setText(text)
|
113
113
|
|
114
|
+
def count(self):
|
115
|
+
"""Get number of items."""
|
116
|
+
return len(self.items)
|
117
|
+
|
118
|
+
def currentData(self):
|
119
|
+
"""Get data of currently selected item."""
|
120
|
+
current_idx = self.currentIndex()
|
121
|
+
return self.itemData(current_idx)
|
122
|
+
|
114
123
|
def blockSignals(self, block):
|
115
124
|
"""Block/unblock signals."""
|
116
125
|
super().blockSignals(block)
|
@@ -0,0 +1,15 @@
|
|
1
|
+
"""Worker thread classes for background operations."""
|
2
|
+
|
3
|
+
from .image_discovery_worker import ImageDiscoveryWorker
|
4
|
+
from .multi_view_sam_init_worker import MultiViewSAMInitWorker
|
5
|
+
from .multi_view_sam_update_worker import MultiViewSAMUpdateWorker
|
6
|
+
from .sam_update_worker import SAMUpdateWorker
|
7
|
+
from .single_view_sam_init_worker import SingleViewSAMInitWorker
|
8
|
+
|
9
|
+
__all__ = [
|
10
|
+
"ImageDiscoveryWorker",
|
11
|
+
"MultiViewSAMInitWorker",
|
12
|
+
"MultiViewSAMUpdateWorker",
|
13
|
+
"SAMUpdateWorker",
|
14
|
+
"SingleViewSAMInitWorker",
|
15
|
+
]
|
@@ -0,0 +1,66 @@
|
|
1
|
+
"""Worker thread for discovering image files in background."""
|
2
|
+
|
3
|
+
from PyQt6.QtCore import QThread, pyqtSignal
|
4
|
+
|
5
|
+
|
6
|
+
class ImageDiscoveryWorker(QThread):
|
7
|
+
"""Worker thread for discovering all image file paths in background."""
|
8
|
+
|
9
|
+
images_discovered = pyqtSignal(list) # List of all image file paths
|
10
|
+
progress = pyqtSignal(str) # Progress message
|
11
|
+
error = pyqtSignal(str)
|
12
|
+
|
13
|
+
def __init__(self, file_model, file_manager, parent=None):
|
14
|
+
super().__init__(parent)
|
15
|
+
self.file_model = file_model
|
16
|
+
self.file_manager = file_manager
|
17
|
+
self._should_stop = False
|
18
|
+
|
19
|
+
def stop(self):
|
20
|
+
"""Request the worker to stop."""
|
21
|
+
self._should_stop = True
|
22
|
+
|
23
|
+
def run(self):
|
24
|
+
"""Discover all image file paths in background."""
|
25
|
+
try:
|
26
|
+
if self._should_stop:
|
27
|
+
return
|
28
|
+
|
29
|
+
self.progress.emit("Scanning for images...")
|
30
|
+
|
31
|
+
if (
|
32
|
+
not hasattr(self.file_model, "rootPath")
|
33
|
+
or not self.file_model.rootPath()
|
34
|
+
):
|
35
|
+
self.images_discovered.emit([])
|
36
|
+
return
|
37
|
+
|
38
|
+
all_image_paths = []
|
39
|
+
root_index = self.file_model.index(self.file_model.rootPath())
|
40
|
+
|
41
|
+
def scan_directory(parent_index):
|
42
|
+
if self._should_stop:
|
43
|
+
return
|
44
|
+
|
45
|
+
for row in range(self.file_model.rowCount(parent_index)):
|
46
|
+
if self._should_stop:
|
47
|
+
return
|
48
|
+
|
49
|
+
index = self.file_model.index(row, 0, parent_index)
|
50
|
+
if self.file_model.isDir(index):
|
51
|
+
scan_directory(index) # Recursively scan subdirectories
|
52
|
+
else:
|
53
|
+
path = self.file_model.filePath(index)
|
54
|
+
if self.file_manager.is_image_file(path):
|
55
|
+
# Simply add all image file paths without checking for NPZ
|
56
|
+
all_image_paths.append(path)
|
57
|
+
|
58
|
+
scan_directory(root_index)
|
59
|
+
|
60
|
+
if not self._should_stop:
|
61
|
+
self.progress.emit(f"Found {len(all_image_paths)} images")
|
62
|
+
self.images_discovered.emit(sorted(all_image_paths))
|
63
|
+
|
64
|
+
except Exception as e:
|
65
|
+
if not self._should_stop:
|
66
|
+
self.error.emit(f"Error discovering images: {str(e)}")
|
@@ -0,0 +1,135 @@
|
|
1
|
+
"""Worker thread for initializing multi-view SAM models in background."""
|
2
|
+
|
3
|
+
from PyQt6.QtCore import QThread, pyqtSignal
|
4
|
+
|
5
|
+
from ...utils.logger import logger
|
6
|
+
|
7
|
+
|
8
|
+
class MultiViewSAMInitWorker(QThread):
|
9
|
+
"""Worker thread for initializing multi-view SAM models in background."""
|
10
|
+
|
11
|
+
model_initialized = pyqtSignal(int, object) # viewer_index, model_instance
|
12
|
+
all_models_initialized = pyqtSignal(int) # total_models_count
|
13
|
+
error = pyqtSignal(str)
|
14
|
+
progress = pyqtSignal(int, int) # current, total
|
15
|
+
|
16
|
+
def __init__(self, model_manager, parent=None):
|
17
|
+
super().__init__(parent)
|
18
|
+
self.model_manager = model_manager
|
19
|
+
self._should_stop = False
|
20
|
+
self.models_created = []
|
21
|
+
|
22
|
+
def stop(self):
|
23
|
+
"""Request the worker to stop."""
|
24
|
+
self._should_stop = True
|
25
|
+
|
26
|
+
def run(self):
|
27
|
+
"""Initialize multi-view SAM models in background thread."""
|
28
|
+
try:
|
29
|
+
if self._should_stop:
|
30
|
+
return
|
31
|
+
|
32
|
+
# Import the required model classes
|
33
|
+
from ...models.sam_model import SamModel
|
34
|
+
|
35
|
+
try:
|
36
|
+
from ...models.sam2_model import Sam2Model
|
37
|
+
|
38
|
+
SAM2_AVAILABLE = True
|
39
|
+
except ImportError:
|
40
|
+
Sam2Model = None
|
41
|
+
SAM2_AVAILABLE = False
|
42
|
+
|
43
|
+
# Determine which type of model to create
|
44
|
+
# Get the currently selected model from the GUI
|
45
|
+
parent = self.parent()
|
46
|
+
custom_model_path = None
|
47
|
+
default_model_type = "vit_h" # fallback
|
48
|
+
|
49
|
+
if parent and hasattr(parent, "control_panel"):
|
50
|
+
# Get the selected model path from the model selection widget
|
51
|
+
model_path = parent.control_panel.model_widget.get_selected_model_path()
|
52
|
+
if model_path:
|
53
|
+
# User has selected a custom model
|
54
|
+
custom_model_path = model_path
|
55
|
+
# Detect model type from filename
|
56
|
+
default_model_type = self.model_manager.detect_model_type(
|
57
|
+
model_path
|
58
|
+
)
|
59
|
+
else:
|
60
|
+
# Using default model
|
61
|
+
default_model_type = (
|
62
|
+
parent.settings.default_model_type
|
63
|
+
if hasattr(parent, "settings")
|
64
|
+
else "vit_h"
|
65
|
+
)
|
66
|
+
|
67
|
+
is_sam2 = default_model_type.startswith("sam2")
|
68
|
+
|
69
|
+
# Create model instances for all viewers - but optimize memory usage
|
70
|
+
config = parent._get_multi_view_config()
|
71
|
+
num_viewers = config["num_viewers"]
|
72
|
+
|
73
|
+
# Warn about performance implications for VIT_H in multi-view
|
74
|
+
if num_viewers > 2 and default_model_type == "vit_h":
|
75
|
+
logger.warning(
|
76
|
+
f"Using vit_h model with {num_viewers} viewers may cause performance issues. Consider using vit_b for better performance."
|
77
|
+
)
|
78
|
+
for i in range(num_viewers):
|
79
|
+
if self._should_stop:
|
80
|
+
return
|
81
|
+
|
82
|
+
self.progress.emit(i + 1, num_viewers)
|
83
|
+
|
84
|
+
try:
|
85
|
+
if is_sam2 and SAM2_AVAILABLE:
|
86
|
+
# Create SAM2 model instance
|
87
|
+
if custom_model_path:
|
88
|
+
model_instance = Sam2Model(custom_model_path)
|
89
|
+
else:
|
90
|
+
model_instance = Sam2Model(model_type=default_model_type)
|
91
|
+
else:
|
92
|
+
# Create SAM1 model instance
|
93
|
+
if custom_model_path:
|
94
|
+
model_instance = SamModel(
|
95
|
+
model_type=default_model_type,
|
96
|
+
custom_model_path=custom_model_path,
|
97
|
+
)
|
98
|
+
else:
|
99
|
+
model_instance = SamModel(model_type=default_model_type)
|
100
|
+
|
101
|
+
if self._should_stop:
|
102
|
+
return
|
103
|
+
|
104
|
+
if model_instance and getattr(model_instance, "is_loaded", False):
|
105
|
+
self.models_created.append(model_instance)
|
106
|
+
self.model_initialized.emit(i, model_instance)
|
107
|
+
|
108
|
+
# Synchronize and clear GPU cache after each model for stability
|
109
|
+
try:
|
110
|
+
import torch
|
111
|
+
|
112
|
+
if torch.cuda.is_available():
|
113
|
+
torch.cuda.synchronize() # Ensure model is fully loaded
|
114
|
+
torch.cuda.empty_cache()
|
115
|
+
except ImportError:
|
116
|
+
pass # PyTorch not available
|
117
|
+
else:
|
118
|
+
raise Exception(f"Model instance {i + 1} failed to load")
|
119
|
+
|
120
|
+
except Exception as model_error:
|
121
|
+
logger.error(
|
122
|
+
f"Error creating model instance {i + 1}: {model_error}"
|
123
|
+
)
|
124
|
+
if not self._should_stop:
|
125
|
+
self.error.emit(
|
126
|
+
f"Failed to create model instance {i + 1}: {model_error}"
|
127
|
+
)
|
128
|
+
return
|
129
|
+
|
130
|
+
if not self._should_stop:
|
131
|
+
self.all_models_initialized.emit(len(self.models_created))
|
132
|
+
|
133
|
+
except Exception as e:
|
134
|
+
if not self._should_stop:
|
135
|
+
self.error.emit(str(e))
|
@@ -0,0 +1,158 @@
|
|
1
|
+
"""Worker thread for updating SAM model image in multi-view mode."""
|
2
|
+
|
3
|
+
import cv2
|
4
|
+
import numpy as np
|
5
|
+
from PyQt6.QtCore import Qt, QThread, pyqtSignal
|
6
|
+
from PyQt6.QtGui import QPixmap
|
7
|
+
|
8
|
+
|
9
|
+
class MultiViewSAMUpdateWorker(QThread):
|
10
|
+
"""Worker thread for updating SAM model image in multi-view mode."""
|
11
|
+
|
12
|
+
finished = pyqtSignal(int) # viewer_index
|
13
|
+
error = pyqtSignal(int, str) # viewer_index, error_message
|
14
|
+
|
15
|
+
def __init__(
|
16
|
+
self,
|
17
|
+
viewer_index,
|
18
|
+
model,
|
19
|
+
image_path,
|
20
|
+
operate_on_view=False,
|
21
|
+
current_image=None,
|
22
|
+
parent=None,
|
23
|
+
):
|
24
|
+
super().__init__(parent)
|
25
|
+
self.viewer_index = viewer_index
|
26
|
+
self.model = model
|
27
|
+
self.image_path = image_path
|
28
|
+
self.operate_on_view = operate_on_view
|
29
|
+
self.current_image = current_image
|
30
|
+
self._should_stop = False
|
31
|
+
self.scale_factor = 1.0
|
32
|
+
|
33
|
+
def stop(self):
|
34
|
+
"""Request the worker to stop."""
|
35
|
+
self._should_stop = True
|
36
|
+
|
37
|
+
def get_scale_factor(self):
|
38
|
+
"""Get the scale factor used for image resizing."""
|
39
|
+
return self.scale_factor
|
40
|
+
|
41
|
+
def run(self):
|
42
|
+
"""Update SAM model image in background thread."""
|
43
|
+
try:
|
44
|
+
if self._should_stop:
|
45
|
+
return
|
46
|
+
|
47
|
+
# Clear GPU cache to reduce memory pressure in multi-view mode
|
48
|
+
try:
|
49
|
+
import torch
|
50
|
+
|
51
|
+
if torch.cuda.is_available():
|
52
|
+
torch.cuda.empty_cache()
|
53
|
+
except ImportError:
|
54
|
+
pass # PyTorch not available
|
55
|
+
|
56
|
+
if self.operate_on_view and self.current_image is not None:
|
57
|
+
# Use the provided modified image
|
58
|
+
if self._should_stop:
|
59
|
+
return
|
60
|
+
|
61
|
+
# Optimize image size for faster SAM processing
|
62
|
+
image = self.current_image
|
63
|
+
original_height, original_width = image.shape[:2]
|
64
|
+
max_size = 1024
|
65
|
+
|
66
|
+
if original_height > max_size or original_width > max_size:
|
67
|
+
# Calculate scaling factor
|
68
|
+
self.scale_factor = min(
|
69
|
+
max_size / original_width, max_size / original_height
|
70
|
+
)
|
71
|
+
new_width = int(original_width * self.scale_factor)
|
72
|
+
new_height = int(original_height * self.scale_factor)
|
73
|
+
|
74
|
+
# Resize using OpenCV for speed
|
75
|
+
image = cv2.resize(
|
76
|
+
image, (new_width, new_height), interpolation=cv2.INTER_AREA
|
77
|
+
)
|
78
|
+
else:
|
79
|
+
self.scale_factor = 1.0
|
80
|
+
|
81
|
+
if self._should_stop:
|
82
|
+
return
|
83
|
+
|
84
|
+
# Set image from numpy array
|
85
|
+
self.model.set_image_from_array(image)
|
86
|
+
else:
|
87
|
+
# Load original image
|
88
|
+
if self._should_stop:
|
89
|
+
return
|
90
|
+
|
91
|
+
# Optimize image size for faster SAM processing
|
92
|
+
pixmap = QPixmap(self.image_path)
|
93
|
+
if pixmap.isNull():
|
94
|
+
if not self._should_stop:
|
95
|
+
self.error.emit(self.viewer_index, "Failed to load image")
|
96
|
+
return
|
97
|
+
|
98
|
+
original_width = pixmap.width()
|
99
|
+
original_height = pixmap.height()
|
100
|
+
max_size = 1024
|
101
|
+
|
102
|
+
if original_width > max_size or original_height > max_size:
|
103
|
+
# Calculate scaling factor
|
104
|
+
self.scale_factor = min(
|
105
|
+
max_size / original_width, max_size / original_height
|
106
|
+
)
|
107
|
+
|
108
|
+
# Scale down while maintaining aspect ratio
|
109
|
+
scaled_pixmap = pixmap.scaled(
|
110
|
+
max_size,
|
111
|
+
max_size,
|
112
|
+
Qt.AspectRatioMode.KeepAspectRatio,
|
113
|
+
Qt.TransformationMode.SmoothTransformation,
|
114
|
+
)
|
115
|
+
|
116
|
+
# Convert to numpy array for SAM
|
117
|
+
qimage = scaled_pixmap.toImage()
|
118
|
+
width = qimage.width()
|
119
|
+
height = qimage.height()
|
120
|
+
ptr = qimage.bits()
|
121
|
+
ptr.setsize(height * width * 4)
|
122
|
+
arr = np.array(ptr).reshape(height, width, 4)
|
123
|
+
# Convert RGBA to RGB
|
124
|
+
image_array = arr[:, :, :3]
|
125
|
+
|
126
|
+
if self._should_stop:
|
127
|
+
return
|
128
|
+
|
129
|
+
# Add CUDA synchronization for multi-model scenarios
|
130
|
+
try:
|
131
|
+
import torch
|
132
|
+
|
133
|
+
if torch.cuda.is_available():
|
134
|
+
torch.cuda.synchronize()
|
135
|
+
except ImportError:
|
136
|
+
pass
|
137
|
+
|
138
|
+
self.model.set_image_from_array(image_array)
|
139
|
+
else:
|
140
|
+
self.scale_factor = 1.0
|
141
|
+
|
142
|
+
# Add CUDA synchronization for multi-model scenarios
|
143
|
+
try:
|
144
|
+
import torch
|
145
|
+
|
146
|
+
if torch.cuda.is_available():
|
147
|
+
torch.cuda.synchronize()
|
148
|
+
except ImportError:
|
149
|
+
pass
|
150
|
+
|
151
|
+
self.model.set_image_from_path(self.image_path)
|
152
|
+
|
153
|
+
if not self._should_stop:
|
154
|
+
self.finished.emit(self.viewer_index)
|
155
|
+
|
156
|
+
except Exception as e:
|
157
|
+
if not self._should_stop:
|
158
|
+
self.error.emit(self.viewer_index, str(e))
|