lazylabel-gui 1.0.9__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lazylabel/__init__.py +9 -0
- lazylabel/config/__init__.py +7 -0
- lazylabel/config/hotkeys.py +169 -0
- lazylabel/config/paths.py +41 -0
- lazylabel/config/settings.py +66 -0
- lazylabel/core/__init__.py +7 -0
- lazylabel/core/file_manager.py +106 -0
- lazylabel/core/model_manager.py +97 -0
- lazylabel/core/segment_manager.py +171 -0
- lazylabel/main.py +20 -1262
- lazylabel/models/__init__.py +5 -0
- lazylabel/models/sam_model.py +195 -0
- lazylabel/ui/__init__.py +8 -0
- lazylabel/ui/control_panel.py +237 -0
- lazylabel/{editable_vertex.py → ui/editable_vertex.py} +25 -3
- lazylabel/ui/hotkey_dialog.py +384 -0
- lazylabel/{hoverable_polygon_item.py → ui/hoverable_polygon_item.py} +17 -1
- lazylabel/ui/main_window.py +1546 -0
- lazylabel/ui/right_panel.py +315 -0
- lazylabel/ui/widgets/__init__.py +8 -0
- lazylabel/ui/widgets/adjustments_widget.py +107 -0
- lazylabel/ui/widgets/model_selection_widget.py +94 -0
- lazylabel/ui/widgets/settings_widget.py +106 -0
- lazylabel/ui/widgets/status_bar.py +109 -0
- lazylabel/utils/__init__.py +6 -0
- lazylabel/{custom_file_system_model.py → utils/custom_file_system_model.py} +9 -3
- {lazylabel_gui-1.0.9.dist-info → lazylabel_gui-1.1.1.dist-info}/METADATA +61 -11
- lazylabel_gui-1.1.1.dist-info/RECORD +37 -0
- lazylabel/controls.py +0 -265
- lazylabel/sam_model.py +0 -70
- lazylabel_gui-1.0.9.dist-info/RECORD +0 -17
- /lazylabel/{hoverable_pixelmap_item.py → ui/hoverable_pixelmap_item.py} +0 -0
- /lazylabel/{numeric_table_widget_item.py → ui/numeric_table_widget_item.py} +0 -0
- /lazylabel/{photo_viewer.py → ui/photo_viewer.py} +0 -0
- /lazylabel/{reorderable_class_table.py → ui/reorderable_class_table.py} +0 -0
- /lazylabel/{utils.py → utils/utils.py} +0 -0
- {lazylabel_gui-1.0.9.dist-info → lazylabel_gui-1.1.1.dist-info}/WHEEL +0 -0
- {lazylabel_gui-1.0.9.dist-info → lazylabel_gui-1.1.1.dist-info}/entry_points.txt +0 -0
- {lazylabel_gui-1.0.9.dist-info → lazylabel_gui-1.1.1.dist-info}/licenses/LICENSE +0 -0
- {lazylabel_gui-1.0.9.dist-info → lazylabel_gui-1.1.1.dist-info}/top_level.txt +0 -0
lazylabel/__init__.py
ADDED
@@ -0,0 +1,169 @@
|
|
1
|
+
"""Hotkey management system."""
|
2
|
+
|
3
|
+
import json
|
4
|
+
import os
|
5
|
+
from dataclasses import dataclass, asdict
|
6
|
+
from typing import Dict, Optional, List, Tuple
|
7
|
+
from PyQt6.QtCore import Qt
|
8
|
+
from PyQt6.QtGui import QKeySequence
|
9
|
+
|
10
|
+
|
11
|
+
@dataclass
|
12
|
+
class HotkeyAction:
|
13
|
+
"""Represents a hotkey action with primary and secondary keys."""
|
14
|
+
name: str
|
15
|
+
description: str
|
16
|
+
primary_key: str
|
17
|
+
secondary_key: Optional[str] = None
|
18
|
+
category: str = "General"
|
19
|
+
mouse_related: bool = False # Cannot be reassigned if True
|
20
|
+
|
21
|
+
|
22
|
+
class HotkeyManager:
|
23
|
+
"""Manages application hotkeys with persistence."""
|
24
|
+
|
25
|
+
def __init__(self, config_dir: str):
|
26
|
+
self.config_dir = config_dir
|
27
|
+
self.hotkeys_file = os.path.join(config_dir, "hotkeys.json")
|
28
|
+
self.actions: Dict[str, HotkeyAction] = {}
|
29
|
+
self._initialize_default_hotkeys()
|
30
|
+
self.load_hotkeys()
|
31
|
+
|
32
|
+
def _initialize_default_hotkeys(self):
|
33
|
+
"""Initialize default hotkey mappings."""
|
34
|
+
default_hotkeys = [
|
35
|
+
# Navigation
|
36
|
+
HotkeyAction("load_next_image", "Load Next Image", "Right", category="Navigation"),
|
37
|
+
HotkeyAction("load_previous_image", "Load Previous Image", "Left", category="Navigation"),
|
38
|
+
HotkeyAction("fit_view", "Fit View", "Period", category="Navigation"),
|
39
|
+
|
40
|
+
# Modes
|
41
|
+
HotkeyAction("sam_mode", "Point Mode (SAM)", "1", category="Modes"),
|
42
|
+
HotkeyAction("polygon_mode", "Polygon Mode", "2", category="Modes"),
|
43
|
+
HotkeyAction("selection_mode", "Selection Mode", "E", category="Modes"),
|
44
|
+
HotkeyAction("pan_mode", "Pan Mode", "Q", category="Modes"),
|
45
|
+
HotkeyAction("edit_mode", "Edit Mode", "R", category="Modes"),
|
46
|
+
|
47
|
+
# Actions
|
48
|
+
HotkeyAction("clear_points", "Clear Points/Vertices", "C", category="Actions"),
|
49
|
+
HotkeyAction("save_segment", "Save Current Segment", "Space", category="Actions"),
|
50
|
+
HotkeyAction("save_output", "Save Output", "Return", category="Actions"),
|
51
|
+
HotkeyAction("save_output_alt", "Save Output (Alt)", "Enter", category="Actions"),
|
52
|
+
HotkeyAction("undo", "Undo Last Action", "Ctrl+Z", category="Actions"),
|
53
|
+
HotkeyAction("escape", "Cancel/Clear Selection", "Escape", category="Actions"),
|
54
|
+
|
55
|
+
# Segments
|
56
|
+
HotkeyAction("merge_segments", "Merge Selected Segments", "M", category="Segments"),
|
57
|
+
HotkeyAction("delete_segments", "Delete Selected Segments", "V", category="Segments"),
|
58
|
+
HotkeyAction("delete_segments_alt", "Delete Selected Segments (Alt)", "Backspace", category="Segments"),
|
59
|
+
HotkeyAction("select_all", "Select All Segments", "Ctrl+A", category="Segments"),
|
60
|
+
|
61
|
+
# View
|
62
|
+
HotkeyAction("zoom_in", "Zoom In", "Ctrl+Plus", category="View"),
|
63
|
+
HotkeyAction("zoom_out", "Zoom Out", "Ctrl+Minus", category="View"),
|
64
|
+
|
65
|
+
# Movement (WASD)
|
66
|
+
HotkeyAction("pan_up", "Pan Up", "W", category="Movement"),
|
67
|
+
HotkeyAction("pan_down", "Pan Down", "S", category="Movement"),
|
68
|
+
HotkeyAction("pan_left", "Pan Left", "A", category="Movement"),
|
69
|
+
HotkeyAction("pan_right", "Pan Right", "D", category="Movement"),
|
70
|
+
|
71
|
+
# Mouse-related (cannot be reassigned)
|
72
|
+
HotkeyAction("left_click", "Add Positive Point / Select", "Left Click",
|
73
|
+
category="Mouse", mouse_related=True),
|
74
|
+
HotkeyAction("right_click", "Add Negative Point", "Right Click",
|
75
|
+
category="Mouse", mouse_related=True),
|
76
|
+
HotkeyAction("mouse_drag", "Drag/Pan", "Mouse Drag",
|
77
|
+
category="Mouse", mouse_related=True),
|
78
|
+
]
|
79
|
+
|
80
|
+
for action in default_hotkeys:
|
81
|
+
self.actions[action.name] = action
|
82
|
+
|
83
|
+
def get_action(self, action_name: str) -> Optional[HotkeyAction]:
|
84
|
+
"""Get hotkey action by name."""
|
85
|
+
return self.actions.get(action_name)
|
86
|
+
|
87
|
+
def get_actions_by_category(self) -> Dict[str, List[HotkeyAction]]:
|
88
|
+
"""Get actions grouped by category."""
|
89
|
+
categories = {}
|
90
|
+
for action in self.actions.values():
|
91
|
+
if action.category not in categories:
|
92
|
+
categories[action.category] = []
|
93
|
+
categories[action.category].append(action)
|
94
|
+
return categories
|
95
|
+
|
96
|
+
def set_primary_key(self, action_name: str, key: str) -> bool:
|
97
|
+
"""Set primary key for an action."""
|
98
|
+
if action_name in self.actions and not self.actions[action_name].mouse_related:
|
99
|
+
self.actions[action_name].primary_key = key
|
100
|
+
return True
|
101
|
+
return False
|
102
|
+
|
103
|
+
def set_secondary_key(self, action_name: str, key: Optional[str]) -> bool:
|
104
|
+
"""Set secondary key for an action."""
|
105
|
+
if action_name in self.actions and not self.actions[action_name].mouse_related:
|
106
|
+
self.actions[action_name].secondary_key = key
|
107
|
+
return True
|
108
|
+
return False
|
109
|
+
|
110
|
+
def get_key_for_action(self, action_name: str) -> Tuple[Optional[str], Optional[str]]:
|
111
|
+
"""Get primary and secondary keys for an action."""
|
112
|
+
action = self.actions.get(action_name)
|
113
|
+
if action:
|
114
|
+
return action.primary_key, action.secondary_key
|
115
|
+
return None, None
|
116
|
+
|
117
|
+
def is_key_in_use(self, key: str, exclude_action: str = None) -> Optional[str]:
|
118
|
+
"""Check if a key is already in use by another action."""
|
119
|
+
for name, action in self.actions.items():
|
120
|
+
if name == exclude_action:
|
121
|
+
continue
|
122
|
+
if action.primary_key == key or action.secondary_key == key:
|
123
|
+
return name
|
124
|
+
return None
|
125
|
+
|
126
|
+
def reset_to_defaults(self):
|
127
|
+
"""Reset all hotkeys to default values."""
|
128
|
+
self._initialize_default_hotkeys()
|
129
|
+
|
130
|
+
def save_hotkeys(self):
|
131
|
+
"""Save hotkeys to file."""
|
132
|
+
os.makedirs(self.config_dir, exist_ok=True)
|
133
|
+
|
134
|
+
# Convert to serializable format
|
135
|
+
data = {}
|
136
|
+
for name, action in self.actions.items():
|
137
|
+
if not action.mouse_related: # Don't save mouse-related actions
|
138
|
+
data[name] = {
|
139
|
+
'primary_key': action.primary_key,
|
140
|
+
'secondary_key': action.secondary_key
|
141
|
+
}
|
142
|
+
|
143
|
+
with open(self.hotkeys_file, 'w') as f:
|
144
|
+
json.dump(data, f, indent=4)
|
145
|
+
|
146
|
+
def load_hotkeys(self):
|
147
|
+
"""Load hotkeys from file."""
|
148
|
+
if not os.path.exists(self.hotkeys_file):
|
149
|
+
return
|
150
|
+
|
151
|
+
try:
|
152
|
+
with open(self.hotkeys_file, 'r') as f:
|
153
|
+
data = json.load(f)
|
154
|
+
|
155
|
+
for name, keys in data.items():
|
156
|
+
if name in self.actions and not self.actions[name].mouse_related:
|
157
|
+
self.actions[name].primary_key = keys.get('primary_key', '')
|
158
|
+
self.actions[name].secondary_key = keys.get('secondary_key')
|
159
|
+
except (json.JSONDecodeError, KeyError, FileNotFoundError):
|
160
|
+
# If loading fails, keep defaults
|
161
|
+
pass
|
162
|
+
|
163
|
+
def key_sequence_to_string(self, key_sequence: QKeySequence) -> str:
|
164
|
+
"""Convert QKeySequence to string representation."""
|
165
|
+
return key_sequence.toString()
|
166
|
+
|
167
|
+
def string_to_key_sequence(self, key_string: str) -> QKeySequence:
|
168
|
+
"""Convert string to QKeySequence."""
|
169
|
+
return QKeySequence(key_string)
|
@@ -0,0 +1,41 @@
|
|
1
|
+
"""Path management for LazyLabel."""
|
2
|
+
|
3
|
+
import os
|
4
|
+
from pathlib import Path
|
5
|
+
|
6
|
+
|
7
|
+
class Paths:
|
8
|
+
"""Centralized path management."""
|
9
|
+
|
10
|
+
def __init__(self):
|
11
|
+
self.app_dir = Path(__file__).parent.parent
|
12
|
+
self.models_dir = self.app_dir / "models"
|
13
|
+
self.config_dir = Path.home() / ".config" / "lazylabel"
|
14
|
+
self.cache_dir = Path.home() / ".cache" / "lazylabel"
|
15
|
+
|
16
|
+
# Ensure directories exist
|
17
|
+
self.models_dir.mkdir(exist_ok=True)
|
18
|
+
self.config_dir.mkdir(parents=True, exist_ok=True)
|
19
|
+
|
20
|
+
@property
|
21
|
+
def settings_file(self) -> Path:
|
22
|
+
"""Path to settings file."""
|
23
|
+
return self.config_dir / "settings.json"
|
24
|
+
|
25
|
+
@property
|
26
|
+
def demo_pictures_dir(self) -> Path:
|
27
|
+
"""Path to demo pictures directory."""
|
28
|
+
return self.app_dir / "demo_pictures"
|
29
|
+
|
30
|
+
@property
|
31
|
+
def logo_path(self) -> Path:
|
32
|
+
"""Path to application logo."""
|
33
|
+
return self.demo_pictures_dir / "logo2.png"
|
34
|
+
|
35
|
+
def get_model_path(self, filename: str) -> Path:
|
36
|
+
"""Get path for a model file."""
|
37
|
+
return self.models_dir / filename
|
38
|
+
|
39
|
+
def get_old_cache_model_path(self, filename: str) -> Path:
|
40
|
+
"""Get path for model in old cache location."""
|
41
|
+
return self.cache_dir / filename
|
@@ -0,0 +1,66 @@
|
|
1
|
+
"""Application settings and configuration."""
|
2
|
+
|
3
|
+
import os
|
4
|
+
from dataclasses import dataclass, asdict
|
5
|
+
from typing import Dict, Any
|
6
|
+
import json
|
7
|
+
|
8
|
+
|
9
|
+
@dataclass
|
10
|
+
class Settings:
|
11
|
+
"""Application settings with defaults."""
|
12
|
+
|
13
|
+
# UI Settings
|
14
|
+
window_width: int = 1600
|
15
|
+
window_height: int = 900
|
16
|
+
left_panel_width: int = 250
|
17
|
+
right_panel_width: int = 350
|
18
|
+
|
19
|
+
# Annotation Settings
|
20
|
+
point_radius: float = 0.3
|
21
|
+
line_thickness: float = 0.5
|
22
|
+
pan_multiplier: float = 1.0
|
23
|
+
polygon_join_threshold: int = 2
|
24
|
+
|
25
|
+
# Model Settings
|
26
|
+
default_model_type: str = "vit_h"
|
27
|
+
default_model_filename: str = "sam_vit_h_4b8939.pth"
|
28
|
+
|
29
|
+
# Save Settings
|
30
|
+
auto_save: bool = True
|
31
|
+
save_npz: bool = True
|
32
|
+
save_txt: bool = True
|
33
|
+
save_class_aliases: bool = False
|
34
|
+
yolo_use_alias: bool = True
|
35
|
+
|
36
|
+
# UI State
|
37
|
+
annotation_size_multiplier: float = 1.0
|
38
|
+
|
39
|
+
def save_to_file(self, filepath: str) -> None:
|
40
|
+
"""Save settings to JSON file."""
|
41
|
+
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
42
|
+
with open(filepath, 'w') as f:
|
43
|
+
json.dump(asdict(self), f, indent=4)
|
44
|
+
|
45
|
+
@classmethod
|
46
|
+
def load_from_file(cls, filepath: str) -> 'Settings':
|
47
|
+
"""Load settings from JSON file."""
|
48
|
+
if not os.path.exists(filepath):
|
49
|
+
return cls()
|
50
|
+
|
51
|
+
try:
|
52
|
+
with open(filepath, 'r') as f:
|
53
|
+
data = json.load(f)
|
54
|
+
return cls(**data)
|
55
|
+
except (json.JSONDecodeError, TypeError):
|
56
|
+
return cls()
|
57
|
+
|
58
|
+
def update(self, **kwargs) -> None:
|
59
|
+
"""Update settings with new values."""
|
60
|
+
for key, value in kwargs.items():
|
61
|
+
if hasattr(self, key):
|
62
|
+
setattr(self, key, value)
|
63
|
+
|
64
|
+
|
65
|
+
# Default settings instance
|
66
|
+
DEFAULT_SETTINGS = Settings()
|
@@ -0,0 +1,106 @@
|
|
1
|
+
"""File management functionality."""
|
2
|
+
|
3
|
+
import os
|
4
|
+
import json
|
5
|
+
import numpy as np
|
6
|
+
import cv2
|
7
|
+
from typing import List, Dict, Any, Optional, Tuple
|
8
|
+
from pathlib import Path
|
9
|
+
|
10
|
+
from .segment_manager import SegmentManager
|
11
|
+
|
12
|
+
|
13
|
+
class FileManager:
|
14
|
+
"""Manages file operations for saving and loading."""
|
15
|
+
|
16
|
+
def __init__(self, segment_manager: SegmentManager):
|
17
|
+
self.segment_manager = segment_manager
|
18
|
+
|
19
|
+
def save_npz(self, image_path: str, image_size: Tuple[int, int], class_order: List[int]) -> str:
|
20
|
+
"""Save segments as NPZ file."""
|
21
|
+
final_mask_tensor = self.segment_manager.create_final_mask_tensor(image_size, class_order)
|
22
|
+
npz_path = os.path.splitext(image_path)[0] + ".npz"
|
23
|
+
np.savez_compressed(npz_path, mask=final_mask_tensor.astype(np.uint8))
|
24
|
+
return npz_path
|
25
|
+
|
26
|
+
def save_yolo_txt(self, image_path: str, image_size: Tuple[int, int],
|
27
|
+
class_order: List[int], class_labels: List[str]) -> Optional[str]:
|
28
|
+
"""Save segments as YOLO format TXT file."""
|
29
|
+
final_mask_tensor = self.segment_manager.create_final_mask_tensor(image_size, class_order)
|
30
|
+
output_path = os.path.splitext(image_path)[0] + ".txt"
|
31
|
+
h, w = image_size
|
32
|
+
|
33
|
+
yolo_annotations = []
|
34
|
+
for channel in range(final_mask_tensor.shape[2]):
|
35
|
+
single_channel_image = final_mask_tensor[:, :, channel]
|
36
|
+
if not np.any(single_channel_image):
|
37
|
+
continue
|
38
|
+
|
39
|
+
contours, _ = cv2.findContours(
|
40
|
+
single_channel_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
41
|
+
)
|
42
|
+
|
43
|
+
class_label = class_labels[channel]
|
44
|
+
for contour in contours:
|
45
|
+
x, y, width, height = cv2.boundingRect(contour)
|
46
|
+
center_x = (x + width / 2) / w
|
47
|
+
center_y = (y + height / 2) / h
|
48
|
+
normalized_width = width / w
|
49
|
+
normalized_height = height / h
|
50
|
+
yolo_entry = f"{class_label} {center_x} {center_y} {normalized_width} {normalized_height}"
|
51
|
+
yolo_annotations.append(yolo_entry)
|
52
|
+
|
53
|
+
if not yolo_annotations:
|
54
|
+
return None
|
55
|
+
|
56
|
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
57
|
+
with open(output_path, "w") as file:
|
58
|
+
for annotation in yolo_annotations:
|
59
|
+
file.write(annotation + "\n")
|
60
|
+
|
61
|
+
return output_path
|
62
|
+
|
63
|
+
def save_class_aliases(self, image_path: str) -> str:
|
64
|
+
"""Save class aliases as JSON file."""
|
65
|
+
aliases_path = os.path.splitext(image_path)[0] + ".json"
|
66
|
+
aliases_to_save = {str(k): v for k, v in self.segment_manager.class_aliases.items()}
|
67
|
+
with open(aliases_path, "w") as f:
|
68
|
+
json.dump(aliases_to_save, f, indent=4)
|
69
|
+
return aliases_path
|
70
|
+
|
71
|
+
def load_class_aliases(self, image_path: str) -> None:
|
72
|
+
"""Load class aliases from JSON file."""
|
73
|
+
json_path = os.path.splitext(image_path)[0] + ".json"
|
74
|
+
if os.path.exists(json_path):
|
75
|
+
try:
|
76
|
+
with open(json_path, "r") as f:
|
77
|
+
loaded_aliases = json.load(f)
|
78
|
+
self.segment_manager.class_aliases = {int(k): v for k, v in loaded_aliases.items()}
|
79
|
+
except (json.JSONDecodeError, ValueError) as e:
|
80
|
+
print(f"Error loading class aliases from {json_path}: {e}")
|
81
|
+
self.segment_manager.class_aliases.clear()
|
82
|
+
|
83
|
+
def load_existing_mask(self, image_path: str) -> None:
|
84
|
+
"""Load existing mask from NPZ file."""
|
85
|
+
npz_path = os.path.splitext(image_path)[0] + ".npz"
|
86
|
+
if os.path.exists(npz_path):
|
87
|
+
with np.load(npz_path) as data:
|
88
|
+
if "mask" in data:
|
89
|
+
mask_data = data["mask"]
|
90
|
+
if mask_data.ndim == 2:
|
91
|
+
mask_data = np.expand_dims(mask_data, axis=-1)
|
92
|
+
|
93
|
+
num_classes = mask_data.shape[2]
|
94
|
+
for i in range(num_classes):
|
95
|
+
class_mask = mask_data[:, :, i].astype(bool)
|
96
|
+
if np.any(class_mask):
|
97
|
+
self.segment_manager.add_segment({
|
98
|
+
"mask": class_mask,
|
99
|
+
"type": "Loaded",
|
100
|
+
"vertices": None,
|
101
|
+
"class_id": i,
|
102
|
+
})
|
103
|
+
|
104
|
+
def is_image_file(self, filepath: str) -> bool:
|
105
|
+
"""Check if file is a supported image format."""
|
106
|
+
return filepath.lower().endswith((".png", ".jpg", ".jpeg", ".tiff", ".tif"))
|
@@ -0,0 +1,97 @@
|
|
1
|
+
"""Model management functionality."""
|
2
|
+
|
3
|
+
import os
|
4
|
+
import glob
|
5
|
+
from typing import List, Tuple, Optional, Callable
|
6
|
+
from pathlib import Path
|
7
|
+
|
8
|
+
from ..models.sam_model import SamModel
|
9
|
+
from ..config import Paths
|
10
|
+
|
11
|
+
|
12
|
+
class ModelManager:
|
13
|
+
"""Manages SAM model loading and selection."""
|
14
|
+
|
15
|
+
def __init__(self, paths: Paths):
|
16
|
+
self.paths = paths
|
17
|
+
self.sam_model: Optional[SamModel] = None
|
18
|
+
self.current_models_folder: Optional[str] = None
|
19
|
+
self.on_model_changed: Optional[Callable[[str], None]] = None
|
20
|
+
|
21
|
+
def initialize_default_model(self, model_type: str = "vit_h") -> Optional[SamModel]:
|
22
|
+
"""Initialize the default SAM model.
|
23
|
+
|
24
|
+
Returns:
|
25
|
+
SamModel instance if successful, None if failed
|
26
|
+
"""
|
27
|
+
try:
|
28
|
+
print(f"[8/20] Loading {model_type.upper()} model...")
|
29
|
+
self.sam_model = SamModel(model_type=model_type)
|
30
|
+
self.current_models_folder = str(self.paths.models_dir)
|
31
|
+
return self.sam_model
|
32
|
+
except Exception as e:
|
33
|
+
print(f"[8/20] Failed to initialize default model: {e}")
|
34
|
+
self.sam_model = None
|
35
|
+
return None
|
36
|
+
|
37
|
+
def get_available_models(self, folder_path: str) -> List[Tuple[str, str]]:
|
38
|
+
"""Get list of available .pth models in folder.
|
39
|
+
|
40
|
+
Returns:
|
41
|
+
List of (display_name, full_path) tuples
|
42
|
+
"""
|
43
|
+
pth_files = []
|
44
|
+
for root, dirs, files in os.walk(folder_path):
|
45
|
+
for file in files:
|
46
|
+
if file.lower().endswith(".pth"):
|
47
|
+
full_path = os.path.join(root, file)
|
48
|
+
rel_path = os.path.relpath(full_path, folder_path)
|
49
|
+
pth_files.append((rel_path, full_path))
|
50
|
+
|
51
|
+
return sorted(pth_files, key=lambda x: x[0])
|
52
|
+
|
53
|
+
def detect_model_type(self, model_path: str) -> str:
|
54
|
+
"""Detect model type from filename."""
|
55
|
+
filename = os.path.basename(model_path).lower()
|
56
|
+
if "vit_l" in filename or "large" in filename:
|
57
|
+
return "vit_l"
|
58
|
+
elif "vit_b" in filename or "base" in filename:
|
59
|
+
return "vit_b"
|
60
|
+
elif "vit_h" in filename or "huge" in filename:
|
61
|
+
return "vit_h"
|
62
|
+
return "vit_h" # default
|
63
|
+
|
64
|
+
def load_custom_model(self, model_path: str) -> bool:
|
65
|
+
"""Load a custom model from path.
|
66
|
+
|
67
|
+
Returns:
|
68
|
+
True if successful, False otherwise
|
69
|
+
"""
|
70
|
+
if not self.sam_model:
|
71
|
+
return False
|
72
|
+
|
73
|
+
if not os.path.exists(model_path):
|
74
|
+
return False
|
75
|
+
|
76
|
+
model_type = self.detect_model_type(model_path)
|
77
|
+
success = self.sam_model.load_custom_model(model_path, model_type)
|
78
|
+
|
79
|
+
if success and self.on_model_changed:
|
80
|
+
model_name = os.path.basename(model_path)
|
81
|
+
self.on_model_changed(f"Current: {model_name}")
|
82
|
+
|
83
|
+
return success
|
84
|
+
|
85
|
+
def set_models_folder(self, folder_path: str) -> None:
|
86
|
+
"""Set the current models folder."""
|
87
|
+
self.current_models_folder = folder_path
|
88
|
+
|
89
|
+
def get_models_folder(self) -> Optional[str]:
|
90
|
+
"""Get the current models folder."""
|
91
|
+
return self.current_models_folder
|
92
|
+
|
93
|
+
def is_model_available(self) -> bool:
|
94
|
+
"""Check if a SAM model is loaded and available."""
|
95
|
+
return self.sam_model is not None and getattr(
|
96
|
+
self.sam_model, "is_loaded", False
|
97
|
+
)
|
@@ -0,0 +1,171 @@
|
|
1
|
+
"""Segment management functionality."""
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
import cv2
|
5
|
+
from typing import List, Dict, Any, Optional, Tuple
|
6
|
+
from PyQt6.QtCore import QPointF
|
7
|
+
|
8
|
+
|
9
|
+
class SegmentManager:
|
10
|
+
"""Manages image segments and classes."""
|
11
|
+
|
12
|
+
def __init__(self):
|
13
|
+
self.segments: List[Dict[str, Any]] = []
|
14
|
+
self.class_aliases: Dict[int, str] = {}
|
15
|
+
self.next_class_id: int = 0
|
16
|
+
self.active_class_id: Optional[int] = None # Currently active/toggled class
|
17
|
+
|
18
|
+
def clear(self) -> None:
|
19
|
+
"""Clear all segments and reset state."""
|
20
|
+
self.segments.clear()
|
21
|
+
self.class_aliases.clear()
|
22
|
+
self.next_class_id = 0
|
23
|
+
self.active_class_id = None
|
24
|
+
|
25
|
+
def add_segment(self, segment_data: Dict[str, Any]) -> None:
|
26
|
+
"""Add a new segment."""
|
27
|
+
if "class_id" not in segment_data:
|
28
|
+
# Use active class if available, otherwise use next class ID
|
29
|
+
if self.active_class_id is not None:
|
30
|
+
segment_data["class_id"] = self.active_class_id
|
31
|
+
else:
|
32
|
+
segment_data["class_id"] = self.next_class_id
|
33
|
+
self.segments.append(segment_data)
|
34
|
+
self._update_next_class_id()
|
35
|
+
|
36
|
+
def delete_segments(self, indices: List[int]) -> None:
|
37
|
+
"""Delete segments by indices."""
|
38
|
+
for i in sorted(indices, reverse=True):
|
39
|
+
if 0 <= i < len(self.segments):
|
40
|
+
del self.segments[i]
|
41
|
+
self._update_next_class_id()
|
42
|
+
|
43
|
+
def assign_segments_to_class(self, indices: List[int]) -> None:
|
44
|
+
"""Assign selected segments to a class."""
|
45
|
+
if not indices:
|
46
|
+
return
|
47
|
+
|
48
|
+
existing_class_ids = [
|
49
|
+
self.segments[i]["class_id"]
|
50
|
+
for i in indices
|
51
|
+
if i < len(self.segments) and self.segments[i].get("class_id") is not None
|
52
|
+
]
|
53
|
+
|
54
|
+
if existing_class_ids:
|
55
|
+
target_class_id = min(existing_class_ids)
|
56
|
+
else:
|
57
|
+
target_class_id = self.next_class_id
|
58
|
+
|
59
|
+
for i in indices:
|
60
|
+
if i < len(self.segments):
|
61
|
+
self.segments[i]["class_id"] = target_class_id
|
62
|
+
|
63
|
+
self._update_next_class_id()
|
64
|
+
|
65
|
+
def get_unique_class_ids(self) -> List[int]:
|
66
|
+
"""Get sorted list of unique class IDs."""
|
67
|
+
return sorted(
|
68
|
+
list(
|
69
|
+
{
|
70
|
+
seg.get("class_id")
|
71
|
+
for seg in self.segments
|
72
|
+
if seg.get("class_id") is not None
|
73
|
+
}
|
74
|
+
)
|
75
|
+
)
|
76
|
+
|
77
|
+
def rasterize_polygon(
|
78
|
+
self, vertices: List[QPointF], image_size: Tuple[int, int]
|
79
|
+
) -> Optional[np.ndarray]:
|
80
|
+
"""Convert polygon vertices to binary mask."""
|
81
|
+
if not vertices:
|
82
|
+
return None
|
83
|
+
|
84
|
+
h, w = image_size
|
85
|
+
points_np = np.array([[p.x(), p.y()] for p in vertices], dtype=np.int32)
|
86
|
+
mask = np.zeros((h, w), dtype=np.uint8)
|
87
|
+
cv2.fillPoly(mask, [points_np], 1)
|
88
|
+
return mask.astype(bool)
|
89
|
+
|
90
|
+
def create_final_mask_tensor(
|
91
|
+
self, image_size: Tuple[int, int], class_order: List[int]
|
92
|
+
) -> np.ndarray:
|
93
|
+
"""Create final mask tensor for saving."""
|
94
|
+
h, w = image_size
|
95
|
+
id_map = {old_id: new_id for new_id, old_id in enumerate(class_order)}
|
96
|
+
num_final_classes = len(class_order)
|
97
|
+
final_mask_tensor = np.zeros((h, w, num_final_classes), dtype=np.uint8)
|
98
|
+
|
99
|
+
for seg in self.segments:
|
100
|
+
class_id = seg.get("class_id")
|
101
|
+
if class_id not in id_map:
|
102
|
+
continue
|
103
|
+
|
104
|
+
new_channel_idx = id_map[class_id]
|
105
|
+
|
106
|
+
if seg["type"] == "Polygon":
|
107
|
+
mask = self.rasterize_polygon(seg["vertices"], image_size)
|
108
|
+
else:
|
109
|
+
mask = seg.get("mask")
|
110
|
+
|
111
|
+
if mask is not None:
|
112
|
+
final_mask_tensor[:, :, new_channel_idx] = np.logical_or(
|
113
|
+
final_mask_tensor[:, :, new_channel_idx], mask
|
114
|
+
)
|
115
|
+
|
116
|
+
return final_mask_tensor
|
117
|
+
|
118
|
+
def reassign_class_ids(self, new_order: List[int]) -> None:
|
119
|
+
"""Reassign class IDs based on new order."""
|
120
|
+
id_map = {old_id: new_id for new_id, old_id in enumerate(new_order)}
|
121
|
+
|
122
|
+
for seg in self.segments:
|
123
|
+
old_id = seg.get("class_id")
|
124
|
+
if old_id in id_map:
|
125
|
+
seg["class_id"] = id_map[old_id]
|
126
|
+
|
127
|
+
# Update aliases
|
128
|
+
new_aliases = {
|
129
|
+
id_map[old_id]: self.class_aliases.get(old_id, str(old_id))
|
130
|
+
for old_id in new_order
|
131
|
+
if old_id in self.class_aliases
|
132
|
+
}
|
133
|
+
self.class_aliases = new_aliases
|
134
|
+
self._update_next_class_id()
|
135
|
+
|
136
|
+
def set_class_alias(self, class_id: int, alias: str) -> None:
|
137
|
+
"""Set alias for a class."""
|
138
|
+
self.class_aliases[class_id] = alias
|
139
|
+
|
140
|
+
def get_class_alias(self, class_id: int) -> str:
|
141
|
+
"""Get alias for a class."""
|
142
|
+
return self.class_aliases.get(class_id, str(class_id))
|
143
|
+
|
144
|
+
def set_active_class(self, class_id: Optional[int]) -> None:
|
145
|
+
"""Set the active class ID."""
|
146
|
+
self.active_class_id = class_id
|
147
|
+
|
148
|
+
def get_active_class(self) -> Optional[int]:
|
149
|
+
"""Get the active class ID."""
|
150
|
+
return self.active_class_id
|
151
|
+
|
152
|
+
def toggle_active_class(self, class_id: int) -> bool:
|
153
|
+
"""Toggle a class as active. Returns True if now active, False if deactivated."""
|
154
|
+
if self.active_class_id == class_id:
|
155
|
+
self.active_class_id = None
|
156
|
+
return False
|
157
|
+
else:
|
158
|
+
self.active_class_id = class_id
|
159
|
+
return True
|
160
|
+
|
161
|
+
def _update_next_class_id(self) -> None:
|
162
|
+
"""Update the next available class ID."""
|
163
|
+
all_ids = {
|
164
|
+
seg.get("class_id")
|
165
|
+
for seg in self.segments
|
166
|
+
if seg.get("class_id") is not None
|
167
|
+
}
|
168
|
+
if not all_ids:
|
169
|
+
self.next_class_id = 0
|
170
|
+
else:
|
171
|
+
self.next_class_id = max(all_ids) + 1
|