geoai-py 0.5.6__py2.py3-none-any.whl → 0.6.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +1 -1
- geoai/classify.py +23 -24
- geoai/extract.py +1 -1
- geoai/geoai.py +3 -3
- geoai/sam.py +832 -0
- geoai/utils.py +329 -0
- {geoai_py-0.5.6.dist-info → geoai_py-0.6.0.dist-info}/METADATA +2 -1
- geoai_py-0.6.0.dist-info/RECORD +17 -0
- {geoai_py-0.5.6.dist-info → geoai_py-0.6.0.dist-info}/WHEEL +1 -1
- geoai_py-0.5.6.dist-info/RECORD +0 -16
- {geoai_py-0.5.6.dist-info → geoai_py-0.6.0.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.5.6.dist-info → geoai_py-0.6.0.dist-info}/licenses/LICENSE +0 -0
- {geoai_py-0.5.6.dist-info → geoai_py-0.6.0.dist-info}/top_level.txt +0 -0
geoai/sam.py
ADDED
|
@@ -0,0 +1,832 @@
|
|
|
1
|
+
"""
|
|
2
|
+
The SamGeo class provides an interface for segmenting geospatial data using the Segment Anything Model (SAM).
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
7
|
+
|
|
8
|
+
import cv2
|
|
9
|
+
import numpy as np
|
|
10
|
+
import torch
|
|
11
|
+
from leafmap import array_to_image, blend_images
|
|
12
|
+
from PIL import Image
|
|
13
|
+
from transformers import SamModel, SamProcessor, pipeline
|
|
14
|
+
|
|
15
|
+
from .utils import *
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SamGeo:
|
|
19
|
+
"""The main class for segmenting geospatial data with the Segment Anything Model (SAM). See
|
|
20
|
+
https://huggingface.co/docs/transformers/main/en/model_doc/sam for details.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
model: str = "facebook/sam-vit-huge",
|
|
26
|
+
automatic: bool = True,
|
|
27
|
+
device: Optional[Union[str, int]] = None,
|
|
28
|
+
sam_kwargs: Optional[Dict[str, Any]] = None,
|
|
29
|
+
**kwargs: Any,
|
|
30
|
+
) -> None:
|
|
31
|
+
"""
|
|
32
|
+
Initialize the class.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
model (str, optional): The model type, such as "facebook/sam-vit-huge", "facebook/sam-vit-large", or "facebook/sam-vit-base".
|
|
36
|
+
Defaults to 'facebook/sam-vit-huge'. See https://bit.ly/3VrpxUh for more details.
|
|
37
|
+
automatic (bool, optional): Whether to use the automatic mask generator or input prompts. Defaults to True.
|
|
38
|
+
The automatic mask generator will segment the entire image, while the input prompts will segment selected objects.
|
|
39
|
+
device (Union[str, int], optional): The device to use. It can be one of the following: 'cpu', 'cuda', or an integer
|
|
40
|
+
representing the CUDA device index. Defaults to None, which will use 'cuda' if available.
|
|
41
|
+
sam_kwargs (Dict[str, Any], optional): Optional arguments for fine-tuning the SAM model. Defaults to None.
|
|
42
|
+
kwargs (Any): Other arguments for the automatic mask generator.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
self.model = model
|
|
46
|
+
self.model_version = "sam"
|
|
47
|
+
|
|
48
|
+
self.sam_kwargs = sam_kwargs # Optional arguments for fine-tuning the SAM model
|
|
49
|
+
self.source = None # Store the input image path
|
|
50
|
+
self.image = None # Store the input image as a numpy array
|
|
51
|
+
self.embeddings = None # Store the image embeddings
|
|
52
|
+
# Store the masks as a list of dictionaries. Each mask is a dictionary
|
|
53
|
+
# containing segmentation, area, bbox, predicted_iou, point_coords, stability_score, and crop_box
|
|
54
|
+
self.masks = None
|
|
55
|
+
self.objects = None # Store the mask objects as a numpy array
|
|
56
|
+
# Store the annotations (objects with random color) as a numpy array.
|
|
57
|
+
self.annotations = None
|
|
58
|
+
|
|
59
|
+
# Store the predicted masks, iou_predictions, and low_res_masks
|
|
60
|
+
self.prediction = None
|
|
61
|
+
self.scores = None
|
|
62
|
+
self.logits = None
|
|
63
|
+
|
|
64
|
+
# Build the SAM model
|
|
65
|
+
sam_kwargs = self.sam_kwargs if self.sam_kwargs is not None else {}
|
|
66
|
+
|
|
67
|
+
if automatic:
|
|
68
|
+
# Use cuda if available
|
|
69
|
+
if device is None:
|
|
70
|
+
device = 0 if torch.cuda.is_available() else -1
|
|
71
|
+
if device >= 0:
|
|
72
|
+
torch.cuda.empty_cache()
|
|
73
|
+
self.device = device
|
|
74
|
+
|
|
75
|
+
self.mask_generator = pipeline(
|
|
76
|
+
task="mask-generation",
|
|
77
|
+
model=model,
|
|
78
|
+
device=device,
|
|
79
|
+
**kwargs,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
else:
|
|
83
|
+
if device is None:
|
|
84
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
85
|
+
if device.type == "cuda":
|
|
86
|
+
torch.cuda.empty_cache()
|
|
87
|
+
self.device = device
|
|
88
|
+
|
|
89
|
+
self.predictor = SamModel.from_pretrained("facebook/sam-vit-huge").to(
|
|
90
|
+
device
|
|
91
|
+
)
|
|
92
|
+
self.processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
|
93
|
+
|
|
94
|
+
def generate(
|
|
95
|
+
self,
|
|
96
|
+
source: Union[str, np.ndarray],
|
|
97
|
+
output: Optional[str] = None,
|
|
98
|
+
foreground: bool = True,
|
|
99
|
+
erosion_kernel: Optional[Tuple[int, int]] = None,
|
|
100
|
+
mask_multiplier: int = 255,
|
|
101
|
+
unique: bool = True,
|
|
102
|
+
min_size: int = 0,
|
|
103
|
+
max_size: Optional[int] = None,
|
|
104
|
+
output_args: Optional[Dict[str, Any]] = None,
|
|
105
|
+
**kwargs: Any,
|
|
106
|
+
) -> None:
|
|
107
|
+
"""
|
|
108
|
+
Generate masks for the input image.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
source (Union[str, np.ndarray]): The path to the input image or the input image as a numpy array.
|
|
112
|
+
output (Optional[str], optional): The path to the output image. Defaults to None.
|
|
113
|
+
foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.
|
|
114
|
+
erosion_kernel (Optional[Tuple[int, int]], optional): The erosion kernel for filtering object masks and extracting borders.
|
|
115
|
+
For example, (3, 3) or (5, 5). Set to None to disable it. Defaults to None.
|
|
116
|
+
mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
|
|
117
|
+
You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.
|
|
118
|
+
unique (bool, optional): Whether to assign a unique value to each object. Defaults to True.
|
|
119
|
+
The unique value increases from 1 to the number of objects. The larger the number, the larger the object area.
|
|
120
|
+
min_size (int, optional): The minimum size of the objects. Defaults to 0.
|
|
121
|
+
max_size (Optional[int], optional): The maximum size of the objects. Defaults to None.
|
|
122
|
+
output_args (Optional[Dict[str, Any]], optional): Additional arguments for saving the output. Defaults to None.
|
|
123
|
+
**kwargs (Any): Other arguments for the mask generator.
|
|
124
|
+
|
|
125
|
+
Raises:
|
|
126
|
+
ValueError: If the input source is not a valid path or numpy array.
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
if isinstance(source, str):
|
|
130
|
+
if source.startswith("http"):
|
|
131
|
+
source = download_file(source)
|
|
132
|
+
|
|
133
|
+
if not os.path.exists(source):
|
|
134
|
+
raise ValueError(f"Input path {source} does not exist.")
|
|
135
|
+
|
|
136
|
+
image = cv2.imread(source)
|
|
137
|
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
138
|
+
elif isinstance(source, np.ndarray):
|
|
139
|
+
image = source
|
|
140
|
+
source = None
|
|
141
|
+
else:
|
|
142
|
+
raise ValueError("Input source must be either a path or a numpy array.")
|
|
143
|
+
|
|
144
|
+
if output_args is None:
|
|
145
|
+
output_args = {}
|
|
146
|
+
|
|
147
|
+
self.source = source # Store the input image path
|
|
148
|
+
self.image = image # Store the input image as a numpy array
|
|
149
|
+
mask_generator = self.mask_generator # The automatic mask generator
|
|
150
|
+
# masks = mask_generator.generate(image) # Segment the input image
|
|
151
|
+
result = mask_generator(source, **kwargs)
|
|
152
|
+
masks = result["masks"] if "masks" in result else result # Get the masks
|
|
153
|
+
scores = result["scores"] if "scores" in result else None # Get the scores
|
|
154
|
+
|
|
155
|
+
# format the masks as a list of dictionaries, similar to the output of SAM.
|
|
156
|
+
formatted_masks = []
|
|
157
|
+
for mask, score in zip(masks, scores):
|
|
158
|
+
area = int(np.sum(mask)) # number of True pixels
|
|
159
|
+
formatted_masks.append(
|
|
160
|
+
{
|
|
161
|
+
"segmentation": mask,
|
|
162
|
+
"area": area,
|
|
163
|
+
"score": float(score), # ensure it's a native Python float
|
|
164
|
+
}
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
self.output = result # Store the result
|
|
168
|
+
self.masks = formatted_masks # Store the masks as a list of dictionaries
|
|
169
|
+
self.batch = False
|
|
170
|
+
# self.scores = scores # Store the scores
|
|
171
|
+
self._min_size = min_size
|
|
172
|
+
self._max_size = max_size
|
|
173
|
+
|
|
174
|
+
# Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
|
|
175
|
+
self.save_masks(
|
|
176
|
+
output,
|
|
177
|
+
foreground,
|
|
178
|
+
unique,
|
|
179
|
+
erosion_kernel,
|
|
180
|
+
mask_multiplier,
|
|
181
|
+
min_size,
|
|
182
|
+
max_size,
|
|
183
|
+
**output_args,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
def generate_batch(
|
|
187
|
+
self,
|
|
188
|
+
inputs: List[Union[str, np.ndarray]],
|
|
189
|
+
output_dir: Optional[str] = None,
|
|
190
|
+
suffix: str = "_masks",
|
|
191
|
+
foreground: bool = True,
|
|
192
|
+
erosion_kernel: Optional[Tuple[int, int]] = None,
|
|
193
|
+
mask_multiplier: int = 255,
|
|
194
|
+
unique: bool = True,
|
|
195
|
+
min_size: int = 0,
|
|
196
|
+
max_size: Optional[int] = None,
|
|
197
|
+
output_args: Optional[Dict[str, Any]] = None,
|
|
198
|
+
**kwargs: Any,
|
|
199
|
+
) -> None:
|
|
200
|
+
"""
|
|
201
|
+
Generate masks for a batch of input images.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
inputs (List[Union[str, np.ndarray]]): A list of paths to input images or numpy arrays representing the images.
|
|
205
|
+
output_dir (Optional[str], optional): The directory to save the output masks. Defaults to the current working directory.
|
|
206
|
+
suffix (str, optional): The suffix to append to the output filenames. Defaults to "_masks".
|
|
207
|
+
foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.
|
|
208
|
+
erosion_kernel (Optional[Tuple[int, int]], optional): The erosion kernel for filtering object masks and extracting borders.
|
|
209
|
+
For example, (3, 3) or (5, 5). Set to None to disable it. Defaults to None.
|
|
210
|
+
mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
|
|
211
|
+
You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.
|
|
212
|
+
unique (bool, optional): Whether to assign a unique value to each object. Defaults to True.
|
|
213
|
+
The unique value increases from 1 to the number of objects. The larger the number, the larger the object area.
|
|
214
|
+
min_size (int, optional): The minimum size of the objects. Defaults to 0.
|
|
215
|
+
max_size (Optional[int], optional): The maximum size of the objects. Defaults to None.
|
|
216
|
+
output_args (Optional[Dict[str, Any]], optional): Additional arguments for saving the output. Defaults to None.
|
|
217
|
+
**kwargs (Any): Other arguments for the mask generator.
|
|
218
|
+
|
|
219
|
+
Raises:
|
|
220
|
+
ValueError: If the input list is empty or contains invalid paths.
|
|
221
|
+
"""
|
|
222
|
+
|
|
223
|
+
mask_generator = self.mask_generator # The automatic mask generator
|
|
224
|
+
outputs = mask_generator(inputs, **kwargs)
|
|
225
|
+
|
|
226
|
+
if output_args is None:
|
|
227
|
+
output_args = {}
|
|
228
|
+
|
|
229
|
+
if output_dir is None:
|
|
230
|
+
output_dir = os.getcwd()
|
|
231
|
+
|
|
232
|
+
for index, result in enumerate(outputs):
|
|
233
|
+
|
|
234
|
+
basename = os.path.basename(inputs[index])
|
|
235
|
+
file_ext = os.path.splitext(basename)[1]
|
|
236
|
+
filename = f"{os.path.splitext(basename)[0]}{suffix}{file_ext}"
|
|
237
|
+
filepath = os.path.join(output_dir, filename)
|
|
238
|
+
|
|
239
|
+
masks = result["masks"] if "masks" in result else result # Get the masks
|
|
240
|
+
scores = result["scores"] if "scores" in result else None # Get the scores
|
|
241
|
+
|
|
242
|
+
# format the masks as a list of dictionaries, similar to the output of SAM.
|
|
243
|
+
formatted_masks = []
|
|
244
|
+
for mask, score in zip(masks, scores):
|
|
245
|
+
area = int(np.sum(mask)) # number of True pixels
|
|
246
|
+
formatted_masks.append(
|
|
247
|
+
{
|
|
248
|
+
"segmentation": mask,
|
|
249
|
+
"area": area,
|
|
250
|
+
"score": float(score), # ensure it's a native Python float
|
|
251
|
+
}
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
self.source = inputs[index] # Store the input image path
|
|
255
|
+
self.output = result # Store the result
|
|
256
|
+
self.masks = formatted_masks # Store the masks as a list of dictionaries
|
|
257
|
+
# self.scores = scores # Store the scores
|
|
258
|
+
self._min_size = min_size
|
|
259
|
+
self._max_size = max_size
|
|
260
|
+
|
|
261
|
+
# Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
|
|
262
|
+
self.save_masks(
|
|
263
|
+
filepath,
|
|
264
|
+
foreground,
|
|
265
|
+
unique,
|
|
266
|
+
erosion_kernel,
|
|
267
|
+
mask_multiplier,
|
|
268
|
+
min_size,
|
|
269
|
+
max_size,
|
|
270
|
+
**output_args,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
def save_masks(
|
|
274
|
+
self,
|
|
275
|
+
output: Optional[str] = None,
|
|
276
|
+
foreground: bool = True,
|
|
277
|
+
unique: bool = True,
|
|
278
|
+
erosion_kernel: Optional[Tuple[int, int]] = None,
|
|
279
|
+
mask_multiplier: int = 255,
|
|
280
|
+
min_size: int = 0,
|
|
281
|
+
max_size: Optional[int] = None,
|
|
282
|
+
**kwargs: Any,
|
|
283
|
+
) -> None:
|
|
284
|
+
"""
|
|
285
|
+
Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
output (Optional[str], optional): The path to the output image. Defaults to None, saving the masks to `SamGeo.objects`.
|
|
289
|
+
foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.
|
|
290
|
+
unique (bool, optional): Whether to assign a unique value to each object. Defaults to True.
|
|
291
|
+
erosion_kernel (Optional[Tuple[int, int]], optional): The erosion kernel for filtering object masks and extracting borders.
|
|
292
|
+
For example, (3, 3) or (5, 5). Set to None to disable it. Defaults to None.
|
|
293
|
+
mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
|
|
294
|
+
You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.
|
|
295
|
+
min_size (int, optional): The minimum size of the objects. Defaults to 0.
|
|
296
|
+
max_size (Optional[int], optional): The maximum size of the objects. Defaults to None.
|
|
297
|
+
**kwargs (Any): Other arguments for `array_to_image()`.
|
|
298
|
+
|
|
299
|
+
Raises:
|
|
300
|
+
ValueError: If no masks are found or if `generate()` has not been run.
|
|
301
|
+
"""
|
|
302
|
+
|
|
303
|
+
if self.masks is None:
|
|
304
|
+
raise ValueError("No masks found. Please run generate() first.")
|
|
305
|
+
|
|
306
|
+
if self.image is None:
|
|
307
|
+
(
|
|
308
|
+
h,
|
|
309
|
+
w,
|
|
310
|
+
) = self.masks[
|
|
311
|
+
0
|
|
312
|
+
]["segmentation"].shape
|
|
313
|
+
else:
|
|
314
|
+
h, w, _ = self.image.shape
|
|
315
|
+
masks = self.masks
|
|
316
|
+
|
|
317
|
+
# Set output image data type based on the number of objects
|
|
318
|
+
if len(masks) < 255:
|
|
319
|
+
dtype = np.uint8
|
|
320
|
+
elif len(masks) < 65535:
|
|
321
|
+
dtype = np.uint16
|
|
322
|
+
else:
|
|
323
|
+
dtype = np.uint32
|
|
324
|
+
|
|
325
|
+
# Generate a mask of objects with unique values
|
|
326
|
+
if unique:
|
|
327
|
+
# Sort the masks by area in descending order
|
|
328
|
+
sorted_masks = sorted(masks, key=(lambda x: x["area"]), reverse=True)
|
|
329
|
+
|
|
330
|
+
# Create an output image with the same size as the input image
|
|
331
|
+
objects = np.zeros(
|
|
332
|
+
(
|
|
333
|
+
sorted_masks[0]["segmentation"].shape[0],
|
|
334
|
+
sorted_masks[0]["segmentation"].shape[1],
|
|
335
|
+
)
|
|
336
|
+
)
|
|
337
|
+
# Assign a unique value to each object
|
|
338
|
+
count = len(sorted_masks)
|
|
339
|
+
for index, ann in enumerate(sorted_masks):
|
|
340
|
+
m = ann["segmentation"]
|
|
341
|
+
if min_size > 0 and ann["area"] < min_size:
|
|
342
|
+
continue
|
|
343
|
+
if max_size is not None and ann["area"] > max_size:
|
|
344
|
+
continue
|
|
345
|
+
objects[m] = count - index
|
|
346
|
+
|
|
347
|
+
# Generate a binary mask
|
|
348
|
+
else:
|
|
349
|
+
if foreground: # Extract foreground objects only
|
|
350
|
+
resulting_mask = np.zeros((h, w), dtype=dtype)
|
|
351
|
+
else:
|
|
352
|
+
resulting_mask = np.ones((h, w), dtype=dtype)
|
|
353
|
+
resulting_borders = np.zeros((h, w), dtype=dtype)
|
|
354
|
+
|
|
355
|
+
for m in masks:
|
|
356
|
+
if min_size > 0 and m["area"] < min_size:
|
|
357
|
+
continue
|
|
358
|
+
if max_size is not None and m["area"] > max_size:
|
|
359
|
+
continue
|
|
360
|
+
mask = (m["segmentation"] > 0).astype(dtype)
|
|
361
|
+
resulting_mask += mask
|
|
362
|
+
|
|
363
|
+
# Apply erosion to the mask
|
|
364
|
+
if erosion_kernel is not None:
|
|
365
|
+
mask_erode = cv2.erode(mask, erosion_kernel, iterations=1)
|
|
366
|
+
mask_erode = (mask_erode > 0).astype(dtype)
|
|
367
|
+
edge_mask = mask - mask_erode
|
|
368
|
+
resulting_borders += edge_mask
|
|
369
|
+
|
|
370
|
+
resulting_mask = (resulting_mask > 0).astype(dtype)
|
|
371
|
+
resulting_borders = (resulting_borders > 0).astype(dtype)
|
|
372
|
+
objects = resulting_mask - resulting_borders
|
|
373
|
+
objects = objects * mask_multiplier
|
|
374
|
+
|
|
375
|
+
objects = objects.astype(dtype)
|
|
376
|
+
self.objects = objects
|
|
377
|
+
|
|
378
|
+
if output is not None: # Save the output image
|
|
379
|
+
array_to_image(self.objects, output, self.source, **kwargs)
|
|
380
|
+
|
|
381
|
+
def show_masks(
|
|
382
|
+
self,
|
|
383
|
+
figsize: Tuple[int, int] = (12, 10),
|
|
384
|
+
cmap: str = "binary_r",
|
|
385
|
+
axis: str = "off",
|
|
386
|
+
foreground: bool = True,
|
|
387
|
+
**kwargs: Any,
|
|
388
|
+
) -> None:
|
|
389
|
+
"""
|
|
390
|
+
Display the binary mask or the mask of objects with unique values.
|
|
391
|
+
|
|
392
|
+
Args:
|
|
393
|
+
figsize (Tuple[int, int], optional): The figure size. Defaults to (12, 10).
|
|
394
|
+
cmap (str, optional): The colormap to use for displaying the mask. Defaults to "binary_r".
|
|
395
|
+
axis (str, optional): Whether to show the axis. Defaults to "off".
|
|
396
|
+
foreground (bool, optional): Whether to show the foreground mask only. Defaults to True.
|
|
397
|
+
**kwargs (Any): Additional arguments for the `save_masks()` method.
|
|
398
|
+
|
|
399
|
+
Raises:
|
|
400
|
+
ValueError: If no masks are available and `save_masks()` cannot generate them.
|
|
401
|
+
"""
|
|
402
|
+
import matplotlib.pyplot as plt
|
|
403
|
+
|
|
404
|
+
if self.batch:
|
|
405
|
+
self.objects = cv2.imread(self.masks)
|
|
406
|
+
else:
|
|
407
|
+
if self.objects is None:
|
|
408
|
+
self.save_masks(foreground=foreground, **kwargs)
|
|
409
|
+
|
|
410
|
+
plt.figure(figsize=figsize)
|
|
411
|
+
plt.imshow(self.objects, cmap=cmap)
|
|
412
|
+
plt.axis(axis)
|
|
413
|
+
plt.show()
|
|
414
|
+
|
|
415
|
+
def show_anns(
|
|
416
|
+
self,
|
|
417
|
+
figsize: Tuple[int, int] = (12, 10),
|
|
418
|
+
axis: str = "off",
|
|
419
|
+
alpha: float = 0.35,
|
|
420
|
+
output: Optional[str] = None,
|
|
421
|
+
blend: bool = True,
|
|
422
|
+
**kwargs: Any,
|
|
423
|
+
) -> None:
|
|
424
|
+
"""
|
|
425
|
+
Show the annotations (objects with random color) on the input image.
|
|
426
|
+
|
|
427
|
+
Args:
|
|
428
|
+
figsize (Tuple[int, int], optional): The figure size. Defaults to (12, 10).
|
|
429
|
+
axis (str, optional): Whether to show the axis. Defaults to "off".
|
|
430
|
+
alpha (float, optional): The alpha value for the annotations. Defaults to 0.35.
|
|
431
|
+
output (Optional[str], optional): The path to the output image. Defaults to None.
|
|
432
|
+
blend (bool, optional): Whether to show the input image blended with annotations. Defaults to True.
|
|
433
|
+
**kwargs (Any): Additional arguments for saving the output image.
|
|
434
|
+
|
|
435
|
+
Raises:
|
|
436
|
+
ValueError: If the input image or annotations are not available.
|
|
437
|
+
"""
|
|
438
|
+
|
|
439
|
+
import matplotlib.pyplot as plt
|
|
440
|
+
|
|
441
|
+
anns = self.masks
|
|
442
|
+
|
|
443
|
+
if self.image is None:
|
|
444
|
+
print("Please run generate() first.")
|
|
445
|
+
return
|
|
446
|
+
|
|
447
|
+
if anns is None or len(anns) == 0:
|
|
448
|
+
return
|
|
449
|
+
|
|
450
|
+
plt.figure(figsize=figsize)
|
|
451
|
+
plt.imshow(self.image)
|
|
452
|
+
|
|
453
|
+
sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
|
|
454
|
+
|
|
455
|
+
ax = plt.gca()
|
|
456
|
+
ax.set_autoscale_on(False)
|
|
457
|
+
|
|
458
|
+
img = np.ones(
|
|
459
|
+
(
|
|
460
|
+
sorted_anns[0]["segmentation"].shape[0],
|
|
461
|
+
sorted_anns[0]["segmentation"].shape[1],
|
|
462
|
+
4,
|
|
463
|
+
)
|
|
464
|
+
)
|
|
465
|
+
img[:, :, 3] = 0
|
|
466
|
+
for ann in sorted_anns:
|
|
467
|
+
if hasattr(self, "_min_size") and (ann["area"] < self._min_size):
|
|
468
|
+
continue
|
|
469
|
+
if (
|
|
470
|
+
hasattr(self, "_max_size")
|
|
471
|
+
and isinstance(self._max_size, int)
|
|
472
|
+
and ann["area"] > self._max_size
|
|
473
|
+
):
|
|
474
|
+
continue
|
|
475
|
+
m = ann["segmentation"]
|
|
476
|
+
color_mask = np.concatenate([np.random.random(3), [alpha]])
|
|
477
|
+
img[m] = color_mask
|
|
478
|
+
ax.imshow(img)
|
|
479
|
+
|
|
480
|
+
# if "dpi" not in kwargs:
|
|
481
|
+
# kwargs["dpi"] = 100
|
|
482
|
+
|
|
483
|
+
# if "bbox_inches" not in kwargs:
|
|
484
|
+
# kwargs["bbox_inches"] = "tight"
|
|
485
|
+
|
|
486
|
+
plt.axis(axis)
|
|
487
|
+
|
|
488
|
+
self.annotations = (img[:, :, 0:3] * 255).astype(np.uint8)
|
|
489
|
+
|
|
490
|
+
if output is not None:
|
|
491
|
+
if blend:
|
|
492
|
+
array = blend_images(
|
|
493
|
+
self.annotations, self.image, alpha=alpha, show=False
|
|
494
|
+
)
|
|
495
|
+
else:
|
|
496
|
+
array = self.annotations
|
|
497
|
+
array_to_image(array, output, self.source, **kwargs)
|
|
498
|
+
|
|
499
|
+
def set_image(self, image: Union[str, np.ndarray], **kwargs: Any) -> None:
|
|
500
|
+
"""
|
|
501
|
+
Set the input image as a numpy array.
|
|
502
|
+
|
|
503
|
+
Args:
|
|
504
|
+
image (Union[str, np.ndarray]): The input image, either as a file path (string) or a numpy array.
|
|
505
|
+
**kwargs (Any): Additional arguments for the image processor.
|
|
506
|
+
|
|
507
|
+
Raises:
|
|
508
|
+
ValueError: If the input image path does not exist.
|
|
509
|
+
"""
|
|
510
|
+
if isinstance(image, str):
|
|
511
|
+
if image.startswith("http"):
|
|
512
|
+
image = download_file(image)
|
|
513
|
+
|
|
514
|
+
if not os.path.exists(image):
|
|
515
|
+
raise ValueError(f"Input path {image} does not exist.")
|
|
516
|
+
|
|
517
|
+
self.source = image
|
|
518
|
+
|
|
519
|
+
image = Image.open(image).convert("RGB")
|
|
520
|
+
self.image = image
|
|
521
|
+
|
|
522
|
+
inputs = self.processor(image, return_tensors="pt").to(self.device)
|
|
523
|
+
self.embeddings = self.predictor.get_image_embeddings(
|
|
524
|
+
inputs["pixel_values"], **kwargs
|
|
525
|
+
)
|
|
526
|
+
|
|
527
|
+
def save_prediction(
|
|
528
|
+
self,
|
|
529
|
+
output: str,
|
|
530
|
+
index: Optional[int] = None,
|
|
531
|
+
mask_multiplier: int = 255,
|
|
532
|
+
dtype: np.dtype = np.float32,
|
|
533
|
+
vector: Optional[str] = None,
|
|
534
|
+
simplify_tolerance: Optional[float] = None,
|
|
535
|
+
**kwargs: Any,
|
|
536
|
+
) -> None:
|
|
537
|
+
"""
|
|
538
|
+
Save the predicted mask to the output path.
|
|
539
|
+
|
|
540
|
+
Args:
|
|
541
|
+
output (str): The path to the output image.
|
|
542
|
+
index (Optional[int], optional): The index of the mask to save. Defaults to None,
|
|
543
|
+
which will save the mask with the highest score.
|
|
544
|
+
mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
|
|
545
|
+
Defaults to 255.
|
|
546
|
+
dtype (np.dtype, optional): The data type of the output image. Defaults to np.float32.
|
|
547
|
+
vector (Optional[str], optional): The path to the output vector file. Defaults to None.
|
|
548
|
+
simplify_tolerance (Optional[float], optional): The maximum allowed geometry displacement.
|
|
549
|
+
The higher this value, the smaller the number of vertices in the resulting geometry. Defaults to None.
|
|
550
|
+
**kwargs (Any): Additional arguments for saving the output image.
|
|
551
|
+
|
|
552
|
+
Raises:
|
|
553
|
+
ValueError: If no predictions are found.
|
|
554
|
+
"""
|
|
555
|
+
if self.scores is None:
|
|
556
|
+
raise ValueError("No predictions found. Please run predict() first.")
|
|
557
|
+
|
|
558
|
+
if index is None:
|
|
559
|
+
index = self.scores.argmax(axis=0)
|
|
560
|
+
|
|
561
|
+
array = self.masks[index] * mask_multiplier
|
|
562
|
+
self.prediction = array
|
|
563
|
+
array_to_image(array, output, self.source, dtype=dtype, **kwargs)
|
|
564
|
+
|
|
565
|
+
if vector is not None:
|
|
566
|
+
raster_to_vector(output, vector, simplify_tolerance=simplify_tolerance)
|
|
567
|
+
|
|
568
|
+
def predict(
|
|
569
|
+
self,
|
|
570
|
+
point_coords=None,
|
|
571
|
+
point_labels=None,
|
|
572
|
+
boxes=None,
|
|
573
|
+
point_crs=None,
|
|
574
|
+
mask_input=None,
|
|
575
|
+
multimask_output=True,
|
|
576
|
+
return_logits=False,
|
|
577
|
+
output=None,
|
|
578
|
+
index=None,
|
|
579
|
+
mask_multiplier=255,
|
|
580
|
+
dtype="float32",
|
|
581
|
+
return_results=False,
|
|
582
|
+
**kwargs,
|
|
583
|
+
):
|
|
584
|
+
"""Predict masks for the given input prompts, using the currently set image.
|
|
585
|
+
|
|
586
|
+
Args:
|
|
587
|
+
point_coords (str | dict | list | np.ndarray, optional): A Nx2 array of point prompts to the
|
|
588
|
+
model. Each point is in (X,Y) in pixels. It can be a path to a vector file, a GeoJSON
|
|
589
|
+
dictionary, a list of coordinates [lon, lat], or a numpy array. Defaults to None.
|
|
590
|
+
point_labels (list | int | np.ndarray, optional): A length N array of labels for the
|
|
591
|
+
point prompts. 1 indicates a foreground point and 0 indicates a background point.
|
|
592
|
+
point_crs (str, optional): The coordinate reference system (CRS) of the point prompts.
|
|
593
|
+
boxes (list | np.ndarray, optional): A length 4 array given a box prompt to the
|
|
594
|
+
model, in XYXY format.
|
|
595
|
+
mask_input (np.ndarray, optional): A low resolution mask input to the model, typically
|
|
596
|
+
coming from a previous prediction iteration. Has form 1xHxW, where for SAM, H=W=256.
|
|
597
|
+
multimask_output (bool, optional): If true, the model will return three masks.
|
|
598
|
+
For ambiguous input prompts (such as a single click), this will often
|
|
599
|
+
produce better masks than a single prediction. If only a single
|
|
600
|
+
mask is needed, the model's predicted quality score can be used
|
|
601
|
+
to select the best mask. For non-ambiguous prompts, such as multiple
|
|
602
|
+
input prompts, multimask_output=False can give better results.
|
|
603
|
+
return_logits (bool, optional): If true, returns un-thresholded masks logits
|
|
604
|
+
instead of a binary mask.
|
|
605
|
+
output (str, optional): The path to the output image. Defaults to None.
|
|
606
|
+
index (index, optional): The index of the mask to save. Defaults to None,
|
|
607
|
+
which will save the mask with the highest score.
|
|
608
|
+
mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
|
|
609
|
+
dtype (np.dtype, optional): The data type of the output image. Defaults to np.float32.
|
|
610
|
+
return_results (bool, optional): Whether to return the predicted masks, scores, and logits. Defaults to False.
|
|
611
|
+
|
|
612
|
+
"""
|
|
613
|
+
out_of_bounds = []
|
|
614
|
+
|
|
615
|
+
if isinstance(boxes, str):
|
|
616
|
+
gdf = gpd.read_file(boxes)
|
|
617
|
+
if gdf.crs is not None:
|
|
618
|
+
gdf = gdf.to_crs("epsg:4326")
|
|
619
|
+
boxes = gdf.geometry.bounds.values.tolist()
|
|
620
|
+
elif isinstance(boxes, dict):
|
|
621
|
+
import json
|
|
622
|
+
|
|
623
|
+
geojson = json.dumps(boxes)
|
|
624
|
+
gdf = gpd.read_file(geojson, driver="GeoJSON")
|
|
625
|
+
boxes = gdf.geometry.bounds.values.tolist()
|
|
626
|
+
|
|
627
|
+
if isinstance(point_coords, str):
|
|
628
|
+
point_coords = vector_to_geojson(point_coords)
|
|
629
|
+
|
|
630
|
+
if isinstance(point_coords, dict):
|
|
631
|
+
point_coords = geojson_to_coords(point_coords)
|
|
632
|
+
|
|
633
|
+
if hasattr(self, "point_coords"):
|
|
634
|
+
point_coords = self.point_coords
|
|
635
|
+
|
|
636
|
+
if hasattr(self, "point_labels"):
|
|
637
|
+
point_labels = self.point_labels
|
|
638
|
+
|
|
639
|
+
if (point_crs is not None) and (point_coords is not None):
|
|
640
|
+
point_coords, out_of_bounds = coords_to_xy(
|
|
641
|
+
self.source, point_coords, point_crs, return_out_of_bounds=True
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
if isinstance(point_coords, list):
|
|
645
|
+
point_coords = np.array(point_coords)
|
|
646
|
+
|
|
647
|
+
if point_coords is not None:
|
|
648
|
+
if point_labels is None:
|
|
649
|
+
point_labels = [1] * len(point_coords)
|
|
650
|
+
elif isinstance(point_labels, int):
|
|
651
|
+
point_labels = [point_labels] * len(point_coords)
|
|
652
|
+
|
|
653
|
+
if isinstance(point_labels, list):
|
|
654
|
+
if len(point_labels) != len(point_coords):
|
|
655
|
+
if len(point_labels) == 1:
|
|
656
|
+
point_labels = point_labels * len(point_coords)
|
|
657
|
+
elif len(out_of_bounds) > 0:
|
|
658
|
+
print(f"Removing {len(out_of_bounds)} out-of-bound points.")
|
|
659
|
+
point_labels_new = []
|
|
660
|
+
for i, p in enumerate(point_labels):
|
|
661
|
+
if i not in out_of_bounds:
|
|
662
|
+
point_labels_new.append(p)
|
|
663
|
+
point_labels = point_labels_new
|
|
664
|
+
else:
|
|
665
|
+
raise ValueError(
|
|
666
|
+
"The length of point_labels must be equal to the length of point_coords."
|
|
667
|
+
)
|
|
668
|
+
point_labels = np.array(point_labels)
|
|
669
|
+
|
|
670
|
+
predictor = self.predictor
|
|
671
|
+
|
|
672
|
+
input_boxes = None
|
|
673
|
+
if isinstance(boxes, list) and (point_crs is not None):
|
|
674
|
+
coords = bbox_to_xy(self.source, boxes, point_crs)
|
|
675
|
+
input_boxes = np.array(coords)
|
|
676
|
+
if isinstance(coords[0], int):
|
|
677
|
+
input_boxes = input_boxes[None, :]
|
|
678
|
+
else:
|
|
679
|
+
input_boxes = torch.tensor(input_boxes, device=self.device)
|
|
680
|
+
input_boxes = predictor.transform.apply_boxes_torch(
|
|
681
|
+
input_boxes, self.image.shape[:2]
|
|
682
|
+
)
|
|
683
|
+
elif isinstance(boxes, list) and (point_crs is None):
|
|
684
|
+
input_boxes = np.array(boxes)
|
|
685
|
+
if isinstance(boxes[0], int):
|
|
686
|
+
input_boxes = input_boxes[None, :]
|
|
687
|
+
|
|
688
|
+
self.boxes = input_boxes
|
|
689
|
+
self.point_coords = point_coords
|
|
690
|
+
self.point_labels = point_labels
|
|
691
|
+
|
|
692
|
+
if input_boxes is not None:
|
|
693
|
+
input_boxes = [input_boxes]
|
|
694
|
+
|
|
695
|
+
if point_coords is not None:
|
|
696
|
+
point_coords = [[point_coords]]
|
|
697
|
+
point_labels = [[point_labels]]
|
|
698
|
+
|
|
699
|
+
inputs = self.processor(
|
|
700
|
+
self.image,
|
|
701
|
+
input_points=point_coords,
|
|
702
|
+
# input_labels=point_labels,
|
|
703
|
+
input_boxes=input_boxes,
|
|
704
|
+
return_tensors="pt",
|
|
705
|
+
**kwargs,
|
|
706
|
+
).to(self.device)
|
|
707
|
+
|
|
708
|
+
inputs.pop("pixel_values", None)
|
|
709
|
+
inputs.update({"image_embeddings": self.embeddings})
|
|
710
|
+
|
|
711
|
+
with torch.no_grad():
|
|
712
|
+
outputs = self.predictor(**inputs)
|
|
713
|
+
|
|
714
|
+
# https://huggingface.co/docs/transformers/en/model_doc/sam#transformers.SamImageProcessor.post_process_masks
|
|
715
|
+
self.masks = self.processor.image_processor.post_process_masks(
|
|
716
|
+
outputs.pred_masks.cpu(),
|
|
717
|
+
inputs["original_sizes"].cpu(),
|
|
718
|
+
inputs["reshaped_input_sizes"].cpu(),
|
|
719
|
+
)
|
|
720
|
+
self.scores = outputs.iou_scores
|
|
721
|
+
|
|
722
|
+
# if (
|
|
723
|
+
# boxes is None
|
|
724
|
+
# or (len(boxes) == 1)
|
|
725
|
+
# or (len(boxes) == 4 and isinstance(boxes[0], float))
|
|
726
|
+
# ):
|
|
727
|
+
# if isinstance(boxes, list) and isinstance(boxes[0], list):
|
|
728
|
+
# boxes = boxes[0]
|
|
729
|
+
# masks, scores, logits = predictor.predict(
|
|
730
|
+
# point_coords,
|
|
731
|
+
# point_labels,
|
|
732
|
+
# input_boxes,
|
|
733
|
+
# mask_input,
|
|
734
|
+
# multimask_output,
|
|
735
|
+
# return_logits,
|
|
736
|
+
# )
|
|
737
|
+
# else:
|
|
738
|
+
# masks, scores, logits = predictor.predict_torch(
|
|
739
|
+
# point_coords=point_coords,
|
|
740
|
+
# point_labels=point_coords,
|
|
741
|
+
# boxes=input_boxes,
|
|
742
|
+
# multimask_output=True,
|
|
743
|
+
# )
|
|
744
|
+
|
|
745
|
+
# self.masks = masks
|
|
746
|
+
# self.scores = scores
|
|
747
|
+
# self.logits = logits
|
|
748
|
+
|
|
749
|
+
# if output is not None:
|
|
750
|
+
# if boxes is None or (not isinstance(boxes[0], list)):
|
|
751
|
+
# self.save_prediction(output, index, mask_multiplier, dtype, **kwargs)
|
|
752
|
+
# else:
|
|
753
|
+
# self.tensor_to_numpy(
|
|
754
|
+
# index, output, mask_multiplier, dtype, save_args=kwargs
|
|
755
|
+
# )
|
|
756
|
+
|
|
757
|
+
# if return_results:
|
|
758
|
+
# return masks, scores, logits
|
|
759
|
+
|
|
760
|
+
def tensor_to_numpy(
|
|
761
|
+
self,
|
|
762
|
+
index: Optional[int] = None,
|
|
763
|
+
output: Optional[str] = None,
|
|
764
|
+
mask_multiplier: int = 255,
|
|
765
|
+
dtype: Union[str, np.dtype] = "uint8",
|
|
766
|
+
save_args: Optional[Dict[str, Any]] = None,
|
|
767
|
+
) -> Optional[np.ndarray]:
|
|
768
|
+
"""
|
|
769
|
+
Convert the predicted masks from tensors to numpy arrays.
|
|
770
|
+
|
|
771
|
+
Args:
|
|
772
|
+
index (Optional[int], optional): The index of the mask to save. Defaults to None,
|
|
773
|
+
which will save the mask with the highest score.
|
|
774
|
+
output (Optional[str], optional): The path to the output image. Defaults to None.
|
|
775
|
+
mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
|
|
776
|
+
Defaults to 255.
|
|
777
|
+
dtype (Union[str, np.dtype], optional): The data type of the output image. Defaults to "uint8".
|
|
778
|
+
save_args (Optional[Dict[str, Any]], optional): Optional arguments for saving the output image. Defaults to None.
|
|
779
|
+
|
|
780
|
+
Returns:
|
|
781
|
+
Optional[np.ndarray]: The predicted mask as a numpy array if `output` is None. Otherwise, saves the mask to the specified path.
|
|
782
|
+
|
|
783
|
+
Raises:
|
|
784
|
+
ValueError: If no objects are found in the image or if the masks are not available.
|
|
785
|
+
"""
|
|
786
|
+
|
|
787
|
+
if save_args is None:
|
|
788
|
+
save_args = {}
|
|
789
|
+
|
|
790
|
+
if self.masks is None:
|
|
791
|
+
raise ValueError("No masks found. Please run the prediction method first.")
|
|
792
|
+
|
|
793
|
+
boxes = self.boxes
|
|
794
|
+
masks = self.masks
|
|
795
|
+
|
|
796
|
+
image_pil = self.image
|
|
797
|
+
image_np = np.array(image_pil)
|
|
798
|
+
|
|
799
|
+
if index is None:
|
|
800
|
+
index = 1
|
|
801
|
+
|
|
802
|
+
masks = masks[:, index, :, :]
|
|
803
|
+
masks = masks.squeeze(1)
|
|
804
|
+
|
|
805
|
+
if boxes is None or (len(boxes) == 0): # No "object" instances found
|
|
806
|
+
print("No objects found in the image.")
|
|
807
|
+
return
|
|
808
|
+
else:
|
|
809
|
+
# Create an empty image to store the mask overlays
|
|
810
|
+
mask_overlay = np.zeros_like(
|
|
811
|
+
image_np[..., 0], dtype=dtype
|
|
812
|
+
) # Adjusted for single channel
|
|
813
|
+
|
|
814
|
+
for i, (box, mask) in enumerate(zip(boxes, masks)):
|
|
815
|
+
# Convert tensor to numpy array if necessary and ensure it contains integers
|
|
816
|
+
if isinstance(mask, torch.Tensor):
|
|
817
|
+
mask = (
|
|
818
|
+
mask.cpu().numpy().astype(dtype)
|
|
819
|
+
) # If mask is on GPU, use .cpu() before .numpy()
|
|
820
|
+
mask_overlay += ((mask > 0) * (i + 1)).astype(
|
|
821
|
+
dtype
|
|
822
|
+
) # Assign a unique value for each mask
|
|
823
|
+
|
|
824
|
+
# Normalize mask_overlay to be in [0, 255]
|
|
825
|
+
mask_overlay = (
|
|
826
|
+
mask_overlay > 0
|
|
827
|
+
) * mask_multiplier # Binary mask in [0, 255]
|
|
828
|
+
|
|
829
|
+
if output is not None:
|
|
830
|
+
array_to_image(mask_overlay, output, self.source, dtype=dtype, **save_args)
|
|
831
|
+
else:
|
|
832
|
+
return mask_overlay
|