geoai-py 0.19.0__py2.py3-none-any.whl → 0.21.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/auto.py ADDED
@@ -0,0 +1,1982 @@
1
+ """Auto classes for geospatial model inference with GeoTIFF support.
2
+
3
+ This module provides AutoGeoModel and AutoGeoImageProcessor that extend
4
+ Hugging Face transformers' AutoModel and AutoImageProcessor to support
5
+ processing geospatial data (GeoTIFF) and saving outputs as GeoTIFF or vector data.
6
+
7
+ Supported tasks:
8
+ - Semantic segmentation (e.g., SegFormer, Mask2Former)
9
+ - Image classification (e.g., ViT, ResNet)
10
+ - Object detection (e.g., DETR, YOLOS)
11
+ - Zero-shot object detection (e.g., Grounding DINO, OWL-ViT)
12
+ - Depth estimation (e.g., Depth Anything, DPT)
13
+ - Mask generation (e.g., SAM)
14
+
15
+ Example:
16
+ >>> from geoai import AutoGeoModel
17
+ >>> model = AutoGeoModel.from_pretrained(
18
+ ... "nvidia/segformer-b0-finetuned-ade-512-512",
19
+ ... task="semantic-segmentation"
20
+ ... )
21
+ >>> result = model.predict("input.tif", output_path="output.tif")
22
+ """
23
+
24
+ import os
25
+ from typing import Any, Dict, List, Optional, Tuple, Union
26
+
27
+ import cv2
28
+ import geopandas as gpd
29
+ import numpy as np
30
+ import rasterio
31
+ import requests
32
+ import torch
33
+ from PIL import Image
34
+ from rasterio.features import shapes
35
+ from rasterio.windows import Window
36
+ from shapely.geometry import box, shape
37
+ from tqdm import tqdm
38
+ from transformers import (
39
+ AutoConfig,
40
+ AutoImageProcessor,
41
+ AutoModel,
42
+ AutoModelForImageClassification,
43
+ AutoModelForImageSegmentation,
44
+ AutoModelForSemanticSegmentation,
45
+ AutoModelForUniversalSegmentation,
46
+ AutoModelForDepthEstimation,
47
+ AutoModelForMaskGeneration,
48
+ AutoModelForObjectDetection,
49
+ AutoModelForZeroShotObjectDetection,
50
+ AutoProcessor,
51
+ )
52
+
53
+ from transformers.utils import logging as hf_logging
54
+
55
+ from .utils import get_device
56
+
57
+
58
+ hf_logging.set_verbosity_error() # silence HF load reports
59
+
60
+
61
+ class AutoGeoImageProcessor:
62
+ """
63
+ Image processor for geospatial data that wraps AutoImageProcessor.
64
+
65
+ This class provides functionality to load and preprocess GeoTIFF images
66
+ while preserving geospatial metadata (CRS, transform, bounds). It wraps
67
+ Hugging Face's AutoImageProcessor and adds geospatial capabilities.
68
+
69
+ Use `from_pretrained` to instantiate this class, following the transformers pattern.
70
+
71
+ Attributes:
72
+ processor: The underlying AutoImageProcessor instance.
73
+ device (str): The device being used ('cuda' or 'cpu').
74
+
75
+ Example:
76
+ >>> processor = AutoGeoImageProcessor.from_pretrained("facebook/sam-vit-base")
77
+ >>> data, metadata = processor.load_geotiff("input.tif")
78
+ >>> inputs = processor(data)
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ processor: "AutoImageProcessor",
84
+ processor_name: Optional[str] = None,
85
+ device: Optional[str] = None,
86
+ ) -> None:
87
+ """Initialize the AutoGeoImageProcessor with an existing processor.
88
+
89
+ Note: Use `from_pretrained` class method to load from Hugging Face Hub.
90
+
91
+ Args:
92
+ processor: An AutoImageProcessor instance.
93
+ processor_name: Name or path of the processor (for reference).
94
+ device: Device to use ('cuda', 'cpu'). If None, auto-detect.
95
+ """
96
+ self.processor = processor
97
+ self.processor_name = processor_name
98
+
99
+ if device is None:
100
+ self.device = get_device()
101
+ else:
102
+ self.device = device
103
+
104
+ @classmethod
105
+ def from_pretrained(
106
+ cls,
107
+ pretrained_model_name_or_path: str,
108
+ device: Optional[str] = None,
109
+ use_full_processor: bool = False,
110
+ **kwargs: Any,
111
+ ) -> "AutoGeoImageProcessor":
112
+ """Load an AutoGeoImageProcessor from a pretrained processor.
113
+
114
+ This method wraps AutoImageProcessor.from_pretrained and adds
115
+ geospatial capabilities for processing GeoTIFF files.
116
+
117
+ Args:
118
+ pretrained_model_name_or_path: Hugging Face model/processor name or local path.
119
+ Can be a model ID from huggingface.co or a local directory path.
120
+ device: Device to use ('cuda', 'cpu'). If None, auto-detect.
121
+ use_full_processor: If True, use AutoProcessor instead of AutoImageProcessor.
122
+ Required for models that need text inputs (e.g., Grounding DINO).
123
+ **kwargs: Additional arguments passed to AutoImageProcessor.from_pretrained.
124
+ Common options include:
125
+ - trust_remote_code (bool): Whether to trust remote code.
126
+ - revision (str): Specific model version to use.
127
+ - use_fast (bool): Whether to use fast tokenizer.
128
+
129
+ Returns:
130
+ AutoGeoImageProcessor instance with geospatial support.
131
+
132
+ Example:
133
+ >>> processor = AutoGeoImageProcessor.from_pretrained("facebook/sam-vit-base")
134
+ >>> processor = AutoGeoImageProcessor.from_pretrained(
135
+ ... "nvidia/segformer-b0-finetuned-ade-512-512",
136
+ ... device="cuda"
137
+ ... )
138
+ """
139
+ # Check if this is a model that needs the full processor (text + image)
140
+ model_name_lower = pretrained_model_name_or_path.lower()
141
+ needs_full_processor = use_full_processor or any(
142
+ name in model_name_lower
143
+ for name in ["grounding-dino", "owl", "clip", "blip"]
144
+ )
145
+
146
+ if needs_full_processor:
147
+ processor = AutoProcessor.from_pretrained(
148
+ pretrained_model_name_or_path, **kwargs
149
+ )
150
+ else:
151
+ try:
152
+ processor = AutoImageProcessor.from_pretrained(
153
+ pretrained_model_name_or_path, **kwargs
154
+ )
155
+ except Exception:
156
+ processor = AutoProcessor.from_pretrained(
157
+ pretrained_model_name_or_path, **kwargs
158
+ )
159
+ return cls(
160
+ processor=processor,
161
+ processor_name=pretrained_model_name_or_path,
162
+ device=device,
163
+ )
164
+
165
+ def load_geotiff(
166
+ self,
167
+ source: Union[str, "rasterio.DatasetReader"],
168
+ window: Optional[Window] = None,
169
+ bands: Optional[List[int]] = None,
170
+ ) -> Tuple[np.ndarray, Dict]:
171
+ """Load a GeoTIFF file and return data with metadata.
172
+
173
+ Args:
174
+ source: Path to GeoTIFF file or open rasterio DatasetReader.
175
+ window: Optional rasterio Window for reading a subset.
176
+ bands: List of band indices to read (1-indexed). If None, read all bands.
177
+
178
+ Returns:
179
+ Tuple of (image array in CHW format, metadata dict).
180
+
181
+ Example:
182
+ >>> processor = AutoGeoImageProcessor.from_pretrained("facebook/sam-vit-base")
183
+ >>> data, metadata = processor.load_geotiff("input.tif")
184
+ >>> print(data.shape) # (C, H, W)
185
+ >>> print(metadata['crs']) # CRS info
186
+ """
187
+ should_close = False
188
+ if isinstance(source, str):
189
+ src = rasterio.open(source)
190
+ should_close = True
191
+ else:
192
+ src = source
193
+
194
+ try:
195
+ # Read specified bands or all bands
196
+ if bands:
197
+ data = src.read(bands, window=window)
198
+ else:
199
+ data = src.read(window=window)
200
+
201
+ # Get profile and update for window
202
+ profile = src.profile.copy()
203
+ if window:
204
+ profile.update(
205
+ {
206
+ "height": window.height,
207
+ "width": window.width,
208
+ "transform": src.window_transform(window),
209
+ }
210
+ )
211
+
212
+ metadata = {
213
+ "profile": profile,
214
+ "crs": src.crs,
215
+ "transform": profile["transform"],
216
+ "bounds": (
217
+ src.bounds
218
+ if not window
219
+ else rasterio.windows.bounds(window, src.transform)
220
+ ),
221
+ "width": profile["width"],
222
+ "height": profile["height"],
223
+ "count": data.shape[0],
224
+ }
225
+
226
+ finally:
227
+ if should_close:
228
+ src.close()
229
+
230
+ return data, metadata
231
+
232
+ def load_image(
233
+ self,
234
+ source: Union[str, np.ndarray, Image.Image],
235
+ window: Optional[Window] = None,
236
+ bands: Optional[List[int]] = None,
237
+ ) -> Tuple[np.ndarray, Optional[Dict]]:
238
+ """Load an image from various sources.
239
+
240
+ Args:
241
+ source: Path to image file, numpy array, or PIL Image.
242
+ window: Optional rasterio Window (only for GeoTIFF).
243
+ bands: List of band indices (only for GeoTIFF, 1-indexed).
244
+
245
+ Returns:
246
+ Tuple of (image array in CHW format, metadata dict or None).
247
+ """
248
+ if isinstance(source, str):
249
+ # Check if GeoTIFF
250
+ try:
251
+ with rasterio.open(source) as src:
252
+ if src.crs is not None or source.lower().endswith(
253
+ (".tif", ".tiff")
254
+ ):
255
+ return self.load_geotiff(source, window, bands)
256
+ except (rasterio.RasterioIOError, rasterio.errors.RasterioIOError):
257
+ # If opening as GeoTIFF fails, fall back to loading as a regular image.
258
+ pass
259
+
260
+ # Load as regular image
261
+ image = Image.open(source).convert("RGB")
262
+ data = np.array(image).transpose(2, 0, 1) # HWC -> CHW
263
+ return data, None
264
+
265
+ elif isinstance(source, np.ndarray):
266
+ # Ensure CHW format
267
+ if source.ndim == 2:
268
+ source = source[np.newaxis, :, :]
269
+ elif source.ndim == 3 and source.shape[2] in [1, 3, 4]:
270
+ source = source.transpose(2, 0, 1)
271
+ return source, None
272
+
273
+ elif isinstance(source, Image.Image):
274
+ data = np.array(source.convert("RGB")).transpose(2, 0, 1)
275
+ return data, None
276
+
277
+ else:
278
+ raise TypeError(f"Unsupported source type: {type(source)}")
279
+
280
+ def prepare_for_model(
281
+ self,
282
+ data: np.ndarray,
283
+ normalize: bool = True,
284
+ to_rgb: bool = True,
285
+ percentile_clip: bool = True,
286
+ return_tensors: str = "pt",
287
+ ) -> Dict[str, Any]:
288
+ """Prepare image data for model input.
289
+
290
+ Args:
291
+ data: Image array in CHW format.
292
+ normalize: Whether to normalize pixel values.
293
+ to_rgb: Whether to convert to 3-channel RGB.
294
+ percentile_clip: Whether to use percentile clipping for normalization.
295
+ return_tensors: Return format ('pt' for PyTorch, 'np' for numpy).
296
+
297
+ Returns:
298
+ Dictionary with processed inputs ready for model.
299
+ """
300
+ # Convert to HWC format
301
+ if data.ndim == 3:
302
+ img = data.transpose(1, 2, 0) # CHW -> HWC
303
+ else:
304
+ img = data
305
+
306
+ # Handle different band counts
307
+ if img.ndim == 2:
308
+ img = np.stack([img] * 3, axis=-1)
309
+ elif img.shape[-1] == 1:
310
+ img = np.repeat(img, 3, axis=-1)
311
+ elif img.shape[-1] > 3:
312
+ img = img[..., :3]
313
+
314
+ # Normalize
315
+ if normalize:
316
+ if percentile_clip:
317
+ for i in range(img.shape[-1]):
318
+ band = img[..., i]
319
+ p2, p98 = np.percentile(band, [2, 98])
320
+ if p98 > p2:
321
+ img[..., i] = np.clip((band - p2) / (p98 - p2), 0, 1)
322
+ else:
323
+ img[..., i] = 0
324
+ img = (img * 255).astype(np.uint8)
325
+ elif img.dtype != np.uint8:
326
+ img = ((img - img.min()) / (img.max() - img.min()) * 255).astype(
327
+ np.uint8
328
+ )
329
+
330
+ # Convert to PIL Image
331
+ pil_image = Image.fromarray(img)
332
+
333
+ # Process with transformers processor
334
+ inputs = self.processor(images=pil_image, return_tensors=return_tensors)
335
+
336
+ return inputs
337
+
338
+ def __call__(
339
+ self,
340
+ images: Union[str, np.ndarray, Image.Image, List],
341
+ **kwargs: Any,
342
+ ) -> Dict[str, Any]:
343
+ """Process images for model input.
344
+
345
+ Args:
346
+ images: Single image or list of images (paths, arrays, or PIL Images).
347
+ **kwargs: Additional arguments passed to the processor.
348
+
349
+ Returns:
350
+ Processed inputs ready for model.
351
+ """
352
+ if isinstance(images, (str, np.ndarray, Image.Image)):
353
+ images = [images]
354
+
355
+ all_pil_images = []
356
+ for img in images:
357
+ if isinstance(img, str):
358
+ data, _ = self.load_image(img)
359
+ elif isinstance(img, np.ndarray):
360
+ data = img
361
+ if data.ndim == 3 and data.shape[2] in [1, 3, 4]:
362
+ data = data.transpose(2, 0, 1)
363
+ else:
364
+ data = np.array(img.convert("RGB")).transpose(2, 0, 1)
365
+
366
+ # Prepare PIL image
367
+ if data.ndim == 3:
368
+ img_arr = data.transpose(1, 2, 0)
369
+ else:
370
+ img_arr = data
371
+
372
+ if img_arr.ndim == 2:
373
+ img_arr = np.stack([img_arr] * 3, axis=-1)
374
+ elif img_arr.shape[-1] == 1:
375
+ img_arr = np.repeat(img_arr, 3, axis=-1)
376
+ elif img_arr.shape[-1] > 3:
377
+ img_arr = img_arr[..., :3]
378
+
379
+ # Normalize to uint8 if needed
380
+ if img_arr.dtype != np.uint8:
381
+ for i in range(img_arr.shape[-1]):
382
+ band = img_arr[..., i]
383
+ p2, p98 = np.percentile(band, [2, 98])
384
+ if p98 > p2:
385
+ img_arr[..., i] = np.clip((band - p2) / (p98 - p2), 0, 1)
386
+ else:
387
+ img_arr[..., i] = 0
388
+ img_arr = (img_arr * 255).astype(np.uint8)
389
+
390
+ all_pil_images.append(Image.fromarray(img_arr))
391
+
392
+ return self.processor(images=all_pil_images, **kwargs)
393
+
394
+ def save_geotiff(
395
+ self,
396
+ data: np.ndarray,
397
+ output_path: str,
398
+ metadata: Dict,
399
+ dtype: Optional[str] = None,
400
+ compress: str = "lzw",
401
+ nodata: Optional[float] = None,
402
+ ) -> str:
403
+ """Save array as GeoTIFF with geospatial metadata.
404
+
405
+ Args:
406
+ data: Array to save (2D or 3D in CHW format).
407
+ output_path: Output file path.
408
+ metadata: Metadata dictionary from load_geotiff.
409
+ dtype: Output data type. If None, infer from data.
410
+ compress: Compression method.
411
+ nodata: NoData value.
412
+
413
+ Returns:
414
+ Path to saved file.
415
+ """
416
+ profile = metadata["profile"].copy()
417
+
418
+ if dtype is None:
419
+ dtype = str(data.dtype)
420
+
421
+ # Handle 2D vs 3D arrays
422
+ if data.ndim == 2:
423
+ count = 1
424
+ height, width = data.shape
425
+ else:
426
+ count = data.shape[0]
427
+ height, width = data.shape[1], data.shape[2]
428
+
429
+ profile.update(
430
+ {
431
+ "dtype": dtype,
432
+ "count": count,
433
+ "height": height,
434
+ "width": width,
435
+ "compress": compress,
436
+ }
437
+ )
438
+
439
+ if nodata is not None:
440
+ profile["nodata"] = nodata
441
+
442
+ os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
443
+
444
+ with rasterio.open(output_path, "w", **profile) as dst:
445
+ if data.ndim == 2:
446
+ dst.write(data, 1)
447
+ else:
448
+ dst.write(data)
449
+
450
+ return output_path
451
+
452
+
453
+ class AutoGeoModel:
454
+ """
455
+ Auto model for geospatial inference with GeoTIFF support.
456
+
457
+ This class wraps Hugging Face transformers' AutoModel classes and adds
458
+ geospatial capabilities for processing GeoTIFF images and saving results
459
+ as georeferenced outputs (GeoTIFF or vector data).
460
+
461
+ Use `from_pretrained` to instantiate this class, following the transformers pattern.
462
+
463
+ Attributes:
464
+ model: The underlying transformers model instance.
465
+ processor: AutoGeoImageProcessor for preprocessing.
466
+ device (str): The device being used ('cuda' or 'cpu').
467
+ task (str): The task type.
468
+ tile_size (int): Size of tiles for processing large images.
469
+ overlap (int): Overlap between tiles.
470
+
471
+ Example:
472
+ >>> model = AutoGeoModel.from_pretrained(
473
+ ... "facebook/sam-vit-base",
474
+ ... task="mask-generation"
475
+ ... )
476
+ >>> result = model.predict("input.tif", output_path="output.tif")
477
+ """
478
+
479
+ TASK_MODEL_MAPPING = {
480
+ "segmentation": AutoModelForSemanticSegmentation,
481
+ "semantic-segmentation": AutoModelForSemanticSegmentation,
482
+ "image-segmentation": AutoModelForImageSegmentation,
483
+ "universal-segmentation": AutoModelForUniversalSegmentation,
484
+ "depth-estimation": AutoModelForDepthEstimation,
485
+ "mask-generation": AutoModelForMaskGeneration,
486
+ "object-detection": AutoModelForObjectDetection,
487
+ "zero-shot-object-detection": AutoModelForZeroShotObjectDetection,
488
+ "classification": AutoModelForImageClassification,
489
+ "image-classification": AutoModelForImageClassification,
490
+ }
491
+
492
+ def __init__(
493
+ self,
494
+ model: torch.nn.Module,
495
+ processor: Optional["AutoGeoImageProcessor"] = None,
496
+ model_name: Optional[str] = None,
497
+ task: Optional[str] = None,
498
+ device: Optional[str] = None,
499
+ tile_size: int = 1024,
500
+ overlap: int = 128,
501
+ ) -> None:
502
+ """Initialize AutoGeoModel with an existing model.
503
+
504
+ Note: Use `from_pretrained` class method to load from Hugging Face Hub.
505
+
506
+ Args:
507
+ model: A transformers model instance.
508
+ processor: An AutoGeoImageProcessor instance (optional).
509
+ model_name: Name or path of the model (for reference).
510
+ task: Task type for the model.
511
+ device: Device to use ('cuda', 'cpu'). If None, auto-detect.
512
+ tile_size: Size of tiles for processing large images.
513
+ overlap: Overlap between tiles.
514
+ """
515
+ self.model = model
516
+ self.processor = processor
517
+ self.model_name = model_name
518
+ self.task = task
519
+ self.tile_size = tile_size
520
+ self.overlap = overlap
521
+
522
+ if device is None:
523
+ self.device = get_device()
524
+ else:
525
+ self.device = device
526
+
527
+ # Ensure model is on the correct device and in eval mode
528
+ self.model = self.model.to(self.device)
529
+ self.model.eval()
530
+
531
+ @classmethod
532
+ def from_pretrained(
533
+ cls,
534
+ pretrained_model_name_or_path: str,
535
+ task: Optional[str] = None,
536
+ device: Optional[str] = None,
537
+ tile_size: int = 1024,
538
+ overlap: int = 128,
539
+ **kwargs: Any,
540
+ ) -> "AutoGeoModel":
541
+ """Load an AutoGeoModel from a pretrained model.
542
+
543
+ This method wraps transformers' AutoModel.from_pretrained and adds
544
+ geospatial capabilities for processing GeoTIFF files.
545
+
546
+ Args:
547
+ pretrained_model_name_or_path: Hugging Face model name or local path.
548
+ Can be a model ID from huggingface.co or a local directory path.
549
+ task: Task type for automatic model class selection. Options:
550
+ - 'segmentation' or 'semantic-segmentation': Semantic segmentation
551
+ - 'image-segmentation': General image segmentation
552
+ - 'universal-segmentation': Universal segmentation (Mask2Former, etc.)
553
+ - 'depth-estimation': Depth estimation
554
+ - 'mask-generation': Mask generation (SAM, etc.)
555
+ - 'object-detection': Object detection
556
+ - 'zero-shot-object-detection': Zero-shot object detection
557
+ - 'classification' or 'image-classification': Image classification
558
+ If None, will try to infer from model config.
559
+ device: Device to use ('cuda', 'cpu'). If None, auto-detect.
560
+ tile_size: Size of tiles for processing large images.
561
+ overlap: Overlap between tiles to avoid edge artifacts.
562
+ **kwargs: Additional arguments passed to the model's from_pretrained.
563
+ Common options include:
564
+ - trust_remote_code (bool): Whether to trust remote code.
565
+ - revision (str): Specific model version to use.
566
+ - torch_dtype: Data type for model weights.
567
+
568
+ Returns:
569
+ AutoGeoModel instance with geospatial support.
570
+
571
+ Example:
572
+ >>> model = AutoGeoModel.from_pretrained("facebook/sam-vit-base", task="mask-generation")
573
+ >>> model = AutoGeoModel.from_pretrained(
574
+ ... "nvidia/segformer-b0-finetuned-ade-512-512",
575
+ ... task="semantic-segmentation",
576
+ ... device="cuda"
577
+ ... )
578
+ """
579
+ # Determine device
580
+ if device is None:
581
+ device = get_device()
582
+
583
+ # Load model using appropriate auto class
584
+ model = cls._load_model_from_pretrained(
585
+ pretrained_model_name_or_path, task, **kwargs
586
+ )
587
+
588
+ # Load processor - use full processor for models that need text inputs
589
+ needs_full_processor = task in (
590
+ "zero-shot-object-detection",
591
+ "object-detection",
592
+ )
593
+ try:
594
+ processor = AutoGeoImageProcessor.from_pretrained(
595
+ pretrained_model_name_or_path,
596
+ device=device,
597
+ use_full_processor=needs_full_processor,
598
+ )
599
+ except Exception:
600
+ processor = None
601
+
602
+ instance = cls(
603
+ model=model,
604
+ processor=processor,
605
+ model_name=pretrained_model_name_or_path,
606
+ task=task,
607
+ device=device,
608
+ tile_size=tile_size,
609
+ overlap=overlap,
610
+ )
611
+
612
+ print(f"Model loaded on {device}")
613
+ return instance
614
+
615
+ @classmethod
616
+ def _load_model_from_pretrained(
617
+ cls,
618
+ model_name_or_path: str,
619
+ task: Optional[str] = None,
620
+ **kwargs: Any,
621
+ ) -> torch.nn.Module:
622
+ """Load the appropriate model based on task using from_pretrained."""
623
+ if task and task in cls.TASK_MODEL_MAPPING:
624
+ model_class = cls.TASK_MODEL_MAPPING[task]
625
+ return model_class.from_pretrained(model_name_or_path, **kwargs)
626
+
627
+ # Try to infer from config
628
+ try:
629
+ config = AutoConfig.from_pretrained(model_name_or_path)
630
+ architectures = getattr(config, "architectures", [])
631
+
632
+ if any("Segmentation" in arch for arch in architectures):
633
+ return AutoModelForSemanticSegmentation.from_pretrained(
634
+ model_name_or_path, **kwargs
635
+ )
636
+ elif any("Detection" in arch for arch in architectures):
637
+ return AutoModelForObjectDetection.from_pretrained(
638
+ model_name_or_path, **kwargs
639
+ )
640
+ elif any("Classification" in arch for arch in architectures):
641
+ return AutoModelForImageClassification.from_pretrained(
642
+ model_name_or_path, **kwargs
643
+ )
644
+ else:
645
+ return AutoModel.from_pretrained(model_name_or_path, **kwargs)
646
+ except Exception:
647
+ return AutoModel.from_pretrained(model_name_or_path, **kwargs)
648
+
649
+ def predict(
650
+ self,
651
+ source: Union[str, np.ndarray, Image.Image],
652
+ output_path: Optional[str] = None,
653
+ output_vector_path: Optional[str] = None,
654
+ window: Optional[Window] = None,
655
+ bands: Optional[List[int]] = None,
656
+ threshold: float = 0.5,
657
+ text: Optional[str] = None,
658
+ labels: Optional[List[str]] = None,
659
+ box_threshold: float = 0.3,
660
+ text_threshold: float = 0.25,
661
+ min_object_area: int = 100,
662
+ simplify_tolerance: float = 1.0,
663
+ batch_size: int = 1,
664
+ return_probabilities: bool = False,
665
+ **kwargs: Any,
666
+ ) -> Dict[str, Any]:
667
+ """Run inference on a GeoTIFF or image.
668
+
669
+ Args:
670
+ source: Input image (path, array, or PIL Image). Can also be a URL.
671
+ output_path: Path to save output GeoTIFF.
672
+ output_vector_path: Path to save output vector file (GeoJSON, GPKG, etc.).
673
+ window: Optional rasterio Window for processing a subset.
674
+ bands: List of band indices to use (1-indexed).
675
+ threshold: Threshold for binary masks (segmentation tasks).
676
+ text: Text prompt for zero-shot detection models (e.g., "a cat. a dog.").
677
+ For Grounding DINO, labels should be lowercase and end with a dot.
678
+ labels: List of labels to detect (alternative to text).
679
+ Will be converted to text prompt format automatically.
680
+ box_threshold: Confidence threshold for bounding boxes (detection tasks).
681
+ text_threshold: Text similarity threshold for zero-shot detection.
682
+ min_object_area: Minimum object area in pixels for vectorization.
683
+ simplify_tolerance: Tolerance for polygon simplification.
684
+ batch_size: Batch size for processing tiles.
685
+ return_probabilities: Whether to return probability maps.
686
+ **kwargs: Additional arguments for specific tasks.
687
+
688
+ Returns:
689
+ Dictionary with results including mask/detections, metadata, and optional vector data.
690
+
691
+ Example:
692
+ >>> # Zero-shot object detection
693
+ >>> model = AutoGeoModel.from_pretrained(
694
+ ... "IDEA-Research/grounding-dino-base",
695
+ ... task="zero-shot-object-detection"
696
+ ... )
697
+ >>> result = model.predict(
698
+ ... "image.jpg",
699
+ ... text="a building. a car. a tree.",
700
+ ... box_threshold=0.3
701
+ ... )
702
+ >>> print(result["boxes"], result["labels"])
703
+ """
704
+ # Convert labels list to text format if provided
705
+ if labels is not None and text is None:
706
+ text = " ".join(f"{label.lower().strip()}." for label in labels)
707
+
708
+ # Handle zero-shot object detection separately
709
+ if self.task in ("zero-shot-object-detection", "object-detection"):
710
+ return self._predict_detection(
711
+ source,
712
+ text=text,
713
+ box_threshold=box_threshold,
714
+ text_threshold=text_threshold,
715
+ output_vector_path=output_vector_path,
716
+ **kwargs,
717
+ )
718
+
719
+ # Load image (handles URLs, local files, arrays, PIL Images)
720
+ pil_image = None
721
+ if isinstance(source, str):
722
+ # Check if URL
723
+ if source.startswith(("http://", "https://")):
724
+ pil_image = Image.open(requests.get(source, stream=True).raw)
725
+ metadata = None
726
+ data = np.array(pil_image.convert("RGB")).transpose(2, 0, 1)
727
+ else:
728
+ # Local file - try to load with geospatial info
729
+ if self.processor is not None:
730
+ data, metadata = self.processor.load_image(source, window, bands)
731
+ else:
732
+ try:
733
+ with rasterio.open(source) as src:
734
+ data = src.read(bands) if bands else src.read()
735
+ profile = src.profile.copy()
736
+ metadata = {
737
+ "profile": profile,
738
+ "crs": src.crs,
739
+ "transform": src.transform,
740
+ "bounds": src.bounds,
741
+ "width": src.width,
742
+ "height": src.height,
743
+ }
744
+ except Exception:
745
+ # Fall back to PIL for regular images
746
+ pil_image = Image.open(source).convert("RGB")
747
+ data = np.array(pil_image).transpose(2, 0, 1)
748
+ metadata = None
749
+ elif isinstance(source, Image.Image):
750
+ pil_image = source
751
+ data = np.array(source.convert("RGB")).transpose(2, 0, 1)
752
+ metadata = None
753
+ else:
754
+ data = np.array(source)
755
+ if data.ndim == 3 and data.shape[2] in [1, 3, 4]:
756
+ data = data.transpose(2, 0, 1)
757
+ metadata = None
758
+
759
+ # Check if we need tiled processing (not for classification tasks)
760
+ if data.ndim == 3:
761
+ _, height, width = data.shape
762
+ elif data.ndim == 2:
763
+ height, width = data.shape
764
+ _ = 1
765
+ else:
766
+ raise ValueError(f"Unexpected data shape: {data.shape}")
767
+
768
+ # Classification tasks should not use tiled processing
769
+ use_tiled = (
770
+ height > self.tile_size or width > self.tile_size
771
+ ) and self.task not in ("classification", "image-classification")
772
+
773
+ if use_tiled:
774
+ result = self._predict_tiled(
775
+ source,
776
+ data,
777
+ metadata,
778
+ threshold=threshold,
779
+ batch_size=batch_size,
780
+ return_probabilities=return_probabilities,
781
+ **kwargs,
782
+ )
783
+ else:
784
+ result = self._predict_single(
785
+ data,
786
+ metadata,
787
+ threshold=threshold,
788
+ return_probabilities=return_probabilities,
789
+ **kwargs,
790
+ )
791
+
792
+ # Save GeoTIFF output
793
+ if output_path and metadata:
794
+ self.save_geotiff(
795
+ result.get("mask", result.get("output")),
796
+ output_path,
797
+ metadata,
798
+ nodata=0,
799
+ )
800
+ result["output_path"] = output_path
801
+
802
+ # Save vector output
803
+ if output_vector_path and metadata and "mask" in result:
804
+ gdf = self.mask_to_vector(
805
+ result["mask"],
806
+ metadata,
807
+ threshold=threshold,
808
+ min_object_area=min_object_area,
809
+ simplify_tolerance=simplify_tolerance,
810
+ )
811
+ if gdf is not None and len(gdf) > 0:
812
+ gdf.to_file(output_vector_path)
813
+ result["vector_path"] = output_vector_path
814
+ result["geodataframe"] = gdf
815
+
816
+ return result
817
+
818
+ def _predict_single(
819
+ self,
820
+ data: np.ndarray,
821
+ metadata: Optional[Dict],
822
+ threshold: float = 0.5,
823
+ return_probabilities: bool = False,
824
+ **kwargs: Any,
825
+ ) -> Dict[str, Any]:
826
+ """Run inference on a single image."""
827
+ # Prepare input
828
+ if self.processor is not None:
829
+ inputs = self.processor.prepare_for_model(data)
830
+ else:
831
+ # Fallback preparation
832
+ img = data.transpose(1, 2, 0) if data.ndim == 3 else data
833
+ if img.ndim == 2:
834
+ img = np.stack([img] * 3, axis=-1)
835
+ elif img.shape[-1] > 3:
836
+ img = img[..., :3]
837
+
838
+ if img.dtype != np.uint8:
839
+ img = ((img - img.min()) / (img.max() - img.min()) * 255).astype(
840
+ np.uint8
841
+ )
842
+
843
+ pil_image = Image.fromarray(img)
844
+ inputs = {
845
+ "pixel_values": torch.from_numpy(
846
+ np.array(pil_image).transpose(2, 0, 1) / 255.0
847
+ )
848
+ .float()
849
+ .unsqueeze(0)
850
+ }
851
+
852
+ # Move to device
853
+ inputs = {
854
+ k: v.to(self.device) if isinstance(v, torch.Tensor) else v
855
+ for k, v in inputs.items()
856
+ }
857
+
858
+ # Run inference
859
+ with torch.no_grad():
860
+ outputs = self.model(**inputs)
861
+
862
+ # Process outputs based on model type
863
+ result = self._process_outputs(
864
+ outputs, data.shape, threshold, return_probabilities
865
+ )
866
+ result["metadata"] = metadata
867
+
868
+ return result
869
+
870
+ def _predict_tiled(
871
+ self,
872
+ source: Union[str, np.ndarray],
873
+ data: np.ndarray,
874
+ metadata: Optional[Dict],
875
+ threshold: float = 0.5,
876
+ batch_size: int = 1,
877
+ return_probabilities: bool = False,
878
+ **kwargs: Any,
879
+ ) -> Dict[str, Any]:
880
+ """Run tiled inference for large images."""
881
+ if data.ndim == 3:
882
+ _, height, width = data.shape
883
+ else:
884
+ height, width = data.shape
885
+
886
+ effective_tile_size = self.tile_size - 2 * self.overlap
887
+
888
+ n_tiles_x = max(1, int(np.ceil(width / effective_tile_size)))
889
+ n_tiles_y = max(1, int(np.ceil(height / effective_tile_size)))
890
+ total_tiles = n_tiles_x * n_tiles_y
891
+
892
+ # Initialize output arrays
893
+ mask_output = np.zeros((height, width), dtype=np.float32)
894
+ count_output = np.zeros((height, width), dtype=np.float32)
895
+
896
+ print(f"Processing {total_tiles} tiles ({n_tiles_x}x{n_tiles_y})")
897
+
898
+ with tqdm(total=total_tiles, desc="Processing tiles") as pbar:
899
+ for y in range(n_tiles_y):
900
+ for x in range(n_tiles_x):
901
+ # Calculate tile coordinates
902
+ x_start = max(0, x * effective_tile_size - self.overlap)
903
+ y_start = max(0, y * effective_tile_size - self.overlap)
904
+ x_end = min(width, (x + 1) * effective_tile_size + self.overlap)
905
+ y_end = min(height, (y + 1) * effective_tile_size + self.overlap)
906
+
907
+ # Extract tile
908
+ if data.ndim == 3:
909
+ tile = data[:, y_start:y_end, x_start:x_end]
910
+ else:
911
+ tile = data[y_start:y_end, x_start:x_end]
912
+
913
+ try:
914
+ # Run inference on tile
915
+ tile_result = self._predict_single(
916
+ tile, None, threshold, return_probabilities
917
+ )
918
+
919
+ tile_mask = tile_result.get("mask", tile_result.get("output"))
920
+ if tile_mask is not None:
921
+ # Handle different dimensions
922
+ if tile_mask.ndim > 2:
923
+ tile_mask = tile_mask.squeeze()
924
+ if tile_mask.ndim > 2:
925
+ tile_mask = tile_mask[0]
926
+
927
+ # Resize if necessary
928
+ tile_h, tile_w = y_end - y_start, x_end - x_start
929
+ if tile_mask.shape != (tile_h, tile_w):
930
+ tile_mask = cv2.resize(
931
+ tile_mask.astype(np.float32),
932
+ (tile_w, tile_h),
933
+ interpolation=cv2.INTER_LINEAR,
934
+ )
935
+
936
+ # Add to output with blending
937
+ mask_output[y_start:y_end, x_start:x_end] += tile_mask
938
+ count_output[y_start:y_end, x_start:x_end] += 1
939
+
940
+ except Exception as e:
941
+ print(f"Error processing tile ({x}, {y}): {e}")
942
+
943
+ pbar.update(1)
944
+
945
+ # Average overlapping regions
946
+ count_output = np.maximum(count_output, 1)
947
+ mask_output = mask_output / count_output
948
+
949
+ result = {
950
+ "mask": (mask_output > threshold).astype(np.uint8),
951
+ "probabilities": mask_output if return_probabilities else None,
952
+ "metadata": metadata,
953
+ }
954
+
955
+ return result
956
+
957
+ def _predict_detection(
958
+ self,
959
+ source: Union[str, np.ndarray, Image.Image],
960
+ text: Optional[str] = None,
961
+ box_threshold: float = 0.3,
962
+ text_threshold: float = 0.25,
963
+ output_vector_path: Optional[str] = None,
964
+ **kwargs: Any,
965
+ ) -> Dict[str, Any]:
966
+ """Run object detection inference.
967
+
968
+ Args:
969
+ source: Input image (path, URL, array, or PIL Image).
970
+ text: Text prompt for zero-shot detection (e.g., "a cat. a dog.").
971
+ box_threshold: Confidence threshold for bounding boxes.
972
+ text_threshold: Text similarity threshold for zero-shot detection.
973
+ output_vector_path: Path to save detection results as vector file.
974
+ **kwargs: Additional arguments.
975
+
976
+ Returns:
977
+ Dictionary with boxes, scores, labels, and optional GeoDataFrame.
978
+ """
979
+ # Load image
980
+ if isinstance(source, str):
981
+ if source.startswith(("http://", "https://")):
982
+ pil_image = Image.open(requests.get(source, stream=True).raw)
983
+ else:
984
+ try:
985
+ # Try loading with rasterio for GeoTIFF
986
+ with rasterio.open(source) as src:
987
+ data = src.read()
988
+ if data.shape[0] > 3:
989
+ data = data[:3]
990
+ elif data.shape[0] == 1:
991
+ data = np.repeat(data, 3, axis=0)
992
+ # Normalize to uint8
993
+ if data.dtype != np.uint8:
994
+ for i in range(data.shape[0]):
995
+ band = data[i].astype(np.float32)
996
+ p2, p98 = np.percentile(band, [2, 98])
997
+ if p98 > p2:
998
+ data[i] = np.clip(
999
+ (band - p2) / (p98 - p2) * 255, 0, 255
1000
+ )
1001
+ else:
1002
+ data[i] = 0
1003
+ data = data.astype(np.uint8)
1004
+ pil_image = Image.fromarray(data.transpose(1, 2, 0))
1005
+ except Exception:
1006
+ pil_image = Image.open(source)
1007
+ elif isinstance(source, np.ndarray):
1008
+ if source.ndim == 3 and source.shape[0] in [1, 3, 4]:
1009
+ source = source.transpose(1, 2, 0)
1010
+ if source.dtype != np.uint8:
1011
+ source = (
1012
+ (source - source.min()) / (source.max() - source.min()) * 255
1013
+ ).astype(np.uint8)
1014
+ pil_image = Image.fromarray(source)
1015
+ else:
1016
+ pil_image = source
1017
+
1018
+ pil_image = pil_image.convert("RGB")
1019
+ image_size = pil_image.size[::-1] # (height, width)
1020
+
1021
+ # Get the underlying processor
1022
+ processor = self.processor.processor if self.processor else None
1023
+ if processor is None:
1024
+ # Load processor directly if not available
1025
+ processor = AutoProcessor.from_pretrained(self.model_name)
1026
+
1027
+ # Prepare inputs based on task type
1028
+ if self.task == "zero-shot-object-detection":
1029
+ if text is None:
1030
+ raise ValueError(
1031
+ "Text prompt is required for zero-shot object detection. "
1032
+ "Provide text='a cat. a dog.' or labels=['cat', 'dog']"
1033
+ )
1034
+ # Use the processor to prepare inputs with text
1035
+ inputs = processor(images=pil_image, text=text, return_tensors="pt")
1036
+ else:
1037
+ # Standard object detection
1038
+ inputs = processor(images=pil_image, return_tensors="pt")
1039
+
1040
+ # Move to device - use .to() method for BatchFeature objects
1041
+ inputs = inputs.to(self.device)
1042
+
1043
+ # Run inference
1044
+ with torch.no_grad():
1045
+ outputs = self.model(**inputs)
1046
+
1047
+ # Post-process results
1048
+ result = self._process_detection_outputs(
1049
+ outputs,
1050
+ inputs,
1051
+ text,
1052
+ image_size,
1053
+ box_threshold,
1054
+ text_threshold,
1055
+ processor,
1056
+ )
1057
+
1058
+ # Convert to GeoDataFrame if requested
1059
+ if output_vector_path and result.get("boxes") is not None:
1060
+ gdf = self._detections_to_geodataframe(result, pil_image.size)
1061
+ if gdf is not None and len(gdf) > 0:
1062
+ gdf.to_file(output_vector_path)
1063
+ result["vector_path"] = output_vector_path
1064
+ result["geodataframe"] = gdf
1065
+
1066
+ return result
1067
+
1068
+ def _process_detection_outputs(
1069
+ self,
1070
+ outputs: Any,
1071
+ inputs: Dict,
1072
+ text: Optional[str],
1073
+ image_size: Tuple[int, int],
1074
+ box_threshold: float,
1075
+ text_threshold: float,
1076
+ processor: Any = None,
1077
+ ) -> Dict[str, Any]:
1078
+ """Process detection model outputs."""
1079
+ result = {}
1080
+
1081
+ if processor is None:
1082
+ processor = self.processor.processor if self.processor else None
1083
+
1084
+ if self.task == "zero-shot-object-detection":
1085
+ # Use processor's post-processing for grounded detection
1086
+ try:
1087
+ results = processor.post_process_grounded_object_detection(
1088
+ outputs,
1089
+ inputs["input_ids"],
1090
+ threshold=box_threshold, # box confidence threshold
1091
+ text_threshold=text_threshold,
1092
+ target_sizes=[image_size],
1093
+ )
1094
+ if results and len(results) > 0:
1095
+ r = results[0]
1096
+ result["boxes"] = r["boxes"].cpu().numpy()
1097
+ result["scores"] = r["scores"].cpu().numpy()
1098
+ # Handle different output formats for labels
1099
+ if "labels" in r:
1100
+ result["labels"] = r["labels"]
1101
+ elif "text_labels" in r:
1102
+ result["labels"] = r["text_labels"]
1103
+ else:
1104
+ # Extract labels from logits if not provided
1105
+ result["labels"] = [
1106
+ f"object_{i}" for i in range(len(r["boxes"]))
1107
+ ]
1108
+ except Exception as e:
1109
+ # Fallback for models without grounded post-processing
1110
+ print(f"Warning: Using fallback detection processing: {e}")
1111
+ if hasattr(outputs, "pred_boxes"):
1112
+ boxes = outputs.pred_boxes[0].cpu().numpy()
1113
+ logits = outputs.logits[0].cpu()
1114
+ scores = logits.sigmoid().max(dim=-1).values.numpy()
1115
+ mask = scores > box_threshold
1116
+ result["boxes"] = boxes[mask]
1117
+ result["scores"] = scores[mask]
1118
+ result["labels"] = [
1119
+ f"object_{i}" for i in range(len(result["boxes"]))
1120
+ ]
1121
+ else:
1122
+ # Standard object detection post-processing
1123
+ if hasattr(outputs, "pred_boxes") and processor is not None:
1124
+ target_sizes = torch.tensor([image_size], device=self.device)
1125
+ results = processor.post_process_object_detection(
1126
+ outputs, threshold=box_threshold, target_sizes=target_sizes
1127
+ )
1128
+ if results and len(results) > 0:
1129
+ r = results[0]
1130
+ result["boxes"] = r["boxes"].cpu().numpy()
1131
+ result["scores"] = r["scores"].cpu().numpy()
1132
+ result["labels"] = r["labels"].cpu().numpy()
1133
+
1134
+ return result
1135
+
1136
+ def _detections_to_geodataframe(
1137
+ self,
1138
+ detections: Dict[str, Any],
1139
+ image_size: Tuple[int, int],
1140
+ ) -> Optional[gpd.GeoDataFrame]:
1141
+ """Convert detection results to a GeoDataFrame.
1142
+
1143
+ Note: Without geospatial metadata, coordinates are in pixel space.
1144
+ """
1145
+ boxes = detections.get("boxes")
1146
+ if boxes is None or len(boxes) == 0:
1147
+ return None
1148
+
1149
+ scores = detections.get("scores", [None] * len(boxes))
1150
+ labels = detections.get("labels", [None] * len(boxes))
1151
+
1152
+ geometries = []
1153
+ for bbox in boxes:
1154
+ # Convert [x1, y1, x2, y2] to polygon
1155
+ x1, y1, x2, y2 = bbox
1156
+ geometries.append(box(x1, y1, x2, y2))
1157
+
1158
+ gdf = gpd.GeoDataFrame(
1159
+ {
1160
+ "geometry": geometries,
1161
+ "score": scores,
1162
+ "label": labels,
1163
+ }
1164
+ )
1165
+
1166
+ return gdf
1167
+
1168
+ def _process_outputs(
1169
+ self,
1170
+ outputs: Any,
1171
+ input_shape: Tuple,
1172
+ threshold: float = 0.5,
1173
+ return_probabilities: bool = False,
1174
+ ) -> Dict[str, Any]:
1175
+ """Process model outputs to extract masks or predictions."""
1176
+ result = {}
1177
+
1178
+ # Handle different output types
1179
+ if hasattr(outputs, "logits"):
1180
+ logits = outputs.logits
1181
+ if logits.dim() == 4: # Segmentation output
1182
+ # Upsample if needed
1183
+ if logits.shape[2:] != input_shape[1:]:
1184
+ logits = torch.nn.functional.interpolate(
1185
+ logits,
1186
+ size=(input_shape[1], input_shape[2]),
1187
+ mode="bilinear",
1188
+ align_corners=False,
1189
+ )
1190
+
1191
+ probs = torch.softmax(logits, dim=1)
1192
+ mask = probs.argmax(dim=1).squeeze().cpu().numpy()
1193
+ result["mask"] = mask.astype(np.uint8)
1194
+
1195
+ if return_probabilities:
1196
+ result["probabilities"] = probs.squeeze().cpu().numpy()
1197
+
1198
+ elif logits.dim() == 2: # Classification output
1199
+ probs = torch.softmax(logits, dim=-1)
1200
+ pred_class = probs.argmax(dim=-1).item()
1201
+ result["class"] = pred_class
1202
+ result["probabilities"] = probs.squeeze().cpu().numpy()
1203
+
1204
+ elif hasattr(outputs, "pred_masks"):
1205
+ masks = outputs.pred_masks.squeeze().cpu().numpy()
1206
+ if masks.ndim == 3:
1207
+ mask = masks.max(axis=0)
1208
+ else:
1209
+ mask = masks
1210
+ result["mask"] = (mask > threshold).astype(np.uint8)
1211
+
1212
+ if return_probabilities:
1213
+ result["probabilities"] = mask
1214
+
1215
+ elif hasattr(outputs, "predicted_depth"):
1216
+ depth = outputs.predicted_depth.squeeze().cpu().numpy()
1217
+ result["output"] = depth
1218
+ result["depth"] = depth
1219
+
1220
+ elif hasattr(outputs, "masks"):
1221
+ # SAM-like output
1222
+ masks = outputs.masks
1223
+ if isinstance(masks, torch.Tensor):
1224
+ masks = masks.cpu().numpy()
1225
+ if masks.ndim == 4:
1226
+ masks = masks.squeeze(0)
1227
+ if masks.ndim == 3:
1228
+ mask = masks.max(axis=0)
1229
+ else:
1230
+ mask = masks
1231
+ result["mask"] = (mask > threshold).astype(np.uint8)
1232
+ result["all_masks"] = masks
1233
+
1234
+ else:
1235
+ # Generic output handling
1236
+ if hasattr(outputs, "last_hidden_state"):
1237
+ result["features"] = outputs.last_hidden_state.cpu().numpy()
1238
+ else:
1239
+ result["output"] = outputs
1240
+
1241
+ return result
1242
+
1243
+ def mask_to_vector(
1244
+ self,
1245
+ mask: np.ndarray,
1246
+ metadata: Dict,
1247
+ threshold: float = 0.5,
1248
+ min_object_area: int = 100,
1249
+ max_object_area: Optional[int] = None,
1250
+ simplify_tolerance: float = 1.0,
1251
+ ) -> Optional[gpd.GeoDataFrame]:
1252
+ """Convert a raster mask to vector polygons.
1253
+
1254
+ Args:
1255
+ mask: Binary or probability mask array.
1256
+ metadata: Geospatial metadata dictionary.
1257
+ threshold: Threshold for binarizing probability masks.
1258
+ min_object_area: Minimum area in pixels for valid objects.
1259
+ max_object_area: Maximum area in pixels (optional).
1260
+ simplify_tolerance: Tolerance for polygon simplification.
1261
+
1262
+ Returns:
1263
+ GeoDataFrame with polygon geometries, or None if no valid polygons.
1264
+ """
1265
+ if metadata is None or metadata.get("crs") is None:
1266
+ print("Warning: No CRS information available for vectorization")
1267
+ return None
1268
+
1269
+ # Ensure binary mask
1270
+ if mask.dtype == np.float32 or mask.dtype == np.float64:
1271
+ mask = (mask > threshold).astype(np.uint8)
1272
+ else:
1273
+ mask = (mask > 0).astype(np.uint8)
1274
+
1275
+ # Get transform
1276
+ transform = metadata.get("transform")
1277
+ crs = metadata.get("crs")
1278
+
1279
+ if transform is None:
1280
+ print("Warning: No transform available for vectorization")
1281
+ return None
1282
+
1283
+ # Extract shapes using rasterio
1284
+ polygons = []
1285
+ values = []
1286
+
1287
+ try:
1288
+ for geom, value in shapes(mask, transform=transform):
1289
+ if value > 0: # Only keep non-background
1290
+ poly = shape(geom)
1291
+
1292
+ # Filter by area
1293
+ pixel_area = poly.area / (transform.a * abs(transform.e))
1294
+ if pixel_area < min_object_area:
1295
+ continue
1296
+ if max_object_area and pixel_area > max_object_area:
1297
+ continue
1298
+
1299
+ # Simplify
1300
+ if simplify_tolerance > 0:
1301
+ poly = poly.simplify(
1302
+ simplify_tolerance * abs(transform.a),
1303
+ preserve_topology=True,
1304
+ )
1305
+
1306
+ if poly.is_valid and not poly.is_empty:
1307
+ polygons.append(poly)
1308
+ values.append(value)
1309
+
1310
+ except Exception as e:
1311
+ print(f"Error during vectorization: {e}")
1312
+ return None
1313
+
1314
+ if not polygons:
1315
+ return None
1316
+
1317
+ # Create GeoDataFrame
1318
+ gdf = gpd.GeoDataFrame(
1319
+ {"geometry": polygons, "class": values},
1320
+ crs=crs,
1321
+ )
1322
+
1323
+ return gdf
1324
+
1325
+ def save_geotiff(
1326
+ self,
1327
+ data: np.ndarray,
1328
+ output_path: str,
1329
+ metadata: Dict,
1330
+ dtype: Optional[str] = None,
1331
+ compress: str = "lzw",
1332
+ nodata: Optional[float] = None,
1333
+ ) -> str:
1334
+ """Save array as GeoTIFF with geospatial metadata.
1335
+
1336
+ Args:
1337
+ data: Array to save (2D or 3D in CHW format).
1338
+ output_path: Output file path.
1339
+ metadata: Metadata dictionary from load_geotiff.
1340
+ dtype: Output data type. If None, infer from data.
1341
+ compress: Compression method.
1342
+ nodata: NoData value.
1343
+
1344
+ Returns:
1345
+ Path to saved file.
1346
+ """
1347
+ profile = metadata["profile"].copy()
1348
+
1349
+ if dtype is None:
1350
+ dtype = str(data.dtype)
1351
+
1352
+ # Handle 2D vs 3D arrays
1353
+ if data.ndim == 2:
1354
+ count = 1
1355
+ height, width = data.shape
1356
+ else:
1357
+ count = data.shape[0]
1358
+ height, width = data.shape[1], data.shape[2]
1359
+
1360
+ profile.update(
1361
+ {
1362
+ "dtype": dtype,
1363
+ "count": count,
1364
+ "height": height,
1365
+ "width": width,
1366
+ "compress": compress,
1367
+ }
1368
+ )
1369
+
1370
+ if nodata is not None:
1371
+ profile["nodata"] = nodata
1372
+
1373
+ os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
1374
+
1375
+ with rasterio.open(output_path, "w", **profile) as dst:
1376
+ if data.ndim == 2:
1377
+ dst.write(data, 1)
1378
+ else:
1379
+ dst.write(data)
1380
+
1381
+ return output_path
1382
+
1383
+ def save_vector(
1384
+ self,
1385
+ gdf: gpd.GeoDataFrame,
1386
+ output_path: str,
1387
+ driver: Optional[str] = None,
1388
+ ) -> str:
1389
+ """Save GeoDataFrame to file.
1390
+
1391
+ Args:
1392
+ gdf: GeoDataFrame to save.
1393
+ output_path: Output file path.
1394
+ driver: File driver (auto-detected from extension if None).
1395
+
1396
+ Returns:
1397
+ Path to saved file.
1398
+ """
1399
+ os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
1400
+
1401
+ if driver is None:
1402
+ ext = os.path.splitext(output_path)[1].lower()
1403
+ driver_map = {
1404
+ ".geojson": "GeoJSON",
1405
+ ".json": "GeoJSON",
1406
+ ".gpkg": "GPKG",
1407
+ ".shp": "ESRI Shapefile",
1408
+ ".parquet": "Parquet",
1409
+ ".fgb": "FlatGeobuf",
1410
+ }
1411
+ driver = driver_map.get(ext, "GeoJSON")
1412
+
1413
+ gdf.to_file(output_path, driver=driver)
1414
+ return output_path
1415
+
1416
+
1417
+ def semantic_segmentation(
1418
+ input_path: str,
1419
+ output_path: str,
1420
+ model_name: str = "nvidia/segformer-b0-finetuned-ade-512-512",
1421
+ output_vector_path: Optional[str] = None,
1422
+ threshold: float = 0.5,
1423
+ tile_size: int = 1024,
1424
+ overlap: int = 128,
1425
+ min_object_area: int = 100,
1426
+ simplify_tolerance: float = 1.0,
1427
+ device: Optional[str] = None,
1428
+ **kwargs: Any,
1429
+ ) -> Dict[str, Any]:
1430
+ """
1431
+ Perform semantic segmentation on a GeoTIFF image.
1432
+
1433
+ Args:
1434
+ input_path: Path to input GeoTIFF.
1435
+ output_path: Path to save output segmentation GeoTIFF.
1436
+ model_name: Hugging Face model name.
1437
+ output_vector_path: Optional path to save vectorized output.
1438
+ threshold: Threshold for binary masks.
1439
+ tile_size: Size of tiles for processing large images.
1440
+ overlap: Overlap between tiles.
1441
+ min_object_area: Minimum object area for vectorization.
1442
+ simplify_tolerance: Tolerance for polygon simplification.
1443
+ device: Device to use ('cuda', 'cpu').
1444
+ **kwargs: Additional arguments for prediction.
1445
+
1446
+ Returns:
1447
+ Dictionary with results.
1448
+
1449
+ Example:
1450
+ >>> result = semantic_segmentation(
1451
+ ... "input.tif",
1452
+ ... "output.tif",
1453
+ ... model_name="nvidia/segformer-b0-finetuned-ade-512-512",
1454
+ ... output_vector_path="output.geojson"
1455
+ ... )
1456
+ """
1457
+ model = AutoGeoModel.from_pretrained(
1458
+ model_name,
1459
+ task="semantic-segmentation",
1460
+ device=device,
1461
+ tile_size=tile_size,
1462
+ overlap=overlap,
1463
+ )
1464
+
1465
+ return model.predict(
1466
+ input_path,
1467
+ output_path=output_path,
1468
+ output_vector_path=output_vector_path,
1469
+ threshold=threshold,
1470
+ min_object_area=min_object_area,
1471
+ simplify_tolerance=simplify_tolerance,
1472
+ **kwargs,
1473
+ )
1474
+
1475
+
1476
+ def depth_estimation(
1477
+ input_path: str,
1478
+ output_path: str,
1479
+ model_name: str = "depth-anything/Depth-Anything-V2-Small-hf",
1480
+ tile_size: int = 1024,
1481
+ overlap: int = 128,
1482
+ device: Optional[str] = None,
1483
+ **kwargs: Any,
1484
+ ) -> Dict[str, Any]:
1485
+ """
1486
+ Perform depth estimation on a GeoTIFF image.
1487
+
1488
+ Args:
1489
+ input_path: Path to input GeoTIFF.
1490
+ output_path: Path to save output depth GeoTIFF.
1491
+ model_name: Hugging Face model name.
1492
+ tile_size: Size of tiles for processing large images.
1493
+ overlap: Overlap between tiles.
1494
+ device: Device to use ('cuda', 'cpu').
1495
+ **kwargs: Additional arguments for prediction.
1496
+
1497
+ Returns:
1498
+ Dictionary with results.
1499
+
1500
+ Example:
1501
+ >>> result = depth_estimation(
1502
+ ... "input.tif",
1503
+ ... "depth_output.tif",
1504
+ ... model_name="depth-anything/Depth-Anything-V2-Small-hf"
1505
+ ... )
1506
+ """
1507
+ model = AutoGeoModel.from_pretrained(
1508
+ model_name,
1509
+ task="depth-estimation",
1510
+ device=device,
1511
+ tile_size=tile_size,
1512
+ overlap=overlap,
1513
+ )
1514
+
1515
+ return model.predict(input_path, output_path=output_path, **kwargs)
1516
+
1517
+
1518
+ def image_classification(
1519
+ input_path: str,
1520
+ model_name: str = "google/vit-base-patch16-224",
1521
+ device: Optional[str] = None,
1522
+ **kwargs: Any,
1523
+ ) -> Dict[str, Any]:
1524
+ """
1525
+ Perform image classification on a GeoTIFF image.
1526
+
1527
+ Args:
1528
+ input_path: Path to input GeoTIFF.
1529
+ model_name: Hugging Face model name.
1530
+ device: Device to use ('cuda', 'cpu').
1531
+ **kwargs: Additional arguments for prediction.
1532
+
1533
+ Returns:
1534
+ Dictionary with classification results.
1535
+
1536
+ Example:
1537
+ >>> result = image_classification(
1538
+ ... "input.tif",
1539
+ ... model_name="google/vit-base-patch16-224"
1540
+ ... )
1541
+ >>> print(result['class'], result['probabilities'])
1542
+ """
1543
+ model = AutoGeoModel.from_pretrained(
1544
+ model_name,
1545
+ task="image-classification",
1546
+ device=device,
1547
+ )
1548
+
1549
+ return model.predict(input_path, **kwargs)
1550
+
1551
+
1552
+ def object_detection(
1553
+ input_path: str,
1554
+ text: Optional[str] = None,
1555
+ labels: Optional[List[str]] = None,
1556
+ model_name: str = "IDEA-Research/grounding-dino-base",
1557
+ output_vector_path: Optional[str] = None,
1558
+ box_threshold: float = 0.3,
1559
+ text_threshold: float = 0.25,
1560
+ device: Optional[str] = None,
1561
+ **kwargs: Any,
1562
+ ) -> Dict[str, Any]:
1563
+ """
1564
+ Perform object detection on an image using Grounding DINO or similar models.
1565
+
1566
+ Args:
1567
+ input_path: Path to input image or URL.
1568
+ text: Text prompt for detection (e.g., "a building. a car.").
1569
+ Labels should be lowercase and end with a dot.
1570
+ labels: List of labels to detect (alternative to text).
1571
+ Will be converted to text format automatically.
1572
+ model_name: Hugging Face model name.
1573
+ output_vector_path: Optional path to save detection boxes as vector file.
1574
+ box_threshold: Confidence threshold for bounding boxes.
1575
+ text_threshold: Text similarity threshold for zero-shot detection.
1576
+ device: Device to use ('cuda', 'cpu').
1577
+ **kwargs: Additional arguments for prediction.
1578
+
1579
+ Returns:
1580
+ Dictionary with detection results (boxes, scores, labels).
1581
+
1582
+ Example:
1583
+ >>> result = object_detection(
1584
+ ... "image.jpg",
1585
+ ... labels=["car", "building", "tree"],
1586
+ ... box_threshold=0.3
1587
+ ... )
1588
+ >>> print(result["boxes"], result["labels"])
1589
+ """
1590
+ # Determine task type based on model
1591
+ task = "zero-shot-object-detection"
1592
+ if "grounding-dino" not in model_name.lower() and "owl" not in model_name.lower():
1593
+ task = "object-detection"
1594
+
1595
+ model = AutoGeoModel.from_pretrained(
1596
+ model_name,
1597
+ task=task,
1598
+ device=device,
1599
+ )
1600
+
1601
+ return model.predict(
1602
+ input_path,
1603
+ text=text,
1604
+ labels=labels,
1605
+ box_threshold=box_threshold,
1606
+ text_threshold=text_threshold,
1607
+ output_vector_path=output_vector_path,
1608
+ **kwargs,
1609
+ )
1610
+
1611
+
1612
+ def get_hf_tasks() -> List[str]:
1613
+ """Get all supported Hugging Face tasks for this module.
1614
+
1615
+ Returns:
1616
+ List of supported task names.
1617
+ """
1618
+ from transformers.pipelines import SUPPORTED_TASKS
1619
+
1620
+ return sorted(list(SUPPORTED_TASKS.keys()))
1621
+
1622
+
1623
+ def get_hf_model_config(model_id: str) -> Dict[str, Any]:
1624
+ """Get the model configuration for a Hugging Face model.
1625
+
1626
+ Args:
1627
+ model_id: The Hugging Face model ID (e.g., "facebook/sam-vit-base").
1628
+
1629
+ Returns:
1630
+ Dictionary representation of the model config.
1631
+
1632
+ Example:
1633
+ >>> config = get_hf_model_config("nvidia/segformer-b0-finetuned-ade-512-512")
1634
+ >>> print(config.get("model_type"))
1635
+ """
1636
+ cfg = AutoConfig.from_pretrained(model_id)
1637
+ return cfg.to_dict()
1638
+
1639
+
1640
+ # =============================================================================
1641
+ # Visualization Functions
1642
+ # =============================================================================
1643
+
1644
+
1645
+ def _load_image_for_display(
1646
+ source: Union[str, np.ndarray, Image.Image],
1647
+ ) -> Tuple[np.ndarray, Optional[Dict]]:
1648
+ """Load an image for display purposes.
1649
+
1650
+ Args:
1651
+ source: Image source (path, array, or PIL Image).
1652
+
1653
+ Returns:
1654
+ Tuple of (RGB image array in HWC format, metadata dict or None).
1655
+ """
1656
+ metadata = None
1657
+
1658
+ if isinstance(source, str):
1659
+ try:
1660
+ with rasterio.open(source) as src:
1661
+ data = src.read()
1662
+ metadata = {
1663
+ "crs": src.crs,
1664
+ "transform": src.transform,
1665
+ "bounds": src.bounds,
1666
+ }
1667
+ # Convert to HWC
1668
+ if data.shape[0] > 3:
1669
+ data = data[:3]
1670
+ elif data.shape[0] == 1:
1671
+ data = np.repeat(data, 3, axis=0)
1672
+ img = data.transpose(1, 2, 0)
1673
+
1674
+ # Normalize to uint8
1675
+ if img.dtype != np.uint8:
1676
+ for i in range(img.shape[-1]):
1677
+ band = img[..., i].astype(np.float32)
1678
+ p2, p98 = np.percentile(band, [2, 98])
1679
+ if p98 > p2:
1680
+ img[..., i] = np.clip(
1681
+ (band - p2) / (p98 - p2) * 255, 0, 255
1682
+ )
1683
+ else:
1684
+ img[..., i] = 0
1685
+ img = img.astype(np.uint8)
1686
+ return img, metadata
1687
+ except Exception:
1688
+ pass
1689
+
1690
+ # Try as regular image
1691
+ img = np.array(Image.open(source).convert("RGB"))
1692
+ return img, None
1693
+
1694
+ elif isinstance(source, np.ndarray):
1695
+ if source.ndim == 3 and source.shape[0] in [1, 3, 4]:
1696
+ source = source.transpose(1, 2, 0)
1697
+ if source.ndim == 2:
1698
+ source = np.stack([source] * 3, axis=-1)
1699
+ elif source.shape[-1] > 3:
1700
+ source = source[..., :3]
1701
+ if source.dtype != np.uint8:
1702
+ source = (
1703
+ (source - source.min()) / (source.max() - source.min() + 1e-8) * 255
1704
+ ).astype(np.uint8)
1705
+ return source, None
1706
+
1707
+ elif isinstance(source, Image.Image):
1708
+ return np.array(source.convert("RGB")), None
1709
+
1710
+ else:
1711
+ raise TypeError(f"Unsupported source type: {type(source)}")
1712
+
1713
+
1714
+ def show_image(
1715
+ source: Union[str, np.ndarray, Image.Image],
1716
+ figsize: Tuple[int, int] = (10, 10),
1717
+ title: Optional[str] = None,
1718
+ axis_off: bool = True,
1719
+ **kwargs: Any,
1720
+ ) -> "plt.Figure":
1721
+ """Display an image (GeoTIFF or regular image).
1722
+
1723
+ Args:
1724
+ source: Image source (path to file, numpy array, or PIL Image).
1725
+ figsize: Figure size as (width, height).
1726
+ title: Optional title for the plot.
1727
+ axis_off: Whether to hide axes.
1728
+ **kwargs: Additional arguments passed to plt.imshow().
1729
+
1730
+ Returns:
1731
+ Matplotlib figure object.
1732
+
1733
+ Example:
1734
+ >>> fig = show_image("aerial.tif", title="Aerial Image")
1735
+ """
1736
+ import matplotlib.pyplot as plt
1737
+
1738
+ img, _ = _load_image_for_display(source)
1739
+
1740
+ fig, ax = plt.subplots(figsize=figsize)
1741
+ ax.imshow(img, **kwargs)
1742
+
1743
+ if title:
1744
+ ax.set_title(title)
1745
+ if axis_off:
1746
+ ax.axis("off")
1747
+
1748
+ plt.tight_layout()
1749
+ return fig
1750
+
1751
+
1752
+ def show_detections(
1753
+ source: Union[str, np.ndarray, Image.Image],
1754
+ detections: Dict[str, Any],
1755
+ figsize: Tuple[int, int] = (12, 10),
1756
+ title: Optional[str] = None,
1757
+ box_color: str = "red",
1758
+ text_color: str = "white",
1759
+ linewidth: int = 2,
1760
+ fontsize: int = 10,
1761
+ show_scores: bool = True,
1762
+ axis_off: bool = True,
1763
+ **kwargs: Any,
1764
+ ) -> "plt.Figure":
1765
+ """Display an image with detection bounding boxes.
1766
+
1767
+ Args:
1768
+ source: Image source (path to file, numpy array, or PIL Image).
1769
+ detections: Detection results dictionary with 'boxes', 'scores', 'labels'.
1770
+ figsize: Figure size as (width, height).
1771
+ title: Optional title for the plot.
1772
+ box_color: Color for bounding boxes (can be single color or list).
1773
+ text_color: Color for label text.
1774
+ linewidth: Width of bounding box lines.
1775
+ fontsize: Font size for labels.
1776
+ show_scores: Whether to show confidence scores.
1777
+ axis_off: Whether to hide axes.
1778
+ **kwargs: Additional arguments passed to plt.imshow().
1779
+
1780
+ Returns:
1781
+ Matplotlib figure object.
1782
+
1783
+ Example:
1784
+ >>> result = geoai.auto.object_detection("aerial.tif", labels=["building", "tree"])
1785
+ >>> fig = show_detections("aerial.tif", result)
1786
+ """
1787
+ import matplotlib.pyplot as plt
1788
+ import matplotlib.patches as patches
1789
+
1790
+ img, _ = _load_image_for_display(source)
1791
+
1792
+ fig, ax = plt.subplots(figsize=figsize)
1793
+ ax.imshow(img, **kwargs)
1794
+
1795
+ boxes = detections.get("boxes", [])
1796
+ scores = detections.get("scores", [None] * len(boxes))
1797
+ labels = detections.get("labels", [None] * len(boxes))
1798
+
1799
+ # Handle color list
1800
+ if isinstance(box_color, str):
1801
+ colors = [box_color] * len(boxes)
1802
+ else:
1803
+ colors = box_color
1804
+
1805
+ for i, (bbox, score, label) in enumerate(zip(boxes, scores, labels)):
1806
+ x1, y1, x2, y2 = bbox
1807
+ width = x2 - x1
1808
+ height = y2 - y1
1809
+
1810
+ color = colors[i % len(colors)]
1811
+
1812
+ rect = patches.Rectangle(
1813
+ (x1, y1),
1814
+ width,
1815
+ height,
1816
+ linewidth=linewidth,
1817
+ edgecolor=color,
1818
+ facecolor="none",
1819
+ )
1820
+ ax.add_patch(rect)
1821
+
1822
+ # Add label
1823
+ if label is not None:
1824
+ text = str(label)
1825
+ if show_scores and score is not None:
1826
+ text = f"{label}: {score:.2f}"
1827
+ ax.text(
1828
+ x1,
1829
+ y1 - 5,
1830
+ text,
1831
+ color=text_color,
1832
+ fontsize=fontsize,
1833
+ bbox=dict(boxstyle="round,pad=0.3", facecolor=color, alpha=0.7),
1834
+ )
1835
+
1836
+ if title:
1837
+ ax.set_title(title)
1838
+ if axis_off:
1839
+ ax.axis("off")
1840
+
1841
+ plt.tight_layout()
1842
+ return fig
1843
+
1844
+
1845
+ def show_segmentation(
1846
+ source: Union[str, np.ndarray, Image.Image],
1847
+ mask: np.ndarray,
1848
+ figsize: Tuple[int, int] = (14, 6),
1849
+ title: Optional[str] = None,
1850
+ alpha: float = 0.5,
1851
+ cmap: str = "tab20",
1852
+ show_original: bool = True,
1853
+ axis_off: bool = True,
1854
+ **kwargs: Any,
1855
+ ) -> "plt.Figure":
1856
+ """Display segmentation results overlaid on the original image.
1857
+
1858
+ Args:
1859
+ source: Image source (path to file, numpy array, or PIL Image).
1860
+ mask: Segmentation mask array.
1861
+ figsize: Figure size as (width, height).
1862
+ title: Optional title for the plot.
1863
+ alpha: Transparency of the mask overlay.
1864
+ cmap: Colormap for the segmentation mask.
1865
+ show_original: Whether to show original image side-by-side.
1866
+ axis_off: Whether to hide axes.
1867
+ **kwargs: Additional arguments passed to plt.imshow().
1868
+
1869
+ Returns:
1870
+ Matplotlib figure object.
1871
+
1872
+ Example:
1873
+ >>> result = geoai.auto.semantic_segmentation("aerial.tif", output_path="seg.tif")
1874
+ >>> fig = show_segmentation("aerial.tif", result["mask"])
1875
+ """
1876
+ import matplotlib.pyplot as plt
1877
+
1878
+ img, _ = _load_image_for_display(source)
1879
+
1880
+ # Resize mask if necessary
1881
+ if mask.shape[:2] != img.shape[:2]:
1882
+ mask = cv2.resize(
1883
+ mask.astype(np.float32),
1884
+ (img.shape[1], img.shape[0]),
1885
+ interpolation=cv2.INTER_NEAREST,
1886
+ ).astype(mask.dtype)
1887
+
1888
+ if show_original:
1889
+ fig, axes = plt.subplots(1, 2, figsize=figsize)
1890
+
1891
+ axes[0].imshow(img, **kwargs)
1892
+ axes[0].set_title("Original Image")
1893
+ if axis_off:
1894
+ axes[0].axis("off")
1895
+
1896
+ axes[1].imshow(img, **kwargs)
1897
+ axes[1].imshow(mask, alpha=alpha, cmap=cmap)
1898
+ axes[1].set_title(title or "Segmentation Overlay")
1899
+ if axis_off:
1900
+ axes[1].axis("off")
1901
+ else:
1902
+ fig, ax = plt.subplots(figsize=figsize)
1903
+ ax.imshow(img, **kwargs)
1904
+ ax.imshow(mask, alpha=alpha, cmap=cmap)
1905
+ if title:
1906
+ ax.set_title(title)
1907
+ if axis_off:
1908
+ ax.axis("off")
1909
+
1910
+ plt.tight_layout()
1911
+ return fig
1912
+
1913
+
1914
+ def show_depth(
1915
+ source: Union[str, np.ndarray, Image.Image],
1916
+ depth: np.ndarray,
1917
+ figsize: Tuple[int, int] = (14, 6),
1918
+ title: Optional[str] = None,
1919
+ cmap: str = "plasma",
1920
+ show_original: bool = True,
1921
+ show_colorbar: bool = True,
1922
+ axis_off: bool = True,
1923
+ **kwargs: Any,
1924
+ ) -> "plt.Figure":
1925
+ """Display depth estimation results.
1926
+
1927
+ Args:
1928
+ source: Image source (path to file, numpy array, or PIL Image).
1929
+ depth: Depth map array.
1930
+ figsize: Figure size as (width, height).
1931
+ title: Optional title for the plot.
1932
+ cmap: Colormap for the depth map.
1933
+ show_original: Whether to show original image side-by-side.
1934
+ show_colorbar: Whether to show a colorbar.
1935
+ axis_off: Whether to hide axes.
1936
+ **kwargs: Additional arguments passed to plt.imshow().
1937
+
1938
+ Returns:
1939
+ Matplotlib figure object.
1940
+
1941
+ Example:
1942
+ >>> result = geoai.auto.depth_estimation("aerial.tif", output_path="depth.tif")
1943
+ >>> fig = show_depth("aerial.tif", result["depth"])
1944
+ """
1945
+ import matplotlib.pyplot as plt
1946
+
1947
+ img, _ = _load_image_for_display(source)
1948
+
1949
+ # Resize depth if necessary
1950
+ if depth.shape[:2] != img.shape[:2]:
1951
+ depth = cv2.resize(
1952
+ depth.astype(np.float32),
1953
+ (img.shape[1], img.shape[0]),
1954
+ interpolation=cv2.INTER_LINEAR,
1955
+ )
1956
+
1957
+ if show_original:
1958
+ fig, axes = plt.subplots(1, 2, figsize=figsize)
1959
+
1960
+ axes[0].imshow(img, **kwargs)
1961
+ axes[0].set_title("Original Image")
1962
+ if axis_off:
1963
+ axes[0].axis("off")
1964
+
1965
+ im = axes[1].imshow(depth, cmap=cmap)
1966
+ axes[1].set_title(title or "Depth Estimation")
1967
+ if axis_off:
1968
+ axes[1].axis("off")
1969
+ if show_colorbar:
1970
+ plt.colorbar(im, ax=axes[1], label="Relative Depth")
1971
+ else:
1972
+ fig, ax = plt.subplots(figsize=figsize)
1973
+ im = ax.imshow(depth, cmap=cmap)
1974
+ if title:
1975
+ ax.set_title(title)
1976
+ if axis_off:
1977
+ ax.axis("off")
1978
+ if show_colorbar:
1979
+ plt.colorbar(im, ax=ax, label="Relative Depth")
1980
+
1981
+ plt.tight_layout()
1982
+ return fig