lazylabel-gui 1.1.0__tar.gz → 1.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/PKG-INFO +1 -1
  2. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/pyproject.toml +1 -1
  3. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/core/model_manager.py +22 -19
  4. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/core/segment_manager.py +65 -34
  5. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/main.py +17 -3
  6. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/models/sam_model.py +72 -31
  7. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/ui/control_panel.py +83 -66
  8. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/ui/main_window.py +322 -40
  9. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/ui/right_panel.py +149 -73
  10. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/ui/widgets/__init__.py +2 -1
  11. lazylabel_gui-1.1.1/src/lazylabel/ui/widgets/status_bar.py +109 -0
  12. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel_gui.egg-info/PKG-INFO +1 -1
  13. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel_gui.egg-info/SOURCES.txt +1 -0
  14. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/LICENSE +0 -0
  15. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/README.md +0 -0
  16. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/setup.cfg +0 -0
  17. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/__init__.py +0 -0
  18. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/config/__init__.py +0 -0
  19. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/config/hotkeys.py +0 -0
  20. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/config/paths.py +0 -0
  21. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/config/settings.py +0 -0
  22. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/core/__init__.py +0 -0
  23. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/core/file_manager.py +0 -0
  24. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/models/__init__.py +0 -0
  25. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/ui/__init__.py +0 -0
  26. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/ui/editable_vertex.py +0 -0
  27. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/ui/hotkey_dialog.py +0 -0
  28. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/ui/hoverable_pixelmap_item.py +0 -0
  29. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/ui/hoverable_polygon_item.py +0 -0
  30. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/ui/numeric_table_widget_item.py +0 -0
  31. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/ui/photo_viewer.py +0 -0
  32. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/ui/reorderable_class_table.py +0 -0
  33. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/ui/widgets/adjustments_widget.py +0 -0
  34. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/ui/widgets/model_selection_widget.py +0 -0
  35. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/ui/widgets/settings_widget.py +0 -0
  36. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/utils/__init__.py +0 -0
  37. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/utils/custom_file_system_model.py +0 -0
  38. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel/utils/utils.py +0 -0
  39. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel_gui.egg-info/dependency_links.txt +0 -0
  40. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel_gui.egg-info/entry_points.txt +0 -0
  41. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel_gui.egg-info/requires.txt +0 -0
  42. {lazylabel_gui-1.1.0 → lazylabel_gui-1.1.1}/src/lazylabel_gui.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lazylabel-gui
3
- Version: 1.1.0
3
+ Version: 1.1.1
4
4
  Summary: An image segmentation GUI for generating ML ready mask tensors and annotations.
5
5
  Author-email: "Deniz N. Cakan" <deniz.n.cakan@gmail.com>
6
6
  License: MIT License
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "lazylabel-gui"
7
- version = "1.1.0"
7
+ version = "1.1.1"
8
8
  authors = [
9
9
  { name="Deniz N. Cakan", email="deniz.n.cakan@gmail.com" },
10
10
  ]
@@ -11,44 +11,45 @@ from ..config import Paths
11
11
 
12
12
  class ModelManager:
13
13
  """Manages SAM model loading and selection."""
14
-
14
+
15
15
  def __init__(self, paths: Paths):
16
16
  self.paths = paths
17
17
  self.sam_model: Optional[SamModel] = None
18
18
  self.current_models_folder: Optional[str] = None
19
19
  self.on_model_changed: Optional[Callable[[str], None]] = None
20
-
20
+
21
21
  def initialize_default_model(self, model_type: str = "vit_h") -> Optional[SamModel]:
22
22
  """Initialize the default SAM model.
23
-
23
+
24
24
  Returns:
25
25
  SamModel instance if successful, None if failed
26
26
  """
27
27
  try:
28
+ print(f"[8/20] Loading {model_type.upper()} model...")
28
29
  self.sam_model = SamModel(model_type=model_type)
29
30
  self.current_models_folder = str(self.paths.models_dir)
30
31
  return self.sam_model
31
32
  except Exception as e:
32
- print(f"Failed to initialize default model: {e}")
33
+ print(f"[8/20] Failed to initialize default model: {e}")
33
34
  self.sam_model = None
34
35
  return None
35
-
36
+
36
37
  def get_available_models(self, folder_path: str) -> List[Tuple[str, str]]:
37
38
  """Get list of available .pth models in folder.
38
-
39
+
39
40
  Returns:
40
41
  List of (display_name, full_path) tuples
41
42
  """
42
43
  pth_files = []
