lazylabel-gui 1.3.4__tar.gz → 1.3.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/PKG-INFO +1 -1
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/pyproject.toml +1 -1
- lazylabel_gui-1.3.5/src/lazylabel/models/sam2_model.py +490 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/ui/main_window.py +100 -530
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/ui/photo_viewer.py +35 -11
- lazylabel_gui-1.3.5/src/lazylabel/ui/workers/__init__.py +15 -0
- lazylabel_gui-1.3.5/src/lazylabel/ui/workers/image_discovery_worker.py +66 -0
- lazylabel_gui-1.3.5/src/lazylabel/ui/workers/multi_view_sam_init_worker.py +135 -0
- lazylabel_gui-1.3.5/src/lazylabel/ui/workers/multi_view_sam_update_worker.py +158 -0
- lazylabel_gui-1.3.5/src/lazylabel/ui/workers/sam_update_worker.py +129 -0
- lazylabel_gui-1.3.5/src/lazylabel/ui/workers/single_view_sam_init_worker.py +61 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel_gui.egg-info/PKG-INFO +1 -1
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel_gui.egg-info/SOURCES.txt +6 -0
- lazylabel_gui-1.3.4/src/lazylabel/models/sam2_model.py +0 -371
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/LICENSE +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/README.md +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/setup.cfg +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/__init__.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/__main__.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/config/__init__.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/config/hotkeys.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/config/paths.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/config/settings.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/core/__init__.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/core/file_manager.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/core/model_manager.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/core/segment_manager.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/main.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/models/__init__.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/models/sam_model.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/ui/__init__.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/ui/control_panel.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/ui/editable_vertex.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/ui/hotkey_dialog.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/ui/hoverable_pixelmap_item.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/ui/hoverable_polygon_item.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/ui/modes/__init__.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/ui/modes/base_mode.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/ui/modes/multi_view_mode.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/ui/modes/single_view_mode.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/ui/numeric_table_widget_item.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/ui/reorderable_class_table.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/ui/right_panel.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/ui/widgets/__init__.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/ui/widgets/adjustments_widget.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/ui/widgets/border_crop_widget.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/ui/widgets/channel_threshold_widget.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/ui/widgets/fft_threshold_widget.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/ui/widgets/fragment_threshold_widget.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/ui/widgets/model_selection_widget.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/ui/widgets/settings_widget.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/ui/widgets/status_bar.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/utils/__init__.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/utils/custom_file_system_model.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/utils/fast_file_manager.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/utils/logger.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel/utils/utils.py +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel_gui.egg-info/dependency_links.txt +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel_gui.egg-info/entry_points.txt +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel_gui.egg-info/requires.txt +0 -0
- {lazylabel_gui-1.3.4 → lazylabel_gui-1.3.5}/src/lazylabel_gui.egg-info/top_level.txt +0 -0
@@ -0,0 +1,490 @@
|
|
1
|
+
import os
|
2
|
+
from pathlib import Path
|
3
|
+
|
4
|
+
import cv2
|
5
|
+
import numpy as np
|
6
|
+
import torch
|
7
|
+
|
8
|
+
from ..utils.logger import logger
|
9
|
+
|
10
|
+
# SAM-2 specific imports - will fail gracefully if not available
|
11
|
+
try:
|
12
|
+
from sam2.build_sam import build_sam2
|
13
|
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
14
|
+
except ImportError as e:
|
15
|
+
logger.error(f"SAM-2 dependencies not found: {e}")
|
16
|
+
logger.info(
|
17
|
+
"Install SAM-2 with: pip install git+https://github.com/facebookresearch/sam2.git"
|
18
|
+
)
|
19
|
+
raise ImportError("SAM-2 dependencies required for Sam2Model") from e
|
20
|
+
|
21
|
+
|
22
|
+
class Sam2Model:
|
23
|
+
"""SAM2 model wrapper that provides the same interface as SamModel."""
|
24
|
+
|
25
|
+
def __init__(self, model_path: str, config_path: str | None = None):
|
26
|
+
"""Initialize SAM2 model.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
model_path: Path to the SAM2 model checkpoint (.pt file)
|
30
|
+
config_path: Path to the config file (optional, will auto-detect if None)
|
31
|
+
"""
|
32
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
33
|
+
logger.info(f"SAM2: Detected device: {str(self.device).upper()}")
|
34
|
+
|
35
|
+
self.current_model_path = model_path
|
36
|
+
self.model = None
|
37
|
+
self.predictor = None
|
38
|
+
self.image = None
|
39
|
+
self.is_loaded = False
|
40
|
+
|
41
|
+
# Auto-detect config if not provided
|
42
|
+
if config_path is None:
|
43
|
+
config_path = self._auto_detect_config(model_path)
|
44
|
+
|
45
|
+
try:
|
46
|
+
logger.info(f"SAM2: Loading model from {model_path}...")
|
47
|
+
logger.info(f"SAM2: Using config: {config_path}")
|
48
|
+
|
49
|
+
# Ensure config_path is absolute
|
50
|
+
if not os.path.isabs(config_path):
|
51
|
+
# Try to make it absolute if it's relative
|
52
|
+
import sam2
|
53
|
+
|
54
|
+
sam2_dir = os.path.dirname(sam2.__file__)
|
55
|
+
config_path = os.path.join(sam2_dir, "configs", config_path)
|
56
|
+
|
57
|
+
# Verify the config exists before passing to build_sam2
|
58
|
+
if not os.path.exists(config_path):
|
59
|
+
raise FileNotFoundError(f"Config file not found: {config_path}")
|
60
|
+
|
61
|
+
logger.info(f"SAM2: Resolved config path: {config_path}")
|
62
|
+
|
63
|
+
# Build SAM2 model
|
64
|
+
# SAM2 uses Hydra for configuration - we need to pass the right config name
|
65
|
+
# Try different approaches based on what's available
|
66
|
+
|
67
|
+
model_filename = Path(model_path).name.lower()
|
68
|
+
|
69
|
+
# For SAM2.1 models, use manual Hydra initialization since configs aren't in search path
|
70
|
+
if "2.1" in model_filename:
|
71
|
+
logger.info(
|
72
|
+
"SAM2: Loading SAM2.1 model with manual config initialization"
|
73
|
+
)
|
74
|
+
|
75
|
+
try:
|
76
|
+
# Import required Hydra components
|
77
|
+
# Get the configs directory
|
78
|
+
import sam2
|
79
|
+
from hydra import compose, initialize_config_dir
|
80
|
+
from hydra.core.global_hydra import GlobalHydra
|
81
|
+
|
82
|
+
sam2_configs_dir = os.path.join(
|
83
|
+
os.path.dirname(sam2.__file__), "configs", "sam2.1"
|
84
|
+
)
|
85
|
+
|
86
|
+
# Clear any existing Hydra instance
|
87
|
+
GlobalHydra.instance().clear()
|
88
|
+
|
89
|
+
# Initialize Hydra with the SAM2.1 configs directory
|
90
|
+
with initialize_config_dir(
|
91
|
+
config_dir=sam2_configs_dir, version_base=None
|
92
|
+
):
|
93
|
+
config_filename = Path(config_path).name
|
94
|
+
logger.info(f"SAM2: Loading SAM2.1 config: {config_filename}")
|
95
|
+
|
96
|
+
# Load the config
|
97
|
+
cfg = compose(config_name=config_filename.replace(".yaml", ""))
|
98
|
+
|
99
|
+
# Manually build the model using the config
|
100
|
+
from hydra.utils import instantiate
|
101
|
+
|
102
|
+
self.model = instantiate(cfg.model)
|
103
|
+
self.model.to(self.device)
|
104
|
+
|
105
|
+
# Load the checkpoint weights
|
106
|
+
if model_path:
|
107
|
+
checkpoint = torch.load(
|
108
|
+
model_path, map_location=self.device
|
109
|
+
)
|
110
|
+
# Handle nested checkpoint structure
|
111
|
+
if "model" in checkpoint:
|
112
|
+
model_weights = checkpoint["model"]
|
113
|
+
else:
|
114
|
+
model_weights = checkpoint
|
115
|
+
self.model.load_state_dict(model_weights, strict=False)
|
116
|
+
|
117
|
+
logger.info(
|
118
|
+
"SAM2: Successfully loaded SAM2.1 with manual initialization"
|
119
|
+
)
|
120
|
+
|
121
|
+
except Exception as e1:
|
122
|
+
logger.debug(f"SAM2: SAM2.1 manual initialization failed: {e1}")
|
123
|
+
# Fallback to using a compatible SAM2.0 config as a workaround
|
124
|
+
logger.warning(
|
125
|
+
"SAM2: Falling back to SAM2.0 config for SAM2.1 model (may have reduced functionality)"
|
126
|
+
)
|
127
|
+
try:
|
128
|
+
# Use the closest SAM2.0 config
|
129
|
+
fallback_config = (
|
130
|
+
"sam2_hiera_l.yaml" # This works according to our test
|
131
|
+
)
|
132
|
+
logger.info(
|
133
|
+
f"SAM2: Attempting fallback with SAM2.0 config: {fallback_config}"
|
134
|
+
)
|
135
|
+
self.model = build_sam2(
|
136
|
+
fallback_config, model_path, device=self.device
|
137
|
+
)
|
138
|
+
logger.warning(
|
139
|
+
"SAM2: Loaded SAM2.1 model with SAM2.0 config - some features may not work"
|
140
|
+
)
|
141
|
+
except Exception as e2:
|
142
|
+
raise Exception(
|
143
|
+
f"Failed to load SAM2.1 model. Manual initialization failed: {e1}. "
|
144
|
+
f"Fallback to SAM2.0 config also failed: {e2}. "
|
145
|
+
f"Try reinstalling SAM2 with latest version from official repo."
|
146
|
+
) from e2
|
147
|
+
else:
|
148
|
+
# Standard SAM2.0 loading approach
|
149
|
+
try:
|
150
|
+
logger.info(
|
151
|
+
f"SAM2: Attempting to load with config path: {config_path}"
|
152
|
+
)
|
153
|
+
self.model = build_sam2(config_path, model_path, device=self.device)
|
154
|
+
logger.info("SAM2: Successfully loaded with config path")
|
155
|
+
except Exception as e1:
|
156
|
+
logger.debug(f"SAM2: Config path approach failed: {e1}")
|
157
|
+
|
158
|
+
# Try just the config filename without path (for Hydra)
|
159
|
+
try:
|
160
|
+
config_filename = Path(config_path).name
|
161
|
+
logger.info(
|
162
|
+
f"SAM2: Attempting to load with config filename: {config_filename}"
|
163
|
+
)
|
164
|
+
self.model = build_sam2(
|
165
|
+
config_filename, model_path, device=self.device
|
166
|
+
)
|
167
|
+
logger.info("SAM2: Successfully loaded with config filename")
|
168
|
+
except Exception as e2:
|
169
|
+
logger.debug(f"SAM2: Config filename approach failed: {e2}")
|
170
|
+
|
171
|
+
# Try the base config name for SAM2.0 models
|
172
|
+
try:
|
173
|
+
# Map model sizes to base config names (SAM2.0 only)
|
174
|
+
if (
|
175
|
+
"tiny" in model_filename
|
176
|
+
or "_t." in model_filename
|
177
|
+
or "_t_" in model_filename
|
178
|
+
):
|
179
|
+
base_config = "sam2_hiera_t.yaml"
|
180
|
+
elif (
|
181
|
+
"small" in model_filename
|
182
|
+
or "_s." in model_filename
|
183
|
+
or "_s_" in model_filename
|
184
|
+
):
|
185
|
+
base_config = "sam2_hiera_s.yaml"
|
186
|
+
elif (
|
187
|
+
"base_plus" in model_filename
|
188
|
+
or "_b+." in model_filename
|
189
|
+
or "_b+_" in model_filename
|
190
|
+
):
|
191
|
+
base_config = "sam2_hiera_b+.yaml"
|
192
|
+
elif (
|
193
|
+
"large" in model_filename
|
194
|
+
or "_l." in model_filename
|
195
|
+
or "_l_" in model_filename
|
196
|
+
):
|
197
|
+
base_config = "sam2_hiera_l.yaml"
|
198
|
+
else:
|
199
|
+
base_config = "sam2_hiera_l.yaml"
|
200
|
+
|
201
|
+
logger.info(
|
202
|
+
f"SAM2: Attempting to load with base config: {base_config}"
|
203
|
+
)
|
204
|
+
self.model = build_sam2(
|
205
|
+
base_config, model_path, device=self.device
|
206
|
+
)
|
207
|
+
logger.info("SAM2: Successfully loaded with base config")
|
208
|
+
except Exception as e3:
|
209
|
+
# All approaches failed
|
210
|
+
raise Exception(
|
211
|
+
f"Failed to load SAM2 model with any config approach. "
|
212
|
+
f"Tried: {config_path}, {config_filename}, {base_config}. "
|
213
|
+
f"Last error: {e3}"
|
214
|
+
) from e3
|
215
|
+
|
216
|
+
# Create predictor
|
217
|
+
self.predictor = SAM2ImagePredictor(self.model)
|
218
|
+
|
219
|
+
self.is_loaded = True
|
220
|
+
logger.info("SAM2: Model loaded successfully.")
|
221
|
+
|
222
|
+
except Exception as e:
|
223
|
+
logger.error(f"SAM2: Failed to load model: {e}")
|
224
|
+
logger.warning("SAM2: SAM2 functionality will be disabled.")
|
225
|
+
self.is_loaded = False
|
226
|
+
|
227
|
+
def _auto_detect_config(self, model_path: str) -> str:
|
228
|
+
"""Auto-detect the appropriate config file based on model filename."""
|
229
|
+
model_path = Path(model_path)
|
230
|
+
filename = model_path.name.lower()
|
231
|
+
|
232
|
+
# Get the sam2 package directory
|
233
|
+
try:
|
234
|
+
import sam2
|
235
|
+
|
236
|
+
sam2_dir = Path(sam2.__file__).parent
|
237
|
+
configs_dir = sam2_dir / "configs"
|
238
|
+
|
239
|
+
# Determine if this is a SAM2.1 model
|
240
|
+
is_sam21 = "2.1" in filename
|
241
|
+
|
242
|
+
# Map model types to config files based on version
|
243
|
+
if "tiny" in filename or "_t" in filename:
|
244
|
+
config_file = "sam2.1_hiera_t.yaml" if is_sam21 else "sam2_hiera_t.yaml"
|
245
|
+
elif "small" in filename or "_s" in filename:
|
246
|
+
config_file = "sam2.1_hiera_s.yaml" if is_sam21 else "sam2_hiera_s.yaml"
|
247
|
+
elif "base_plus" in filename or "_b+" in filename:
|
248
|
+
config_file = (
|
249
|
+
"sam2.1_hiera_b+.yaml" if is_sam21 else "sam2_hiera_b+.yaml"
|
250
|
+
)
|
251
|
+
elif "large" in filename or "_l" in filename:
|
252
|
+
config_file = "sam2.1_hiera_l.yaml" if is_sam21 else "sam2_hiera_l.yaml"
|
253
|
+
else:
|
254
|
+
# Default to large model with appropriate version
|
255
|
+
config_file = "sam2.1_hiera_l.yaml" if is_sam21 else "sam2_hiera_l.yaml"
|
256
|
+
|
257
|
+
# Build config path based on version
|
258
|
+
if is_sam21:
|
259
|
+
config_path = configs_dir / "sam2.1" / config_file
|
260
|
+
else:
|
261
|
+
config_path = configs_dir / "sam2" / config_file
|
262
|
+
|
263
|
+
logger.debug(f"SAM2: Checking config path: {config_path}")
|
264
|
+
if config_path.exists():
|
265
|
+
return str(config_path.absolute())
|
266
|
+
|
267
|
+
# Fallback to default large config of the same version
|
268
|
+
fallback_config_file = (
|
269
|
+
"sam2.1_hiera_l.yaml" if is_sam21 else "sam2_hiera_l.yaml"
|
270
|
+
)
|
271
|
+
fallback_subdir = "sam2.1" if is_sam21 else "sam2"
|
272
|
+
fallback_config = configs_dir / fallback_subdir / fallback_config_file
|
273
|
+
logger.debug(f"SAM2: Checking fallback config: {fallback_config}")
|
274
|
+
if fallback_config.exists():
|
275
|
+
return str(fallback_config.absolute())
|
276
|
+
|
277
|
+
# Try without version subdirectory (only for SAM2.0)
|
278
|
+
if not is_sam21:
|
279
|
+
direct_config = configs_dir / config_file
|
280
|
+
logger.debug(f"SAM2: Checking direct config: {direct_config}")
|
281
|
+
if direct_config.exists():
|
282
|
+
return str(direct_config.absolute())
|
283
|
+
|
284
|
+
raise FileNotFoundError(
|
285
|
+
f"No suitable {'SAM2.1' if is_sam21 else 'SAM2'} config found for {filename} in {configs_dir}"
|
286
|
+
)
|
287
|
+
|
288
|
+
except Exception as e:
|
289
|
+
logger.error(f"SAM2: Failed to auto-detect config: {e}")
|
290
|
+
# Try to construct a full path even if auto-detection failed
|
291
|
+
try:
|
292
|
+
import sam2
|
293
|
+
|
294
|
+
sam2_dir = Path(sam2.__file__).parent
|
295
|
+
filename = Path(model_path).name.lower()
|
296
|
+
is_sam21 = "2.1" in filename
|
297
|
+
|
298
|
+
# Return full path to appropriate default config
|
299
|
+
if is_sam21:
|
300
|
+
return str(sam2_dir / "configs" / "sam2.1" / "sam2.1_hiera_l.yaml")
|
301
|
+
else:
|
302
|
+
return str(sam2_dir / "configs" / "sam2" / "sam2_hiera_l.yaml")
|
303
|
+
except Exception:
|
304
|
+
# Last resort - return just the config name and let hydra handle it
|
305
|
+
filename = Path(model_path).name.lower()
|
306
|
+
is_sam21 = "2.1" in filename
|
307
|
+
return "sam2.1_hiera_l.yaml" if is_sam21 else "sam2_hiera_l.yaml"
|
308
|
+
|
309
|
+
def set_image_from_path(self, image_path: str) -> bool:
|
310
|
+
"""Set image for SAM2 model from file path."""
|
311
|
+
if not self.is_loaded:
|
312
|
+
return False
|
313
|
+
try:
|
314
|
+
self.image = cv2.imread(image_path)
|
315
|
+
self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
|
316
|
+
self.predictor.set_image(self.image)
|
317
|
+
return True
|
318
|
+
except Exception as e:
|
319
|
+
logger.error(f"SAM2: Error setting image from path: {e}")
|
320
|
+
return False
|
321
|
+
|
322
|
+
def set_image_from_array(self, image_array: np.ndarray) -> bool:
|
323
|
+
"""Set image for SAM2 model from numpy array."""
|
324
|
+
if not self.is_loaded:
|
325
|
+
return False
|
326
|
+
try:
|
327
|
+
self.image = image_array
|
328
|
+
self.predictor.set_image(self.image)
|
329
|
+
return True
|
330
|
+
except Exception as e:
|
331
|
+
logger.error(f"SAM2: Error setting image from array: {e}")
|
332
|
+
return False
|
333
|
+
|
334
|
+
def predict(self, positive_points, negative_points):
|
335
|
+
"""Generate predictions using SAM2."""
|
336
|
+
if not self.is_loaded or not positive_points:
|
337
|
+
return None
|
338
|
+
|
339
|
+
try:
|
340
|
+
points = np.array(positive_points + negative_points)
|
341
|
+
labels = np.array([1] * len(positive_points) + [0] * len(negative_points))
|
342
|
+
|
343
|
+
masks, scores, logits = self.predictor.predict(
|
344
|
+
point_coords=points,
|
345
|
+
point_labels=labels,
|
346
|
+
multimask_output=True,
|
347
|
+
)
|
348
|
+
|
349
|
+
# Return the mask with the highest score
|
350
|
+
best_mask_idx = np.argmax(scores)
|
351
|
+
return masks[best_mask_idx], scores[best_mask_idx], logits[best_mask_idx]
|
352
|
+
|
353
|
+
except Exception as e:
|
354
|
+
logger.error(f"SAM2: Error during prediction: {e}")
|
355
|
+
return None
|
356
|
+
|
357
|
+
def predict_from_box(self, box):
|
358
|
+
"""Generate predictions from bounding box using SAM2."""
|
359
|
+
if not self.is_loaded:
|
360
|
+
return None
|
361
|
+
|
362
|
+
try:
|
363
|
+
masks, scores, logits = self.predictor.predict(
|
364
|
+
box=np.array(box),
|
365
|
+
multimask_output=True,
|
366
|
+
)
|
367
|
+
|
368
|
+
# Return the mask with the highest score
|
369
|
+
best_mask_idx = np.argmax(scores)
|
370
|
+
return masks[best_mask_idx], scores[best_mask_idx], logits[best_mask_idx]
|
371
|
+
|
372
|
+
except Exception as e:
|
373
|
+
logger.error(f"SAM2: Error during box prediction: {e}")
|
374
|
+
return None
|
375
|
+
|
376
|
+
def load_custom_model(
|
377
|
+
self, model_path: str, config_path: str | None = None
|
378
|
+
) -> bool:
|
379
|
+
"""Load a custom SAM2 model from the specified path."""
|
380
|
+
if not os.path.exists(model_path):
|
381
|
+
logger.warning(f"SAM2: Model file not found: {model_path}")
|
382
|
+
return False
|
383
|
+
|
384
|
+
logger.info(f"SAM2: Loading custom model from {model_path}...")
|
385
|
+
try:
|
386
|
+
# Clear existing model from memory
|
387
|
+
if hasattr(self, "model") and self.model is not None:
|
388
|
+
del self.model
|
389
|
+
del self.predictor
|
390
|
+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
391
|
+
|
392
|
+
# Auto-detect config if not provided
|
393
|
+
if config_path is None:
|
394
|
+
config_path = self._auto_detect_config(model_path)
|
395
|
+
|
396
|
+
# Load new model with same logic as __init__
|
397
|
+
model_filename = Path(model_path).name.lower()
|
398
|
+
|
399
|
+
# Use same loading logic as __init__
|
400
|
+
if "2.1" in model_filename:
|
401
|
+
# SAM2.1 models need manual Hydra initialization
|
402
|
+
logger.info(
|
403
|
+
"SAM2: Loading custom SAM2.1 model with manual config initialization"
|
404
|
+
)
|
405
|
+
|
406
|
+
try:
|
407
|
+
import sam2
|
408
|
+
from hydra import compose, initialize_config_dir
|
409
|
+
from hydra.core.global_hydra import GlobalHydra
|
410
|
+
|
411
|
+
sam2_configs_dir = os.path.join(
|
412
|
+
os.path.dirname(sam2.__file__), "configs", "sam2.1"
|
413
|
+
)
|
414
|
+
GlobalHydra.instance().clear()
|
415
|
+
|
416
|
+
with initialize_config_dir(
|
417
|
+
config_dir=sam2_configs_dir, version_base=None
|
418
|
+
):
|
419
|
+
config_filename = Path(config_path).name
|
420
|
+
cfg = compose(config_name=config_filename.replace(".yaml", ""))
|
421
|
+
|
422
|
+
from hydra.utils import instantiate
|
423
|
+
|
424
|
+
self.model = instantiate(cfg.model)
|
425
|
+
self.model.to(self.device)
|
426
|
+
|
427
|
+
if model_path:
|
428
|
+
checkpoint = torch.load(
|
429
|
+
model_path, map_location=self.device
|
430
|
+
)
|
431
|
+
model_weights = checkpoint.get("model", checkpoint)
|
432
|
+
self.model.load_state_dict(model_weights, strict=False)
|
433
|
+
|
434
|
+
logger.info(
|
435
|
+
"SAM2: Successfully loaded custom SAM2.1 with manual initialization"
|
436
|
+
)
|
437
|
+
|
438
|
+
except Exception as e1:
|
439
|
+
# Fallback to SAM2.0 config
|
440
|
+
logger.warning(
|
441
|
+
"SAM2: Falling back to SAM2.0 config for custom SAM2.1 model"
|
442
|
+
)
|
443
|
+
try:
|
444
|
+
fallback_config = "sam2_hiera_l.yaml"
|
445
|
+
self.model = build_sam2(
|
446
|
+
fallback_config, model_path, device=self.device
|
447
|
+
)
|
448
|
+
logger.warning(
|
449
|
+
"SAM2: Loaded custom SAM2.1 model with SAM2.0 config"
|
450
|
+
)
|
451
|
+
except Exception as e2:
|
452
|
+
raise Exception(
|
453
|
+
f"Failed to load custom SAM2.1 model. Manual init failed: {e1}, fallback failed: {e2}"
|
454
|
+
) from e2
|
455
|
+
else:
|
456
|
+
# Standard SAM2.0 loading
|
457
|
+
try:
|
458
|
+
logger.info(
|
459
|
+
f"SAM2: Attempting to load custom model with config path: {config_path}"
|
460
|
+
)
|
461
|
+
self.model = build_sam2(config_path, model_path, device=self.device)
|
462
|
+
except Exception:
|
463
|
+
try:
|
464
|
+
config_filename = Path(config_path).name
|
465
|
+
logger.info(
|
466
|
+
f"SAM2: Attempting to load custom model with config filename: {config_filename}"
|
467
|
+
)
|
468
|
+
self.model = build_sam2(
|
469
|
+
config_filename, model_path, device=self.device
|
470
|
+
)
|
471
|
+
except Exception as e2:
|
472
|
+
raise Exception(
|
473
|
+
f"Failed to load custom model. Last error: {e2}"
|
474
|
+
) from e2
|
475
|
+
self.predictor = SAM2ImagePredictor(self.model)
|
476
|
+
self.current_model_path = model_path
|
477
|
+
self.is_loaded = True
|
478
|
+
|
479
|
+
# Re-set image if one was previously loaded
|
480
|
+
if self.image is not None:
|
481
|
+
self.predictor.set_image(self.image)
|
482
|
+
|
483
|
+
logger.info("SAM2: Custom model loaded successfully.")
|
484
|
+
return True
|
485
|
+
except Exception as e:
|
486
|
+
logger.error(f"SAM2: Error loading custom model: {e}")
|
487
|
+
self.is_loaded = False
|
488
|
+
self.model = None
|
489
|
+
self.predictor = None
|
490
|
+
return False
|