lazylabel-gui 1.1.8__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -44,7 +44,7 @@ class HotkeyManager:
44
44
  ),
45
45
  HotkeyAction("fit_view", "Fit View", ".", category="Navigation"),
46
46
  # Modes
47
- HotkeyAction("sam_mode", "Point Mode (SAM)", "1", category="Modes"),
47
+ HotkeyAction("sam_mode", "AI Mode (Points + Box)", "1", category="Modes"),
48
48
  HotkeyAction("polygon_mode", "Polygon Mode", "2", category="Modes"),
49
49
  HotkeyAction("bbox_mode", "Bounding Box Mode", "3", category="Modes"),
50
50
  HotkeyAction("selection_mode", "Selection Mode", "E", category="Modes"),
@@ -95,7 +95,7 @@ class HotkeyManager:
95
95
  # Mouse-related (cannot be reassigned)
96
96
  HotkeyAction(
97
97
  "left_click",
98
- "Add Positive Point / Select",
98
+ "AI: Point (click) / Box (drag) / Select",
99
99
  "Left Click",
100
100
  category="Mouse",
101
101
  mouse_related=True,
@@ -7,6 +7,18 @@ from ..config import Paths
7
7
  from ..models.sam_model import SamModel
8
8
  from ..utils.logger import logger
9
9
 
10
+ # Optional SAM-2 support
11
+ try:
12
+ from ..models.sam2_model import Sam2Model
13
+
14
+ SAM2_AVAILABLE = True
15
+ except ImportError:
16
+ logger.info(
17
+ "SAM-2 not available. Install with: pip install git+https://github.com/facebookresearch/sam2.git"
18
+ )
19
+ Sam2Model = None
20
+ SAM2_AVAILABLE = False
21
+
10
22
 
11
23
  class ModelManager:
12
24
  """Manages SAM model loading and selection."""
@@ -42,7 +54,7 @@ class ModelManager:
42
54
  pth_files = []
43
55
  for root, _dirs, files in os.walk(folder_path):
44
56
  for file in files:
45
- if file.lower().endswith(".pth"):
57
+ if file.lower().endswith(".pth") or file.lower().endswith(".pt"):
46
58
  full_path = os.path.join(root, file)
47
59
  rel_path = os.path.relpath(full_path, folder_path)
48
60
  pth_files.append((rel_path, full_path))
@@ -52,13 +64,34 @@ class ModelManager:
52
64
  def detect_model_type(self, model_path: str) -> str:
53
65
  """Detect model type from filename."""
54
66
  filename = os.path.basename(model_path).lower()
55
- if "vit_l" in filename or "large" in filename:
56
- return "vit_l"
57
- elif "vit_b" in filename or "base" in filename:
58
- return "vit_b"
59
- elif "vit_h" in filename or "huge" in filename:
60
- return "vit_h"
61
- return "vit_h" # default
67
+
68
+ # Check if it's a SAM2 model
69
+ if self._is_sam2_model(model_path):
70
+ if "tiny" in filename or "_t" in filename:
71
+ return "sam2_tiny"
72
+ elif "small" in filename or "_s" in filename:
73
+ return "sam2_small"
74
+ elif "base_plus" in filename or "_b+" in filename:
75
+ return "sam2_base_plus"
76
+ elif "large" in filename or "_l" in filename:
77
+ return "sam2_large"
78
+ else:
79
+ return "sam2_large" # default for SAM2
80
+ else:
81
+ # Original SAM model types
82
+ if "vit_l" in filename or "large" in filename:
83
+ return "vit_l"
84
+ elif "vit_b" in filename or "base" in filename:
85
+ return "vit_b"
86
+ elif "vit_h" in filename or "huge" in filename:
87
+ return "vit_h"
88
+ return "vit_h" # default for SAM1
89
+
90
+ def _is_sam2_model(self, model_path: str) -> bool:
91
+ """Check if the model is a SAM2 model based on filename patterns."""
92
+ filename = os.path.basename(model_path).lower()
93
+ sam2_indicators = ["sam2", "sam2.1", "hiera", "_t.", "_s.", "_b+.", "_l."]
94
+ return any(indicator in filename for indicator in sam2_indicators)
62
95
 
63
96
  def load_custom_model(self, model_path: str) -> bool:
64
97
  """Load a custom model from path.
@@ -66,20 +99,62 @@ class ModelManager:
66
99
  Returns:
67
100
  True if successful, False otherwise
68
101
  """
69
- if not self.sam_model:
70
- return False
71
-
72
102
  if not os.path.exists(model_path):
73
103
  return False
74
104
 
75
105
  model_type = self.detect_model_type(model_path)
76
- success = self.sam_model.load_custom_model(model_path, model_type)
77
106
 
78
- if success and self.on_model_changed:
79
- model_name = os.path.basename(model_path)
80
- self.on_model_changed(f"Current: {model_name}")
107
+ try:
108
+ # Clear existing model from memory
109
+ if self.sam_model is not None:
110
+ del self.sam_model
111
+ import torch
112
+
113
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
114
+
115
+ # Create appropriate model instance
116
+ if self._is_sam2_model(model_path):
117
+ if not SAM2_AVAILABLE:
118
+ logger.warning(
119
+ f"SAM-2 model detected but SAM-2 not installed: {model_path}"
120
+ )
121
+ logger.info(
122
+ "Install SAM-2 with: pip install git+https://github.com/facebookresearch/sam2.git"
123
+ )
124
+ return False
125
+
126
+ logger.info(f"Loading SAM2 model: {model_type}")
127
+ self.sam_model = Sam2Model(model_path)
128
+ else:
129
+ logger.info(f"Loading SAM1 model: {model_type}")
130
+ # Convert SAM2 model types back to SAM1 types for compatibility
131
+ sam1_model_type = model_type
132
+ if model_type.startswith("sam2_"):
133
+ type_mapping = {
134
+ "sam2_tiny": "vit_b",
135
+ "sam2_small": "vit_b",
136
+ "sam2_base_plus": "vit_l",
137
+ "sam2_large": "vit_h",
138
+ }
139
+ sam1_model_type = type_mapping.get(model_type, "vit_h")
140
+
141
+ # Create SAM1 model with custom path
142
+ self.sam_model = SamModel(
143
+ model_type=sam1_model_type, custom_model_path=model_path
144
+ )
145
+
146
+ success = self.sam_model.is_loaded
147
+
148
+ if success and self.on_model_changed:
149
+ model_name = os.path.basename(model_path)
150
+ self.on_model_changed(f"Current: {model_name}")
151
+
152
+ return success
81
153
 
82
- return success
154
+ except Exception as e:
155
+ logger.error(f"Failed to load custom model: {e}")
156
+ self.sam_model = None
157
+ return False
83
158
 
84
159
  def set_models_folder(self, folder_path: str) -> None:
85
160
  """Set the current models folder."""
@@ -0,0 +1,223 @@
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+
7
+ from ..utils.logger import logger
8
+
9
+ # SAM-2 specific imports - will fail gracefully if not available
10
+ try:
11
+ from sam2.build_sam import build_sam2
12
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
13
+ except ImportError as e:
14
+ logger.error(f"SAM-2 dependencies not found: {e}")
15
+ logger.info(
16
+ "Install SAM-2 with: pip install git+https://github.com/facebookresearch/sam2.git"
17
+ )
18
+ raise ImportError("SAM-2 dependencies required for Sam2Model") from e
19
+
20
+
21
+ class Sam2Model:
22
+ """SAM2 model wrapper that provides the same interface as SamModel."""
23
+
24
+ def __init__(self, model_path: str, config_path: str | None = None):
25
+ """Initialize SAM2 model.
26
+
27
+ Args:
28
+ model_path: Path to the SAM2 model checkpoint (.pt file)
29
+ config_path: Path to the config file (optional, will auto-detect if None)
30
+ """
31
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ logger.info(f"SAM2: Detected device: {str(self.device).upper()}")
33
+
34
+ self.current_model_path = model_path
35
+ self.model = None
36
+ self.predictor = None
37
+ self.image = None
38
+ self.is_loaded = False
39
+
40
+ # Auto-detect config if not provided
41
+ if config_path is None:
42
+ config_path = self._auto_detect_config(model_path)
43
+
44
+ try:
45
+ logger.info(f"SAM2: Loading model from {model_path}...")
46
+ logger.info(f"SAM2: Using config: {config_path}")
47
+
48
+ # Build SAM2 model
49
+ self.model = build_sam2(config_path, model_path, device=self.device)
50
+
51
+ # Create predictor
52
+ self.predictor = SAM2ImagePredictor(self.model)
53
+
54
+ self.is_loaded = True
55
+ logger.info("SAM2: Model loaded successfully.")
56
+
57
+ except Exception as e:
58
+ logger.error(f"SAM2: Failed to load model: {e}")
59
+ logger.warning("SAM2: SAM2 functionality will be disabled.")
60
+ self.is_loaded = False
61
+
62
+ def _auto_detect_config(self, model_path: str) -> str:
63
+ """Auto-detect the appropriate config file based on model filename."""
64
+ filename = os.path.basename(model_path).lower()
65
+
66
+ # Get the sam2 package directory
67
+ try:
68
+ import sam2
69
+
70
+ sam2_dir = os.path.dirname(sam2.__file__)
71
+ configs_dir = os.path.join(sam2_dir, "configs")
72
+
73
+ # Map model types to config files
74
+ if "tiny" in filename or "_t" in filename:
75
+ config_file = (
76
+ "sam2.1_hiera_t.yaml" if "2.1" in filename else "sam2_hiera_t.yaml"
77
+ )
78
+ elif "small" in filename or "_s" in filename:
79
+ config_file = (
80
+ "sam2.1_hiera_s.yaml" if "2.1" in filename else "sam2_hiera_s.yaml"
81
+ )
82
+ elif "base_plus" in filename or "_b+" in filename:
83
+ config_file = (
84
+ "sam2.1_hiera_b+.yaml"
85
+ if "2.1" in filename
86
+ else "sam2_hiera_b+.yaml"
87
+ )
88
+ elif "large" in filename or "_l" in filename:
89
+ config_file = (
90
+ "sam2.1_hiera_l.yaml" if "2.1" in filename else "sam2_hiera_l.yaml"
91
+ )
92
+ else:
93
+ # Default to large model
94
+ config_file = "sam2.1_hiera_l.yaml"
95
+
96
+ # Check sam2.1 configs first, then fall back to sam2
97
+ if "2.1" in filename:
98
+ config_path = os.path.join(configs_dir, "sam2.1", config_file)
99
+ else:
100
+ config_path = os.path.join(
101
+ configs_dir, "sam2", config_file.replace("2.1_", "")
102
+ )
103
+
104
+ if os.path.exists(config_path):
105
+ return config_path
106
+
107
+ # Fallback to default large config
108
+ fallback_config = os.path.join(configs_dir, "sam2.1", "sam2.1_hiera_l.yaml")
109
+ if os.path.exists(fallback_config):
110
+ return fallback_config
111
+
112
+ raise FileNotFoundError(f"No suitable config found for {filename}")
113
+
114
+ except Exception as e:
115
+ logger.error(f"SAM2: Failed to auto-detect config: {e}")
116
+ # Return a reasonable default path
117
+ return "sam2.1_hiera_l.yaml"
118
+
119
+ def set_image_from_path(self, image_path: str) -> bool:
120
+ """Set image for SAM2 model from file path."""
121
+ if not self.is_loaded:
122
+ return False
123
+ try:
124
+ self.image = cv2.imread(image_path)
125
+ self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
126
+ self.predictor.set_image(self.image)
127
+ return True
128
+ except Exception as e:
129
+ logger.error(f"SAM2: Error setting image from path: {e}")
130
+ return False
131
+
132
+ def set_image_from_array(self, image_array: np.ndarray) -> bool:
133
+ """Set image for SAM2 model from numpy array."""
134
+ if not self.is_loaded:
135
+ return False
136
+ try:
137
+ self.image = image_array
138
+ self.predictor.set_image(self.image)
139
+ return True
140
+ except Exception as e:
141
+ logger.error(f"SAM2: Error setting image from array: {e}")
142
+ return False
143
+
144
+ def predict(self, positive_points, negative_points):
145
+ """Generate predictions using SAM2."""
146
+ if not self.is_loaded or not positive_points:
147
+ return None
148
+
149
+ try:
150
+ points = np.array(positive_points + negative_points)
151
+ labels = np.array([1] * len(positive_points) + [0] * len(negative_points))
152
+
153
+ masks, scores, logits = self.predictor.predict(
154
+ point_coords=points,
155
+ point_labels=labels,
156
+ multimask_output=True,
157
+ )
158
+
159
+ # Return the mask with the highest score
160
+ best_mask_idx = np.argmax(scores)
161
+ return masks[best_mask_idx], scores[best_mask_idx], logits[best_mask_idx]
162
+
163
+ except Exception as e:
164
+ logger.error(f"SAM2: Error during prediction: {e}")
165
+ return None
166
+
167
+ def predict_from_box(self, box):
168
+ """Generate predictions from bounding box using SAM2."""
169
+ if not self.is_loaded:
170
+ return None
171
+
172
+ try:
173
+ masks, scores, logits = self.predictor.predict(
174
+ box=np.array(box),
175
+ multimask_output=True,
176
+ )
177
+
178
+ # Return the mask with the highest score
179
+ best_mask_idx = np.argmax(scores)
180
+ return masks[best_mask_idx], scores[best_mask_idx], logits[best_mask_idx]
181
+
182
+ except Exception as e:
183
+ logger.error(f"SAM2: Error during box prediction: {e}")
184
+ return None
185
+
186
+ def load_custom_model(
187
+ self, model_path: str, config_path: str | None = None
188
+ ) -> bool:
189
+ """Load a custom SAM2 model from the specified path."""
190
+ if not os.path.exists(model_path):
191
+ logger.warning(f"SAM2: Model file not found: {model_path}")
192
+ return False
193
+
194
+ logger.info(f"SAM2: Loading custom model from {model_path}...")
195
+ try:
196
+ # Clear existing model from memory
197
+ if hasattr(self, "model") and self.model is not None:
198
+ del self.model
199
+ del self.predictor
200
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
201
+
202
+ # Auto-detect config if not provided
203
+ if config_path is None:
204
+ config_path = self._auto_detect_config(model_path)
205
+
206
+ # Load new model
207
+ self.model = build_sam2(config_path, model_path, device=self.device)
208
+ self.predictor = SAM2ImagePredictor(self.model)
209
+ self.current_model_path = model_path
210
+ self.is_loaded = True
211
+
212
+ # Re-set image if one was previously loaded
213
+ if self.image is not None:
214
+ self.predictor.set_image(self.image)
215
+
216
+ logger.info("SAM2: Custom model loaded successfully.")
217
+ return True
218
+ except Exception as e:
219
+ logger.error(f"SAM2: Error loading custom model: {e}")
220
+ self.is_loaded = False
221
+ self.model = None
222
+ self.predictor = None
223
+ return False
@@ -200,12 +200,34 @@ class SamModel:
200
200
  points = np.array(positive_points + negative_points)
201
201
  labels = np.array([1] * len(positive_points) + [0] * len(negative_points))
202
202
 
203
- masks, _, _ = self.predictor.predict(
203
+ masks, scores, logits = self.predictor.predict(
204
204
  point_coords=points,
205
205
  point_labels=labels,
206
- multimask_output=False,
206
+ multimask_output=True,
207
207
  )
208
- return masks[0]
208
+
209
+ # Return the mask with the highest score (consistent with SAM2)
210
+ best_mask_idx = np.argmax(scores)
211
+ return masks[best_mask_idx], scores[best_mask_idx], logits[best_mask_idx]
209
212
  except Exception as e:
210
213
  logger.error(f"Error during prediction: {e}")
211
214
  return None
215
+
216
+ def predict_from_box(self, box):
217
+ """Generate predictions from bounding box using SAM."""
218
+ if not self.is_loaded:
219
+ return None
220
+
221
+ try:
222
+ masks, scores, logits = self.predictor.predict(
223
+ box=np.array(box),
224
+ multimask_output=True,
225
+ )
226
+
227
+ # Return the mask with the highest score
228
+ best_mask_idx = np.argmax(scores)
229
+ return masks[best_mask_idx], scores[best_mask_idx], logits[best_mask_idx]
230
+
231
+ except Exception as e:
232
+ logger.error(f"Error during box prediction: {e}")
233
+ return None