43
44
  for root, dirs, files in os.walk(folder_path):
44
45
  for file in files:
45
- if file.lower().endswith('.pth'):
46
+ if file.lower().endswith(".pth"):
46
47
  full_path = os.path.join(root, file)
47
48
  rel_path = os.path.relpath(full_path, folder_path)
48
49
  pth_files.append((rel_path, full_path))
49
-
50
+
50
51
  return sorted(pth_files, key=lambda x: x[0])
51
-
52
+
52
53
  def detect_model_type(self, model_path: str) -> str:
53
54
  """Detect model type from filename."""
54
55
  filename = os.path.basename(model_path).lower()
@@ -59,36 +60,38 @@ class ModelManager:
59
60
  elif "vit_h" in filename or "huge" in filename:
60
61
  return "vit_h"
61
62
  return "vit_h" # default
62
-
63
+
63
64
  def load_custom_model(self, model_path: str) -> bool:
64
65
  """Load a custom model from path.
65
-
66
+
66
67
  Returns:
67
68
  True if successful, False otherwise
68
69
  """
69
70
  if not self.sam_model:
70
71
  return False
71
-
72
+
72
73
  if not os.path.exists(model_path):
73
74
  return False
74
-
75
+
75
76
  model_type = self.detect_model_type(model_path)
76
77
  success = self.sam_model.load_custom_model(model_path, model_type)
77
-
78
+
78
79
  if success and self.on_model_changed:
79
80
  model_name = os.path.basename(model_path)
80
81
  self.on_model_changed(f"Current: {model_name}")
81
-
82
+
82
83
  return success
83
-
84
+
84
85
  def set_models_folder(self, folder_path: str) -> None:
85
86
  """Set the current models folder."""
86
87
  self.current_models_folder = folder_path
87
-
88
+
88
89
  def get_models_folder(self) -> Optional[str]:
89
90
  """Get the current models folder."""
90
91
  return self.current_models_folder
91
-
92
+
92
93
  def is_model_available(self) -> bool:
93
94
  """Check if a SAM model is loaded and available."""
94
- return self.sam_model is not None and getattr(self.sam_model, 'is_loaded', False)
95
+ return self.sam_model is not None and getattr(
96
+ self.sam_model, "is_loaded", False
97
+ )
@@ -8,108 +8,122 @@ from PyQt6.QtCore import QPointF
8
8
 
9
9
  class SegmentManager:
10
10
  """Manages image segments and classes."""
11
-
11
+
12
12
  def __init__(self):
13
13
  self.segments: List[Dict[str, Any]] = []
14
14
  self.class_aliases: Dict[int, str] = {}
15
15
  self.next_class_id: int = 0
16
-
16
+ self.active_class_id: Optional[int] = None # Currently active/toggled class
17
+
17
18
  def clear(self) -> None:
18
19
  """Clear all segments and reset state."""
19
20
  self.segments.clear()
20
21
  self.class_aliases.clear()
21
22
  self.next_class_id = 0
22
-
23
+ self.active_class_id = None
24
+
23
25
  def add_segment(self, segment_data: Dict[str, Any]) -> None:
24
26
  """Add a new segment."""
25
- if 'class_id' not in segment_data:
26
- segment_data['class_id'] = self.next_class_id
27
+ if "class_id" not in segment_data:
28
+ # Use active class if available, otherwise use next class ID
29
+ if self.active_class_id is not None:
30
+ segment_data["class_id"] = self.active_class_id
31
+ else:
32
+ segment_data["class_id"] = self.next_class_id
27
33
  self.segments.append(segment_data)
28
34
  self._update_next_class_id()
29
-
35
+
30
36
  def delete_segments(self, indices: List[int]) -> None:
31
37
  """Delete segments by indices."""
32
38
  for i in sorted(indices, reverse=True):
33
39
  if 0 <= i < len(self.segments):
34
40
  del self.segments[i]
35
41
  self._update_next_class_id()
36
-
42
+
37
43
  def assign_segments_to_class(self, indices: List[int]) -> None:
38
44
  """Assign selected segments to a class."""
39
45
  if not indices:
40
46
  return
41
-
47
+
42
48
  existing_class_ids = [
43
49
  self.segments[i]["class_id"]
44
50
  for i in indices
45
51
  if i < len(self.segments) and self.segments[i].get("class_id") is not None
46
52
  ]
47
-
53
+
48
54
  if existing_class_ids:
49
55
  target_class_id = min(existing_class_ids)
50
56
  else:
51
57
  target_class_id = self.next_class_id
