geoai-py 0.27.0__py2.py3-none-any.whl → 0.28.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/onnx.py ADDED
@@ -0,0 +1,1155 @@
1
+ """ONNX Runtime support for geospatial model inference.
2
+
3
+ This module provides ONNXGeoModel for loading and running inference with
4
+ ONNX models on geospatial data (GeoTIFF), and export_to_onnx for converting
5
+ existing PyTorch/Hugging Face models to ONNX format.
6
+
7
+ Supported tasks:
8
+ - Semantic segmentation (e.g., SegFormer, Mask2Former)
9
+ - Image classification (e.g., ViT, ResNet)
10
+ - Object detection (e.g., DETR, YOLOS)
11
+ - Depth estimation (e.g., Depth Anything, DPT)
12
+
13
+ Requirements:
14
+ - onnx
15
+ - onnxruntime (or onnxruntime-gpu for GPU acceleration)
16
+
17
+ Install with::
18
+
19
+ pip install geoai-py[onnx]
20
+
21
+ Example:
22
+ >>> from geoai import export_to_onnx, ONNXGeoModel
23
+ >>> # Export a HuggingFace model to ONNX
24
+ >>> export_to_onnx(
25
+ ... "nvidia/segformer-b0-finetuned-ade-512-512",
26
+ ... "segformer.onnx",
27
+ ... task="semantic-segmentation",
28
+ ... )
29
+ >>> # Load and run inference with the ONNX model
30
+ >>> model = ONNXGeoModel("segformer.onnx", task="semantic-segmentation")
31
+ >>> result = model.predict("input.tif", output_path="output.tif")
32
+ """
33
+
34
+ import json
35
+ import os
36
+ from typing import Any, Dict, List, Optional, Tuple, Union
37
+
38
+ import geopandas as gpd
39
+ import numpy as np
40
+ import rasterio
41
+ from PIL import Image
42
+ from rasterio.features import shapes
43
+ from rasterio.windows import Window
44
+ from shapely.geometry import shape
45
+ from tqdm import tqdm
46
+
47
+
48
+ def _check_onnx_deps() -> None:
49
+ """Check that onnx and onnxruntime are installed.
50
+
51
+ Raises:
52
+ ImportError: If onnx or onnxruntime is not installed.
53
+ """
54
+ try:
55
+ import onnx # noqa: F401
56
+ except ImportError:
57
+ raise ImportError(
58
+ "The 'onnx' package is required for ONNX support. "
59
+ "Install it with: pip install geoai-py[onnx]"
60
+ )
61
+
62
+ try:
63
+ import onnxruntime # noqa: F401
64
+ except ImportError:
65
+ raise ImportError(
66
+ "The 'onnxruntime' package is required for ONNX support. "
67
+ "Install it with: pip install geoai-py[onnx] "
68
+ "(use 'onnxruntime-gpu' for GPU acceleration)"
69
+ )
70
+
71
+
72
+ def _check_torch_deps() -> None:
73
+ """Check that torch and transformers are installed (needed for export).
74
+
75
+ Raises:
76
+ ImportError: If torch or transformers is not installed.
77
+ """
78
+ try:
79
+ import torch # noqa: F401
80
+ except ImportError:
81
+ raise ImportError(
82
+ "PyTorch is required for exporting models to ONNX. "
83
+ "Install it from https://pytorch.org/"
84
+ )
85
+
86
+ try:
87
+ import transformers # noqa: F401
88
+ except ImportError:
89
+ raise ImportError(
90
+ "The 'transformers' package is required for exporting "
91
+ "Hugging Face models to ONNX. "
92
+ "Install it with: pip install transformers"
93
+ )
94
+
95
+
96
+ # ---------------------------------------------------------------------------
97
+ # Export helpers
98
+ # ---------------------------------------------------------------------------
99
+
100
+
101
+ def export_to_onnx(
102
+ model_name_or_path: str,
103
+ output_path: str,
104
+ task: Optional[str] = None,
105
+ input_height: int = 512,
106
+ input_width: int = 512,
107
+ input_channels: int = 3,
108
+ opset_version: int = 17,
109
+ dynamic_axes: Optional[Dict[str, Dict[int, str]]] = None,
110
+ simplify: bool = True,
111
+ device: Optional[str] = None,
112
+ **kwargs: Any,
113
+ ) -> str:
114
+ """Export a PyTorch / Hugging Face model to ONNX format.
115
+
116
+ Args:
117
+ model_name_or_path: Hugging Face model name or local checkpoint path.
118
+ output_path: Path where the ``.onnx`` file will be saved.
119
+ task: Model task. One of ``"semantic-segmentation"``,
120
+ ``"image-classification"``, ``"object-detection"``, or
121
+ ``"depth-estimation"``. If *None* the function tries to infer
122
+ the task from the model configuration.
123
+ input_height: Height of the dummy input tensor (pixels).
124
+ input_width: Width of the dummy input tensor (pixels).
125
+ input_channels: Number of input channels (default 3 for RGB).
126
+ opset_version: ONNX opset version (default 17).
127
+ dynamic_axes: Optional mapping of dynamic axes for variable-size
128
+ inputs/outputs. When *None* a sensible default is used so that
129
+ batch size and spatial dimensions are dynamic.
130
+ simplify: Whether to simplify the exported graph with
131
+ ``onnxsim.simplify`` (requires the ``onnxsim`` package).
132
+ device: Device used for tracing (``"cpu"`` recommended for export).
133
+ **kwargs: Extra keyword arguments forwarded to
134
+ ``AutoModel.from_pretrained``.
135
+
136
+ Returns:
137
+ Absolute path to the saved ONNX file.
138
+
139
+ Raises:
140
+ ImportError: If required packages are missing.
141
+ ValueError: If the task cannot be determined.
142
+
143
+ Example:
144
+ >>> export_to_onnx(
145
+ ... "nvidia/segformer-b0-finetuned-ade-512-512",
146
+ ... "segformer.onnx",
147
+ ... task="semantic-segmentation",
148
+ ... )
149
+ 'segformer.onnx'
150
+ """
151
+ _check_torch_deps()
152
+ import onnx # noqa: F811
153
+ import torch
154
+ from transformers import (
155
+ AutoConfig,
156
+ AutoImageProcessor,
157
+ AutoModelForDepthEstimation,
158
+ AutoModelForImageClassification,
159
+ AutoModelForObjectDetection,
160
+ AutoModelForSemanticSegmentation,
161
+ )
162
+
163
+ if device is None:
164
+ device = "cpu"
165
+
166
+ # ------------------------------------------------------------------
167
+ # Load model
168
+ # ------------------------------------------------------------------
169
+ task_model_map = {
170
+ "segmentation": AutoModelForSemanticSegmentation,
171
+ "semantic-segmentation": AutoModelForSemanticSegmentation,
172
+ "classification": AutoModelForImageClassification,
173
+ "image-classification": AutoModelForImageClassification,
174
+ "object-detection": AutoModelForObjectDetection,
175
+ "depth-estimation": AutoModelForDepthEstimation,
176
+ }
177
+
178
+ if task and task in task_model_map:
179
+ model_cls = task_model_map[task]
180
+ else:
181
+ # Try to infer from config
182
+ try:
183
+ config = AutoConfig.from_pretrained(model_name_or_path)
184
+ architectures = getattr(config, "architectures", [])
185
+ if any("Segmentation" in a for a in architectures):
186
+ model_cls = AutoModelForSemanticSegmentation
187
+ task = task or "semantic-segmentation"
188
+ elif any("Classification" in a for a in architectures):
189
+ model_cls = AutoModelForImageClassification
190
+ task = task or "image-classification"
191
+ elif any("Detection" in a for a in architectures):
192
+ model_cls = AutoModelForObjectDetection
193
+ task = task or "object-detection"
194
+ elif any("Depth" in a for a in architectures):
195
+ model_cls = AutoModelForDepthEstimation
196
+ task = task or "depth-estimation"
197
+ else:
198
+ raise ValueError(
199
+ f"Cannot infer task from model config. "
200
+ f"Found architectures: {architectures}. "
201
+ f"Please specify task= explicitly."
202
+ )
203
+ except Exception as exc:
204
+ raise ValueError(
205
+ "Cannot determine the model task. " "Please specify task= explicitly."
206
+ ) from exc
207
+
208
+ model = model_cls.from_pretrained(model_name_or_path, **kwargs)
209
+ model = model.to(device).eval()
210
+
211
+ # Try loading the image processor to get expected input size
212
+ try:
213
+ processor = AutoImageProcessor.from_pretrained(model_name_or_path)
214
+ if hasattr(processor, "size"):
215
+ size = processor.size
216
+ if isinstance(size, dict):
217
+ input_height = size.get("height", input_height)
218
+ input_width = size.get("width", input_width)
219
+ elif isinstance(size, (list, tuple)) and len(size) == 2:
220
+ input_height, input_width = size
221
+ except Exception:
222
+ pass # processor introspection is optional; fall back to defaults
223
+
224
+ # ------------------------------------------------------------------
225
+ # Build dummy input & dynamic axes
226
+ # ------------------------------------------------------------------
227
+ dummy_input = torch.randn(
228
+ 1, input_channels, input_height, input_width, device=device
229
+ )
230
+
231
+ input_names = ["pixel_values"]
232
+
233
+ if task in ("segmentation", "semantic-segmentation", "depth-estimation"):
234
+ output_names = ["logits"]
235
+ elif task in ("classification", "image-classification"):
236
+ output_names = ["logits"]
237
+ elif task == "object-detection":
238
+ output_names = ["logits", "pred_boxes"]
239
+ else:
240
+ output_names = ["output"]
241
+
242
+ if dynamic_axes is None:
243
+ dynamic_axes = {
244
+ "pixel_values": {0: "batch", 2: "height", 3: "width"},
245
+ }
246
+ for name in output_names:
247
+ dynamic_axes[name] = {0: "batch"}
248
+
249
+ # ------------------------------------------------------------------
250
+ # Export
251
+ # ------------------------------------------------------------------
252
+ os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
253
+
254
+ torch.onnx.export(
255
+ model,
256
+ ({"pixel_values": dummy_input},),
257
+ output_path,
258
+ input_names=input_names,
259
+ output_names=output_names,
260
+ dynamic_axes=dynamic_axes,
261
+ opset_version=opset_version,
262
+ do_constant_folding=True,
263
+ )
264
+
265
+ # Validate
266
+ onnx_model = onnx.load(output_path)
267
+ onnx.checker.check_model(onnx_model)
268
+
269
+ # Optional simplification
270
+ if simplify:
271
+ try:
272
+ import onnxsim
273
+
274
+ onnx_model_simplified, check = onnxsim.simplify(onnx_model)
275
+ if check:
276
+ onnx.save(onnx_model_simplified, output_path)
277
+ except ImportError:
278
+ pass # onnxsim is optional
279
+ except Exception:
280
+ pass # simplification can fail for some models; keep original
281
+
282
+ # ------------------------------------------------------------------
283
+ # Save metadata alongside the model
284
+ # ------------------------------------------------------------------
285
+ meta = {
286
+ "model_name": model_name_or_path,
287
+ "task": task,
288
+ "input_height": input_height,
289
+ "input_width": input_width,
290
+ "input_channels": input_channels,
291
+ "opset_version": opset_version,
292
+ "output_names": output_names,
293
+ }
294
+
295
+ # Include id2label when available
296
+ config = model.config if hasattr(model, "config") else None
297
+ if config and hasattr(config, "id2label"):
298
+ meta["id2label"] = {str(k): v for k, v in config.id2label.items()}
299
+ if config and hasattr(config, "num_labels"):
300
+ meta["num_labels"] = config.num_labels
301
+
302
+ meta_path = output_path + ".json"
303
+ with open(meta_path, "w") as fh:
304
+ json.dump(meta, fh, indent=2)
305
+
306
+ print(f"ONNX model exported to {output_path}")
307
+ print(f"Metadata saved to {meta_path}")
308
+ return os.path.abspath(output_path)
309
+
310
+
311
+ # ---------------------------------------------------------------------------
312
+ # ONNXGeoModel
313
+ # ---------------------------------------------------------------------------
314
+
315
+
316
+ class ONNXGeoModel:
317
+ """ONNX Runtime model for geospatial inference with GeoTIFF support.
318
+
319
+ This class mirrors the :class:`~geoai.auto.AutoGeoModel` API but uses
320
+ ONNX Runtime instead of PyTorch for inference, enabling deployment on
321
+ edge devices and environments without GPU drivers.
322
+
323
+ Attributes:
324
+ session: The ``onnxruntime.InferenceSession`` instance.
325
+ task (str): The model task (e.g. ``"semantic-segmentation"``).
326
+ tile_size (int): Tile size used for processing large images.
327
+ overlap (int): Overlap between adjacent tiles.
328
+ metadata (dict): Model metadata loaded from the sidecar JSON file.
329
+
330
+ Example:
331
+ >>> model = ONNXGeoModel("segformer.onnx", task="semantic-segmentation")
332
+ >>> result = model.predict("input.tif", output_path="output.tif")
333
+ """
334
+
335
+ def __init__(
336
+ self,
337
+ model_path: str,
338
+ task: Optional[str] = None,
339
+ providers: Optional[List[str]] = None,
340
+ tile_size: int = 1024,
341
+ overlap: int = 128,
342
+ metadata: Optional[Dict[str, Any]] = None,
343
+ ) -> None:
344
+ """Load an ONNX model for geospatial inference.
345
+
346
+ Args:
347
+ model_path: Path to the ``.onnx`` model file.
348
+ task: Model task. One of ``"semantic-segmentation"``,
349
+ ``"image-classification"``, ``"object-detection"``, or
350
+ ``"depth-estimation"``. If *None*, the task is read from the
351
+ sidecar ``<model>.onnx.json`` metadata file.
352
+ providers: ONNX Runtime execution providers in priority order.
353
+ Defaults to ``["CUDAExecutionProvider",
354
+ "CPUExecutionProvider"]``.
355
+ tile_size: Tile size for processing large images.
356
+ overlap: Overlap between adjacent tiles (in pixels).
357
+ metadata: Optional pre-loaded metadata dict. When *None* the
358
+ constructor looks for ``<model_path>.json``.
359
+
360
+ Raises:
361
+ FileNotFoundError: If *model_path* does not exist.
362
+ ImportError: If onnxruntime is not installed.
363
+ """
364
+ _check_onnx_deps()
365
+ import onnxruntime as ort
366
+
367
+ if not os.path.isfile(model_path):
368
+ raise FileNotFoundError(f"ONNX model not found: {model_path}")
369
+
370
+ self.model_path = os.path.abspath(model_path)
371
+ self.tile_size = tile_size
372
+ self.overlap = overlap
373
+
374
+ # Load sidecar metadata
375
+ if metadata is not None:
376
+ self.metadata = metadata
377
+ else:
378
+ meta_path = model_path + ".json"
379
+ if os.path.isfile(meta_path):
380
+ with open(meta_path) as fh:
381
+ self.metadata = json.load(fh)
382
+ else:
383
+ self.metadata = {}
384
+
385
+ # Resolve task
386
+ self.task = task or self.metadata.get("task")
387
+
388
+ # Label mapping
389
+ self.id2label: Dict[int, str] = {}
390
+ raw = self.metadata.get("id2label", {})
391
+ if raw:
392
+ self.id2label = {int(k): v for k, v in raw.items()}
393
+
394
+ # Create session
395
+ if providers is None:
396
+ providers = ort.get_available_providers()
397
+ self.session = ort.InferenceSession(model_path, providers=providers)
398
+
399
+ # Inspect inputs / outputs
400
+ self.input_name = self.session.get_inputs()[0].name
401
+ self.input_shape = self.session.get_inputs()[0].shape # may have str dims
402
+ self.output_names = [o.name for o in self.session.get_outputs()]
403
+
404
+ # Determine expected spatial size from metadata or model input shape
405
+ self._model_height = self.metadata.get("input_height")
406
+ self._model_width = self.metadata.get("input_width")
407
+ if self._model_height is None and isinstance(self.input_shape, list):
408
+ if len(self.input_shape) == 4:
409
+ h, w = self.input_shape[2], self.input_shape[3]
410
+ if isinstance(h, int) and isinstance(w, int):
411
+ self._model_height = h
412
+ self._model_width = w
413
+
414
+ active = self.session.get_providers()
415
+ print(f"ONNX model loaded from {model_path}")
416
+ print(f"Execution providers: {active}")
417
+ if self.task:
418
+ print(f"Task: {self.task}")
419
+
420
+ # ------------------------------------------------------------------
421
+ # Image I/O helpers (mirrors AutoGeoImageProcessor)
422
+ # ------------------------------------------------------------------
423
+
424
+ @staticmethod
425
+ def load_geotiff(
426
+ source: Union[str, "rasterio.DatasetReader"],
427
+ window: Optional[Window] = None,
428
+ bands: Optional[List[int]] = None,
429
+ ) -> Tuple[np.ndarray, Dict]:
430
+ """Load a GeoTIFF file and return data with metadata.
431
+
432
+ Args:
433
+ source: Path to GeoTIFF file or open rasterio DatasetReader.
434
+ window: Optional rasterio Window for reading a subset.
435
+ bands: List of band indices to read (1-indexed).
436
+
437
+ Returns:
438
+ Tuple of (image array in CHW format, metadata dict).
439
+ """
440
+ should_close = False
441
+ if isinstance(source, str):
442
+ src = rasterio.open(source)
443
+ should_close = True
444
+ else:
445
+ src = source
446
+
447
+ try:
448
+ if bands:
449
+ data = src.read(bands, window=window)
450
+ else:
451
+ data = src.read(window=window)
452
+
453
+ profile = src.profile.copy()
454
+ if window:
455
+ profile.update(
456
+ {
457
+ "height": window.height,
458
+ "width": window.width,
459
+ "transform": src.window_transform(window),
460
+ }
461
+ )
462
+
463
+ metadata = {
464
+ "profile": profile,
465
+ "crs": src.crs,
466
+ "transform": profile["transform"],
467
+ "bounds": (
468
+ src.bounds
469
+ if not window
470
+ else rasterio.windows.bounds(window, src.transform)
471
+ ),
472
+ "width": profile["width"],
473
+ "height": profile["height"],
474
+ "count": data.shape[0],
475
+ }
476
+ finally:
477
+ if should_close:
478
+ src.close()
479
+
480
+ return data, metadata
481
+
482
+ @staticmethod
483
+ def load_image(
484
+ source: Union[str, np.ndarray, "Image.Image"],
485
+ window: Optional[Window] = None,
486
+ bands: Optional[List[int]] = None,
487
+ ) -> Tuple[np.ndarray, Optional[Dict]]:
488
+ """Load an image from various sources.
489
+
490
+ Args:
491
+ source: Path to image file, numpy array, or PIL Image.
492
+ window: Optional rasterio Window (only for GeoTIFF).
493
+ bands: List of band indices (only for GeoTIFF, 1-indexed).
494
+
495
+ Returns:
496
+ Tuple of (image array in CHW format, metadata dict or None).
497
+ """
498
+ if isinstance(source, str):
499
+ try:
500
+ with rasterio.open(source) as src:
501
+ if src.crs is not None or source.lower().endswith(
502
+ (".tif", ".tiff")
503
+ ):
504
+ return ONNXGeoModel.load_geotiff(source, window, bands)
505
+ except (rasterio.RasterioIOError, rasterio.errors.RasterioIOError):
506
+ pass # not a rasterio-compatible file; fall through to PIL
507
+
508
+ image = Image.open(source).convert("RGB")
509
+ data = np.array(image).transpose(2, 0, 1)
510
+ return data, None
511
+
512
+ elif isinstance(source, np.ndarray):
513
+ if source.ndim == 2:
514
+ source = source[np.newaxis, :, :]
515
+ elif source.ndim == 3 and source.shape[2] in [1, 3, 4]:
516
+ source = source.transpose(2, 0, 1)
517
+ return source, None
518
+
519
+ elif isinstance(source, Image.Image):
520
+ data = np.array(source.convert("RGB")).transpose(2, 0, 1)
521
+ return data, None
522
+
523
+ else:
524
+ raise TypeError(f"Unsupported source type: {type(source)}")
525
+
526
+ # ------------------------------------------------------------------
527
+ # Preprocessing
528
+ # ------------------------------------------------------------------
529
+
530
+ def _prepare_input(
531
+ self,
532
+ data: np.ndarray,
533
+ target_height: Optional[int] = None,
534
+ target_width: Optional[int] = None,
535
+ ) -> np.ndarray:
536
+ """Prepare a CHW uint‑capable array for the ONNX model.
537
+
538
+ The method converts to 3‑channel RGB, normalizes to ``[0, 1]``
539
+ float32, resizes to the model's expected spatial dimensions and
540
+ adds a batch dimension.
541
+
542
+ Args:
543
+ data: Image array in CHW format.
544
+ target_height: Target height. Defaults to model metadata.
545
+ target_width: Target width. Defaults to model metadata.
546
+
547
+ Returns:
548
+ Numpy array of shape ``(1, 3, H, W)`` ready for the ONNX
549
+ session.
550
+ """
551
+ # Lazy import to avoid QGIS opencv conflicts
552
+ import cv2
553
+
554
+ # CHW → HWC
555
+ if data.ndim == 3:
556
+ img = data.transpose(1, 2, 0)
557
+ else:
558
+ img = data
559
+
560
+ # Ensure 3 channels
561
+ if img.ndim == 2:
562
+ img = np.stack([img] * 3, axis=-1)
563
+ elif img.shape[-1] == 1:
564
+ img = np.repeat(img, 3, axis=-1)
565
+ elif img.shape[-1] > 3:
566
+ img = img[..., :3]
567
+
568
+ # Percentile normalization → uint8
569
+ if img.dtype != np.uint8:
570
+ for i in range(img.shape[-1]):
571
+ band = img[..., i].astype(np.float32)
572
+ p2, p98 = np.percentile(band, [2, 98])
573
+ if p98 > p2:
574
+ img[..., i] = np.clip((band - p2) / (p98 - p2), 0, 1)
575
+ else:
576
+ img[..., i] = 0
577
+ img = (img * 255).astype(np.uint8)
578
+
579
+ # Resize to model expected size if needed
580
+ th = target_height or self._model_height
581
+ tw = target_width or self._model_width
582
+ if th and tw and (img.shape[0] != th or img.shape[1] != tw):
583
+ img = cv2.resize(img, (tw, th), interpolation=cv2.INTER_LINEAR)
584
+
585
+ # Normalize to float32 [0, 1]
586
+ img = img.astype(np.float32) / 255.0
587
+
588
+ # HWC → NCHW
589
+ tensor = img.transpose(2, 0, 1)[np.newaxis, ...]
590
+ return tensor
591
+
592
+ # ------------------------------------------------------------------
593
+ # Prediction
594
+ # ------------------------------------------------------------------
595
+
596
+ def predict(
597
+ self,
598
+ source: Union[str, np.ndarray, "Image.Image"],
599
+ output_path: Optional[str] = None,
600
+ output_vector_path: Optional[str] = None,
601
+ window: Optional[Window] = None,
602
+ bands: Optional[List[int]] = None,
603
+ threshold: float = 0.5,
604
+ box_threshold: float = 0.3,
605
+ min_object_area: int = 100,
606
+ simplify_tolerance: float = 1.0,
607
+ batch_size: int = 1,
608
+ return_probabilities: bool = False,
609
+ **kwargs: Any,
610
+ ) -> Dict[str, Any]:
611
+ """Run inference on a GeoTIFF or image.
612
+
613
+ This method follows the same interface as
614
+ :meth:`~geoai.auto.AutoGeoModel.predict`.
615
+
616
+ Args:
617
+ source: Input image path, numpy array, or PIL Image.
618
+ output_path: Path to save output GeoTIFF (segmentation / depth).
619
+ output_vector_path: Path to save vectorised output.
620
+ window: Optional rasterio Window for reading a subset.
621
+ bands: Band indices to read (1-indexed).
622
+ threshold: Threshold for binary masks (segmentation).
623
+ box_threshold: Confidence threshold for detections.
624
+ min_object_area: Minimum polygon area in pixels for
625
+ vectorization.
626
+ simplify_tolerance: Tolerance for polygon simplification.
627
+ batch_size: Batch size for tiled processing (reserved for
628
+ future use).
629
+ return_probabilities: Whether to return probability maps.
630
+ **kwargs: Extra keyword arguments (currently unused).
631
+
632
+ Returns:
633
+ Dictionary with results (``mask``, ``class``, ``boxes`` etc.)
634
+ depending on the task, plus ``metadata``.
635
+
636
+ Example:
637
+ >>> model = ONNXGeoModel("segformer.onnx",
638
+ ... task="semantic-segmentation")
639
+ >>> result = model.predict("input.tif", output_path="output.tif")
640
+ """
641
+ # Handle URL sources
642
+ if isinstance(source, str) and source.startswith(("http://", "https://")):
643
+ import requests
644
+
645
+ pil_image = Image.open(requests.get(source, stream=True).raw)
646
+ data = np.array(pil_image.convert("RGB")).transpose(2, 0, 1)
647
+ metadata = None
648
+ else:
649
+ data, metadata = self.load_image(source, window, bands)
650
+
651
+ # Determine spatial size
652
+ if data.ndim == 3:
653
+ _, height, width = data.shape
654
+ else:
655
+ height, width = data.shape
656
+
657
+ # Classification never uses tiled processing
658
+ use_tiled = (
659
+ height > self.tile_size or width > self.tile_size
660
+ ) and self.task not in ("classification", "image-classification")
661
+
662
+ if use_tiled:
663
+ result = self._predict_tiled(
664
+ data,
665
+ metadata,
666
+ threshold=threshold,
667
+ return_probabilities=return_probabilities,
668
+ )
669
+ else:
670
+ result = self._predict_single(
671
+ data,
672
+ metadata,
673
+ threshold=threshold,
674
+ return_probabilities=return_probabilities,
675
+ )
676
+
677
+ # Save GeoTIFF
678
+ if output_path and metadata:
679
+ out_data = result.get("mask", result.get("output"))
680
+ if out_data is not None:
681
+ self._save_geotiff(out_data, output_path, metadata, nodata=0)
682
+ result["output_path"] = output_path
683
+
684
+ # Save vector
685
+ if output_vector_path and metadata and "mask" in result:
686
+ gdf = self.mask_to_vector(
687
+ result["mask"],
688
+ metadata,
689
+ threshold=threshold,
690
+ min_object_area=min_object_area,
691
+ simplify_tolerance=simplify_tolerance,
692
+ )
693
+ if gdf is not None and len(gdf) > 0:
694
+ gdf.to_file(output_vector_path)
695
+ result["vector_path"] = output_vector_path
696
+ result["geodataframe"] = gdf
697
+
698
+ return result
699
+
700
+ # ------------------------------------------------------------------
701
+ # Internal prediction helpers
702
+ # ------------------------------------------------------------------
703
+
704
+ def _predict_single(
705
+ self,
706
+ data: np.ndarray,
707
+ metadata: Optional[Dict],
708
+ threshold: float = 0.5,
709
+ return_probabilities: bool = False,
710
+ ) -> Dict[str, Any]:
711
+ """Run inference on a single (non-tiled) image."""
712
+ # Lazy import to avoid QGIS opencv conflicts
713
+ import cv2
714
+
715
+ original_h = data.shape[1] if data.ndim == 3 else data.shape[0]
716
+ original_w = data.shape[2] if data.ndim == 3 else data.shape[1]
717
+
718
+ input_tensor = self._prepare_input(data)
719
+ outputs = self.session.run(self.output_names, {self.input_name: input_tensor})
720
+
721
+ result = self._process_outputs(
722
+ outputs, (original_h, original_w), threshold, return_probabilities
723
+ )
724
+ result["metadata"] = metadata
725
+ return result
726
+
727
+ def _predict_tiled(
728
+ self,
729
+ data: np.ndarray,
730
+ metadata: Optional[Dict],
731
+ threshold: float = 0.5,
732
+ return_probabilities: bool = False,
733
+ ) -> Dict[str, Any]:
734
+ """Run tiled inference for large images."""
735
+ # Lazy import to avoid QGIS opencv conflicts
736
+ import cv2
737
+
738
+ if data.ndim == 3:
739
+ _, height, width = data.shape
740
+ else:
741
+ height, width = data.shape
742
+
743
+ effective = self.tile_size - 2 * self.overlap
744
+ n_x = max(1, int(np.ceil(width / effective)))
745
+ n_y = max(1, int(np.ceil(height / effective)))
746
+ total = n_x * n_y
747
+
748
+ mask_output = np.zeros((height, width), dtype=np.float32)
749
+ count_output = np.zeros((height, width), dtype=np.float32)
750
+
751
+ print(f"Processing {total} tiles ({n_x}x{n_y})")
752
+
753
+ with tqdm(total=total, desc="Processing tiles") as pbar:
754
+ for iy in range(n_y):
755
+ for ix in range(n_x):
756
+ x_start = max(0, ix * effective - self.overlap)
757
+ y_start = max(0, iy * effective - self.overlap)
758
+ x_end = min(width, (ix + 1) * effective + self.overlap)
759
+ y_end = min(height, (iy + 1) * effective + self.overlap)
760
+
761
+ if data.ndim == 3:
762
+ tile = data[:, y_start:y_end, x_start:x_end]
763
+ else:
764
+ tile = data[y_start:y_end, x_start:x_end]
765
+
766
+ try:
767
+ tile_result = self._predict_single(
768
+ tile, None, threshold, return_probabilities
769
+ )
770
+ tile_mask = tile_result.get("mask", tile_result.get("output"))
771
+ if tile_mask is not None:
772
+ if tile_mask.ndim > 2:
773
+ tile_mask = tile_mask.squeeze()
774
+ if tile_mask.ndim > 2:
775
+ tile_mask = tile_mask[0]
776
+
777
+ tile_h = y_end - y_start
778
+ tile_w = x_end - x_start
779
+ if tile_mask.shape != (tile_h, tile_w):
780
+ tile_mask = cv2.resize(
781
+ tile_mask.astype(np.float32),
782
+ (tile_w, tile_h),
783
+ interpolation=cv2.INTER_LINEAR,
784
+ )
785
+
786
+ mask_output[y_start:y_end, x_start:x_end] += tile_mask
787
+ count_output[y_start:y_end, x_start:x_end] += 1
788
+ except Exception as e:
789
+ print(f"Error processing tile ({ix}, {iy}): {e}")
790
+
791
+ pbar.update(1)
792
+
793
+ count_output = np.maximum(count_output, 1)
794
+ mask_output = mask_output / count_output
795
+
796
+ return {
797
+ "mask": (mask_output > threshold).astype(np.uint8),
798
+ "probabilities": mask_output if return_probabilities else None,
799
+ "metadata": metadata,
800
+ }
801
+
802
+ # ------------------------------------------------------------------
803
+ # Output processing
804
+ # ------------------------------------------------------------------
805
+
806
+ def _process_outputs(
807
+ self,
808
+ outputs: List[np.ndarray],
809
+ original_size: Tuple[int, int],
810
+ threshold: float = 0.5,
811
+ return_probabilities: bool = False,
812
+ ) -> Dict[str, Any]:
813
+ """Map raw ONNX outputs to a result dict.
814
+
815
+ Args:
816
+ outputs: List of numpy arrays returned by
817
+ ``session.run()``.
818
+ original_size: ``(height, width)`` of the input before
819
+ resizing.
820
+ threshold: Binary threshold for segmentation masks.
821
+ return_probabilities: Whether to include probability maps.
822
+
823
+ Returns:
824
+ Result dictionary.
825
+ """
826
+ # Lazy import to avoid QGIS opencv conflicts
827
+ import cv2
828
+
829
+ result: Dict[str, Any] = {}
830
+ oh, ow = original_size
831
+
832
+ if self.task in ("segmentation", "semantic-segmentation"):
833
+ logits = outputs[0] # (1, C, H, W)
834
+ if logits.ndim == 4:
835
+ # Softmax → argmax
836
+ exp = np.exp(logits - logits.max(axis=1, keepdims=True))
837
+ probs = exp / exp.sum(axis=1, keepdims=True)
838
+ mask = probs.argmax(axis=1).squeeze() # (H, W)
839
+
840
+ if mask.shape != (oh, ow):
841
+ mask = cv2.resize(
842
+ mask.astype(np.float32),
843
+ (ow, oh),
844
+ interpolation=cv2.INTER_NEAREST,
845
+ )
846
+
847
+ result["mask"] = mask.astype(np.uint8)
848
+ if return_probabilities:
849
+ result["probabilities"] = probs.squeeze()
850
+
851
+ elif self.task in ("classification", "image-classification"):
852
+ logits = outputs[0] # (1, C)
853
+ exp = np.exp(logits - logits.max(axis=-1, keepdims=True))
854
+ probs = exp / exp.sum(axis=-1, keepdims=True)
855
+ pred = int(probs.argmax(axis=-1).squeeze())
856
+ result["class"] = pred
857
+ result["probabilities"] = probs.squeeze()
858
+ if self.id2label:
859
+ result["label"] = self.id2label.get(pred, str(pred))
860
+
861
+ elif self.task == "object-detection":
862
+ logits = outputs[0] # (1, N, num_classes)
863
+ pred_boxes = outputs[1] if len(outputs) > 1 else None # (1, N, 4)
864
+ if pred_boxes is not None:
865
+ # Sigmoid scores
866
+ scores_all = 1.0 / (1.0 + np.exp(-logits)) # sigmoid
867
+ scores = scores_all.max(axis=-1).squeeze() # (N,)
868
+ labels = scores_all.argmax(axis=-1).squeeze() # (N,)
869
+ boxes = pred_boxes.squeeze() # (N, 4)
870
+
871
+ keep = scores > threshold
872
+ result["boxes"] = boxes[keep]
873
+ result["scores"] = scores[keep]
874
+ result["labels"] = labels[keep]
875
+
876
+ elif self.task == "depth-estimation":
877
+ depth = outputs[0].squeeze()
878
+ if depth.shape != (oh, ow):
879
+ depth = cv2.resize(
880
+ depth.astype(np.float32),
881
+ (ow, oh),
882
+ interpolation=cv2.INTER_LINEAR,
883
+ )
884
+ result["output"] = depth
885
+ result["depth"] = depth
886
+
887
+ else:
888
+ # Fallback – expose raw outputs
889
+ result["output"] = outputs[0]
890
+
891
+ return result
892
+
893
+ # ------------------------------------------------------------------
894
+ # Vectorization
895
+ # ------------------------------------------------------------------
896
+
897
+ @staticmethod
898
+ def mask_to_vector(
899
+ mask: np.ndarray,
900
+ metadata: Dict,
901
+ threshold: float = 0.5,
902
+ min_object_area: int = 100,
903
+ max_object_area: Optional[int] = None,
904
+ simplify_tolerance: float = 1.0,
905
+ ) -> Optional[gpd.GeoDataFrame]:
906
+ """Convert a raster mask to vector polygons.
907
+
908
+ Args:
909
+ mask: Binary or probability mask array.
910
+ metadata: Geospatial metadata dictionary.
911
+ threshold: Threshold for binarizing probability masks.
912
+ min_object_area: Minimum polygon area in pixels.
913
+ max_object_area: Maximum polygon area in pixels (optional).
914
+ simplify_tolerance: Tolerance for polygon simplification.
915
+
916
+ Returns:
917
+ GeoDataFrame with polygon geometries, or *None* if no valid
918
+ polygons are found.
919
+ """
920
+ if metadata is None or metadata.get("crs") is None:
921
+ print("Warning: No CRS information available for vectorization")
922
+ return None
923
+
924
+ if mask.dtype in (np.float32, np.float64):
925
+ mask = (mask > threshold).astype(np.uint8)
926
+ else:
927
+ mask = (mask > 0).astype(np.uint8)
928
+
929
+ transform = metadata.get("transform")
930
+ crs = metadata.get("crs")
931
+ if transform is None:
932
+ print("Warning: No transform available for vectorization")
933
+ return None
934
+
935
+ polygons: List = []
936
+ values: List = []
937
+
938
+ try:
939
+ for geom, value in shapes(mask, transform=transform):
940
+ if value > 0:
941
+ poly = shape(geom)
942
+ pixel_area = poly.area / (transform.a * abs(transform.e))
943
+ if pixel_area < min_object_area:
944
+ continue
945
+ if max_object_area and pixel_area > max_object_area:
946
+ continue
947
+ if simplify_tolerance > 0:
948
+ poly = poly.simplify(
949
+ simplify_tolerance * abs(transform.a),
950
+ preserve_topology=True,
951
+ )
952
+ if poly.is_valid and not poly.is_empty:
953
+ polygons.append(poly)
954
+ values.append(value)
955
+ except Exception as e:
956
+ print(f"Error during vectorization: {e}")
957
+ return None
958
+
959
+ if not polygons:
960
+ return None
961
+
962
+ return gpd.GeoDataFrame(
963
+ {"geometry": polygons, "class": values},
964
+ crs=crs,
965
+ )
966
+
967
+ # ------------------------------------------------------------------
968
+ # GeoTIFF / vector save helpers
969
+ # ------------------------------------------------------------------
970
+
971
+ @staticmethod
972
+ def _save_geotiff(
973
+ data: np.ndarray,
974
+ output_path: str,
975
+ metadata: Dict,
976
+ dtype: Optional[str] = None,
977
+ compress: str = "lzw",
978
+ nodata: Optional[float] = None,
979
+ ) -> str:
980
+ """Save an array as a GeoTIFF with geospatial metadata.
981
+
982
+ Args:
983
+ data: Array to save (2D or 3D in CHW format).
984
+ output_path: Output file path.
985
+ metadata: Metadata dictionary from :meth:`load_geotiff`.
986
+ dtype: Output data type. If *None*, inferred from *data*.
987
+ compress: Compression method.
988
+ nodata: NoData value.
989
+
990
+ Returns:
991
+ Path to the saved file.
992
+ """
993
+ profile = metadata["profile"].copy()
994
+ if dtype is None:
995
+ dtype = str(data.dtype)
996
+
997
+ if data.ndim == 2:
998
+ count = 1
999
+ height, width = data.shape
1000
+ else:
1001
+ count = data.shape[0]
1002
+ height, width = data.shape[1], data.shape[2]
1003
+
1004
+ profile.update(
1005
+ {
1006
+ "dtype": dtype,
1007
+ "count": count,
1008
+ "height": height,
1009
+ "width": width,
1010
+ "compress": compress,
1011
+ }
1012
+ )
1013
+ if nodata is not None:
1014
+ profile["nodata"] = nodata
1015
+
1016
+ os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
1017
+
1018
+ with rasterio.open(output_path, "w", **profile) as dst:
1019
+ if data.ndim == 2:
1020
+ dst.write(data, 1)
1021
+ else:
1022
+ dst.write(data)
1023
+
1024
+ return output_path
1025
+
1026
+ @staticmethod
1027
+ def save_vector(
1028
+ gdf: gpd.GeoDataFrame,
1029
+ output_path: str,
1030
+ driver: Optional[str] = None,
1031
+ ) -> str:
1032
+ """Save a GeoDataFrame to file.
1033
+
1034
+ Args:
1035
+ gdf: GeoDataFrame to save.
1036
+ output_path: Output file path.
1037
+ driver: File driver (auto-detected from extension if *None*).
1038
+
1039
+ Returns:
1040
+ Path to the saved file.
1041
+ """
1042
+ os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
1043
+ if driver is None:
1044
+ ext = os.path.splitext(output_path)[1].lower()
1045
+ driver_map = {
1046
+ ".geojson": "GeoJSON",
1047
+ ".json": "GeoJSON",
1048
+ ".gpkg": "GPKG",
1049
+ ".shp": "ESRI Shapefile",
1050
+ ".parquet": "Parquet",
1051
+ ".fgb": "FlatGeobuf",
1052
+ }
1053
+ driver = driver_map.get(ext, "GeoJSON")
1054
+ gdf.to_file(output_path, driver=driver)
1055
+ return output_path
1056
+
1057
+ def __repr__(self) -> str:
1058
+ return (
1059
+ f"ONNXGeoModel(path={self.model_path!r}, task={self.task!r}, "
1060
+ f"providers={self.session.get_providers()!r})"
1061
+ )
1062
+
1063
+
1064
+ # ---------------------------------------------------------------------------
1065
+ # Convenience functions
1066
+ # ---------------------------------------------------------------------------
1067
+
1068
+
1069
+ def onnx_semantic_segmentation(
1070
+ input_path: str,
1071
+ output_path: str,
1072
+ model_path: str,
1073
+ output_vector_path: Optional[str] = None,
1074
+ threshold: float = 0.5,
1075
+ tile_size: int = 1024,
1076
+ overlap: int = 128,
1077
+ min_object_area: int = 100,
1078
+ simplify_tolerance: float = 1.0,
1079
+ providers: Optional[List[str]] = None,
1080
+ **kwargs: Any,
1081
+ ) -> Dict[str, Any]:
1082
+ """Perform semantic segmentation using an ONNX model on a GeoTIFF.
1083
+
1084
+ This is a convenience wrapper around :class:`ONNXGeoModel`.
1085
+
1086
+ Args:
1087
+ input_path: Path to input GeoTIFF.
1088
+ output_path: Path to save output segmentation GeoTIFF.
1089
+ model_path: Path to the ONNX model file.
1090
+ output_vector_path: Optional path to save vectorised output.
1091
+ threshold: Threshold for binary masks.
1092
+ tile_size: Tile size for processing large images.
1093
+ overlap: Overlap between tiles.
1094
+ min_object_area: Minimum object area for vectorization.
1095
+ simplify_tolerance: Tolerance for polygon simplification.
1096
+ providers: ONNX Runtime execution providers.
1097
+ **kwargs: Additional arguments passed to :meth:`ONNXGeoModel.predict`.
1098
+
1099
+ Returns:
1100
+ Dictionary with results.
1101
+
1102
+ Example:
1103
+ >>> result = onnx_semantic_segmentation(
1104
+ ... "input.tif",
1105
+ ... "output.tif",
1106
+ ... "segformer.onnx",
1107
+ ... output_vector_path="output.geojson",
1108
+ ... )
1109
+ """
1110
+ model = ONNXGeoModel(
1111
+ model_path,
1112
+ task="semantic-segmentation",
1113
+ providers=providers,
1114
+ tile_size=tile_size,
1115
+ overlap=overlap,
1116
+ )
1117
+ return model.predict(
1118
+ input_path,
1119
+ output_path=output_path,
1120
+ output_vector_path=output_vector_path,
1121
+ threshold=threshold,
1122
+ min_object_area=min_object_area,
1123
+ simplify_tolerance=simplify_tolerance,
1124
+ **kwargs,
1125
+ )
1126
+
1127
+
1128
+ def onnx_image_classification(
1129
+ input_path: str,
1130
+ model_path: str,
1131
+ providers: Optional[List[str]] = None,
1132
+ **kwargs: Any,
1133
+ ) -> Dict[str, Any]:
1134
+ """Classify an image using an ONNX model.
1135
+
1136
+ Args:
1137
+ input_path: Path to input image or GeoTIFF.
1138
+ model_path: Path to the ONNX model file.
1139
+ providers: ONNX Runtime execution providers.
1140
+ **kwargs: Additional arguments passed to :meth:`ONNXGeoModel.predict`.
1141
+
1142
+ Returns:
1143
+ Dictionary with ``class``, ``label`` (if available), and
1144
+ ``probabilities``.
1145
+
1146
+ Example:
1147
+ >>> result = onnx_image_classification("image.tif", "classifier.onnx")
1148
+ >>> print(result["class"], result["label"])
1149
+ """
1150
+ model = ONNXGeoModel(
1151
+ model_path,
1152
+ task="image-classification",
1153
+ providers=providers,
1154
+ )
1155
+ return model.predict(input_path, **kwargs)