lazylabel-gui 1.1.4__py3-none-any.whl → 1.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,200 +1,211 @@
1
- import os
2
-
3
- import cv2
4
- import numpy as np
5
- import requests
6
- import torch
7
- from segment_anything import SamPredictor, sam_model_registry
8
- from tqdm import tqdm
9
-
10
-
11
- def download_model(url, download_path):
12
- """Downloads file with a progress bar."""
13
- print("[10/20] SAM model not found. Downloading from Meta's repository...")
14
- print(f" Downloading to: {download_path}")
15
- try:
16
- print("[10/20] Connecting to download server...")
17
- response = requests.get(url, stream=True, timeout=30)
18
- response.raise_for_status()
19
- total_size_in_bytes = int(response.headers.get("content-length", 0))
20
- block_size = 1024 # 1 Kibibyte
21
-
22
- print(
23
- f"[10/20] Starting download ({total_size_in_bytes / (1024 * 1024 * 1024):.1f} GB)..."
24
- )
25
- progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
26
- with open(download_path, "wb") as file:
27
- for data in response.iter_content(block_size):
28
- progress_bar.update(len(data))
29
- file.write(data)
30
- progress_bar.close()
31
-
32
- if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
33
- raise RuntimeError("Download incomplete - file size mismatch")
34
-
35
- print("[10/20] Model download completed successfully.")
36
-
37
- except requests.exceptions.ConnectionError as e:
38
- raise RuntimeError(
39
- "[10/20] Network connection failed: Check your internet connection"
40
- ) from e
41
- except requests.exceptions.Timeout as e:
42
- raise RuntimeError(
43
- "[10/20] Download timeout: Server took too long to respond"
44
- ) from e
45
- except requests.exceptions.HTTPError as e:
46
- raise RuntimeError(
47
- f"[10/20] HTTP error {e.response.status_code}: Server rejected request"
48
- ) from e
49
- except requests.exceptions.RequestException as e:
50
- raise RuntimeError(f"[10/20] Network error during download: {e}") from e
51
- except PermissionError as e:
52
- raise RuntimeError(
53
- f"[10/20] Permission denied: Cannot write to {download_path}"
54
- ) from e
55
- except OSError as e:
56
- raise RuntimeError(
57
- f"[10/20] Disk error: {e} (check available disk space)"
58
- ) from e
59
- except Exception as e:
60
- # Clean up partial download
61
- if os.path.exists(download_path):
62
- import contextlib
63
-
64
- with contextlib.suppress(OSError):
65
- os.remove(download_path)
66
- raise RuntimeError(f"[10/20] Download failed: {e}") from e
67
-
68
-
69
- class SamModel:
70
- def __init__(
71
- self,
72
- model_type="vit_h",
73
- model_filename="sam_vit_h_4b8939.pth",
74
- custom_model_path=None,
75
- ):
76
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
77
- print(f"[9/20] Detected device: {str(self.device).upper()}")
78
-
79
- self.current_model_type = model_type
80
- self.current_model_path = custom_model_path
81
- self.model = None
82
- self.predictor = None
83
- self.image = None
84
- self.is_loaded = False
85
-
86
- try:
87
- if custom_model_path and os.path.exists(custom_model_path):
88
- # Use custom model path
89
- model_path = custom_model_path
90
- print(f"[10/20] Loading custom SAM model from {model_path}...")
91
- else:
92
- # Use default model with download if needed - store in models folder
93
- model_url = (
94
- f"https://dl.fbaipublicfiles.com/segment_anything/{model_filename}"
95
- )
96
-
97
- # Use models folder instead of cache folder
98
- models_dir = os.path.dirname(__file__) # Already in models directory
99
- os.makedirs(models_dir, exist_ok=True)
100
- model_path = os.path.join(models_dir, model_filename)
101
-
102
- # Also check the old cache location and move it if it exists
103
- old_cache_dir = os.path.join(
104
- os.path.expanduser("~"), ".cache", "lazylabel"
105
- )
106
- old_model_path = os.path.join(old_cache_dir, model_filename)
107
-
108
- if os.path.exists(old_model_path) and not os.path.exists(model_path):
109
- print(
110
- "[10/20] Moving existing model from cache to models folder..."
111
- )
112
- import shutil
113
-
114
- shutil.move(old_model_path, model_path)
115
- elif not os.path.exists(model_path):
116
- # Download the model if it doesn't exist
117
- download_model(model_url, model_path)
118
-
119
- print(f"[10/20] Loading default SAM model from {model_path}...")
120
-
121
- print(f"[11/20] Initializing {model_type.upper()} model architecture...")
122
- self.model = sam_model_registry[model_type](checkpoint=model_path).to(
123
- self.device
124
- )
125
-
126
- print("[12/20] Setting up predictor...")
127
- self.predictor = SamPredictor(self.model)
128
- self.is_loaded = True
129
- print("[13/20] SAM model loaded successfully.")
130
-
131
- except Exception as e:
132
- print(f"[8/20] Failed to load SAM model: {e}")
133
- print("[8/20] SAM point functionality will be disabled.")
134
- self.is_loaded = False
135
-
136
- def load_custom_model(self, model_path, model_type="vit_h"):
137
- """Load a custom model from the specified path."""
138
- if not os.path.exists(model_path):
139
- print(f"Model file not found: {model_path}")
140
- return False
141
-
142
- print(f"Loading custom SAM model from {model_path}...")
143
- try:
144
- # Clear existing model from memory
145
- if hasattr(self, "model") and self.model is not None:
146
- del self.model
147
- del self.predictor
148
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
149
-
150
- # Load new model
151
- self.model = sam_model_registry[model_type](checkpoint=model_path).to(
152
- self.device
153
- )
154
- self.predictor = SamPredictor(self.model)
155
- self.current_model_type = model_type
156
- self.current_model_path = model_path
157
- self.is_loaded = True
158
-
159
- # Re-set image if one was previously loaded
160
- if self.image is not None:
161
- self.predictor.set_image(self.image)
162
-
163
- print("Custom SAM model loaded successfully.")
164
- return True
165
- except Exception as e:
166
- print(f"Error loading custom model: {e}")
167
- self.is_loaded = False
168
- self.model = None
169
- self.predictor = None
170
- return False
171
-
172
- def set_image(self, image_path):
173
- if not self.is_loaded:
174
- return False
175
- try:
176
- self.image = cv2.imread(image_path)
177
- self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
178
- self.predictor.set_image(self.image)
179
- return True
180
- except Exception as e:
181
- print(f"Error setting image: {e}")
182
- return False
183
-
184
- def predict(self, positive_points, negative_points):
185
- if not self.is_loaded or not positive_points:
186
- return None
187
-
188
- try:
189
- points = np.array(positive_points + negative_points)
190
- labels = np.array([1] * len(positive_points) + [0] * len(negative_points))
191
-
192
- masks, _, _ = self.predictor.predict(
193
- point_coords=points,
194
- point_labels=labels,
195
- multimask_output=False,
196
- )
197
- return masks[0]
198
- except Exception as e:
199
- print(f"Error during prediction: {e}")
200
- return None
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import requests
6
+ import torch
7
+ from segment_anything import SamPredictor, sam_model_registry
8
+ from tqdm import tqdm
9
+
10
+ from ..utils.logger import logger
11
+
12
+
13
+ def download_model(url, download_path):
14
+ """Downloads file with a progress bar."""
15
+
16
+ try:
17
+ logger.info("Step 5/8: Connecting to download server...")
18
+ response = requests.get(url, stream=True, timeout=30)
19
+ response.raise_for_status()
20
+ total_size_in_bytes = int(response.headers.get("content-length", 0))
21
+ block_size = 1024 # 1 Kibibyte
22
+
23
+ progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
24
+ with open(download_path, "wb") as file:
25
+ for data in response.iter_content(block_size):
26
+ progress_bar.update(len(data))
27
+ file.write(data)
28
+ progress_bar.close()
29
+
30
+ if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
31
+ raise RuntimeError("Download incomplete - file size mismatch")
32
+
33
+ logger.info("Step 5/8: Model download completed successfully.")
34
+
35
+ except requests.exceptions.ConnectionError as e:
36
+ raise RuntimeError(
37
+ "Step 5/8: Network connection failed: Check your internet connection"
38
+ ) from e
39
+ except requests.exceptions.Timeout as e:
40
+ raise RuntimeError(
41
+ "Step 5/8: Download timeout: Server took too long to respond"
42
+ ) from e
43
+ except requests.exceptions.HTTPError as e:
44
+ raise RuntimeError(
45
+ f"Step 5/8: HTTP error {e.response.status_code}: Server rejected request"
46
+ ) from e
47
+ except requests.exceptions.RequestException as e:
48
+ raise RuntimeError(f"Step 5/8: Network error during download: {e}") from e
49
+ except PermissionError as e:
50
+ raise RuntimeError(
51
+ f"Step 5/8: Permission denied: Cannot write to {download_path}"
52
+ ) from e
53
+ except OSError as e:
54
+ raise RuntimeError(
55
+ f"Step 5/8: Disk error: {e} (check available disk space)"
56
+ ) from e
57
+ except Exception as e:
58
+ # Clean up partial download
59
+ if os.path.exists(download_path):
60
+ import contextlib
61
+
62
+ with contextlib.suppress(OSError):
63
+ os.remove(download_path)
64
+ raise RuntimeError(f"Step 5/8: Download failed: {e}") from e
65
+
66
+
67
+ class SamModel:
68
+ def __init__(
69
+ self,
70
+ model_type="vit_h",
71
+ model_filename="sam_vit_h_4b8939.pth",
72
+ custom_model_path=None,
73
+ ):
74
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
75
+ logger.info(f"Step 5/8: Detected device: {str(self.device).upper()}")
76
+
77
+ self.current_model_type = model_type
78
+ self.current_model_path = custom_model_path
79
+ self.model = None
80
+ self.predictor = None
81
+ self.image = None
82
+ self.is_loaded = False
83
+
84
+ try:
85
+ if custom_model_path and os.path.exists(custom_model_path):
86
+ # Use custom model path
87
+ model_path = custom_model_path
88
+ logger.info(f"Step 5/8: Loading custom SAM model from {model_path}...")
89
+ else:
90
+ # Use default model with download if needed - store in models folder
91
+ model_url = (
92
+ f"https://dl.fbaipublicfiles.com/segment_anything/{model_filename}"
93
+ )
94
+
95
+ # Use models folder instead of cache folder
96
+ models_dir = os.path.dirname(__file__) # Already in models directory
97
+ os.makedirs(models_dir, exist_ok=True)
98
+ model_path = os.path.join(models_dir, model_filename)
99
+
100
+ # Also check the old cache location and move it if it exists
101
+ old_cache_dir = os.path.join(
102
+ os.path.expanduser("~"), ".cache", "lazylabel"
103
+ )
104
+ old_model_path = os.path.join(old_cache_dir, model_filename)
105
+
106
+ if os.path.exists(old_model_path) and not os.path.exists(model_path):
107
+ logger.info(
108
+ "Step 5/8: Moving existing model from cache to models folder..."
109
+ )
110
+ import shutil
111
+
112
+ shutil.move(old_model_path, model_path)
113
+ elif not os.path.exists(model_path):
114
+ # Download the model if it doesn't exist
115
+ download_model(model_url, model_path)
116
+
117
+ logger.info(f"Step 5/8: Loading default SAM model from {model_path}...")
118
+
119
+ logger.info(
120
+ f"Step 5/8: Initializing {model_type.upper()} model architecture..."
121
+ )
122
+ self.model = sam_model_registry[model_type](checkpoint=model_path).to(
123
+ self.device
124
+ )
125
+
126
+ logger.info("Step 5/8: Setting up predictor...")
127
+ self.predictor = SamPredictor(self.model)
128
+ self.is_loaded = True
129
+ logger.info("Step 5/8: SAM model loaded successfully.")
130
+
131
+ except Exception as e:
132
+ logger.error(f"Step 4/8: Failed to load SAM model: {e}")
133
+ logger.warning("Step 4/8: SAM point functionality will be disabled.")
134
+ self.is_loaded = False
135
+
136
+ def load_custom_model(self, model_path, model_type="vit_h"):
137
+ """Load a custom model from the specified path."""
138
+ if not os.path.exists(model_path):
139
+ logger.warning(f"Model file not found: {model_path}")
140
+ return False
141
+
142
+ logger.info(f"Loading custom SAM model from {model_path}...")
143
+ try:
144
+ # Clear existing model from memory
145
+ if hasattr(self, "model") and self.model is not None:
146
+ del self.model
147
+ del self.predictor
148
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
149
+
150
+ # Load new model
151
+ self.model = sam_model_registry[model_type](checkpoint=model_path).to(
152
+ self.device
153
+ )
154
+ self.predictor = SamPredictor(self.model)
155
+ self.current_model_type = model_type
156
+ self.current_model_path = model_path
157
+ self.is_loaded = True
158
+
159
+ # Re-set image if one was previously loaded
160
+ if self.image is not None:
161
+ self.predictor.set_image(self.image)
162
+
163
+ logger.info("Custom SAM model loaded successfully.")
164
+ return True
165
+ except Exception as e:
166
+ logger.error(f"Error loading custom model: {e}")
167
+ self.is_loaded = False
168
+ self.model = None
169
+ self.predictor = None
170
+ return False
171
+
172
+ def set_image_from_path(self, image_path):
173
+ if not self.is_loaded:
174
+ return False
175
+ try:
176
+ self.image = cv2.imread(image_path)
177
+ self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
178
+ self.predictor.set_image(self.image)
179
+ return True
180
+ except Exception as e:
181
+ logger.error(f"Error setting image from path: {e}")
182
+ return False
183
+
184
+ def set_image_from_array(self, image_array: np.ndarray):
185
+ if not self.is_loaded:
186
+ return False
187
+ try:
188
+ self.image = image_array
189
+ self.predictor.set_image(self.image)
190
+ return True
191
+ except Exception as e:
192
+ logger.error(f"Error setting image from array: {e}")
193
+ return False
194
+
195
+ def predict(self, positive_points, negative_points):
196
+ if not self.is_loaded or not positive_points:
197
+ return None
198
+
199
+ try:
200
+ points = np.array(positive_points + negative_points)
201
+ labels = np.array([1] * len(positive_points) + [0] * len(negative_points))
202
+
203
+ masks, _, _ = self.predictor.predict(
204
+ point_coords=points,
205
+ point_labels=labels,
206
+ multimask_output=False,
207
+ )
208
+ return masks[0]
209
+ except Exception as e:
210
+ logger.error(f"Error during prediction: {e}")
211
+ return None