lazylabel-gui 1.1.0__tar.gz → 1.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/PKG-INFO +1 -1
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/pyproject.toml +1 -1
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/core/model_manager.py +22 -19
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/core/segment_manager.py +65 -34
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/main.py +17 -3
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/models/sam_model.py +72 -31
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/ui/control_panel.py +83 -66
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/ui/main_window.py +322 -40
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/ui/right_panel.py +149 -73
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/ui/widgets/__init__.py +2 -1
- lazylabel_gui-1.1.1/src/lazylabel/ui/widgets/status_bar.py +109 -0
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel_gui.egg-info/PKG-INFO +1 -1
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel_gui.egg-info/SOURCES.txt +1 -0
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/LICENSE +0 -0
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/README.md +0 -0
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/setup.cfg +0 -0
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/__init__.py +0 -0
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/config/__init__.py +0 -0
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/config/hotkeys.py +0 -0
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/config/paths.py +0 -0
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/config/settings.py +0 -0
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/core/__init__.py +0 -0
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/core/file_manager.py +0 -0
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/models/__init__.py +0 -0
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/ui/__init__.py +0 -0
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/ui/editable_vertex.py +0 -0
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/ui/hotkey_dialog.py +0 -0
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/ui/hoverable_pixelmap_item.py +0 -0
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/ui/hoverable_polygon_item.py +0 -0
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/ui/numeric_table_widget_item.py +0 -0
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/ui/photo_viewer.py +0 -0
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/ui/reorderable_class_table.py +0 -0
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/ui/widgets/adjustments_widget.py +0 -0
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/ui/widgets/model_selection_widget.py +0 -0
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/ui/widgets/settings_widget.py +0 -0
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/utils/__init__.py +0 -0
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/utils/custom_file_system_model.py +0 -0
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/utils/utils.py +0 -0
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel_gui.egg-info/dependency_links.txt +0 -0
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel_gui.egg-info/entry_points.txt +0 -0
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel_gui.egg-info/requires.txt +0 -0
- {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel_gui.egg-info/top_level.txt +0 -0
@@ -11,44 +11,45 @@ from ..config import Paths
|
|
11
11
|
|
12
12
|
class ModelManager:
|
13
13
|
"""Manages SAM model loading and selection."""
|
14
|
-
|
14
|
+
|
15
15
|
def __init__(self, paths: Paths):
|
16
16
|
self.paths = paths
|
17
17
|
self.sam_model: Optional[SamModel] = None
|
18
18
|
self.current_models_folder: Optional[str] = None
|
19
19
|
self.on_model_changed: Optional[Callable[[str], None]] = None
|
20
|
-
|
20
|
+
|
21
21
|
def initialize_default_model(self, model_type: str = "vit_h") -> Optional[SamModel]:
|
22
22
|
"""Initialize the default SAM model.
|
23
|
-
|
23
|
+
|
24
24
|
Returns:
|
25
25
|
SamModel instance if successful, None if failed
|
26
26
|
"""
|
27
27
|
try:
|
28
|
+
print(f"[8/20] Loading {model_type.upper()} model...")
|
28
29
|
self.sam_model = SamModel(model_type=model_type)
|
29
30
|
self.current_models_folder = str(self.paths.models_dir)
|
30
31
|
return self.sam_model
|
31
32
|
except Exception as e:
|
32
|
-
print(f"Failed to initialize default model: {e}")
|
33
|
+
print(f"[8/20] Failed to initialize default model: {e}")
|
33
34
|
self.sam_model = None
|
34
35
|
return None
|
35
|
-
|
36
|
+
|
36
37
|
def get_available_models(self, folder_path: str) -> List[Tuple[str, str]]:
|
37
38
|
"""Get list of available .pth models in folder.
|
38
|
-
|
39
|
+
|
39
40
|
Returns:
|
40
41
|
List of (display_name, full_path) tuples
|
41
42
|
"""
|
42
43
|
pth_files = []
|
43
44
|
for root, dirs, files in os.walk(folder_path):
|
44
45
|
for file in files:
|
45
|
-
if file.lower().endswith(
|
46
|
+
if file.lower().endswith(".pth"):
|
46
47
|
full_path = os.path.join(root, file)
|
47
48
|
rel_path = os.path.relpath(full_path, folder_path)
|
48
49
|
pth_files.append((rel_path, full_path))
|
49
|
-
|
50
|
+
|
50
51
|
return sorted(pth_files, key=lambda x: x[0])
|
51
|
-
|
52
|
+
|
52
53
|
def detect_model_type(self, model_path: str) -> str:
|
53
54
|
"""Detect model type from filename."""
|
54
55
|
filename = os.path.basename(model_path).lower()
|
@@ -59,36 +60,38 @@ class ModelManager:
|
|
59
60
|
elif "vit_h" in filename or "huge" in filename:
|
60
61
|
return "vit_h"
|
61
62
|
return "vit_h" # default
|
62
|
-
|
63
|
+
|
63
64
|
def load_custom_model(self, model_path: str) -> bool:
|
64
65
|
"""Load a custom model from path.
|
65
|
-
|
66
|
+
|
66
67
|
Returns:
|
67
68
|
True if successful, False otherwise
|
68
69
|
"""
|
69
70
|
if not self.sam_model:
|
70
71
|
return False
|
71
|
-
|
72
|
+
|
72
73
|
if not os.path.exists(model_path):
|
73
74
|
return False
|
74
|
-
|
75
|
+
|
75
76
|
model_type = self.detect_model_type(model_path)
|
76
77
|
success = self.sam_model.load_custom_model(model_path, model_type)
|
77
|
-
|
78
|
+
|
78
79
|
if success and self.on_model_changed:
|
79
80
|
model_name = os.path.basename(model_path)
|
80
81
|
self.on_model_changed(f"Current: {model_name}")
|
81
|
-
|
82
|
+
|
82
83
|
return success
|
83
|
-
|
84
|
+
|
84
85
|
def set_models_folder(self, folder_path: str) -> None:
|
85
86
|
"""Set the current models folder."""
|
86
87
|
self.current_models_folder = folder_path
|
87
|
-
|
88
|
+
|
88
89
|
def get_models_folder(self) -> Optional[str]:
|
89
90
|
"""Get the current models folder."""
|
90
91
|
return self.current_models_folder
|
91
|
-
|
92
|
+
|
92
93
|
def is_model_available(self) -> bool:
|
93
94
|
"""Check if a SAM model is loaded and available."""
|
94
|
-
return self.sam_model is not None and getattr(
|
95
|
+
return self.sam_model is not None and getattr(
|
96
|
+
self.sam_model, "is_loaded", False
|
97
|
+
)
|
@@ -8,108 +8,122 @@ from PyQt6.QtCore import QPointF
|
|
8
8
|
|
9
9
|
class SegmentManager:
|
10
10
|
"""Manages image segments and classes."""
|
11
|
-
|
11
|
+
|
12
12
|
def __init__(self):
|
13
13
|
self.segments: List[Dict[str, Any]] = []
|
14
14
|
self.class_aliases: Dict[int, str] = {}
|
15
15
|
self.next_class_id: int = 0
|
16
|
-
|
16
|
+
self.active_class_id: Optional[int] = None # Currently active/toggled class
|
17
|
+
|
17
18
|
def clear(self) -> None:
|
18
19
|
"""Clear all segments and reset state."""
|
19
20
|
self.segments.clear()
|
20
21
|
self.class_aliases.clear()
|
21
22
|
self.next_class_id = 0
|
22
|
-
|
23
|
+
self.active_class_id = None
|
24
|
+
|
23
25
|
def add_segment(self, segment_data: Dict[str, Any]) -> None:
|
24
26
|
"""Add a new segment."""
|
25
|
-
if
|
26
|
-
|
27
|
+
if "class_id" not in segment_data:
|
28
|
+
# Use active class if available, otherwise use next class ID
|
29
|
+
if self.active_class_id is not None:
|
30
|
+
segment_data["class_id"] = self.active_class_id
|
31
|
+
else:
|
32
|
+
segment_data["class_id"] = self.next_class_id
|
27
33
|
self.segments.append(segment_data)
|
28
34
|
self._update_next_class_id()
|
29
|
-
|
35
|
+
|
30
36
|
def delete_segments(self, indices: List[int]) -> None:
|
31
37
|
"""Delete segments by indices."""
|
32
38
|
for i in sorted(indices, reverse=True):
|
33
39
|
if 0 <= i < len(self.segments):
|
34
40
|
del self.segments[i]
|
35
41
|
self._update_next_class_id()
|
36
|
-
|
42
|
+
|
37
43
|
def assign_segments_to_class(self, indices: List[int]) -> None:
|
38
44
|
"""Assign selected segments to a class."""
|
39
45
|
if not indices:
|
40
46
|
return
|
41
|
-
|
47
|
+
|
42
48
|
existing_class_ids = [
|
43
49
|
self.segments[i]["class_id"]
|
44
50
|
for i in indices
|
45
51
|
if i < len(self.segments) and self.segments[i].get("class_id") is not None
|
46
52
|
]
|
47
|
-
|
53
|
+
|
48
54
|
if existing_class_ids:
|
49
55
|
target_class_id = min(existing_class_ids)
|
50
56
|
else:
|
51
57
|
target_class_id = self.next_class_id
|
52
|
-
|
58
|
+
|
53
59
|
for i in indices:
|
54
60
|
if i < len(self.segments):
|
55
61
|
self.segments[i]["class_id"] = target_class_id
|
56
|
-
|
62
|
+
|
57
63
|
self._update_next_class_id()
|
58
|
-
|
64
|
+
|
59
65
|
def get_unique_class_ids(self) -> List[int]:
|
60
66
|
"""Get sorted list of unique class IDs."""
|
61
|
-
return sorted(
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
67
|
+
return sorted(
|
68
|
+
list(
|
69
|
+
{
|
70
|
+
seg.get("class_id")
|
71
|
+
for seg in self.segments
|
72
|
+
if seg.get("class_id") is not None
|
73
|
+
}
|
74
|
+
)
|
75
|
+
)
|
76
|
+
|
77
|
+
def rasterize_polygon(
|
78
|
+
self, vertices: List[QPointF], image_size: Tuple[int, int]
|
79
|
+
) -> Optional[np.ndarray]:
|
68
80
|
"""Convert polygon vertices to binary mask."""
|
69
81
|
if not vertices:
|
70
82
|
return None
|
71
|
-
|
83
|
+
|
72
84
|
h, w = image_size
|
73
85
|
points_np = np.array([[p.x(), p.y()] for p in vertices], dtype=np.int32)
|
74
86
|
mask = np.zeros((h, w), dtype=np.uint8)
|
75
87
|
cv2.fillPoly(mask, [points_np], 1)
|
76
88
|
return mask.astype(bool)
|
77
|
-
|
78
|
-
def create_final_mask_tensor(
|
89
|
+
|
90
|
+
def create_final_mask_tensor(
|
91
|
+
self, image_size: Tuple[int, int], class_order: List[int]
|
92
|
+
) -> np.ndarray:
|
79
93
|
"""Create final mask tensor for saving."""
|
80
94
|
h, w = image_size
|
81
95
|
id_map = {old_id: new_id for new_id, old_id in enumerate(class_order)}
|
82
96
|
num_final_classes = len(class_order)
|
83
97
|
final_mask_tensor = np.zeros((h, w, num_final_classes), dtype=np.uint8)
|
84
|
-
|
98
|
+
|
85
99
|
for seg in self.segments:
|
86
100
|
class_id = seg.get("class_id")
|
87
101
|
if class_id not in id_map:
|
88
102
|
continue
|
89
|
-
|
103
|
+
|
90
104
|
new_channel_idx = id_map[class_id]
|
91
|
-
|
105
|
+
|
92
106
|
if seg["type"] == "Polygon":
|
93
107
|
mask = self.rasterize_polygon(seg["vertices"], image_size)
|
94
108
|
else:
|
95
109
|
mask = seg.get("mask")
|
96
|
-
|
110
|
+
|
97
111
|
if mask is not None:
|
98
112
|
final_mask_tensor[:, :, new_channel_idx] = np.logical_or(
|
99
113
|
final_mask_tensor[:, :, new_channel_idx], mask
|
100
114
|
)
|
101
|
-
|
115
|
+
|
102
116
|
return final_mask_tensor
|
103
|
-
|
117
|
+
|
104
118
|
def reassign_class_ids(self, new_order: List[int]) -> None:
|
105
119
|
"""Reassign class IDs based on new order."""
|
106
120
|
id_map = {old_id: new_id for new_id, old_id in enumerate(new_order)}
|
107
|
-
|
121
|
+
|
108
122
|
for seg in self.segments:
|
109
123
|
old_id = seg.get("class_id")
|
110
124
|
if old_id in id_map:
|
111
125
|
seg["class_id"] = id_map[old_id]
|
112
|
-
|
126
|
+
|
113
127
|
# Update aliases
|
114
128
|
new_aliases = {
|
115
129
|
id_map[old_id]: self.class_aliases.get(old_id, str(old_id))
|
@@ -118,15 +132,32 @@ class SegmentManager:
|
|
118
132
|
}
|
119
133
|
self.class_aliases = new_aliases
|
120
134
|
self._update_next_class_id()
|
121
|
-
|
135
|
+
|
122
136
|
def set_class_alias(self, class_id: int, alias: str) -> None:
|
123
137
|
"""Set alias for a class."""
|
124
138
|
self.class_aliases[class_id] = alias
|
125
|
-
|
139
|
+
|
126
140
|
def get_class_alias(self, class_id: int) -> str:
|
127
141
|
"""Get alias for a class."""
|
128
142
|
return self.class_aliases.get(class_id, str(class_id))
|
129
|
-
|
143
|
+
|
144
|
+
def set_active_class(self, class_id: Optional[int]) -> None:
|
145
|
+
"""Set the active class ID."""
|
146
|
+
self.active_class_id = class_id
|
147
|
+
|
148
|
+
def get_active_class(self) -> Optional[int]:
|
149
|
+
"""Get the active class ID."""
|
150
|
+
return self.active_class_id
|
151
|
+
|
152
|
+
def toggle_active_class(self, class_id: int) -> bool:
|
153
|
+
"""Toggle a class as active. Returns True if now active, False if deactivated."""
|
154
|
+
if self.active_class_id == class_id:
|
155
|
+
self.active_class_id = None
|
156
|
+
return False
|
157
|
+
else:
|
158
|
+
self.active_class_id = class_id
|
159
|
+
return True
|
160
|
+
|
130
161
|
def _update_next_class_id(self) -> None:
|
131
162
|
"""Update the next available class ID."""
|
132
163
|
all_ids = {
|
@@ -137,4 +168,4 @@ class SegmentManager:
|
|
137
168
|
if not all_ids:
|
138
169
|
self.next_class_id = 0
|
139
170
|
else:
|
140
|
-
self.next_class_id = max(all_ids) + 1
|
171
|
+
self.next_class_id = max(all_ids) + 1
|
@@ -9,14 +9,28 @@ from .ui.main_window import MainWindow
|
|
9
9
|
|
10
10
|
def main():
|
11
11
|
"""Main application entry point."""
|
12
|
+
print("=" * 50)
|
13
|
+
print("LazyLabel - AI-Assisted Image Labeling")
|
14
|
+
print("=" * 50)
|
15
|
+
print()
|
16
|
+
|
17
|
+
print("[1/20] Initializing application...")
|
12
18
|
app = QApplication(sys.argv)
|
19
|
+
|
20
|
+
print("[2/20] Applying dark theme...")
|
13
21
|
qdarktheme.setup_theme()
|
14
|
-
|
22
|
+
|
15
23
|
main_window = MainWindow()
|
24
|
+
|
25
|
+
print("[19/20] Showing main window...")
|
16
26
|
main_window.show()
|
17
|
-
|
27
|
+
|
28
|
+
print()
|
29
|
+
print("[20/20] LazyLabel is ready! Happy labeling!")
|
30
|
+
print("=" * 50)
|
31
|
+
|
18
32
|
sys.exit(app.exec())
|
19
33
|
|
20
34
|
|
21
35
|
if __name__ == "__main__":
|
22
|
-
main()
|
36
|
+
main()
|
@@ -9,15 +9,18 @@ from segment_anything import sam_model_registry, SamPredictor
|
|
9
9
|
|
10
10
|
def download_model(url, download_path):
|
11
11
|
"""Downloads file with a progress bar."""
|
12
|
-
print(
|
13
|
-
|
14
|
-
)
|
12
|
+
print(f"[10/20] SAM model not found. Downloading from Meta's repository...")
|
13
|
+
print(f" Downloading to: {download_path}")
|
15
14
|
try:
|
15
|
+
print(f"[10/20] Connecting to download server...")
|
16
16
|
response = requests.get(url, stream=True, timeout=30)
|
17
17
|
response.raise_for_status()
|
18
18
|
total_size_in_bytes = int(response.headers.get("content-length", 0))
|
19
19
|
block_size = 1024 # 1 Kibibyte
|
20
20
|
|
21
|
+
print(
|
22
|
+
f"[10/20] Starting download ({total_size_in_bytes / (1024*1024*1024):.1f} GB)..."
|
23
|
+
)
|
21
24
|
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
22
25
|
with open(download_path, "wb") as file:
|
23
26
|
for data in response.iter_content(block_size):
|
@@ -27,93 +30,131 @@ def download_model(url, download_path):
|
|
27
30
|
|
28
31
|
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
|
29
32
|
raise RuntimeError("Download incomplete - file size mismatch")
|
30
|
-
|
31
|
-
print("Model download completed successfully.")
|
32
|
-
|
33
|
+
|
34
|
+
print("[10/20] Model download completed successfully.")
|
35
|
+
|
36
|
+
except requests.exceptions.ConnectionError as e:
|
37
|
+
raise RuntimeError(
|
38
|
+
f"[10/20] Network connection failed: Check your internet connection"
|
39
|
+
)
|
40
|
+
except requests.exceptions.Timeout as e:
|
41
|
+
raise RuntimeError(f"[10/20] Download timeout: Server took too long to respond")
|
42
|
+
except requests.exceptions.HTTPError as e:
|
43
|
+
raise RuntimeError(
|
44
|
+
f"[10/20] HTTP error {e.response.status_code}: Server rejected request"
|
45
|
+
)
|
33
46
|
except requests.exceptions.RequestException as e:
|
34
|
-
raise RuntimeError(f"Network error during download: {e}")
|
47
|
+
raise RuntimeError(f"[10/20] Network error during download: {e}")
|
48
|
+
except PermissionError as e:
|
49
|
+
raise RuntimeError(
|
50
|
+
f"[10/20] Permission denied: Cannot write to {download_path}"
|
51
|
+
)
|
52
|
+
except OSError as e:
|
53
|
+
raise RuntimeError(f"[10/20] Disk error: {e} (check available disk space)")
|
35
54
|
except Exception as e:
|
36
55
|
# Clean up partial download
|
37
56
|
if os.path.exists(download_path):
|
38
|
-
|
39
|
-
|
57
|
+
try:
|
58
|
+
os.remove(download_path)
|
59
|
+
except:
|
60
|
+
pass
|
61
|
+
raise RuntimeError(f"[10/20] Download failed: {e}")
|
40
62
|
|
41
63
|
|
42
64
|
class SamModel:
|
43
|
-
def __init__(
|
65
|
+
def __init__(
|
66
|
+
self,
|
67
|
+
model_type="vit_h",
|
68
|
+
model_filename="sam_vit_h_4b8939.pth",
|
69
|
+
custom_model_path=None,
|
70
|
+
):
|
44
71
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
72
|
+
print(f"[9/20] Detected device: {str(self.device).upper()}")
|
73
|
+
|
45
74
|
self.current_model_type = model_type
|
46
75
|
self.current_model_path = custom_model_path
|
47
76
|
self.model = None
|
48
77
|
self.predictor = None
|
49
78
|
self.image = None
|
50
79
|
self.is_loaded = False
|
51
|
-
|
80
|
+
|
52
81
|
try:
|
53
82
|
if custom_model_path and os.path.exists(custom_model_path):
|
54
83
|
# Use custom model path
|
55
84
|
model_path = custom_model_path
|
56
|
-
print(f"Loading custom SAM model from {model_path}...")
|
85
|
+
print(f"[10/20] Loading custom SAM model from {model_path}...")
|
57
86
|
else:
|
58
87
|
# Use default model with download if needed - store in models folder
|
59
|
-
model_url =
|
60
|
-
|
88
|
+
model_url = (
|
89
|
+
f"https://dl.fbaipublicfiles.com/segment_anything/{model_filename}"
|
90
|
+
)
|
91
|
+
|
61
92
|
# Use models folder instead of cache folder
|
62
93
|
models_dir = os.path.dirname(__file__) # Already in models directory
|
63
94
|
os.makedirs(models_dir, exist_ok=True)
|
64
95
|
model_path = os.path.join(models_dir, model_filename)
|
65
|
-
|
96
|
+
|
66
97
|
# Also check the old cache location and move it if it exists
|
67
|
-
old_cache_dir = os.path.join(
|
98
|
+
old_cache_dir = os.path.join(
|
99
|
+
os.path.expanduser("~"), ".cache", "lazylabel"
|
100
|
+
)
|
68
101
|
old_model_path = os.path.join(old_cache_dir, model_filename)
|
69
|
-
|
102
|
+
|
70
103
|
if os.path.exists(old_model_path) and not os.path.exists(model_path):
|
71
|
-
print(
|
104
|
+
print(
|
105
|
+
f"[10/20] Moving existing model from cache to models folder..."
|
106
|
+
)
|
72
107
|
import shutil
|
108
|
+
|
73
109
|
shutil.move(old_model_path, model_path)
|
74
110
|
elif not os.path.exists(model_path):
|
75
111
|
# Download the model if it doesn't exist
|
76
112
|
download_model(model_url, model_path)
|
77
|
-
|
78
|
-
print(f"Loading default SAM model from {model_path}...")
|
79
113
|
|
114
|
+
print(f"[10/20] Loading default SAM model from {model_path}...")
|
115
|
+
|
116
|
+
print(f"[11/20] Initializing {model_type.upper()} model architecture...")
|
80
117
|
self.model = sam_model_registry[model_type](checkpoint=model_path).to(
|
81
118
|
self.device
|
82
119
|
)
|
120
|
+
|
121
|
+
print(f"[12/20] Setting up predictor...")
|
83
122
|
self.predictor = SamPredictor(self.model)
|
84
123
|
self.is_loaded = True
|
85
|
-
print("SAM model loaded successfully.")
|
86
|
-
|
124
|
+
print("[13/20] SAM model loaded successfully.")
|
125
|
+
|
87
126
|
except Exception as e:
|
88
|
-
print(f"Failed to load SAM model: {e}")
|
89
|
-
print("SAM point functionality will be disabled.")
|
127
|
+
print(f"[8/20] Failed to load SAM model: {e}")
|
128
|
+
print(f"[8/20] SAM point functionality will be disabled.")
|
90
129
|
self.is_loaded = False
|
91
|
-
|
130
|
+
|
92
131
|
def load_custom_model(self, model_path, model_type="vit_h"):
|
93
132
|
"""Load a custom model from the specified path."""
|
94
133
|
if not os.path.exists(model_path):
|
95
134
|
print(f"Model file not found: {model_path}")
|
96
135
|
return False
|
97
|
-
|
136
|
+
|
98
137
|
print(f"Loading custom SAM model from {model_path}...")
|
99
138
|
try:
|
100
139
|
# Clear existing model from memory
|
101
|
-
if hasattr(self,
|
140
|
+
if hasattr(self, "model") and self.model is not None:
|
102
141
|
del self.model
|
103
142
|
del self.predictor
|
104
143
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
105
|
-
|
144
|
+
|
106
145
|
# Load new model
|
107
|
-
self.model = sam_model_registry[model_type](checkpoint=model_path).to(
|
146
|
+
self.model = sam_model_registry[model_type](checkpoint=model_path).to(
|
147
|
+
self.device
|
148
|
+
)
|
108
149
|
self.predictor = SamPredictor(self.model)
|
109
150
|
self.current_model_type = model_type
|
110
151
|
self.current_model_path = model_path
|
111
152
|
self.is_loaded = True
|
112
|
-
|
153
|
+
|
113
154
|
# Re-set image if one was previously loaded
|
114
155
|
if self.image is not None:
|
115
156
|
self.predictor.set_image(self.image)
|
116
|
-
|
157
|
+
|
117
158
|
print("Custom SAM model loaded successfully.")
|
118
159
|
return True
|
119
160
|
except Exception as e:
|