52
-
58
+
53
59
  for i in indices:
54
60
  if i < len(self.segments):
55
61
  self.segments[i]["class_id"] = target_class_id
56
-
62
+
57
63
  self._update_next_class_id()
58
-
64
+
59
65
  def get_unique_class_ids(self) -> List[int]:
60
66
  """Get sorted list of unique class IDs."""
61
- return sorted(list({
62
- seg.get("class_id")
63
- for seg in self.segments
64
- if seg.get("class_id") is not None
65
- }))
66
-
67
- def rasterize_polygon(self, vertices: List[QPointF], image_size: Tuple[int, int]) -> Optional[np.ndarray]:
67
+ return sorted(
68
+ list(
69
+ {
70
+ seg.get("class_id")
71
+ for seg in self.segments
72
+ if seg.get("class_id") is not None
73
+ }
74
+ )
75
+ )
76
+
77
+ def rasterize_polygon(
78
+ self, vertices: List[QPointF], image_size: Tuple[int, int]
79
+ ) -> Optional[np.ndarray]:
68
80
  """Convert polygon vertices to binary mask."""
69
81
  if not vertices:
70
82
  return None
71
-
83
+
72
84
  h, w = image_size
73
85
  points_np = np.array([[p.x(), p.y()] for p in vertices], dtype=np.int32)
74
86
  mask = np.zeros((h, w), dtype=np.uint8)
75
87
  cv2.fillPoly(mask, [points_np], 1)
76
88
  return mask.astype(bool)
77
-
78
- def create_final_mask_tensor(self, image_size: Tuple[int, int], class_order: List[int]) -> np.ndarray:
89
+
90
+ def create_final_mask_tensor(
91
+ self, image_size: Tuple[int, int], class_order: List[int]
92
+ ) -> np.ndarray:
79
93
  """Create final mask tensor for saving."""
80
94
  h, w = image_size
81
95
  id_map = {old_id: new_id for new_id, old_id in enumerate(class_order)}
82
96
  num_final_classes = len(class_order)
83
97
  final_mask_tensor = np.zeros((h, w, num_final_classes), dtype=np.uint8)
84
-
98
+
85
99
  for seg in self.segments:
86
100
  class_id = seg.get("class_id")
87
101
  if class_id not in id_map:
88
102
  continue
89
-
103
+
90
104
  new_channel_idx = id_map[class_id]
91
-
105
+
92
106
  if seg["type"] == "Polygon":
93
107
  mask = self.rasterize_polygon(seg["vertices"], image_size)
94
108
  else:
95
109
  mask = seg.get("mask")
96
-
110
+
97
111
  if mask is not None:
98
112
  final_mask_tensor[:, :, new_channel_idx] = np.logical_or(
99
113
  final_mask_tensor[:, :, new_channel_idx], mask
100
114
  )
101
-
115
+
102
116
  return final_mask_tensor
103
-
117
+
104
118
  def reassign_class_ids(self, new_order: List[int]) -> None:
105
119
  """Reassign class IDs based on new order."""
106
120
  id_map = {old_id: new_id for new_id, old_id in enumerate(new_order)}
107
-
121
+
108
122
  for seg in self.segments:
109
123
  old_id = seg.get("class_id")
110
124
  if old_id in id_map:
111
125
  seg["class_id"] = id_map[old_id]
112
-
126
+
113
127
  # Update aliases
114
128
  new_aliases = {
115
129
  id_map[old_id]: self.class_aliases.get(old_id, str(old_id))
@@ -118,15 +132,32 @@ class SegmentManager:
118
132
  }
119
133
  self.class_aliases = new_aliases
120
134
  self._update_next_class_id()
121
-
135
+
122
136
  def set_class_alias(self, class_id: int, alias: str) -> None:
123
137
  """Set alias for a class."""
124
138
  self.class_aliases[class_id] = alias
125
-
139
+
126
140
  def get_class_alias(self, class_id: int) -> str:
127
141
  """Get alias for a class."""
128
142
  return self.class_aliases.get(class_id, str(class_id))
129
-
143
+
144
+ def set_active_class(self, class_id: Optional[int]) -> None:
145
+ """Set the active class ID."""
146
+ self.active_class_id = class_id
147
+
148
+ def get_active_class(self) -> Optional[int]:
149
+ """Get the active class ID."""
150
+ return self.active_class_id
151
+
152
+ def toggle_active_class(self, class_id: int) -> bool:
153
+ """Toggle a class as active. Returns True if now active, False if deactivated."""
154
+ if self.active_class_id == class_id:
155
+ self.active_class_id = None
156
+ return False
157
+ else:
158
+ self.active_class_id = class_id
159
+ return True
160
+
130
161
  def _update_next_class_id(self) -> None:
