geoai-py 0.5.6__py2.py3-none-any.whl → 0.6.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/sam.py ADDED
@@ -0,0 +1,832 @@
1
+ """
2
+ The SamGeo class provides an interface for segmenting geospatial data using the Segment Anything Model (SAM).
3
+ """
4
+
5
+ import os
6
+ from typing import Any, Dict, List, Optional, Tuple, Union
7
+
8
+ import cv2
9
+ import numpy as np
10
+ import torch
11
+ from leafmap import array_to_image, blend_images
12
+ from PIL import Image
13
+ from transformers import SamModel, SamProcessor, pipeline
14
+
15
+ from .utils import *
16
+
17
+
18
+ class SamGeo:
19
+ """The main class for segmenting geospatial data with the Segment Anything Model (SAM). See
20
+ https://huggingface.co/docs/transformers/main/en/model_doc/sam for details.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ model: str = "facebook/sam-vit-huge",
26
+ automatic: bool = True,
27
+ device: Optional[Union[str, int]] = None,
28
+ sam_kwargs: Optional[Dict[str, Any]] = None,
29
+ **kwargs: Any,
30
+ ) -> None:
31
+ """
32
+ Initialize the class.
33
+
34
+ Args:
35
+ model (str, optional): The model type, such as "facebook/sam-vit-huge", "facebook/sam-vit-large", or "facebook/sam-vit-base".
36
+ Defaults to 'facebook/sam-vit-huge'. See https://bit.ly/3VrpxUh for more details.
37
+ automatic (bool, optional): Whether to use the automatic mask generator or input prompts. Defaults to True.
38
+ The automatic mask generator will segment the entire image, while the input prompts will segment selected objects.
39
+ device (Union[str, int], optional): The device to use. It can be one of the following: 'cpu', 'cuda', or an integer
40
+ representing the CUDA device index. Defaults to None, which will use 'cuda' if available.
41
+ sam_kwargs (Dict[str, Any], optional): Optional arguments for fine-tuning the SAM model. Defaults to None.
42
+ kwargs (Any): Other arguments for the automatic mask generator.
43
+ """
44
+
45
+ self.model = model
46
+ self.model_version = "sam"
47
+
48
+ self.sam_kwargs = sam_kwargs # Optional arguments for fine-tuning the SAM model
49
+ self.source = None # Store the input image path
50
+ self.image = None # Store the input image as a numpy array
51
+ self.embeddings = None # Store the image embeddings
52
+ # Store the masks as a list of dictionaries. Each mask is a dictionary
53
+ # containing segmentation, area, bbox, predicted_iou, point_coords, stability_score, and crop_box
54
+ self.masks = None
55
+ self.objects = None # Store the mask objects as a numpy array
56
+ # Store the annotations (objects with random color) as a numpy array.
57
+ self.annotations = None
58
+
59
+ # Store the predicted masks, iou_predictions, and low_res_masks
60
+ self.prediction = None
61
+ self.scores = None
62
+ self.logits = None
63
+
64
+ # Build the SAM model
65
+ sam_kwargs = self.sam_kwargs if self.sam_kwargs is not None else {}
66
+
67
+ if automatic:
68
+ # Use cuda if available
69
+ if device is None:
70
+ device = 0 if torch.cuda.is_available() else -1
71
+ if device >= 0:
72
+ torch.cuda.empty_cache()
73
+ self.device = device
74
+
75
+ self.mask_generator = pipeline(
76
+ task="mask-generation",
77
+ model=model,
78
+ device=device,
79
+ **kwargs,
80
+ )
81
+
82
+ else:
83
+ if device is None:
84
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
85
+ if device.type == "cuda":
86
+ torch.cuda.empty_cache()
87
+ self.device = device
88
+
89
+ self.predictor = SamModel.from_pretrained("facebook/sam-vit-huge").to(
90
+ device
91
+ )
92
+ self.processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
93
+
94
+ def generate(
95
+ self,
96
+ source: Union[str, np.ndarray],
97
+ output: Optional[str] = None,
98
+ foreground: bool = True,
99
+ erosion_kernel: Optional[Tuple[int, int]] = None,
100
+ mask_multiplier: int = 255,
101
+ unique: bool = True,
102
+ min_size: int = 0,
103
+ max_size: Optional[int] = None,
104
+ output_args: Optional[Dict[str, Any]] = None,
105
+ **kwargs: Any,
106
+ ) -> None:
107
+ """
108
+ Generate masks for the input image.
109
+
110
+ Args:
111
+ source (Union[str, np.ndarray]): The path to the input image or the input image as a numpy array.
112
+ output (Optional[str], optional): The path to the output image. Defaults to None.
113
+ foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.
114
+ erosion_kernel (Optional[Tuple[int, int]], optional): The erosion kernel for filtering object masks and extracting borders.
115
+ For example, (3, 3) or (5, 5). Set to None to disable it. Defaults to None.
116
+ mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
117
+ You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.
118
+ unique (bool, optional): Whether to assign a unique value to each object. Defaults to True.
119
+ The unique value increases from 1 to the number of objects. The larger the number, the larger the object area.
120
+ min_size (int, optional): The minimum size of the objects. Defaults to 0.
121
+ max_size (Optional[int], optional): The maximum size of the objects. Defaults to None.
122
+ output_args (Optional[Dict[str, Any]], optional): Additional arguments for saving the output. Defaults to None.
123
+ **kwargs (Any): Other arguments for the mask generator.
124
+
125
+ Raises:
126
+ ValueError: If the input source is not a valid path or numpy array.
127
+ """
128
+
129
+ if isinstance(source, str):
130
+ if source.startswith("http"):
131
+ source = download_file(source)
132
+
133
+ if not os.path.exists(source):
134
+ raise ValueError(f"Input path {source} does not exist.")
135
+
136
+ image = cv2.imread(source)
137
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
138
+ elif isinstance(source, np.ndarray):
139
+ image = source
140
+ source = None
141
+ else:
142
+ raise ValueError("Input source must be either a path or a numpy array.")
143
+
144
+ if output_args is None:
145
+ output_args = {}
146
+
147
+ self.source = source # Store the input image path
148
+ self.image = image # Store the input image as a numpy array
149
+ mask_generator = self.mask_generator # The automatic mask generator
150
+ # masks = mask_generator.generate(image) # Segment the input image
151
+ result = mask_generator(source, **kwargs)
152
+ masks = result["masks"] if "masks" in result else result # Get the masks
153
+ scores = result["scores"] if "scores" in result else None # Get the scores
154
+
155
+ # format the masks as a list of dictionaries, similar to the output of SAM.
156
+ formatted_masks = []
157
+ for mask, score in zip(masks, scores):
158
+ area = int(np.sum(mask)) # number of True pixels
159
+ formatted_masks.append(
160
+ {
161
+ "segmentation": mask,
162
+ "area": area,
163
+ "score": float(score), # ensure it's a native Python float
164
+ }
165
+ )
166
+
167
+ self.output = result # Store the result
168
+ self.masks = formatted_masks # Store the masks as a list of dictionaries
169
+ self.batch = False
170
+ # self.scores = scores # Store the scores
171
+ self._min_size = min_size
172
+ self._max_size = max_size
173
+
174
+ # Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
175
+ self.save_masks(
176
+ output,
177
+ foreground,
178
+ unique,
179
+ erosion_kernel,
180
+ mask_multiplier,
181
+ min_size,
182
+ max_size,
183
+ **output_args,
184
+ )
185
+
186
+ def generate_batch(
187
+ self,
188
+ inputs: List[Union[str, np.ndarray]],
189
+ output_dir: Optional[str] = None,
190
+ suffix: str = "_masks",
191
+ foreground: bool = True,
192
+ erosion_kernel: Optional[Tuple[int, int]] = None,
193
+ mask_multiplier: int = 255,
194
+ unique: bool = True,
195
+ min_size: int = 0,
196
+ max_size: Optional[int] = None,
197
+ output_args: Optional[Dict[str, Any]] = None,
198
+ **kwargs: Any,
199
+ ) -> None:
200
+ """
201
+ Generate masks for a batch of input images.
202
+
203
+ Args:
204
+ inputs (List[Union[str, np.ndarray]]): A list of paths to input images or numpy arrays representing the images.
205
+ output_dir (Optional[str], optional): The directory to save the output masks. Defaults to the current working directory.
206
+ suffix (str, optional): The suffix to append to the output filenames. Defaults to "_masks".
207
+ foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.
208
+ erosion_kernel (Optional[Tuple[int, int]], optional): The erosion kernel for filtering object masks and extracting borders.
209
+ For example, (3, 3) or (5, 5). Set to None to disable it. Defaults to None.
210
+ mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
211
+ You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.
212
+ unique (bool, optional): Whether to assign a unique value to each object. Defaults to True.
213
+ The unique value increases from 1 to the number of objects. The larger the number, the larger the object area.
214
+ min_size (int, optional): The minimum size of the objects. Defaults to 0.
215
+ max_size (Optional[int], optional): The maximum size of the objects. Defaults to None.
216
+ output_args (Optional[Dict[str, Any]], optional): Additional arguments for saving the output. Defaults to None.
217
+ **kwargs (Any): Other arguments for the mask generator.
218
+
219
+ Raises:
220
+ ValueError: If the input list is empty or contains invalid paths.
221
+ """
222
+
223
+ mask_generator = self.mask_generator # The automatic mask generator
224
+ outputs = mask_generator(inputs, **kwargs)
225
+
226
+ if output_args is None:
227
+ output_args = {}
228
+
229
+ if output_dir is None:
230
+ output_dir = os.getcwd()
231
+
232
+ for index, result in enumerate(outputs):
233
+
234
+ basename = os.path.basename(inputs[index])
235
+ file_ext = os.path.splitext(basename)[1]
236
+ filename = f"{os.path.splitext(basename)[0]}{suffix}{file_ext}"
237
+ filepath = os.path.join(output_dir, filename)
238
+
239
+ masks = result["masks"] if "masks" in result else result # Get the masks
240
+ scores = result["scores"] if "scores" in result else None # Get the scores
241
+
242
+ # format the masks as a list of dictionaries, similar to the output of SAM.
243
+ formatted_masks = []
244
+ for mask, score in zip(masks, scores):
245
+ area = int(np.sum(mask)) # number of True pixels
246
+ formatted_masks.append(
247
+ {
248
+ "segmentation": mask,
249
+ "area": area,
250
+ "score": float(score), # ensure it's a native Python float
251
+ }
252
+ )
253
+
254
+ self.source = inputs[index] # Store the input image path
255
+ self.output = result # Store the result
256
+ self.masks = formatted_masks # Store the masks as a list of dictionaries
257
+ # self.scores = scores # Store the scores
258
+ self._min_size = min_size
259
+ self._max_size = max_size
260
+
261
+ # Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
262
+ self.save_masks(
263
+ filepath,
264
+ foreground,
265
+ unique,
266
+ erosion_kernel,
267
+ mask_multiplier,
268
+ min_size,
269
+ max_size,
270
+ **output_args,
271
+ )
272
+
273
+ def save_masks(
274
+ self,
275
+ output: Optional[str] = None,
276
+ foreground: bool = True,
277
+ unique: bool = True,
278
+ erosion_kernel: Optional[Tuple[int, int]] = None,
279
+ mask_multiplier: int = 255,
280
+ min_size: int = 0,
281
+ max_size: Optional[int] = None,
282
+ **kwargs: Any,
283
+ ) -> None:
284
+ """
285
+ Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
286
+
287
+ Args:
288
+ output (Optional[str], optional): The path to the output image. Defaults to None, saving the masks to `SamGeo.objects`.
289
+ foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.
290
+ unique (bool, optional): Whether to assign a unique value to each object. Defaults to True.
291
+ erosion_kernel (Optional[Tuple[int, int]], optional): The erosion kernel for filtering object masks and extracting borders.
292
+ For example, (3, 3) or (5, 5). Set to None to disable it. Defaults to None.
293
+ mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
294
+ You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.
295
+ min_size (int, optional): The minimum size of the objects. Defaults to 0.
296
+ max_size (Optional[int], optional): The maximum size of the objects. Defaults to None.
297
+ **kwargs (Any): Other arguments for `array_to_image()`.
298
+
299
+ Raises:
300
+ ValueError: If no masks are found or if `generate()` has not been run.
301
+ """
302
+
303
+ if self.masks is None:
304
+ raise ValueError("No masks found. Please run generate() first.")
305
+
306
+ if self.image is None:
307
+ (
308
+ h,
309
+ w,
310
+ ) = self.masks[
311
+ 0
312
+ ]["segmentation"].shape
313
+ else:
314
+ h, w, _ = self.image.shape
315
+ masks = self.masks
316
+
317
+ # Set output image data type based on the number of objects
318
+ if len(masks) < 255:
319
+ dtype = np.uint8
320
+ elif len(masks) < 65535:
321
+ dtype = np.uint16
322
+ else:
323
+ dtype = np.uint32
324
+
325
+ # Generate a mask of objects with unique values
326
+ if unique:
327
+ # Sort the masks by area in descending order
328
+ sorted_masks = sorted(masks, key=(lambda x: x["area"]), reverse=True)
329
+
330
+ # Create an output image with the same size as the input image
331
+ objects = np.zeros(
332
+ (
333
+ sorted_masks[0]["segmentation"].shape[0],
334
+ sorted_masks[0]["segmentation"].shape[1],
335
+ )
336
+ )
337
+ # Assign a unique value to each object
338
+ count = len(sorted_masks)
339
+ for index, ann in enumerate(sorted_masks):
340
+ m = ann["segmentation"]
341
+ if min_size > 0 and ann["area"] < min_size:
342
+ continue
343
+ if max_size is not None and ann["area"] > max_size:
344
+ continue
345
+ objects[m] = count - index
346
+
347
+ # Generate a binary mask
348
+ else:
349
+ if foreground: # Extract foreground objects only
350
+ resulting_mask = np.zeros((h, w), dtype=dtype)
351
+ else:
352
+ resulting_mask = np.ones((h, w), dtype=dtype)
353
+ resulting_borders = np.zeros((h, w), dtype=dtype)
354
+
355
+ for m in masks:
356
+ if min_size > 0 and m["area"] < min_size:
357
+ continue
358
+ if max_size is not None and m["area"] > max_size:
359
+ continue
360
+ mask = (m["segmentation"] > 0).astype(dtype)
361
+ resulting_mask += mask
362
+
363
+ # Apply erosion to the mask
364
+ if erosion_kernel is not None:
365
+ mask_erode = cv2.erode(mask, erosion_kernel, iterations=1)
366
+ mask_erode = (mask_erode > 0).astype(dtype)
367
+ edge_mask = mask - mask_erode
368
+ resulting_borders += edge_mask
369
+
370
+ resulting_mask = (resulting_mask > 0).astype(dtype)
371
+ resulting_borders = (resulting_borders > 0).astype(dtype)
372
+ objects = resulting_mask - resulting_borders
373
+ objects = objects * mask_multiplier
374
+
375
+ objects = objects.astype(dtype)
376
+ self.objects = objects
377
+
378
+ if output is not None: # Save the output image
379
+ array_to_image(self.objects, output, self.source, **kwargs)
380
+
381
+ def show_masks(
382
+ self,
383
+ figsize: Tuple[int, int] = (12, 10),
384
+ cmap: str = "binary_r",
385
+ axis: str = "off",
386
+ foreground: bool = True,
387
+ **kwargs: Any,
388
+ ) -> None:
389
+ """
390
+ Display the binary mask or the mask of objects with unique values.
391
+
392
+ Args:
393
+ figsize (Tuple[int, int], optional): The figure size. Defaults to (12, 10).
394
+ cmap (str, optional): The colormap to use for displaying the mask. Defaults to "binary_r".
395
+ axis (str, optional): Whether to show the axis. Defaults to "off".
396
+ foreground (bool, optional): Whether to show the foreground mask only. Defaults to True.
397
+ **kwargs (Any): Additional arguments for the `save_masks()` method.
398
+
399
+ Raises:
400
+ ValueError: If no masks are available and `save_masks()` cannot generate them.
401
+ """
402
+ import matplotlib.pyplot as plt
403
+
404
+ if self.batch:
405
+ self.objects = cv2.imread(self.masks)
406
+ else:
407
+ if self.objects is None:
408
+ self.save_masks(foreground=foreground, **kwargs)
409
+
410
+ plt.figure(figsize=figsize)
411
+ plt.imshow(self.objects, cmap=cmap)
412
+ plt.axis(axis)
413
+ plt.show()
414
+
415
+ def show_anns(
416
+ self,
417
+ figsize: Tuple[int, int] = (12, 10),
418
+ axis: str = "off",
419
+ alpha: float = 0.35,
420
+ output: Optional[str] = None,
421
+ blend: bool = True,
422
+ **kwargs: Any,
423
+ ) -> None:
424
+ """
425
+ Show the annotations (objects with random color) on the input image.
426
+
427
+ Args:
428
+ figsize (Tuple[int, int], optional): The figure size. Defaults to (12, 10).
429
+ axis (str, optional): Whether to show the axis. Defaults to "off".
430
+ alpha (float, optional): The alpha value for the annotations. Defaults to 0.35.
431
+ output (Optional[str], optional): The path to the output image. Defaults to None.
432
+ blend (bool, optional): Whether to show the input image blended with annotations. Defaults to True.
433
+ **kwargs (Any): Additional arguments for saving the output image.
434
+
435
+ Raises:
436
+ ValueError: If the input image or annotations are not available.
437
+ """
438
+
439
+ import matplotlib.pyplot as plt
440
+
441
+ anns = self.masks
442
+
443
+ if self.image is None:
444
+ print("Please run generate() first.")
445
+ return
446
+
447
+ if anns is None or len(anns) == 0:
448
+ return
449
+
450
+ plt.figure(figsize=figsize)
451
+ plt.imshow(self.image)
452
+
453
+ sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
454
+
455
+ ax = plt.gca()
456
+ ax.set_autoscale_on(False)
457
+
458
+ img = np.ones(
459
+ (
460
+ sorted_anns[0]["segmentation"].shape[0],
461
+ sorted_anns[0]["segmentation"].shape[1],
462
+ 4,
463
+ )
464
+ )
465
+ img[:, :, 3] = 0
466
+ for ann in sorted_anns:
467
+ if hasattr(self, "_min_size") and (ann["area"] < self._min_size):
468
+ continue
469
+ if (
470
+ hasattr(self, "_max_size")
471
+ and isinstance(self._max_size, int)
472
+ and ann["area"] > self._max_size
473
+ ):
474
+ continue
475
+ m = ann["segmentation"]
476
+ color_mask = np.concatenate([np.random.random(3), [alpha]])
477
+ img[m] = color_mask
478
+ ax.imshow(img)
479
+
480
+ # if "dpi" not in kwargs:
481
+ # kwargs["dpi"] = 100
482
+
483
+ # if "bbox_inches" not in kwargs:
484
+ # kwargs["bbox_inches"] = "tight"
485
+
486
+ plt.axis(axis)
487
+
488
+ self.annotations = (img[:, :, 0:3] * 255).astype(np.uint8)
489
+
490
+ if output is not None:
491
+ if blend:
492
+ array = blend_images(
493
+ self.annotations, self.image, alpha=alpha, show=False
494
+ )
495
+ else:
496
+ array = self.annotations
497
+ array_to_image(array, output, self.source, **kwargs)
498
+
499
+ def set_image(self, image: Union[str, np.ndarray], **kwargs: Any) -> None:
500
+ """
501
+ Set the input image as a numpy array.
502
+
503
+ Args:
504
+ image (Union[str, np.ndarray]): The input image, either as a file path (string) or a numpy array.
505
+ **kwargs (Any): Additional arguments for the image processor.
506
+
507
+ Raises:
508
+ ValueError: If the input image path does not exist.
509
+ """
510
+ if isinstance(image, str):
511
+ if image.startswith("http"):
512
+ image = download_file(image)
513
+
514
+ if not os.path.exists(image):
515
+ raise ValueError(f"Input path {image} does not exist.")
516
+
517
+ self.source = image
518
+
519
+ image = Image.open(image).convert("RGB")
520
+ self.image = image
521
+
522
+ inputs = self.processor(image, return_tensors="pt").to(self.device)
523
+ self.embeddings = self.predictor.get_image_embeddings(
524
+ inputs["pixel_values"], **kwargs
525
+ )
526
+
527
+ def save_prediction(
528
+ self,
529
+ output: str,
530
+ index: Optional[int] = None,
531
+ mask_multiplier: int = 255,
532
+ dtype: np.dtype = np.float32,
533
+ vector: Optional[str] = None,
534
+ simplify_tolerance: Optional[float] = None,
535
+ **kwargs: Any,
536
+ ) -> None:
537
+ """
538
+ Save the predicted mask to the output path.
539
+
540
+ Args:
541
+ output (str): The path to the output image.
542
+ index (Optional[int], optional): The index of the mask to save. Defaults to None,
543
+ which will save the mask with the highest score.
544
+ mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
545
+ Defaults to 255.
546
+ dtype (np.dtype, optional): The data type of the output image. Defaults to np.float32.
547
+ vector (Optional[str], optional): The path to the output vector file. Defaults to None.
548
+ simplify_tolerance (Optional[float], optional): The maximum allowed geometry displacement.
549
+ The higher this value, the smaller the number of vertices in the resulting geometry. Defaults to None.
550
+ **kwargs (Any): Additional arguments for saving the output image.
551
+
552
+ Raises:
553
+ ValueError: If no predictions are found.
554
+ """
555
+ if self.scores is None:
556
+ raise ValueError("No predictions found. Please run predict() first.")
557
+
558
+ if index is None:
559
+ index = self.scores.argmax(axis=0)
560
+
561
+ array = self.masks[index] * mask_multiplier
562
+ self.prediction = array
563
+ array_to_image(array, output, self.source, dtype=dtype, **kwargs)
564
+
565
+ if vector is not None:
566
+ raster_to_vector(output, vector, simplify_tolerance=simplify_tolerance)
567
+
568
+ def predict(
569
+ self,
570
+ point_coords=None,
571
+ point_labels=None,
572
+ boxes=None,
573
+ point_crs=None,
574
+ mask_input=None,
575
+ multimask_output=True,
576
+ return_logits=False,
577
+ output=None,
578
+ index=None,
579
+ mask_multiplier=255,
580
+ dtype="float32",
581
+ return_results=False,
582
+ **kwargs,
583
+ ):
584
+ """Predict masks for the given input prompts, using the currently set image.
585
+
586
+ Args:
587
+ point_coords (str | dict | list | np.ndarray, optional): A Nx2 array of point prompts to the
588
+ model. Each point is in (X,Y) in pixels. It can be a path to a vector file, a GeoJSON
589
+ dictionary, a list of coordinates [lon, lat], or a numpy array. Defaults to None.
590
+ point_labels (list | int | np.ndarray, optional): A length N array of labels for the
591
+ point prompts. 1 indicates a foreground point and 0 indicates a background point.
592
+ point_crs (str, optional): The coordinate reference system (CRS) of the point prompts.
593
+ boxes (list | np.ndarray, optional): A length 4 array given a box prompt to the
594
+ model, in XYXY format.
595
+ mask_input (np.ndarray, optional): A low resolution mask input to the model, typically
596
+ coming from a previous prediction iteration. Has form 1xHxW, where for SAM, H=W=256.
597
+ multimask_output (bool, optional): If true, the model will return three masks.
598
+ For ambiguous input prompts (such as a single click), this will often
599
+ produce better masks than a single prediction. If only a single
600
+ mask is needed, the model's predicted quality score can be used
601
+ to select the best mask. For non-ambiguous prompts, such as multiple
602
+ input prompts, multimask_output=False can give better results.
603
+ return_logits (bool, optional): If true, returns un-thresholded masks logits
604
+ instead of a binary mask.
605
+ output (str, optional): The path to the output image. Defaults to None.
606
+ index (index, optional): The index of the mask to save. Defaults to None,
607
+ which will save the mask with the highest score.
608
+ mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
609
+ dtype (np.dtype, optional): The data type of the output image. Defaults to np.float32.
610
+ return_results (bool, optional): Whether to return the predicted masks, scores, and logits. Defaults to False.
611
+
612
+ """
613
+ out_of_bounds = []
614
+
615
+ if isinstance(boxes, str):
616
+ gdf = gpd.read_file(boxes)
617
+ if gdf.crs is not None:
618
+ gdf = gdf.to_crs("epsg:4326")
619
+ boxes = gdf.geometry.bounds.values.tolist()
620
+ elif isinstance(boxes, dict):
621
+ import json
622
+
623
+ geojson = json.dumps(boxes)
624
+ gdf = gpd.read_file(geojson, driver="GeoJSON")
625
+ boxes = gdf.geometry.bounds.values.tolist()
626
+
627
+ if isinstance(point_coords, str):
628
+ point_coords = vector_to_geojson(point_coords)
629
+
630
+ if isinstance(point_coords, dict):
631
+ point_coords = geojson_to_coords(point_coords)
632
+
633
+ if hasattr(self, "point_coords"):
634
+ point_coords = self.point_coords
635
+
636
+ if hasattr(self, "point_labels"):
637
+ point_labels = self.point_labels
638
+
639
+ if (point_crs is not None) and (point_coords is not None):
640
+ point_coords, out_of_bounds = coords_to_xy(
641
+ self.source, point_coords, point_crs, return_out_of_bounds=True
642
+ )
643
+
644
+ if isinstance(point_coords, list):
645
+ point_coords = np.array(point_coords)
646
+
647
+ if point_coords is not None:
648
+ if point_labels is None:
649
+ point_labels = [1] * len(point_coords)
650
+ elif isinstance(point_labels, int):
651
+ point_labels = [point_labels] * len(point_coords)
652
+
653
+ if isinstance(point_labels, list):
654
+ if len(point_labels) != len(point_coords):
655
+ if len(point_labels) == 1:
656
+ point_labels = point_labels * len(point_coords)
657
+ elif len(out_of_bounds) > 0:
658
+ print(f"Removing {len(out_of_bounds)} out-of-bound points.")
659
+ point_labels_new = []
660
+ for i, p in enumerate(point_labels):
661
+ if i not in out_of_bounds:
662
+ point_labels_new.append(p)
663
+ point_labels = point_labels_new
664
+ else:
665
+ raise ValueError(
666
+ "The length of point_labels must be equal to the length of point_coords."
667
+ )
668
+ point_labels = np.array(point_labels)
669
+
670
+ predictor = self.predictor
671
+
672
+ input_boxes = None
673
+ if isinstance(boxes, list) and (point_crs is not None):
674
+ coords = bbox_to_xy(self.source, boxes, point_crs)
675
+ input_boxes = np.array(coords)
676
+ if isinstance(coords[0], int):
677
+ input_boxes = input_boxes[None, :]
678
+ else:
679
+ input_boxes = torch.tensor(input_boxes, device=self.device)
680
+ input_boxes = predictor.transform.apply_boxes_torch(
681
+ input_boxes, self.image.shape[:2]
682
+ )
683
+ elif isinstance(boxes, list) and (point_crs is None):
684
+ input_boxes = np.array(boxes)
685
+ if isinstance(boxes[0], int):
686
+ input_boxes = input_boxes[None, :]
687
+
688
+ self.boxes = input_boxes
689
+ self.point_coords = point_coords
690
+ self.point_labels = point_labels
691
+
692
+ if input_boxes is not None:
693
+ input_boxes = [input_boxes]
694
+
695
+ if point_coords is not None:
696
+ point_coords = [[point_coords]]
697
+ point_labels = [[point_labels]]
698
+
699
+ inputs = self.processor(
700
+ self.image,
701
+ input_points=point_coords,
702
+ # input_labels=point_labels,
703
+ input_boxes=input_boxes,
704
+ return_tensors="pt",
705
+ **kwargs,
706
+ ).to(self.device)
707
+
708
+ inputs.pop("pixel_values", None)
709
+ inputs.update({"image_embeddings": self.embeddings})
710
+
711
+ with torch.no_grad():
712
+ outputs = self.predictor(**inputs)
713
+
714
+ # https://huggingface.co/docs/transformers/en/model_doc/sam#transformers.SamImageProcessor.post_process_masks
715
+ self.masks = self.processor.image_processor.post_process_masks(
716
+ outputs.pred_masks.cpu(),
717
+ inputs["original_sizes"].cpu(),
718
+ inputs["reshaped_input_sizes"].cpu(),
719
+ )
720
+ self.scores = outputs.iou_scores
721
+
722
+ # if (
723
+ # boxes is None
724
+ # or (len(boxes) == 1)
725
+ # or (len(boxes) == 4 and isinstance(boxes[0], float))
726
+ # ):
727
+ # if isinstance(boxes, list) and isinstance(boxes[0], list):
728
+ # boxes = boxes[0]
729
+ # masks, scores, logits = predictor.predict(
730
+ # point_coords,
731
+ # point_labels,
732
+ # input_boxes,
733
+ # mask_input,
734
+ # multimask_output,
735
+ # return_logits,
736
+ # )
737
+ # else:
738
+ # masks, scores, logits = predictor.predict_torch(
739
+ # point_coords=point_coords,
740
+ # point_labels=point_coords,
741
+ # boxes=input_boxes,
742
+ # multimask_output=True,
743
+ # )
744
+
745
+ # self.masks = masks
746
+ # self.scores = scores
747
+ # self.logits = logits
748
+
749
+ # if output is not None:
750
+ # if boxes is None or (not isinstance(boxes[0], list)):
751
+ # self.save_prediction(output, index, mask_multiplier, dtype, **kwargs)
752
+ # else:
753
+ # self.tensor_to_numpy(
754
+ # index, output, mask_multiplier, dtype, save_args=kwargs
755
+ # )
756
+
757
+ # if return_results:
758
+ # return masks, scores, logits
759
+
760
+ def tensor_to_numpy(
761
+ self,
762
+ index: Optional[int] = None,
763
+ output: Optional[str] = None,
764
+ mask_multiplier: int = 255,
765
+ dtype: Union[str, np.dtype] = "uint8",
766
+ save_args: Optional[Dict[str, Any]] = None,
767
+ ) -> Optional[np.ndarray]:
768
+ """
769
+ Convert the predicted masks from tensors to numpy arrays.
770
+
771
+ Args:
772
+ index (Optional[int], optional): The index of the mask to save. Defaults to None,
773
+ which will save the mask with the highest score.
774
+ output (Optional[str], optional): The path to the output image. Defaults to None.
775
+ mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
776
+ Defaults to 255.
777
+ dtype (Union[str, np.dtype], optional): The data type of the output image. Defaults to "uint8".
778
+ save_args (Optional[Dict[str, Any]], optional): Optional arguments for saving the output image. Defaults to None.
779
+
780
+ Returns:
781
+ Optional[np.ndarray]: The predicted mask as a numpy array if `output` is None. Otherwise, saves the mask to the specified path.
782
+
783
+ Raises:
784
+ ValueError: If no objects are found in the image or if the masks are not available.
785
+ """
786
+
787
+ if save_args is None:
788
+ save_args = {}
789
+
790
+ if self.masks is None:
791
+ raise ValueError("No masks found. Please run the prediction method first.")
792
+
793
+ boxes = self.boxes
794
+ masks = self.masks
795
+
796
+ image_pil = self.image
797
+ image_np = np.array(image_pil)
798
+
799
+ if index is None:
800
+ index = 1
801
+
802
+ masks = masks[:, index, :, :]
803
+ masks = masks.squeeze(1)
804
+
805
+ if boxes is None or (len(boxes) == 0): # No "object" instances found
806
+ print("No objects found in the image.")
807
+ return
808
+ else:
809
+ # Create an empty image to store the mask overlays
810
+ mask_overlay = np.zeros_like(
811
+ image_np[..., 0], dtype=dtype
812
+ ) # Adjusted for single channel
813
+
814
+ for i, (box, mask) in enumerate(zip(boxes, masks)):
815
+ # Convert tensor to numpy array if necessary and ensure it contains integers
816
+ if isinstance(mask, torch.Tensor):
817
+ mask = (
818
+ mask.cpu().numpy().astype(dtype)
819
+ ) # If mask is on GPU, use .cpu() before .numpy()
820
+ mask_overlay += ((mask > 0) * (i + 1)).astype(
821
+ dtype
822
+ ) # Assign a unique value for each mask
823
+
824
+ # Normalize mask_overlay to be in [0, 255]
825
+ mask_overlay = (
826
+ mask_overlay > 0
827
+ ) * mask_multiplier # Binary mask in [0, 255]
828
+
829
+ if output is not None:
830
+ array_to_image(mask_overlay, output, self.source, dtype=dtype, **save_args)
831
+ else:
832
+ return mask_overlay