geoai-py 0.24.0__py2.py3-none-any.whl → 0.26.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/utils.py CHANGED
@@ -64,7 +64,7 @@ def view_raster(
64
64
  client_args: Optional[Dict] = {"cors_all": False},
65
65
  basemap: Optional[str] = "OpenStreetMap",
66
66
  basemap_args: Optional[Dict] = None,
67
- backend: Optional[str] = "ipyleaflet",
67
+ backend: Optional[str] = "folium",
68
68
  **kwargs: Any,
69
69
  ) -> Any:
70
70
  """
@@ -87,7 +87,7 @@ def view_raster(
87
87
  client_args (Optional[Dict], optional): Additional arguments for the client. Defaults to {"cors_all": False}.
88
88
  basemap (Optional[str], optional): The basemap to use. Defaults to "OpenStreetMap".
89
89
  basemap_args (Optional[Dict], optional): Additional arguments for the basemap. Defaults to None.
90
- backend (Optional[str], optional): The backend to use. Defaults to "ipyleaflet".
90
+ backend (Optional[str], optional): The backend to use. Defaults to "folium".
91
91
  **kwargs (Any): Additional keyword arguments.
92
92
 
93
93
  Returns:
@@ -123,26 +123,39 @@ def view_raster(
123
123
  if isinstance(source, dict):
124
124
  source = dict_to_image(source)
125
125
 
126
- if (
127
- isinstance(source, str)
128
- and source.lower().endswith(".tif")
129
- and source.startswith("http")
130
- ):
131
- if indexes is not None:
132
- kwargs["bidx"] = indexes
133
- if colormap is not None:
134
- kwargs["colormap_name"] = colormap
135
- if attribution is None:
136
- attribution = "TiTiler"
137
-
138
- m.add_cog_layer(
139
- source,
140
- name=layer_name,
141
- opacity=opacity,
142
- attribution=attribution,
143
- zoom_to_layer=zoom_to_layer,
144
- **kwargs,
145
- )
126
+ if isinstance(source, str) and source.startswith("http"):
127
+ if backend == "folium":
128
+
129
+ m.add_geotiff(
130
+ url=source,
131
+ name=layer_name,
132
+ opacity=opacity,
133
+ attribution=attribution,
134
+ fit_bounds=zoom_to_layer,
135
+ palette=colormap,
136
+ vmin=vmin,
137
+ vmax=vmax,
138
+ **kwargs,
139
+ )
140
+ m.add_layer_control()
141
+ m.add_opacity_control()
142
+
143
+ else:
144
+ if indexes is not None:
145
+ kwargs["bidx"] = indexes
146
+ if colormap is not None:
147
+ kwargs["colormap_name"] = colormap
148
+ if attribution is None:
149
+ attribution = "TiTiler"
150
+
151
+ m.add_cog_layer(
152
+ source,
153
+ name=layer_name,
154
+ opacity=opacity,
155
+ attribution=attribution,
156
+ zoom_to_layer=zoom_to_layer,
157
+ **kwargs,
158
+ )
146
159
  else:
147
160
  m.add_raster(
148
161
  source=source,
@@ -381,6 +394,394 @@ def calc_stats(dataset, divide_by: float = 1.0) -> Tuple[np.ndarray, np.ndarray]
381
394
  return accum_mean / len(files), accum_std / len(files)
382
395
 
383
396
 
397
+ def calc_iou(
398
+ ground_truth: Union[str, np.ndarray, torch.Tensor],
399
+ prediction: Union[str, np.ndarray, torch.Tensor],
400
+ num_classes: Optional[int] = None,
401
+ ignore_index: Optional[int] = None,
402
+ smooth: float = 1e-6,
403
+ band: int = 1,
404
+ ) -> Union[float, np.ndarray]:
405
+ """
406
+ Calculate Intersection over Union (IoU) between ground truth and prediction masks.
407
+
408
+ This function computes the IoU metric for segmentation tasks. It supports both
409
+ binary and multi-class segmentation, and can handle numpy arrays, PyTorch tensors,
410
+ or file paths to raster files.
411
+
412
+ Args:
413
+ ground_truth (Union[str, np.ndarray, torch.Tensor]): Ground truth segmentation mask.
414
+ Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
415
+ For binary segmentation: shape (H, W) with values {0, 1}.
416
+ For multi-class segmentation: shape (H, W) with class indices.
417
+ prediction (Union[str, np.ndarray, torch.Tensor]): Predicted segmentation mask.
418
+ Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
419
+ Should have the same shape and format as ground_truth.
420
+ num_classes (Optional[int], optional): Number of classes for multi-class segmentation.
421
+ If None, assumes binary segmentation. Defaults to None.
422
+ ignore_index (Optional[int], optional): Class index to ignore in computation.
423
+ Useful for ignoring background or unlabeled pixels. Defaults to None.
424
+ smooth (float, optional): Smoothing factor to avoid division by zero.
425
+ Defaults to 1e-6.
426
+ band (int, optional): Band index to read from raster file (1-based indexing).
427
+ Only used when input is a file path. Defaults to 1.
428
+
429
+ Returns:
430
+ Union[float, np.ndarray]: For binary segmentation, returns a single float IoU score.
431
+ For multi-class segmentation, returns an array of IoU scores for each class.
432
+
433
+ Examples:
434
+ >>> # Binary segmentation with arrays
435
+ >>> gt = np.array([[0, 0, 1, 1], [0, 1, 1, 1]])
436
+ >>> pred = np.array([[0, 0, 1, 1], [0, 0, 1, 1]])
437
+ >>> iou = calc_iou(gt, pred)
438
+ >>> print(f"IoU: {iou:.4f}")
439
+ IoU: 0.8333
440
+
441
+ >>> # Multi-class segmentation
442
+ >>> gt = np.array([[0, 0, 1, 1], [0, 2, 2, 1]])
443
+ >>> pred = np.array([[0, 0, 1, 1], [0, 0, 2, 2]])
444
+ >>> iou = calc_iou(gt, pred, num_classes=3)
445
+ >>> print(f"IoU per class: {iou}")
446
+ IoU per class: [0.8333 0.5000 0.5000]
447
+
448
+ >>> # Using PyTorch tensors
449
+ >>> gt_tensor = torch.tensor([[0, 0, 1, 1], [0, 1, 1, 1]])
450
+ >>> pred_tensor = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]])
451
+ >>> iou = calc_iou(gt_tensor, pred_tensor)
452
+ >>> print(f"IoU: {iou:.4f}")
453
+ IoU: 0.8333
454
+
455
+ >>> # Using raster file paths
456
+ >>> iou = calc_iou("ground_truth.tif", "prediction.tif", num_classes=3)
457
+ >>> print(f"Mean IoU: {np.nanmean(iou):.4f}")
458
+ Mean IoU: 0.7500
459
+ """
460
+ # Load from file if string path is provided
461
+ if isinstance(ground_truth, str):
462
+ with rasterio.open(ground_truth) as src:
463
+ ground_truth = src.read(band)
464
+ if isinstance(prediction, str):
465
+ with rasterio.open(prediction) as src:
466
+ prediction = src.read(band)
467
+
468
+ # Convert to numpy if torch tensor
469
+ if isinstance(ground_truth, torch.Tensor):
470
+ ground_truth = ground_truth.cpu().numpy()
471
+ if isinstance(prediction, torch.Tensor):
472
+ prediction = prediction.cpu().numpy()
473
+
474
+ # Ensure inputs have the same shape
475
+ if ground_truth.shape != prediction.shape:
476
+ raise ValueError(
477
+ f"Shape mismatch: ground_truth {ground_truth.shape} vs prediction {prediction.shape}"
478
+ )
479
+
480
+ # Binary segmentation
481
+ if num_classes is None:
482
+ ground_truth = ground_truth.astype(bool)
483
+ prediction = prediction.astype(bool)
484
+
485
+ intersection = np.logical_and(ground_truth, prediction).sum()
486
+ union = np.logical_or(ground_truth, prediction).sum()
487
+
488
+ if union == 0:
489
+ return 1.0 if intersection == 0 else 0.0
490
+
491
+ iou = (intersection + smooth) / (union + smooth)
492
+ return float(iou)
493
+
494
+ # Multi-class segmentation
495
+ else:
496
+ iou_per_class = []
497
+
498
+ for class_idx in range(num_classes):
499
+ # Handle ignored class by appending np.nan
500
+ if ignore_index is not None and class_idx == ignore_index:
501
+ iou_per_class.append(np.nan)
502
+ continue
503
+
504
+ # Create binary masks for current class
505
+ gt_class = (ground_truth == class_idx).astype(bool)
506
+ pred_class = (prediction == class_idx).astype(bool)
507
+
508
+ intersection = np.logical_and(gt_class, pred_class).sum()
509
+ union = np.logical_or(gt_class, pred_class).sum()
510
+
511
+ if union == 0:
512
+ # If class is not present in both gt and pred
513
+ iou_per_class.append(np.nan)
514
+ else:
515
+ iou_per_class.append((intersection + smooth) / (union + smooth))
516
+
517
+ return np.array(iou_per_class)
518
+
519
+
520
+ def calc_f1_score(
521
+ ground_truth: Union[str, np.ndarray, torch.Tensor],
522
+ prediction: Union[str, np.ndarray, torch.Tensor],
523
+ num_classes: Optional[int] = None,
524
+ ignore_index: Optional[int] = None,
525
+ smooth: float = 1e-6,
526
+ band: int = 1,
527
+ ) -> Union[float, np.ndarray]:
528
+ """
529
+ Calculate F1 score between ground truth and prediction masks.
530
+
531
+ The F1 score is the harmonic mean of precision and recall, computed as:
532
+ F1 = 2 * (precision * recall) / (precision + recall)
533
+ where precision = TP / (TP + FP) and recall = TP / (TP + FN).
534
+
535
+ This function supports both binary and multi-class segmentation, and can handle
536
+ numpy arrays, PyTorch tensors, or file paths to raster files.
537
+
538
+ Args:
539
+ ground_truth (Union[str, np.ndarray, torch.Tensor]): Ground truth segmentation mask.
540
+ Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
541
+ For binary segmentation: shape (H, W) with values {0, 1}.
542
+ For multi-class segmentation: shape (H, W) with class indices.
543
+ prediction (Union[str, np.ndarray, torch.Tensor]): Predicted segmentation mask.
544
+ Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
545
+ Should have the same shape and format as ground_truth.
546
+ num_classes (Optional[int], optional): Number of classes for multi-class segmentation.
547
+ If None, assumes binary segmentation. Defaults to None.
548
+ ignore_index (Optional[int], optional): Class index to ignore in computation.
549
+ Useful for ignoring background or unlabeled pixels. Defaults to None.
550
+ smooth (float, optional): Smoothing factor to avoid division by zero.
551
+ Defaults to 1e-6.
552
+ band (int, optional): Band index to read from raster file (1-based indexing).
553
+ Only used when input is a file path. Defaults to 1.
554
+
555
+ Returns:
556
+ Union[float, np.ndarray]: For binary segmentation, returns a single float F1 score.
557
+ For multi-class segmentation, returns an array of F1 scores for each class.
558
+
559
+ Examples:
560
+ >>> # Binary segmentation with arrays
561
+ >>> gt = np.array([[0, 0, 1, 1], [0, 1, 1, 1]])
562
+ >>> pred = np.array([[0, 0, 1, 1], [0, 0, 1, 1]])
563
+ >>> f1 = calc_f1_score(gt, pred)
564
+ >>> print(f"F1 Score: {f1:.4f}")
565
+ F1 Score: 0.8571
566
+
567
+ >>> # Multi-class segmentation
568
+ >>> gt = np.array([[0, 0, 1, 1], [0, 2, 2, 1]])
569
+ >>> pred = np.array([[0, 0, 1, 1], [0, 0, 2, 2]])
570
+ >>> f1 = calc_f1_score(gt, pred, num_classes=3)
571
+ >>> print(f"F1 Score per class: {f1}")
572
+ F1 Score per class: [0.8571 0.6667 0.6667]
573
+
574
+ >>> # Using PyTorch tensors
575
+ >>> gt_tensor = torch.tensor([[0, 0, 1, 1], [0, 1, 1, 1]])
576
+ >>> pred_tensor = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]])
577
+ >>> f1 = calc_f1_score(gt_tensor, pred_tensor)
578
+ >>> print(f"F1 Score: {f1:.4f}")
579
+ F1 Score: 0.8571
580
+
581
+ >>> # Using raster file paths
582
+ >>> f1 = calc_f1_score("ground_truth.tif", "prediction.tif", num_classes=3)
583
+ >>> print(f"Mean F1: {np.nanmean(f1):.4f}")
584
+ Mean F1: 0.7302
585
+ """
586
+ # Load from file if string path is provided
587
+ if isinstance(ground_truth, str):
588
+ with rasterio.open(ground_truth) as src:
589
+ ground_truth = src.read(band)
590
+ if isinstance(prediction, str):
591
+ with rasterio.open(prediction) as src:
592
+ prediction = src.read(band)
593
+
594
+ # Convert to numpy if torch tensor
595
+ if isinstance(ground_truth, torch.Tensor):
596
+ ground_truth = ground_truth.cpu().numpy()
597
+ if isinstance(prediction, torch.Tensor):
598
+ prediction = prediction.cpu().numpy()
599
+
600
+ # Ensure inputs have the same shape
601
+ if ground_truth.shape != prediction.shape:
602
+ raise ValueError(
603
+ f"Shape mismatch: ground_truth {ground_truth.shape} vs prediction {prediction.shape}"
604
+ )
605
+
606
+ # Binary segmentation
607
+ if num_classes is None:
608
+ ground_truth = ground_truth.astype(bool)
609
+ prediction = prediction.astype(bool)
610
+
611
+ # Calculate True Positives, False Positives, False Negatives
612
+ tp = np.logical_and(ground_truth, prediction).sum()
613
+ fp = np.logical_and(~ground_truth, prediction).sum()
614
+ fn = np.logical_and(ground_truth, ~prediction).sum()
615
+
616
+ # Calculate precision and recall
617
+ precision = (tp + smooth) / (tp + fp + smooth)
618
+ recall = (tp + smooth) / (tp + fn + smooth)
619
+
620
+ # Calculate F1 score
621
+ f1 = 2 * (precision * recall) / (precision + recall + smooth)
622
+ return float(f1)
623
+
624
+ # Multi-class segmentation
625
+ else:
626
+ f1_per_class = []
627
+
628
+ for class_idx in range(num_classes):
629
+ # Mark ignored class with np.nan
630
+ if ignore_index is not None and class_idx == ignore_index:
631
+ f1_per_class.append(np.nan)
632
+ continue
633
+
634
+ # Create binary masks for current class
635
+ gt_class = (ground_truth == class_idx).astype(bool)
636
+ pred_class = (prediction == class_idx).astype(bool)
637
+
638
+ # Calculate True Positives, False Positives, False Negatives
639
+ tp = np.logical_and(gt_class, pred_class).sum()
640
+ fp = np.logical_and(~gt_class, pred_class).sum()
641
+ fn = np.logical_and(gt_class, ~pred_class).sum()
642
+
643
+ # Calculate precision and recall
644
+ precision = (tp + smooth) / (tp + fp + smooth)
645
+ recall = (tp + smooth) / (tp + fn + smooth)
646
+
647
+ # Calculate F1 score
648
+ if tp + fp + fn == 0:
649
+ # If class is not present in both gt and pred
650
+ f1_per_class.append(np.nan)
651
+ else:
652
+ f1 = 2 * (precision * recall) / (precision + recall + smooth)
653
+ f1_per_class.append(f1)
654
+
655
+ return np.array(f1_per_class)
656
+
657
+
658
+ def calc_segmentation_metrics(
659
+ ground_truth: Union[str, np.ndarray, torch.Tensor],
660
+ prediction: Union[str, np.ndarray, torch.Tensor],
661
+ num_classes: Optional[int] = None,
662
+ ignore_index: Optional[int] = None,
663
+ smooth: float = 1e-6,
664
+ metrics: List[str] = ["iou", "f1"],
665
+ band: int = 1,
666
+ ) -> Dict[str, Union[float, np.ndarray]]:
667
+ """
668
+ Calculate multiple segmentation metrics between ground truth and prediction masks.
669
+
670
+ This is a convenient wrapper function that computes multiple metrics at once,
671
+ including IoU (Intersection over Union) and F1 score. It supports both binary
672
+ and multi-class segmentation, and can handle numpy arrays, PyTorch tensors,
673
+ or file paths to raster files.
674
+
675
+ Args:
676
+ ground_truth (Union[str, np.ndarray, torch.Tensor]): Ground truth segmentation mask.
677
+ Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
678
+ For binary segmentation: shape (H, W) with values {0, 1}.
679
+ For multi-class segmentation: shape (H, W) with class indices.
680
+ prediction (Union[str, np.ndarray, torch.Tensor]): Predicted segmentation mask.
681
+ Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
682
+ Should have the same shape and format as ground_truth.
683
+ num_classes (Optional[int], optional): Number of classes for multi-class segmentation.
684
+ If None, assumes binary segmentation. Defaults to None.
685
+ ignore_index (Optional[int], optional): Class index to ignore in computation.
686
+ Useful for ignoring background or unlabeled pixels. Defaults to None.
687
+ smooth (float, optional): Smoothing factor to avoid division by zero.
688
+ Defaults to 1e-6.
689
+ metrics (List[str], optional): List of metrics to calculate.
690
+ Options: "iou", "f1". Defaults to ["iou", "f1"].
691
+ band (int, optional): Band index to read from raster file (1-based indexing).
692
+ Only used when input is a file path. Defaults to 1.
693
+
694
+ Returns:
695
+ Dict[str, Union[float, np.ndarray]]: Dictionary containing the computed metrics.
696
+ Keys are metric names ("iou", "f1"), values are the metric scores.
697
+ For binary segmentation, values are floats.
698
+ For multi-class segmentation, values are numpy arrays with per-class scores.
699
+ Also includes "mean_iou" and "mean_f1" for multi-class segmentation
700
+ (mean computed over valid classes, ignoring NaN values).
701
+
702
+ Examples:
703
+ >>> # Binary segmentation with arrays
704
+ >>> gt = np.array([[0, 0, 1, 1], [0, 1, 1, 1]])
705
+ >>> pred = np.array([[0, 0, 1, 1], [0, 0, 1, 1]])
706
+ >>> metrics = calc_segmentation_metrics(gt, pred)
707
+ >>> print(f"IoU: {metrics['iou']:.4f}, F1: {metrics['f1']:.4f}")
708
+ IoU: 0.8333, F1: 0.8571
709
+
710
+ >>> # Multi-class segmentation
711
+ >>> gt = np.array([[0, 0, 1, 1], [0, 2, 2, 1]])
712
+ >>> pred = np.array([[0, 0, 1, 1], [0, 0, 2, 2]])
713
+ >>> metrics = calc_segmentation_metrics(gt, pred, num_classes=3)
714
+ >>> print(f"Mean IoU: {metrics['mean_iou']:.4f}")
715
+ >>> print(f"Mean F1: {metrics['mean_f1']:.4f}")
716
+ >>> print(f"Per-class IoU: {metrics['iou']}")
717
+ Mean IoU: 0.6111
718
+ Mean F1: 0.7302
719
+ Per-class IoU: [0.8333 0.5000 0.5000]
720
+
721
+ >>> # Calculate only IoU
722
+ >>> metrics = calc_segmentation_metrics(gt, pred, num_classes=3, metrics=["iou"])
723
+ >>> print(f"Mean IoU: {metrics['mean_iou']:.4f}")
724
+ Mean IoU: 0.6111
725
+
726
+ >>> # Using PyTorch tensors
727
+ >>> gt_tensor = torch.tensor([[0, 0, 1, 1], [0, 1, 1, 1]])
728
+ >>> pred_tensor = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]])
729
+ >>> metrics = calc_segmentation_metrics(gt_tensor, pred_tensor)
730
+ >>> print(f"IoU: {metrics['iou']:.4f}, F1: {metrics['f1']:.4f}")
731
+ IoU: 0.8333, F1: 0.8571
732
+
733
+ >>> # Using raster file paths
734
+ >>> metrics = calc_segmentation_metrics("ground_truth.tif", "prediction.tif", num_classes=3)
735
+ >>> print(f"Mean IoU: {metrics['mean_iou']:.4f}")
736
+ >>> print(f"Mean F1: {metrics['mean_f1']:.4f}")
737
+ Mean IoU: 0.6111
738
+ Mean F1: 0.7302
739
+ """
740
+ results = {}
741
+
742
+ # Calculate IoU if requested
743
+ if "iou" in metrics:
744
+ iou = calc_iou(
745
+ ground_truth,
746
+ prediction,
747
+ num_classes=num_classes,
748
+ ignore_index=ignore_index,
749
+ smooth=smooth,
750
+ band=band,
751
+ )
752
+ results["iou"] = iou
753
+
754
+ # Add mean IoU for multi-class
755
+ if num_classes is not None and isinstance(iou, np.ndarray):
756
+ # Calculate mean ignoring NaN values
757
+ valid_ious = iou[~np.isnan(iou)]
758
+ results["mean_iou"] = (
759
+ float(np.mean(valid_ious)) if len(valid_ious) > 0 else 0.0
760
+ )
761
+
762
+ # Calculate F1 score if requested
763
+ if "f1" in metrics:
764
+ f1 = calc_f1_score(
765
+ ground_truth,
766
+ prediction,
767
+ num_classes=num_classes,
768
+ ignore_index=ignore_index,
769
+ smooth=smooth,
770
+ band=band,
771
+ )
772
+ results["f1"] = f1
773
+
774
+ # Add mean F1 for multi-class
775
+ if num_classes is not None and isinstance(f1, np.ndarray):
776
+ # Calculate mean ignoring NaN values
777
+ valid_f1s = f1[~np.isnan(f1)]
778
+ results["mean_f1"] = (
779
+ float(np.mean(valid_f1s)) if len(valid_f1s) > 0 else 0.0
780
+ )
781
+
782
+ return results
783
+
784
+
384
785
  def dict_to_rioxarray(data_dict: Dict) -> xr.DataArray:
385
786
  """Convert a dictionary to a xarray DataArray. The dictionary should contain the
