lazylabel-gui 1.3.4__py3-none-any.whl → 1.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lazylabel/models/sam2_model.py +253 -134
- lazylabel/ui/main_window.py +100 -530
- lazylabel/ui/photo_viewer.py +35 -11
- lazylabel/ui/widgets/channel_threshold_widget.py +18 -4
- lazylabel/ui/workers/__init__.py +15 -0
- lazylabel/ui/workers/image_discovery_worker.py +66 -0
- lazylabel/ui/workers/multi_view_sam_init_worker.py +135 -0
- lazylabel/ui/workers/multi_view_sam_update_worker.py +158 -0
- lazylabel/ui/workers/sam_update_worker.py +129 -0
- lazylabel/ui/workers/single_view_sam_init_worker.py +61 -0
- {lazylabel_gui-1.3.4.dist-info → lazylabel_gui-1.3.6.dist-info}/METADATA +52 -49
- {lazylabel_gui-1.3.4.dist-info → lazylabel_gui-1.3.6.dist-info}/RECORD +16 -10
- {lazylabel_gui-1.3.4.dist-info → lazylabel_gui-1.3.6.dist-info}/WHEEL +0 -0
- {lazylabel_gui-1.3.4.dist-info → lazylabel_gui-1.3.6.dist-info}/entry_points.txt +0 -0
- {lazylabel_gui-1.3.4.dist-info → lazylabel_gui-1.3.6.dist-info}/licenses/LICENSE +0 -0
- {lazylabel_gui-1.3.4.dist-info → lazylabel_gui-1.3.6.dist-info}/top_level.txt +0 -0
lazylabel/ui/photo_viewer.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import cv2
|
2
2
|
import numpy as np
|
3
3
|
from PyQt6.QtCore import QRectF, Qt, pyqtSignal
|
4
|
-
from PyQt6.QtGui import QCursor, QImage, QPixmap
|
4
|
+
from PyQt6.QtGui import QCursor, QImage, QPainter, QPixmap
|
5
5
|
from PyQt6.QtWidgets import QGraphicsPixmapItem, QGraphicsScene, QGraphicsView
|
6
6
|
|
7
7
|
|
@@ -17,6 +17,14 @@ class PhotoViewer(QGraphicsView):
|
|
17
17
|
self._scene.addItem(self._pixmap_item)
|
18
18
|
self.setScene(self._scene)
|
19
19
|
|
20
|
+
# Enable proper alpha blending for transparency
|
21
|
+
self.setRenderHint(QPainter.RenderHint.Antialiasing)
|
22
|
+
self.setRenderHint(QPainter.RenderHint.SmoothPixmapTransform)
|
23
|
+
|
24
|
+
# Ensure viewport supports transparency
|
25
|
+
self.viewport().setAttribute(Qt.WidgetAttribute.WA_TranslucentBackground, True)
|
26
|
+
self.setStyleSheet("background: transparent;")
|
27
|
+
|
20
28
|
self.setTransformationAnchor(QGraphicsView.ViewportAnchor.AnchorUnderMouse)
|
21
29
|
self.setResizeAnchor(QGraphicsView.ViewportAnchor.AnchorViewCenter)
|
22
30
|
self.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
|
@@ -26,7 +34,7 @@ class PhotoViewer(QGraphicsView):
|
|
26
34
|
|
27
35
|
self._original_image = None
|
28
36
|
self._adjusted_pixmap = None
|
29
|
-
self.
|
37
|
+
self._original_image_bgra = None
|
30
38
|
|
31
39
|
def fitInView(self, scale=True):
|
32
40
|
rect = QRectF(self._pixmap_item.pixmap().rect())
|
@@ -50,6 +58,9 @@ class PhotoViewer(QGraphicsView):
|
|
50
58
|
if self._pixmap_item not in self._scene.items():
|
51
59
|
self._pixmap_item = QGraphicsPixmapItem()
|
52
60
|
self._scene.addItem(self._pixmap_item)
|
61
|
+
|
62
|
+
# PNG files are now loaded with proper alpha format at source
|
63
|
+
|
53
64
|
self._pixmap_item.setPixmap(pixmap)
|
54
65
|
|
55
66
|
# Convert QImage to ARGB32 for consistent processing
|
@@ -62,14 +73,16 @@ class PhotoViewer(QGraphicsView):
|
|
62
73
|
img_np = np.array(ptr).reshape(
|
63
74
|
converted_image.height(), converted_image.width(), 4
|
64
75
|
)
|
65
|
-
#
|
66
|
-
self.
|
76
|
+
# QImage ARGB32 is stored as BGRA in memory, keep this format
|
77
|
+
self._original_image_bgra = (
|
78
|
+
img_np.copy()
|
79
|
+
) # Make a copy to avoid memory issues
|
67
80
|
|
68
81
|
self.fitInView()
|
69
82
|
else:
|
70
83
|
self._original_image = None
|
71
84
|
self._adjusted_pixmap = None
|
72
|
-
self.
|
85
|
+
self._original_image_bgra = None
|
73
86
|
# Check if _pixmap_item still exists, recreate if deleted
|
74
87
|
if self._pixmap_item not in self._scene.items():
|
75
88
|
self._pixmap_item = QGraphicsPixmapItem()
|
@@ -77,7 +90,7 @@ class PhotoViewer(QGraphicsView):
|
|
77
90
|
self._pixmap_item.setPixmap(QPixmap())
|
78
91
|
|
79
92
|
def set_image_adjustments(self, brightness: float, contrast: float, gamma: float):
|
80
|
-
if self.
|
93
|
+
if self._original_image_bgra is None or self._original_image is None:
|
81
94
|
return
|
82
95
|
|
83
96
|
# Ensure _pixmap_item exists and is valid
|
@@ -85,9 +98,13 @@ class PhotoViewer(QGraphicsView):
|
|
85
98
|
self._pixmap_item = QGraphicsPixmapItem()
|
86
99
|
self._scene.addItem(self._pixmap_item)
|
87
100
|
|
88
|
-
|
101
|
+
img_bgra = self._original_image_bgra.copy()
|
89
102
|
|
90
|
-
#
|
103
|
+
# Separate alpha channel for transparency preservation
|
104
|
+
alpha_channel = img_bgra[:, :, 3:4] # Keep dimensions
|
105
|
+
img_bgr = img_bgra[:, :, :3] # RGB channels only
|
106
|
+
|
107
|
+
# Apply brightness and contrast to RGB channels only
|
91
108
|
# new_image = alpha * old_image + beta
|
92
109
|
adjusted_img = cv2.convertScaleAbs(
|
93
110
|
img_bgr, alpha=1 + contrast / 100.0, beta=brightness
|
@@ -101,11 +118,18 @@ class PhotoViewer(QGraphicsView):
|
|
101
118
|
).astype("uint8")
|
102
119
|
adjusted_img = cv2.LUT(adjusted_img, table)
|
103
120
|
|
104
|
-
#
|
105
|
-
|
121
|
+
# Recombine with alpha channel to preserve transparency
|
122
|
+
adjusted_bgra = np.concatenate([adjusted_img, alpha_channel], axis=2)
|
123
|
+
|
124
|
+
# Convert back to QImage with alpha support
|
125
|
+
h, w, ch = adjusted_bgra.shape
|
106
126
|
bytes_per_line = ch * w
|
107
127
|
adjusted_qimage = QImage(
|
108
|
-
|
128
|
+
adjusted_bgra.data, w, h, bytes_per_line, QImage.Format.Format_ARGB32
|
129
|
+
)
|
130
|
+
# Convert to premultiplied alpha for proper blending
|
131
|
+
adjusted_qimage = adjusted_qimage.convertToFormat(
|
132
|
+
QImage.Format.Format_ARGB32_Premultiplied
|
109
133
|
)
|
110
134
|
self._adjusted_pixmap = QPixmap.fromImage(adjusted_qimage)
|
111
135
|
|
@@ -372,6 +372,7 @@ class ChannelThresholdWidget(QWidget):
|
|
372
372
|
super().__init__(parent)
|
373
373
|
self.current_image_channels = 0 # 0 = no image, 1 = grayscale, 3 = RGB
|
374
374
|
self.sliders = {} # Dictionary of channel name -> slider
|
375
|
+
self.is_dragging = False # Track if any slider is being dragged
|
375
376
|
|
376
377
|
self.setupUI()
|
377
378
|
|
@@ -434,10 +435,8 @@ class ChannelThresholdWidget(QWidget):
|
|
434
435
|
for channel_name in channel_names:
|
435
436
|
slider_widget = ChannelSliderWidget(channel_name, self)
|
436
437
|
slider_widget.valueChanged.connect(self._on_slider_changed)
|
437
|
-
slider_widget.dragStarted.connect(
|
438
|
-
|
439
|
-
) # Forward drag signals
|
440
|
-
slider_widget.dragFinished.connect(self.dragFinished.emit)
|
438
|
+
slider_widget.dragStarted.connect(self._on_drag_started)
|
439
|
+
slider_widget.dragFinished.connect(self._on_drag_finished)
|
441
440
|
self.sliders[channel_name] = slider_widget
|
442
441
|
self.sliders_layout.addWidget(slider_widget)
|
443
442
|
|
@@ -450,6 +449,21 @@ class ChannelThresholdWidget(QWidget):
|
|
450
449
|
|
451
450
|
def _on_slider_changed(self):
|
452
451
|
"""Handle slider value change."""
|
452
|
+
# Only emit thresholdChanged if not currently dragging
|
453
|
+
# This prevents expensive calculations during drag operations
|
454
|
+
if not self.is_dragging:
|
455
|
+
self.thresholdChanged.emit()
|
456
|
+
|
457
|
+
def _on_drag_started(self):
|
458
|
+
"""Handle drag start - suppress expensive calculations during drag."""
|
459
|
+
self.is_dragging = True
|
460
|
+
self.dragStarted.emit()
|
461
|
+
|
462
|
+
def _on_drag_finished(self):
|
463
|
+
"""Handle drag finish - perform final calculation."""
|
464
|
+
self.is_dragging = False
|
465
|
+
self.dragFinished.emit()
|
466
|
+
# Emit threshold changed now that dragging is complete
|
453
467
|
self.thresholdChanged.emit()
|
454
468
|
|
455
469
|
def get_threshold_settings(self):
|
@@ -0,0 +1,15 @@
|
|
1
|
+
"""Worker thread classes for background operations."""
|
2
|
+
|
3
|
+
from .image_discovery_worker import ImageDiscoveryWorker
|
4
|
+
from .multi_view_sam_init_worker import MultiViewSAMInitWorker
|
5
|
+
from .multi_view_sam_update_worker import MultiViewSAMUpdateWorker
|
6
|
+
from .sam_update_worker import SAMUpdateWorker
|
7
|
+
from .single_view_sam_init_worker import SingleViewSAMInitWorker
|
8
|
+
|
9
|
+
__all__ = [
|
10
|
+
"ImageDiscoveryWorker",
|
11
|
+
"MultiViewSAMInitWorker",
|
12
|
+
"MultiViewSAMUpdateWorker",
|
13
|
+
"SAMUpdateWorker",
|
14
|
+
"SingleViewSAMInitWorker",
|
15
|
+
]
|
@@ -0,0 +1,66 @@
|
|
1
|
+
"""Worker thread for discovering image files in background."""
|
2
|
+
|
3
|
+
from PyQt6.QtCore import QThread, pyqtSignal
|
4
|
+
|
5
|
+
|
6
|
+
class ImageDiscoveryWorker(QThread):
|
7
|
+
"""Worker thread for discovering all image file paths in background."""
|
8
|
+
|
9
|
+
images_discovered = pyqtSignal(list) # List of all image file paths
|
10
|
+
progress = pyqtSignal(str) # Progress message
|
11
|
+
error = pyqtSignal(str)
|
12
|
+
|
13
|
+
def __init__(self, file_model, file_manager, parent=None):
|
14
|
+
super().__init__(parent)
|
15
|
+
self.file_model = file_model
|
16
|
+
self.file_manager = file_manager
|
17
|
+
self._should_stop = False
|
18
|
+
|
19
|
+
def stop(self):
|
20
|
+
"""Request the worker to stop."""
|
21
|
+
self._should_stop = True
|
22
|
+
|
23
|
+
def run(self):
|
24
|
+
"""Discover all image file paths in background."""
|
25
|
+
try:
|
26
|
+
if self._should_stop:
|
27
|
+
return
|
28
|
+
|
29
|
+
self.progress.emit("Scanning for images...")
|
30
|
+
|
31
|
+
if (
|
32
|
+
not hasattr(self.file_model, "rootPath")
|
33
|
+
or not self.file_model.rootPath()
|
34
|
+
):
|
35
|
+
self.images_discovered.emit([])
|
36
|
+
return
|
37
|
+
|
38
|
+
all_image_paths = []
|
39
|
+
root_index = self.file_model.index(self.file_model.rootPath())
|
40
|
+
|
41
|
+
def scan_directory(parent_index):
|
42
|
+
if self._should_stop:
|
43
|
+
return
|
44
|
+
|
45
|
+
for row in range(self.file_model.rowCount(parent_index)):
|
46
|
+
if self._should_stop:
|
47
|
+
return
|
48
|
+
|
49
|
+
index = self.file_model.index(row, 0, parent_index)
|
50
|
+
if self.file_model.isDir(index):
|
51
|
+
scan_directory(index) # Recursively scan subdirectories
|
52
|
+
else:
|
53
|
+
path = self.file_model.filePath(index)
|
54
|
+
if self.file_manager.is_image_file(path):
|
55
|
+
# Simply add all image file paths without checking for NPZ
|
56
|
+
all_image_paths.append(path)
|
57
|
+
|
58
|
+
scan_directory(root_index)
|
59
|
+
|
60
|
+
if not self._should_stop:
|
61
|
+
self.progress.emit(f"Found {len(all_image_paths)} images")
|
62
|
+
self.images_discovered.emit(sorted(all_image_paths))
|
63
|
+
|
64
|
+
except Exception as e:
|
65
|
+
if not self._should_stop:
|
66
|
+
self.error.emit(f"Error discovering images: {str(e)}")
|
@@ -0,0 +1,135 @@
|
|
1
|
+
"""Worker thread for initializing multi-view SAM models in background."""
|
2
|
+
|
3
|
+
from PyQt6.QtCore import QThread, pyqtSignal
|
4
|
+
|
5
|
+
from ...utils.logger import logger
|
6
|
+
|
7
|
+
|
8
|
+
class MultiViewSAMInitWorker(QThread):
|
9
|
+
"""Worker thread for initializing multi-view SAM models in background."""
|
10
|
+
|
11
|
+
model_initialized = pyqtSignal(int, object) # viewer_index, model_instance
|
12
|
+
all_models_initialized = pyqtSignal(int) # total_models_count
|
13
|
+
error = pyqtSignal(str)
|
14
|
+
progress = pyqtSignal(int, int) # current, total
|
15
|
+
|
16
|
+
def __init__(self, model_manager, parent=None):
|
17
|
+
super().__init__(parent)
|
18
|
+
self.model_manager = model_manager
|
19
|
+
self._should_stop = False
|
20
|
+
self.models_created = []
|
21
|
+
|
22
|
+
def stop(self):
|
23
|
+
"""Request the worker to stop."""
|
24
|
+
self._should_stop = True
|
25
|
+
|
26
|
+
def run(self):
|
27
|
+
"""Initialize multi-view SAM models in background thread."""
|
28
|
+
try:
|
29
|
+
if self._should_stop:
|
30
|
+
return
|
31
|
+
|
32
|
+
# Import the required model classes
|
33
|
+
from ...models.sam_model import SamModel
|
34
|
+
|
35
|
+
try:
|
36
|
+
from ...models.sam2_model import Sam2Model
|
37
|
+
|
38
|
+
SAM2_AVAILABLE = True
|
39
|
+
except ImportError:
|
40
|
+
Sam2Model = None
|
41
|
+
SAM2_AVAILABLE = False
|
42
|
+
|
43
|
+
# Determine which type of model to create
|
44
|
+
# Get the currently selected model from the GUI
|
45
|
+
parent = self.parent()
|
46
|
+
custom_model_path = None
|
47
|
+
default_model_type = "vit_h" # fallback
|
48
|
+
|
49
|
+
if parent and hasattr(parent, "control_panel"):
|
50
|
+
# Get the selected model path from the model selection widget
|
51
|
+
model_path = parent.control_panel.model_widget.get_selected_model_path()
|
52
|
+
if model_path:
|
53
|
+
# User has selected a custom model
|
54
|
+
custom_model_path = model_path
|
55
|
+
# Detect model type from filename
|
56
|
+
default_model_type = self.model_manager.detect_model_type(
|
57
|
+
model_path
|
58
|
+
)
|
59
|
+
else:
|
60
|
+
# Using default model
|
61
|
+
default_model_type = (
|
62
|
+
parent.settings.default_model_type
|
63
|
+
if hasattr(parent, "settings")
|
64
|
+
else "vit_h"
|
65
|
+
)
|
66
|
+
|
67
|
+
is_sam2 = default_model_type.startswith("sam2")
|
68
|
+
|
69
|
+
# Create model instances for all viewers - but optimize memory usage
|
70
|
+
config = parent._get_multi_view_config()
|
71
|
+
num_viewers = config["num_viewers"]
|
72
|
+
|
73
|
+
# Warn about performance implications for VIT_H in multi-view
|
74
|
+
if num_viewers > 2 and default_model_type == "vit_h":
|
75
|
+
logger.warning(
|
76
|
+
f"Using vit_h model with {num_viewers} viewers may cause performance issues. Consider using vit_b for better performance."
|
77
|
+
)
|
78
|
+
for i in range(num_viewers):
|
79
|
+
if self._should_stop:
|
80
|
+
return
|
81
|
+
|
82
|
+
self.progress.emit(i + 1, num_viewers)
|
83
|
+
|
84
|
+
try:
|
85
|
+
if is_sam2 and SAM2_AVAILABLE:
|
86
|
+
# Create SAM2 model instance
|
87
|
+
if custom_model_path:
|
88
|
+
model_instance = Sam2Model(custom_model_path)
|
89
|
+
else:
|
90
|
+
model_instance = Sam2Model(model_type=default_model_type)
|
91
|
+
else:
|
92
|
+
# Create SAM1 model instance
|
93
|
+
if custom_model_path:
|
94
|
+
model_instance = SamModel(
|
95
|
+
model_type=default_model_type,
|
96
|
+
custom_model_path=custom_model_path,
|
97
|
+
)
|
98
|
+
else:
|
99
|
+
model_instance = SamModel(model_type=default_model_type)
|
100
|
+
|
101
|
+
if self._should_stop:
|
102
|
+
return
|
103
|
+
|
104
|
+
if model_instance and getattr(model_instance, "is_loaded", False):
|
105
|
+
self.models_created.append(model_instance)
|
106
|
+
self.model_initialized.emit(i, model_instance)
|
107
|
+
|
108
|
+
# Synchronize and clear GPU cache after each model for stability
|
109
|
+
try:
|
110
|
+
import torch
|
111
|
+
|
112
|
+
if torch.cuda.is_available():
|
113
|
+
torch.cuda.synchronize() # Ensure model is fully loaded
|
114
|
+
torch.cuda.empty_cache()
|
115
|
+
except ImportError:
|
116
|
+
pass # PyTorch not available
|
117
|
+
else:
|
118
|
+
raise Exception(f"Model instance {i + 1} failed to load")
|
119
|
+
|
120
|
+
except Exception as model_error:
|
121
|
+
logger.error(
|
122
|
+
f"Error creating model instance {i + 1}: {model_error}"
|
123
|
+
)
|
124
|
+
if not self._should_stop:
|
125
|
+
self.error.emit(
|
126
|
+
f"Failed to create model instance {i + 1}: {model_error}"
|
127
|
+
)
|
128
|
+
return
|
129
|
+
|
130
|
+
if not self._should_stop:
|
131
|
+
self.all_models_initialized.emit(len(self.models_created))
|
132
|
+
|
133
|
+
except Exception as e:
|
134
|
+
if not self._should_stop:
|
135
|
+
self.error.emit(str(e))
|
@@ -0,0 +1,158 @@
|
|
1
|
+
"""Worker thread for updating SAM model image in multi-view mode."""
|
2
|
+
|
3
|
+
import cv2
|
4
|
+
import numpy as np
|
5
|
+
from PyQt6.QtCore import Qt, QThread, pyqtSignal
|
6
|
+
from PyQt6.QtGui import QPixmap
|
7
|
+
|
8
|
+
|
9
|
+
class MultiViewSAMUpdateWorker(QThread):
|
10
|
+
"""Worker thread for updating SAM model image in multi-view mode."""
|
11
|
+
|
12
|
+
finished = pyqtSignal(int) # viewer_index
|
13
|
+
error = pyqtSignal(int, str) # viewer_index, error_message
|
14
|
+
|
15
|
+
def __init__(
|
16
|
+
self,
|
17
|
+
viewer_index,
|
18
|
+
model,
|
19
|
+
image_path,
|
20
|
+
operate_on_view=False,
|
21
|
+
current_image=None,
|
22
|
+
parent=None,
|
23
|
+
):
|
24
|
+
super().__init__(parent)
|
25
|
+
self.viewer_index = viewer_index
|
26
|
+
self.model = model
|
27
|
+
self.image_path = image_path
|
28
|
+
self.operate_on_view = operate_on_view
|
29
|
+
self.current_image = current_image
|
30
|
+
self._should_stop = False
|
31
|
+
self.scale_factor = 1.0
|
32
|
+
|
33
|
+
def stop(self):
|
34
|
+
"""Request the worker to stop."""
|
35
|
+
self._should_stop = True
|
36
|
+
|
37
|
+
def get_scale_factor(self):
|
38
|
+
"""Get the scale factor used for image resizing."""
|
39
|
+
return self.scale_factor
|
40
|
+
|
41
|
+
def run(self):
|
42
|
+
"""Update SAM model image in background thread."""
|
43
|
+
try:
|
44
|
+
if self._should_stop:
|
45
|
+
return
|
46
|
+
|
47
|
+
# Clear GPU cache to reduce memory pressure in multi-view mode
|
48
|
+
try:
|
49
|
+
import torch
|
50
|
+
|
51
|
+
if torch.cuda.is_available():
|
52
|
+
torch.cuda.empty_cache()
|
53
|
+
except ImportError:
|
54
|
+
pass # PyTorch not available
|
55
|
+
|
56
|
+
if self.operate_on_view and self.current_image is not None:
|
57
|
+
# Use the provided modified image
|
58
|
+
if self._should_stop:
|
59
|
+
return
|
60
|
+
|
61
|
+
# Optimize image size for faster SAM processing
|
62
|
+
image = self.current_image
|
63
|
+
original_height, original_width = image.shape[:2]
|
64
|
+
max_size = 1024
|
65
|
+
|
66
|
+
if original_height > max_size or original_width > max_size:
|
67
|
+
# Calculate scaling factor
|
68
|
+
self.scale_factor = min(
|
69
|
+
max_size / original_width, max_size / original_height
|
70
|
+
)
|
71
|
+
new_width = int(original_width * self.scale_factor)
|
72
|
+
new_height = int(original_height * self.scale_factor)
|
73
|
+
|
74
|
+
# Resize using OpenCV for speed
|
75
|
+
image = cv2.resize(
|
76
|
+
image, (new_width, new_height), interpolation=cv2.INTER_AREA
|
77
|
+
)
|
78
|
+
else:
|
79
|
+
self.scale_factor = 1.0
|
80
|
+
|
81
|
+
if self._should_stop:
|
82
|
+
return
|
83
|
+
|
84
|
+
# Set image from numpy array
|
85
|
+
self.model.set_image_from_array(image)
|
86
|
+
else:
|
87
|
+
# Load original image
|
88
|
+
if self._should_stop:
|
89
|
+
return
|
90
|
+
|
91
|
+
# Optimize image size for faster SAM processing
|
92
|
+
pixmap = QPixmap(self.image_path)
|
93
|
+
if pixmap.isNull():
|
94
|
+
if not self._should_stop:
|
95
|
+
self.error.emit(self.viewer_index, "Failed to load image")
|
96
|
+
return
|
97
|
+
|
98
|
+
original_width = pixmap.width()
|
99
|
+
original_height = pixmap.height()
|
100
|
+
max_size = 1024
|
101
|
+
|
102
|
+
if original_width > max_size or original_height > max_size:
|
103
|
+
# Calculate scaling factor
|
104
|
+
self.scale_factor = min(
|
105
|
+
max_size / original_width, max_size / original_height
|
106
|
+
)
|
107
|
+
|
108
|
+
# Scale down while maintaining aspect ratio
|
109
|
+
scaled_pixmap = pixmap.scaled(
|
110
|
+
max_size,
|
111
|
+
max_size,
|
112
|
+
Qt.AspectRatioMode.KeepAspectRatio,
|
113
|
+
Qt.TransformationMode.SmoothTransformation,
|
114
|
+
)
|
115
|
+
|
116
|
+
# Convert to numpy array for SAM
|
117
|
+
qimage = scaled_pixmap.toImage()
|
118
|
+
width = qimage.width()
|
119
|
+
height = qimage.height()
|
120
|
+
ptr = qimage.bits()
|
121
|
+
ptr.setsize(height * width * 4)
|
122
|
+
arr = np.array(ptr).reshape(height, width, 4)
|
123
|
+
# Convert RGBA to RGB
|
124
|
+
image_array = arr[:, :, :3]
|
125
|
+
|
126
|
+
if self._should_stop:
|
127
|
+
return
|
128
|
+
|
129
|
+
# Add CUDA synchronization for multi-model scenarios
|
130
|
+
try:
|
131
|
+
import torch
|
132
|
+
|
133
|
+
if torch.cuda.is_available():
|
134
|
+
torch.cuda.synchronize()
|
135
|
+
except ImportError:
|
136
|
+
pass
|
137
|
+
|
138
|
+
self.model.set_image_from_array(image_array)
|
139
|
+
else:
|
140
|
+
self.scale_factor = 1.0
|
141
|
+
|
142
|
+
# Add CUDA synchronization for multi-model scenarios
|
143
|
+
try:
|
144
|
+
import torch
|
145
|
+
|
146
|
+
if torch.cuda.is_available():
|
147
|
+
torch.cuda.synchronize()
|
148
|
+
except ImportError:
|
149
|
+
pass
|
150
|
+
|
151
|
+
self.model.set_image_from_path(self.image_path)
|
152
|
+
|
153
|
+
if not self._should_stop:
|
154
|
+
self.finished.emit(self.viewer_index)
|
155
|
+
|
156
|
+
except Exception as e:
|
157
|
+
if not self._should_stop:
|
158
|
+
self.error.emit(self.viewer_index, str(e))
|
@@ -0,0 +1,129 @@
|
|
1
|
+
"""Worker thread for updating SAM model in background."""
|
2
|
+
|
3
|
+
import cv2
|
4
|
+
import numpy as np
|
5
|
+
from PyQt6.QtCore import Qt, QThread, pyqtSignal
|
6
|
+
from PyQt6.QtGui import QPixmap
|
7
|
+
|
8
|
+
|
9
|
+
class SAMUpdateWorker(QThread):
|
10
|
+
"""Worker thread for updating SAM model in background."""
|
11
|
+
|
12
|
+
finished = pyqtSignal()
|
13
|
+
error = pyqtSignal(str)
|
14
|
+
|
15
|
+
def __init__(
|
16
|
+
self,
|
17
|
+
model_manager,
|
18
|
+
image_path,
|
19
|
+
operate_on_view,
|
20
|
+
current_image=None,
|
21
|
+
parent=None,
|
22
|
+
):
|
23
|
+
super().__init__(parent)
|
24
|
+
self.model_manager = model_manager
|
25
|
+
self.image_path = image_path
|
26
|
+
self.operate_on_view = operate_on_view
|
27
|
+
self.current_image = current_image # Numpy array of current modified image
|
28
|
+
self._should_stop = False
|
29
|
+
self.scale_factor = 1.0 # Track scaling factor for coordinate transformation
|
30
|
+
|
31
|
+
def stop(self):
|
32
|
+
"""Request the worker to stop."""
|
33
|
+
self._should_stop = True
|
34
|
+
|
35
|
+
def get_scale_factor(self):
|
36
|
+
"""Get the scale factor used for image resizing."""
|
37
|
+
return self.scale_factor
|
38
|
+
|
39
|
+
def run(self):
|
40
|
+
"""Run SAM update in background thread."""
|
41
|
+
try:
|
42
|
+
if self._should_stop:
|
43
|
+
return
|
44
|
+
|
45
|
+
if self.operate_on_view and self.current_image is not None:
|
46
|
+
# Use the provided modified image
|
47
|
+
if self._should_stop:
|
48
|
+
return
|
49
|
+
|
50
|
+
# Optimize image size for faster SAM processing
|
51
|
+
image = self.current_image
|
52
|
+
original_height, original_width = image.shape[:2]
|
53
|
+
max_size = 1024
|
54
|
+
|
55
|
+
if original_height > max_size or original_width > max_size:
|
56
|
+
# Calculate scaling factor
|
57
|
+
self.scale_factor = min(
|
58
|
+
max_size / original_width, max_size / original_height
|
59
|
+
)
|
60
|
+
new_width = int(original_width * self.scale_factor)
|
61
|
+
new_height = int(original_height * self.scale_factor)
|
62
|
+
|
63
|
+
# Resize using OpenCV for speed
|
64
|
+
image = cv2.resize(
|
65
|
+
image, (new_width, new_height), interpolation=cv2.INTER_AREA
|
66
|
+
)
|
67
|
+
else:
|
68
|
+
self.scale_factor = 1.0
|
69
|
+
|
70
|
+
if self._should_stop:
|
71
|
+
return
|
72
|
+
|
73
|
+
# Set image from numpy array (FIXED: use resized image, not original)
|
74
|
+
self.model_manager.set_image_from_array(image)
|
75
|
+
else:
|
76
|
+
# Load original image
|
77
|
+
pixmap = QPixmap(self.image_path)
|
78
|
+
if pixmap.isNull():
|
79
|
+
self.error.emit("Failed to load image")
|
80
|
+
return
|
81
|
+
|
82
|
+
if self._should_stop:
|
83
|
+
return
|
84
|
+
|
85
|
+
original_width = pixmap.width()
|
86
|
+
original_height = pixmap.height()
|
87
|
+
|
88
|
+
# Optimize image size for faster SAM processing
|
89
|
+
max_size = 1024
|
90
|
+
if original_width > max_size or original_height > max_size:
|
91
|
+
# Calculate scaling factor
|
92
|
+
self.scale_factor = min(
|
93
|
+
max_size / original_width, max_size / original_height
|
94
|
+
)
|
95
|
+
|
96
|
+
# Scale down while maintaining aspect ratio
|
97
|
+
scaled_pixmap = pixmap.scaled(
|
98
|
+
max_size,
|
99
|
+
max_size,
|
100
|
+
Qt.AspectRatioMode.KeepAspectRatio,
|
101
|
+
Qt.TransformationMode.SmoothTransformation,
|
102
|
+
)
|
103
|
+
|
104
|
+
# Convert to numpy array for SAM
|
105
|
+
qimage = scaled_pixmap.toImage()
|
106
|
+
width = qimage.width()
|
107
|
+
height = qimage.height()
|
108
|
+
ptr = qimage.bits()
|
109
|
+
ptr.setsize(height * width * 4)
|
110
|
+
arr = np.array(ptr).reshape(height, width, 4)
|
111
|
+
# Convert RGBA to RGB
|
112
|
+
image_array = arr[:, :, :3]
|
113
|
+
|
114
|
+
if self._should_stop:
|
115
|
+
return
|
116
|
+
|
117
|
+
# FIXED: Use the resized image array, not original path
|
118
|
+
self.model_manager.set_image_from_array(image_array)
|
119
|
+
else:
|
120
|
+
self.scale_factor = 1.0
|
121
|
+
# For images that don't need resizing, use original path
|
122
|
+
self.model_manager.set_image_from_path(self.image_path)
|
123
|
+
|
124
|
+
if not self._should_stop:
|
125
|
+
self.finished.emit()
|
126
|
+
|
127
|
+
except Exception as e:
|
128
|
+
if not self._should_stop:
|
129
|
+
self.error.emit(str(e))
|