131
162
  """Update the next available class ID."""
132
163
  all_ids = {
@@ -137,4 +168,4 @@ class SegmentManager:
137
168
  if not all_ids:
138
169
  self.next_class_id = 0
139
170
  else:
140
- self.next_class_id = max(all_ids) + 1
171
+ self.next_class_id = max(all_ids) + 1
@@ -9,14 +9,28 @@ from .ui.main_window import MainWindow
9
9
 
10
10
  def main():
11
11
  """Main application entry point."""
12
+ print("=" * 50)
13
+ print("LazyLabel - AI-Assisted Image Labeling")
14
+ print("=" * 50)
15
+ print()
16
+
17
+ print("[1/20] Initializing application...")
12
18
  app = QApplication(sys.argv)
19
+
20
+ print("[2/20] Applying dark theme...")
13
21
  qdarktheme.setup_theme()
14
-
22
+
15
23
  main_window = MainWindow()
24
+
25
+ print("[19/20] Showing main window...")
16
26
  main_window.show()
17
-
27
+
28
+ print()
29
+ print("[20/20] LazyLabel is ready! Happy labeling!")
30
+ print("=" * 50)
31
+
18
32
  sys.exit(app.exec())
19
33
 
20
34
 
21
35
  if __name__ == "__main__":
22
- main()
36
+ main()
@@ -9,15 +9,18 @@ from segment_anything import sam_model_registry, SamPredictor
9
9
 
10
10
  def download_model(url, download_path):
11
11
  """Downloads file with a progress bar."""
12
- print(
13
- f"SAM model not found. Downloading from Meta's GitHub repository to: {download_path}"
14
- )
12
+ print(f"[10/20] SAM model not found. Downloading from Meta's repository...")
13
+ print(f" Downloading to: {download_path}")
15
14
  try:
15
+ print(f"[10/20] Connecting to download server...")
16
16
  response = requests.get(url, stream=True, timeout=30)
17
17
  response.raise_for_status()
18
18
  total_size_in_bytes = int(response.headers.get("content-length", 0))
19
19
  block_size = 1024 # 1 Kibibyte
20
20
 
21
+ print(
22
+ f"[10/20] Starting download ({total_size_in_bytes / (1024*1024*1024):.1f} GB)..."
23
+ )
21
24
  progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
22
25
  with open(download_path, "wb") as file:
23
26
  for data in response.iter_content(block_size):
@@ -27,93 +30,131 @@ def download_model(url, download_path):
27
30
 
28
31
  if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
29
32
  raise RuntimeError("Download incomplete - file size mismatch")
30
-
31
- print("Model download completed successfully.")
32
-
33
+
34
+ print("[10/20] Model download completed successfully.")
35
+
36
+ except requests.exceptions.ConnectionError as e:
37
+ raise RuntimeError(
38
+ f"[10/20] Network connection failed: Check your internet connection"
39
+ )
40
+ except requests.exceptions.Timeout as e:
41
+ raise RuntimeError(f"[10/20] Download timeout: Server took too long to respond")
42
+ except requests.exceptions.HTTPError as e:
43
+ raise RuntimeError(
44
+ f"[10/20] HTTP error {e.response.status_code}: Server rejected request"
45
+ )
33
46
  except requests.exceptions.RequestException as e:
34
- raise RuntimeError(f"Network error during download: {e}")
47
+ raise RuntimeError(f"[10/20] Network error during download: {e}")
48
+ except PermissionError as e:
49
+ raise RuntimeError(
50
+ f"[10/20] Permission denied: Cannot write to {download_path}"
51
+ )
52
+ except OSError as e:
53
+ raise RuntimeError(f"[10/20] Disk error: {e} (check available disk space)")
35
54
  except Exception as e:
36
55
  # Clean up partial download
37
56
  if os.path.exists(download_path):
38
- os.remove(download_path)
39
- raise RuntimeError(f"Download failed: {e}")
57
+ try:
58
+ os.remove(download_path)
59
+ except:
60
+ pass
61
+ raise RuntimeError(f"[10/20] Download failed: {e}")
40
62
 
41
63
 
42
64
  class SamModel:
43
- def __init__(self, model_type="vit_h", model_filename="sam_vit_h_4b8939.pth", custom_model_path=None):
65
+ def __init__(
66
+ self,
67
+ model_type="vit_h",
68
+ model_filename="sam_vit_h_4b8939.pth",
69
+ custom_model_path=None,
70
+ ):
44
71
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
72
+ print(f"[9/20] Detected device: {str(self.device).upper()}")
73
+
45
74
  self.current_model_type = model_type
46
75
  self.current_model_path = custom_model_path
47
76
  self.model = None
48
77
  self.predictor = None
49
78
  self.image = None
50
79
  self.is_loaded = False
51
-
80
+
52
81
  try:
53
82
  if custom_model_path and os.path.exists(custom_model_path):
54
83
  # Use custom model path
55
84
  model_path = custom_model_path
56
- print(f"Loading custom SAM model from {model_path}...")
85
+ print(f"[10/20] Loading custom SAM model from {model_path}...")
57
86
  else:
58
87
  # Use default model with download if needed - store in models folder
59
- model_url = f"https://dl.fbaipublicfiles.com/segment_anything/{model_filename}"
60
-
88
+ model_url = (
89
+ f"https://dl.fbaipublicfiles.com/segment_anything/{model_filename}"
90
+ )
91
+
61
92
  # Use models folder instead of cache folder
62
93
  models_dir = os.path.dirname(__file__) # Already in models directory
63
94
  os.makedirs(models_dir, exist_ok=True)
64
95
  model_path = os.path.join(models_dir, model_filename)
65
-
96
+
66
97
  # Also check the old cache location and move it if it exists
67
- old_cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "lazylabel")
98
+ old_cache_dir = os.path.join(
99
+ os.path.expanduser("~"), ".cache", "lazylabel"
100
+ )
68
101
  old_model_path = os.path.join(old_cache_dir, model_filename)
