lazylabel-gui 1.3.4__py3-none-any.whl → 1.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -66,71 +66,152 @@ class Sam2Model:
66
66
 
67
67
  model_filename = Path(model_path).name.lower()
68
68
 
69
- # First, try using the auto-detected config path directly
70
- try:
71
- logger.info(f"SAM2: Attempting to load with config path: {config_path}")
72
- self.model = self._build_sam2_with_fallback(config_path, model_path)
73
- logger.info("SAM2: Successfully loaded with config path")
74
- except Exception as e1:
75
- logger.debug(f"SAM2: Config path approach failed: {e1}")
69
+ # For SAM2.1 models, use manual Hydra initialization since configs aren't in search path
70
+ if "2.1" in model_filename:
71
+ logger.info(
72
+ "SAM2: Loading SAM2.1 model with manual config initialization"
73
+ )
76
74
 
77
- # Second, try just the config filename without path
78
75
  try:
79
- config_filename = Path(config_path).name
80
- logger.info(
81
- f"SAM2: Attempting to load with config filename: {config_filename}"
76
+ # Import required Hydra components
77
+ # Get the configs directory
78
+ import sam2
79
+ from hydra import compose, initialize_config_dir
80
+ from hydra.core.global_hydra import GlobalHydra
81
+
82
+ sam2_configs_dir = os.path.join(
83
+ os.path.dirname(sam2.__file__), "configs", "sam2.1"
82
84
  )
83
- self.model = self._build_sam2_with_fallback(
84
- config_filename, model_path
85
- )
86
- logger.info("SAM2: Successfully loaded with config filename")
87
- except Exception as e2:
88
- logger.debug(f"SAM2: Config filename approach failed: {e2}")
89
85
 
90
- # Third, try the base config name without version
91
- try:
92
- # Map model sizes to base config names
93
- if (
94
- "tiny" in model_filename
95
- or "_t." in model_filename
96
- or "_t_" in model_filename
97
- ):
98
- base_config = "sam2_hiera_t.yaml"
99
- elif (
100
- "small" in model_filename
101
- or "_s." in model_filename
102
- or "_s_" in model_filename
103
- ):
104
- base_config = "sam2_hiera_s.yaml"
105
- elif (
106
- "base_plus" in model_filename
107
- or "_b+." in model_filename
108
- or "_b+_" in model_filename
109
- ):
110
- base_config = "sam2_hiera_b+.yaml"
111
- elif (
112
- "large" in model_filename
113
- or "_l." in model_filename
114
- or "_l_" in model_filename
115
- ):
116
- base_config = "sam2_hiera_l.yaml"
117
- else:
118
- base_config = "sam2_hiera_l.yaml"
86
+ # Clear any existing Hydra instance
87
+ GlobalHydra.instance().clear()
88
+
89
+ # Initialize Hydra with the SAM2.1 configs directory
90
+ with initialize_config_dir(
91
+ config_dir=sam2_configs_dir, version_base=None
92
+ ):
93
+ config_filename = Path(config_path).name
94
+ logger.info(f"SAM2: Loading SAM2.1 config: {config_filename}")
95
+
96
+ # Load the config
97
+ cfg = compose(config_name=config_filename.replace(".yaml", ""))
98
+
99
+ # Manually build the model using the config
100
+ from hydra.utils import instantiate
101
+
102
+ self.model = instantiate(cfg.model)
103
+ self.model.to(self.device)
104
+
105
+ # Load the checkpoint weights
106
+ if model_path:
107
+ checkpoint = torch.load(
108
+ model_path, map_location=self.device
109
+ )
110
+ # Handle nested checkpoint structure
111
+ if "model" in checkpoint:
112
+ model_weights = checkpoint["model"]
113
+ else:
114
+ model_weights = checkpoint
115
+ self.model.load_state_dict(model_weights, strict=False)
119
116
 
120
117
  logger.info(
121
- f"SAM2: Attempting to load with base config: {base_config}"
118
+ "SAM2: Successfully loaded SAM2.1 with manual initialization"
122
119
  )
123
- self.model = self._build_sam2_with_fallback(
124
- base_config, model_path
120
+
121
+ except Exception as e1:
122
+ logger.debug(f"SAM2: SAM2.1 manual initialization failed: {e1}")
123
+ # Fallback to using a compatible SAM2.0 config as a workaround
124
+ logger.warning(
125
+ "SAM2: Falling back to SAM2.0 config for SAM2.1 model (may have reduced functionality)"
126
+ )
127
+ try:
128
+ # Use the closest SAM2.0 config
129
+ fallback_config = (
130
+ "sam2_hiera_l.yaml" # This works according to our test
131
+ )
132
+ logger.info(
133
+ f"SAM2: Attempting fallback with SAM2.0 config: {fallback_config}"
125
134
  )
126
- logger.info("SAM2: Successfully loaded with base config")
127
- except Exception as e3:
128
- # All approaches failed
135
+ self.model = build_sam2(
136
+ fallback_config, model_path, device=self.device
137
+ )
138
+ logger.warning(
139
+ "SAM2: Loaded SAM2.1 model with SAM2.0 config - some features may not work"
140
+ )
141
+ except Exception as e2:
129
142
  raise Exception(
130
- f"Failed to load SAM2 model with any config approach. "
131
- f"Tried: {config_path}, {config_filename}, {base_config}. "
132
- f"Last error: {e3}"
133
- ) from e3
143
+ f"Failed to load SAM2.1 model. Manual initialization failed: {e1}. "
144
+ f"Fallback to SAM2.0 config also failed: {e2}. "
145
+ f"Try reinstalling SAM2 with latest version from official repo."
146
+ ) from e2
147
+ else:
148
+ # Standard SAM2.0 loading approach
149
+ try:
150
+ logger.info(
151
+ f"SAM2: Attempting to load with config path: {config_path}"
152
+ )
153
+ self.model = build_sam2(config_path, model_path, device=self.device)
154
+ logger.info("SAM2: Successfully loaded with config path")
155
+ except Exception as e1:
156
+ logger.debug(f"SAM2: Config path approach failed: {e1}")
157
+
158
+ # Try just the config filename without path (for Hydra)
159
+ try:
160
+ config_filename = Path(config_path).name
161
+ logger.info(
162
+ f"SAM2: Attempting to load with config filename: {config_filename}"
163
+ )
164
+ self.model = build_sam2(
165
+ config_filename, model_path, device=self.device
166
+ )
167
+ logger.info("SAM2: Successfully loaded with config filename")
168
+ except Exception as e2:
169
+ logger.debug(f"SAM2: Config filename approach failed: {e2}")
170
+
171
+ # Try the base config name for SAM2.0 models
172
+ try:
173
+ # Map model sizes to base config names (SAM2.0 only)
174
+ if (
175
+ "tiny" in model_filename
176
+ or "_t." in model_filename
177
+ or "_t_" in model_filename
178
+ ):
179
+ base_config = "sam2_hiera_t.yaml"
180
+ elif (
181
+ "small" in model_filename
182
+ or "_s." in model_filename
183
+ or "_s_" in model_filename
184
+ ):
185
+ base_config = "sam2_hiera_s.yaml"
186
+ elif (
187
+ "base_plus" in model_filename
188
+ or "_b+." in model_filename
189
+ or "_b+_" in model_filename
190
+ ):
191
+ base_config = "sam2_hiera_b+.yaml"
192
+ elif (
193
+ "large" in model_filename
194
+ or "_l." in model_filename
195
+ or "_l_" in model_filename
196
+ ):
197
+ base_config = "sam2_hiera_l.yaml"
198
+ else:
199
+ base_config = "sam2_hiera_l.yaml"
200
+
201
+ logger.info(
202
+ f"SAM2: Attempting to load with base config: {base_config}"
203
+ )
204
+ self.model = build_sam2(
205
+ base_config, model_path, device=self.device
206
+ )
207
+ logger.info("SAM2: Successfully loaded with base config")
208
+ except Exception as e3:
209
+ # All approaches failed
210
+ raise Exception(
211
+ f"Failed to load SAM2 model with any config approach. "
212
+ f"Tried: {config_path}, {config_filename}, {base_config}. "
213
+ f"Last error: {e3}"
214
+ ) from e3
134
215
 
135
216
  # Create predictor
136
217
  self.predictor = SAM2ImagePredictor(self.model)
@@ -155,53 +236,53 @@ class Sam2Model:
155
236
  sam2_dir = Path(sam2.__file__).parent
156
237
  configs_dir = sam2_dir / "configs"
157
238
 
158
- # Map model types to config files
239
+ # Determine if this is a SAM2.1 model
240
+ is_sam21 = "2.1" in filename
241
+
242
+ # Map model types to config files based on version
159
243
  if "tiny" in filename or "_t" in filename:
160
- config_file = (
161
- "sam2.1_hiera_t.yaml" if "2.1" in filename else "sam2_hiera_t.yaml"
162
- )
244
+ config_file = "sam2.1_hiera_t.yaml" if is_sam21 else "sam2_hiera_t.yaml"
163
245
  elif "small" in filename or "_s" in filename:
164
- config_file = (
165
- "sam2.1_hiera_s.yaml" if "2.1" in filename else "sam2_hiera_s.yaml"
166
- )
246
+ config_file = "sam2.1_hiera_s.yaml" if is_sam21 else "sam2_hiera_s.yaml"
167
247
  elif "base_plus" in filename or "_b+" in filename:
168
248
  config_file = (
169
- "sam2.1_hiera_b+.yaml"
170
- if "2.1" in filename
171
- else "sam2_hiera_b+.yaml"
249
+ "sam2.1_hiera_b+.yaml" if is_sam21 else "sam2_hiera_b+.yaml"
172
250
  )
173
251
  elif "large" in filename or "_l" in filename:
174
- config_file = (
175
- "sam2.1_hiera_l.yaml" if "2.1" in filename else "sam2_hiera_l.yaml"
176
- )
252
+ config_file = "sam2.1_hiera_l.yaml" if is_sam21 else "sam2_hiera_l.yaml"
177
253
  else:
178
- # Default to large model
179
- config_file = "sam2.1_hiera_l.yaml"
254
+ # Default to large model with appropriate version
255
+ config_file = "sam2.1_hiera_l.yaml" if is_sam21 else "sam2_hiera_l.yaml"
180
256
 
181
- # Check sam2.1 configs first, then fall back to sam2
182
- if "2.1" in filename:
257
+ # Build config path based on version
258
+ if is_sam21:
183
259
  config_path = configs_dir / "sam2.1" / config_file
184
260
  else:
185
- config_path = configs_dir / "sam2" / config_file.replace("2.1_", "")
261
+ config_path = configs_dir / "sam2" / config_file
186
262
 
187
263
  logger.debug(f"SAM2: Checking config path: {config_path}")
188
264
  if config_path.exists():
189
265
  return str(config_path.absolute())
190
266
 
191
- # Fallback to default large config
192
- fallback_config = configs_dir / "sam2.1" / "sam2.1_hiera_l.yaml"
267
+ # Fallback to default large config of the same version
268
+ fallback_config_file = (
269
+ "sam2.1_hiera_l.yaml" if is_sam21 else "sam2_hiera_l.yaml"
270
+ )
271
+ fallback_subdir = "sam2.1" if is_sam21 else "sam2"
272
+ fallback_config = configs_dir / fallback_subdir / fallback_config_file
193
273
  logger.debug(f"SAM2: Checking fallback config: {fallback_config}")
194
274
  if fallback_config.exists():
195
275
  return str(fallback_config.absolute())
196
276
 
197
- # Try without version subdirectory
198
- direct_config = configs_dir / config_file
199
- logger.debug(f"SAM2: Checking direct config: {direct_config}")
200
- if direct_config.exists():
201
- return str(direct_config.absolute())
277
+ # Try without version subdirectory (only for SAM2.0)
278
+ if not is_sam21:
279
+ direct_config = configs_dir / config_file
280
+ logger.debug(f"SAM2: Checking direct config: {direct_config}")
281
+ if direct_config.exists():
282
+ return str(direct_config.absolute())
202
283
 
203
284
  raise FileNotFoundError(
204
- f"No suitable config found for {filename} in {configs_dir}"
285
+ f"No suitable {'SAM2.1' if is_sam21 else 'SAM2'} config found for {filename} in {configs_dir}"
205
286
  )
206
287
 
207
288
  except Exception as e:
@@ -211,58 +292,19 @@ class Sam2Model:
211
292
  import sam2
212
293
 
213
294
  sam2_dir = Path(sam2.__file__).parent
214
- # Return full path to default config
215
- return str(sam2_dir / "configs" / "sam2.1" / "sam2.1_hiera_l.yaml")
216
- except Exception:
217
- # Last resort - return just the config name and let hydra handle it
218
- return "sam2.1_hiera_l.yaml"
219
-
220
- def _build_sam2_with_fallback(self, config_path, model_path):
221
- """Build SAM2 model with fallback for state_dict compatibility issues."""
222
- try:
223
- # First, try the standard build_sam2 approach
224
- return build_sam2(config_path, model_path, device=self.device)
225
- except RuntimeError as e:
226
- if "Unexpected key(s) in state_dict" in str(e):
227
- logger.warning(f"SAM2: Detected state_dict compatibility issue: {e}")
228
- logger.info("SAM2: Attempting to load with state_dict filtering...")
295
+ filename = Path(model_path).name.lower()
296
+ is_sam21 = "2.1" in filename
229
297
 
230
- # Build model without loading weights first
231
- model = build_sam2(config_path, None, device=self.device)
232
-
233
- # Load checkpoint and handle nested structure
234
- checkpoint = torch.load(model_path, map_location=self.device)
235
-
236
- # Check if checkpoint has nested 'model' key (common in SAM2.1)
237
- if "model" in checkpoint and isinstance(checkpoint["model"], dict):
238
- logger.info(
239
- "SAM2: Detected nested checkpoint structure, extracting model weights"
240
- )
241
- model_weights = checkpoint["model"]
298
+ # Return full path to appropriate default config
299
+ if is_sam21:
300
+ return str(sam2_dir / "configs" / "sam2.1" / "sam2.1_hiera_l.yaml")
242
301
  else:
243
- # Flat structure - filter out the known problematic keys
244
- model_weights = {}
245
- problematic_keys = {
246
- "no_obj_embed_spatial",
247
- "obj_ptr_tpos_proj.weight",
248
- "obj_ptr_tpos_proj.bias",
249
- }
250
- for key, value in checkpoint.items():
251
- if key not in problematic_keys:
252
- model_weights[key] = value
253
-
254
- logger.info(
255
- f"SAM2: Filtered out problematic keys: {list(problematic_keys & set(checkpoint.keys()))}"
256
- )
257
-
258
- # Load the model weights
259
- model.load_state_dict(model_weights, strict=False)
260
- logger.info("SAM2: Successfully loaded model with state_dict filtering")
261
-
262
- return model
263
- else:
264
- # Re-raise if it's a different type of error
265
- raise
302
+ return str(sam2_dir / "configs" / "sam2" / "sam2_hiera_l.yaml")
303
+ except Exception:
304
+ # Last resort - return just the config name and let hydra handle it
305
+ filename = Path(model_path).name.lower()
306
+ is_sam21 = "2.1" in filename
307
+ return "sam2.1_hiera_l.yaml" if is_sam21 else "sam2_hiera_l.yaml"
266
308
 
267
309
  def set_image_from_path(self, image_path: str) -> bool:
268
310
  """Set image for SAM2 model from file path."""
@@ -351,8 +393,85 @@ class Sam2Model:
351
393
  if config_path is None:
352
394
  config_path = self._auto_detect_config(model_path)
353
395
 
354
- # Load new model
355
- self.model = self._build_sam2_with_fallback(config_path, model_path)
396
+ # Load new model with same logic as __init__
397
+ model_filename = Path(model_path).name.lower()
398
+
399
+ # Use same loading logic as __init__
400
+ if "2.1" in model_filename:
401
+ # SAM2.1 models need manual Hydra initialization
402
+ logger.info(
403
+ "SAM2: Loading custom SAM2.1 model with manual config initialization"
404
+ )
405
+
406
+ try:
407
+ import sam2
408
+ from hydra import compose, initialize_config_dir
409
+ from hydra.core.global_hydra import GlobalHydra
410
+
411
+ sam2_configs_dir = os.path.join(
412
+ os.path.dirname(sam2.__file__), "configs", "sam2.1"
413
+ )
414
+ GlobalHydra.instance().clear()
415
+
416
+ with initialize_config_dir(
417
+ config_dir=sam2_configs_dir, version_base=None
418
+ ):
419
+ config_filename = Path(config_path).name
420
+ cfg = compose(config_name=config_filename.replace(".yaml", ""))
421
+
422
+ from hydra.utils import instantiate
423
+
424
+ self.model = instantiate(cfg.model)
425
+ self.model.to(self.device)
426
+
427
+ if model_path:
428
+ checkpoint = torch.load(
429
+ model_path, map_location=self.device
430
+ )
431
+ model_weights = checkpoint.get("model", checkpoint)
432
+ self.model.load_state_dict(model_weights, strict=False)
433
+
434
+ logger.info(
435
+ "SAM2: Successfully loaded custom SAM2.1 with manual initialization"
436
+ )
437
+
438
+ except Exception as e1:
439
+ # Fallback to SAM2.0 config
440
+ logger.warning(
441
+ "SAM2: Falling back to SAM2.0 config for custom SAM2.1 model"
442
+ )
443
+ try:
444
+ fallback_config = "sam2_hiera_l.yaml"
445
+ self.model = build_sam2(
446
+ fallback_config, model_path, device=self.device
447
+ )
448
+ logger.warning(
449
+ "SAM2: Loaded custom SAM2.1 model with SAM2.0 config"
450
+ )
451
+ except Exception as e2:
452
+ raise Exception(
453
+ f"Failed to load custom SAM2.1 model. Manual init failed: {e1}, fallback failed: {e2}"
454
+ ) from e2
455
+ else:
456
+ # Standard SAM2.0 loading
457
+ try:
458
+ logger.info(
459
+ f"SAM2: Attempting to load custom model with config path: {config_path}"
460
+ )
461
+ self.model = build_sam2(config_path, model_path, device=self.device)
462
+ except Exception:
463
+ try:
464
+ config_filename = Path(config_path).name
465
+ logger.info(
466
+ f"SAM2: Attempting to load custom model with config filename: {config_filename}"
467
+ )
468
+ self.model = build_sam2(
469
+ config_filename, model_path, device=self.device
470
+ )
471
+ except Exception as e2:
472
+ raise Exception(
473
+ f"Failed to load custom model. Last error: {e2}"
474
+ ) from e2
356
475
  self.predictor = SAM2ImagePredictor(self.model)
357
476
  self.current_model_path = model_path
358
477
  self.is_loaded = True