geoai-py 0.20.0__py2.py3-none-any.whl → 0.21.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +1 -1
- geoai/auto.py +1982 -0
- {geoai_py-0.20.0.dist-info → geoai_py-0.21.0.dist-info}/METADATA +8 -2
- {geoai_py-0.20.0.dist-info → geoai_py-0.21.0.dist-info}/RECORD +8 -7
- {geoai_py-0.20.0.dist-info → geoai_py-0.21.0.dist-info}/WHEEL +0 -0
- {geoai_py-0.20.0.dist-info → geoai_py-0.21.0.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.20.0.dist-info → geoai_py-0.21.0.dist-info}/licenses/LICENSE +0 -0
- {geoai_py-0.20.0.dist-info → geoai_py-0.21.0.dist-info}/top_level.txt +0 -0
geoai/auto.py
ADDED
|
@@ -0,0 +1,1982 @@
|
|
|
1
|
+
"""Auto classes for geospatial model inference with GeoTIFF support.
|
|
2
|
+
|
|
3
|
+
This module provides AutoGeoModel and AutoGeoImageProcessor that extend
|
|
4
|
+
Hugging Face transformers' AutoModel and AutoImageProcessor to support
|
|
5
|
+
processing geospatial data (GeoTIFF) and saving outputs as GeoTIFF or vector data.
|
|
6
|
+
|
|
7
|
+
Supported tasks:
|
|
8
|
+
- Semantic segmentation (e.g., SegFormer, Mask2Former)
|
|
9
|
+
- Image classification (e.g., ViT, ResNet)
|
|
10
|
+
- Object detection (e.g., DETR, YOLOS)
|
|
11
|
+
- Zero-shot object detection (e.g., Grounding DINO, OWL-ViT)
|
|
12
|
+
- Depth estimation (e.g., Depth Anything, DPT)
|
|
13
|
+
- Mask generation (e.g., SAM)
|
|
14
|
+
|
|
15
|
+
Example:
|
|
16
|
+
>>> from geoai import AutoGeoModel
|
|
17
|
+
>>> model = AutoGeoModel.from_pretrained(
|
|
18
|
+
... "nvidia/segformer-b0-finetuned-ade-512-512",
|
|
19
|
+
... task="semantic-segmentation"
|
|
20
|
+
... )
|
|
21
|
+
>>> result = model.predict("input.tif", output_path="output.tif")
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
import os
|
|
25
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
26
|
+
|
|
27
|
+
import cv2
|
|
28
|
+
import geopandas as gpd
|
|
29
|
+
import numpy as np
|
|
30
|
+
import rasterio
|
|
31
|
+
import requests
|
|
32
|
+
import torch
|
|
33
|
+
from PIL import Image
|
|
34
|
+
from rasterio.features import shapes
|
|
35
|
+
from rasterio.windows import Window
|
|
36
|
+
from shapely.geometry import box, shape
|
|
37
|
+
from tqdm import tqdm
|
|
38
|
+
from transformers import (
|
|
39
|
+
AutoConfig,
|
|
40
|
+
AutoImageProcessor,
|
|
41
|
+
AutoModel,
|
|
42
|
+
AutoModelForImageClassification,
|
|
43
|
+
AutoModelForImageSegmentation,
|
|
44
|
+
AutoModelForSemanticSegmentation,
|
|
45
|
+
AutoModelForUniversalSegmentation,
|
|
46
|
+
AutoModelForDepthEstimation,
|
|
47
|
+
AutoModelForMaskGeneration,
|
|
48
|
+
AutoModelForObjectDetection,
|
|
49
|
+
AutoModelForZeroShotObjectDetection,
|
|
50
|
+
AutoProcessor,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
from transformers.utils import logging as hf_logging
|
|
54
|
+
|
|
55
|
+
from .utils import get_device
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
hf_logging.set_verbosity_error() # silence HF load reports
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class AutoGeoImageProcessor:
|
|
62
|
+
"""
|
|
63
|
+
Image processor for geospatial data that wraps AutoImageProcessor.
|
|
64
|
+
|
|
65
|
+
This class provides functionality to load and preprocess GeoTIFF images
|
|
66
|
+
while preserving geospatial metadata (CRS, transform, bounds). It wraps
|
|
67
|
+
Hugging Face's AutoImageProcessor and adds geospatial capabilities.
|
|
68
|
+
|
|
69
|
+
Use `from_pretrained` to instantiate this class, following the transformers pattern.
|
|
70
|
+
|
|
71
|
+
Attributes:
|
|
72
|
+
processor: The underlying AutoImageProcessor instance.
|
|
73
|
+
device (str): The device being used ('cuda' or 'cpu').
|
|
74
|
+
|
|
75
|
+
Example:
|
|
76
|
+
>>> processor = AutoGeoImageProcessor.from_pretrained("facebook/sam-vit-base")
|
|
77
|
+
>>> data, metadata = processor.load_geotiff("input.tif")
|
|
78
|
+
>>> inputs = processor(data)
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
processor: "AutoImageProcessor",
|
|
84
|
+
processor_name: Optional[str] = None,
|
|
85
|
+
device: Optional[str] = None,
|
|
86
|
+
) -> None:
|
|
87
|
+
"""Initialize the AutoGeoImageProcessor with an existing processor.
|
|
88
|
+
|
|
89
|
+
Note: Use `from_pretrained` class method to load from Hugging Face Hub.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
processor: An AutoImageProcessor instance.
|
|
93
|
+
processor_name: Name or path of the processor (for reference).
|
|
94
|
+
device: Device to use ('cuda', 'cpu'). If None, auto-detect.
|
|
95
|
+
"""
|
|
96
|
+
self.processor = processor
|
|
97
|
+
self.processor_name = processor_name
|
|
98
|
+
|
|
99
|
+
if device is None:
|
|
100
|
+
self.device = get_device()
|
|
101
|
+
else:
|
|
102
|
+
self.device = device
|
|
103
|
+
|
|
104
|
+
@classmethod
|
|
105
|
+
def from_pretrained(
|
|
106
|
+
cls,
|
|
107
|
+
pretrained_model_name_or_path: str,
|
|
108
|
+
device: Optional[str] = None,
|
|
109
|
+
use_full_processor: bool = False,
|
|
110
|
+
**kwargs: Any,
|
|
111
|
+
) -> "AutoGeoImageProcessor":
|
|
112
|
+
"""Load an AutoGeoImageProcessor from a pretrained processor.
|
|
113
|
+
|
|
114
|
+
This method wraps AutoImageProcessor.from_pretrained and adds
|
|
115
|
+
geospatial capabilities for processing GeoTIFF files.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
pretrained_model_name_or_path: Hugging Face model/processor name or local path.
|
|
119
|
+
Can be a model ID from huggingface.co or a local directory path.
|
|
120
|
+
device: Device to use ('cuda', 'cpu'). If None, auto-detect.
|
|
121
|
+
use_full_processor: If True, use AutoProcessor instead of AutoImageProcessor.
|
|
122
|
+
Required for models that need text inputs (e.g., Grounding DINO).
|
|
123
|
+
**kwargs: Additional arguments passed to AutoImageProcessor.from_pretrained.
|
|
124
|
+
Common options include:
|
|
125
|
+
- trust_remote_code (bool): Whether to trust remote code.
|
|
126
|
+
- revision (str): Specific model version to use.
|
|
127
|
+
- use_fast (bool): Whether to use fast tokenizer.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
AutoGeoImageProcessor instance with geospatial support.
|
|
131
|
+
|
|
132
|
+
Example:
|
|
133
|
+
>>> processor = AutoGeoImageProcessor.from_pretrained("facebook/sam-vit-base")
|
|
134
|
+
>>> processor = AutoGeoImageProcessor.from_pretrained(
|
|
135
|
+
... "nvidia/segformer-b0-finetuned-ade-512-512",
|
|
136
|
+
... device="cuda"
|
|
137
|
+
... )
|
|
138
|
+
"""
|
|
139
|
+
# Check if this is a model that needs the full processor (text + image)
|
|
140
|
+
model_name_lower = pretrained_model_name_or_path.lower()
|
|
141
|
+
needs_full_processor = use_full_processor or any(
|
|
142
|
+
name in model_name_lower
|
|
143
|
+
for name in ["grounding-dino", "owl", "clip", "blip"]
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
if needs_full_processor:
|
|
147
|
+
processor = AutoProcessor.from_pretrained(
|
|
148
|
+
pretrained_model_name_or_path, **kwargs
|
|
149
|
+
)
|
|
150
|
+
else:
|
|
151
|
+
try:
|
|
152
|
+
processor = AutoImageProcessor.from_pretrained(
|
|
153
|
+
pretrained_model_name_or_path, **kwargs
|
|
154
|
+
)
|
|
155
|
+
except Exception:
|
|
156
|
+
processor = AutoProcessor.from_pretrained(
|
|
157
|
+
pretrained_model_name_or_path, **kwargs
|
|
158
|
+
)
|
|
159
|
+
return cls(
|
|
160
|
+
processor=processor,
|
|
161
|
+
processor_name=pretrained_model_name_or_path,
|
|
162
|
+
device=device,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
def load_geotiff(
|
|
166
|
+
self,
|
|
167
|
+
source: Union[str, "rasterio.DatasetReader"],
|
|
168
|
+
window: Optional[Window] = None,
|
|
169
|
+
bands: Optional[List[int]] = None,
|
|
170
|
+
) -> Tuple[np.ndarray, Dict]:
|
|
171
|
+
"""Load a GeoTIFF file and return data with metadata.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
source: Path to GeoTIFF file or open rasterio DatasetReader.
|
|
175
|
+
window: Optional rasterio Window for reading a subset.
|
|
176
|
+
bands: List of band indices to read (1-indexed). If None, read all bands.
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
Tuple of (image array in CHW format, metadata dict).
|
|
180
|
+
|
|
181
|
+
Example:
|
|
182
|
+
>>> processor = AutoGeoImageProcessor.from_pretrained("facebook/sam-vit-base")
|
|
183
|
+
>>> data, metadata = processor.load_geotiff("input.tif")
|
|
184
|
+
>>> print(data.shape) # (C, H, W)
|
|
185
|
+
>>> print(metadata['crs']) # CRS info
|
|
186
|
+
"""
|
|
187
|
+
should_close = False
|
|
188
|
+
if isinstance(source, str):
|
|
189
|
+
src = rasterio.open(source)
|
|
190
|
+
should_close = True
|
|
191
|
+
else:
|
|
192
|
+
src = source
|
|
193
|
+
|
|
194
|
+
try:
|
|
195
|
+
# Read specified bands or all bands
|
|
196
|
+
if bands:
|
|
197
|
+
data = src.read(bands, window=window)
|
|
198
|
+
else:
|
|
199
|
+
data = src.read(window=window)
|
|
200
|
+
|
|
201
|
+
# Get profile and update for window
|
|
202
|
+
profile = src.profile.copy()
|
|
203
|
+
if window:
|
|
204
|
+
profile.update(
|
|
205
|
+
{
|
|
206
|
+
"height": window.height,
|
|
207
|
+
"width": window.width,
|
|
208
|
+
"transform": src.window_transform(window),
|
|
209
|
+
}
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
metadata = {
|
|
213
|
+
"profile": profile,
|
|
214
|
+
"crs": src.crs,
|
|
215
|
+
"transform": profile["transform"],
|
|
216
|
+
"bounds": (
|
|
217
|
+
src.bounds
|
|
218
|
+
if not window
|
|
219
|
+
else rasterio.windows.bounds(window, src.transform)
|
|
220
|
+
),
|
|
221
|
+
"width": profile["width"],
|
|
222
|
+
"height": profile["height"],
|
|
223
|
+
"count": data.shape[0],
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
finally:
|
|
227
|
+
if should_close:
|
|
228
|
+
src.close()
|
|
229
|
+
|
|
230
|
+
return data, metadata
|
|
231
|
+
|
|
232
|
+
def load_image(
|
|
233
|
+
self,
|
|
234
|
+
source: Union[str, np.ndarray, Image.Image],
|
|
235
|
+
window: Optional[Window] = None,
|
|
236
|
+
bands: Optional[List[int]] = None,
|
|
237
|
+
) -> Tuple[np.ndarray, Optional[Dict]]:
|
|
238
|
+
"""Load an image from various sources.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
source: Path to image file, numpy array, or PIL Image.
|
|
242
|
+
window: Optional rasterio Window (only for GeoTIFF).
|
|
243
|
+
bands: List of band indices (only for GeoTIFF, 1-indexed).
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
Tuple of (image array in CHW format, metadata dict or None).
|
|
247
|
+
"""
|
|
248
|
+
if isinstance(source, str):
|
|
249
|
+
# Check if GeoTIFF
|
|
250
|
+
try:
|
|
251
|
+
with rasterio.open(source) as src:
|
|
252
|
+
if src.crs is not None or source.lower().endswith(
|
|
253
|
+
(".tif", ".tiff")
|
|
254
|
+
):
|
|
255
|
+
return self.load_geotiff(source, window, bands)
|
|
256
|
+
except (rasterio.RasterioIOError, rasterio.errors.RasterioIOError):
|
|
257
|
+
# If opening as GeoTIFF fails, fall back to loading as a regular image.
|
|
258
|
+
pass
|
|
259
|
+
|
|
260
|
+
# Load as regular image
|
|
261
|
+
image = Image.open(source).convert("RGB")
|
|
262
|
+
data = np.array(image).transpose(2, 0, 1) # HWC -> CHW
|
|
263
|
+
return data, None
|
|
264
|
+
|
|
265
|
+
elif isinstance(source, np.ndarray):
|
|
266
|
+
# Ensure CHW format
|
|
267
|
+
if source.ndim == 2:
|
|
268
|
+
source = source[np.newaxis, :, :]
|
|
269
|
+
elif source.ndim == 3 and source.shape[2] in [1, 3, 4]:
|
|
270
|
+
source = source.transpose(2, 0, 1)
|
|
271
|
+
return source, None
|
|
272
|
+
|
|
273
|
+
elif isinstance(source, Image.Image):
|
|
274
|
+
data = np.array(source.convert("RGB")).transpose(2, 0, 1)
|
|
275
|
+
return data, None
|
|
276
|
+
|
|
277
|
+
else:
|
|
278
|
+
raise TypeError(f"Unsupported source type: {type(source)}")
|
|
279
|
+
|
|
280
|
+
def prepare_for_model(
|
|
281
|
+
self,
|
|
282
|
+
data: np.ndarray,
|
|
283
|
+
normalize: bool = True,
|
|
284
|
+
to_rgb: bool = True,
|
|
285
|
+
percentile_clip: bool = True,
|
|
286
|
+
return_tensors: str = "pt",
|
|
287
|
+
) -> Dict[str, Any]:
|
|
288
|
+
"""Prepare image data for model input.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
data: Image array in CHW format.
|
|
292
|
+
normalize: Whether to normalize pixel values.
|
|
293
|
+
to_rgb: Whether to convert to 3-channel RGB.
|
|
294
|
+
percentile_clip: Whether to use percentile clipping for normalization.
|
|
295
|
+
return_tensors: Return format ('pt' for PyTorch, 'np' for numpy).
|
|
296
|
+
|
|
297
|
+
Returns:
|
|
298
|
+
Dictionary with processed inputs ready for model.
|
|
299
|
+
"""
|
|
300
|
+
# Convert to HWC format
|
|
301
|
+
if data.ndim == 3:
|
|
302
|
+
img = data.transpose(1, 2, 0) # CHW -> HWC
|
|
303
|
+
else:
|
|
304
|
+
img = data
|
|
305
|
+
|
|
306
|
+
# Handle different band counts
|
|
307
|
+
if img.ndim == 2:
|
|
308
|
+
img = np.stack([img] * 3, axis=-1)
|
|
309
|
+
elif img.shape[-1] == 1:
|
|
310
|
+
img = np.repeat(img, 3, axis=-1)
|
|
311
|
+
elif img.shape[-1] > 3:
|
|
312
|
+
img = img[..., :3]
|
|
313
|
+
|
|
314
|
+
# Normalize
|
|
315
|
+
if normalize:
|
|
316
|
+
if percentile_clip:
|
|
317
|
+
for i in range(img.shape[-1]):
|
|
318
|
+
band = img[..., i]
|
|
319
|
+
p2, p98 = np.percentile(band, [2, 98])
|
|
320
|
+
if p98 > p2:
|
|
321
|
+
img[..., i] = np.clip((band - p2) / (p98 - p2), 0, 1)
|
|
322
|
+
else:
|
|
323
|
+
img[..., i] = 0
|
|
324
|
+
img = (img * 255).astype(np.uint8)
|
|
325
|
+
elif img.dtype != np.uint8:
|
|
326
|
+
img = ((img - img.min()) / (img.max() - img.min()) * 255).astype(
|
|
327
|
+
np.uint8
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
# Convert to PIL Image
|
|
331
|
+
pil_image = Image.fromarray(img)
|
|
332
|
+
|
|
333
|
+
# Process with transformers processor
|
|
334
|
+
inputs = self.processor(images=pil_image, return_tensors=return_tensors)
|
|
335
|
+
|
|
336
|
+
return inputs
|
|
337
|
+
|
|
338
|
+
def __call__(
|
|
339
|
+
self,
|
|
340
|
+
images: Union[str, np.ndarray, Image.Image, List],
|
|
341
|
+
**kwargs: Any,
|
|
342
|
+
) -> Dict[str, Any]:
|
|
343
|
+
"""Process images for model input.
|
|
344
|
+
|
|
345
|
+
Args:
|
|
346
|
+
images: Single image or list of images (paths, arrays, or PIL Images).
|
|
347
|
+
**kwargs: Additional arguments passed to the processor.
|
|
348
|
+
|
|
349
|
+
Returns:
|
|
350
|
+
Processed inputs ready for model.
|
|
351
|
+
"""
|
|
352
|
+
if isinstance(images, (str, np.ndarray, Image.Image)):
|
|
353
|
+
images = [images]
|
|
354
|
+
|
|
355
|
+
all_pil_images = []
|
|
356
|
+
for img in images:
|
|
357
|
+
if isinstance(img, str):
|
|
358
|
+
data, _ = self.load_image(img)
|
|
359
|
+
elif isinstance(img, np.ndarray):
|
|
360
|
+
data = img
|
|
361
|
+
if data.ndim == 3 and data.shape[2] in [1, 3, 4]:
|
|
362
|
+
data = data.transpose(2, 0, 1)
|
|
363
|
+
else:
|
|
364
|
+
data = np.array(img.convert("RGB")).transpose(2, 0, 1)
|
|
365
|
+
|
|
366
|
+
# Prepare PIL image
|
|
367
|
+
if data.ndim == 3:
|
|
368
|
+
img_arr = data.transpose(1, 2, 0)
|
|
369
|
+
else:
|
|
370
|
+
img_arr = data
|
|
371
|
+
|
|
372
|
+
if img_arr.ndim == 2:
|
|
373
|
+
img_arr = np.stack([img_arr] * 3, axis=-1)
|
|
374
|
+
elif img_arr.shape[-1] == 1:
|
|
375
|
+
img_arr = np.repeat(img_arr, 3, axis=-1)
|
|
376
|
+
elif img_arr.shape[-1] > 3:
|
|
377
|
+
img_arr = img_arr[..., :3]
|
|
378
|
+
|
|
379
|
+
# Normalize to uint8 if needed
|
|
380
|
+
if img_arr.dtype != np.uint8:
|
|
381
|
+
for i in range(img_arr.shape[-1]):
|
|
382
|
+
band = img_arr[..., i]
|
|
383
|
+
p2, p98 = np.percentile(band, [2, 98])
|
|
384
|
+
if p98 > p2:
|
|
385
|
+
img_arr[..., i] = np.clip((band - p2) / (p98 - p2), 0, 1)
|
|
386
|
+
else:
|
|
387
|
+
img_arr[..., i] = 0
|
|
388
|
+
img_arr = (img_arr * 255).astype(np.uint8)
|
|
389
|
+
|
|
390
|
+
all_pil_images.append(Image.fromarray(img_arr))
|
|
391
|
+
|
|
392
|
+
return self.processor(images=all_pil_images, **kwargs)
|
|
393
|
+
|
|
394
|
+
def save_geotiff(
|
|
395
|
+
self,
|
|
396
|
+
data: np.ndarray,
|
|
397
|
+
output_path: str,
|
|
398
|
+
metadata: Dict,
|
|
399
|
+
dtype: Optional[str] = None,
|
|
400
|
+
compress: str = "lzw",
|
|
401
|
+
nodata: Optional[float] = None,
|
|
402
|
+
) -> str:
|
|
403
|
+
"""Save array as GeoTIFF with geospatial metadata.
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
data: Array to save (2D or 3D in CHW format).
|
|
407
|
+
output_path: Output file path.
|
|
408
|
+
metadata: Metadata dictionary from load_geotiff.
|
|
409
|
+
dtype: Output data type. If None, infer from data.
|
|
410
|
+
compress: Compression method.
|
|
411
|
+
nodata: NoData value.
|
|
412
|
+
|
|
413
|
+
Returns:
|
|
414
|
+
Path to saved file.
|
|
415
|
+
"""
|
|
416
|
+
profile = metadata["profile"].copy()
|
|
417
|
+
|
|
418
|
+
if dtype is None:
|
|
419
|
+
dtype = str(data.dtype)
|
|
420
|
+
|
|
421
|
+
# Handle 2D vs 3D arrays
|
|
422
|
+
if data.ndim == 2:
|
|
423
|
+
count = 1
|
|
424
|
+
height, width = data.shape
|
|
425
|
+
else:
|
|
426
|
+
count = data.shape[0]
|
|
427
|
+
height, width = data.shape[1], data.shape[2]
|
|
428
|
+
|
|
429
|
+
profile.update(
|
|
430
|
+
{
|
|
431
|
+
"dtype": dtype,
|
|
432
|
+
"count": count,
|
|
433
|
+
"height": height,
|
|
434
|
+
"width": width,
|
|
435
|
+
"compress": compress,
|
|
436
|
+
}
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
if nodata is not None:
|
|
440
|
+
profile["nodata"] = nodata
|
|
441
|
+
|
|
442
|
+
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
|
|
443
|
+
|
|
444
|
+
with rasterio.open(output_path, "w", **profile) as dst:
|
|
445
|
+
if data.ndim == 2:
|
|
446
|
+
dst.write(data, 1)
|
|
447
|
+
else:
|
|
448
|
+
dst.write(data)
|
|
449
|
+
|
|
450
|
+
return output_path
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
class AutoGeoModel:
|
|
454
|
+
"""
|
|
455
|
+
Auto model for geospatial inference with GeoTIFF support.
|
|
456
|
+
|
|
457
|
+
This class wraps Hugging Face transformers' AutoModel classes and adds
|
|
458
|
+
geospatial capabilities for processing GeoTIFF images and saving results
|
|
459
|
+
as georeferenced outputs (GeoTIFF or vector data).
|
|
460
|
+
|
|
461
|
+
Use `from_pretrained` to instantiate this class, following the transformers pattern.
|
|
462
|
+
|
|
463
|
+
Attributes:
|
|
464
|
+
model: The underlying transformers model instance.
|
|
465
|
+
processor: AutoGeoImageProcessor for preprocessing.
|
|
466
|
+
device (str): The device being used ('cuda' or 'cpu').
|
|
467
|
+
task (str): The task type.
|
|
468
|
+
tile_size (int): Size of tiles for processing large images.
|
|
469
|
+
overlap (int): Overlap between tiles.
|
|
470
|
+
|
|
471
|
+
Example:
|
|
472
|
+
>>> model = AutoGeoModel.from_pretrained(
|
|
473
|
+
... "facebook/sam-vit-base",
|
|
474
|
+
... task="mask-generation"
|
|
475
|
+
... )
|
|
476
|
+
>>> result = model.predict("input.tif", output_path="output.tif")
|
|
477
|
+
"""
|
|
478
|
+
|
|
479
|
+
TASK_MODEL_MAPPING = {
|
|
480
|
+
"segmentation": AutoModelForSemanticSegmentation,
|
|
481
|
+
"semantic-segmentation": AutoModelForSemanticSegmentation,
|
|
482
|
+
"image-segmentation": AutoModelForImageSegmentation,
|
|
483
|
+
"universal-segmentation": AutoModelForUniversalSegmentation,
|
|
484
|
+
"depth-estimation": AutoModelForDepthEstimation,
|
|
485
|
+
"mask-generation": AutoModelForMaskGeneration,
|
|
486
|
+
"object-detection": AutoModelForObjectDetection,
|
|
487
|
+
"zero-shot-object-detection": AutoModelForZeroShotObjectDetection,
|
|
488
|
+
"classification": AutoModelForImageClassification,
|
|
489
|
+
"image-classification": AutoModelForImageClassification,
|
|
490
|
+
}
|
|
491
|
+
|
|
492
|
+
def __init__(
|
|
493
|
+
self,
|
|
494
|
+
model: torch.nn.Module,
|
|
495
|
+
processor: Optional["AutoGeoImageProcessor"] = None,
|
|
496
|
+
model_name: Optional[str] = None,
|
|
497
|
+
task: Optional[str] = None,
|
|
498
|
+
device: Optional[str] = None,
|
|
499
|
+
tile_size: int = 1024,
|
|
500
|
+
overlap: int = 128,
|
|
501
|
+
) -> None:
|
|
502
|
+
"""Initialize AutoGeoModel with an existing model.
|
|
503
|
+
|
|
504
|
+
Note: Use `from_pretrained` class method to load from Hugging Face Hub.
|
|
505
|
+
|
|
506
|
+
Args:
|
|
507
|
+
model: A transformers model instance.
|
|
508
|
+
processor: An AutoGeoImageProcessor instance (optional).
|
|
509
|
+
model_name: Name or path of the model (for reference).
|
|
510
|
+
task: Task type for the model.
|
|
511
|
+
device: Device to use ('cuda', 'cpu'). If None, auto-detect.
|
|
512
|
+
tile_size: Size of tiles for processing large images.
|
|
513
|
+
overlap: Overlap between tiles.
|
|
514
|
+
"""
|
|
515
|
+
self.model = model
|
|
516
|
+
self.processor = processor
|
|
517
|
+
self.model_name = model_name
|
|
518
|
+
self.task = task
|
|
519
|
+
self.tile_size = tile_size
|
|
520
|
+
self.overlap = overlap
|
|
521
|
+
|
|
522
|
+
if device is None:
|
|
523
|
+
self.device = get_device()
|
|
524
|
+
else:
|
|
525
|
+
self.device = device
|
|
526
|
+
|
|
527
|
+
# Ensure model is on the correct device and in eval mode
|
|
528
|
+
self.model = self.model.to(self.device)
|
|
529
|
+
self.model.eval()
|
|
530
|
+
|
|
531
|
+
@classmethod
|
|
532
|
+
def from_pretrained(
|
|
533
|
+
cls,
|
|
534
|
+
pretrained_model_name_or_path: str,
|
|
535
|
+
task: Optional[str] = None,
|
|
536
|
+
device: Optional[str] = None,
|
|
537
|
+
tile_size: int = 1024,
|
|
538
|
+
overlap: int = 128,
|
|
539
|
+
**kwargs: Any,
|
|
540
|
+
) -> "AutoGeoModel":
|
|
541
|
+
"""Load an AutoGeoModel from a pretrained model.
|
|
542
|
+
|
|
543
|
+
This method wraps transformers' AutoModel.from_pretrained and adds
|
|
544
|
+
geospatial capabilities for processing GeoTIFF files.
|
|
545
|
+
|
|
546
|
+
Args:
|
|
547
|
+
pretrained_model_name_or_path: Hugging Face model name or local path.
|
|
548
|
+
Can be a model ID from huggingface.co or a local directory path.
|
|
549
|
+
task: Task type for automatic model class selection. Options:
|
|
550
|
+
- 'segmentation' or 'semantic-segmentation': Semantic segmentation
|
|
551
|
+
- 'image-segmentation': General image segmentation
|
|
552
|
+
- 'universal-segmentation': Universal segmentation (Mask2Former, etc.)
|
|
553
|
+
- 'depth-estimation': Depth estimation
|
|
554
|
+
- 'mask-generation': Mask generation (SAM, etc.)
|
|
555
|
+
- 'object-detection': Object detection
|
|
556
|
+
- 'zero-shot-object-detection': Zero-shot object detection
|
|
557
|
+
- 'classification' or 'image-classification': Image classification
|
|
558
|
+
If None, will try to infer from model config.
|
|
559
|
+
device: Device to use ('cuda', 'cpu'). If None, auto-detect.
|
|
560
|
+
tile_size: Size of tiles for processing large images.
|
|
561
|
+
overlap: Overlap between tiles to avoid edge artifacts.
|
|
562
|
+
**kwargs: Additional arguments passed to the model's from_pretrained.
|
|
563
|
+
Common options include:
|
|
564
|
+
- trust_remote_code (bool): Whether to trust remote code.
|
|
565
|
+
- revision (str): Specific model version to use.
|
|
566
|
+
- torch_dtype: Data type for model weights.
|
|
567
|
+
|
|
568
|
+
Returns:
|
|
569
|
+
AutoGeoModel instance with geospatial support.
|
|
570
|
+
|
|
571
|
+
Example:
|
|
572
|
+
>>> model = AutoGeoModel.from_pretrained("facebook/sam-vit-base", task="mask-generation")
|
|
573
|
+
>>> model = AutoGeoModel.from_pretrained(
|
|
574
|
+
... "nvidia/segformer-b0-finetuned-ade-512-512",
|
|
575
|
+
... task="semantic-segmentation",
|
|
576
|
+
... device="cuda"
|
|
577
|
+
... )
|
|
578
|
+
"""
|
|
579
|
+
# Determine device
|
|
580
|
+
if device is None:
|
|
581
|
+
device = get_device()
|
|
582
|
+
|
|
583
|
+
# Load model using appropriate auto class
|
|
584
|
+
model = cls._load_model_from_pretrained(
|
|
585
|
+
pretrained_model_name_or_path, task, **kwargs
|
|
586
|
+
)
|
|
587
|
+
|
|
588
|
+
# Load processor - use full processor for models that need text inputs
|
|
589
|
+
needs_full_processor = task in (
|
|
590
|
+
"zero-shot-object-detection",
|
|
591
|
+
"object-detection",
|
|
592
|
+
)
|
|
593
|
+
try:
|
|
594
|
+
processor = AutoGeoImageProcessor.from_pretrained(
|
|
595
|
+
pretrained_model_name_or_path,
|
|
596
|
+
device=device,
|
|
597
|
+
use_full_processor=needs_full_processor,
|
|
598
|
+
)
|
|
599
|
+
except Exception:
|
|
600
|
+
processor = None
|
|
601
|
+
|
|
602
|
+
instance = cls(
|
|
603
|
+
model=model,
|
|
604
|
+
processor=processor,
|
|
605
|
+
model_name=pretrained_model_name_or_path,
|
|
606
|
+
task=task,
|
|
607
|
+
device=device,
|
|
608
|
+
tile_size=tile_size,
|
|
609
|
+
overlap=overlap,
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
print(f"Model loaded on {device}")
|
|
613
|
+
return instance
|
|
614
|
+
|
|
615
|
+
@classmethod
|
|
616
|
+
def _load_model_from_pretrained(
|
|
617
|
+
cls,
|
|
618
|
+
model_name_or_path: str,
|
|
619
|
+
task: Optional[str] = None,
|
|
620
|
+
**kwargs: Any,
|
|
621
|
+
) -> torch.nn.Module:
|
|
622
|
+
"""Load the appropriate model based on task using from_pretrained."""
|
|
623
|
+
if task and task in cls.TASK_MODEL_MAPPING:
|
|
624
|
+
model_class = cls.TASK_MODEL_MAPPING[task]
|
|
625
|
+
return model_class.from_pretrained(model_name_or_path, **kwargs)
|
|
626
|
+
|
|
627
|
+
# Try to infer from config
|
|
628
|
+
try:
|
|
629
|
+
config = AutoConfig.from_pretrained(model_name_or_path)
|
|
630
|
+
architectures = getattr(config, "architectures", [])
|
|
631
|
+
|
|
632
|
+
if any("Segmentation" in arch for arch in architectures):
|
|
633
|
+
return AutoModelForSemanticSegmentation.from_pretrained(
|
|
634
|
+
model_name_or_path, **kwargs
|
|
635
|
+
)
|
|
636
|
+
elif any("Detection" in arch for arch in architectures):
|
|
637
|
+
return AutoModelForObjectDetection.from_pretrained(
|
|
638
|
+
model_name_or_path, **kwargs
|
|
639
|
+
)
|
|
640
|
+
elif any("Classification" in arch for arch in architectures):
|
|
641
|
+
return AutoModelForImageClassification.from_pretrained(
|
|
642
|
+
model_name_or_path, **kwargs
|
|
643
|
+
)
|
|
644
|
+
else:
|
|
645
|
+
return AutoModel.from_pretrained(model_name_or_path, **kwargs)
|
|
646
|
+
except Exception:
|
|
647
|
+
return AutoModel.from_pretrained(model_name_or_path, **kwargs)
|
|
648
|
+
|
|
649
|
+
def predict(
|
|
650
|
+
self,
|
|
651
|
+
source: Union[str, np.ndarray, Image.Image],
|
|
652
|
+
output_path: Optional[str] = None,
|
|
653
|
+
output_vector_path: Optional[str] = None,
|
|
654
|
+
window: Optional[Window] = None,
|
|
655
|
+
bands: Optional[List[int]] = None,
|
|
656
|
+
threshold: float = 0.5,
|
|
657
|
+
text: Optional[str] = None,
|
|
658
|
+
labels: Optional[List[str]] = None,
|
|
659
|
+
box_threshold: float = 0.3,
|
|
660
|
+
text_threshold: float = 0.25,
|
|
661
|
+
min_object_area: int = 100,
|
|
662
|
+
simplify_tolerance: float = 1.0,
|
|
663
|
+
batch_size: int = 1,
|
|
664
|
+
return_probabilities: bool = False,
|
|
665
|
+
**kwargs: Any,
|
|
666
|
+
) -> Dict[str, Any]:
|
|
667
|
+
"""Run inference on a GeoTIFF or image.
|
|
668
|
+
|
|
669
|
+
Args:
|
|
670
|
+
source: Input image (path, array, or PIL Image). Can also be a URL.
|
|
671
|
+
output_path: Path to save output GeoTIFF.
|
|
672
|
+
output_vector_path: Path to save output vector file (GeoJSON, GPKG, etc.).
|
|
673
|
+
window: Optional rasterio Window for processing a subset.
|
|
674
|
+
bands: List of band indices to use (1-indexed).
|
|
675
|
+
threshold: Threshold for binary masks (segmentation tasks).
|
|
676
|
+
text: Text prompt for zero-shot detection models (e.g., "a cat. a dog.").
|
|
677
|
+
For Grounding DINO, labels should be lowercase and end with a dot.
|
|
678
|
+
labels: List of labels to detect (alternative to text).
|
|
679
|
+
Will be converted to text prompt format automatically.
|
|
680
|
+
box_threshold: Confidence threshold for bounding boxes (detection tasks).
|
|
681
|
+
text_threshold: Text similarity threshold for zero-shot detection.
|
|
682
|
+
min_object_area: Minimum object area in pixels for vectorization.
|
|
683
|
+
simplify_tolerance: Tolerance for polygon simplification.
|
|
684
|
+
batch_size: Batch size for processing tiles.
|
|
685
|
+
return_probabilities: Whether to return probability maps.
|
|
686
|
+
**kwargs: Additional arguments for specific tasks.
|
|
687
|
+
|
|
688
|
+
Returns:
|
|
689
|
+
Dictionary with results including mask/detections, metadata, and optional vector data.
|
|
690
|
+
|
|
691
|
+
Example:
|
|
692
|
+
>>> # Zero-shot object detection
|
|
693
|
+
>>> model = AutoGeoModel.from_pretrained(
|
|
694
|
+
... "IDEA-Research/grounding-dino-base",
|
|
695
|
+
... task="zero-shot-object-detection"
|
|
696
|
+
... )
|
|
697
|
+
>>> result = model.predict(
|
|
698
|
+
... "image.jpg",
|
|
699
|
+
... text="a building. a car. a tree.",
|
|
700
|
+
... box_threshold=0.3
|
|
701
|
+
... )
|
|
702
|
+
>>> print(result["boxes"], result["labels"])
|
|
703
|
+
"""
|
|
704
|
+
# Convert labels list to text format if provided
|
|
705
|
+
if labels is not None and text is None:
|
|
706
|
+
text = " ".join(f"{label.lower().strip()}." for label in labels)
|
|
707
|
+
|
|
708
|
+
# Handle zero-shot object detection separately
|
|
709
|
+
if self.task in ("zero-shot-object-detection", "object-detection"):
|
|
710
|
+
return self._predict_detection(
|
|
711
|
+
source,
|
|
712
|
+
text=text,
|
|
713
|
+
box_threshold=box_threshold,
|
|
714
|
+
text_threshold=text_threshold,
|
|
715
|
+
output_vector_path=output_vector_path,
|
|
716
|
+
**kwargs,
|
|
717
|
+
)
|
|
718
|
+
|
|
719
|
+
# Load image (handles URLs, local files, arrays, PIL Images)
|
|
720
|
+
pil_image = None
|
|
721
|
+
if isinstance(source, str):
|
|
722
|
+
# Check if URL
|
|
723
|
+
if source.startswith(("http://", "https://")):
|
|
724
|
+
pil_image = Image.open(requests.get(source, stream=True).raw)
|
|
725
|
+
metadata = None
|
|
726
|
+
data = np.array(pil_image.convert("RGB")).transpose(2, 0, 1)
|
|
727
|
+
else:
|
|
728
|
+
# Local file - try to load with geospatial info
|
|
729
|
+
if self.processor is not None:
|
|
730
|
+
data, metadata = self.processor.load_image(source, window, bands)
|
|
731
|
+
else:
|
|
732
|
+
try:
|
|
733
|
+
with rasterio.open(source) as src:
|
|
734
|
+
data = src.read(bands) if bands else src.read()
|
|
735
|
+
profile = src.profile.copy()
|
|
736
|
+
metadata = {
|
|
737
|
+
"profile": profile,
|
|
738
|
+
"crs": src.crs,
|
|
739
|
+
"transform": src.transform,
|
|
740
|
+
"bounds": src.bounds,
|
|
741
|
+
"width": src.width,
|
|
742
|
+
"height": src.height,
|
|
743
|
+
}
|
|
744
|
+
except Exception:
|
|
745
|
+
# Fall back to PIL for regular images
|
|
746
|
+
pil_image = Image.open(source).convert("RGB")
|
|
747
|
+
data = np.array(pil_image).transpose(2, 0, 1)
|
|
748
|
+
metadata = None
|
|
749
|
+
elif isinstance(source, Image.Image):
|
|
750
|
+
pil_image = source
|
|
751
|
+
data = np.array(source.convert("RGB")).transpose(2, 0, 1)
|
|
752
|
+
metadata = None
|
|
753
|
+
else:
|
|
754
|
+
data = np.array(source)
|
|
755
|
+
if data.ndim == 3 and data.shape[2] in [1, 3, 4]:
|
|
756
|
+
data = data.transpose(2, 0, 1)
|
|
757
|
+
metadata = None
|
|
758
|
+
|
|
759
|
+
# Check if we need tiled processing (not for classification tasks)
|
|
760
|
+
if data.ndim == 3:
|
|
761
|
+
_, height, width = data.shape
|
|
762
|
+
elif data.ndim == 2:
|
|
763
|
+
height, width = data.shape
|
|
764
|
+
_ = 1
|
|
765
|
+
else:
|
|
766
|
+
raise ValueError(f"Unexpected data shape: {data.shape}")
|
|
767
|
+
|
|
768
|
+
# Classification tasks should not use tiled processing
|
|
769
|
+
use_tiled = (
|
|
770
|
+
height > self.tile_size or width > self.tile_size
|
|
771
|
+
) and self.task not in ("classification", "image-classification")
|
|
772
|
+
|
|
773
|
+
if use_tiled:
|
|
774
|
+
result = self._predict_tiled(
|
|
775
|
+
source,
|
|
776
|
+
data,
|
|
777
|
+
metadata,
|
|
778
|
+
threshold=threshold,
|
|
779
|
+
batch_size=batch_size,
|
|
780
|
+
return_probabilities=return_probabilities,
|
|
781
|
+
**kwargs,
|
|
782
|
+
)
|
|
783
|
+
else:
|
|
784
|
+
result = self._predict_single(
|
|
785
|
+
data,
|
|
786
|
+
metadata,
|
|
787
|
+
threshold=threshold,
|
|
788
|
+
return_probabilities=return_probabilities,
|
|
789
|
+
**kwargs,
|
|
790
|
+
)
|
|
791
|
+
|
|
792
|
+
# Save GeoTIFF output
|
|
793
|
+
if output_path and metadata:
|
|
794
|
+
self.save_geotiff(
|
|
795
|
+
result.get("mask", result.get("output")),
|
|
796
|
+
output_path,
|
|
797
|
+
metadata,
|
|
798
|
+
nodata=0,
|
|
799
|
+
)
|
|
800
|
+
result["output_path"] = output_path
|
|
801
|
+
|
|
802
|
+
# Save vector output
|
|
803
|
+
if output_vector_path and metadata and "mask" in result:
|
|
804
|
+
gdf = self.mask_to_vector(
|
|
805
|
+
result["mask"],
|
|
806
|
+
metadata,
|
|
807
|
+
threshold=threshold,
|
|
808
|
+
min_object_area=min_object_area,
|
|
809
|
+
simplify_tolerance=simplify_tolerance,
|
|
810
|
+
)
|
|
811
|
+
if gdf is not None and len(gdf) > 0:
|
|
812
|
+
gdf.to_file(output_vector_path)
|
|
813
|
+
result["vector_path"] = output_vector_path
|
|
814
|
+
result["geodataframe"] = gdf
|
|
815
|
+
|
|
816
|
+
return result
|
|
817
|
+
|
|
818
|
+
def _predict_single(
|
|
819
|
+
self,
|
|
820
|
+
data: np.ndarray,
|
|
821
|
+
metadata: Optional[Dict],
|
|
822
|
+
threshold: float = 0.5,
|
|
823
|
+
return_probabilities: bool = False,
|
|
824
|
+
**kwargs: Any,
|
|
825
|
+
) -> Dict[str, Any]:
|
|
826
|
+
"""Run inference on a single image."""
|
|
827
|
+
# Prepare input
|
|
828
|
+
if self.processor is not None:
|
|
829
|
+
inputs = self.processor.prepare_for_model(data)
|
|
830
|
+
else:
|
|
831
|
+
# Fallback preparation
|
|
832
|
+
img = data.transpose(1, 2, 0) if data.ndim == 3 else data
|
|
833
|
+
if img.ndim == 2:
|
|
834
|
+
img = np.stack([img] * 3, axis=-1)
|
|
835
|
+
elif img.shape[-1] > 3:
|
|
836
|
+
img = img[..., :3]
|
|
837
|
+
|
|
838
|
+
if img.dtype != np.uint8:
|
|
839
|
+
img = ((img - img.min()) / (img.max() - img.min()) * 255).astype(
|
|
840
|
+
np.uint8
|
|
841
|
+
)
|
|
842
|
+
|
|
843
|
+
pil_image = Image.fromarray(img)
|
|
844
|
+
inputs = {
|
|
845
|
+
"pixel_values": torch.from_numpy(
|
|
846
|
+
np.array(pil_image).transpose(2, 0, 1) / 255.0
|
|
847
|
+
)
|
|
848
|
+
.float()
|
|
849
|
+
.unsqueeze(0)
|
|
850
|
+
}
|
|
851
|
+
|
|
852
|
+
# Move to device
|
|
853
|
+
inputs = {
|
|
854
|
+
k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
|
855
|
+
for k, v in inputs.items()
|
|
856
|
+
}
|
|
857
|
+
|
|
858
|
+
# Run inference
|
|
859
|
+
with torch.no_grad():
|
|
860
|
+
outputs = self.model(**inputs)
|
|
861
|
+
|
|
862
|
+
# Process outputs based on model type
|
|
863
|
+
result = self._process_outputs(
|
|
864
|
+
outputs, data.shape, threshold, return_probabilities
|
|
865
|
+
)
|
|
866
|
+
result["metadata"] = metadata
|
|
867
|
+
|
|
868
|
+
return result
|
|
869
|
+
|
|
870
|
+
def _predict_tiled(
|
|
871
|
+
self,
|
|
872
|
+
source: Union[str, np.ndarray],
|
|
873
|
+
data: np.ndarray,
|
|
874
|
+
metadata: Optional[Dict],
|
|
875
|
+
threshold: float = 0.5,
|
|
876
|
+
batch_size: int = 1,
|
|
877
|
+
return_probabilities: bool = False,
|
|
878
|
+
**kwargs: Any,
|
|
879
|
+
) -> Dict[str, Any]:
|
|
880
|
+
"""Run tiled inference for large images."""
|
|
881
|
+
if data.ndim == 3:
|
|
882
|
+
_, height, width = data.shape
|
|
883
|
+
else:
|
|
884
|
+
height, width = data.shape
|
|
885
|
+
|
|
886
|
+
effective_tile_size = self.tile_size - 2 * self.overlap
|
|
887
|
+
|
|
888
|
+
n_tiles_x = max(1, int(np.ceil(width / effective_tile_size)))
|
|
889
|
+
n_tiles_y = max(1, int(np.ceil(height / effective_tile_size)))
|
|
890
|
+
total_tiles = n_tiles_x * n_tiles_y
|
|
891
|
+
|
|
892
|
+
# Initialize output arrays
|
|
893
|
+
mask_output = np.zeros((height, width), dtype=np.float32)
|
|
894
|
+
count_output = np.zeros((height, width), dtype=np.float32)
|
|
895
|
+
|
|
896
|
+
print(f"Processing {total_tiles} tiles ({n_tiles_x}x{n_tiles_y})")
|
|
897
|
+
|
|
898
|
+
with tqdm(total=total_tiles, desc="Processing tiles") as pbar:
|
|
899
|
+
for y in range(n_tiles_y):
|
|
900
|
+
for x in range(n_tiles_x):
|
|
901
|
+
# Calculate tile coordinates
|
|
902
|
+
x_start = max(0, x * effective_tile_size - self.overlap)
|
|
903
|
+
y_start = max(0, y * effective_tile_size - self.overlap)
|
|
904
|
+
x_end = min(width, (x + 1) * effective_tile_size + self.overlap)
|
|
905
|
+
y_end = min(height, (y + 1) * effective_tile_size + self.overlap)
|
|
906
|
+
|
|
907
|
+
# Extract tile
|
|
908
|
+
if data.ndim == 3:
|
|
909
|
+
tile = data[:, y_start:y_end, x_start:x_end]
|
|
910
|
+
else:
|
|
911
|
+
tile = data[y_start:y_end, x_start:x_end]
|
|
912
|
+
|
|
913
|
+
try:
|
|
914
|
+
# Run inference on tile
|
|
915
|
+
tile_result = self._predict_single(
|
|
916
|
+
tile, None, threshold, return_probabilities
|
|
917
|
+
)
|
|
918
|
+
|
|
919
|
+
tile_mask = tile_result.get("mask", tile_result.get("output"))
|
|
920
|
+
if tile_mask is not None:
|
|
921
|
+
# Handle different dimensions
|
|
922
|
+
if tile_mask.ndim > 2:
|
|
923
|
+
tile_mask = tile_mask.squeeze()
|
|
924
|
+
if tile_mask.ndim > 2:
|
|
925
|
+
tile_mask = tile_mask[0]
|
|
926
|
+
|
|
927
|
+
# Resize if necessary
|
|
928
|
+
tile_h, tile_w = y_end - y_start, x_end - x_start
|
|
929
|
+
if tile_mask.shape != (tile_h, tile_w):
|
|
930
|
+
tile_mask = cv2.resize(
|
|
931
|
+
tile_mask.astype(np.float32),
|
|
932
|
+
(tile_w, tile_h),
|
|
933
|
+
interpolation=cv2.INTER_LINEAR,
|
|
934
|
+
)
|
|
935
|
+
|
|
936
|
+
# Add to output with blending
|
|
937
|
+
mask_output[y_start:y_end, x_start:x_end] += tile_mask
|
|
938
|
+
count_output[y_start:y_end, x_start:x_end] += 1
|
|
939
|
+
|
|
940
|
+
except Exception as e:
|
|
941
|
+
print(f"Error processing tile ({x}, {y}): {e}")
|
|
942
|
+
|
|
943
|
+
pbar.update(1)
|
|
944
|
+
|
|
945
|
+
# Average overlapping regions
|
|
946
|
+
count_output = np.maximum(count_output, 1)
|
|
947
|
+
mask_output = mask_output / count_output
|
|
948
|
+
|
|
949
|
+
result = {
|
|
950
|
+
"mask": (mask_output > threshold).astype(np.uint8),
|
|
951
|
+
"probabilities": mask_output if return_probabilities else None,
|
|
952
|
+
"metadata": metadata,
|
|
953
|
+
}
|
|
954
|
+
|
|
955
|
+
return result
|
|
956
|
+
|
|
957
|
+
def _predict_detection(
|
|
958
|
+
self,
|
|
959
|
+
source: Union[str, np.ndarray, Image.Image],
|
|
960
|
+
text: Optional[str] = None,
|
|
961
|
+
box_threshold: float = 0.3,
|
|
962
|
+
text_threshold: float = 0.25,
|
|
963
|
+
output_vector_path: Optional[str] = None,
|
|
964
|
+
**kwargs: Any,
|
|
965
|
+
) -> Dict[str, Any]:
|
|
966
|
+
"""Run object detection inference.
|
|
967
|
+
|
|
968
|
+
Args:
|
|
969
|
+
source: Input image (path, URL, array, or PIL Image).
|
|
970
|
+
text: Text prompt for zero-shot detection (e.g., "a cat. a dog.").
|
|
971
|
+
box_threshold: Confidence threshold for bounding boxes.
|
|
972
|
+
text_threshold: Text similarity threshold for zero-shot detection.
|
|
973
|
+
output_vector_path: Path to save detection results as vector file.
|
|
974
|
+
**kwargs: Additional arguments.
|
|
975
|
+
|
|
976
|
+
Returns:
|
|
977
|
+
Dictionary with boxes, scores, labels, and optional GeoDataFrame.
|
|
978
|
+
"""
|
|
979
|
+
# Load image
|
|
980
|
+
if isinstance(source, str):
|
|
981
|
+
if source.startswith(("http://", "https://")):
|
|
982
|
+
pil_image = Image.open(requests.get(source, stream=True).raw)
|
|
983
|
+
else:
|
|
984
|
+
try:
|
|
985
|
+
# Try loading with rasterio for GeoTIFF
|
|
986
|
+
with rasterio.open(source) as src:
|
|
987
|
+
data = src.read()
|
|
988
|
+
if data.shape[0] > 3:
|
|
989
|
+
data = data[:3]
|
|
990
|
+
elif data.shape[0] == 1:
|
|
991
|
+
data = np.repeat(data, 3, axis=0)
|
|
992
|
+
# Normalize to uint8
|
|
993
|
+
if data.dtype != np.uint8:
|
|
994
|
+
for i in range(data.shape[0]):
|
|
995
|
+
band = data[i].astype(np.float32)
|
|
996
|
+
p2, p98 = np.percentile(band, [2, 98])
|
|
997
|
+
if p98 > p2:
|
|
998
|
+
data[i] = np.clip(
|
|
999
|
+
(band - p2) / (p98 - p2) * 255, 0, 255
|
|
1000
|
+
)
|
|
1001
|
+
else:
|
|
1002
|
+
data[i] = 0
|
|
1003
|
+
data = data.astype(np.uint8)
|
|
1004
|
+
pil_image = Image.fromarray(data.transpose(1, 2, 0))
|
|
1005
|
+
except Exception:
|
|
1006
|
+
pil_image = Image.open(source)
|
|
1007
|
+
elif isinstance(source, np.ndarray):
|
|
1008
|
+
if source.ndim == 3 and source.shape[0] in [1, 3, 4]:
|
|
1009
|
+
source = source.transpose(1, 2, 0)
|
|
1010
|
+
if source.dtype != np.uint8:
|
|
1011
|
+
source = (
|
|
1012
|
+
(source - source.min()) / (source.max() - source.min()) * 255
|
|
1013
|
+
).astype(np.uint8)
|
|
1014
|
+
pil_image = Image.fromarray(source)
|
|
1015
|
+
else:
|
|
1016
|
+
pil_image = source
|
|
1017
|
+
|
|
1018
|
+
pil_image = pil_image.convert("RGB")
|
|
1019
|
+
image_size = pil_image.size[::-1] # (height, width)
|
|
1020
|
+
|
|
1021
|
+
# Get the underlying processor
|
|
1022
|
+
processor = self.processor.processor if self.processor else None
|
|
1023
|
+
if processor is None:
|
|
1024
|
+
# Load processor directly if not available
|
|
1025
|
+
processor = AutoProcessor.from_pretrained(self.model_name)
|
|
1026
|
+
|
|
1027
|
+
# Prepare inputs based on task type
|
|
1028
|
+
if self.task == "zero-shot-object-detection":
|
|
1029
|
+
if text is None:
|
|
1030
|
+
raise ValueError(
|
|
1031
|
+
"Text prompt is required for zero-shot object detection. "
|
|
1032
|
+
"Provide text='a cat. a dog.' or labels=['cat', 'dog']"
|
|
1033
|
+
)
|
|
1034
|
+
# Use the processor to prepare inputs with text
|
|
1035
|
+
inputs = processor(images=pil_image, text=text, return_tensors="pt")
|
|
1036
|
+
else:
|
|
1037
|
+
# Standard object detection
|
|
1038
|
+
inputs = processor(images=pil_image, return_tensors="pt")
|
|
1039
|
+
|
|
1040
|
+
# Move to device - use .to() method for BatchFeature objects
|
|
1041
|
+
inputs = inputs.to(self.device)
|
|
1042
|
+
|
|
1043
|
+
# Run inference
|
|
1044
|
+
with torch.no_grad():
|
|
1045
|
+
outputs = self.model(**inputs)
|
|
1046
|
+
|
|
1047
|
+
# Post-process results
|
|
1048
|
+
result = self._process_detection_outputs(
|
|
1049
|
+
outputs,
|
|
1050
|
+
inputs,
|
|
1051
|
+
text,
|
|
1052
|
+
image_size,
|
|
1053
|
+
box_threshold,
|
|
1054
|
+
text_threshold,
|
|
1055
|
+
processor,
|
|
1056
|
+
)
|
|
1057
|
+
|
|
1058
|
+
# Convert to GeoDataFrame if requested
|
|
1059
|
+
if output_vector_path and result.get("boxes") is not None:
|
|
1060
|
+
gdf = self._detections_to_geodataframe(result, pil_image.size)
|
|
1061
|
+
if gdf is not None and len(gdf) > 0:
|
|
1062
|
+
gdf.to_file(output_vector_path)
|
|
1063
|
+
result["vector_path"] = output_vector_path
|
|
1064
|
+
result["geodataframe"] = gdf
|
|
1065
|
+
|
|
1066
|
+
return result
|
|
1067
|
+
|
|
1068
|
+
def _process_detection_outputs(
|
|
1069
|
+
self,
|
|
1070
|
+
outputs: Any,
|
|
1071
|
+
inputs: Dict,
|
|
1072
|
+
text: Optional[str],
|
|
1073
|
+
image_size: Tuple[int, int],
|
|
1074
|
+
box_threshold: float,
|
|
1075
|
+
text_threshold: float,
|
|
1076
|
+
processor: Any = None,
|
|
1077
|
+
) -> Dict[str, Any]:
|
|
1078
|
+
"""Process detection model outputs."""
|
|
1079
|
+
result = {}
|
|
1080
|
+
|
|
1081
|
+
if processor is None:
|
|
1082
|
+
processor = self.processor.processor if self.processor else None
|
|
1083
|
+
|
|
1084
|
+
if self.task == "zero-shot-object-detection":
|
|
1085
|
+
# Use processor's post-processing for grounded detection
|
|
1086
|
+
try:
|
|
1087
|
+
results = processor.post_process_grounded_object_detection(
|
|
1088
|
+
outputs,
|
|
1089
|
+
inputs["input_ids"],
|
|
1090
|
+
threshold=box_threshold, # box confidence threshold
|
|
1091
|
+
text_threshold=text_threshold,
|
|
1092
|
+
target_sizes=[image_size],
|
|
1093
|
+
)
|
|
1094
|
+
if results and len(results) > 0:
|
|
1095
|
+
r = results[0]
|
|
1096
|
+
result["boxes"] = r["boxes"].cpu().numpy()
|
|
1097
|
+
result["scores"] = r["scores"].cpu().numpy()
|
|
1098
|
+
# Handle different output formats for labels
|
|
1099
|
+
if "labels" in r:
|
|
1100
|
+
result["labels"] = r["labels"]
|
|
1101
|
+
elif "text_labels" in r:
|
|
1102
|
+
result["labels"] = r["text_labels"]
|
|
1103
|
+
else:
|
|
1104
|
+
# Extract labels from logits if not provided
|
|
1105
|
+
result["labels"] = [
|
|
1106
|
+
f"object_{i}" for i in range(len(r["boxes"]))
|
|
1107
|
+
]
|
|
1108
|
+
except Exception as e:
|
|
1109
|
+
# Fallback for models without grounded post-processing
|
|
1110
|
+
print(f"Warning: Using fallback detection processing: {e}")
|
|
1111
|
+
if hasattr(outputs, "pred_boxes"):
|
|
1112
|
+
boxes = outputs.pred_boxes[0].cpu().numpy()
|
|
1113
|
+
logits = outputs.logits[0].cpu()
|
|
1114
|
+
scores = logits.sigmoid().max(dim=-1).values.numpy()
|
|
1115
|
+
mask = scores > box_threshold
|
|
1116
|
+
result["boxes"] = boxes[mask]
|
|
1117
|
+
result["scores"] = scores[mask]
|
|
1118
|
+
result["labels"] = [
|
|
1119
|
+
f"object_{i}" for i in range(len(result["boxes"]))
|
|
1120
|
+
]
|
|
1121
|
+
else:
|
|
1122
|
+
# Standard object detection post-processing
|
|
1123
|
+
if hasattr(outputs, "pred_boxes") and processor is not None:
|
|
1124
|
+
target_sizes = torch.tensor([image_size], device=self.device)
|
|
1125
|
+
results = processor.post_process_object_detection(
|
|
1126
|
+
outputs, threshold=box_threshold, target_sizes=target_sizes
|
|
1127
|
+
)
|
|
1128
|
+
if results and len(results) > 0:
|
|
1129
|
+
r = results[0]
|
|
1130
|
+
result["boxes"] = r["boxes"].cpu().numpy()
|
|
1131
|
+
result["scores"] = r["scores"].cpu().numpy()
|
|
1132
|
+
result["labels"] = r["labels"].cpu().numpy()
|
|
1133
|
+
|
|
1134
|
+
return result
|
|
1135
|
+
|
|
1136
|
+
def _detections_to_geodataframe(
|
|
1137
|
+
self,
|
|
1138
|
+
detections: Dict[str, Any],
|
|
1139
|
+
image_size: Tuple[int, int],
|
|
1140
|
+
) -> Optional[gpd.GeoDataFrame]:
|
|
1141
|
+
"""Convert detection results to a GeoDataFrame.
|
|
1142
|
+
|
|
1143
|
+
Note: Without geospatial metadata, coordinates are in pixel space.
|
|
1144
|
+
"""
|
|
1145
|
+
boxes = detections.get("boxes")
|
|
1146
|
+
if boxes is None or len(boxes) == 0:
|
|
1147
|
+
return None
|
|
1148
|
+
|
|
1149
|
+
scores = detections.get("scores", [None] * len(boxes))
|
|
1150
|
+
labels = detections.get("labels", [None] * len(boxes))
|
|
1151
|
+
|
|
1152
|
+
geometries = []
|
|
1153
|
+
for bbox in boxes:
|
|
1154
|
+
# Convert [x1, y1, x2, y2] to polygon
|
|
1155
|
+
x1, y1, x2, y2 = bbox
|
|
1156
|
+
geometries.append(box(x1, y1, x2, y2))
|
|
1157
|
+
|
|
1158
|
+
gdf = gpd.GeoDataFrame(
|
|
1159
|
+
{
|
|
1160
|
+
"geometry": geometries,
|
|
1161
|
+
"score": scores,
|
|
1162
|
+
"label": labels,
|
|
1163
|
+
}
|
|
1164
|
+
)
|
|
1165
|
+
|
|
1166
|
+
return gdf
|
|
1167
|
+
|
|
1168
|
+
def _process_outputs(
|
|
1169
|
+
self,
|
|
1170
|
+
outputs: Any,
|
|
1171
|
+
input_shape: Tuple,
|
|
1172
|
+
threshold: float = 0.5,
|
|
1173
|
+
return_probabilities: bool = False,
|
|
1174
|
+
) -> Dict[str, Any]:
|
|
1175
|
+
"""Process model outputs to extract masks or predictions."""
|
|
1176
|
+
result = {}
|
|
1177
|
+
|
|
1178
|
+
# Handle different output types
|
|
1179
|
+
if hasattr(outputs, "logits"):
|
|
1180
|
+
logits = outputs.logits
|
|
1181
|
+
if logits.dim() == 4: # Segmentation output
|
|
1182
|
+
# Upsample if needed
|
|
1183
|
+
if logits.shape[2:] != input_shape[1:]:
|
|
1184
|
+
logits = torch.nn.functional.interpolate(
|
|
1185
|
+
logits,
|
|
1186
|
+
size=(input_shape[1], input_shape[2]),
|
|
1187
|
+
mode="bilinear",
|
|
1188
|
+
align_corners=False,
|
|
1189
|
+
)
|
|
1190
|
+
|
|
1191
|
+
probs = torch.softmax(logits, dim=1)
|
|
1192
|
+
mask = probs.argmax(dim=1).squeeze().cpu().numpy()
|
|
1193
|
+
result["mask"] = mask.astype(np.uint8)
|
|
1194
|
+
|
|
1195
|
+
if return_probabilities:
|
|
1196
|
+
result["probabilities"] = probs.squeeze().cpu().numpy()
|
|
1197
|
+
|
|
1198
|
+
elif logits.dim() == 2: # Classification output
|
|
1199
|
+
probs = torch.softmax(logits, dim=-1)
|
|
1200
|
+
pred_class = probs.argmax(dim=-1).item()
|
|
1201
|
+
result["class"] = pred_class
|
|
1202
|
+
result["probabilities"] = probs.squeeze().cpu().numpy()
|
|
1203
|
+
|
|
1204
|
+
elif hasattr(outputs, "pred_masks"):
|
|
1205
|
+
masks = outputs.pred_masks.squeeze().cpu().numpy()
|
|
1206
|
+
if masks.ndim == 3:
|
|
1207
|
+
mask = masks.max(axis=0)
|
|
1208
|
+
else:
|
|
1209
|
+
mask = masks
|
|
1210
|
+
result["mask"] = (mask > threshold).astype(np.uint8)
|
|
1211
|
+
|
|
1212
|
+
if return_probabilities:
|
|
1213
|
+
result["probabilities"] = mask
|
|
1214
|
+
|
|
1215
|
+
elif hasattr(outputs, "predicted_depth"):
|
|
1216
|
+
depth = outputs.predicted_depth.squeeze().cpu().numpy()
|
|
1217
|
+
result["output"] = depth
|
|
1218
|
+
result["depth"] = depth
|
|
1219
|
+
|
|
1220
|
+
elif hasattr(outputs, "masks"):
|
|
1221
|
+
# SAM-like output
|
|
1222
|
+
masks = outputs.masks
|
|
1223
|
+
if isinstance(masks, torch.Tensor):
|
|
1224
|
+
masks = masks.cpu().numpy()
|
|
1225
|
+
if masks.ndim == 4:
|
|
1226
|
+
masks = masks.squeeze(0)
|
|
1227
|
+
if masks.ndim == 3:
|
|
1228
|
+
mask = masks.max(axis=0)
|
|
1229
|
+
else:
|
|
1230
|
+
mask = masks
|
|
1231
|
+
result["mask"] = (mask > threshold).astype(np.uint8)
|
|
1232
|
+
result["all_masks"] = masks
|
|
1233
|
+
|
|
1234
|
+
else:
|
|
1235
|
+
# Generic output handling
|
|
1236
|
+
if hasattr(outputs, "last_hidden_state"):
|
|
1237
|
+
result["features"] = outputs.last_hidden_state.cpu().numpy()
|
|
1238
|
+
else:
|
|
1239
|
+
result["output"] = outputs
|
|
1240
|
+
|
|
1241
|
+
return result
|
|
1242
|
+
|
|
1243
|
+
def mask_to_vector(
|
|
1244
|
+
self,
|
|
1245
|
+
mask: np.ndarray,
|
|
1246
|
+
metadata: Dict,
|
|
1247
|
+
threshold: float = 0.5,
|
|
1248
|
+
min_object_area: int = 100,
|
|
1249
|
+
max_object_area: Optional[int] = None,
|
|
1250
|
+
simplify_tolerance: float = 1.0,
|
|
1251
|
+
) -> Optional[gpd.GeoDataFrame]:
|
|
1252
|
+
"""Convert a raster mask to vector polygons.
|
|
1253
|
+
|
|
1254
|
+
Args:
|
|
1255
|
+
mask: Binary or probability mask array.
|
|
1256
|
+
metadata: Geospatial metadata dictionary.
|
|
1257
|
+
threshold: Threshold for binarizing probability masks.
|
|
1258
|
+
min_object_area: Minimum area in pixels for valid objects.
|
|
1259
|
+
max_object_area: Maximum area in pixels (optional).
|
|
1260
|
+
simplify_tolerance: Tolerance for polygon simplification.
|
|
1261
|
+
|
|
1262
|
+
Returns:
|
|
1263
|
+
GeoDataFrame with polygon geometries, or None if no valid polygons.
|
|
1264
|
+
"""
|
|
1265
|
+
if metadata is None or metadata.get("crs") is None:
|
|
1266
|
+
print("Warning: No CRS information available for vectorization")
|
|
1267
|
+
return None
|
|
1268
|
+
|
|
1269
|
+
# Ensure binary mask
|
|
1270
|
+
if mask.dtype == np.float32 or mask.dtype == np.float64:
|
|
1271
|
+
mask = (mask > threshold).astype(np.uint8)
|
|
1272
|
+
else:
|
|
1273
|
+
mask = (mask > 0).astype(np.uint8)
|
|
1274
|
+
|
|
1275
|
+
# Get transform
|
|
1276
|
+
transform = metadata.get("transform")
|
|
1277
|
+
crs = metadata.get("crs")
|
|
1278
|
+
|
|
1279
|
+
if transform is None:
|
|
1280
|
+
print("Warning: No transform available for vectorization")
|
|
1281
|
+
return None
|
|
1282
|
+
|
|
1283
|
+
# Extract shapes using rasterio
|
|
1284
|
+
polygons = []
|
|
1285
|
+
values = []
|
|
1286
|
+
|
|
1287
|
+
try:
|
|
1288
|
+
for geom, value in shapes(mask, transform=transform):
|
|
1289
|
+
if value > 0: # Only keep non-background
|
|
1290
|
+
poly = shape(geom)
|
|
1291
|
+
|
|
1292
|
+
# Filter by area
|
|
1293
|
+
pixel_area = poly.area / (transform.a * abs(transform.e))
|
|
1294
|
+
if pixel_area < min_object_area:
|
|
1295
|
+
continue
|
|
1296
|
+
if max_object_area and pixel_area > max_object_area:
|
|
1297
|
+
continue
|
|
1298
|
+
|
|
1299
|
+
# Simplify
|
|
1300
|
+
if simplify_tolerance > 0:
|
|
1301
|
+
poly = poly.simplify(
|
|
1302
|
+
simplify_tolerance * abs(transform.a),
|
|
1303
|
+
preserve_topology=True,
|
|
1304
|
+
)
|
|
1305
|
+
|
|
1306
|
+
if poly.is_valid and not poly.is_empty:
|
|
1307
|
+
polygons.append(poly)
|
|
1308
|
+
values.append(value)
|
|
1309
|
+
|
|
1310
|
+
except Exception as e:
|
|
1311
|
+
print(f"Error during vectorization: {e}")
|
|
1312
|
+
return None
|
|
1313
|
+
|
|
1314
|
+
if not polygons:
|
|
1315
|
+
return None
|
|
1316
|
+
|
|
1317
|
+
# Create GeoDataFrame
|
|
1318
|
+
gdf = gpd.GeoDataFrame(
|
|
1319
|
+
{"geometry": polygons, "class": values},
|
|
1320
|
+
crs=crs,
|
|
1321
|
+
)
|
|
1322
|
+
|
|
1323
|
+
return gdf
|
|
1324
|
+
|
|
1325
|
+
def save_geotiff(
|
|
1326
|
+
self,
|
|
1327
|
+
data: np.ndarray,
|
|
1328
|
+
output_path: str,
|
|
1329
|
+
metadata: Dict,
|
|
1330
|
+
dtype: Optional[str] = None,
|
|
1331
|
+
compress: str = "lzw",
|
|
1332
|
+
nodata: Optional[float] = None,
|
|
1333
|
+
) -> str:
|
|
1334
|
+
"""Save array as GeoTIFF with geospatial metadata.
|
|
1335
|
+
|
|
1336
|
+
Args:
|
|
1337
|
+
data: Array to save (2D or 3D in CHW format).
|
|
1338
|
+
output_path: Output file path.
|
|
1339
|
+
metadata: Metadata dictionary from load_geotiff.
|
|
1340
|
+
dtype: Output data type. If None, infer from data.
|
|
1341
|
+
compress: Compression method.
|
|
1342
|
+
nodata: NoData value.
|
|
1343
|
+
|
|
1344
|
+
Returns:
|
|
1345
|
+
Path to saved file.
|
|
1346
|
+
"""
|
|
1347
|
+
profile = metadata["profile"].copy()
|
|
1348
|
+
|
|
1349
|
+
if dtype is None:
|
|
1350
|
+
dtype = str(data.dtype)
|
|
1351
|
+
|
|
1352
|
+
# Handle 2D vs 3D arrays
|
|
1353
|
+
if data.ndim == 2:
|
|
1354
|
+
count = 1
|
|
1355
|
+
height, width = data.shape
|
|
1356
|
+
else:
|
|
1357
|
+
count = data.shape[0]
|
|
1358
|
+
height, width = data.shape[1], data.shape[2]
|
|
1359
|
+
|
|
1360
|
+
profile.update(
|
|
1361
|
+
{
|
|
1362
|
+
"dtype": dtype,
|
|
1363
|
+
"count": count,
|
|
1364
|
+
"height": height,
|
|
1365
|
+
"width": width,
|
|
1366
|
+
"compress": compress,
|
|
1367
|
+
}
|
|
1368
|
+
)
|
|
1369
|
+
|
|
1370
|
+
if nodata is not None:
|
|
1371
|
+
profile["nodata"] = nodata
|
|
1372
|
+
|
|
1373
|
+
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
|
|
1374
|
+
|
|
1375
|
+
with rasterio.open(output_path, "w", **profile) as dst:
|
|
1376
|
+
if data.ndim == 2:
|
|
1377
|
+
dst.write(data, 1)
|
|
1378
|
+
else:
|
|
1379
|
+
dst.write(data)
|
|
1380
|
+
|
|
1381
|
+
return output_path
|
|
1382
|
+
|
|
1383
|
+
def save_vector(
|
|
1384
|
+
self,
|
|
1385
|
+
gdf: gpd.GeoDataFrame,
|
|
1386
|
+
output_path: str,
|
|
1387
|
+
driver: Optional[str] = None,
|
|
1388
|
+
) -> str:
|
|
1389
|
+
"""Save GeoDataFrame to file.
|
|
1390
|
+
|
|
1391
|
+
Args:
|
|
1392
|
+
gdf: GeoDataFrame to save.
|
|
1393
|
+
output_path: Output file path.
|
|
1394
|
+
driver: File driver (auto-detected from extension if None).
|
|
1395
|
+
|
|
1396
|
+
Returns:
|
|
1397
|
+
Path to saved file.
|
|
1398
|
+
"""
|
|
1399
|
+
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
|
|
1400
|
+
|
|
1401
|
+
if driver is None:
|
|
1402
|
+
ext = os.path.splitext(output_path)[1].lower()
|
|
1403
|
+
driver_map = {
|
|
1404
|
+
".geojson": "GeoJSON",
|
|
1405
|
+
".json": "GeoJSON",
|
|
1406
|
+
".gpkg": "GPKG",
|
|
1407
|
+
".shp": "ESRI Shapefile",
|
|
1408
|
+
".parquet": "Parquet",
|
|
1409
|
+
".fgb": "FlatGeobuf",
|
|
1410
|
+
}
|
|
1411
|
+
driver = driver_map.get(ext, "GeoJSON")
|
|
1412
|
+
|
|
1413
|
+
gdf.to_file(output_path, driver=driver)
|
|
1414
|
+
return output_path
|
|
1415
|
+
|
|
1416
|
+
|
|
1417
|
+
def semantic_segmentation(
|
|
1418
|
+
input_path: str,
|
|
1419
|
+
output_path: str,
|
|
1420
|
+
model_name: str = "nvidia/segformer-b0-finetuned-ade-512-512",
|
|
1421
|
+
output_vector_path: Optional[str] = None,
|
|
1422
|
+
threshold: float = 0.5,
|
|
1423
|
+
tile_size: int = 1024,
|
|
1424
|
+
overlap: int = 128,
|
|
1425
|
+
min_object_area: int = 100,
|
|
1426
|
+
simplify_tolerance: float = 1.0,
|
|
1427
|
+
device: Optional[str] = None,
|
|
1428
|
+
**kwargs: Any,
|
|
1429
|
+
) -> Dict[str, Any]:
|
|
1430
|
+
"""
|
|
1431
|
+
Perform semantic segmentation on a GeoTIFF image.
|
|
1432
|
+
|
|
1433
|
+
Args:
|
|
1434
|
+
input_path: Path to input GeoTIFF.
|
|
1435
|
+
output_path: Path to save output segmentation GeoTIFF.
|
|
1436
|
+
model_name: Hugging Face model name.
|
|
1437
|
+
output_vector_path: Optional path to save vectorized output.
|
|
1438
|
+
threshold: Threshold for binary masks.
|
|
1439
|
+
tile_size: Size of tiles for processing large images.
|
|
1440
|
+
overlap: Overlap between tiles.
|
|
1441
|
+
min_object_area: Minimum object area for vectorization.
|
|
1442
|
+
simplify_tolerance: Tolerance for polygon simplification.
|
|
1443
|
+
device: Device to use ('cuda', 'cpu').
|
|
1444
|
+
**kwargs: Additional arguments for prediction.
|
|
1445
|
+
|
|
1446
|
+
Returns:
|
|
1447
|
+
Dictionary with results.
|
|
1448
|
+
|
|
1449
|
+
Example:
|
|
1450
|
+
>>> result = semantic_segmentation(
|
|
1451
|
+
... "input.tif",
|
|
1452
|
+
... "output.tif",
|
|
1453
|
+
... model_name="nvidia/segformer-b0-finetuned-ade-512-512",
|
|
1454
|
+
... output_vector_path="output.geojson"
|
|
1455
|
+
... )
|
|
1456
|
+
"""
|
|
1457
|
+
model = AutoGeoModel.from_pretrained(
|
|
1458
|
+
model_name,
|
|
1459
|
+
task="semantic-segmentation",
|
|
1460
|
+
device=device,
|
|
1461
|
+
tile_size=tile_size,
|
|
1462
|
+
overlap=overlap,
|
|
1463
|
+
)
|
|
1464
|
+
|
|
1465
|
+
return model.predict(
|
|
1466
|
+
input_path,
|
|
1467
|
+
output_path=output_path,
|
|
1468
|
+
output_vector_path=output_vector_path,
|
|
1469
|
+
threshold=threshold,
|
|
1470
|
+
min_object_area=min_object_area,
|
|
1471
|
+
simplify_tolerance=simplify_tolerance,
|
|
1472
|
+
**kwargs,
|
|
1473
|
+
)
|
|
1474
|
+
|
|
1475
|
+
|
|
1476
|
+
def depth_estimation(
|
|
1477
|
+
input_path: str,
|
|
1478
|
+
output_path: str,
|
|
1479
|
+
model_name: str = "depth-anything/Depth-Anything-V2-Small-hf",
|
|
1480
|
+
tile_size: int = 1024,
|
|
1481
|
+
overlap: int = 128,
|
|
1482
|
+
device: Optional[str] = None,
|
|
1483
|
+
**kwargs: Any,
|
|
1484
|
+
) -> Dict[str, Any]:
|
|
1485
|
+
"""
|
|
1486
|
+
Perform depth estimation on a GeoTIFF image.
|
|
1487
|
+
|
|
1488
|
+
Args:
|
|
1489
|
+
input_path: Path to input GeoTIFF.
|
|
1490
|
+
output_path: Path to save output depth GeoTIFF.
|
|
1491
|
+
model_name: Hugging Face model name.
|
|
1492
|
+
tile_size: Size of tiles for processing large images.
|
|
1493
|
+
overlap: Overlap between tiles.
|
|
1494
|
+
device: Device to use ('cuda', 'cpu').
|
|
1495
|
+
**kwargs: Additional arguments for prediction.
|
|
1496
|
+
|
|
1497
|
+
Returns:
|
|
1498
|
+
Dictionary with results.
|
|
1499
|
+
|
|
1500
|
+
Example:
|
|
1501
|
+
>>> result = depth_estimation(
|
|
1502
|
+
... "input.tif",
|
|
1503
|
+
... "depth_output.tif",
|
|
1504
|
+
... model_name="depth-anything/Depth-Anything-V2-Small-hf"
|
|
1505
|
+
... )
|
|
1506
|
+
"""
|
|
1507
|
+
model = AutoGeoModel.from_pretrained(
|
|
1508
|
+
model_name,
|
|
1509
|
+
task="depth-estimation",
|
|
1510
|
+
device=device,
|
|
1511
|
+
tile_size=tile_size,
|
|
1512
|
+
overlap=overlap,
|
|
1513
|
+
)
|
|
1514
|
+
|
|
1515
|
+
return model.predict(input_path, output_path=output_path, **kwargs)
|
|
1516
|
+
|
|
1517
|
+
|
|
1518
|
+
def image_classification(
|
|
1519
|
+
input_path: str,
|
|
1520
|
+
model_name: str = "google/vit-base-patch16-224",
|
|
1521
|
+
device: Optional[str] = None,
|
|
1522
|
+
**kwargs: Any,
|
|
1523
|
+
) -> Dict[str, Any]:
|
|
1524
|
+
"""
|
|
1525
|
+
Perform image classification on a GeoTIFF image.
|
|
1526
|
+
|
|
1527
|
+
Args:
|
|
1528
|
+
input_path: Path to input GeoTIFF.
|
|
1529
|
+
model_name: Hugging Face model name.
|
|
1530
|
+
device: Device to use ('cuda', 'cpu').
|
|
1531
|
+
**kwargs: Additional arguments for prediction.
|
|
1532
|
+
|
|
1533
|
+
Returns:
|
|
1534
|
+
Dictionary with classification results.
|
|
1535
|
+
|
|
1536
|
+
Example:
|
|
1537
|
+
>>> result = image_classification(
|
|
1538
|
+
... "input.tif",
|
|
1539
|
+
... model_name="google/vit-base-patch16-224"
|
|
1540
|
+
... )
|
|
1541
|
+
>>> print(result['class'], result['probabilities'])
|
|
1542
|
+
"""
|
|
1543
|
+
model = AutoGeoModel.from_pretrained(
|
|
1544
|
+
model_name,
|
|
1545
|
+
task="image-classification",
|
|
1546
|
+
device=device,
|
|
1547
|
+
)
|
|
1548
|
+
|
|
1549
|
+
return model.predict(input_path, **kwargs)
|
|
1550
|
+
|
|
1551
|
+
|
|
1552
|
+
def object_detection(
|
|
1553
|
+
input_path: str,
|
|
1554
|
+
text: Optional[str] = None,
|
|
1555
|
+
labels: Optional[List[str]] = None,
|
|
1556
|
+
model_name: str = "IDEA-Research/grounding-dino-base",
|
|
1557
|
+
output_vector_path: Optional[str] = None,
|
|
1558
|
+
box_threshold: float = 0.3,
|
|
1559
|
+
text_threshold: float = 0.25,
|
|
1560
|
+
device: Optional[str] = None,
|
|
1561
|
+
**kwargs: Any,
|
|
1562
|
+
) -> Dict[str, Any]:
|
|
1563
|
+
"""
|
|
1564
|
+
Perform object detection on an image using Grounding DINO or similar models.
|
|
1565
|
+
|
|
1566
|
+
Args:
|
|
1567
|
+
input_path: Path to input image or URL.
|
|
1568
|
+
text: Text prompt for detection (e.g., "a building. a car.").
|
|
1569
|
+
Labels should be lowercase and end with a dot.
|
|
1570
|
+
labels: List of labels to detect (alternative to text).
|
|
1571
|
+
Will be converted to text format automatically.
|
|
1572
|
+
model_name: Hugging Face model name.
|
|
1573
|
+
output_vector_path: Optional path to save detection boxes as vector file.
|
|
1574
|
+
box_threshold: Confidence threshold for bounding boxes.
|
|
1575
|
+
text_threshold: Text similarity threshold for zero-shot detection.
|
|
1576
|
+
device: Device to use ('cuda', 'cpu').
|
|
1577
|
+
**kwargs: Additional arguments for prediction.
|
|
1578
|
+
|
|
1579
|
+
Returns:
|
|
1580
|
+
Dictionary with detection results (boxes, scores, labels).
|
|
1581
|
+
|
|
1582
|
+
Example:
|
|
1583
|
+
>>> result = object_detection(
|
|
1584
|
+
... "image.jpg",
|
|
1585
|
+
... labels=["car", "building", "tree"],
|
|
1586
|
+
... box_threshold=0.3
|
|
1587
|
+
... )
|
|
1588
|
+
>>> print(result["boxes"], result["labels"])
|
|
1589
|
+
"""
|
|
1590
|
+
# Determine task type based on model
|
|
1591
|
+
task = "zero-shot-object-detection"
|
|
1592
|
+
if "grounding-dino" not in model_name.lower() and "owl" not in model_name.lower():
|
|
1593
|
+
task = "object-detection"
|
|
1594
|
+
|
|
1595
|
+
model = AutoGeoModel.from_pretrained(
|
|
1596
|
+
model_name,
|
|
1597
|
+
task=task,
|
|
1598
|
+
device=device,
|
|
1599
|
+
)
|
|
1600
|
+
|
|
1601
|
+
return model.predict(
|
|
1602
|
+
input_path,
|
|
1603
|
+
text=text,
|
|
1604
|
+
labels=labels,
|
|
1605
|
+
box_threshold=box_threshold,
|
|
1606
|
+
text_threshold=text_threshold,
|
|
1607
|
+
output_vector_path=output_vector_path,
|
|
1608
|
+
**kwargs,
|
|
1609
|
+
)
|
|
1610
|
+
|
|
1611
|
+
|
|
1612
|
+
def get_hf_tasks() -> List[str]:
|
|
1613
|
+
"""Get all supported Hugging Face tasks for this module.
|
|
1614
|
+
|
|
1615
|
+
Returns:
|
|
1616
|
+
List of supported task names.
|
|
1617
|
+
"""
|
|
1618
|
+
from transformers.pipelines import SUPPORTED_TASKS
|
|
1619
|
+
|
|
1620
|
+
return sorted(list(SUPPORTED_TASKS.keys()))
|
|
1621
|
+
|
|
1622
|
+
|
|
1623
|
+
def get_hf_model_config(model_id: str) -> Dict[str, Any]:
|
|
1624
|
+
"""Get the model configuration for a Hugging Face model.
|
|
1625
|
+
|
|
1626
|
+
Args:
|
|
1627
|
+
model_id: The Hugging Face model ID (e.g., "facebook/sam-vit-base").
|
|
1628
|
+
|
|
1629
|
+
Returns:
|
|
1630
|
+
Dictionary representation of the model config.
|
|
1631
|
+
|
|
1632
|
+
Example:
|
|
1633
|
+
>>> config = get_hf_model_config("nvidia/segformer-b0-finetuned-ade-512-512")
|
|
1634
|
+
>>> print(config.get("model_type"))
|
|
1635
|
+
"""
|
|
1636
|
+
cfg = AutoConfig.from_pretrained(model_id)
|
|
1637
|
+
return cfg.to_dict()
|
|
1638
|
+
|
|
1639
|
+
|
|
1640
|
+
# =============================================================================
|
|
1641
|
+
# Visualization Functions
|
|
1642
|
+
# =============================================================================
|
|
1643
|
+
|
|
1644
|
+
|
|
1645
|
+
def _load_image_for_display(
|
|
1646
|
+
source: Union[str, np.ndarray, Image.Image],
|
|
1647
|
+
) -> Tuple[np.ndarray, Optional[Dict]]:
|
|
1648
|
+
"""Load an image for display purposes.
|
|
1649
|
+
|
|
1650
|
+
Args:
|
|
1651
|
+
source: Image source (path, array, or PIL Image).
|
|
1652
|
+
|
|
1653
|
+
Returns:
|
|
1654
|
+
Tuple of (RGB image array in HWC format, metadata dict or None).
|
|
1655
|
+
"""
|
|
1656
|
+
metadata = None
|
|
1657
|
+
|
|
1658
|
+
if isinstance(source, str):
|
|
1659
|
+
try:
|
|
1660
|
+
with rasterio.open(source) as src:
|
|
1661
|
+
data = src.read()
|
|
1662
|
+
metadata = {
|
|
1663
|
+
"crs": src.crs,
|
|
1664
|
+
"transform": src.transform,
|
|
1665
|
+
"bounds": src.bounds,
|
|
1666
|
+
}
|
|
1667
|
+
# Convert to HWC
|
|
1668
|
+
if data.shape[0] > 3:
|
|
1669
|
+
data = data[:3]
|
|
1670
|
+
elif data.shape[0] == 1:
|
|
1671
|
+
data = np.repeat(data, 3, axis=0)
|
|
1672
|
+
img = data.transpose(1, 2, 0)
|
|
1673
|
+
|
|
1674
|
+
# Normalize to uint8
|
|
1675
|
+
if img.dtype != np.uint8:
|
|
1676
|
+
for i in range(img.shape[-1]):
|
|
1677
|
+
band = img[..., i].astype(np.float32)
|
|
1678
|
+
p2, p98 = np.percentile(band, [2, 98])
|
|
1679
|
+
if p98 > p2:
|
|
1680
|
+
img[..., i] = np.clip(
|
|
1681
|
+
(band - p2) / (p98 - p2) * 255, 0, 255
|
|
1682
|
+
)
|
|
1683
|
+
else:
|
|
1684
|
+
img[..., i] = 0
|
|
1685
|
+
img = img.astype(np.uint8)
|
|
1686
|
+
return img, metadata
|
|
1687
|
+
except Exception:
|
|
1688
|
+
pass
|
|
1689
|
+
|
|
1690
|
+
# Try as regular image
|
|
1691
|
+
img = np.array(Image.open(source).convert("RGB"))
|
|
1692
|
+
return img, None
|
|
1693
|
+
|
|
1694
|
+
elif isinstance(source, np.ndarray):
|
|
1695
|
+
if source.ndim == 3 and source.shape[0] in [1, 3, 4]:
|
|
1696
|
+
source = source.transpose(1, 2, 0)
|
|
1697
|
+
if source.ndim == 2:
|
|
1698
|
+
source = np.stack([source] * 3, axis=-1)
|
|
1699
|
+
elif source.shape[-1] > 3:
|
|
1700
|
+
source = source[..., :3]
|
|
1701
|
+
if source.dtype != np.uint8:
|
|
1702
|
+
source = (
|
|
1703
|
+
(source - source.min()) / (source.max() - source.min() + 1e-8) * 255
|
|
1704
|
+
).astype(np.uint8)
|
|
1705
|
+
return source, None
|
|
1706
|
+
|
|
1707
|
+
elif isinstance(source, Image.Image):
|
|
1708
|
+
return np.array(source.convert("RGB")), None
|
|
1709
|
+
|
|
1710
|
+
else:
|
|
1711
|
+
raise TypeError(f"Unsupported source type: {type(source)}")
|
|
1712
|
+
|
|
1713
|
+
|
|
1714
|
+
def show_image(
|
|
1715
|
+
source: Union[str, np.ndarray, Image.Image],
|
|
1716
|
+
figsize: Tuple[int, int] = (10, 10),
|
|
1717
|
+
title: Optional[str] = None,
|
|
1718
|
+
axis_off: bool = True,
|
|
1719
|
+
**kwargs: Any,
|
|
1720
|
+
) -> "plt.Figure":
|
|
1721
|
+
"""Display an image (GeoTIFF or regular image).
|
|
1722
|
+
|
|
1723
|
+
Args:
|
|
1724
|
+
source: Image source (path to file, numpy array, or PIL Image).
|
|
1725
|
+
figsize: Figure size as (width, height).
|
|
1726
|
+
title: Optional title for the plot.
|
|
1727
|
+
axis_off: Whether to hide axes.
|
|
1728
|
+
**kwargs: Additional arguments passed to plt.imshow().
|
|
1729
|
+
|
|
1730
|
+
Returns:
|
|
1731
|
+
Matplotlib figure object.
|
|
1732
|
+
|
|
1733
|
+
Example:
|
|
1734
|
+
>>> fig = show_image("aerial.tif", title="Aerial Image")
|
|
1735
|
+
"""
|
|
1736
|
+
import matplotlib.pyplot as plt
|
|
1737
|
+
|
|
1738
|
+
img, _ = _load_image_for_display(source)
|
|
1739
|
+
|
|
1740
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
1741
|
+
ax.imshow(img, **kwargs)
|
|
1742
|
+
|
|
1743
|
+
if title:
|
|
1744
|
+
ax.set_title(title)
|
|
1745
|
+
if axis_off:
|
|
1746
|
+
ax.axis("off")
|
|
1747
|
+
|
|
1748
|
+
plt.tight_layout()
|
|
1749
|
+
return fig
|
|
1750
|
+
|
|
1751
|
+
|
|
1752
|
+
def show_detections(
|
|
1753
|
+
source: Union[str, np.ndarray, Image.Image],
|
|
1754
|
+
detections: Dict[str, Any],
|
|
1755
|
+
figsize: Tuple[int, int] = (12, 10),
|
|
1756
|
+
title: Optional[str] = None,
|
|
1757
|
+
box_color: str = "red",
|
|
1758
|
+
text_color: str = "white",
|
|
1759
|
+
linewidth: int = 2,
|
|
1760
|
+
fontsize: int = 10,
|
|
1761
|
+
show_scores: bool = True,
|
|
1762
|
+
axis_off: bool = True,
|
|
1763
|
+
**kwargs: Any,
|
|
1764
|
+
) -> "plt.Figure":
|
|
1765
|
+
"""Display an image with detection bounding boxes.
|
|
1766
|
+
|
|
1767
|
+
Args:
|
|
1768
|
+
source: Image source (path to file, numpy array, or PIL Image).
|
|
1769
|
+
detections: Detection results dictionary with 'boxes', 'scores', 'labels'.
|
|
1770
|
+
figsize: Figure size as (width, height).
|
|
1771
|
+
title: Optional title for the plot.
|
|
1772
|
+
box_color: Color for bounding boxes (can be single color or list).
|
|
1773
|
+
text_color: Color for label text.
|
|
1774
|
+
linewidth: Width of bounding box lines.
|
|
1775
|
+
fontsize: Font size for labels.
|
|
1776
|
+
show_scores: Whether to show confidence scores.
|
|
1777
|
+
axis_off: Whether to hide axes.
|
|
1778
|
+
**kwargs: Additional arguments passed to plt.imshow().
|
|
1779
|
+
|
|
1780
|
+
Returns:
|
|
1781
|
+
Matplotlib figure object.
|
|
1782
|
+
|
|
1783
|
+
Example:
|
|
1784
|
+
>>> result = geoai.auto.object_detection("aerial.tif", labels=["building", "tree"])
|
|
1785
|
+
>>> fig = show_detections("aerial.tif", result)
|
|
1786
|
+
"""
|
|
1787
|
+
import matplotlib.pyplot as plt
|
|
1788
|
+
import matplotlib.patches as patches
|
|
1789
|
+
|
|
1790
|
+
img, _ = _load_image_for_display(source)
|
|
1791
|
+
|
|
1792
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
1793
|
+
ax.imshow(img, **kwargs)
|
|
1794
|
+
|
|
1795
|
+
boxes = detections.get("boxes", [])
|
|
1796
|
+
scores = detections.get("scores", [None] * len(boxes))
|
|
1797
|
+
labels = detections.get("labels", [None] * len(boxes))
|
|
1798
|
+
|
|
1799
|
+
# Handle color list
|
|
1800
|
+
if isinstance(box_color, str):
|
|
1801
|
+
colors = [box_color] * len(boxes)
|
|
1802
|
+
else:
|
|
1803
|
+
colors = box_color
|
|
1804
|
+
|
|
1805
|
+
for i, (bbox, score, label) in enumerate(zip(boxes, scores, labels)):
|
|
1806
|
+
x1, y1, x2, y2 = bbox
|
|
1807
|
+
width = x2 - x1
|
|
1808
|
+
height = y2 - y1
|
|
1809
|
+
|
|
1810
|
+
color = colors[i % len(colors)]
|
|
1811
|
+
|
|
1812
|
+
rect = patches.Rectangle(
|
|
1813
|
+
(x1, y1),
|
|
1814
|
+
width,
|
|
1815
|
+
height,
|
|
1816
|
+
linewidth=linewidth,
|
|
1817
|
+
edgecolor=color,
|
|
1818
|
+
facecolor="none",
|
|
1819
|
+
)
|
|
1820
|
+
ax.add_patch(rect)
|
|
1821
|
+
|
|
1822
|
+
# Add label
|
|
1823
|
+
if label is not None:
|
|
1824
|
+
text = str(label)
|
|
1825
|
+
if show_scores and score is not None:
|
|
1826
|
+
text = f"{label}: {score:.2f}"
|
|
1827
|
+
ax.text(
|
|
1828
|
+
x1,
|
|
1829
|
+
y1 - 5,
|
|
1830
|
+
text,
|
|
1831
|
+
color=text_color,
|
|
1832
|
+
fontsize=fontsize,
|
|
1833
|
+
bbox=dict(boxstyle="round,pad=0.3", facecolor=color, alpha=0.7),
|
|
1834
|
+
)
|
|
1835
|
+
|
|
1836
|
+
if title:
|
|
1837
|
+
ax.set_title(title)
|
|
1838
|
+
if axis_off:
|
|
1839
|
+
ax.axis("off")
|
|
1840
|
+
|
|
1841
|
+
plt.tight_layout()
|
|
1842
|
+
return fig
|
|
1843
|
+
|
|
1844
|
+
|
|
1845
|
+
def show_segmentation(
|
|
1846
|
+
source: Union[str, np.ndarray, Image.Image],
|
|
1847
|
+
mask: np.ndarray,
|
|
1848
|
+
figsize: Tuple[int, int] = (14, 6),
|
|
1849
|
+
title: Optional[str] = None,
|
|
1850
|
+
alpha: float = 0.5,
|
|
1851
|
+
cmap: str = "tab20",
|
|
1852
|
+
show_original: bool = True,
|
|
1853
|
+
axis_off: bool = True,
|
|
1854
|
+
**kwargs: Any,
|
|
1855
|
+
) -> "plt.Figure":
|
|
1856
|
+
"""Display segmentation results overlaid on the original image.
|
|
1857
|
+
|
|
1858
|
+
Args:
|
|
1859
|
+
source: Image source (path to file, numpy array, or PIL Image).
|
|
1860
|
+
mask: Segmentation mask array.
|
|
1861
|
+
figsize: Figure size as (width, height).
|
|
1862
|
+
title: Optional title for the plot.
|
|
1863
|
+
alpha: Transparency of the mask overlay.
|
|
1864
|
+
cmap: Colormap for the segmentation mask.
|
|
1865
|
+
show_original: Whether to show original image side-by-side.
|
|
1866
|
+
axis_off: Whether to hide axes.
|
|
1867
|
+
**kwargs: Additional arguments passed to plt.imshow().
|
|
1868
|
+
|
|
1869
|
+
Returns:
|
|
1870
|
+
Matplotlib figure object.
|
|
1871
|
+
|
|
1872
|
+
Example:
|
|
1873
|
+
>>> result = geoai.auto.semantic_segmentation("aerial.tif", output_path="seg.tif")
|
|
1874
|
+
>>> fig = show_segmentation("aerial.tif", result["mask"])
|
|
1875
|
+
"""
|
|
1876
|
+
import matplotlib.pyplot as plt
|
|
1877
|
+
|
|
1878
|
+
img, _ = _load_image_for_display(source)
|
|
1879
|
+
|
|
1880
|
+
# Resize mask if necessary
|
|
1881
|
+
if mask.shape[:2] != img.shape[:2]:
|
|
1882
|
+
mask = cv2.resize(
|
|
1883
|
+
mask.astype(np.float32),
|
|
1884
|
+
(img.shape[1], img.shape[0]),
|
|
1885
|
+
interpolation=cv2.INTER_NEAREST,
|
|
1886
|
+
).astype(mask.dtype)
|
|
1887
|
+
|
|
1888
|
+
if show_original:
|
|
1889
|
+
fig, axes = plt.subplots(1, 2, figsize=figsize)
|
|
1890
|
+
|
|
1891
|
+
axes[0].imshow(img, **kwargs)
|
|
1892
|
+
axes[0].set_title("Original Image")
|
|
1893
|
+
if axis_off:
|
|
1894
|
+
axes[0].axis("off")
|
|
1895
|
+
|
|
1896
|
+
axes[1].imshow(img, **kwargs)
|
|
1897
|
+
axes[1].imshow(mask, alpha=alpha, cmap=cmap)
|
|
1898
|
+
axes[1].set_title(title or "Segmentation Overlay")
|
|
1899
|
+
if axis_off:
|
|
1900
|
+
axes[1].axis("off")
|
|
1901
|
+
else:
|
|
1902
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
1903
|
+
ax.imshow(img, **kwargs)
|
|
1904
|
+
ax.imshow(mask, alpha=alpha, cmap=cmap)
|
|
1905
|
+
if title:
|
|
1906
|
+
ax.set_title(title)
|
|
1907
|
+
if axis_off:
|
|
1908
|
+
ax.axis("off")
|
|
1909
|
+
|
|
1910
|
+
plt.tight_layout()
|
|
1911
|
+
return fig
|
|
1912
|
+
|
|
1913
|
+
|
|
1914
|
+
def show_depth(
|
|
1915
|
+
source: Union[str, np.ndarray, Image.Image],
|
|
1916
|
+
depth: np.ndarray,
|
|
1917
|
+
figsize: Tuple[int, int] = (14, 6),
|
|
1918
|
+
title: Optional[str] = None,
|
|
1919
|
+
cmap: str = "plasma",
|
|
1920
|
+
show_original: bool = True,
|
|
1921
|
+
show_colorbar: bool = True,
|
|
1922
|
+
axis_off: bool = True,
|
|
1923
|
+
**kwargs: Any,
|
|
1924
|
+
) -> "plt.Figure":
|
|
1925
|
+
"""Display depth estimation results.
|
|
1926
|
+
|
|
1927
|
+
Args:
|
|
1928
|
+
source: Image source (path to file, numpy array, or PIL Image).
|
|
1929
|
+
depth: Depth map array.
|
|
1930
|
+
figsize: Figure size as (width, height).
|
|
1931
|
+
title: Optional title for the plot.
|
|
1932
|
+
cmap: Colormap for the depth map.
|
|
1933
|
+
show_original: Whether to show original image side-by-side.
|
|
1934
|
+
show_colorbar: Whether to show a colorbar.
|
|
1935
|
+
axis_off: Whether to hide axes.
|
|
1936
|
+
**kwargs: Additional arguments passed to plt.imshow().
|
|
1937
|
+
|
|
1938
|
+
Returns:
|
|
1939
|
+
Matplotlib figure object.
|
|
1940
|
+
|
|
1941
|
+
Example:
|
|
1942
|
+
>>> result = geoai.auto.depth_estimation("aerial.tif", output_path="depth.tif")
|
|
1943
|
+
>>> fig = show_depth("aerial.tif", result["depth"])
|
|
1944
|
+
"""
|
|
1945
|
+
import matplotlib.pyplot as plt
|
|
1946
|
+
|
|
1947
|
+
img, _ = _load_image_for_display(source)
|
|
1948
|
+
|
|
1949
|
+
# Resize depth if necessary
|
|
1950
|
+
if depth.shape[:2] != img.shape[:2]:
|
|
1951
|
+
depth = cv2.resize(
|
|
1952
|
+
depth.astype(np.float32),
|
|
1953
|
+
(img.shape[1], img.shape[0]),
|
|
1954
|
+
interpolation=cv2.INTER_LINEAR,
|
|
1955
|
+
)
|
|
1956
|
+
|
|
1957
|
+
if show_original:
|
|
1958
|
+
fig, axes = plt.subplots(1, 2, figsize=figsize)
|
|
1959
|
+
|
|
1960
|
+
axes[0].imshow(img, **kwargs)
|
|
1961
|
+
axes[0].set_title("Original Image")
|
|
1962
|
+
if axis_off:
|
|
1963
|
+
axes[0].axis("off")
|
|
1964
|
+
|
|
1965
|
+
im = axes[1].imshow(depth, cmap=cmap)
|
|
1966
|
+
axes[1].set_title(title or "Depth Estimation")
|
|
1967
|
+
if axis_off:
|
|
1968
|
+
axes[1].axis("off")
|
|
1969
|
+
if show_colorbar:
|
|
1970
|
+
plt.colorbar(im, ax=axes[1], label="Relative Depth")
|
|
1971
|
+
else:
|
|
1972
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
1973
|
+
im = ax.imshow(depth, cmap=cmap)
|
|
1974
|
+
if title:
|
|
1975
|
+
ax.set_title(title)
|
|
1976
|
+
if axis_off:
|
|
1977
|
+
ax.axis("off")
|
|
1978
|
+
if show_colorbar:
|
|
1979
|
+
plt.colorbar(im, ax=ax, label="Relative Depth")
|
|
1980
|
+
|
|
1981
|
+
plt.tight_layout()
|
|
1982
|
+
return fig
|