69
-
102
+
70
103
  if os.path.exists(old_model_path) and not os.path.exists(model_path):
71
- print(f"Moving existing model from cache to models folder...")
104
+ print(
105
+ f"[10/20] Moving existing model from cache to models folder..."
106
+ )
72
107
  import shutil
108
+
73
109
  shutil.move(old_model_path, model_path)
74
110
  elif not os.path.exists(model_path):
75
111
  # Download the model if it doesn't exist
76
112
  download_model(model_url, model_path)
77
-
78
- print(f"Loading default SAM model from {model_path}...")
79
113
 
114
+ print(f"[10/20] Loading default SAM model from {model_path}...")
115
+
116
+ print(f"[11/20] Initializing {model_type.upper()} model architecture...")
80
117
  self.model = sam_model_registry[model_type](checkpoint=model_path).to(
81
118
  self.device
82
119
  )
120
+
121
+ print(f"[12/20] Setting up predictor...")
83
122
  self.predictor = SamPredictor(self.model)
84
123
  self.is_loaded = True
85
- print("SAM model loaded successfully.")
86
-
124
+ print("[13/20] SAM model loaded successfully.")
125
+
87
126
  except Exception as e:
88
- print(f"Failed to load SAM model: {e}")
89
- print("SAM point functionality will be disabled.")
127
+ print(f"[8/20] Failed to load SAM model: {e}")
128
+ print(f"[8/20] SAM point functionality will be disabled.")
90
129
  self.is_loaded = False
91
-
130
+
92
131
  def load_custom_model(self, model_path, model_type="vit_h"):
93
132
  """Load a custom model from the specified path."""
94
133
  if not os.path.exists(model_path):
95
134
  print(f"Model file not found: {model_path}")
96
135
  return False
97
-
136
+
98
137
  print(f"Loading custom SAM model from {model_path}...")
99
138
  try:
100
139
  # Clear existing model from memory
101
- if hasattr(self, 'model') and self.model is not None:
140
+ if hasattr(self, "model") and self.model is not None:
102
141
  del self.model
103
142
  del self.predictor
104
143
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
105
-
144
+
106
145
  # Load new model
107
- self.model = sam_model_registry[model_type](checkpoint=model_path).to(self.device)
146
+ self.model = sam_model_registry[model_type](checkpoint=model_path).to(
147
+ self.device
148
+ )
108
149
  self.predictor = SamPredictor(self.model)
109
150
  self.current_model_type = model_type
110
151
  self.current_model_path = model_path
111
152
  self.is_loaded = True
112
-
153
+
113
154
  # Re-set image if one was previously loaded
114
155
  if self.image is not None:
115
156
  self.predictor.set_image(self.image)
116
-
157
+
117
158
  print("Custom SAM model loaded successfully.")
118
159
  return True
119
160
  except Exception as e: