geoai-py 0.26.0__py2.py3-none-any.whl → 0.28.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +41 -1
- geoai/auto.py +4 -1
- geoai/change_detection.py +1 -1
- geoai/detectron2.py +4 -1
- geoai/extract.py +10 -7
- geoai/hf.py +3 -3
- geoai/moondream.py +2 -2
- geoai/onnx.py +1155 -0
- geoai/prithvi.py +92 -7
- geoai/sam.py +2 -1
- geoai/segment.py +10 -1
- geoai/timm_regress.py +1652 -0
- geoai/train.py +1 -1
- geoai/utils.py +550 -1
- {geoai_py-0.26.0.dist-info → geoai_py-0.28.0.dist-info}/METADATA +9 -7
- {geoai_py-0.26.0.dist-info → geoai_py-0.28.0.dist-info}/RECORD +20 -18
- {geoai_py-0.26.0.dist-info → geoai_py-0.28.0.dist-info}/WHEEL +1 -1
- {geoai_py-0.26.0.dist-info → geoai_py-0.28.0.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.26.0.dist-info → geoai_py-0.28.0.dist-info}/licenses/LICENSE +0 -0
- {geoai_py-0.26.0.dist-info → geoai_py-0.28.0.dist-info}/top_level.txt +0 -0
geoai/onnx.py
ADDED
|
@@ -0,0 +1,1155 @@
|
|
|
1
|
+
"""ONNX Runtime support for geospatial model inference.
|
|
2
|
+
|
|
3
|
+
This module provides ONNXGeoModel for loading and running inference with
|
|
4
|
+
ONNX models on geospatial data (GeoTIFF), and export_to_onnx for converting
|
|
5
|
+
existing PyTorch/Hugging Face models to ONNX format.
|
|
6
|
+
|
|
7
|
+
Supported tasks:
|
|
8
|
+
- Semantic segmentation (e.g., SegFormer, Mask2Former)
|
|
9
|
+
- Image classification (e.g., ViT, ResNet)
|
|
10
|
+
- Object detection (e.g., DETR, YOLOS)
|
|
11
|
+
- Depth estimation (e.g., Depth Anything, DPT)
|
|
12
|
+
|
|
13
|
+
Requirements:
|
|
14
|
+
- onnx
|
|
15
|
+
- onnxruntime (or onnxruntime-gpu for GPU acceleration)
|
|
16
|
+
|
|
17
|
+
Install with::
|
|
18
|
+
|
|
19
|
+
pip install geoai-py[onnx]
|
|
20
|
+
|
|
21
|
+
Example:
|
|
22
|
+
>>> from geoai import export_to_onnx, ONNXGeoModel
|
|
23
|
+
>>> # Export a HuggingFace model to ONNX
|
|
24
|
+
>>> export_to_onnx(
|
|
25
|
+
... "nvidia/segformer-b0-finetuned-ade-512-512",
|
|
26
|
+
... "segformer.onnx",
|
|
27
|
+
... task="semantic-segmentation",
|
|
28
|
+
... )
|
|
29
|
+
>>> # Load and run inference with the ONNX model
|
|
30
|
+
>>> model = ONNXGeoModel("segformer.onnx", task="semantic-segmentation")
|
|
31
|
+
>>> result = model.predict("input.tif", output_path="output.tif")
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
import json
|
|
35
|
+
import os
|
|
36
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
37
|
+
|
|
38
|
+
import geopandas as gpd
|
|
39
|
+
import numpy as np
|
|
40
|
+
import rasterio
|
|
41
|
+
from PIL import Image
|
|
42
|
+
from rasterio.features import shapes
|
|
43
|
+
from rasterio.windows import Window
|
|
44
|
+
from shapely.geometry import shape
|
|
45
|
+
from tqdm import tqdm
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _check_onnx_deps() -> None:
|
|
49
|
+
"""Check that onnx and onnxruntime are installed.
|
|
50
|
+
|
|
51
|
+
Raises:
|
|
52
|
+
ImportError: If onnx or onnxruntime is not installed.
|
|
53
|
+
"""
|
|
54
|
+
try:
|
|
55
|
+
import onnx # noqa: F401
|
|
56
|
+
except ImportError:
|
|
57
|
+
raise ImportError(
|
|
58
|
+
"The 'onnx' package is required for ONNX support. "
|
|
59
|
+
"Install it with: pip install geoai-py[onnx]"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
try:
|
|
63
|
+
import onnxruntime # noqa: F401
|
|
64
|
+
except ImportError:
|
|
65
|
+
raise ImportError(
|
|
66
|
+
"The 'onnxruntime' package is required for ONNX support. "
|
|
67
|
+
"Install it with: pip install geoai-py[onnx] "
|
|
68
|
+
"(use 'onnxruntime-gpu' for GPU acceleration)"
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _check_torch_deps() -> None:
|
|
73
|
+
"""Check that torch and transformers are installed (needed for export).
|
|
74
|
+
|
|
75
|
+
Raises:
|
|
76
|
+
ImportError: If torch or transformers is not installed.
|
|
77
|
+
"""
|
|
78
|
+
try:
|
|
79
|
+
import torch # noqa: F401
|
|
80
|
+
except ImportError:
|
|
81
|
+
raise ImportError(
|
|
82
|
+
"PyTorch is required for exporting models to ONNX. "
|
|
83
|
+
"Install it from https://pytorch.org/"
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
try:
|
|
87
|
+
import transformers # noqa: F401
|
|
88
|
+
except ImportError:
|
|
89
|
+
raise ImportError(
|
|
90
|
+
"The 'transformers' package is required for exporting "
|
|
91
|
+
"Hugging Face models to ONNX. "
|
|
92
|
+
"Install it with: pip install transformers"
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
# ---------------------------------------------------------------------------
|
|
97
|
+
# Export helpers
|
|
98
|
+
# ---------------------------------------------------------------------------
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def export_to_onnx(
|
|
102
|
+
model_name_or_path: str,
|
|
103
|
+
output_path: str,
|
|
104
|
+
task: Optional[str] = None,
|
|
105
|
+
input_height: int = 512,
|
|
106
|
+
input_width: int = 512,
|
|
107
|
+
input_channels: int = 3,
|
|
108
|
+
opset_version: int = 17,
|
|
109
|
+
dynamic_axes: Optional[Dict[str, Dict[int, str]]] = None,
|
|
110
|
+
simplify: bool = True,
|
|
111
|
+
device: Optional[str] = None,
|
|
112
|
+
**kwargs: Any,
|
|
113
|
+
) -> str:
|
|
114
|
+
"""Export a PyTorch / Hugging Face model to ONNX format.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
model_name_or_path: Hugging Face model name or local checkpoint path.
|
|
118
|
+
output_path: Path where the ``.onnx`` file will be saved.
|
|
119
|
+
task: Model task. One of ``"semantic-segmentation"``,
|
|
120
|
+
``"image-classification"``, ``"object-detection"``, or
|
|
121
|
+
``"depth-estimation"``. If *None* the function tries to infer
|
|
122
|
+
the task from the model configuration.
|
|
123
|
+
input_height: Height of the dummy input tensor (pixels).
|
|
124
|
+
input_width: Width of the dummy input tensor (pixels).
|
|
125
|
+
input_channels: Number of input channels (default 3 for RGB).
|
|
126
|
+
opset_version: ONNX opset version (default 17).
|
|
127
|
+
dynamic_axes: Optional mapping of dynamic axes for variable-size
|
|
128
|
+
inputs/outputs. When *None* a sensible default is used so that
|
|
129
|
+
batch size and spatial dimensions are dynamic.
|
|
130
|
+
simplify: Whether to simplify the exported graph with
|
|
131
|
+
``onnxsim.simplify`` (requires the ``onnxsim`` package).
|
|
132
|
+
device: Device used for tracing (``"cpu"`` recommended for export).
|
|
133
|
+
**kwargs: Extra keyword arguments forwarded to
|
|
134
|
+
``AutoModel.from_pretrained``.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
Absolute path to the saved ONNX file.
|
|
138
|
+
|
|
139
|
+
Raises:
|
|
140
|
+
ImportError: If required packages are missing.
|
|
141
|
+
ValueError: If the task cannot be determined.
|
|
142
|
+
|
|
143
|
+
Example:
|
|
144
|
+
>>> export_to_onnx(
|
|
145
|
+
... "nvidia/segformer-b0-finetuned-ade-512-512",
|
|
146
|
+
... "segformer.onnx",
|
|
147
|
+
... task="semantic-segmentation",
|
|
148
|
+
... )
|
|
149
|
+
'segformer.onnx'
|
|
150
|
+
"""
|
|
151
|
+
_check_torch_deps()
|
|
152
|
+
import onnx # noqa: F811
|
|
153
|
+
import torch
|
|
154
|
+
from transformers import (
|
|
155
|
+
AutoConfig,
|
|
156
|
+
AutoImageProcessor,
|
|
157
|
+
AutoModelForDepthEstimation,
|
|
158
|
+
AutoModelForImageClassification,
|
|
159
|
+
AutoModelForObjectDetection,
|
|
160
|
+
AutoModelForSemanticSegmentation,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
if device is None:
|
|
164
|
+
device = "cpu"
|
|
165
|
+
|
|
166
|
+
# ------------------------------------------------------------------
|
|
167
|
+
# Load model
|
|
168
|
+
# ------------------------------------------------------------------
|
|
169
|
+
task_model_map = {
|
|
170
|
+
"segmentation": AutoModelForSemanticSegmentation,
|
|
171
|
+
"semantic-segmentation": AutoModelForSemanticSegmentation,
|
|
172
|
+
"classification": AutoModelForImageClassification,
|
|
173
|
+
"image-classification": AutoModelForImageClassification,
|
|
174
|
+
"object-detection": AutoModelForObjectDetection,
|
|
175
|
+
"depth-estimation": AutoModelForDepthEstimation,
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
if task and task in task_model_map:
|
|
179
|
+
model_cls = task_model_map[task]
|
|
180
|
+
else:
|
|
181
|
+
# Try to infer from config
|
|
182
|
+
try:
|
|
183
|
+
config = AutoConfig.from_pretrained(model_name_or_path)
|
|
184
|
+
architectures = getattr(config, "architectures", [])
|
|
185
|
+
if any("Segmentation" in a for a in architectures):
|
|
186
|
+
model_cls = AutoModelForSemanticSegmentation
|
|
187
|
+
task = task or "semantic-segmentation"
|
|
188
|
+
elif any("Classification" in a for a in architectures):
|
|
189
|
+
model_cls = AutoModelForImageClassification
|
|
190
|
+
task = task or "image-classification"
|
|
191
|
+
elif any("Detection" in a for a in architectures):
|
|
192
|
+
model_cls = AutoModelForObjectDetection
|
|
193
|
+
task = task or "object-detection"
|
|
194
|
+
elif any("Depth" in a for a in architectures):
|
|
195
|
+
model_cls = AutoModelForDepthEstimation
|
|
196
|
+
task = task or "depth-estimation"
|
|
197
|
+
else:
|
|
198
|
+
raise ValueError(
|
|
199
|
+
f"Cannot infer task from model config. "
|
|
200
|
+
f"Found architectures: {architectures}. "
|
|
201
|
+
f"Please specify task= explicitly."
|
|
202
|
+
)
|
|
203
|
+
except Exception as exc:
|
|
204
|
+
raise ValueError(
|
|
205
|
+
"Cannot determine the model task. " "Please specify task= explicitly."
|
|
206
|
+
) from exc
|
|
207
|
+
|
|
208
|
+
model = model_cls.from_pretrained(model_name_or_path, **kwargs)
|
|
209
|
+
model = model.to(device).eval()
|
|
210
|
+
|
|
211
|
+
# Try loading the image processor to get expected input size
|
|
212
|
+
try:
|
|
213
|
+
processor = AutoImageProcessor.from_pretrained(model_name_or_path)
|
|
214
|
+
if hasattr(processor, "size"):
|
|
215
|
+
size = processor.size
|
|
216
|
+
if isinstance(size, dict):
|
|
217
|
+
input_height = size.get("height", input_height)
|
|
218
|
+
input_width = size.get("width", input_width)
|
|
219
|
+
elif isinstance(size, (list, tuple)) and len(size) == 2:
|
|
220
|
+
input_height, input_width = size
|
|
221
|
+
except Exception:
|
|
222
|
+
pass # processor introspection is optional; fall back to defaults
|
|
223
|
+
|
|
224
|
+
# ------------------------------------------------------------------
|
|
225
|
+
# Build dummy input & dynamic axes
|
|
226
|
+
# ------------------------------------------------------------------
|
|
227
|
+
dummy_input = torch.randn(
|
|
228
|
+
1, input_channels, input_height, input_width, device=device
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
input_names = ["pixel_values"]
|
|
232
|
+
|
|
233
|
+
if task in ("segmentation", "semantic-segmentation", "depth-estimation"):
|
|
234
|
+
output_names = ["logits"]
|
|
235
|
+
elif task in ("classification", "image-classification"):
|
|
236
|
+
output_names = ["logits"]
|
|
237
|
+
elif task == "object-detection":
|
|
238
|
+
output_names = ["logits", "pred_boxes"]
|
|
239
|
+
else:
|
|
240
|
+
output_names = ["output"]
|
|
241
|
+
|
|
242
|
+
if dynamic_axes is None:
|
|
243
|
+
dynamic_axes = {
|
|
244
|
+
"pixel_values": {0: "batch", 2: "height", 3: "width"},
|
|
245
|
+
}
|
|
246
|
+
for name in output_names:
|
|
247
|
+
dynamic_axes[name] = {0: "batch"}
|
|
248
|
+
|
|
249
|
+
# ------------------------------------------------------------------
|
|
250
|
+
# Export
|
|
251
|
+
# ------------------------------------------------------------------
|
|
252
|
+
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
|
|
253
|
+
|
|
254
|
+
torch.onnx.export(
|
|
255
|
+
model,
|
|
256
|
+
({"pixel_values": dummy_input},),
|
|
257
|
+
output_path,
|
|
258
|
+
input_names=input_names,
|
|
259
|
+
output_names=output_names,
|
|
260
|
+
dynamic_axes=dynamic_axes,
|
|
261
|
+
opset_version=opset_version,
|
|
262
|
+
do_constant_folding=True,
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
# Validate
|
|
266
|
+
onnx_model = onnx.load(output_path)
|
|
267
|
+
onnx.checker.check_model(onnx_model)
|
|
268
|
+
|
|
269
|
+
# Optional simplification
|
|
270
|
+
if simplify:
|
|
271
|
+
try:
|
|
272
|
+
import onnxsim
|
|
273
|
+
|
|
274
|
+
onnx_model_simplified, check = onnxsim.simplify(onnx_model)
|
|
275
|
+
if check:
|
|
276
|
+
onnx.save(onnx_model_simplified, output_path)
|
|
277
|
+
except ImportError:
|
|
278
|
+
pass # onnxsim is optional
|
|
279
|
+
except Exception:
|
|
280
|
+
pass # simplification can fail for some models; keep original
|
|
281
|
+
|
|
282
|
+
# ------------------------------------------------------------------
|
|
283
|
+
# Save metadata alongside the model
|
|
284
|
+
# ------------------------------------------------------------------
|
|
285
|
+
meta = {
|
|
286
|
+
"model_name": model_name_or_path,
|
|
287
|
+
"task": task,
|
|
288
|
+
"input_height": input_height,
|
|
289
|
+
"input_width": input_width,
|
|
290
|
+
"input_channels": input_channels,
|
|
291
|
+
"opset_version": opset_version,
|
|
292
|
+
"output_names": output_names,
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
# Include id2label when available
|
|
296
|
+
config = model.config if hasattr(model, "config") else None
|
|
297
|
+
if config and hasattr(config, "id2label"):
|
|
298
|
+
meta["id2label"] = {str(k): v for k, v in config.id2label.items()}
|
|
299
|
+
if config and hasattr(config, "num_labels"):
|
|
300
|
+
meta["num_labels"] = config.num_labels
|
|
301
|
+
|
|
302
|
+
meta_path = output_path + ".json"
|
|
303
|
+
with open(meta_path, "w") as fh:
|
|
304
|
+
json.dump(meta, fh, indent=2)
|
|
305
|
+
|
|
306
|
+
print(f"ONNX model exported to {output_path}")
|
|
307
|
+
print(f"Metadata saved to {meta_path}")
|
|
308
|
+
return os.path.abspath(output_path)
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
# ---------------------------------------------------------------------------
|
|
312
|
+
# ONNXGeoModel
|
|
313
|
+
# ---------------------------------------------------------------------------
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
class ONNXGeoModel:
|
|
317
|
+
"""ONNX Runtime model for geospatial inference with GeoTIFF support.
|
|
318
|
+
|
|
319
|
+
This class mirrors the :class:`~geoai.auto.AutoGeoModel` API but uses
|
|
320
|
+
ONNX Runtime instead of PyTorch for inference, enabling deployment on
|
|
321
|
+
edge devices and environments without GPU drivers.
|
|
322
|
+
|
|
323
|
+
Attributes:
|
|
324
|
+
session: The ``onnxruntime.InferenceSession`` instance.
|
|
325
|
+
task (str): The model task (e.g. ``"semantic-segmentation"``).
|
|
326
|
+
tile_size (int): Tile size used for processing large images.
|
|
327
|
+
overlap (int): Overlap between adjacent tiles.
|
|
328
|
+
metadata (dict): Model metadata loaded from the sidecar JSON file.
|
|
329
|
+
|
|
330
|
+
Example:
|
|
331
|
+
>>> model = ONNXGeoModel("segformer.onnx", task="semantic-segmentation")
|
|
332
|
+
>>> result = model.predict("input.tif", output_path="output.tif")
|
|
333
|
+
"""
|
|
334
|
+
|
|
335
|
+
def __init__(
|
|
336
|
+
self,
|
|
337
|
+
model_path: str,
|
|
338
|
+
task: Optional[str] = None,
|
|
339
|
+
providers: Optional[List[str]] = None,
|
|
340
|
+
tile_size: int = 1024,
|
|
341
|
+
overlap: int = 128,
|
|
342
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
343
|
+
) -> None:
|
|
344
|
+
"""Load an ONNX model for geospatial inference.
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
model_path: Path to the ``.onnx`` model file.
|
|
348
|
+
task: Model task. One of ``"semantic-segmentation"``,
|
|
349
|
+
``"image-classification"``, ``"object-detection"``, or
|
|
350
|
+
``"depth-estimation"``. If *None*, the task is read from the
|
|
351
|
+
sidecar ``<model>.onnx.json`` metadata file.
|
|
352
|
+
providers: ONNX Runtime execution providers in priority order.
|
|
353
|
+
Defaults to ``["CUDAExecutionProvider",
|
|
354
|
+
"CPUExecutionProvider"]``.
|
|
355
|
+
tile_size: Tile size for processing large images.
|
|
356
|
+
overlap: Overlap between adjacent tiles (in pixels).
|
|
357
|
+
metadata: Optional pre-loaded metadata dict. When *None* the
|
|
358
|
+
constructor looks for ``<model_path>.json``.
|
|
359
|
+
|
|
360
|
+
Raises:
|
|
361
|
+
FileNotFoundError: If *model_path* does not exist.
|
|
362
|
+
ImportError: If onnxruntime is not installed.
|
|
363
|
+
"""
|
|
364
|
+
_check_onnx_deps()
|
|
365
|
+
import onnxruntime as ort
|
|
366
|
+
|
|
367
|
+
if not os.path.isfile(model_path):
|
|
368
|
+
raise FileNotFoundError(f"ONNX model not found: {model_path}")
|
|
369
|
+
|
|
370
|
+
self.model_path = os.path.abspath(model_path)
|
|
371
|
+
self.tile_size = tile_size
|
|
372
|
+
self.overlap = overlap
|
|
373
|
+
|
|
374
|
+
# Load sidecar metadata
|
|
375
|
+
if metadata is not None:
|
|
376
|
+
self.metadata = metadata
|
|
377
|
+
else:
|
|
378
|
+
meta_path = model_path + ".json"
|
|
379
|
+
if os.path.isfile(meta_path):
|
|
380
|
+
with open(meta_path) as fh:
|
|
381
|
+
self.metadata = json.load(fh)
|
|
382
|
+
else:
|
|
383
|
+
self.metadata = {}
|
|
384
|
+
|
|
385
|
+
# Resolve task
|
|
386
|
+
self.task = task or self.metadata.get("task")
|
|
387
|
+
|
|
388
|
+
# Label mapping
|
|
389
|
+
self.id2label: Dict[int, str] = {}
|
|
390
|
+
raw = self.metadata.get("id2label", {})
|
|
391
|
+
if raw:
|
|
392
|
+
self.id2label = {int(k): v for k, v in raw.items()}
|
|
393
|
+
|
|
394
|
+
# Create session
|
|
395
|
+
if providers is None:
|
|
396
|
+
providers = ort.get_available_providers()
|
|
397
|
+
self.session = ort.InferenceSession(model_path, providers=providers)
|
|
398
|
+
|
|
399
|
+
# Inspect inputs / outputs
|
|
400
|
+
self.input_name = self.session.get_inputs()[0].name
|
|
401
|
+
self.input_shape = self.session.get_inputs()[0].shape # may have str dims
|
|
402
|
+
self.output_names = [o.name for o in self.session.get_outputs()]
|
|
403
|
+
|
|
404
|
+
# Determine expected spatial size from metadata or model input shape
|
|
405
|
+
self._model_height = self.metadata.get("input_height")
|
|
406
|
+
self._model_width = self.metadata.get("input_width")
|
|
407
|
+
if self._model_height is None and isinstance(self.input_shape, list):
|
|
408
|
+
if len(self.input_shape) == 4:
|
|
409
|
+
h, w = self.input_shape[2], self.input_shape[3]
|
|
410
|
+
if isinstance(h, int) and isinstance(w, int):
|
|
411
|
+
self._model_height = h
|
|
412
|
+
self._model_width = w
|
|
413
|
+
|
|
414
|
+
active = self.session.get_providers()
|
|
415
|
+
print(f"ONNX model loaded from {model_path}")
|
|
416
|
+
print(f"Execution providers: {active}")
|
|
417
|
+
if self.task:
|
|
418
|
+
print(f"Task: {self.task}")
|
|
419
|
+
|
|
420
|
+
# ------------------------------------------------------------------
|
|
421
|
+
# Image I/O helpers (mirrors AutoGeoImageProcessor)
|
|
422
|
+
# ------------------------------------------------------------------
|
|
423
|
+
|
|
424
|
+
@staticmethod
|
|
425
|
+
def load_geotiff(
|
|
426
|
+
source: Union[str, "rasterio.DatasetReader"],
|
|
427
|
+
window: Optional[Window] = None,
|
|
428
|
+
bands: Optional[List[int]] = None,
|
|
429
|
+
) -> Tuple[np.ndarray, Dict]:
|
|
430
|
+
"""Load a GeoTIFF file and return data with metadata.
|
|
431
|
+
|
|
432
|
+
Args:
|
|
433
|
+
source: Path to GeoTIFF file or open rasterio DatasetReader.
|
|
434
|
+
window: Optional rasterio Window for reading a subset.
|
|
435
|
+
bands: List of band indices to read (1-indexed).
|
|
436
|
+
|
|
437
|
+
Returns:
|
|
438
|
+
Tuple of (image array in CHW format, metadata dict).
|
|
439
|
+
"""
|
|
440
|
+
should_close = False
|
|
441
|
+
if isinstance(source, str):
|
|
442
|
+
src = rasterio.open(source)
|
|
443
|
+
should_close = True
|
|
444
|
+
else:
|
|
445
|
+
src = source
|
|
446
|
+
|
|
447
|
+
try:
|
|
448
|
+
if bands:
|
|
449
|
+
data = src.read(bands, window=window)
|
|
450
|
+
else:
|
|
451
|
+
data = src.read(window=window)
|
|
452
|
+
|
|
453
|
+
profile = src.profile.copy()
|
|
454
|
+
if window:
|
|
455
|
+
profile.update(
|
|
456
|
+
{
|
|
457
|
+
"height": window.height,
|
|
458
|
+
"width": window.width,
|
|
459
|
+
"transform": src.window_transform(window),
|
|
460
|
+
}
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
metadata = {
|
|
464
|
+
"profile": profile,
|
|
465
|
+
"crs": src.crs,
|
|
466
|
+
"transform": profile["transform"],
|
|
467
|
+
"bounds": (
|
|
468
|
+
src.bounds
|
|
469
|
+
if not window
|
|
470
|
+
else rasterio.windows.bounds(window, src.transform)
|
|
471
|
+
),
|
|
472
|
+
"width": profile["width"],
|
|
473
|
+
"height": profile["height"],
|
|
474
|
+
"count": data.shape[0],
|
|
475
|
+
}
|
|
476
|
+
finally:
|
|
477
|
+
if should_close:
|
|
478
|
+
src.close()
|
|
479
|
+
|
|
480
|
+
return data, metadata
|
|
481
|
+
|
|
482
|
+
@staticmethod
|
|
483
|
+
def load_image(
|
|
484
|
+
source: Union[str, np.ndarray, "Image.Image"],
|
|
485
|
+
window: Optional[Window] = None,
|
|
486
|
+
bands: Optional[List[int]] = None,
|
|
487
|
+
) -> Tuple[np.ndarray, Optional[Dict]]:
|
|
488
|
+
"""Load an image from various sources.
|
|
489
|
+
|
|
490
|
+
Args:
|
|
491
|
+
source: Path to image file, numpy array, or PIL Image.
|
|
492
|
+
window: Optional rasterio Window (only for GeoTIFF).
|
|
493
|
+
bands: List of band indices (only for GeoTIFF, 1-indexed).
|
|
494
|
+
|
|
495
|
+
Returns:
|
|
496
|
+
Tuple of (image array in CHW format, metadata dict or None).
|
|
497
|
+
"""
|
|
498
|
+
if isinstance(source, str):
|
|
499
|
+
try:
|
|
500
|
+
with rasterio.open(source) as src:
|
|
501
|
+
if src.crs is not None or source.lower().endswith(
|
|
502
|
+
(".tif", ".tiff")
|
|
503
|
+
):
|
|
504
|
+
return ONNXGeoModel.load_geotiff(source, window, bands)
|
|
505
|
+
except (rasterio.RasterioIOError, rasterio.errors.RasterioIOError):
|
|
506
|
+
pass # not a rasterio-compatible file; fall through to PIL
|
|
507
|
+
|
|
508
|
+
image = Image.open(source).convert("RGB")
|
|
509
|
+
data = np.array(image).transpose(2, 0, 1)
|
|
510
|
+
return data, None
|
|
511
|
+
|
|
512
|
+
elif isinstance(source, np.ndarray):
|
|
513
|
+
if source.ndim == 2:
|
|
514
|
+
source = source[np.newaxis, :, :]
|
|
515
|
+
elif source.ndim == 3 and source.shape[2] in [1, 3, 4]:
|
|
516
|
+
source = source.transpose(2, 0, 1)
|
|
517
|
+
return source, None
|
|
518
|
+
|
|
519
|
+
elif isinstance(source, Image.Image):
|
|
520
|
+
data = np.array(source.convert("RGB")).transpose(2, 0, 1)
|
|
521
|
+
return data, None
|
|
522
|
+
|
|
523
|
+
else:
|
|
524
|
+
raise TypeError(f"Unsupported source type: {type(source)}")
|
|
525
|
+
|
|
526
|
+
# ------------------------------------------------------------------
|
|
527
|
+
# Preprocessing
|
|
528
|
+
# ------------------------------------------------------------------
|
|
529
|
+
|
|
530
|
+
def _prepare_input(
|
|
531
|
+
self,
|
|
532
|
+
data: np.ndarray,
|
|
533
|
+
target_height: Optional[int] = None,
|
|
534
|
+
target_width: Optional[int] = None,
|
|
535
|
+
) -> np.ndarray:
|
|
536
|
+
"""Prepare a CHW uint‑capable array for the ONNX model.
|
|
537
|
+
|
|
538
|
+
The method converts to 3‑channel RGB, normalizes to ``[0, 1]``
|
|
539
|
+
float32, resizes to the model's expected spatial dimensions and
|
|
540
|
+
adds a batch dimension.
|
|
541
|
+
|
|
542
|
+
Args:
|
|
543
|
+
data: Image array in CHW format.
|
|
544
|
+
target_height: Target height. Defaults to model metadata.
|
|
545
|
+
target_width: Target width. Defaults to model metadata.
|
|
546
|
+
|
|
547
|
+
Returns:
|
|
548
|
+
Numpy array of shape ``(1, 3, H, W)`` ready for the ONNX
|
|
549
|
+
session.
|
|
550
|
+
"""
|
|
551
|
+
# Lazy import to avoid QGIS opencv conflicts
|
|
552
|
+
import cv2
|
|
553
|
+
|
|
554
|
+
# CHW → HWC
|
|
555
|
+
if data.ndim == 3:
|
|
556
|
+
img = data.transpose(1, 2, 0)
|
|
557
|
+
else:
|
|
558
|
+
img = data
|
|
559
|
+
|
|
560
|
+
# Ensure 3 channels
|
|
561
|
+
if img.ndim == 2:
|
|
562
|
+
img = np.stack([img] * 3, axis=-1)
|
|
563
|
+
elif img.shape[-1] == 1:
|
|
564
|
+
img = np.repeat(img, 3, axis=-1)
|
|
565
|
+
elif img.shape[-1] > 3:
|
|
566
|
+
img = img[..., :3]
|
|
567
|
+
|
|
568
|
+
# Percentile normalization → uint8
|
|
569
|
+
if img.dtype != np.uint8:
|
|
570
|
+
for i in range(img.shape[-1]):
|
|
571
|
+
band = img[..., i].astype(np.float32)
|
|
572
|
+
p2, p98 = np.percentile(band, [2, 98])
|
|
573
|
+
if p98 > p2:
|
|
574
|
+
img[..., i] = np.clip((band - p2) / (p98 - p2), 0, 1)
|
|
575
|
+
else:
|
|
576
|
+
img[..., i] = 0
|
|
577
|
+
img = (img * 255).astype(np.uint8)
|
|
578
|
+
|
|
579
|
+
# Resize to model expected size if needed
|
|
580
|
+
th = target_height or self._model_height
|
|
581
|
+
tw = target_width or self._model_width
|
|
582
|
+
if th and tw and (img.shape[0] != th or img.shape[1] != tw):
|
|
583
|
+
img = cv2.resize(img, (tw, th), interpolation=cv2.INTER_LINEAR)
|
|
584
|
+
|
|
585
|
+
# Normalize to float32 [0, 1]
|
|
586
|
+
img = img.astype(np.float32) / 255.0
|
|
587
|
+
|
|
588
|
+
# HWC → NCHW
|
|
589
|
+
tensor = img.transpose(2, 0, 1)[np.newaxis, ...]
|
|
590
|
+
return tensor
|
|
591
|
+
|
|
592
|
+
# ------------------------------------------------------------------
|
|
593
|
+
# Prediction
|
|
594
|
+
# ------------------------------------------------------------------
|
|
595
|
+
|
|
596
|
+
def predict(
|
|
597
|
+
self,
|
|
598
|
+
source: Union[str, np.ndarray, "Image.Image"],
|
|
599
|
+
output_path: Optional[str] = None,
|
|
600
|
+
output_vector_path: Optional[str] = None,
|
|
601
|
+
window: Optional[Window] = None,
|
|
602
|
+
bands: Optional[List[int]] = None,
|
|
603
|
+
threshold: float = 0.5,
|
|
604
|
+
box_threshold: float = 0.3,
|
|
605
|
+
min_object_area: int = 100,
|
|
606
|
+
simplify_tolerance: float = 1.0,
|
|
607
|
+
batch_size: int = 1,
|
|
608
|
+
return_probabilities: bool = False,
|
|
609
|
+
**kwargs: Any,
|
|
610
|
+
) -> Dict[str, Any]:
|
|
611
|
+
"""Run inference on a GeoTIFF or image.
|
|
612
|
+
|
|
613
|
+
This method follows the same interface as
|
|
614
|
+
:meth:`~geoai.auto.AutoGeoModel.predict`.
|
|
615
|
+
|
|
616
|
+
Args:
|
|
617
|
+
source: Input image path, numpy array, or PIL Image.
|
|
618
|
+
output_path: Path to save output GeoTIFF (segmentation / depth).
|
|
619
|
+
output_vector_path: Path to save vectorised output.
|
|
620
|
+
window: Optional rasterio Window for reading a subset.
|
|
621
|
+
bands: Band indices to read (1-indexed).
|
|
622
|
+
threshold: Threshold for binary masks (segmentation).
|
|
623
|
+
box_threshold: Confidence threshold for detections.
|
|
624
|
+
min_object_area: Minimum polygon area in pixels for
|
|
625
|
+
vectorization.
|
|
626
|
+
simplify_tolerance: Tolerance for polygon simplification.
|
|
627
|
+
batch_size: Batch size for tiled processing (reserved for
|
|
628
|
+
future use).
|
|
629
|
+
return_probabilities: Whether to return probability maps.
|
|
630
|
+
**kwargs: Extra keyword arguments (currently unused).
|
|
631
|
+
|
|
632
|
+
Returns:
|
|
633
|
+
Dictionary with results (``mask``, ``class``, ``boxes`` etc.)
|
|
634
|
+
depending on the task, plus ``metadata``.
|
|
635
|
+
|
|
636
|
+
Example:
|
|
637
|
+
>>> model = ONNXGeoModel("segformer.onnx",
|
|
638
|
+
... task="semantic-segmentation")
|
|
639
|
+
>>> result = model.predict("input.tif", output_path="output.tif")
|
|
640
|
+
"""
|
|
641
|
+
# Handle URL sources
|
|
642
|
+
if isinstance(source, str) and source.startswith(("http://", "https://")):
|
|
643
|
+
import requests
|
|
644
|
+
|
|
645
|
+
pil_image = Image.open(requests.get(source, stream=True).raw)
|
|
646
|
+
data = np.array(pil_image.convert("RGB")).transpose(2, 0, 1)
|
|
647
|
+
metadata = None
|
|
648
|
+
else:
|
|
649
|
+
data, metadata = self.load_image(source, window, bands)
|
|
650
|
+
|
|
651
|
+
# Determine spatial size
|
|
652
|
+
if data.ndim == 3:
|
|
653
|
+
_, height, width = data.shape
|
|
654
|
+
else:
|
|
655
|
+
height, width = data.shape
|
|
656
|
+
|
|
657
|
+
# Classification never uses tiled processing
|
|
658
|
+
use_tiled = (
|
|
659
|
+
height > self.tile_size or width > self.tile_size
|
|
660
|
+
) and self.task not in ("classification", "image-classification")
|
|
661
|
+
|
|
662
|
+
if use_tiled:
|
|
663
|
+
result = self._predict_tiled(
|
|
664
|
+
data,
|
|
665
|
+
metadata,
|
|
666
|
+
threshold=threshold,
|
|
667
|
+
return_probabilities=return_probabilities,
|
|
668
|
+
)
|
|
669
|
+
else:
|
|
670
|
+
result = self._predict_single(
|
|
671
|
+
data,
|
|
672
|
+
metadata,
|
|
673
|
+
threshold=threshold,
|
|
674
|
+
return_probabilities=return_probabilities,
|
|
675
|
+
)
|
|
676
|
+
|
|
677
|
+
# Save GeoTIFF
|
|
678
|
+
if output_path and metadata:
|
|
679
|
+
out_data = result.get("mask", result.get("output"))
|
|
680
|
+
if out_data is not None:
|
|
681
|
+
self._save_geotiff(out_data, output_path, metadata, nodata=0)
|
|
682
|
+
result["output_path"] = output_path
|
|
683
|
+
|
|
684
|
+
# Save vector
|
|
685
|
+
if output_vector_path and metadata and "mask" in result:
|
|
686
|
+
gdf = self.mask_to_vector(
|
|
687
|
+
result["mask"],
|
|
688
|
+
metadata,
|
|
689
|
+
threshold=threshold,
|
|
690
|
+
min_object_area=min_object_area,
|
|
691
|
+
simplify_tolerance=simplify_tolerance,
|
|
692
|
+
)
|
|
693
|
+
if gdf is not None and len(gdf) > 0:
|
|
694
|
+
gdf.to_file(output_vector_path)
|
|
695
|
+
result["vector_path"] = output_vector_path
|
|
696
|
+
result["geodataframe"] = gdf
|
|
697
|
+
|
|
698
|
+
return result
|
|
699
|
+
|
|
700
|
+
# ------------------------------------------------------------------
|
|
701
|
+
# Internal prediction helpers
|
|
702
|
+
# ------------------------------------------------------------------
|
|
703
|
+
|
|
704
|
+
def _predict_single(
|
|
705
|
+
self,
|
|
706
|
+
data: np.ndarray,
|
|
707
|
+
metadata: Optional[Dict],
|
|
708
|
+
threshold: float = 0.5,
|
|
709
|
+
return_probabilities: bool = False,
|
|
710
|
+
) -> Dict[str, Any]:
|
|
711
|
+
"""Run inference on a single (non-tiled) image."""
|
|
712
|
+
# Lazy import to avoid QGIS opencv conflicts
|
|
713
|
+
import cv2
|
|
714
|
+
|
|
715
|
+
original_h = data.shape[1] if data.ndim == 3 else data.shape[0]
|
|
716
|
+
original_w = data.shape[2] if data.ndim == 3 else data.shape[1]
|
|
717
|
+
|
|
718
|
+
input_tensor = self._prepare_input(data)
|
|
719
|
+
outputs = self.session.run(self.output_names, {self.input_name: input_tensor})
|
|
720
|
+
|
|
721
|
+
result = self._process_outputs(
|
|
722
|
+
outputs, (original_h, original_w), threshold, return_probabilities
|
|
723
|
+
)
|
|
724
|
+
result["metadata"] = metadata
|
|
725
|
+
return result
|
|
726
|
+
|
|
727
|
+
def _predict_tiled(
|
|
728
|
+
self,
|
|
729
|
+
data: np.ndarray,
|
|
730
|
+
metadata: Optional[Dict],
|
|
731
|
+
threshold: float = 0.5,
|
|
732
|
+
return_probabilities: bool = False,
|
|
733
|
+
) -> Dict[str, Any]:
|
|
734
|
+
"""Run tiled inference for large images."""
|
|
735
|
+
# Lazy import to avoid QGIS opencv conflicts
|
|
736
|
+
import cv2
|
|
737
|
+
|
|
738
|
+
if data.ndim == 3:
|
|
739
|
+
_, height, width = data.shape
|
|
740
|
+
else:
|
|
741
|
+
height, width = data.shape
|
|
742
|
+
|
|
743
|
+
effective = self.tile_size - 2 * self.overlap
|
|
744
|
+
n_x = max(1, int(np.ceil(width / effective)))
|
|
745
|
+
n_y = max(1, int(np.ceil(height / effective)))
|
|
746
|
+
total = n_x * n_y
|
|
747
|
+
|
|
748
|
+
mask_output = np.zeros((height, width), dtype=np.float32)
|
|
749
|
+
count_output = np.zeros((height, width), dtype=np.float32)
|
|
750
|
+
|
|
751
|
+
print(f"Processing {total} tiles ({n_x}x{n_y})")
|
|
752
|
+
|
|
753
|
+
with tqdm(total=total, desc="Processing tiles") as pbar:
|
|
754
|
+
for iy in range(n_y):
|
|
755
|
+
for ix in range(n_x):
|
|
756
|
+
x_start = max(0, ix * effective - self.overlap)
|
|
757
|
+
y_start = max(0, iy * effective - self.overlap)
|
|
758
|
+
x_end = min(width, (ix + 1) * effective + self.overlap)
|
|
759
|
+
y_end = min(height, (iy + 1) * effective + self.overlap)
|
|
760
|
+
|
|
761
|
+
if data.ndim == 3:
|
|
762
|
+
tile = data[:, y_start:y_end, x_start:x_end]
|
|
763
|
+
else:
|
|
764
|
+
tile = data[y_start:y_end, x_start:x_end]
|
|
765
|
+
|
|
766
|
+
try:
|
|
767
|
+
tile_result = self._predict_single(
|
|
768
|
+
tile, None, threshold, return_probabilities
|
|
769
|
+
)
|
|
770
|
+
tile_mask = tile_result.get("mask", tile_result.get("output"))
|
|
771
|
+
if tile_mask is not None:
|
|
772
|
+
if tile_mask.ndim > 2:
|
|
773
|
+
tile_mask = tile_mask.squeeze()
|
|
774
|
+
if tile_mask.ndim > 2:
|
|
775
|
+
tile_mask = tile_mask[0]
|
|
776
|
+
|
|
777
|
+
tile_h = y_end - y_start
|
|
778
|
+
tile_w = x_end - x_start
|
|
779
|
+
if tile_mask.shape != (tile_h, tile_w):
|
|
780
|
+
tile_mask = cv2.resize(
|
|
781
|
+
tile_mask.astype(np.float32),
|
|
782
|
+
(tile_w, tile_h),
|
|
783
|
+
interpolation=cv2.INTER_LINEAR,
|
|
784
|
+
)
|
|
785
|
+
|
|
786
|
+
mask_output[y_start:y_end, x_start:x_end] += tile_mask
|
|
787
|
+
count_output[y_start:y_end, x_start:x_end] += 1
|
|
788
|
+
except Exception as e:
|
|
789
|
+
print(f"Error processing tile ({ix}, {iy}): {e}")
|
|
790
|
+
|
|
791
|
+
pbar.update(1)
|
|
792
|
+
|
|
793
|
+
count_output = np.maximum(count_output, 1)
|
|
794
|
+
mask_output = mask_output / count_output
|
|
795
|
+
|
|
796
|
+
return {
|
|
797
|
+
"mask": (mask_output > threshold).astype(np.uint8),
|
|
798
|
+
"probabilities": mask_output if return_probabilities else None,
|
|
799
|
+
"metadata": metadata,
|
|
800
|
+
}
|
|
801
|
+
|
|
802
|
+
# ------------------------------------------------------------------
|
|
803
|
+
# Output processing
|
|
804
|
+
# ------------------------------------------------------------------
|
|
805
|
+
|
|
806
|
+
def _process_outputs(
|
|
807
|
+
self,
|
|
808
|
+
outputs: List[np.ndarray],
|
|
809
|
+
original_size: Tuple[int, int],
|
|
810
|
+
threshold: float = 0.5,
|
|
811
|
+
return_probabilities: bool = False,
|
|
812
|
+
) -> Dict[str, Any]:
|
|
813
|
+
"""Map raw ONNX outputs to a result dict.
|
|
814
|
+
|
|
815
|
+
Args:
|
|
816
|
+
outputs: List of numpy arrays returned by
|
|
817
|
+
``session.run()``.
|
|
818
|
+
original_size: ``(height, width)`` of the input before
|
|
819
|
+
resizing.
|
|
820
|
+
threshold: Binary threshold for segmentation masks.
|
|
821
|
+
return_probabilities: Whether to include probability maps.
|
|
822
|
+
|
|
823
|
+
Returns:
|
|
824
|
+
Result dictionary.
|
|
825
|
+
"""
|
|
826
|
+
# Lazy import to avoid QGIS opencv conflicts
|
|
827
|
+
import cv2
|
|
828
|
+
|
|
829
|
+
result: Dict[str, Any] = {}
|
|
830
|
+
oh, ow = original_size
|
|
831
|
+
|
|
832
|
+
if self.task in ("segmentation", "semantic-segmentation"):
|
|
833
|
+
logits = outputs[0] # (1, C, H, W)
|
|
834
|
+
if logits.ndim == 4:
|
|
835
|
+
# Softmax → argmax
|
|
836
|
+
exp = np.exp(logits - logits.max(axis=1, keepdims=True))
|
|
837
|
+
probs = exp / exp.sum(axis=1, keepdims=True)
|
|
838
|
+
mask = probs.argmax(axis=1).squeeze() # (H, W)
|
|
839
|
+
|
|
840
|
+
if mask.shape != (oh, ow):
|
|
841
|
+
mask = cv2.resize(
|
|
842
|
+
mask.astype(np.float32),
|
|
843
|
+
(ow, oh),
|
|
844
|
+
interpolation=cv2.INTER_NEAREST,
|
|
845
|
+
)
|
|
846
|
+
|
|
847
|
+
result["mask"] = mask.astype(np.uint8)
|
|
848
|
+
if return_probabilities:
|
|
849
|
+
result["probabilities"] = probs.squeeze()
|
|
850
|
+
|
|
851
|
+
elif self.task in ("classification", "image-classification"):
|
|
852
|
+
logits = outputs[0] # (1, C)
|
|
853
|
+
exp = np.exp(logits - logits.max(axis=-1, keepdims=True))
|
|
854
|
+
probs = exp / exp.sum(axis=-1, keepdims=True)
|
|
855
|
+
pred = int(probs.argmax(axis=-1).squeeze())
|
|
856
|
+
result["class"] = pred
|
|
857
|
+
result["probabilities"] = probs.squeeze()
|
|
858
|
+
if self.id2label:
|
|
859
|
+
result["label"] = self.id2label.get(pred, str(pred))
|
|
860
|
+
|
|
861
|
+
elif self.task == "object-detection":
|
|
862
|
+
logits = outputs[0] # (1, N, num_classes)
|
|
863
|
+
pred_boxes = outputs[1] if len(outputs) > 1 else None # (1, N, 4)
|
|
864
|
+
if pred_boxes is not None:
|
|
865
|
+
# Sigmoid scores
|
|
866
|
+
scores_all = 1.0 / (1.0 + np.exp(-logits)) # sigmoid
|
|
867
|
+
scores = scores_all.max(axis=-1).squeeze() # (N,)
|
|
868
|
+
labels = scores_all.argmax(axis=-1).squeeze() # (N,)
|
|
869
|
+
boxes = pred_boxes.squeeze() # (N, 4)
|
|
870
|
+
|
|
871
|
+
keep = scores > threshold
|
|
872
|
+
result["boxes"] = boxes[keep]
|
|
873
|
+
result["scores"] = scores[keep]
|
|
874
|
+
result["labels"] = labels[keep]
|
|
875
|
+
|
|
876
|
+
elif self.task == "depth-estimation":
|
|
877
|
+
depth = outputs[0].squeeze()
|
|
878
|
+
if depth.shape != (oh, ow):
|
|
879
|
+
depth = cv2.resize(
|
|
880
|
+
depth.astype(np.float32),
|
|
881
|
+
(ow, oh),
|
|
882
|
+
interpolation=cv2.INTER_LINEAR,
|
|
883
|
+
)
|
|
884
|
+
result["output"] = depth
|
|
885
|
+
result["depth"] = depth
|
|
886
|
+
|
|
887
|
+
else:
|
|
888
|
+
# Fallback – expose raw outputs
|
|
889
|
+
result["output"] = outputs[0]
|
|
890
|
+
|
|
891
|
+
return result
|
|
892
|
+
|
|
893
|
+
# ------------------------------------------------------------------
|
|
894
|
+
# Vectorization
|
|
895
|
+
# ------------------------------------------------------------------
|
|
896
|
+
|
|
897
|
+
@staticmethod
|
|
898
|
+
def mask_to_vector(
|
|
899
|
+
mask: np.ndarray,
|
|
900
|
+
metadata: Dict,
|
|
901
|
+
threshold: float = 0.5,
|
|
902
|
+
min_object_area: int = 100,
|
|
903
|
+
max_object_area: Optional[int] = None,
|
|
904
|
+
simplify_tolerance: float = 1.0,
|
|
905
|
+
) -> Optional[gpd.GeoDataFrame]:
|
|
906
|
+
"""Convert a raster mask to vector polygons.
|
|
907
|
+
|
|
908
|
+
Args:
|
|
909
|
+
mask: Binary or probability mask array.
|
|
910
|
+
metadata: Geospatial metadata dictionary.
|
|
911
|
+
threshold: Threshold for binarizing probability masks.
|
|
912
|
+
min_object_area: Minimum polygon area in pixels.
|
|
913
|
+
max_object_area: Maximum polygon area in pixels (optional).
|
|
914
|
+
simplify_tolerance: Tolerance for polygon simplification.
|
|
915
|
+
|
|
916
|
+
Returns:
|
|
917
|
+
GeoDataFrame with polygon geometries, or *None* if no valid
|
|
918
|
+
polygons are found.
|
|
919
|
+
"""
|
|
920
|
+
if metadata is None or metadata.get("crs") is None:
|
|
921
|
+
print("Warning: No CRS information available for vectorization")
|
|
922
|
+
return None
|
|
923
|
+
|
|
924
|
+
if mask.dtype in (np.float32, np.float64):
|
|
925
|
+
mask = (mask > threshold).astype(np.uint8)
|
|
926
|
+
else:
|
|
927
|
+
mask = (mask > 0).astype(np.uint8)
|
|
928
|
+
|
|
929
|
+
transform = metadata.get("transform")
|
|
930
|
+
crs = metadata.get("crs")
|
|
931
|
+
if transform is None:
|
|
932
|
+
print("Warning: No transform available for vectorization")
|
|
933
|
+
return None
|
|
934
|
+
|
|
935
|
+
polygons: List = []
|
|
936
|
+
values: List = []
|
|
937
|
+
|
|
938
|
+
try:
|
|
939
|
+
for geom, value in shapes(mask, transform=transform):
|
|
940
|
+
if value > 0:
|
|
941
|
+
poly = shape(geom)
|
|
942
|
+
pixel_area = poly.area / (transform.a * abs(transform.e))
|
|
943
|
+
if pixel_area < min_object_area:
|
|
944
|
+
continue
|
|
945
|
+
if max_object_area and pixel_area > max_object_area:
|
|
946
|
+
continue
|
|
947
|
+
if simplify_tolerance > 0:
|
|
948
|
+
poly = poly.simplify(
|
|
949
|
+
simplify_tolerance * abs(transform.a),
|
|
950
|
+
preserve_topology=True,
|
|
951
|
+
)
|
|
952
|
+
if poly.is_valid and not poly.is_empty:
|
|
953
|
+
polygons.append(poly)
|
|
954
|
+
values.append(value)
|
|
955
|
+
except Exception as e:
|
|
956
|
+
print(f"Error during vectorization: {e}")
|
|
957
|
+
return None
|
|
958
|
+
|
|
959
|
+
if not polygons:
|
|
960
|
+
return None
|
|
961
|
+
|
|
962
|
+
return gpd.GeoDataFrame(
|
|
963
|
+
{"geometry": polygons, "class": values},
|
|
964
|
+
crs=crs,
|
|
965
|
+
)
|
|
966
|
+
|
|
967
|
+
# ------------------------------------------------------------------
|
|
968
|
+
# GeoTIFF / vector save helpers
|
|
969
|
+
# ------------------------------------------------------------------
|
|
970
|
+
|
|
971
|
+
@staticmethod
|
|
972
|
+
def _save_geotiff(
|
|
973
|
+
data: np.ndarray,
|
|
974
|
+
output_path: str,
|
|
975
|
+
metadata: Dict,
|
|
976
|
+
dtype: Optional[str] = None,
|
|
977
|
+
compress: str = "lzw",
|
|
978
|
+
nodata: Optional[float] = None,
|
|
979
|
+
) -> str:
|
|
980
|
+
"""Save an array as a GeoTIFF with geospatial metadata.
|
|
981
|
+
|
|
982
|
+
Args:
|
|
983
|
+
data: Array to save (2D or 3D in CHW format).
|
|
984
|
+
output_path: Output file path.
|
|
985
|
+
metadata: Metadata dictionary from :meth:`load_geotiff`.
|
|
986
|
+
dtype: Output data type. If *None*, inferred from *data*.
|
|
987
|
+
compress: Compression method.
|
|
988
|
+
nodata: NoData value.
|
|
989
|
+
|
|
990
|
+
Returns:
|
|
991
|
+
Path to the saved file.
|
|
992
|
+
"""
|
|
993
|
+
profile = metadata["profile"].copy()
|
|
994
|
+
if dtype is None:
|
|
995
|
+
dtype = str(data.dtype)
|
|
996
|
+
|
|
997
|
+
if data.ndim == 2:
|
|
998
|
+
count = 1
|
|
999
|
+
height, width = data.shape
|
|
1000
|
+
else:
|
|
1001
|
+
count = data.shape[0]
|
|
1002
|
+
height, width = data.shape[1], data.shape[2]
|
|
1003
|
+
|
|
1004
|
+
profile.update(
|
|
1005
|
+
{
|
|
1006
|
+
"dtype": dtype,
|
|
1007
|
+
"count": count,
|
|
1008
|
+
"height": height,
|
|
1009
|
+
"width": width,
|
|
1010
|
+
"compress": compress,
|
|
1011
|
+
}
|
|
1012
|
+
)
|
|
1013
|
+
if nodata is not None:
|
|
1014
|
+
profile["nodata"] = nodata
|
|
1015
|
+
|
|
1016
|
+
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
|
|
1017
|
+
|
|
1018
|
+
with rasterio.open(output_path, "w", **profile) as dst:
|
|
1019
|
+
if data.ndim == 2:
|
|
1020
|
+
dst.write(data, 1)
|
|
1021
|
+
else:
|
|
1022
|
+
dst.write(data)
|
|
1023
|
+
|
|
1024
|
+
return output_path
|
|
1025
|
+
|
|
1026
|
+
@staticmethod
|
|
1027
|
+
def save_vector(
|
|
1028
|
+
gdf: gpd.GeoDataFrame,
|
|
1029
|
+
output_path: str,
|
|
1030
|
+
driver: Optional[str] = None,
|
|
1031
|
+
) -> str:
|
|
1032
|
+
"""Save a GeoDataFrame to file.
|
|
1033
|
+
|
|
1034
|
+
Args:
|
|
1035
|
+
gdf: GeoDataFrame to save.
|
|
1036
|
+
output_path: Output file path.
|
|
1037
|
+
driver: File driver (auto-detected from extension if *None*).
|
|
1038
|
+
|
|
1039
|
+
Returns:
|
|
1040
|
+
Path to the saved file.
|
|
1041
|
+
"""
|
|
1042
|
+
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
|
|
1043
|
+
if driver is None:
|
|
1044
|
+
ext = os.path.splitext(output_path)[1].lower()
|
|
1045
|
+
driver_map = {
|
|
1046
|
+
".geojson": "GeoJSON",
|
|
1047
|
+
".json": "GeoJSON",
|
|
1048
|
+
".gpkg": "GPKG",
|
|
1049
|
+
".shp": "ESRI Shapefile",
|
|
1050
|
+
".parquet": "Parquet",
|
|
1051
|
+
".fgb": "FlatGeobuf",
|
|
1052
|
+
}
|
|
1053
|
+
driver = driver_map.get(ext, "GeoJSON")
|
|
1054
|
+
gdf.to_file(output_path, driver=driver)
|
|
1055
|
+
return output_path
|
|
1056
|
+
|
|
1057
|
+
def __repr__(self) -> str:
|
|
1058
|
+
return (
|
|
1059
|
+
f"ONNXGeoModel(path={self.model_path!r}, task={self.task!r}, "
|
|
1060
|
+
f"providers={self.session.get_providers()!r})"
|
|
1061
|
+
)
|
|
1062
|
+
|
|
1063
|
+
|
|
1064
|
+
# ---------------------------------------------------------------------------
|
|
1065
|
+
# Convenience functions
|
|
1066
|
+
# ---------------------------------------------------------------------------
|
|
1067
|
+
|
|
1068
|
+
|
|
1069
|
+
def onnx_semantic_segmentation(
|
|
1070
|
+
input_path: str,
|
|
1071
|
+
output_path: str,
|
|
1072
|
+
model_path: str,
|
|
1073
|
+
output_vector_path: Optional[str] = None,
|
|
1074
|
+
threshold: float = 0.5,
|
|
1075
|
+
tile_size: int = 1024,
|
|
1076
|
+
overlap: int = 128,
|
|
1077
|
+
min_object_area: int = 100,
|
|
1078
|
+
simplify_tolerance: float = 1.0,
|
|
1079
|
+
providers: Optional[List[str]] = None,
|
|
1080
|
+
**kwargs: Any,
|
|
1081
|
+
) -> Dict[str, Any]:
|
|
1082
|
+
"""Perform semantic segmentation using an ONNX model on a GeoTIFF.
|
|
1083
|
+
|
|
1084
|
+
This is a convenience wrapper around :class:`ONNXGeoModel`.
|
|
1085
|
+
|
|
1086
|
+
Args:
|
|
1087
|
+
input_path: Path to input GeoTIFF.
|
|
1088
|
+
output_path: Path to save output segmentation GeoTIFF.
|
|
1089
|
+
model_path: Path to the ONNX model file.
|
|
1090
|
+
output_vector_path: Optional path to save vectorised output.
|
|
1091
|
+
threshold: Threshold for binary masks.
|
|
1092
|
+
tile_size: Tile size for processing large images.
|
|
1093
|
+
overlap: Overlap between tiles.
|
|
1094
|
+
min_object_area: Minimum object area for vectorization.
|
|
1095
|
+
simplify_tolerance: Tolerance for polygon simplification.
|
|
1096
|
+
providers: ONNX Runtime execution providers.
|
|
1097
|
+
**kwargs: Additional arguments passed to :meth:`ONNXGeoModel.predict`.
|
|
1098
|
+
|
|
1099
|
+
Returns:
|
|
1100
|
+
Dictionary with results.
|
|
1101
|
+
|
|
1102
|
+
Example:
|
|
1103
|
+
>>> result = onnx_semantic_segmentation(
|
|
1104
|
+
... "input.tif",
|
|
1105
|
+
... "output.tif",
|
|
1106
|
+
... "segformer.onnx",
|
|
1107
|
+
... output_vector_path="output.geojson",
|
|
1108
|
+
... )
|
|
1109
|
+
"""
|
|
1110
|
+
model = ONNXGeoModel(
|
|
1111
|
+
model_path,
|
|
1112
|
+
task="semantic-segmentation",
|
|
1113
|
+
providers=providers,
|
|
1114
|
+
tile_size=tile_size,
|
|
1115
|
+
overlap=overlap,
|
|
1116
|
+
)
|
|
1117
|
+
return model.predict(
|
|
1118
|
+
input_path,
|
|
1119
|
+
output_path=output_path,
|
|
1120
|
+
output_vector_path=output_vector_path,
|
|
1121
|
+
threshold=threshold,
|
|
1122
|
+
min_object_area=min_object_area,
|
|
1123
|
+
simplify_tolerance=simplify_tolerance,
|
|
1124
|
+
**kwargs,
|
|
1125
|
+
)
|
|
1126
|
+
|
|
1127
|
+
|
|
1128
|
+
def onnx_image_classification(
|
|
1129
|
+
input_path: str,
|
|
1130
|
+
model_path: str,
|
|
1131
|
+
providers: Optional[List[str]] = None,
|
|
1132
|
+
**kwargs: Any,
|
|
1133
|
+
) -> Dict[str, Any]:
|
|
1134
|
+
"""Classify an image using an ONNX model.
|
|
1135
|
+
|
|
1136
|
+
Args:
|
|
1137
|
+
input_path: Path to input image or GeoTIFF.
|
|
1138
|
+
model_path: Path to the ONNX model file.
|
|
1139
|
+
providers: ONNX Runtime execution providers.
|
|
1140
|
+
**kwargs: Additional arguments passed to :meth:`ONNXGeoModel.predict`.
|
|
1141
|
+
|
|
1142
|
+
Returns:
|
|
1143
|
+
Dictionary with ``class``, ``label`` (if available), and
|
|
1144
|
+
``probabilities``.
|
|
1145
|
+
|
|
1146
|
+
Example:
|
|
1147
|
+
>>> result = onnx_image_classification("image.tif", "classifier.onnx")
|
|
1148
|
+
>>> print(result["class"], result["label"])
|
|
1149
|
+
"""
|
|
1150
|
+
model = ONNXGeoModel(
|
|
1151
|
+
model_path,
|
|
1152
|
+
task="image-classification",
|
|
1153
|
+
providers=providers,
|
|
1154
|
+
)
|
|
1155
|
+
return model.predict(input_path, **kwargs)
|