386
787
  following keys: "crs", "bounds", and "image". It can be generated from a TorchGeo
@@ -691,8 +1092,9 @@ def view_vector(
691
1092
 
692
1093
  def view_vector_interactive(
693
1094
  vector_data: Union[str, gpd.GeoDataFrame],
694
- layer_name: str = "Vector Layer",
1095
+ layer_name: str = "Vector",
695
1096
  tiles_args: Optional[Dict] = None,
1097
+ opacity: float = 0.7,
696
1098
  **kwargs: Any,
697
1099
  ) -> Any:
698
1100
  """
@@ -704,9 +1106,10 @@ def view_vector_interactive(
704
1106
 
705
1107
  Args:
706
1108
  vector_data (geopandas.GeoDataFrame): The vector dataset to visualize.
707
- layer_name (str, optional): The name of the layer. Defaults to "Vector Layer".
1109
+ layer_name (str, optional): The name of the layer. Defaults to "Vector".
708
1110
  tiles_args (dict, optional): Additional arguments for the localtileserver client.
709
1111
  get_folium_tile_layer function. Defaults to None.
1112
+ opacity (float, optional): The opacity of the layer. Defaults to 0.7.
710
1113
  **kwargs: Additional keyword arguments to pass to GeoDataFrame.explore() function.
711
1114
  See https://geopandas.org/en/stable/docs/reference/api/geopandas.GeoDataFrame.explore.html
712
1115
 
@@ -721,9 +1124,8 @@ def view_vector_interactive(
721
1124
  >>> roads = gpd.read_file("roads.shp")
722
1125
  >>> view_vector_interactive(roads, figsize=(12, 8))
723
1126
  """
724
- import folium
725
- import folium.plugins as plugins
726
- from leafmap import cog_tile
1127
+
1128
+ from leafmap.foliumap import Map
727
1129
  from localtileserver import TileClient, get_folium_tile_layer
728
1130
 
729
1131
  google_tiles = {
@@ -749,9 +1151,17 @@ def view_vector_interactive(
749
1151
  },
750
1152
  }
751
1153
 
1154
+ # Make it compatible with binder and JupyterHub
1155
+ if os.environ.get("JUPYTERHUB_SERVICE_PREFIX") is not None:
1156
+ os.environ["LOCALTILESERVER_CLIENT_PREFIX"] = (
1157
+ f"{os.environ['JUPYTERHUB_SERVICE_PREFIX'].lstrip('/')}/proxy/{{port}}"
1158
+ )
1159
+
752
1160
  basemap_layer_name = None
753
1161
  raster_layer = None
754
1162
 
1163
+ m = Map()
1164
+
755
1165
  if "tiles" in kwargs and isinstance(kwargs["tiles"], str):
756
1166
  if kwargs["tiles"].title() in google_tiles:
757
1167
  basemap_layer_name = google_tiles[kwargs["tiles"].title()]["name"]
@@ -762,14 +1172,17 @@ def view_vector_interactive(
762
1172
  tiles_args = {}
763
1173
  if kwargs["tiles"].lower().startswith("http"):
764
1174
  basemap_layer_name = "Remote Raster"
765
- kwargs["tiles"] = cog_tile(kwargs["tiles"], **tiles_args)
766
- kwargs["attr"] = "TiTiler"
1175
+ m.add_geotiff(kwargs["tiles"], name=basemap_layer_name, **tiles_args)
767
1176
  else:
768
1177
  basemap_layer_name = "Local Raster"
769
1178
  client = TileClient(kwargs["tiles"])
770
1179
  raster_layer = get_folium_tile_layer(client, **tiles_args)
771
- kwargs["tiles"] = raster_layer.tiles
772
- kwargs["attr"] = "localtileserver"
1180
+ m.add_tile_layer(
1181
+ raster_layer.tiles,
1182
+ name=basemap_layer_name,
1183
+ attribution="localtileserver",
1184
+ **tiles_args,
1185
+ )
773
1186
 
774
1187
  if "max_zoom" not in kwargs:
775
1188
  kwargs["max_zoom"] = 30
@@ -784,23 +1197,18 @@ def view_vector_interactive(
784
1197
  if not isinstance(vector_data, gpd.GeoDataFrame):
785
1198
  raise TypeError("Input data must be a GeoDataFrame")
786
1199
 
787
- layer_control = kwargs.pop("layer_control", True)
788
- fullscreen_control = kwargs.pop("fullscreen_control", True)
1200
+ if "column" in kwargs:
1201
+ if "legend_position" not in kwargs:
1202
+ kwargs["legend_position"] = "bottomleft"
1203
+ if "cmap" not in kwargs:
1204
+ kwargs["cmap"] = "viridis"
1205
+ m.add_data(vector_data, layer_name=layer_name, opacity=opacity, **kwargs)
789
1206
 
790
- m = vector_data.explore(**kwargs)
791
-
792
- # Change the layer name
793
- for layer in m._children.values():
794
- if isinstance(layer, folium.GeoJson):
795
- layer.layer_name = layer_name
796
- if isinstance(layer, folium.TileLayer) and basemap_layer_name:
797
- layer.layer_name = basemap_layer_name
798
-
799
- if layer_control:
800
- m.add_child(folium.LayerControl())
1207
+ else:
1208
+ m.add_gdf(vector_data, layer_name=layer_name, opacity=opacity, **kwargs)
801
1209
 
802
- if fullscreen_control:
803
- plugins.Fullscreen().add_to(m)
1210
+ m.add_layer_control()
1211
+ m.add_opacity_control()
804
1212
 
805
1213
  return m
806
1214
 
@@ -2594,10 +3002,86 @@ def batch_vector_to_raster(
2594
3002
  return output_files
2595
3003
 
2596
3004
 
3005
+ def get_default_augmentation_transforms(
3006
+ tile_size: int = 256,
3007
+ include_normalize: bool = False,
3008
+ mean: Tuple[float, float, float] = (0.485, 0.456, 0.406),
3009
+ std: Tuple[float, float, float] = (0.229, 0.224, 0.225),
3010
+ ) -> Any:
3011
+ """
3012
+ Get default data augmentation transforms for geospatial imagery using albumentations.
3013
+
3014
+ This function returns a composition of augmentation transforms commonly used
3015
+ for remote sensing and geospatial data. The transforms include geometric
3016
+ transformations (flips, rotations) and photometric adjustments (brightness,
3017
+ contrast, saturation).
3018
+
3019
+ Args:
3020
+ tile_size (int): Target size for tiles. Defaults to 256.
3021
+ include_normalize (bool): Whether to include normalization transform.
3022
+ Defaults to False. Set to True if using for training with pretrained models.
3023
+ mean (tuple): Mean values for normalization (RGB). Defaults to ImageNet values.
3024
+ std (tuple): Standard deviation for normalization (RGB). Defaults to ImageNet values.
3025
+
3026
+ Returns:
3027
+ albumentations.Compose: A composition of augmentation transforms.
3028
+
3029
+ Example:
3030
+ >>> import albumentations as A
3031
+ >>> # Get default transforms
3032
+ >>> transform = get_default_augmentation_transforms()
3033
+ >>> # Apply to image and mask
3034
+ >>> augmented = transform(image=image, mask=mask)
3035
+ >>> aug_image = augmented['image']
3036
+ >>> aug_mask = augmented['mask']
3037
+ """
3038
+ try:
3039
+ import albumentations as A
3040
+ except ImportError:
3041
+ raise ImportError(
3042
+ "albumentations is required for data augmentation. "
3043
+ "Install it with: pip install albumentations"
3044
+ )
3045
+
3046
+ transforms_list = [
3047
+ # Geometric transforms
3048
+ A.HorizontalFlip(p=0.5),
3049
+ A.VerticalFlip(p=0.5),
3050
+ A.RandomRotate90(p=0.5),
3051
+ A.ShiftScaleRotate(
3052
+ shift_limit=0.1,
3053
+ scale_limit=0.1,
3054
+ rotate_limit=45,
3055
+ border_mode=0,
3056
+ p=0.5,
3057
+ ),
3058
+ # Photometric transforms
3059
+ A.RandomBrightnessContrast(
3060
+ brightness_limit=0.2,
3061
+ contrast_limit=0.2,
3062
+ p=0.5,
3063
+ ),
3064
+ A.HueSaturationValue(
3065
+ hue_shift_limit=10,
3066
+ sat_shift_limit=20,
3067
+ val_shift_limit=10,
3068
+ p=0.3,
3069
+ ),
3070
+ A.GaussNoise(var_limit=(10.0, 50.0), p=0.2),
3071
+ A.GaussianBlur(blur_limit=(3, 5), p=0.2),
3072
+ ]
3073
+
3074
+ # Add normalization if requested
3075
+ if include_normalize:
3076
+ transforms_list.append(A.Normalize(mean=mean, std=std))
3077
+
3078
+ return A.Compose(transforms_list)
3079
+
3080
+
2597
3081
  def export_geotiff_tiles(
2598
3082
  in_raster,
2599
3083
  out_folder,
2600
- in_class_data,
3084
+ in_class_data=None,
2601
3085
  tile_size=256,
2602
3086
  stride=128,
2603
3087
  class_value_field="class",
@@ -2607,6 +3091,10 @@ def export_geotiff_tiles(
2607
3091
  all_touched=True,
2608
3092
  create_overview=False,
2609
3093
  skip_empty_tiles=False,
3094
+ metadata_format="PASCAL_VOC",
3095
+ apply_augmentation=False,
3096
+ augmentation_count=3,
3097
+ augmentation_transforms=None,
2610
3098
  ):
2611
3099
  """
2612
3100
  Export georeferenced GeoTIFF tiles and labels from raster and classification data.
@@ -2614,7 +3102,8 @@ def export_geotiff_tiles(
2614
3102
  Args:
2615
3103
  in_raster (str): Path to input raster image
2616
3104
  out_folder (str): Path to output folder
2617
- in_class_data (str): Path to classification data - can be vector file or raster
3105
+ in_class_data (str, optional): Path to classification data - can be vector file or raster.
3106
+ If None, only image tiles will be exported without labels. Defaults to None.
2618
3107
  tile_size (int): Size of tiles in pixels (square)
2619
3108
  stride (int): Step size between tiles
2620
3109
  class_value_field (str): Field containing class values (for vector data)
@@ -2624,38 +3113,95 @@ def export_geotiff_tiles(
2624
3113
  all_touched (bool): Whether to use all_touched=True in rasterization (for vector data)
2625
3114
  create_overview (bool): Whether to create an overview image of all tiles
2626
3115
  skip_empty_tiles (bool): If True, skip tiles with no features
3116
+ metadata_format (str): Output metadata format (PASCAL_VOC, COCO, YOLO). Default: PASCAL_VOC
3117
+ apply_augmentation (bool): If True, generate augmented versions of each tile.
3118
+ This will create multiple variants of each tile using data augmentation techniques.
3119
+ Defaults to False.
3120
+ augmentation_count (int): Number of augmented versions to generate per tile
3121
+ (only used if apply_augmentation=True). Defaults to 3.
3122
+ augmentation_transforms (albumentations.Compose, optional): Custom augmentation transforms.
3123
+ If None and apply_augmentation=True, uses default transforms from
3124
+ get_default_augmentation_transforms(). Should be an albumentations.Compose object.
3125
+ Defaults to None.
3126
+
3127
+ Returns:
3128
+ None: Tiles and labels are saved to out_folder.
3129
+
3130
+ Example:
3131
+ >>> # Export tiles without augmentation
3132
+ >>> export_geotiff_tiles('image.tif', 'output/', 'labels.tif')
3133
+ >>>
3134
+ >>> # Export tiles with default augmentation (3 augmented versions per tile)
3135
+ >>> export_geotiff_tiles('image.tif', 'output/', 'labels.tif',
3136
+ ... apply_augmentation=True)
3137
+ >>>
3138
+ >>> # Export with custom augmentation
3139
+ >>> import albumentations as A
3140
+ >>> custom_transform = A.Compose([
3141
+ ... A.HorizontalFlip(p=0.5),
3142
+ ... A.RandomBrightnessContrast(p=0.5),
3143
+ ... ])
3144
+ >>> export_geotiff_tiles('image.tif', 'output/', 'labels.tif',
3145
+ ... apply_augmentation=True,
3146
+ ... augmentation_count=5,
3147
+ ... augmentation_transforms=custom_transform)
2627
3148
  """
2628
3149
 
2629
3150
  import logging
2630
3151
 
2631
3152
  logging.getLogger("rasterio").setLevel(logging.ERROR)
2632
3153
 
3154
+ # Initialize augmentation transforms if needed
3155
+ if apply_augmentation:
3156
+ if augmentation_transforms is None:
3157
+ augmentation_transforms = get_default_augmentation_transforms(
3158
+ tile_size=tile_size
3159
+ )
3160
+ if not quiet:
3161
+ print(
3162
+ f"Data augmentation enabled: generating {augmentation_count} augmented versions per tile"
3163
+ )
3164
+
2633
3165
  # Create output directories
2634
3166
  os.makedirs(out_folder, exist_ok=True)
2635
3167
  image_dir = os.path.join(out_folder, "images")
2636
3168
  os.makedirs(image_dir, exist_ok=True)
2637
- label_dir = os.path.join(out_folder, "labels")
2638
- os.makedirs(label_dir, exist_ok=True)
2639
- ann_dir = os.path.join(out_folder, "annotations")
2640
- os.makedirs(ann_dir, exist_ok=True)
2641
3169
 
2642
- # Determine if class data is raster or vector
3170
+ # Only create label and annotation directories if class data is provided
3171
+ if in_class_data is not None:
3172
+ label_dir = os.path.join(out_folder, "labels")
3173
+ os.makedirs(label_dir, exist_ok=True)
3174
+
3175
+ # Create annotation directory based on metadata format
3176
+ if metadata_format in ["PASCAL_VOC", "COCO"]:
3177
+ ann_dir = os.path.join(out_folder, "annotations")
3178
+ os.makedirs(ann_dir, exist_ok=True)
3179
+
3180
+ # Initialize COCO annotations dictionary
3181
+ if metadata_format == "COCO":
3182
+ coco_annotations = {"images": [], "annotations": [], "categories": []}
3183
+ ann_id = 0
3184
+
3185
+ # Determine if class data is raster or vector (only if class data provided)
2643
3186
  is_class_data_raster = False
2644
- if isinstance(in_class_data, str):
2645
- file_ext = Path(in_class_data).suffix.lower()
2646
- # Common raster extensions
2647
- if file_ext in [".tif", ".tiff", ".img", ".jp2", ".png", ".bmp", ".gif"]:
2648
- try:
2649
- with rasterio.open(in_class_data) as src:
2650
- is_class_data_raster = True
3187
+ if in_class_data is not None:
3188
+ if isinstance(in_class_data, str):
3189
+ file_ext = Path(in_class_data).suffix.lower()
3190
+ # Common raster extensions
3191
+ if file_ext in [".tif", ".tiff", ".img", ".jp2", ".png", ".bmp", ".gif"]:
3192
+ try:
3193
+ with rasterio.open(in_class_data) as src:
3194
+ is_class_data_raster = True
3195
+ if not quiet:
3196
+ print(f"Detected in_class_data as raster: {in_class_data}")
3197
+ print(f"Raster CRS: {src.crs}")
3198
+ print(f"Raster dimensions: {src.width} x {src.height}")
3199
+ except Exception:
3200
+ is_class_data_raster = False
2651
3201
  if not quiet:
2652
- print(f"Detected in_class_data as raster: {in_class_data}")
2653
- print(f"Raster CRS: {src.crs}")
2654
- print(f"Raster dimensions: {src.width} x {src.height}")
2655
- except Exception:
2656
- is_class_data_raster = False
2657
- if not quiet:
2658
- print(f"Unable to open {in_class_data} as raster, trying as vector")
3202
+ print(
3203
+ f"Unable to open {in_class_data} as raster, trying as vector"
3204
+ )
2659
3205
 
2660
3206
  # Open the input raster
2661
3207
  with rasterio.open(in_raster) as src:
@@ -2675,10 +3221,10 @@ def export_geotiff_tiles(
2675
3221
  if max_tiles is None:
2676
3222
  max_tiles = total_tiles
2677
3223
 
2678
- # Process classification data
3224
+ # Process classification data (only if class data provided)
2679
3225
  class_to_id = {}
2680
3226
 
2681
- if is_class_data_raster:
3227
+ if in_class_data is not None and is_class_data_raster:
2682
3228
  # Load raster class data
2683
3229
  with rasterio.open(in_class_data) as class_src:
2684
3230
  # Check if raster CRS matches
@@ -2711,7 +3257,18 @@ def export_geotiff_tiles(
2711
3257
 
2712
3258
  # Create class mapping
2713
3259
  class_to_id = {int(cls): i + 1 for i, cls in enumerate(unique_classes)}
2714
- else:
3260
+
3261
+ # Populate COCO categories
3262
+ if metadata_format == "COCO":
3263
+ for cls_val in unique_classes:
3264
+ coco_annotations["categories"].append(
3265
+ {
3266
+ "id": class_to_id[int(cls_val)],
3267
+ "name": str(int(cls_val)),
3268
+ "supercategory": "object",
3269
+ }
3270
+ )
3271
+ elif in_class_data is not None:
2715
3272
  # Load vector class data
2716
3273
  try:
2717
3274
  gdf = gpd.read_file(in_class_data)
@@ -2740,12 +3297,33 @@ def export_geotiff_tiles(
2740
3297
  )
2741
3298
  # Create class mapping
2742
3299
  class_to_id = {cls: i + 1 for i, cls in enumerate(unique_classes)}
3300
+
3301
+ # Populate COCO categories
3302
+ if metadata_format == "COCO":
3303
+ for cls_val in unique_classes:
3304
+ coco_annotations["categories"].append(
3305
+ {
3306
+ "id": class_to_id[cls_val],
3307
+ "name": str(cls_val),
3308
+ "supercategory": "object",
3309
+ }
3310
+ )
2743
3311
  else:
2744
3312
  if not quiet:
2745
3313
  print(
2746
3314
  f"WARNING: '{class_value_field}' not found in vector data. Using default class ID 1."
2747
3315
  )
2748
3316
  class_to_id = {1: 1} # Default mapping
3317
+
3318
+ # Populate COCO categories with default
3319
+ if metadata_format == "COCO":
3320
+ coco_annotations["categories"].append(
3321
+ {
3322
+ "id": 1,
3323
+ "name": "object",
3324
+ "supercategory": "object",
3325
+ }
3326
+ )
2749
3327
  except Exception as e:
2750
3328
  raise ValueError(f"Error processing vector data: {e}")
2751
3329
 
@@ -2812,8 +3390,8 @@ def export_geotiff_tiles(
2812
3390
  label_mask = np.zeros((tile_size, tile_size), dtype=np.uint8)
2813
3391
  has_features = False
2814
3392
 
2815
- # Process classification data to create labels
2816
- if is_class_data_raster:
3393
+ # Process classification data to create labels (only if class data provided)
3394
+ if in_class_data is not None and is_class_data_raster:
2817
3395
  # For raster class data
2818
3396
  with rasterio.open(in_class_data) as class_src:
2819
3397
  # Calculate window in class raster
@@ -2863,7 +3441,7 @@ def export_geotiff_tiles(
2863
3441
  except Exception as e:
2864
3442
  pbar.write(f"Error reading class raster window: {e}")
2865
3443
  stats["errors"] += 1
2866
- else:
3444
+ elif in_class_data is not None:
2867
3445
  # For vector class data
2868
3446
  # Find features that intersect with window
2869
3447
  window_features = gdf[gdf.intersects(window_bounds)]
@@ -2906,8 +3484,8 @@ def export_geotiff_tiles(
2906
3484
  pbar.write(f"Error rasterizing feature {idx}: {e}")
2907
3485
  stats["errors"] += 1
2908
3486
 
2909
- # Skip tile if no features and skip_empty_tiles is True
2910
- if skip_empty_tiles and not has_features:
3487
+ # Skip tile if no features and skip_empty_tiles is True (only when class data provided)
3488
+ if in_class_data is not None and skip_empty_tiles and not has_features:
2911
3489
  pbar.update(1)
2912
3490
  tile_index += 1
2913
3491
  continue
@@ -2915,119 +3493,316 @@ def export_geotiff_tiles(
2915
3493
  # Read image data
2916
3494
  image_data = src.read(window=window)
2917
3495
 
2918
- # Export image as GeoTIFF
2919
- image_path = os.path.join(image_dir, f"tile_{tile_index:06d}.tif")
3496
+ # Helper function to save a single tile (original or augmented)
3497
+ def save_tile(
3498
+ img_data,
3499
+ lbl_mask,
3500
+ tile_id,
3501
+ img_profile,
3502
+ window_trans,
3503
+ is_augmented=False,
3504
+ ):
3505
+ """Save a single image and label tile."""
3506
+ # Export image as GeoTIFF
3507
+ image_path = os.path.join(image_dir, f"tile_{tile_id:06d}.tif")
2920
3508
 
2921
- # Create profile for image GeoTIFF
2922
- image_profile = src.profile.copy()
2923
- image_profile.update(
2924
- {
2925
- "height": tile_size,
2926
- "width": tile_size,
2927
- "count": image_data.shape[0],
2928
- "transform": window_transform,
2929
- }
3509
+ # Update profile
3510
+ img_profile_copy = img_profile.copy()
3511
+ img_profile_copy.update(
3512
+ {
3513
+ "height": tile_size,
3514
+ "width": tile_size,
3515
+ "count": img_data.shape[0],
3516
+ "transform": window_trans,
3517
+ }
3518
+ )
3519
+
3520
+ # Save image as GeoTIFF
3521
+ try:
3522
+ with rasterio.open(image_path, "w", **img_profile_copy) as dst:
3523
+ dst.write(img_data)
3524
+ stats["total_tiles"] += 1
3525
+ except Exception as e:
3526
+ pbar.write(f"ERROR saving image GeoTIFF: {e}")
3527
+ stats["errors"] += 1
3528
+ return
3529
+
3530
+ # Export label as GeoTIFF (only if class data provided)
3531
+ if in_class_data is not None:
3532
+ # Create profile for label GeoTIFF
3533
+ label_profile = {
3534
+ "driver": "GTiff",
3535
+ "height": tile_size,
3536
+ "width": tile_size,
3537
+ "count": 1,
3538
+ "dtype": "uint8",
3539
+ "crs": src.crs,
3540
+ "transform": window_trans,
3541
+ }
3542
+
3543
+ label_path = os.path.join(label_dir, f"tile_{tile_id:06d}.tif")
3544
+ try:
3545
+ with rasterio.open(label_path, "w", **label_profile) as dst:
3546
+ dst.write(lbl_mask.astype(np.uint8), 1)
3547
+
3548
+ if not is_augmented and np.any(lbl_mask > 0):
3549
+ stats["tiles_with_features"] += 1
3550
+ stats["feature_pixels"] += np.count_nonzero(lbl_mask)
3551
+ except Exception as e:
3552
+ pbar.write(f"ERROR saving label GeoTIFF: {e}")
3553
+ stats["errors"] += 1
3554
+
3555
+ # Save original tile
3556
+ save_tile(
3557
+ image_data,
3558
+ label_mask,
3559
+ tile_index,
3560
+ src.profile,
3561
+ window_transform,
3562
+ is_augmented=False,
2930
3563
  )
2931
3564
 
2932
- # Save image as GeoTIFF
2933
- try:
2934
- with rasterio.open(image_path, "w", **image_profile) as dst:
2935
- dst.write(image_data)
2936
- stats["total_tiles"] += 1
2937
- except Exception as e:
2938
- pbar.write(f"ERROR saving image GeoTIFF: {e}")
2939
- stats["errors"] += 1
3565
+ # Generate and save augmented tiles if enabled
3566
+ if apply_augmentation:
3567
+ for aug_idx in range(augmentation_count):
3568
+ # Prepare image for augmentation (convert from CHW to HWC)
3569
+ img_for_aug = np.transpose(image_data, (1, 2, 0))
3570
+
3571
+ # Ensure uint8 data type for albumentations
3572
+ # Albumentations expects uint8 for most transforms
3573
+ if not np.issubdtype(img_for_aug.dtype, np.uint8):
3574
+ # If image is float, scale to 0-255 and convert to uint8
3575
+ if np.issubdtype(img_for_aug.dtype, np.floating):
3576
+ img_for_aug = (
3577
+ (img_for_aug * 255).clip(0, 255).astype(np.uint8)
3578
+ )
3579
+ else:
3580
+ img_for_aug = img_for_aug.astype(np.uint8)
2940
3581
 
2941
- # Create profile for label GeoTIFF
2942
- label_profile = {
2943
- "driver": "GTiff",
2944
- "height": tile_size,
2945
- "width": tile_size,
2946
- "count": 1,
2947
- "dtype": "uint8",
2948
- "crs": src.crs,
2949
- "transform": window_transform,
2950
- }
3582
+ # Apply augmentation
3583
+ try:
3584
+ if in_class_data is not None:
3585
+ # Augment both image and mask
3586
+ augmented = augmentation_transforms(
3587
+ image=img_for_aug, mask=label_mask
3588
+ )
3589
+ aug_image = augmented["image"]
3590
+ aug_mask = augmented["mask"]
3591
+ else:
3592
+ # Augment only image
3593
+ augmented = augmentation_transforms(image=img_for_aug)
3594
+ aug_image = augmented["image"]
3595
+ aug_mask = label_mask
2951
3596
 
2952
- # Export label as GeoTIFF
2953
- label_path = os.path.join(label_dir, f"tile_{tile_index:06d}.tif")
2954
- try:
2955
- with rasterio.open(label_path, "w", **label_profile) as dst:
2956
- dst.write(label_mask.astype(np.uint8), 1)
3597
+ # Convert back from HWC to CHW
3598
+ aug_image = np.transpose(aug_image, (2, 0, 1))
2957
3599
 
2958
- if has_features:
2959
- stats["tiles_with_features"] += 1
2960
- stats["feature_pixels"] += np.count_nonzero(label_mask)
2961
- except Exception as e:
2962
- pbar.write(f"ERROR saving label GeoTIFF: {e}")
2963
- stats["errors"] += 1
3600
+ # Ensure correct dtype for saving
3601
+ aug_image = aug_image.astype(image_data.dtype)
3602
+
3603
+ # Generate unique tile ID for augmented version
3604
+ # Use a collision-free numbering scheme: (tile_index * (augmentation_count + 1)) + aug_idx + 1
3605
+ aug_tile_id = (
3606
+ (tile_index * (augmentation_count + 1)) + aug_idx + 1
3607
+ )
3608
+
3609
+ # Save augmented tile
3610
+ save_tile(
3611
+ aug_image,
3612
+ aug_mask,
3613
+ aug_tile_id,
3614
+ src.profile,
3615
+ window_transform,
3616
+ is_augmented=True,
3617
+ )
3618
+
3619
+ except Exception as e:
3620
+ pbar.write(
3621
+ f"ERROR applying augmentation {aug_idx} to tile {tile_index}: {e}"
3622
+ )
3623
+ stats["errors"] += 1
2964
3624
 
2965
- # Create XML annotation for object detection if using vector class data
3625
+ # Create annotations for object detection if using vector class data
2966
3626
  if (
2967
- not is_class_data_raster
3627
+ in_class_data is not None
3628
+ and not is_class_data_raster
2968
3629
  and "gdf" in locals()
2969
3630
  and len(window_features) > 0
2970
3631
  ):
2971
- # Create XML annotation
2972
- root = ET.Element("annotation")
2973
- ET.SubElement(root, "folder").text = "images"
2974
- ET.SubElement(root, "filename").text = f"tile_{tile_index:06d}.tif"
3632
+ if metadata_format == "PASCAL_VOC":
3633
+ # Create XML annotation
3634
+ root = ET.Element("annotation")
3635
+ ET.SubElement(root, "folder").text = "images"
3636
+ ET.SubElement(root, "filename").text = (
3637
+ f"tile_{tile_index:06d}.tif"
3638
+ )
2975
3639
 
2976
- size = ET.SubElement(root, "size")
2977
- ET.SubElement(size, "width").text = str(tile_size)
2978
- ET.SubElement(size, "height").text = str(tile_size)
2979
- ET.SubElement(size, "depth").text = str(image_data.shape[0])
3640
+ size = ET.SubElement(root, "size")
3641
+ ET.SubElement(size, "width").text = str(tile_size)
3642
+ ET.SubElement(size, "height").text = str(tile_size)
3643
+ ET.SubElement(size, "depth").text = str(image_data.shape[0])
3644
+
3645
+ # Add georeference information
3646
+ geo = ET.SubElement(root, "georeference")
3647
+ ET.SubElement(geo, "crs").text = str(src.crs)
3648
+ ET.SubElement(geo, "transform").text = str(
3649
+ window_transform
3650
+ ).replace("\n", "")
3651
+ ET.SubElement(geo, "bounds").text = (
3652
+ f"{minx}, {miny}, {maxx}, {maxy}"
3653
+ )
2980
3654
 
2981
- # Add georeference information
2982
- geo = ET.SubElement(root, "georeference")
2983
- ET.SubElement(geo, "crs").text = str(src.crs)
2984
- ET.SubElement(geo, "transform").text = str(
2985
- window_transform
2986
- ).replace("\n", "")
2987
- ET.SubElement(geo, "bounds").text = (
2988
- f"{minx}, {miny}, {maxx}, {maxy}"
2989
- )
3655
+ # Add objects
3656
+ for idx, feature in window_features.iterrows():
3657
+ # Get feature class
3658
+ if class_value_field in feature:
3659
+ class_val = feature[class_value_field]
3660
+ else:
3661
+ class_val = "object"
2990
3662
 
2991
- # Add objects
2992
- for idx, feature in window_features.iterrows():
2993
- # Get feature class
2994
- if class_value_field in feature:
2995
- class_val = feature[class_value_field]
2996
- else:
2997
- class_val = "object"
3663
+ # Get geometry bounds in pixel coordinates
3664
+ geom = feature.geometry.intersection(window_bounds)
3665
+ if not geom.is_empty:
3666
+ # Get bounds in world coordinates
3667
+ minx_f, miny_f, maxx_f, maxy_f = geom.bounds
3668
+
3669
+ # Convert to pixel coordinates
3670
+ col_min, row_min = ~window_transform * (minx_f, maxy_f)
3671
+ col_max, row_max = ~window_transform * (maxx_f, miny_f)
3672
+
3673
+ # Ensure coordinates are within tile bounds
3674
+ xmin = max(0, min(tile_size, int(col_min)))
3675
+ ymin = max(0, min(tile_size, int(row_min)))
3676
+ xmax = max(0, min(tile_size, int(col_max)))
3677
+ ymax = max(0, min(tile_size, int(row_max)))
3678
+
3679
+ # Only add if the box has non-zero area
3680
+ if xmax > xmin and ymax > ymin:
3681
+ obj = ET.SubElement(root, "object")
3682
+ ET.SubElement(obj, "name").text = str(class_val)
3683
+ ET.SubElement(obj, "difficult").text = "0"
3684
+
3685
+ bbox = ET.SubElement(obj, "bndbox")
3686
+ ET.SubElement(bbox, "xmin").text = str(xmin)
3687
+ ET.SubElement(bbox, "ymin").text = str(ymin)
3688
+ ET.SubElement(bbox, "xmax").text = str(xmax)
3689
+ ET.SubElement(bbox, "ymax").text = str(ymax)
3690
+
3691
+ # Save XML
3692
+ tree = ET.ElementTree(root)
3693
+ xml_path = os.path.join(ann_dir, f"tile_{tile_index:06d}.xml")
3694
+ tree.write(xml_path)
2998
3695
 
2999
- # Get geometry bounds in pixel coordinates
3000
- geom = feature.geometry.intersection(window_bounds)
3001
- if not geom.is_empty:
3002
- # Get bounds in world coordinates
3003
- minx_f, miny_f, maxx_f, maxy_f = geom.bounds
3004
-
3005
- # Convert to pixel coordinates
3006
- col_min, row_min = ~window_transform * (minx_f, maxy_f)
3007
- col_max, row_max = ~window_transform * (maxx_f, miny_f)
3008
-
3009
- # Ensure coordinates are within tile bounds
3010
- xmin = max(0, min(tile_size, int(col_min)))
3011
- ymin = max(0, min(tile_size, int(row_min)))
3012
- xmax = max(0, min(tile_size, int(col_max)))
3013
- ymax = max(0, min(tile_size, int(row_max)))
3014
-
3015
- # Only add if the box has non-zero area
3016
- if xmax > xmin and ymax > ymin:
3017
- obj = ET.SubElement(root, "object")
3018
- ET.SubElement(obj, "name").text = str(class_val)
3019
- ET.SubElement(obj, "difficult").text = "0"
3020
-
3021
- bbox = ET.SubElement(obj, "bndbox")
3022
- ET.SubElement(bbox, "xmin").text = str(xmin)
3023
- ET.SubElement(bbox, "ymin").text = str(ymin)
3024
- ET.SubElement(bbox, "xmax").text = str(xmax)
3025
- ET.SubElement(bbox, "ymax").text = str(ymax)
3696
+ elif metadata_format == "COCO":
3697
+ # Add image info
3698
+ image_id = tile_index
3699
+ coco_annotations["images"].append(
3700
+ {
3701
+ "id": image_id,
3702
+ "file_name": f"tile_{tile_index:06d}.tif",
3703
+ "width": tile_size,
3704
+ "height": tile_size,
3705
+ "crs": str(src.crs),
3706
+ "transform": str(window_transform),
3707
+ }
3708
+ )
3026
3709
 
3027
- # Save XML
3028
- tree = ET.ElementTree(root)
3029
- xml_path = os.path.join(ann_dir, f"tile_{tile_index:06d}.xml")
3030
- tree.write(xml_path)
3710
+ # Add annotations for each feature
3711
+ for _, feature in window_features.iterrows():
3712
+ # Get feature class
3713
+ if class_value_field in feature:
3714
+ class_val = feature[class_value_field]
3715
+ category_id = class_to_id.get(class_val, 1)
3716
+ else:
3717
+ category_id = 1
3718
+
3719
+ # Get geometry bounds
3720
+ geom = feature.geometry.intersection(window_bounds)
3721
+ if not geom.is_empty:
3722
+ # Get bounds in world coordinates
3723
+ minx_f, miny_f, maxx_f, maxy_f = geom.bounds
3724
+
3725
+ # Convert to pixel coordinates
3726
+ col_min, row_min = ~window_transform * (minx_f, maxy_f)
3727
+ col_max, row_max = ~window_transform * (maxx_f, miny_f)
3728
+
3729
+ # Ensure coordinates are within tile bounds
3730
+ xmin = max(0, min(tile_size, int(col_min)))
3731
+ ymin = max(0, min(tile_size, int(row_min)))
3732
+ xmax = max(0, min(tile_size, int(col_max)))
3733
+ ymax = max(0, min(tile_size, int(row_max)))
3734
+
3735
+ # Skip if box is too small
3736
+ if xmax - xmin < 1 or ymax - ymin < 1:
3737
+ continue
3738
+
3739
+ width = xmax - xmin
3740
+ height = ymax - ymin
3741
+
3742
+ # Add annotation
3743
+ ann_id += 1
3744
+ coco_annotations["annotations"].append(
3745
+ {
3746
+ "id": ann_id,
3747
+ "image_id": image_id,
3748
+ "category_id": category_id,
3749
+ "bbox": [xmin, ymin, width, height],
3750
+ "area": width * height,
3751
+ "iscrowd": 0,
3752
+ }
3753
+ )
3754
+
3755
+ elif metadata_format == "YOLO":
3756
+ # Create YOLO format annotations
3757
+ yolo_annotations = []
3758
+
3759
+ for _, feature in window_features.iterrows():
3760
+ # Get feature class
3761
+ if class_value_field in feature:
3762
+ class_val = feature[class_value_field]
3763
+ # YOLO uses 0-indexed class IDs
3764
+ class_id = class_to_id.get(class_val, 1) - 1
3765
+ else:
3766
+ class_id = 0
3767
+
3768
+ # Get geometry bounds
3769
+ geom = feature.geometry.intersection(window_bounds)
3770
+ if not geom.is_empty:
3771
+ # Get bounds in world coordinates
3772
+ minx_f, miny_f, maxx_f, maxy_f = geom.bounds
3773
+
3774
+ # Convert to pixel coordinates
3775
+ col_min, row_min = ~window_transform * (minx_f, maxy_f)
3776
+ col_max, row_max = ~window_transform * (maxx_f, miny_f)
3777
+
3778
+ # Ensure coordinates are within tile bounds
3779
+ xmin = max(0, min(tile_size, col_min))
3780
+ ymin = max(0, min(tile_size, row_min))
3781
+ xmax = max(0, min(tile_size, col_max))
3782
+ ymax = max(0, min(tile_size, row_max))
3783
+
3784
+ # Skip if box is too small
3785
+ if xmax - xmin < 1 or ymax - ymin < 1:
3786
+ continue
3787
+
3788
+ # Calculate normalized coordinates (YOLO format)
3789
+ x_center = ((xmin + xmax) / 2) / tile_size
3790
+ y_center = ((ymin + ymax) / 2) / tile_size
3791
+ width = (xmax - xmin) / tile_size
3792
+ height = (ymax - ymin) / tile_size
3793
+
3794
+ # Add YOLO annotation line
3795
+ yolo_annotations.append(
3796
+ f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
3797
+ )
3798
+
3799
+ # Save YOLO annotations to text file
3800
+ if yolo_annotations:
3801
+ yolo_path = os.path.join(
3802
+ label_dir, f"tile_{tile_index:06d}.txt"
3803
+ )
3804
+ with open(yolo_path, "w") as f:
3805
+ f.write("\n".join(yolo_annotations))
3031
3806
 
3032
3807
  # Update progress bar
3033
3808
  pbar.update(1)
@@ -3045,6 +3820,39 @@ def export_geotiff_tiles(
3045
3820
  # Close progress bar
3046
3821
  pbar.close()
3047
3822
 
3823
+ # Save COCO annotations if applicable (only if class data provided)
3824
+ if in_class_data is not None and metadata_format == "COCO":
3825
+ try:
3826
+ with open(os.path.join(ann_dir, "instances.json"), "w") as f:
3827
+ json.dump(coco_annotations, f, indent=2)
3828
+ if not quiet:
3829
+ print(
3830
+ f"Saved COCO annotations: {len(coco_annotations['images'])} images, "
3831
+ f"{len(coco_annotations['annotations'])} annotations, "
3832
+ f"{len(coco_annotations['categories'])} categories"
3833
+ )
3834
+ except Exception as e:
3835
+ if not quiet:
3836
+ print(f"ERROR saving COCO annotations: {e}")
3837
+ stats["errors"] += 1
3838
+
3839
+ # Save YOLO classes file if applicable (only if class data provided)
3840
+ if in_class_data is not None and metadata_format == "YOLO":
3841
+ try:
3842
+ # Create classes.txt with class names
3843
+ classes_path = os.path.join(out_folder, "classes.txt")
3844
+ # Sort by class ID to ensure correct order
3845
+ sorted_classes = sorted(class_to_id.items(), key=lambda x: x[1])
3846
+ with open(classes_path, "w") as f:
3847
+ for class_val, _ in sorted_classes:
3848
+ f.write(f"{class_val}\n")
3849
+ if not quiet:
3850
+ print(f"Saved YOLO classes file with {len(class_to_id)} classes")
3851
+ except Exception as e:
3852
+ if not quiet:
3853
+ print(f"ERROR saving YOLO classes file: {e}")
3854
+ stats["errors"] += 1
3855
+
3048
3856
  # Create overview image if requested
3049
3857
  if create_overview and stats["tile_coordinates"]:
3050
3858
  try:
@@ -3062,13 +3870,14 @@ def export_geotiff_tiles(
3062
3870
  if not quiet:
3063
3871
  print("\n------- Export Summary -------")
3064
3872
  print(f"Total tiles exported: {stats['total_tiles']}")
3065
- print(
3066
- f"Tiles with features: {stats['tiles_with_features']} ({stats['tiles_with_features']/max(1, stats['total_tiles'])*100:.1f}%)"
3067
- )
3068
- if stats["tiles_with_features"] > 0:
3873
+ if in_class_data is not None:
3069
3874
  print(
3070
- f"Average feature pixels per tile: {stats['feature_pixels']/stats['tiles_with_features']:.1f}"
3875
+ f"Tiles with features: {stats['tiles_with_features']} ({stats['tiles_with_features']/max(1, stats['total_tiles'])*100:.1f}%)"
3071
3876
  )
3877
+ if stats["tiles_with_features"] > 0:
3878
+ print(
3879
+ f"Average feature pixels per tile: {stats['feature_pixels']/stats['tiles_with_features']:.1f}"
3880
+ )
3072
3881
  if stats["errors"] > 0:
3073
3882
  print(f"Errors encountered: {stats['errors']}")
3074
3883
  print(f"Output saved to: {out_folder}")
@@ -3077,7 +3886,6 @@ def export_geotiff_tiles(
3077
3886
  if stats["total_tiles"] > 0:
3078
3887
  print("\n------- Georeference Verification -------")
3079
3888
  sample_image = os.path.join(image_dir, f"tile_0.tif")
3080
- sample_label = os.path.join(label_dir, f"tile_0.tif")
3081
3889
 
3082
3890
  if os.path.exists(sample_image):
3083
3891
  try:
@@ -3093,19 +3901,22 @@ def export_geotiff_tiles(
3093
3901
  except Exception as e:
3094
3902
  print(f"Error verifying image georeference: {e}")
3095
3903
 
3096
- if os.path.exists(sample_label):
3097
- try:
3098
- with rasterio.open(sample_label) as lbl:
3099
- print(f"Label CRS: {lbl.crs}")
3100
- print(f"Label transform: {lbl.transform}")
3101
- print(
3102
- f"Label has georeference: {lbl.crs is not None and lbl.transform is not None}"
3103
- )
3104
- print(
3105
- f"Label dimensions: {lbl.width}x{lbl.height}, {lbl.count} bands, {lbl.dtypes[0]} type"
3106
- )
3107
- except Exception as e:
3108
- print(f"Error verifying label georeference: {e}")
3904
+ # Only verify label if class data was provided
3905
+ if in_class_data is not None:
3906
+ sample_label = os.path.join(label_dir, f"tile_0.tif")
3907
+ if os.path.exists(sample_label):
3908
+ try:
3909
+ with rasterio.open(sample_label) as lbl:
3910
+ print(f"Label CRS: {lbl.crs}")
3911
+ print(f"Label transform: {lbl.transform}")
3912
+ print(
3913
+ f"Label has georeference: {lbl.crs is not None and lbl.transform is not None}"
3914
+ )
3915
+ print(
3916
+ f"Label dimensions: {lbl.width}x{lbl.height}, {lbl.count} bands, {lbl.dtypes[0]} type"
3917
+ )
3918
+ except Exception as e:
3919
+ print(f"Error verifying label georeference: {e}")
3109
3920
 
3110
3921
  # Return statistics dictionary for further processing if needed
3111
3922
  return stats
@@ -3113,8 +3924,9 @@ def export_geotiff_tiles(
3113
3924
 
3114
3925
  def export_geotiff_tiles_batch(
3115
3926
  images_folder,
3116
- masks_folder,
3117
- output_folder,
3927
+ masks_folder=None,
3928
+ masks_file=None,
3929
+ output_folder=None,
3118
3930
  tile_size=256,
3119
3931
  stride=128,
3120
3932
  class_value_field="class",
@@ -3122,25 +3934,43 @@ def export_geotiff_tiles_batch(
3122
3934
  max_tiles=None,
3123
3935
  quiet=False,
3124
3936
  all_touched=True,
3125
- create_overview=False,
3126
3937
  skip_empty_tiles=False,
3127
3938
  image_extensions=None,
3128
3939
  mask_extensions=None,
3940
+ match_by_name=False,
3941
+ metadata_format="PASCAL_VOC",
3129
3942
  ) -> Dict[str, Any]:
3130
3943
  """
3131
- Export georeferenced GeoTIFF tiles from folders of images and masks.
3944
+ Export georeferenced GeoTIFF tiles from images and optionally masks.
3945
+
3946
+ This function supports four modes:
3947
+ 1. Images only (no masks) - when neither masks_file nor masks_folder is provided
3948
+ 2. Single vector file covering all images (masks_file parameter)
3949
+ 3. Multiple vector files, one per image (masks_folder parameter)
3950
+ 4. Multiple raster mask files (masks_folder parameter)
3132
3951
 
3133
- This function processes multiple image-mask pairs from input folders,
3134
- generating tiles for each pair. All image tiles are saved to a single
3135
- 'images' folder and all mask tiles to a single 'masks' folder.
3952
+ For mode 1 (images only), only image tiles will be exported without labels.
3136
3953
 
3137
- Images and masks are paired by their sorted order (alphabetically), not by
3138
- filename matching. The number of images and masks must be equal.
3954
+ For mode 2 (single vector file), specify masks_file path. The function will
3955
+ use spatial intersection to determine which features apply to each image.
3956
+
3957
+ For mode 3/4 (multiple mask files), specify masks_folder path. Images and masks
3958
+ are paired either by matching filenames (match_by_name=True) or by sorted order
3959
+ (match_by_name=False).
3960
+
3961
+ All image tiles are saved to a single 'images' folder and all mask tiles (if provided)
3962
+ to a single 'masks' folder within the output directory.
3139
3963
 
3140
3964
  Args:
3141
3965
  images_folder (str): Path to folder containing raster images
3142
- masks_folder (str): Path to folder containing classification masks/vectors
3143
- output_folder (str): Path to output folder
3966
+ masks_folder (str, optional): Path to folder containing classification masks/vectors.
3967
+ Use this for multiple mask files (one per image or raster masks). If not provided
3968
+ and masks_file is also not provided, only image tiles will be exported.
3969
+ masks_file (str, optional): Path to a single vector file covering all images.
3970
+ Use this for a single GeoJSON/Shapefile that covers multiple images. If not provided
3971
+ and masks_folder is also not provided, only image tiles will be exported.
3972
+ output_folder (str, optional): Path to output folder. If None, creates 'tiles'
3973
+ subfolder in images_folder.
3144
3974
  tile_size (int): Size of tiles in pixels (square)
3145
3975
  stride (int): Step size between tiles
3146
3976
  class_value_field (str): Field containing class values (for vector data)
@@ -3152,18 +3982,63 @@ def export_geotiff_tiles_batch(
3152
3982
  skip_empty_tiles (bool): If True, skip tiles with no features
3153
3983
  image_extensions (list): List of image file extensions to process (default: common raster formats)
3154
3984
  mask_extensions (list): List of mask file extensions to process (default: common raster/vector formats)
3985
+ match_by_name (bool): If True, match image and mask files by base filename.
3986
+ If False, match by sorted order (alphabetically). Only applies when masks_folder is used.
3987
+ metadata_format (str): Annotation format - "PASCAL_VOC" (XML), "COCO" (JSON), or "YOLO" (TXT).
3988
+ Default is "PASCAL_VOC".
3989
+
3990
+ Returns:
3991
+ Dict[str, Any]: Dictionary containing batch processing statistics
3992
+
3993
+ Raises:
3994
+ ValueError: If no images found, or if masks_folder and masks_file are both specified,
3995
+ or if counts don't match when using masks_folder with match_by_name=False.
3996
+
3997
+ Examples:
3998
+ # Images only (no masks)
3999
+ >>> stats = export_geotiff_tiles_batch(
4000
+ ... images_folder='data/images',
4001
+ ... output_folder='output/tiles'
4002
+ ... )
3155
4003
 
3156
- Returns:
3157
- Dict[str, Any]: Dictionary containing batch processing statistics
4004
+ # Single vector file covering all images
4005
+ >>> stats = export_geotiff_tiles_batch(
4006
+ ... images_folder='data/images',
4007
+ ... masks_file='data/buildings.geojson',
4008
+ ... output_folder='output/tiles'
4009
+ ... )
3158
4010
 
3159
- Raises:
3160
- ValueError: If no images or masks found, or if counts don't match
4011
+ # Multiple vector files, matched by filename
4012
+ >>> stats = export_geotiff_tiles_batch(
4013
+ ... images_folder='data/images',
4014
+ ... masks_folder='data/masks',
4015
+ ... output_folder='output/tiles',
4016
+ ... match_by_name=True
4017
+ ... )
4018
+
4019
+ # Multiple mask files, matched by sorted order
4020
+ >>> stats = export_geotiff_tiles_batch(
4021
+ ... images_folder='data/images',
4022
+ ... masks_folder='data/masks',
4023
+ ... output_folder='output/tiles',
4024
+ ... match_by_name=False
4025
+ ... )
3161
4026
  """
3162
4027
 
3163
4028
  import logging
3164
4029
 
3165
4030
  logging.getLogger("rasterio").setLevel(logging.ERROR)
3166
4031
 
4032
+ # Validate input parameters
4033
+ if masks_folder is not None and masks_file is not None:
4034
+ raise ValueError(
4035
+ "Cannot specify both masks_folder and masks_file. Please use only one."
4036
+ )
4037
+
4038
+ # Default output folder if not specified
4039
+ if output_folder is None:
4040
+ output_folder = os.path.join(images_folder, "tiles")
4041
+
3167
4042
  # Default extensions if not provided
3168
4043
  if image_extensions is None:
3169
4044
  image_extensions = [".tif", ".tiff", ".jpg", ".jpeg", ".png", ".jp2", ".img"]
@@ -3190,9 +4065,37 @@ def export_geotiff_tiles_batch(
3190
4065
  # Create output folder structure
3191
4066
  os.makedirs(output_folder, exist_ok=True)
3192
4067
  output_images_dir = os.path.join(output_folder, "images")
3193
- output_masks_dir = os.path.join(output_folder, "masks")
3194
4068
  os.makedirs(output_images_dir, exist_ok=True)
3195
- os.makedirs(output_masks_dir, exist_ok=True)
4069
+
4070
+ # Only create masks directory if masks are provided
4071
+ output_masks_dir = None
4072
+ if masks_folder is not None or masks_file is not None:
4073
+ output_masks_dir = os.path.join(output_folder, "masks")
4074
+ os.makedirs(output_masks_dir, exist_ok=True)
4075
+
4076
+ # Create annotation directory based on metadata format (only if masks are provided)
4077
+ ann_dir = None
4078
+ if (masks_folder is not None or masks_file is not None) and metadata_format in [
4079
+ "PASCAL_VOC",
4080
+ "COCO",
4081
+ ]:
4082
+ ann_dir = os.path.join(output_folder, "annotations")
4083
+ os.makedirs(ann_dir, exist_ok=True)
4084
+
4085
+ # Initialize COCO annotations dictionary (only if masks are provided)
4086
+ coco_annotations = None
4087
+ if (
4088
+ masks_folder is not None or masks_file is not None
4089
+ ) and metadata_format == "COCO":
4090
+ coco_annotations = {"images": [], "annotations": [], "categories": []}
4091
+
4092
+ # Initialize YOLO class set (only if masks are provided)
4093
+ yolo_classes = (
4094
+ set()
4095
+ if (masks_folder is not None or masks_file is not None)
4096
+ and metadata_format == "YOLO"
4097
+ else None
4098
+ )
3196
4099
 
3197
4100
  # Get list of image files
3198
4101
  image_files = []
@@ -3200,30 +4103,105 @@ def export_geotiff_tiles_batch(
3200
4103
  pattern = os.path.join(images_folder, f"*{ext}")
3201
4104
  image_files.extend(glob.glob(pattern))
3202
4105
 
3203
- # Get list of mask files
3204
- mask_files = []
3205
- for ext in mask_extensions:
3206
- pattern = os.path.join(masks_folder, f"*{ext}")
3207
- mask_files.extend(glob.glob(pattern))
3208
-
3209
4106
  # Sort files for consistent processing
3210
4107
  image_files.sort()
3211
- mask_files.sort()
3212
4108
 
3213
4109
  if not image_files:
3214
4110
  raise ValueError(
3215
4111
  f"No image files found in {images_folder} with extensions {image_extensions}"
3216
4112
  )
3217
4113
 
3218
- if not mask_files:
3219
- raise ValueError(
3220
- f"No mask files found in {masks_folder} with extensions {mask_extensions}"
3221
- )
4114
+ # Handle different mask input modes
4115
+ use_single_mask_file = masks_file is not None
4116
+ has_masks = masks_file is not None or masks_folder is not None
4117
+ mask_files = []
4118
+ image_mask_pairs = []
3222
4119
 
3223
- if len(image_files) != len(mask_files):
3224
- raise ValueError(
3225
- f"Number of image files ({len(image_files)}) does not match number of mask files ({len(mask_files)})"
3226
- )
4120
+ if not has_masks:
4121
+ # Mode 0: No masks - create pairs with None for mask
4122
+ for image_file in image_files:
4123
+ image_mask_pairs.append((image_file, None, None))
4124
+
4125
+ elif use_single_mask_file:
4126
+ # Mode 1: Single vector file covering all images
4127
+ if not os.path.exists(masks_file):
4128
+ raise ValueError(f"Mask file not found: {masks_file}")
4129
+
4130
+ # Load the single mask file once - will be spatially filtered per image
4131
+ single_mask_gdf = gpd.read_file(masks_file)
4132
+
4133
+ if not quiet:
4134
+ print(f"Using single mask file: {masks_file}")
4135
+ print(
4136
+ f"Mask contains {len(single_mask_gdf)} features in CRS: {single_mask_gdf.crs}"
4137
+ )
4138
+
4139
+ # Create pairs with the same mask file for all images
4140
+ for image_file in image_files:
4141
+ image_mask_pairs.append((image_file, masks_file, single_mask_gdf))
4142
+
4143
+ else:
4144
+ # Mode 2/3: Multiple mask files (vector or raster)
4145
+ # Get list of mask files
4146
+ for ext in mask_extensions:
4147
+ pattern = os.path.join(masks_folder, f"*{ext}")
4148
+ mask_files.extend(glob.glob(pattern))
4149
+
4150
+ # Sort files for consistent processing
4151
+ mask_files.sort()
4152
+
4153
+ if not mask_files:
4154
+ raise ValueError(
4155
+ f"No mask files found in {masks_folder} with extensions {mask_extensions}"
4156
+ )
4157
+
4158
+ # Match images to masks
4159
+ if match_by_name:
4160
+ # Match by base filename
4161
+ image_dict = {
4162
+ os.path.splitext(os.path.basename(f))[0]: f for f in image_files
4163
+ }
4164
+ mask_dict = {
4165
+ os.path.splitext(os.path.basename(f))[0]: f for f in mask_files
4166
+ }
4167
+
4168
+ # Find matching pairs
4169
+ for img_base, img_path in image_dict.items():
4170
+ if img_base in mask_dict:
4171
+ image_mask_pairs.append((img_path, mask_dict[img_base], None))
4172
+ else:
4173
+ if not quiet:
4174
+ print(f"Warning: No mask found for image {img_base}")
4175
+
4176
+ if not image_mask_pairs:
4177
+ # Provide detailed error message with found files
4178
+ image_bases = list(image_dict.keys())
4179
+ mask_bases = list(mask_dict.keys())
4180
+ error_msg = (
4181
+ "No matching image-mask pairs found when matching by filename. "
4182
+ "Check that image and mask files have matching base names.\n"
4183
+ f"Found {len(image_bases)} image(s): "
4184
+ f"{', '.join(image_bases[:5]) if image_bases else 'None found'}"
4185
+ f"{'...' if len(image_bases) > 5 else ''}\n"
4186
+ f"Found {len(mask_bases)} mask(s): "
4187
+ f"{', '.join(mask_bases[:5]) if mask_bases else 'None found'}"
4188
+ f"{'...' if len(mask_bases) > 5 else ''}\n"
4189
+ "Tip: Set match_by_name=False to match by sorted order, or ensure filenames match."
4190
+ )
4191
+ raise ValueError(error_msg)
4192
+
4193
+ else:
4194
+ # Match by sorted order
4195
+ if len(image_files) != len(mask_files):
4196
+ raise ValueError(
4197
+ f"Number of image files ({len(image_files)}) does not match "
4198
+ f"number of mask files ({len(mask_files)}) when matching by sorted order. "
4199
+ f"Use match_by_name=True for filename-based matching."
4200
+ )
4201
+
4202
+ # Create pairs by sorted order
4203
+ for image_file, mask_file in zip(image_files, mask_files):
4204
+ image_mask_pairs.append((image_file, mask_file, None))
3227
4205
 
3228
4206
  # Initialize batch statistics
3229
4207
  batch_stats = {
@@ -3237,23 +4215,28 @@ def export_geotiff_tiles_batch(
3237
4215
  }
3238
4216
 
3239
4217
  if not quiet:
3240
- print(
3241
- f"Found {len(image_files)} image files and {len(mask_files)} mask files to process"
3242
- )
3243
- print(f"Processing batch from {images_folder} and {masks_folder}")
4218
+ if not has_masks:
4219
+ print(
4220
+ f"Found {len(image_files)} image files to process (images only, no masks)"
4221
+ )
4222
+ elif use_single_mask_file:
4223
+ print(f"Found {len(image_files)} image files to process")
4224
+ print(f"Using single mask file: {masks_file}")
4225
+ else:
4226
+ print(f"Found {len(image_mask_pairs)} matching image-mask pairs to process")
4227
+ print(f"Processing batch from {images_folder} and {masks_folder}")
3244
4228
  print(f"Output folder: {output_folder}")
3245
4229
  print("-" * 60)
3246
4230
 
3247
4231
  # Global tile counter for unique naming
3248
4232
  global_tile_counter = 0
3249
4233
 
3250
- # Process each image-mask pair by sorted order
3251
- for idx, (image_file, mask_file) in enumerate(
4234
+ # Process each image-mask pair
4235
+ for idx, (image_file, mask_file, mask_gdf) in enumerate(
3252
4236
  tqdm(
3253
- zip(image_files, mask_files),
4237
+ image_mask_pairs,
3254
4238
  desc="Processing image pairs",
3255
4239
  disable=quiet,
3256
- total=len(image_files),
3257
4240
  )
3258
4241
  ):
3259
4242
  batch_stats["total_image_pairs"] += 1
@@ -3265,9 +4248,17 @@ def export_geotiff_tiles_batch(
3265
4248
  if not quiet:
3266
4249
  print(f"\nProcessing: {base_name}")
3267
4250
  print(f" Image: {os.path.basename(image_file)}")
3268
- print(f" Mask: {os.path.basename(mask_file)}")
4251
+ if mask_file is not None:
4252
+ if use_single_mask_file:
4253
+ print(
4254
+ f" Mask: {os.path.basename(mask_file)} (spatially filtered)"
4255
+ )
4256
+ else:
4257
+ print(f" Mask: {os.path.basename(mask_file)}")
4258
+ else:
4259
+ print(f" Mask: None (images only)")
3269
4260
 
3270
- # Process the image-mask pair manually to get direct control over tile saving
4261
+ # Process the image-mask pair
3271
4262
  tiles_generated = _process_image_mask_pair(
3272
4263
  image_file=image_file,
3273
4264
  mask_file=mask_file,
@@ -3283,6 +4274,15 @@ def export_geotiff_tiles_batch(
3283
4274
  all_touched=all_touched,
3284
4275
  skip_empty_tiles=skip_empty_tiles,
3285
4276
  quiet=quiet,
4277
+ mask_gdf=mask_gdf, # Pass pre-loaded GeoDataFrame if using single mask
4278
+ use_single_mask_file=use_single_mask_file,
4279
+ metadata_format=metadata_format,
4280
+ ann_dir=(
4281
+ ann_dir
4282
+ if "ann_dir" in locals()
4283
+ and metadata_format in ["PASCAL_VOC", "COCO"]
4284
+ else None
4285
+ ),
3286
4286
  )
3287
4287
 
3288
4288
  # Update counters
@@ -3304,6 +4304,23 @@ def export_geotiff_tiles_batch(
3304
4304
  }
3305
4305
  )
3306
4306
 
4307
+ # Aggregate COCO annotations
4308
+ if metadata_format == "COCO" and "coco_data" in tiles_generated:
4309
+ coco_data = tiles_generated["coco_data"]
4310
+ # Add images and annotations
4311
+ coco_annotations["images"].extend(coco_data.get("images", []))
4312
+ coco_annotations["annotations"].extend(coco_data.get("annotations", []))
4313
+ # Merge categories (avoid duplicates)
4314
+ for cat in coco_data.get("categories", []):
4315
+ if not any(
4316
+ c["id"] == cat["id"] for c in coco_annotations["categories"]
4317
+ ):
4318
+ coco_annotations["categories"].append(cat)
4319
+
4320
+ # Aggregate YOLO classes
4321
+ if metadata_format == "YOLO" and "yolo_classes" in tiles_generated:
4322
+ yolo_classes.update(tiles_generated["yolo_classes"])
4323
+
3307
4324
  except Exception as e:
3308
4325
  if not quiet:
3309
4326
  print(f"ERROR processing {base_name}: {e}")
@@ -3312,6 +4329,33 @@ def export_geotiff_tiles_batch(
3312
4329
  )
3313
4330
  batch_stats["errors"] += 1
3314
4331
 
4332
+ # Save aggregated COCO annotations
4333
+ if metadata_format == "COCO" and coco_annotations:
4334
+ import json
4335
+
4336
+ coco_path = os.path.join(ann_dir, "instances.json")
4337
+ with open(coco_path, "w") as f:
4338
+ json.dump(coco_annotations, f, indent=2)
4339
+ if not quiet:
4340
+ print(f"\nSaved COCO annotations: {coco_path}")
4341
+ print(
4342
+ f" Images: {len(coco_annotations['images'])}, "
4343
+ f"Annotations: {len(coco_annotations['annotations'])}, "
4344
+ f"Categories: {len(coco_annotations['categories'])}"
4345
+ )
4346
+
4347
+ # Save aggregated YOLO classes
4348
+ if metadata_format == "YOLO" and yolo_classes:
4349
+ classes_path = os.path.join(output_folder, "labels", "classes.txt")
4350
+ os.makedirs(os.path.dirname(classes_path), exist_ok=True)
4351
+ sorted_classes = sorted(yolo_classes)
4352
+ with open(classes_path, "w") as f:
4353
+ for cls in sorted_classes:
4354
+ f.write(f"{cls}\n")
4355
+ if not quiet:
4356
+ print(f"\nSaved YOLO classes: {classes_path}")
4357
+ print(f" Total classes: {len(sorted_classes)}")
4358
+
3315
4359
  # Print batch summary
3316
4360
  if not quiet:
3317
4361
  print("\n" + "=" * 60)
@@ -3334,7 +4378,12 @@ def export_geotiff_tiles_batch(
3334
4378
 
3335
4379
  print(f"Output saved to: {output_folder}")
3336
4380
  print(f" Images: {output_images_dir}")
3337
- print(f" Masks: {output_masks_dir}")
4381
+ if output_masks_dir is not None:
4382
+ print(f" Masks: {output_masks_dir}")
4383
+ if metadata_format in ["PASCAL_VOC", "COCO"] and ann_dir is not None:
4384
+ print(f" Annotations: {ann_dir}")
4385
+ elif metadata_format == "YOLO":
4386
+ print(f" Labels: {os.path.join(output_folder, 'labels')}")
3338
4387
 
3339
4388
  # List failed files if any
3340
4389
  if batch_stats["failed_files"]:
@@ -3360,18 +4409,26 @@ def _process_image_mask_pair(
3360
4409
  all_touched=True,
3361
4410
  skip_empty_tiles=False,
3362
4411
  quiet=False,
4412
+ mask_gdf=None,
4413
+ use_single_mask_file=False,
4414
+ metadata_format="PASCAL_VOC",
4415
+ ann_dir=None,
3363
4416
  ):
3364
4417
  """
3365
4418
  Process a single image-mask pair and save tiles directly to output directories.
3366
4419
 
4420
+ Args:
4421
+ mask_gdf (GeoDataFrame, optional): Pre-loaded GeoDataFrame when using single mask file
4422
+ use_single_mask_file (bool): If True, spatially filter mask_gdf to image bounds
4423
+
3367
4424
  Returns:
3368
4425
  dict: Statistics for this image-mask pair
3369
4426
  """
3370
4427
  import warnings
3371
4428
 
3372
- # Determine if mask data is raster or vector
4429
+ # Determine if mask data is raster or vector (only if mask_file is provided)
3373
4430
  is_class_data_raster = False
3374
- if isinstance(mask_file, str):
4431
+ if mask_file is not None and isinstance(mask_file, str):
3375
4432
  file_ext = Path(mask_file).suffix.lower()
3376
4433
  # Common raster extensions
3377
4434
  if file_ext in [".tif", ".tiff", ".img", ".jp2", ".png", ".bmp", ".gif"]:
@@ -3388,6 +4445,13 @@ def _process_image_mask_pair(
3388
4445
  "errors": 0,
3389
4446
  }
3390
4447
 
4448
+ # Initialize COCO/YOLO tracking for this image
4449
+ if metadata_format == "COCO":
4450
+ stats["coco_data"] = {"images": [], "annotations": [], "categories": []}
4451
+ coco_ann_id = 0
4452
+ if metadata_format == "YOLO":
4453
+ stats["yolo_classes"] = set()
4454
+
3391
4455
  # Open the input raster
3392
4456
  with rasterio.open(image_file) as src:
3393
4457
  # Calculate number of tiles
@@ -3398,10 +4462,10 @@ def _process_image_mask_pair(
3398
4462
  if max_tiles is None:
3399
4463
  max_tiles = total_tiles
3400
4464
 
3401
- # Process classification data
4465
+ # Process classification data (only if mask_file is provided)
3402
4466
  class_to_id = {}
3403
4467
 
3404
- if is_class_data_raster:
4468
+ if mask_file is not None and is_class_data_raster:
3405
4469
  # Load raster class data
3406
4470
  with rasterio.open(mask_file) as class_src:
3407
4471
  # Check if raster CRS matches
@@ -3428,14 +4492,39 @@ def _process_image_mask_pair(
3428
4492
 
3429
4493
  # Create class mapping
3430
4494
  class_to_id = {int(cls): i + 1 for i, cls in enumerate(unique_classes)}
3431
- else:
4495
+ elif mask_file is not None:
3432
4496
  # Load vector class data
3433
4497
  try:
3434
- gdf = gpd.read_file(mask_file)
4498
+ if use_single_mask_file and mask_gdf is not None:
4499
+ # Using pre-loaded single mask file - spatially filter to image bounds
4500
+ # Get image bounds
4501
+ image_bounds = box(*src.bounds)
4502
+ image_gdf = gpd.GeoDataFrame(
4503
+ {"geometry": [image_bounds]}, crs=src.crs
4504
+ )
3435
4505
 
3436
- # Always reproject to match raster CRS
3437
- if gdf.crs != src.crs:
3438
- gdf = gdf.to_crs(src.crs)
4506
+ # Reproject mask if needed
4507
+ if mask_gdf.crs != src.crs:
4508
+ mask_gdf_reprojected = mask_gdf.to_crs(src.crs)
4509
+ else:
4510
+ mask_gdf_reprojected = mask_gdf
4511
+
4512
+ # Spatially filter features that intersect with image bounds
4513
+ gdf = mask_gdf_reprojected[
4514
+ mask_gdf_reprojected.intersects(image_bounds)
4515
+ ].copy()
4516
+
4517
+ if not quiet and len(gdf) > 0:
4518
+ print(
4519
+ f" Filtered to {len(gdf)} features intersecting image bounds"
4520
+ )
4521
+ else:
4522
+ # Load individual mask file
4523
+ gdf = gpd.read_file(mask_file)
4524
+
4525
+ # Always reproject to match raster CRS
4526
+ if gdf.crs != src.crs:
4527
+ gdf = gdf.to_crs(src.crs)
3439
4528
 
3440
4529
  # Apply buffer if specified
3441
4530
  if buffer_radius > 0:
@@ -3455,9 +4544,6 @@ def _process_image_mask_pair(
3455
4544
  tile_index = 0
3456
4545
  for y in range(num_tiles_y):
3457
4546
  for x in range(num_tiles_x):
3458
- if tile_index >= max_tiles:
3459
- break
3460
-
3461
4547
  # Calculate window coordinates
3462
4548
  window_x = x * stride
3463
4549
  window_y = y * stride
@@ -3482,12 +4568,12 @@ def _process_image_mask_pair(
3482
4568
 
3483
4569
  window_bounds = box(minx, miny, maxx, maxy)
3484
4570
 
3485
- # Create label mask
4571
+ # Create label mask (only if mask_file is provided)
3486
4572
  label_mask = np.zeros((tile_size, tile_size), dtype=np.uint8)
3487
4573
  has_features = False
3488
4574
 
3489
- # Process classification data to create labels
3490
- if is_class_data_raster:
4575
+ # Process classification data to create labels (only if mask_file is provided)
4576
+ if mask_file is not None and is_class_data_raster:
3491
4577
  # For raster class data
3492
4578
  with rasterio.open(mask_file) as class_src:
3493
4579
  # Get corresponding window in class raster
@@ -3520,7 +4606,7 @@ def _process_image_mask_pair(
3520
4606
  if not quiet:
3521
4607
  print(f"Error reading class raster window: {e}")
3522
4608
  stats["errors"] += 1
3523
- else:
4609
+ elif mask_file is not None:
3524
4610
  # For vector class data
3525
4611
  # Find features that intersect with window
3526
4612
  window_features = gdf[gdf.intersects(window_bounds)]
@@ -3558,11 +4644,14 @@ def _process_image_mask_pair(
3558
4644
  print(f"Error rasterizing feature {idx}: {e}")
3559
4645
  stats["errors"] += 1
3560
4646
 
3561
- # Skip tile if no features and skip_empty_tiles is True
3562
- if skip_empty_tiles and not has_features:
3563
- tile_index += 1
4647
+ # Skip tile if no features and skip_empty_tiles is True (only applies when masks are provided)
4648
+ if mask_file is not None and skip_empty_tiles and not has_features:
3564
4649
  continue
3565
4650
 
4651
+ # Check if we've reached max_tiles before saving
4652
+ if tile_index >= max_tiles:
4653
+ break
4654
+
3566
4655
  # Generate unique tile name
3567
4656
  tile_name = f"{base_name}_{global_tile_counter + tile_index:06d}"
3568
4657
 
@@ -3593,29 +4682,225 @@ def _process_image_mask_pair(
3593
4682
  print(f"ERROR saving image GeoTIFF: {e}")
3594
4683
  stats["errors"] += 1
3595
4684
 
3596
- # Create profile for label GeoTIFF
3597
- label_profile = {
3598
- "driver": "GTiff",
3599
- "height": tile_size,
3600
- "width": tile_size,
3601
- "count": 1,
3602
- "dtype": "uint8",
3603
- "crs": src.crs,
3604
- "transform": window_transform,
3605
- }
4685
+ # Export label as GeoTIFF (only if mask_file and output_masks_dir are provided)
4686
+ if mask_file is not None and output_masks_dir is not None:
4687
+ # Create profile for label GeoTIFF
4688
+ label_profile = {
4689
+ "driver": "GTiff",
4690
+ "height": tile_size,
4691
+ "width": tile_size,
4692
+ "count": 1,
4693
+ "dtype": "uint8",
4694
+ "crs": src.crs,
4695
+ "transform": window_transform,
4696
+ }
3606
4697
 
3607
- # Export label as GeoTIFF
3608
- label_path = os.path.join(output_masks_dir, f"{tile_name}.tif")
3609
- try:
3610
- with rasterio.open(label_path, "w", **label_profile) as dst:
3611
- dst.write(label_mask.astype(np.uint8), 1)
4698
+ label_path = os.path.join(output_masks_dir, f"{tile_name}.tif")
4699
+ try:
4700
+ with rasterio.open(label_path, "w", **label_profile) as dst:
4701
+ dst.write(label_mask.astype(np.uint8), 1)
3612
4702
 
3613
- if has_features:
3614
- stats["tiles_with_features"] += 1
3615
- except Exception as e:
3616
- if not quiet:
3617
- print(f"ERROR saving label GeoTIFF: {e}")
3618
- stats["errors"] += 1
4703
+ if has_features:
4704
+ stats["tiles_with_features"] += 1
4705
+ except Exception as e:
4706
+ if not quiet:
4707
+ print(f"ERROR saving label GeoTIFF: {e}")
4708
+ stats["errors"] += 1
4709
+
4710
+ # Generate annotation metadata based on format (only if mask_file is provided)
4711
+ if (
4712
+ mask_file is not None
4713
+ and metadata_format == "PASCAL_VOC"
4714
+ and ann_dir
4715
+ ):
4716
+ # Create PASCAL VOC XML annotation
4717
+ from lxml import etree as ET
4718
+
4719
+ annotation = ET.Element("annotation")
4720
+ ET.SubElement(annotation, "folder").text = os.path.basename(
4721
+ output_images_dir
4722
+ )
4723
+ ET.SubElement(annotation, "filename").text = f"{tile_name}.tif"
4724
+ ET.SubElement(annotation, "path").text = image_path
4725
+
4726
+ source = ET.SubElement(annotation, "source")
4727
+ ET.SubElement(source, "database").text = "GeoAI"
4728
+
4729
+ size = ET.SubElement(annotation, "size")
4730
+ ET.SubElement(size, "width").text = str(tile_size)
4731
+ ET.SubElement(size, "height").text = str(tile_size)
4732
+ ET.SubElement(size, "depth").text = str(image_data.shape[0])
4733
+
4734
+ ET.SubElement(annotation, "segmented").text = "1"
4735
+
4736
+ # Find connected components for instance segmentation
4737
+ from scipy import ndimage
4738
+
4739
+ for class_id in np.unique(label_mask):
4740
+ if class_id == 0:
4741
+ continue
4742
+
4743
+ class_mask = (label_mask == class_id).astype(np.uint8)
4744
+ labeled_array, num_features = ndimage.label(class_mask)
4745
+
4746
+ for instance_id in range(1, num_features + 1):
4747
+ instance_mask = labeled_array == instance_id
4748
+ coords = np.argwhere(instance_mask)
4749
+
4750
+ if len(coords) == 0:
4751
+ continue
4752
+
4753
+ ymin, xmin = coords.min(axis=0)
4754
+ ymax, xmax = coords.max(axis=0)
4755
+
4756
+ obj = ET.SubElement(annotation, "object")
4757
+ class_name = next(
4758
+ (k for k, v in class_to_id.items() if v == class_id),
4759
+ str(class_id),
4760
+ )
4761
+ ET.SubElement(obj, "name").text = str(class_name)
4762
+ ET.SubElement(obj, "pose").text = "Unspecified"
4763
+ ET.SubElement(obj, "truncated").text = "0"
4764
+ ET.SubElement(obj, "difficult").text = "0"
4765
+
4766
+ bndbox = ET.SubElement(obj, "bndbox")
4767
+ ET.SubElement(bndbox, "xmin").text = str(int(xmin))
4768
+ ET.SubElement(bndbox, "ymin").text = str(int(ymin))
4769
+ ET.SubElement(bndbox, "xmax").text = str(int(xmax))
4770
+ ET.SubElement(bndbox, "ymax").text = str(int(ymax))
4771
+
4772
+ # Save XML file
4773
+ xml_path = os.path.join(ann_dir, f"{tile_name}.xml")
4774
+ tree = ET.ElementTree(annotation)
4775
+ tree.write(xml_path, pretty_print=True, encoding="utf-8")
4776
+
4777
+ elif mask_file is not None and metadata_format == "COCO":
4778
+ # Add COCO image entry
4779
+ image_id = int(global_tile_counter + tile_index)
4780
+ stats["coco_data"]["images"].append(
4781
+ {
4782
+ "id": image_id,
4783
+ "file_name": f"{tile_name}.tif",
4784
+ "width": int(tile_size),
4785
+ "height": int(tile_size),
4786
+ }
4787
+ )
4788
+
4789
+ # Add COCO categories (only once per unique class)
4790
+ for class_val, class_id in class_to_id.items():
4791
+ if not any(
4792
+ c["id"] == class_id
4793
+ for c in stats["coco_data"]["categories"]
4794
+ ):
4795
+ stats["coco_data"]["categories"].append(
4796
+ {
4797
+ "id": int(class_id),
4798
+ "name": str(class_val),
4799
+ "supercategory": "object",
4800
+ }
4801
+ )
4802
+
4803
+ # Add COCO annotations (instance segmentation)
4804
+ from scipy import ndimage
4805
+ from skimage import measure
4806
+
4807
+ for class_id in np.unique(label_mask):
4808
+ if class_id == 0:
4809
+ continue
4810
+
4811
+ class_mask = (label_mask == class_id).astype(np.uint8)
4812
+ labeled_array, num_features = ndimage.label(class_mask)
4813
+
4814
+ for instance_id in range(1, num_features + 1):
4815
+ instance_mask = (labeled_array == instance_id).astype(
4816
+ np.uint8
4817
+ )
4818
+ coords = np.argwhere(instance_mask)
4819
+
4820
+ if len(coords) == 0:
4821
+ continue
4822
+
4823
+ ymin, xmin = coords.min(axis=0)
4824
+ ymax, xmax = coords.max(axis=0)
4825
+
4826
+ bbox = [
4827
+ int(xmin),
4828
+ int(ymin),
4829
+ int(xmax - xmin),
4830
+ int(ymax - ymin),
4831
+ ]
4832
+ area = int(np.sum(instance_mask))
4833
+
4834
+ # Find contours for segmentation
4835
+ contours = measure.find_contours(instance_mask, 0.5)
4836
+ segmentation = []
4837
+ for contour in contours:
4838
+ contour = np.flip(contour, axis=1)
4839
+ segmentation_points = contour.ravel().tolist()
4840
+ if len(segmentation_points) >= 6:
4841
+ segmentation.append(segmentation_points)
4842
+
4843
+ if segmentation:
4844
+ stats["coco_data"]["annotations"].append(
4845
+ {
4846
+ "id": int(coco_ann_id),
4847
+ "image_id": int(image_id),
4848
+ "category_id": int(class_id),
4849
+ "bbox": bbox,
4850
+ "area": area,
4851
+ "segmentation": segmentation,
4852
+ "iscrowd": 0,
4853
+ }
4854
+ )
4855
+ coco_ann_id += 1
4856
+
4857
+ elif mask_file is not None and metadata_format == "YOLO":
4858
+ # Create YOLO labels directory if needed
4859
+ labels_dir = os.path.join(
4860
+ os.path.dirname(output_images_dir), "labels"
4861
+ )
4862
+ os.makedirs(labels_dir, exist_ok=True)
4863
+
4864
+ # Generate YOLO annotation file
4865
+ yolo_path = os.path.join(labels_dir, f"{tile_name}.txt")
4866
+ from scipy import ndimage
4867
+
4868
+ with open(yolo_path, "w") as yolo_file:
4869
+ for class_id in np.unique(label_mask):
4870
+ if class_id == 0:
4871
+ continue
4872
+
4873
+ # Track class for classes.txt
4874
+ class_name = next(
4875
+ (k for k, v in class_to_id.items() if v == class_id),
4876
+ str(class_id),
4877
+ )
4878
+ stats["yolo_classes"].add(class_name)
4879
+
4880
+ class_mask = (label_mask == class_id).astype(np.uint8)
4881
+ labeled_array, num_features = ndimage.label(class_mask)
4882
+
4883
+ for instance_id in range(1, num_features + 1):
4884
+ instance_mask = labeled_array == instance_id
4885
+ coords = np.argwhere(instance_mask)
4886
+
4887
+ if len(coords) == 0:
4888
+ continue
4889
+
4890
+ ymin, xmin = coords.min(axis=0)
4891
+ ymax, xmax = coords.max(axis=0)
4892
+
4893
+ # Convert to YOLO format (normalized center coordinates)
4894
+ x_center = ((xmin + xmax) / 2) / tile_size
4895
+ y_center = ((ymin + ymax) / 2) / tile_size
4896
+ width = (xmax - xmin) / tile_size
4897
+ height = (ymax - ymin) / tile_size
4898
+
4899
+ # YOLO uses 0-based class indices
4900
+ yolo_class_id = class_id - 1
4901
+ yolo_file.write(
4902
+ f"{yolo_class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n"
4903
+ )
3619
4904
 
3620
4905
  tile_index += 1
3621
4906
  if tile_index >= max_tiles:
@@ -3627,6 +4912,179 @@ def _process_image_mask_pair(
3627
4912
  return stats
3628
4913
 
3629
4914
 
4915
+ def display_training_tiles(
4916
+ output_dir,
4917
+ num_tiles=6,
4918
+ figsize=(18, 6),
4919
+ cmap="gray",
4920
+ save_path=None,
4921
+ ):
4922
+ """
4923
+ Display image and mask tile pairs from training data output.
4924
+
4925
+ Args:
4926
+ output_dir (str): Path to output directory containing 'images' and 'masks' subdirectories
4927
+ num_tiles (int): Number of tile pairs to display (default: 6)
4928
+ figsize (tuple): Figure size as (width, height) in inches (default: (18, 6))
4929
+ cmap (str): Colormap for mask display (default: 'gray')
4930
+ save_path (str, optional): If provided, save figure to this path instead of displaying
4931
+
4932
+ Returns:
4933
+ tuple: (fig, axes) matplotlib figure and axes objects
4934
+
4935
+ Example:
4936
+ >>> fig, axes = display_training_tiles('output/tiles', num_tiles=6)
4937
+ >>> # Or save to file
4938
+ >>> display_training_tiles('output/tiles', num_tiles=4, save_path='tiles_preview.png')
4939
+ """
4940
+ import matplotlib.pyplot as plt
4941
+
4942
+ # Get list of image tiles
4943
+ images_dir = os.path.join(output_dir, "images")
4944
+ if not os.path.exists(images_dir):
4945
+ raise ValueError(f"Images directory not found: {images_dir}")
4946
+
4947
+ image_tiles = sorted(os.listdir(images_dir))[:num_tiles]
4948
+
4949
+ if not image_tiles:
4950
+ raise ValueError(f"No image tiles found in {images_dir}")
4951
+
4952
+ # Limit to available tiles
4953
+ num_tiles = min(num_tiles, len(image_tiles))
4954
+
4955
+ # Create figure with subplots
4956
+ fig, axes = plt.subplots(2, num_tiles, figsize=figsize)
4957
+
4958
+ # Handle case where num_tiles is 1
4959
+ if num_tiles == 1:
4960
+ axes = axes.reshape(2, 1)
4961
+
4962
+ for idx, tile_name in enumerate(image_tiles):
4963
+ # Load and display image tile
4964
+ image_path = os.path.join(output_dir, "images", tile_name)
4965
+ with rasterio.open(image_path) as src:
4966
+ show(src, ax=axes[0, idx], title=f"Image {idx+1}")
4967
+
4968
+ # Load and display mask tile
4969
+ mask_path = os.path.join(output_dir, "masks", tile_name)
4970
+ if os.path.exists(mask_path):
4971
+ with rasterio.open(mask_path) as src:
4972
+ show(src, ax=axes[1, idx], title=f"Mask {idx+1}", cmap=cmap)
4973
+ else:
4974
+ axes[1, idx].text(
4975
+ 0.5,
4976
+ 0.5,
4977
+ "Mask not found",
4978
+ ha="center",
4979
+ va="center",
4980
+ transform=axes[1, idx].transAxes,
4981
+ )
4982
+ axes[1, idx].set_title(f"Mask {idx+1}")
4983
+
4984
+ plt.tight_layout()
4985
+
4986
+ # Save or show
4987
+ if save_path:
4988
+ plt.savefig(save_path, dpi=150, bbox_inches="tight")
4989
+ plt.close(fig)
4990
+ print(f"Figure saved to: {save_path}")
4991
+ else:
4992
+ plt.show()
4993
+
4994
+ return fig, axes
4995
+
4996
+
4997
+ def display_image_with_vector(
4998
+ image_path,
4999
+ vector_path,
5000
+ figsize=(16, 8),
5001
+ vector_color="red",
5002
+ vector_linewidth=1,
5003
+ vector_facecolor="none",
5004
+ save_path=None,
5005
+ ):
5006
+ """
5007
+ Display a raster image alongside the same image with vector overlay.
5008
+
5009
+ Args:
5010
+ image_path (str): Path to raster image file
5011
+ vector_path (str): Path to vector file (GeoJSON, Shapefile, etc.)
5012
+ figsize (tuple): Figure size as (width, height) in inches (default: (16, 8))
5013
+ vector_color (str): Edge color for vector features (default: 'red')
5014
+ vector_linewidth (float): Line width for vector features (default: 1)
5015
+ vector_facecolor (str): Fill color for vector features (default: 'none')
5016
+ save_path (str, optional): If provided, save figure to this path instead of displaying
5017
+
5018
+ Returns:
5019
+ tuple: (fig, axes, info_dict) where info_dict contains image and vector metadata
5020
+
5021
+ Example:
5022
+ >>> fig, axes, info = display_image_with_vector(
5023
+ ... 'image.tif',
5024
+ ... 'buildings.geojson',
5025
+ ... vector_color='blue'
5026
+ ... )
5027
+ >>> print(f"Number of features: {info['num_features']}")
5028
+ """
5029
+ import matplotlib.pyplot as plt
5030
+
5031
+ # Validate inputs
5032
+ if not os.path.exists(image_path):
5033
+ raise ValueError(f"Image file not found: {image_path}")
5034
+ if not os.path.exists(vector_path):
5035
+ raise ValueError(f"Vector file not found: {vector_path}")
5036
+
5037
+ # Create figure
5038
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
5039
+
5040
+ # Load and display image
5041
+ with rasterio.open(image_path) as src:
5042
+ # Plot image only
5043
+ show(src, ax=ax1, title="Image")
5044
+
5045
+ # Load vector data
5046
+ vector_data = gpd.read_file(vector_path)
5047
+
5048
+ # Reproject to image CRS if needed
5049
+ if vector_data.crs != src.crs:
5050
+ vector_data = vector_data.to_crs(src.crs)
5051
+
5052
+ # Plot image with vector overlay
5053
+ show(
5054
+ src,
5055
+ ax=ax2,
5056
+ title=f"Image with {len(vector_data)} Vector Features",
5057
+ )
5058
+ vector_data.plot(
5059
+ ax=ax2,
5060
+ facecolor=vector_facecolor,
5061
+ edgecolor=vector_color,
5062
+ linewidth=vector_linewidth,
5063
+ )
5064
+
5065
+ # Collect metadata
5066
+ info = {
5067
+ "image_shape": src.shape,
5068
+ "image_crs": src.crs,
5069
+ "image_bounds": src.bounds,
5070
+ "num_features": len(vector_data),
5071
+ "vector_crs": vector_data.crs,
5072
+ "vector_bounds": vector_data.total_bounds,
5073
+ }
5074
+
5075
+ plt.tight_layout()
5076
+
5077
+ # Save or show
5078
+ if save_path:
5079
+ plt.savefig(save_path, dpi=150, bbox_inches="tight")
5080
+ plt.close(fig)
5081
+ print(f"Figure saved to: {save_path}")
5082
+ else:
5083
+ plt.show()
5084
+
5085
+ return fig, (ax1, ax2), info
5086
+
5087
+
3630
5088
  def create_overview_image(
3631
5089
  src, tile_coordinates, output_path, tile_size, stride, geojson_path=None
3632
5090
  ) -> str:
@@ -7521,17 +8979,39 @@ def write_colormap(
7521
8979
 
7522
8980
  def plot_performance_metrics(
7523
8981
  history_path: str,
7524
- figsize: Tuple[int, int] = (15, 5),
8982
+ figsize: Optional[Tuple[int, int]] = None,
7525
8983
  verbose: bool = True,
7526
8984
  save_path: Optional[str] = None,
8985
+ csv_path: Optional[str] = None,
7527
8986
  kwargs: Optional[Dict] = None,
7528
- ) -> None:
7529
- """Plot performance metrics from a history object.
8987
+ ) -> pd.DataFrame:
8988
+ """Plot performance metrics from a training history object and return as DataFrame.
8989
+
8990
+ This function loads training history, plots available metrics (loss, IoU, F1,
8991
+ precision, recall), optionally exports to CSV, and returns all metrics as a
8992
+ pandas DataFrame for further analysis.
7530
8993
 
7531
8994
  Args:
7532
- history_path: The history object to plot.
7533
- figsize: The figure size.
7534
- verbose: Whether to print the best and final metrics.
8995
+ history_path (str): Path to the saved training history (.pth file).
8996
+ figsize (Optional[Tuple[int, int]]): Figure size in inches. If None,
8997
+ automatically determined based on number of metrics.
8998
+ verbose (bool): Whether to print best and final metric values. Defaults to True.
8999
+ save_path (Optional[str]): Path to save the plot image. If None, plot is not saved.
9000
+ csv_path (Optional[str]): Path to export metrics as CSV. If None, CSV is not exported.
9001
+ kwargs (Optional[Dict]): Additional keyword arguments for plt.savefig().
9002
+
9003
+ Returns:
9004
+ pd.DataFrame: DataFrame containing all metrics with columns for epoch and each metric.
9005
+ Columns include: 'epoch', 'train_loss', 'val_loss', 'val_iou', 'val_f1',
9006
+ 'val_precision', 'val_recall' (depending on availability in history).
9007
+
9008
+ Example:
9009
+ >>> df = plot_performance_metrics(
9010
+ ... 'training_history.pth',
9011
+ ... save_path='metrics_plot.png',
9012
+ ... csv_path='metrics.csv'
9013
+ ... )
9014
+ >>> print(df.head())
7535
9015
  """
7536
9016
  if kwargs is None:
7537
9017
  kwargs = {}
@@ -7541,65 +9021,135 @@ def plot_performance_metrics(
7541
9021
  train_loss_key = "train_losses" if "train_losses" in history else "train_loss"
7542
9022
  val_loss_key = "val_losses" if "val_losses" in history else "val_loss"
7543
9023
  val_iou_key = "val_ious" if "val_ious" in history else "val_iou"
7544
- val_dice_key = "val_dices" if "val_dices" in history else "val_dice"
9024
+ # Support both new (f1) and old (dice) key formats for backward compatibility
9025
+ val_f1_key = (
9026
+ "val_f1s"
9027
+ if "val_f1s" in history
9028
+ else ("val_dices" if "val_dices" in history else "val_dice")
9029
+ )
9030
+ # Add support for precision and recall
9031
+ val_precision_key = (
9032
+ "val_precisions" if "val_precisions" in history else "val_precision"
9033
+ )
9034
+ val_recall_key = "val_recalls" if "val_recalls" in history else "val_recall"
9035
+
9036
+ # Collect available metrics for plotting
9037
+ available_metrics = []
9038
+ metric_info = {
9039
+ "Loss": (train_loss_key, val_loss_key, ["Train Loss", "Val Loss"]),
9040
+ "IoU": (val_iou_key, None, ["Val IoU"]),
9041
+ "F1": (val_f1_key, None, ["Val F1"]),
9042
+ "Precision": (val_precision_key, None, ["Val Precision"]),
9043
+ "Recall": (val_recall_key, None, ["Val Recall"]),
9044
+ }
7545
9045
 
7546
- # Determine number of subplots based on available metrics
7547
- has_dice = val_dice_key in history
7548
- n_plots = 3 if has_dice else 2
7549
- figsize = (15, 5) if has_dice else (10, 5)
9046
+ for metric_name, (key1, key2, labels) in metric_info.items():
9047
+ if key1 in history or (key2 and key2 in history):
9048
+ available_metrics.append((metric_name, key1, key2, labels))
7550
9049
 
7551
- plt.figure(figsize=figsize)
9050
+ # Determine number of subplots and figure size
9051
+ n_plots = len(available_metrics)
9052
+ if figsize is None:
9053
+ figsize = (5 * n_plots, 5)
7552
9054
 
7553
- # Plot loss
7554
- plt.subplot(1, n_plots, 1)
9055
+ # Create DataFrame for all metrics
9056
+ n_epochs = 0
9057
+ df_data = {}
9058
+
9059
+ # Add epochs
9060
+ if "epochs" in history:
9061
+ df_data["epoch"] = history["epochs"]
9062
+ n_epochs = len(history["epochs"])
9063
+ elif train_loss_key in history:
9064
+ n_epochs = len(history[train_loss_key])
9065
+ df_data["epoch"] = list(range(1, n_epochs + 1))
9066
+
9067
+ # Add all available metrics to DataFrame
7555
9068
  if train_loss_key in history:
7556
- plt.plot(history[train_loss_key], label="Train Loss")
9069
+ df_data["train_loss"] = history[train_loss_key]
7557
9070
  if val_loss_key in history:
7558
- plt.plot(history[val_loss_key], label="Val Loss")
7559
- plt.title("Loss")
7560
- plt.xlabel("Epoch")
7561
- plt.ylabel("Loss")
7562
- plt.legend()
7563
- plt.grid(True)
7564
-
7565
- # Plot IoU
7566
- plt.subplot(1, n_plots, 2)
9071
+ df_data["val_loss"] = history[val_loss_key]
7567
9072
  if val_iou_key in history:
7568
- plt.plot(history[val_iou_key], label="Val IoU")
7569
- plt.title("IoU Score")
7570
- plt.xlabel("Epoch")
7571
- plt.ylabel("IoU")
7572
- plt.legend()
7573
- plt.grid(True)
7574
-
7575
- # Plot Dice if available
7576
- if has_dice:
7577
- plt.subplot(1, n_plots, 3)
7578
- plt.plot(history[val_dice_key], label="Val Dice")
7579
- plt.title("Dice Score")
7580
- plt.xlabel("Epoch")
7581
- plt.ylabel("Dice")
7582
- plt.legend()
7583
- plt.grid(True)
9073
+ df_data["val_iou"] = history[val_iou_key]
9074
+ if val_f1_key in history:
9075
+ df_data["val_f1"] = history[val_f1_key]
9076
+ if val_precision_key in history:
9077
+ df_data["val_precision"] = history[val_precision_key]
9078
+ if val_recall_key in history:
9079
+ df_data["val_recall"] = history[val_recall_key]
9080
+
9081
+ # Create DataFrame
9082
+ df = pd.DataFrame(df_data)
9083
+
9084
+ # Export to CSV if requested
9085
+ if csv_path:
9086
+ df.to_csv(csv_path, index=False)
9087
+ if verbose:
9088
+ print(f"Metrics exported to: {csv_path}")
9089
+
9090
+ # Create plots
9091
+ if n_plots > 0:
9092
+ fig, axes = plt.subplots(1, n_plots, figsize=figsize)
9093
+ if n_plots == 1:
9094
+ axes = [axes]
9095
+
9096
+ for idx, (metric_name, key1, key2, labels) in enumerate(available_metrics):
9097
+ ax = axes[idx]
9098
+
9099
+ if metric_name == "Loss":
9100
+ # Special handling for loss (has both train and val)
9101
+ if key1 in history:
9102
+ ax.plot(history[key1], label=labels[0])
9103
+ if key2 and key2 in history:
9104
+ ax.plot(history[key2], label=labels[1])
9105
+ else:
9106
+ # Single metric plots
9107
+ if key1 in history:
9108
+ ax.plot(history[key1], label=labels[0])
7584
9109
 
7585
- plt.tight_layout()
9110
+ ax.set_title(metric_name)
9111
+ ax.set_xlabel("Epoch")
9112
+ ax.set_ylabel(metric_name)
9113
+ ax.legend()
9114
+ ax.grid(True)
7586
9115
 
7587
- if save_path:
7588
- if "dpi" not in kwargs:
7589
- kwargs["dpi"] = 150
7590
- if "bbox_inches" not in kwargs:
7591
- kwargs["bbox_inches"] = "tight"
7592
- plt.savefig(save_path, **kwargs)
9116
+ plt.tight_layout()
7593
9117
 
7594
- plt.show()
9118
+ if save_path:
9119
+ if "dpi" not in kwargs:
9120
+ kwargs["dpi"] = 150
9121
+ if "bbox_inches" not in kwargs:
9122
+ kwargs["bbox_inches"] = "tight"
9123
+ plt.savefig(save_path, **kwargs)
9124
+
9125
+ plt.show()
7595
9126
 
9127
+ # Print summary statistics
7596
9128
  if verbose:
9129
+ print("\n=== Performance Metrics Summary ===")
7597
9130
  if val_iou_key in history:
7598
- print(f"Best IoU: {max(history[val_iou_key]):.4f}")
7599
- print(f"Final IoU: {history[val_iou_key][-1]:.4f}")
7600
- if val_dice_key in history:
7601
- print(f"Best Dice: {max(history[val_dice_key]):.4f}")
7602
- print(f"Final Dice: {history[val_dice_key][-1]:.4f}")
9131
+ print(
9132
+ f"IoU - Best: {max(history[val_iou_key]):.4f} | Final: {history[val_iou_key][-1]:.4f}"
9133
+ )
9134
+ if val_f1_key in history:
9135
+ print(
9136
+ f"F1 - Best: {max(history[val_f1_key]):.4f} | Final: {history[val_f1_key][-1]:.4f}"
9137
+ )
9138
+ if val_precision_key in history:
9139
+ print(
9140
+ f"Precision - Best: {max(history[val_precision_key]):.4f} | Final: {history[val_precision_key][-1]:.4f}"
9141
+ )
9142
+ if val_recall_key in history:
9143
+ print(
9144
+ f"Recall - Best: {max(history[val_recall_key]):.4f} | Final: {history[val_recall_key][-1]:.4f}"
9145
+ )
9146
+ if val_loss_key in history:
9147
+ print(
9148
+ f"Val Loss - Best: {min(history[val_loss_key]):.4f} | Final: {history[val_loss_key][-1]:.4f}"
9149
+ )
9150
+ print("===================================\n")
9151
+
9152
+ return df
7603
9153
 
7604
9154
 
7605
9155
  def get_device() -> torch.device: